Update pytorch_lightning.py

This commit is contained in:
Phúc H. Lê Khắc 2023-08-02 16:43:51 +02:00 committed by GitHub
parent f85dc71f46
commit 0eae807fe8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -133,7 +133,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None: # pylint: disable=arguments-differ
if self._intra_step_time:
self._snap_intra_step_time = time.monotonic()
@ -148,7 +148,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: # pylint: disable=arguments-differ
if self._inter_step_time:
self._snap_inter_step_time = time.monotonic()