mirror of
https://github.com/XuehaiPan/nvitop.git
synced 2026-05-21 06:45:24 -06:00
refactor(core/libcuda): replace local function with functools.partial
Signed-off-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
parent
9597468a19
commit
be4b6fdee8
2 changed files with 12 additions and 13 deletions
|
|
@ -131,7 +131,7 @@ UtilizationRates = NamedTuple(
|
|||
],
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
_VALUE_OMITTED = object()
|
||||
|
||||
|
||||
class Device: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
|
|
@ -350,7 +350,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
|
|||
|
||||
@staticmethod
|
||||
def parse_cuda_visible_devices(
|
||||
cuda_visible_devices: Optional[str] = _SENTINEL,
|
||||
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
|
||||
) -> Union[List[int], List[Tuple[int, int]]]:
|
||||
"""Parses the given ``CUDA_VISIBLE_DEVICES`` value into NVML device indices.
|
||||
|
||||
|
|
@ -373,7 +373,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me
|
|||
If the ``CUDA_VISIBLE_DEVICES`` environment variable is invalid (e.g. duplicate entries).
|
||||
""" # pylint: disable=line-too-long
|
||||
|
||||
if cuda_visible_devices is _SENTINEL:
|
||||
if cuda_visible_devices is _VALUE_OMITTED:
|
||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', default=None)
|
||||
|
||||
return Device._parse_cuda_visible_devices(cuda_visible_devices)
|
||||
|
|
@ -2286,7 +2286,7 @@ def _cuda_visible_devices_parser(
|
|||
|
||||
|
||||
def parse_cuda_visible_devices_to_uuids(
|
||||
cuda_visible_devices: Optional[str] = _SENTINEL,
|
||||
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
|
||||
verbose=True,
|
||||
) -> List[str]:
|
||||
"""Parses the given ``CUDA_VISIBLE_DEVICES`` environment variable in a separate process and
|
||||
|
|
@ -2308,7 +2308,7 @@ def parse_cuda_visible_devices_to_uuids(
|
|||
If failed to parse the ``CUDA_VISIBLE_DEVICES`` environment variable.
|
||||
"""
|
||||
|
||||
if cuda_visible_devices is _SENTINEL:
|
||||
if cuda_visible_devices is _VALUE_OMITTED:
|
||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', default=None)
|
||||
|
||||
# Do not inherit file descriptors and handles from the parent process
|
||||
|
|
|
|||
|
|
@ -273,6 +273,8 @@ def _extract_cuda_errors_as_classes() -> None:
|
|||
e.g. :data:`CUDA_ERROR_INVALID_VALUE` will be turned into :class:`CUDAError_InvalidValue`.
|
||||
"""
|
||||
|
||||
import functools # pylint: disable=import-outside-toplevel
|
||||
|
||||
this_module = _sys.modules[__name__]
|
||||
cuda_error_names = [x for x in dir(this_module) if x.startswith('CUDA_ERROR_')]
|
||||
for err_name in cuda_error_names:
|
||||
|
|
@ -281,15 +283,12 @@ def _extract_cuda_errors_as_classes() -> None:
|
|||
class_name = 'CUDAError_{}'.format(pascal_case)
|
||||
err_val = getattr(this_module, err_name)
|
||||
|
||||
def gen_new(value):
|
||||
def new(cls):
|
||||
obj = CUDAError.__new__(cls, value)
|
||||
return obj
|
||||
|
||||
return new
|
||||
|
||||
# pylint: disable=protected-access
|
||||
new_error_class = type(class_name, (CUDAError,), {'__new__': gen_new(err_val)})
|
||||
new_error_class = type(
|
||||
class_name,
|
||||
(CUDAError,),
|
||||
{'__new__': functools.partial(CUDAError.__new__, value=err_val)},
|
||||
)
|
||||
new_error_class.__module__ = __name__
|
||||
if err_val in CUDAError._errcode_to_string:
|
||||
new_error_class.__doc__ = 'CUDA Error: {} Code: :data:`{}` ({}).'.format(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue