mirror of
https://github.com/XuehaiPan/nvitop.git
synced 2026-05-21 06:45:24 -06:00
style(api): add more type annotations
This commit is contained in:
parent
74295ca2ae
commit
15095dc624
9 changed files with 61 additions and 46 deletions
|
|
@ -250,7 +250,7 @@ def collect_in_background(
|
|||
raise ValueError(f'Invalid argument interval={interval!r}')
|
||||
interval = min(interval, collector.interval)
|
||||
|
||||
def target():
|
||||
def target() -> None:
|
||||
if on_start is not None:
|
||||
on_start(collector)
|
||||
try:
|
||||
|
|
@ -402,10 +402,7 @@ class ResourceMetricCollector: # pylint: disable=too-many-instance-attributes
|
|||
if devices is None:
|
||||
devices = Device.all()
|
||||
|
||||
if root_pids is None:
|
||||
root_pids = {os.getpid()}
|
||||
else:
|
||||
root_pids = set(root_pids)
|
||||
root_pids = {os.getpid()} if root_pids is None else set(root_pids)
|
||||
|
||||
self.interval = interval
|
||||
|
||||
|
|
|
|||
|
|
@ -2019,7 +2019,7 @@ class MigDevice(Device): # pylint: disable=too-many-instance-attributes
|
|||
|
||||
return snapshot
|
||||
|
||||
SNAPSHOT_KEYS = Device.SNAPSHOT_KEYS + ['gpu_instance_id', 'compute_instance_id']
|
||||
SNAPSHOT_KEYS = [*Device.SNAPSHOT_KEYS, 'gpu_instance_id', 'compute_instance_id']
|
||||
|
||||
|
||||
class CudaDevice(Device):
|
||||
|
|
@ -2387,15 +2387,14 @@ def normalize_cuda_visible_devices(cuda_visible_devices: Optional[str] = _VALUE_
|
|||
|
||||
# Helper functions #################################################################################
|
||||
|
||||
_PhysicalDeviceAttrs = NamedTuple(
|
||||
'PhysicalDeviceAttrs',
|
||||
[
|
||||
('index', int),
|
||||
('name', str),
|
||||
('uuid', str),
|
||||
('support_mig_mode', bool),
|
||||
],
|
||||
)
|
||||
|
||||
class _PhysicalDeviceAttrs(NamedTuple):
|
||||
index: int
|
||||
name: str
|
||||
uuid: str
|
||||
support_mig_mode: bool
|
||||
|
||||
|
||||
_PHYSICAL_DEVICE_ATTRS = None
|
||||
_GLOBAL_PHYSICAL_DEVICE = None
|
||||
_GLOBAL_PHYSICAL_DEVICE_LOCK = threading.RLock()
|
||||
|
|
@ -2551,7 +2550,7 @@ def _parse_cuda_visible_devices( # pylint: disable=too-many-branches,too-many-s
|
|||
|
||||
def _parse_cuda_visible_devices_to_uuids(
|
||||
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
|
||||
verbose=True,
|
||||
verbose: bool = True,
|
||||
) -> List[str]:
|
||||
"""Parse the given ``CUDA_VISIBLE_DEVICES`` environment variable in a separate process and return a list of device UUIDs.
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ __all__[__all__.index('Error')] = 'PsutilError'
|
|||
|
||||
|
||||
PsutilError = Error # make alias # noqa: F405
|
||||
del Error # noqa: F405,F821 # pylint: disable=undefined-variable
|
||||
del Error # noqa: F821 # pylint: disable=undefined-variable
|
||||
|
||||
|
||||
cpu_percent = _ttl_cache(ttl=0.25)(_psutil.cpu_percent)
|
||||
|
|
|
|||
|
|
@ -483,6 +483,12 @@ def __LoadCudaLibrary() -> None: # pylint: disable=too-many-branches
|
|||
_glob.iglob(_os.path.join(candidate_dir, f'cudart{bits}*.dll'))
|
||||
)
|
||||
|
||||
# Normalize paths and remove duplicates
|
||||
candidate_paths = list(
|
||||
dict.fromkeys(
|
||||
_os.path.normpath(_os.path.normcase(p)) for p in candidate_paths
|
||||
)
|
||||
)
|
||||
for lib_filename in candidate_paths:
|
||||
try:
|
||||
__cudaLib = _ctypes.CDLL(lib_filename)
|
||||
|
|
|
|||
|
|
@ -391,10 +391,7 @@ def nvmlQuery(
|
|||
retval = func(*args, **kwargs)
|
||||
except NVMLError_FunctionNotFound as e2:
|
||||
if not ignore_function_not_found:
|
||||
if func.__name__ == '<lambda>':
|
||||
identifier = _inspect.getsource(func)
|
||||
else:
|
||||
identifier = repr(func)
|
||||
identifier = _inspect.getsource(func) if func.__name__ == '<lambda>' else repr(func)
|
||||
with __lock:
|
||||
if (
|
||||
identifier not in UNKNOWN_FUNCTIONS
|
||||
|
|
|
|||
|
|
@ -100,7 +100,9 @@ _RAISE = object()
|
|||
_USE_FALLBACK_WHEN_RAISE = threading.local() # see also `GpuProcess.failsafe`
|
||||
|
||||
|
||||
def auto_garbage_clean(fallback=_RAISE):
|
||||
def auto_garbage_clean(
|
||||
fallback: Any = _RAISE,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Remove the object references in the instance cache if the method call fails (the process is gone).
|
||||
|
||||
The fallback value will be used with `:meth:`GpuProcess.failsafe`` context manager, otherwise
|
||||
|
|
@ -979,7 +981,7 @@ class GpuProcess: # pylint: disable=too-many-instance-attributes,too-many-publi
|
|||
|
||||
@classmethod
|
||||
def take_snapshots( # batched version of `as_snapshot`
|
||||
cls, gpu_processes: Iterable['GpuProcess'], *, failsafe=False
|
||||
cls, gpu_processes: Iterable['GpuProcess'], *, failsafe: bool = False
|
||||
) -> List[Snapshot]:
|
||||
"""Take snapshots for a list of :class:`GpuProcess` instances.
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,28 @@
|
|||
|
||||
# pylint: disable=missing-module-docstring
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
def add_scalar_dict(writer, main_tag, tag_scalar_dict, global_step=None, walltime=None):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from tensorboard.summary import Writer as SummaryWriter
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def add_scalar_dict(
|
||||
writer: 'SummaryWriter',
|
||||
main_tag: str,
|
||||
tag_scalar_dict: Dict[str, Union[int, float, 'np.floating']],
|
||||
global_step: Optional[Union[int, 'np.integer']] = None,
|
||||
walltime: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Add a batch of scalars to the writer.
|
||||
|
||||
Batched version of ``writer.add_scalar``.
|
||||
|
|
|
|||
|
|
@ -19,13 +19,14 @@ NVITOP_MONITOR_MODE = set(
|
|||
)
|
||||
|
||||
|
||||
def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command-line arguments for ``nvtiop``."""
|
||||
coloring_rules = '{} < th1 %% <= {} < th2 %% <= {}'.format(
|
||||
colored('light', 'green'), colored('moderate', 'yellow'), colored('heavy', 'red')
|
||||
)
|
||||
|
||||
def posint(argstring):
|
||||
def posint(argstring: str) -> int:
|
||||
num = int(argstring)
|
||||
if num <= 0:
|
||||
raise ValueError
|
||||
|
|
@ -251,7 +252,8 @@ def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
|
|||
return args
|
||||
|
||||
|
||||
def main(): # pylint: disable=too-many-branches,too-many-statements,too-many-locals
|
||||
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
||||
def main() -> None:
|
||||
"""Main function for ``nvitop`` CLI."""
|
||||
args = parse_arguments()
|
||||
|
||||
|
|
@ -272,10 +274,7 @@ def main(): # pylint: disable=too-many-branches,too-many-statements,too-many-lo
|
|||
|
||||
if hasattr(args, 'monitor') and args.monitor is None:
|
||||
mode = NVITOP_MONITOR_MODE.intersection({'auto', 'full', 'compact'})
|
||||
if len(mode) != 1:
|
||||
mode = 'auto'
|
||||
else:
|
||||
mode = mode.pop()
|
||||
mode = 'auto' if len(mode) != 1 else mode.pop()
|
||||
args.monitor = mode
|
||||
|
||||
if not setlocale_utf8():
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ import math
|
|||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from nvitop.api import Device, GpuProcess, colored, human2bytes, libnvml
|
||||
from nvitop.version import __version__
|
||||
|
|
@ -76,8 +76,7 @@ except ModuleNotFoundError:
|
|||
TTY = sys.stdout.isatty()
|
||||
|
||||
|
||||
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
||||
def select_devices(
|
||||
def select_devices( # pylint: disable=too-many-branches,too-many-statements,too-many-locals,unused-argument
|
||||
devices: Iterable[Device] = None,
|
||||
*,
|
||||
format: str = 'index', # pylint: disable=redefined-builtin
|
||||
|
|
@ -91,7 +90,7 @@ def select_devices(
|
|||
tolerance: int = 0, # in percentage
|
||||
free_accounts: List[str] = None,
|
||||
sort: bool = True,
|
||||
**kwargs, # fmt: skip # pylint: disable=unused-argument
|
||||
**kwargs: Any,
|
||||
) -> Union[List[int], List[Tuple[int, int]], List[str]]:
|
||||
"""Select a subset of devices satisfying the specified criteria.
|
||||
|
||||
|
|
@ -163,7 +162,7 @@ def select_devices(
|
|||
if isinstance(min_total_memory, str):
|
||||
min_total_memory = human2bytes(min_total_memory)
|
||||
|
||||
available_devices = [] # type: Iterable[DeviceSnapshot]
|
||||
available_devices = []
|
||||
for device in devices:
|
||||
available_devices.extend(dev.as_snapshot() for dev in device.to_leaf_devices())
|
||||
for device in available_devices:
|
||||
|
|
@ -261,20 +260,16 @@ def select_devices(
|
|||
|
||||
if format == 'device':
|
||||
return [device.real for device in available_devices]
|
||||
|
||||
if format == 'uuid':
|
||||
identifiers = [device.uuid for device in available_devices] # type: List[str]
|
||||
else:
|
||||
identifiers = [
|
||||
device.index for device in available_devices
|
||||
] # type: List[int, Tuple[int, int]]
|
||||
return identifiers
|
||||
return [device.uuid for device in available_devices]
|
||||
return [device.index for device in available_devices]
|
||||
|
||||
|
||||
def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
|
||||
# pylint: disable-next=too-many-branches,too-many-statements
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command-line arguments for ``nvisel``."""
|
||||
|
||||
def non_negint(argstring):
|
||||
def non_negint(argstring: str) -> int:
|
||||
num = int(argstring)
|
||||
if num < 0:
|
||||
raise ValueError
|
||||
|
|
@ -489,7 +484,7 @@ def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
|
|||
return args
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""Main function for ``nvisel`` CLI."""
|
||||
args = parse_arguments()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue