chore(setup): add extra cuda11 and cuda12

This commit is contained in:
Xuehai Pan 2025-01-28 21:34:54 +08:00
parent 9d64ea83aa
commit d72bc712b8

View file

@ -70,26 +70,48 @@ def vcs_version(name: str, path: Path | str) -> Generator[ModuleType]:
file.write(content)
extra_requirements = {
'lint': [
'black >= 24.0.0, < 25.0.0a0',
'pylint[spelling]',
'mypy',
'typing-extensions',
'pre-commit',
],
'cuda10': ['nvidia-ml-py == 11.450.51'],
}
with vcs_version(
name='nvitop.version',
path=HERE / 'nvitop' / 'version.py',
) as version:
for pynvml_major in sorted(
{int(pynvml.partition('.')[0]) for pynvml in version.PYNVML_VERSION_CANDIDATES},
):
pynvml_range = [
pynvml
for pynvml in version.PYNVML_VERSION_CANDIDATES
if pynvml.startswith(f'{pynvml_major}.')
]
if len(pynvml_range) == 1:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py == {pynvml_range[0]}',
]
elif len(pynvml_range) >= 2:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py >= {pynvml_range[0]}, <= {pynvml_range[-1]}',
]
extra_requirements.update(
{
# The identifier could not start with numbers, add a prefix `pynvml-`
f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}']
for pynvml in version.PYNVML_VERSION_CANDIDATES
},
)
setup(
name='nvitop',
version=version.__version__,
extras_require={
'lint': [
'black >= 24.0.0, < 25.0.0a0',
'pylint[spelling]',
'mypy',
'typing-extensions',
'pre-commit',
],
'cuda10': ['nvidia-ml-py == 11.450.51'],
**{
# The identifier could not start with numbers, add a prefix `pynvml-`
f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}']
for pynvml in version.PYNVML_VERSION_CANDIDATES
},
},
extras_require=extra_requirements,
)