From 0eae807fe8f57fa5f0e11c666561fc55843d802f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20H=2E=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Wed, 2 Aug 2023 16:43:51 +0200 Subject: [PATCH] Update pytorch_lightning.py --- nvitop/callbacks/pytorch_lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nvitop/callbacks/pytorch_lightning.py b/nvitop/callbacks/pytorch_lightning.py index 1bd68a9..5a11006 100644 --- a/nvitop/callbacks/pytorch_lightning.py +++ b/nvitop/callbacks/pytorch_lightning.py @@ -133,7 +133,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes self._snap_inter_step_time = None @rank_zero_only - def on_train_batch_start(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None: # pylint: disable=arguments-differ if self._intra_step_time: self._snap_intra_step_time = time.monotonic() @@ -148,7 +148,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only - def on_train_batch_end(self, trainer, **kwargs) -> None: # pylint: disable=arguments-differ + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: # pylint: disable=arguments-differ if self._inter_step_time: self._snap_inter_step_time = time.monotonic()