diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ceb8a37..d006f4c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.282 + rev: v0.0.284 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/nvitop/api/collector.py b/nvitop/api/collector.py index 2180164..227c781 100644 --- a/nvitop/api/collector.py +++ b/nvitop/api/collector.py @@ -396,7 +396,7 @@ class ResourceMetricCollector: # pylint: disable=too-many-instance-attributes self, devices: Iterable[Device] | None = None, root_pids: Iterable[int] | None = None, - interval: int | float = 1.0, + interval: float = 1.0, ) -> None: """Initialize the resource metric collector.""" if isinstance(interval, (int, float)) and interval > 0: diff --git a/nvitop/api/device.py b/nvitop/api/device.py index 0515307..bcf0b6f 100644 --- a/nvitop/api/device.py +++ b/nvitop/api/device.py @@ -142,6 +142,7 @@ from nvitop.api.utils import ( if TYPE_CHECKING: from typing_extensions import Literal # Python 3.8+ + from typing_extensions import Self # Python 3.11+ __all__ = [ @@ -562,7 +563,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me *, uuid: str | None = None, bus_id: str | None = None, - ) -> Device: + ) -> Self: """Create a new instance of Device. The type of the result is determined by the given argument. @@ -592,8 +593,10 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me ) if cls is not Device: + # Use the subclass type if the type is explicitly specified return super().__new__(cls) + # Auto subclass type inference logic goes here when `cls` is `Device` (e.g., calls `Device(...)`) match: re.Match | None = None if isinstance(index, str): match = cls.UUID_PATTERN.match(index) @@ -616,10 +619,10 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me f'index for MIG device must be a tuple of two integers ' f'but index = {index!r} was given', ) - return super().__new__(MigDevice) + return super().__new__(MigDevice) # type: ignore[return-value] elif uuid is not None and match is not None and match.group('MigMode') is not None: - return super().__new__(MigDevice) - return super().__new__(PhysicalDevice) + return super().__new__(MigDevice) # type: ignore[return-value] + return super().__new__(PhysicalDevice) # type: ignore[return-value] def __init__( self, @@ -1527,7 +1530,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me @memoize_when_activated def nvlink_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> list[ThroughputInfo]: # in KiB/s """The current NVLink throughput for each NVLink in KiB/s. @@ -1536,7 +1539,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1605,7 +1608,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_mean_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> ThroughputInfo: # in KiB/s """The mean NVLink throughput for all NVLinks in KiB/s. @@ -1614,7 +1617,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1638,7 +1641,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_tx_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> list[int | NaType]: # in KiB/s """The current NVLink transmit data throughput in KiB/s for each NVLink. @@ -1647,7 +1650,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1661,7 +1664,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_mean_tx_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> int | NaType: # in KiB/s """The mean NVLink transmit data throughput for all NVLinks in KiB/s. @@ -1670,7 +1673,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1684,7 +1687,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_rx_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> list[int | NaType]: # in KiB/s """The current NVLink receive data throughput for each NVLink in KiB/s. @@ -1693,7 +1696,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1707,7 +1710,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_mean_rx_throughput( self, - interval: int | float | None = None, + interval: float | None = None, ) -> int | NaType: # in KiB/s """The mean NVLink receive data throughput for all NVLinks in KiB/s. @@ -1716,7 +1719,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1730,7 +1733,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_tx_throughput_human( self, - interval: int | float | None = None, + interval: float | None = None, ) -> list[str | NaType]: # in human readable """The current NVLink transmit data throughput for each NVLink in human readable format. @@ -1739,7 +1742,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1756,7 +1759,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_mean_tx_throughput_human( self, - interval: int | float | None = None, + interval: float | None = None, ) -> str | NaType: # in human readable """The mean NVLink transmit data throughput for all NVLinks in human readable format. @@ -1765,7 +1768,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1782,7 +1785,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_rx_throughput_human( self, - interval: int | float | None = None, + interval: float | None = None, ) -> list[str | NaType]: # in human readable """The current NVLink receive data throughput for each NVLink in human readable format. @@ -1791,7 +1794,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -1808,7 +1811,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me def nvlink_mean_rx_throughput_human( self, - interval: int | float | None = None, + interval: float | None = None, ) -> str | NaType: # in human readable """The mean NVLink receive data throughput for all NVLinks in human readable format. @@ -1817,7 +1820,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me the first data counters. Args: - interval (Optional[Union[int, float]]): + interval (Optional[float]): The interval in seconds between two calls to get the NVLink throughput. If ``interval`` is a positive number, compares throughput counters before and after the interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares @@ -2653,7 +2656,7 @@ class CudaDevice(Device): *, nvml_index: int | tuple[int, int] | None = None, uuid: str | None = None, - ) -> CudaDevice: + ) -> Self: """Create a new instance of CudaDevice. The type of the result is determined by the given argument. @@ -2690,10 +2693,14 @@ class CudaDevice(Device): raise RuntimeError(f'CUDA Error: invalid device ordinal: {cuda_index!r}.') nvml_index = cuda_visible_devices[cuda_index] + if cls is not CudaDevice: + # Use the subclass type if the type is explicitly specified + return super().__new__(cls, index=nvml_index, uuid=uuid) + + # Auto subclass type inference logic goes here when `cls` is `CudaDevice` (e.g., calls `CudaDevice(...)`) if (nvml_index is not None and not isinstance(nvml_index, int)) or is_mig_device_uuid(uuid): return super().__new__(CudaMigDevice, index=nvml_index, uuid=uuid) # type: ignore[return-value] - - return super().__new__(cls, index=nvml_index, uuid=uuid) # type: ignore[return-value] + return super().__new__(CudaDevice, index=nvml_index, uuid=uuid) # type: ignore[return-value] def __init__( self, diff --git a/nvitop/api/libcuda.py b/nvitop/api/libcuda.py index ba5b9d2..c966916 100644 --- a/nvitop/api/libcuda.py +++ b/nvitop/api/libcuda.py @@ -33,6 +33,7 @@ from typing import ClassVar as _ClassVar if _TYPE_CHECKING: + from typing_extensions import Self as _Self # Python 3.11+ from typing_extensions import TypeAlias as _TypeAlias # Python 3.10+ @@ -41,7 +42,9 @@ class _struct_c_CUdevice_t(_ctypes.Structure): pass # opaque handle -_c_CUdevice_t: _TypeAlias = _ctypes.POINTER(_struct_c_CUdevice_t) # type: ignore[valid-type] +_c_CUdevice_t: _TypeAlias = _ctypes.POINTER( # type: ignore[valid-type] # noqa: PYI042 + _struct_c_CUdevice_t, +) _CUresult_t: _TypeAlias = _ctypes.c_uint @@ -237,11 +240,11 @@ class CUDAError(Exception): _errcode_to_name: _ClassVar[dict[int, str]] = {} value: int - def __new__(cls, value: int) -> CUDAError: + def __new__(cls, value: int) -> _Self: """Map value to a proper subclass of :class:`CUDAError`.""" if cls is CUDAError: # pylint: disable-next=self-cls-assignment - cls = CUDAError._value_class_mapping.get(value, cls) + cls = CUDAError._value_class_mapping.get(value, cls) # type: ignore[assignment] obj = Exception.__new__(cls) obj.value = value return obj diff --git a/nvitop/api/libcudart.py b/nvitop/api/libcudart.py index 995afdc..17ab5d1 100644 --- a/nvitop/api/libcudart.py +++ b/nvitop/api/libcudart.py @@ -26,11 +26,16 @@ import os as _os import platform as _platform import sys as _sys import threading as _threading +from typing import TYPE_CHECKING as _TYPE_CHECKING from typing import Any as _Any from typing import Callable as _Callable from typing import ClassVar as _ClassVar +if _TYPE_CHECKING: + from typing_extensions import Self as _Self # Python 3.11+ + + _cudaError_t = _ctypes.c_int # Error codes # @@ -283,11 +288,11 @@ class cudaError(Exception): _errcode_to_name: _ClassVar[dict[int, str]] = {} value: int - def __new__(cls, value: int) -> cudaError: + def __new__(cls, value: int) -> _Self: """Map value to a proper subclass of :class:`cudaError`.""" if cls is cudaError: # pylint: disable-next=self-cls-assignment - cls = cudaError._value_class_mapping.get(value, cls) + cls = cudaError._value_class_mapping.get(value, cls) # type: ignore[assignment] obj = Exception.__new__(cls) obj.value = value return obj diff --git a/nvitop/api/libnvml.py b/nvitop/api/libnvml.py index a47366b..b1618dd 100644 --- a/nvitop/api/libnvml.py +++ b/nvitop/api/libnvml.py @@ -47,6 +47,7 @@ from nvitop.api.utils import colored as __colored if _TYPE_CHECKING: + from typing_extensions import Self as _Self # Python 3.11+ from typing_extensions import TypeAlias as _TypeAlias # Python 3.10+ @@ -91,7 +92,7 @@ for _name, _attr in _vars_pynvml.items(): if _name in {'nvmlInit', 'nvmlInitWithFlags', 'nvmlShutdown'}: continue if _name.startswith(('NVML_ERROR_', 'NVMLError_')): - __all__.append(_name) + __all__.append(_name) # noqa: PYI056 if _name.startswith('NVML_ERROR_'): _errcode_to_name[_attr] = _name _const_names.append(_name) @@ -103,7 +104,7 @@ for _name, _attr in _vars_pynvml.items(): if (_name.startswith('NVML_') and not _name.startswith('NVML_ERROR_')) or ( _name.startswith('nvml') and isinstance(_attr, _FunctionType) ): - __all__.append(_name) + __all__.append(_name) # noqa: PYI056 if _name.startswith('NVML_'): _const_names.append(_name) @@ -173,8 +174,8 @@ del ( # 5. Add explicit references to appease linters # pylint: disable=no-member -c_nvmlDevice_t: _TypeAlias = _pynvml.c_nvmlDevice_t -c_nvmlFieldValue_t: _TypeAlias = _pynvml.c_nvmlFieldValue_t +c_nvmlDevice_t: _TypeAlias = _pynvml.c_nvmlDevice_t # noqa: PYI042 +c_nvmlFieldValue_t: _TypeAlias = _pynvml.c_nvmlFieldValue_t # noqa: PYI042 NVML_SUCCESS: int = _pynvml.NVML_SUCCESS NVML_ERROR_INSUFFICIENT_SIZE: int = _pynvml.NVML_ERROR_INSUFFICIENT_SIZE NVMLError_FunctionNotFound: _TypeAlias = _pynvml.NVMLError_FunctionNotFound @@ -905,12 +906,12 @@ class _CustomModule(_ModuleType): except AttributeError: return getattr(_pynvml, name) - def __enter__(self) -> _CustomModule: + def __enter__(self) -> _Self: """Entry of the context manager for ``with`` statement.""" _lazy_init() return self - def __exit__(self, *args: _Any, **kwargs: _Any) -> None: + def __exit__(self, *exc: object) -> None: """Shutdown the NVML context in the context manager for ``with`` statement.""" try: nvmlShutdown() diff --git a/nvitop/api/process.py b/nvitop/api/process.py index 74fccb0..004660d 100644 --- a/nvitop/api/process.py +++ b/nvitop/api/process.py @@ -43,6 +43,8 @@ from nvitop.api.utils import ( if TYPE_CHECKING: + from typing_extensions import Self # Python 3.11+ + from nvitop.api.device import Device @@ -191,7 +193,7 @@ class HostProcess(host.Process, metaclass=ABCMeta): _ident: tuple _lock: threading.RLock - def __new__(cls, pid: int | None = None) -> HostProcess: + def __new__(cls, pid: int | None = None) -> Self: """Return the cached instance of :class:`HostProcess`.""" if pid is None: pid = os.getpid() @@ -471,7 +473,7 @@ class GpuProcess: # pylint: disable=too-many-instance-attributes,too-many-publi gpu_cc_protected_memory: int | NaType | None = None, type: str | NaType | None = None, # pylint: disable=redefined-builtin # pylint: enable=unused-argument - ) -> GpuProcess: + ) -> Self: """Return the cached instance of :class:`GpuProcess`.""" if pid is None: pid = os.getpid() @@ -480,7 +482,7 @@ class GpuProcess: # pylint: disable=too-many-instance-attributes,too-many-publi try: instance = cls.INSTANCES[(pid, device)] if instance.is_running(): - return instance + return instance # type: ignore[return-value] except KeyError: pass diff --git a/nvitop/api/utils.py b/nvitop/api/utils.py index 545f44a..8b5b909 100644 --- a/nvitop/api/utils.py +++ b/nvitop/api/utils.py @@ -153,7 +153,8 @@ class NaType(str): nan """ - def __new__(cls) -> NaType: + # NOTE: Decorate this class with `@final` and remove `noqa` when we drop Python 3.7 support. + def __new__(cls) -> NaType: # noqa: PYI034 """Get the singleton instance (:const:`nvitop.NA`).""" if not hasattr(cls, '_instance'): cls._instance = super().__new__(cls, 'N/A') @@ -527,7 +528,7 @@ SIZE_PATTERN: re.Pattern = re.compile( # pylint: disable-next=too-many-return-statements,too-many-branches def bytes2human( - b: int | float | NaType, + b: int | float | NaType, # noqa: PYI041 *, min_unit: int = 1, ) -> str: @@ -599,7 +600,7 @@ def human2bytes(s: int | str) -> int: def timedelta2human( - dt: int | float | datetime.timedelta | NaType, + dt: int | float | datetime.timedelta | NaType, # noqa: PYI041 *, round: bool = False, # pylint: disable=redefined-builtin ) -> str: @@ -619,7 +620,7 @@ def timedelta2human( return '{:d}:{:02d}'.format(*divmod(seconds, 60)) -def utilization2string(utilization: int | float | NaType) -> str: +def utilization2string(utilization: int | float | NaType) -> str: # noqa: PYI041 """Convert a utilization rate to string.""" if utilization != NA: if isinstance(utilization, int): diff --git a/nvitop/gui/screens/main/process.py b/nvitop/gui/screens/main/process.py index f07d5e6..e548f9e 100644 --- a/nvitop/gui/screens/main/process.py +++ b/nvitop/gui/screens/main/process.py @@ -6,8 +6,8 @@ import itertools import threading import time -from collections import namedtuple from operator import attrgetter, xor +from typing import Any, Callable, NamedTuple from cachetools.func import ttl_cache @@ -29,7 +29,13 @@ from nvitop.gui.library import ( ) -Order = namedtuple('Order', ['key', 'reverse', 'offset', 'column', 'previous', 'next']) +class Order(NamedTuple): + key: Callable[[Any], Any] + reverse: bool + offset: int + column: str + previous: str + next: str class ProcessPanel(Displayable): # pylint: disable=too-many-instance-attributes