diff --git a/nvitop/callbacks/pytorch_lightning.py b/nvitop/callbacks/pytorch_lightning.py index 0c49ce4..1bd68a9 100644 --- a/nvitop/callbacks/pytorch_lightning.py +++ b/nvitop/callbacks/pytorch_lightning.py @@ -119,7 +119,7 @@ class GpuStatsLogger(Callback): # pylint: disable=too-many-instance-attributes f'The root device type is {trainer.strategy.root_device.type}.', ) - device_ids = trainer.data_parallel_device_ids + device_ids = trainer.device_ids try: self._devices = get_devices_by_logical_ids(device_ids, unique=True) except (libnvml.NVMLError, RuntimeError) as ex: