style(api): add more type annotations

This commit is contained in:
Xuehai Pan 2023-02-20 05:27:19 +00:00
parent 74295ca2ae
commit 15095dc624
9 changed files with 61 additions and 46 deletions

View file

@ -250,7 +250,7 @@ def collect_in_background(
raise ValueError(f'Invalid argument interval={interval!r}')
interval = min(interval, collector.interval)
def target():
def target() -> None:
if on_start is not None:
on_start(collector)
try:
@ -402,10 +402,7 @@ class ResourceMetricCollector: # pylint: disable=too-many-instance-attributes
if devices is None:
devices = Device.all()
if root_pids is None:
root_pids = {os.getpid()}
else:
root_pids = set(root_pids)
root_pids = {os.getpid()} if root_pids is None else set(root_pids)
self.interval = interval

View file

@ -2019,7 +2019,7 @@ class MigDevice(Device): # pylint: disable=too-many-instance-attributes
return snapshot
SNAPSHOT_KEYS = Device.SNAPSHOT_KEYS + ['gpu_instance_id', 'compute_instance_id']
SNAPSHOT_KEYS = [*Device.SNAPSHOT_KEYS, 'gpu_instance_id', 'compute_instance_id']
class CudaDevice(Device):
@ -2387,15 +2387,14 @@ def normalize_cuda_visible_devices(cuda_visible_devices: Optional[str] = _VALUE_
# Helper functions #################################################################################
_PhysicalDeviceAttrs = NamedTuple(
'PhysicalDeviceAttrs',
[
('index', int),
('name', str),
('uuid', str),
('support_mig_mode', bool),
],
)
class _PhysicalDeviceAttrs(NamedTuple):
index: int
name: str
uuid: str
support_mig_mode: bool
_PHYSICAL_DEVICE_ATTRS = None
_GLOBAL_PHYSICAL_DEVICE = None
_GLOBAL_PHYSICAL_DEVICE_LOCK = threading.RLock()
@ -2551,7 +2550,7 @@ def _parse_cuda_visible_devices( # pylint: disable=too-many-branches,too-many-s
def _parse_cuda_visible_devices_to_uuids(
cuda_visible_devices: Optional[str] = _VALUE_OMITTED,
verbose=True,
verbose: bool = True,
) -> List[str]:
"""Parse the given ``CUDA_VISIBLE_DEVICES`` environment variable in a separate process and return a list of device UUIDs.

View file

@ -40,7 +40,7 @@ __all__[__all__.index('Error')] = 'PsutilError'
PsutilError = Error # make alias # noqa: F405
del Error # noqa: F405,F821 # pylint: disable=undefined-variable
del Error # noqa: F821 # pylint: disable=undefined-variable
cpu_percent = _ttl_cache(ttl=0.25)(_psutil.cpu_percent)

View file

@ -483,6 +483,12 @@ def __LoadCudaLibrary() -> None: # pylint: disable=too-many-branches
_glob.iglob(_os.path.join(candidate_dir, f'cudart{bits}*.dll'))
)
# Normalize paths and remove duplicates
candidate_paths = list(
dict.fromkeys(
_os.path.normpath(_os.path.normcase(p)) for p in candidate_paths
)
)
for lib_filename in candidate_paths:
try:
__cudaLib = _ctypes.CDLL(lib_filename)

View file

@ -391,10 +391,7 @@ def nvmlQuery(
retval = func(*args, **kwargs)
except NVMLError_FunctionNotFound as e2:
if not ignore_function_not_found:
if func.__name__ == '<lambda>':
identifier = _inspect.getsource(func)
else:
identifier = repr(func)
identifier = _inspect.getsource(func) if func.__name__ == '<lambda>' else repr(func)
with __lock:
if (
identifier not in UNKNOWN_FUNCTIONS

View file

@ -100,7 +100,9 @@ _RAISE = object()
_USE_FALLBACK_WHEN_RAISE = threading.local() # see also `GpuProcess.failsafe`
def auto_garbage_clean(fallback=_RAISE):
def auto_garbage_clean(
fallback: Any = _RAISE,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Remove the object references in the instance cache if the method call fails (the process is gone).
The fallback value will be used with `:meth:`GpuProcess.failsafe`` context manager, otherwise
@ -979,7 +981,7 @@ class GpuProcess: # pylint: disable=too-many-instance-attributes,too-many-publi
@classmethod
def take_snapshots( # batched version of `as_snapshot`
cls, gpu_processes: Iterable['GpuProcess'], *, failsafe=False
cls, gpu_processes: Iterable['GpuProcess'], *, failsafe: bool = False
) -> List[Snapshot]:
"""Take snapshots for a list of :class:`GpuProcess` instances.

View file

@ -17,8 +17,28 @@
# pylint: disable=missing-module-docstring
from typing import TYPE_CHECKING, Dict, Optional, Union
def add_scalar_dict(writer, main_tag, tag_scalar_dict, global_step=None, walltime=None):
if TYPE_CHECKING:
import numpy as np
try:
from tensorboard.summary import Writer as SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
except ImportError:
pass
def add_scalar_dict(
writer: 'SummaryWriter',
main_tag: str,
tag_scalar_dict: Dict[str, Union[int, float, 'np.floating']],
global_step: Optional[Union[int, 'np.integer']] = None,
walltime: Optional[float] = None,
) -> None:
"""Add a batch of scalars to the writer.
Batched version of ``writer.add_scalar``.

View file

@ -19,13 +19,14 @@ NVITOP_MONITOR_MODE = set(
)
def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
# pylint: disable=too-many-branches,too-many-statements
def parse_arguments() -> argparse.Namespace:
"""Parse command-line arguments for ``nvtiop``."""
coloring_rules = '{} < th1 %% <= {} < th2 %% <= {}'.format(
colored('light', 'green'), colored('moderate', 'yellow'), colored('heavy', 'red')
)
def posint(argstring):
def posint(argstring: str) -> int:
num = int(argstring)
if num <= 0:
raise ValueError
@ -251,7 +252,8 @@ def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
return args
def main(): # pylint: disable=too-many-branches,too-many-statements,too-many-locals
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def main() -> None:
"""Main function for ``nvitop`` CLI."""
args = parse_arguments()
@ -272,10 +274,7 @@ def main(): # pylint: disable=too-many-branches,too-many-statements,too-many-lo
if hasattr(args, 'monitor') and args.monitor is None:
mode = NVITOP_MONITOR_MODE.intersection({'auto', 'full', 'compact'})
if len(mode) != 1:
mode = 'auto'
else:
mode = mode.pop()
mode = 'auto' if len(mode) != 1 else mode.pop()
args.monitor = mode
if not setlocale_utf8():

View file

@ -60,7 +60,7 @@ import math
import os
import sys
import warnings
from typing import Iterable, List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union
from nvitop.api import Device, GpuProcess, colored, human2bytes, libnvml
from nvitop.version import __version__
@ -76,8 +76,7 @@ except ModuleNotFoundError:
TTY = sys.stdout.isatty()
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def select_devices(
def select_devices( # pylint: disable=too-many-branches,too-many-statements,too-many-locals,unused-argument
devices: Iterable[Device] = None,
*,
format: str = 'index', # pylint: disable=redefined-builtin
@ -91,7 +90,7 @@ def select_devices(
tolerance: int = 0, # in percentage
free_accounts: List[str] = None,
sort: bool = True,
**kwargs, # fmt: skip # pylint: disable=unused-argument
**kwargs: Any,
) -> Union[List[int], List[Tuple[int, int]], List[str]]:
"""Select a subset of devices satisfying the specified criteria.
@ -163,7 +162,7 @@ def select_devices(
if isinstance(min_total_memory, str):
min_total_memory = human2bytes(min_total_memory)
available_devices = [] # type: Iterable[DeviceSnapshot]
available_devices = []
for device in devices:
available_devices.extend(dev.as_snapshot() for dev in device.to_leaf_devices())
for device in available_devices:
@ -261,20 +260,16 @@ def select_devices(
if format == 'device':
return [device.real for device in available_devices]
if format == 'uuid':
identifiers = [device.uuid for device in available_devices] # type: List[str]
else:
identifiers = [
device.index for device in available_devices
] # type: List[int, Tuple[int, int]]
return identifiers
return [device.uuid for device in available_devices]
return [device.index for device in available_devices]
def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
# pylint: disable-next=too-many-branches,too-many-statements
def parse_arguments() -> argparse.Namespace:
"""Parse command-line arguments for ``nvisel``."""
def non_negint(argstring):
def non_negint(argstring: str) -> int:
num = int(argstring)
if num < 0:
raise ValueError
@ -489,7 +484,7 @@ def parse_arguments(): # pylint: disable=too-many-branches,too-many-statements
return args
def main():
def main() -> None:
"""Main function for ``nvisel`` CLI."""
args = parse_arguments()