From d8e61cafb7426d10ebcafec1ae3ab7a34d913492 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 13 Jul 2022 22:09:17 +0800 Subject: [PATCH] fix(core/device): fix CUDA_VISIBLE_DEVICES parsing for MIG UUID Signed-off-by: Xuehai Pan --- nvitop/core/device.py | 63 ++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/nvitop/core/device.py b/nvitop/core/device.py index 4d5806b..a6dff4f 100644 --- a/nvitop/core/device.py +++ b/nvitop/core/device.py @@ -353,7 +353,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me devices = [] for cuda_index in cuda_indices: if not 0 <= cuda_index < cuda_device_count: - raise RuntimeError('CUDA Error: invalid device ordinal') + raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index)) device = cuda_devices[cuda_index] devices.append(device) @@ -422,7 +422,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me use_integer_identifiers = None for identifier in map(str.strip, cuda_visible_devices.split(',')): if identifier in presented: - raise RuntimeError('CUDA Error: invalid device ordinal') + raise RuntimeError('CUDA Error: duplicate device ordinal: {!r}.'.format(identifier)) try: device = from_index_or_uuid(identifier) @@ -477,34 +477,44 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me 'but (index, uuid, bus_id) = {!r} were given'.format((index, uuid, bus_id)) ) + if cls is not Device: + return super().__new__(cls) + match = None if isinstance(index, str): match = cls.UUID_PATTERN.match(index) if match is not None: # passed by UUID index, uuid = None, index + elif isinstance(uuid, str): + match = cls.UUID_PATTERN.match(uuid) - if cls is Device: - if index is not None: - if not isinstance(index, int): - if not ( - isinstance(index, tuple) and len(index) == 2 and - isinstance(index[0], int) and isinstance(index[1], int) - ): - raise TypeError( - 'index for MIG device must be a tuple of 2 integers ' - 'but index = {!r} was given'.format((index)) - ) - return super().__new__(MigDevice) - elif uuid is not None: - if match is not None and match.group('MigMode') is not None: - return super().__new__(MigDevice) - return super().__new__(PhysicalDevice) - return super().__new__(cls) + if index is not None: + if not isinstance(index, int): + if not ( + isinstance(index, tuple) and len(index) == 2 and + isinstance(index[0], int) and isinstance(index[1], int) + ): + raise TypeError( + 'index for MIG device must be a tuple of 2 integers ' + 'but index = {!r} was given'.format((index)) + ) + return super().__new__(MigDevice) + elif uuid is not None: + if match is not None and match.group('MigMode') is not None: + return super().__new__(MigDevice) + return super().__new__(PhysicalDevice) def __init__(self, index: Optional[Union[int, str]] = None, *, uuid: Optional[str] = None, bus_id: Optional[str] = None) -> None: - """Initializes the instance created by ``__new__()``.""" + """Initializes the instance created by ``__new__()``. + + Raises: + NVMLError_NotFound: + If the device is not found for the given NVML identifier. + NVMLError_InvalidArgument: + If the device index is out of range. + """ if isinstance(index, str) and self.UUID_PATTERN.match(index) is not None: # passed by UUID index, uuid = None, index @@ -1983,7 +1993,6 @@ class CudaDevice(Device): If the number of non-None arguments is not exactly 1. TypeError: If the given index is a tuple but is not consist of two integers. - Raises: RuntimeError: If the environment variable ``CUDA_VISIBLE_DEVICES`` is invalid (e.g. duplicate entries). RuntimeError: @@ -1992,8 +2001,8 @@ class CudaDevice(Device): if cuda_index is not None and nvml_index is None and uuid is None: cuda_visible_devices = cls.parse_cuda_visible_devices() - if not 0 <= cuda_index < len(cuda_visible_devices): - raise RuntimeError('CUDA Error: invalid device ordinal') + if not isinstance(cuda_index, int) or not 0 <= cuda_index < len(cuda_visible_devices): + raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index)) nvml_index = cuda_visible_devices[cuda_index] if not isinstance(nvml_index, int) or is_mig_device_uuid(uuid): @@ -2007,14 +2016,18 @@ class CudaDevice(Device): """Initializes the instance created by ``__new__()``. Raises: + NVMLError_NotFound: + If the device is not found for the given NVML identifier. + NVMLError_InvalidArgument: + If the NVML index is out of range. RuntimeError: The given device is not visible to CUDA applications. """ if cuda_index is not None and nvml_index is None and uuid is None: cuda_visible_devices = self.parse_cuda_visible_devices() - if not 0 <= cuda_index < len(cuda_visible_devices): - raise RuntimeError('CUDA Error: invalid device ordinal') + if not isinstance(cuda_index, int) or not 0 <= cuda_index < len(cuda_visible_devices): + raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index)) nvml_index = cuda_visible_devices[cuda_index] super().__init__(index=nvml_index, uuid=uuid)