diff --git a/nvitop/callbacks/pytorch_lightning.py b/nvitop/callbacks/pytorch_lightning.py index 1bd68a9..5a11006 100644 --- a/nvitop/callbacks/pytorch_lightning.py +++ b/nvitop/callbacks/pytorch_lightning.py @@ -133,7 +133,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes self._snap_inter_step_time = None @rank_zero_only - def on_train_batch_start(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None: # pylint: disable=arguments-differ if self._intra_step_time: self._snap_intra_step_time = time.monotonic() @@ -148,7 +148,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only - def on_train_batch_end(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: # pylint: disable=arguments-differ if self._inter_step_time: self._snap_inter_step_time = time.monotonic()