diff --git a/setup.py b/setup.py index 6cb159a..b88b75e 100755 --- a/setup.py +++ b/setup.py @@ -70,26 +70,48 @@ def vcs_version(name: str, path: Path | str) -> Generator[ModuleType]: file.write(content) +extra_requirements = { + 'lint': [ + 'black >= 24.0.0, < 25.0.0a0', + 'pylint[spelling]', + 'mypy', + 'typing-extensions', + 'pre-commit', + ], + 'cuda10': ['nvidia-ml-py == 11.450.51'], +} + + with vcs_version( name='nvitop.version', path=HERE / 'nvitop' / 'version.py', ) as version: + for pynvml_major in sorted( + {int(pynvml.partition('.')[0]) for pynvml in version.PYNVML_VERSION_CANDIDATES}, + ): + pynvml_range = [ + pynvml + for pynvml in version.PYNVML_VERSION_CANDIDATES + if pynvml.startswith(f'{pynvml_major}.') + ] + if len(pynvml_range) == 1: + extra_requirements[f'cuda{pynvml_major}'] = [ + f'nvidia-ml-py == {pynvml_range[0]}', + ] + elif len(pynvml_range) >= 2: + extra_requirements[f'cuda{pynvml_major}'] = [ + f'nvidia-ml-py >= {pynvml_range[0]}, <= {pynvml_range[-1]}', + ] + extra_requirements.update( + { + # The identifier could not start with numbers, add a prefix `pynvml-` + f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}'] + for pynvml in version.PYNVML_VERSION_CANDIDATES + }, + ) + setup( name='nvitop', version=version.__version__, - extras_require={ - 'lint': [ - 'black >= 24.0.0, < 25.0.0a0', - 'pylint[spelling]', - 'mypy', - 'typing-extensions', - 'pre-commit', - ], - 'cuda10': ['nvidia-ml-py == 11.450.51'], - **{ - # The identifier could not start with numbers, add a prefix `pynvml-` - f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}'] - for pynvml in version.PYNVML_VERSION_CANDIDATES - }, - }, + extras_require=extra_requirements, )