diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 5cfe1a9f..18f526bb 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -108,6 +108,8 @@ def init_distributed_device_so( world_size = 1 global_rank = 0 local_rank = 0 + device_type, *device_idx = device.split(':', maxsplit=1) + if dist_backend is None: # FIXME: verify that ROCm transform nccl to rccl dist_backends = { @@ -115,7 +117,7 @@ def init_distributed_device_so( "hpu": "hccl", "cuda": "nccl", } - dist_backend = dist_backends.get(device, 'gloo') + dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod? @@ -155,18 +157,15 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - if 'cuda' in device: + if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, *device_idx = device.split(':', maxsplit=1) - # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') - - device = f'{device}:{local_rank}' + device = f'{device_type}:{local_rank}' if device.startswith('cuda:'): torch.cuda.set_device(device)