fix(core/device): fix CUDA_VISIBLE_DEVICES parsing for MIG UUID

Signed-off-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
Xuehai Pan 2022-07-13 22:09:17 +08:00
parent 58c4c017a0
commit d8e61cafb7

View file

@ -353,7 +353,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
devices = []
for cuda_index in cuda_indices:
if not 0 <= cuda_index < cuda_device_count:
raise RuntimeError('CUDA Error: invalid device ordinal')
raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index))
device = cuda_devices[cuda_index]
devices.append(device)
@ -422,7 +422,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
use_integer_identifiers = None
for identifier in map(str.strip, cuda_visible_devices.split(',')):
if identifier in presented:
raise RuntimeError('CUDA Error: invalid device ordinal')
raise RuntimeError('CUDA Error: duplicate device ordinal: {!r}.'.format(identifier))
try:
device = from_index_or_uuid(identifier)
@ -477,34 +477,44 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
'but (index, uuid, bus_id) = {!r} were given'.format((index, uuid, bus_id))
)
if cls is not Device:
return super().__new__(cls)
match = None
if isinstance(index, str):
match = cls.UUID_PATTERN.match(index)
if match is not None: # passed by UUID
index, uuid = None, index
elif isinstance(uuid, str):
match = cls.UUID_PATTERN.match(uuid)
if cls is Device:
if index is not None:
if not isinstance(index, int):
if not (
isinstance(index, tuple) and len(index) == 2 and
isinstance(index[0], int) and isinstance(index[1], int)
):
raise TypeError(
'index for MIG device must be a tuple of 2 integers '
'but index = {!r} was given'.format((index))
)
return super().__new__(MigDevice)
elif uuid is not None:
if match is not None and match.group('MigMode') is not None:
return super().__new__(MigDevice)
return super().__new__(PhysicalDevice)
return super().__new__(cls)
if index is not None:
if not isinstance(index, int):
if not (
isinstance(index, tuple) and len(index) == 2 and
isinstance(index[0], int) and isinstance(index[1], int)
):
raise TypeError(
'index for MIG device must be a tuple of 2 integers '
'but index = {!r} was given'.format((index))
)
return super().__new__(MigDevice)
elif uuid is not None:
if match is not None and match.group('MigMode') is not None:
return super().__new__(MigDevice)
return super().__new__(PhysicalDevice)
def __init__(self, index: Optional[Union[int, str]] = None, *,
uuid: Optional[str] = None,
bus_id: Optional[str] = None) -> None:
"""Initializes the instance created by ``__new__()``."""
"""Initializes the instance created by ``__new__()``.
Raises:
NVMLError_NotFound:
If the device is not found for the given NVML identifier.
NVMLError_InvalidArgument:
If the device index is out of range.
"""
if isinstance(index, str) and self.UUID_PATTERN.match(index) is not None: # passed by UUID
index, uuid = None, index
@ -1983,7 +1993,6 @@ class CudaDevice(Device):
If the number of non-None arguments is not exactly 1.
TypeError:
If the given index is a tuple but is not consist of two integers.
Raises:
RuntimeError:
If the environment variable ``CUDA_VISIBLE_DEVICES`` is invalid (e.g. duplicate entries).
RuntimeError:
@ -1992,8 +2001,8 @@ class CudaDevice(Device):
if cuda_index is not None and nvml_index is None and uuid is None:
cuda_visible_devices = cls.parse_cuda_visible_devices()
if not 0 <= cuda_index < len(cuda_visible_devices):
raise RuntimeError('CUDA Error: invalid device ordinal')
if not isinstance(cuda_index, int) or not 0 <= cuda_index < len(cuda_visible_devices):
raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index))
nvml_index = cuda_visible_devices[cuda_index]
if not isinstance(nvml_index, int) or is_mig_device_uuid(uuid):
@ -2007,14 +2016,18 @@ class CudaDevice(Device):
"""Initializes the instance created by ``__new__()``.
Raises:
NVMLError_NotFound:
If the device is not found for the given NVML identifier.
NVMLError_InvalidArgument:
If the NVML index is out of range.
RuntimeError:
The given device is not visible to CUDA applications.
"""
if cuda_index is not None and nvml_index is None and uuid is None:
cuda_visible_devices = self.parse_cuda_visible_devices()
if not 0 <= cuda_index < len(cuda_visible_devices):
raise RuntimeError('CUDA Error: invalid device ordinal')
if not isinstance(cuda_index, int) or not 0 <= cuda_index < len(cuda_visible_devices):
raise RuntimeError('CUDA Error: invalid device ordinal: {!r}.'.format(cuda_index))
nvml_index = cuda_visible_devices[cuda_index]
super().__init__(index=nvml_index, uuid=uuid)