refactor(core/libcuda): replace local function with functools.partial

Signed-off-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
Xuehai Pan 2022-07-24 16:13:10 +08:00
parent 9597468a19
commit be4b6fdee8
2 changed files with 12 additions and 13 deletions

View file

@ -131,7 +131,7 @@ UtilizationRates = NamedTuple(
],
)
_SENTINEL = object()
_VALUE_OMITTED = object()
class Device: # pylint: disable=too-many-instance-attributes,too-many-public-methods
@ -350,7 +350,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
@staticmethod
def parse_cuda_visible_devices(
cuda_visible_devices: Optional[str] = _SENTINEL,
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
) -> Union[List[int], List[Tuple[int, int]]]:
"""Parses the given ``CUDA_VISIBLE_DEVICES`` value into NVML device indices.
@ -373,7 +373,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
If the ``CUDA_VISIBLE_DEVICES`` environment variable is invalid (e.g. duplicate entries).
""" # pylint: disable=line-too-long
if cuda_visible_devices is _SENTINEL:
if cuda_visible_devices is _VALUE_OMITTED:
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', default=None)
return Device._parse_cuda_visible_devices(cuda_visible_devices)
@ -2286,7 +2286,7 @@ def _cuda_visible_devices_parser(
def parse_cuda_visible_devices_to_uuids(
cuda_visible_devices: Optional[str] = _SENTINEL,
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
verbose=True,
) -> List[str]:
"""Parses the given ``CUDA_VISIBLE_DEVICES`` environment variable in a separate process and
@ -2308,7 +2308,7 @@ def parse_cuda_visible_devices_to_uuids(
If failed to parse the ``CUDA_VISIBLE_DEVICES`` environment variable.
"""
if cuda_visible_devices is _SENTINEL:
if cuda_visible_devices is _VALUE_OMITTED:
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', default=None)
# Do not inherit file descriptors and handles from the parent process

View file

@ -273,6 +273,8 @@ def _extract_cuda_errors_as_classes() -> None:
e.g. :data:`CUDA_ERROR_INVALID_VALUE` will be turned into :class:`CUDAError_InvalidValue`.
"""
import functools # pylint: disable=import-outside-toplevel
this_module = _sys.modules[__name__]
cuda_error_names = [x for x in dir(this_module) if x.startswith('CUDA_ERROR_')]
for err_name in cuda_error_names:
@ -281,15 +283,12 @@ def _extract_cuda_errors_as_classes() -> None:
class_name = 'CUDAError_{}'.format(pascal_case)
err_val = getattr(this_module, err_name)
def gen_new(value):
def new(cls):
obj = CUDAError.__new__(cls, value)
return obj
return new
# pylint: disable=protected-access
new_error_class = type(class_name, (CUDAError,), {'__new__': gen_new(err_val)})
new_error_class = type(
class_name,
(CUDAError,),
{'__new__': functools.partial(CUDAError.__new__, value=err_val)},
)
new_error_class.__module__ = __name__
if err_val in CUDAError._errcode_to_string:
new_error_class.__doc__ = 'CUDA Error: {} Code: :data:`{}` ({}).'.format(