[Enhancement] Handle the case for Multi-Instance GPUs when using cuda_visible_devices (#1164)

This commit is contained in:
Adrian Joshua Strutt 2023-05-28 15:30:36 +02:00 committed by GitHub
parent 691500dce6
commit 0ff89f7c36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -376,7 +376,12 @@ def _get_device_id():
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
return int(cuda_visible_devices[local_rank])
if cuda_visible_devices[local_rank].isdigit():
return int(cuda_visible_devices[local_rank])
else:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
def _get_host_info() -> str: