From 8848dad362141b2d819322fdb527bf0579424aea Mon Sep 17 00:00:00 2001 From: Setepenre Date: Mon, 13 May 2024 16:55:42 -0400 Subject: [PATCH 1/2] Update distributed.py --- timm/utils/distributed.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 286e8ba4..5cfe1a9f 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -109,8 +109,13 @@ def init_distributed_device_so( global_rank = 0 local_rank = 0 if dist_backend is None: - # FIXME sane defaults for other device backends? - dist_backend = 'nccl' if 'cuda' in device else 'gloo' + # FIXME: verify that ROCm transform nccl to rccl + dist_backends = { + "xpu": "ccl", + "hpu": "hccl", + "cuda": "nccl", + } + dist_backend = dist_backends.get(device, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod? From e57625e8140f7aa2aa662b3417f1c0203ec9ffa2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 May 2024 08:49:25 -0700 Subject: [PATCH 2/2] Tweak dist_backend to use device_type (before possible :) --- timm/utils/distributed.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 5cfe1a9f..18f526bb 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -108,6 +108,8 @@ def init_distributed_device_so( world_size = 1 global_rank = 0 local_rank = 0 + device_type, *device_idx = device.split(':', maxsplit=1) + if dist_backend is None: # FIXME: verify that ROCm transform nccl to rccl dist_backends = { @@ -115,7 +117,7 @@ def init_distributed_device_so( "hpu": "hccl", "cuda": "nccl", } - dist_backend = dist_backends.get(device, 'gloo') + dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod? @@ -155,18 +157,15 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - if 'cuda' in device: + if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, *device_idx = device.split(':', maxsplit=1) - # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') - - device = f'{device}:{local_rank}' + device = f'{device_type}:{local_rank}' if device.startswith('cuda:'): torch.cuda.set_device(device)