diff --git a/nvitop/api/device.py b/nvitop/api/device.py index 8fb78ea..0515307 100644 --- a/nvitop/api/device.py +++ b/nvitop/api/device.py @@ -1428,9 +1428,9 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me power_usage = self.power_usage() power_limit = self.power_limit() if libnvml.nvmlCheckReturn(power_usage, int): - power_usage = f'{round(power_usage / 1000.0)}W' + power_usage = f'{round(power_usage / 1000)}W' # type: ignore[assignment] if libnvml.nvmlCheckReturn(power_limit, int): - power_limit = f'{round(power_limit / 1000.0)}W' + power_limit = f'{round(power_limit / 1000)}W' # type: ignore[assignment] return f'{power_usage} / {power_limit}' def pcie_throughput(self) -> ThroughputInfo: # in KiB/s @@ -1487,9 +1487,9 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me The current PCIe transmit throughput in human readable format, or :const:`nvitop.NA` when not applicable. """ - tx_throughput = self.pcie_tx_throughput() - if libnvml.nvmlCheckReturn(tx_throughput, int): - return f'{bytes2human(tx_throughput << 10)}/s' + tx = self.pcie_tx_throughput() + if libnvml.nvmlCheckReturn(tx, int): + return f'{bytes2human(tx * 1024)}/s' return NA def pcie_rx_throughput_human(self) -> str | NaType: # in human readable @@ -1502,9 +1502,9 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me The current PCIe receive throughput in human readable format, or :const:`nvitop.NA` when not applicable. """ - rx_throughput = self.pcie_rx_throughput() - if libnvml.nvmlCheckReturn(rx_throughput, int): - return f'{bytes2human(rx_throughput << 10)}/s' + rx = self.pcie_rx_throughput() + if libnvml.nvmlCheckReturn(rx, int): + return f'{bytes2human(rx * 1024)}/s' return NA def nvlink_link_count(self) -> int: @@ -1525,12 +1525,22 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me return self._nvlink_link_count # type: ignore[return-value] @memoize_when_activated - def nvlink_throughput(self) -> list[ThroughputInfo]: # in KiB/s + def nvlink_throughput( + self, + interval: int | float | None = None, + ) -> list[ThroughputInfo]: # in KiB/s """The current NVLink throughput for each NVLink in KiB/s. This function is querying data counters between methods calls and thus is the NVLink - throughput over that interval. For the first call, the function is blocking for 100ms to - get the first data counters. + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). Returns: List[ThroughputInfo(tx, rx)] A list of named tuples with current NVLink throughput for each NVLink in KiB/s, the item @@ -1555,9 +1565,17 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me ), ) + if interval is not None: + if not interval >= 0.0: + raise ValueError('`interval` must be a non-negative number, got {interval!r}.') + if interval > 0.0: + self._nvlink_throughput_counters = query_nvlink_throughput_counters() + time.sleep(interval) + if self._nvlink_throughput_counters is None: self._nvlink_throughput_counters = query_nvlink_throughput_counters() - time.sleep(0.1) + time.sleep(0.02) # 20ms + old_throughput_counters = self._nvlink_throughput_counters new_throughput_counters = query_nvlink_throughput_counters() @@ -1585,12 +1603,22 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me for tx, rx in zip(throughputs[:nvlink_link_count], throughputs[nvlink_link_count:]) ] - def nvlink_mean_throughput(self) -> ThroughputInfo: # in KiB/s + def nvlink_mean_throughput( + self, + interval: int | float | None = None, + ) -> ThroughputInfo: # in KiB/s """The mean NVLink throughput for all NVLinks in KiB/s. This function is querying data counters between methods calls and thus is the NVLink - throughput over that interval. For the first call, the function is blocking for 100ms to - get the first data counters. + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). Returns: ThroughputInfo(tx, rx) A named tuple with the mean NVLink throughput for all NVLinks in KiB/s, the item could @@ -1598,7 +1626,7 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me """ tx_throughputs = [] rx_throughputs = [] - for tx, rx in self.nvlink_throughput(): + for tx, rx in self.nvlink_throughput(interval=interval): if libnvml.nvmlCheckReturn(tx, int): tx_throughputs.append(tx) if libnvml.nvmlCheckReturn(rx, int): @@ -1608,88 +1636,200 @@ class Device: # pylint: disable=too-many-instance-attributes,too-many-public-me rx=round(sum(rx_throughputs) / len(rx_throughputs)) if rx_throughputs else NA, ) - def nvlink_tx_throughput(self) -> list[int | NaType]: # in KiB/s + def nvlink_tx_throughput( + self, + interval: int | float | None = None, + ) -> list[int | NaType]: # in KiB/s """The current NVLink transmit data throughput in KiB/s for each NVLink. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: List[Union[int, NaType]] The current NVLink transmit data throughput in KiB/s for each NVLink, or :const:`nvitop.NA` when not applicable. """ - return [tx for tx, _ in self.nvlink_throughput()] + return [tx for tx, _ in self.nvlink_throughput(interval=interval)] - def nvlink_mean_tx_throughput(self) -> int | NaType: # in KiB/s + def nvlink_mean_tx_throughput( + self, + interval: int | float | None = None, + ) -> int | NaType: # in KiB/s """The mean NVLink transmit data throughput for all NVLinks in KiB/s. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[int, NaType] The mean NVLink transmit data throughput for all NVLinks in KiB/s, or :const:`nvitop.NA` when not applicable. """ - return self.nvlink_mean_throughput().tx + return self.nvlink_mean_throughput(interval=interval).tx - def nvlink_rx_throughput(self) -> list[int | NaType]: # in KiB/s + def nvlink_rx_throughput( + self, + interval: int | float | None = None, + ) -> list[int | NaType]: # in KiB/s """The current NVLink receive data throughput for each NVLink in KiB/s. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[int, NaType] The current NVLink receive data throughput for each NVLink in KiB/s, or :const:`nvitop.NA` when not applicable. """ - return [rx for _, rx in self.nvlink_throughput()] + return [rx for _, rx in self.nvlink_throughput(interval=interval)] - def nvlink_mean_rx_throughput(self) -> int | NaType: # in KiB/s + def nvlink_mean_rx_throughput( + self, + interval: int | float | None = None, + ) -> int | NaType: # in KiB/s """The mean NVLink receive data throughput for all NVLinks in KiB/s. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[int, NaType] The mean NVLink receive data throughput for all NVLinks in KiB/s, or :const:`nvitop.NA` when not applicable. """ - return self.nvlink_mean_throughput().rx + return self.nvlink_mean_throughput(interval=interval).rx - def nvlink_tx_throughput_human(self) -> list[str | NaType]: # in human readable + def nvlink_tx_throughput_human( + self, + interval: int | float | None = None, + ) -> list[str | NaType]: # in human readable """The current NVLink transmit data throughput for each NVLink in human readable format. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[str, NaType] The current NVLink transmit data throughput for each NVLink in human readable format, or :const:`nvitop.NA` when not applicable. """ return [ - f'{bytes2human(tx << 10)}/s' if libnvml.nvmlCheckReturn(tx, int) else NA # type: ignore[operator] - for tx in self.nvlink_tx_throughput() + f'{bytes2human(tx * 1024)}/s' if libnvml.nvmlCheckReturn(tx, int) else NA + for tx in self.nvlink_tx_throughput(interval=interval) ] - def nvlink_mean_tx_throughput_human(self) -> str | NaType: # in human readable + def nvlink_mean_tx_throughput_human( + self, + interval: int | float | None = None, + ) -> str | NaType: # in human readable """The mean NVLink transmit data throughput for all NVLinks in human readable format. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[str, NaType] The mean NVLink transmit data throughput for all NVLinks in human readable format, or :const:`nvitop.NA` when not applicable. """ - mean_tx = self.nvlink_mean_tx_throughput() + mean_tx = self.nvlink_mean_tx_throughput(interval=interval) if libnvml.nvmlCheckReturn(mean_tx, int): - return f'{bytes2human(mean_tx << 10)}/s' # type: ignore[operator] + return f'{bytes2human(mean_tx * 1024)}/s' return NA - def nvlink_rx_throughput_human(self) -> list[str | NaType]: # in human readable + def nvlink_rx_throughput_human( + self, + interval: int | float | None = None, + ) -> list[str | NaType]: # in human readable """The current NVLink receive data throughput for each NVLink in human readable format. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[str, NaType] The current NVLink receive data throughput for each NVLink in human readable format, or :const:`nvitop.NA` when not applicable. """ return [ - f'{bytes2human(rx << 10)}/s' if libnvml.nvmlCheckReturn(rx, int) else NA # type: ignore[operator] - for rx in self.nvlink_rx_throughput() + f'{bytes2human(rx * 1024)}/s' if libnvml.nvmlCheckReturn(rx, int) else NA + for rx in self.nvlink_rx_throughput(interval=interval) ] - def nvlink_mean_rx_throughput_human(self) -> str | NaType: # in human readable + def nvlink_mean_rx_throughput_human( + self, + interval: int | float | None = None, + ) -> str | NaType: # in human readable """The mean NVLink receive data throughput for all NVLinks in human readable format. + This function is querying data counters between methods calls and thus is the NVLink + throughput over that interval. For the first call, the function is blocking for 20ms to get + the first data counters. + + Args: + interval (Optional[Union[int, float]]): + The interval in seconds between two calls to get the NVLink throughput. If + ``interval`` is a positive number, compares throughput counters before and after the + interval (blocking). If ``interval`` is :const`0.0` or :data:`None`, compares + throughput counters since the last call, returning immediately (non-blocking). + Returns: Union[str, NaType] The mean NVLink receive data throughput for all NVLinks in human readable format, or :const:`nvitop.NA` when not applicable. """ - mean_rx = self.nvlink_mean_rx_throughput() + mean_rx = self.nvlink_mean_rx_throughput(interval=interval) if libnvml.nvmlCheckReturn(mean_rx, int): - return f'{bytes2human(mean_rx << 10)}/s' # type: ignore[operator] + return f'{bytes2human(mean_rx * 1024)}/s' return NA def display_active(self) -> str | NaType: diff --git a/nvitop/api/utils.py b/nvitop/api/utils.py index d331721..545f44a 100644 --- a/nvitop/api/utils.py +++ b/nvitop/api/utils.py @@ -29,7 +29,7 @@ import re import sys import time from collections.abc import KeysView -from typing import Any, Callable, Generator, Iterable, Iterator +from typing import Any, Callable, Generator, Iterable, Iterator, TypeVar from psutil import WINDOWS @@ -717,8 +717,11 @@ class Snapshot: return KeysView(self) # type: ignore[arg-type] +Method = TypeVar('Method', bound=Callable[..., Any]) + + # Modified from psutil (https://github.com/giampaolo/psutil) -def memoize_when_activated(method: Callable[[Any], Any]) -> Callable[[Any], Any]: +def memoize_when_activated(method: Method) -> Method: """A memoize decorator which is disabled by default. It can be activated and deactivated on request. For efficiency reasons it can be used only @@ -726,17 +729,17 @@ def memoize_when_activated(method: Callable[[Any], Any]) -> Callable[[Any], Any] """ @functools.wraps(method) - def wrapped(self): # noqa: ANN001,ANN202 + def wrapped(self, *args, **kwargs): # noqa: ANN001,ANN002,ANN003,ANN202 try: # case 1: we previously entered oneshot() ctx ret = self._cache[method] # pylint: disable=protected-access except AttributeError: # case 2: we never entered oneshot() ctx - return method(self) + return method(self, *args, **kwargs) except KeyError: # case 3: we entered oneshot() ctx but there's no cache # for this entry yet - ret = method(self) + ret = method(self, *args, **kwargs) try: self._cache[method] = ret # pylint: disable=protected-access except AttributeError: @@ -762,4 +765,4 @@ def memoize_when_activated(method: Callable[[Any], Any]) -> Callable[[Any], Any] wrapped.cache_activate = cache_activate # type: ignore[attr-defined] wrapped.cache_deactivate = cache_deactivate # type: ignore[attr-defined] - return wrapped + return wrapped # type: ignore[return-value]