refactor(api/libnvml): refactor symbol lookup logic

This commit is contained in:
Xuehai Pan 2023-08-26 10:30:10 +00:00
parent 6a9663b33f
commit 9f3ff53425

View file

@ -595,56 +595,21 @@ if not _pynvml_installation_corrupted:
_nvmlGetFunctionPointer = _pynvml._nvmlGetFunctionPointer
__get_running_processes_version_suffix = '_v3'
def lookup(symbol: str) -> _Any:
def lookup(symbol: str) -> _Any | None:
try:
ptr = _nvmlGetFunctionPointer(symbol)
except NVMLError_FunctionNotFound:
LOGGER.debug('Failed to found symbol `%s`.', symbol)
raise
return None
LOGGER.debug('Found symbol `%s`.', symbol)
return ptr
try:
lookup('nvmlDeviceGetConfComputeMemSizeInfo')
except NVMLError_FunctionNotFound:
c_nvmlProcessInfo_t = c_nvmlProcessInfo_v2_t
LOGGER.debug(
'NVML get running process version 3 API with v3 type struct is not available '
'due to incompatible NVIDIA driver. Fallback to use get running process '
'version 3 API with v2 type struct.',
)
try:
lookup('nvmlDeviceGetComputeRunningProcesses_v3')
except NVMLError_FunctionNotFound:
__get_running_processes_version_suffix = '_v2'
if lookup('nvmlDeviceGetConfComputeMemSizeInfo'):
if lookup('nvmlDeviceGetComputeRunningProcesses_v3'):
LOGGER.debug(
'NVML get running process version 3 API with v2 type struct is not '
'available due to incompatible NVIDIA driver. Fallback to use get running '
'process version 2 API with v2 type struct.',
'NVML get running process version 3 API with v3 type struct is available.',
)
try:
lookup('nvmlDeviceGetComputeRunningProcesses_v2')
except NVMLError_FunctionNotFound:
c_nvmlProcessInfo_t = c_nvmlProcessInfo_v1_t
__get_running_processes_version_suffix = ''
LOGGER.debug(
'NVML get running process version 2 API with v2 type struct is not '
'available due to incompatible NVIDIA driver. Fallback to use get '
'running process version 1 API with v1 type struct.',
)
else:
LOGGER.debug(
'NVML get running process version 2 API with v2 type struct is '
'available.',
)
else:
LOGGER.debug(
'NVML get running process version 3 API with v2 type struct is available.',
)
else:
try:
lookup('nvmlDeviceGetComputeRunningProcesses_v3')
except NVMLError_FunctionNotFound:
c_nvmlProcessInfo_t = c_nvmlProcessInfo_v2_t
__get_running_processes_version_suffix = '_v2'
LOGGER.debug(
@ -652,10 +617,37 @@ if not _pynvml_installation_corrupted:
'available due to incompatible NVIDIA driver. Fallback to use get running '
'process version 2 API with v2 type struct.',
)
else:
else:
c_nvmlProcessInfo_t = c_nvmlProcessInfo_v2_t
LOGGER.debug(
'NVML get running process version 3 API with v3 type struct is not available '
'due to incompatible NVIDIA driver. Fallback to use get running process '
'version 3 API with v2 type struct.',
)
if lookup('nvmlDeviceGetComputeRunningProcesses_v3'):
LOGGER.debug(
'NVML get running process version 3 API with v3 type struct is available.',
'NVML get running process version 3 API with v2 type struct is available.',
)
else:
__get_running_processes_version_suffix = '_v2'
LOGGER.debug(
'NVML get running process version 3 API with v2 type struct is not '
'available due to incompatible NVIDIA driver. Fallback to use get running '
'process version 2 API with v2 type struct.',
)
if lookup('nvmlDeviceGetComputeRunningProcesses_v2'):
LOGGER.debug(
'NVML get running process version 2 API with v2 type struct is '
'available.',
)
else:
c_nvmlProcessInfo_t = c_nvmlProcessInfo_v1_t
__get_running_processes_version_suffix = ''
LOGGER.debug(
'NVML get running process version 2 API with v2 type struct is not '
'available due to incompatible NVIDIA driver. Fallback to use get '
'running process version 1 API with v1 type struct.',
)
return __get_running_processes_version_suffix