Fix device idx split

This commit is contained in:
Ross Wightman 2024-02-10 21:41:14 -08:00
parent 59239d9df5
commit 47c9bc4dc6

View File

@ -92,7 +92,7 @@ def init_distributed_device(args):
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
args.distributed = args.world_size > 1
args.distributed = result['distributed']
device = torch.device(args.device)
return device
@ -154,12 +154,12 @@ def init_distributed_device_so(
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if distributed and device != 'cpu':
device, device_idx = device.split(':', maxsplit=1)
device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx} removed from specified ({device}).')
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'