diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 286e8ba4..18f526bb 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -108,9 +108,16 @@ def init_distributed_device_so( world_size = 1 global_rank = 0 local_rank = 0 + device_type, *device_idx = device.split(':', maxsplit=1) + if dist_backend is None: - # FIXME sane defaults for other device backends? - dist_backend = 'nccl' if 'cuda' in device else 'gloo' + # FIXME: verify that ROCm transform nccl to rccl + dist_backends = { + "xpu": "ccl", + "hpu": "hccl", + "cuda": "nccl", + } + dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod? @@ -150,18 +157,15 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - if 'cuda' in device: + if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, *device_idx = device.split(':', maxsplit=1) - # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') - - device = f'{device}:{local_rank}' + device = f'{device_type}:{local_rank}' if device.startswith('cuda:'): torch.cuda.set_device(device)