Merge pull request #2181 from huggingface/Delaunay-dist-backend

Delaunay dist backend flag
This commit is contained in:
Ross Wightman 2024-05-15 10:00:59 -07:00 committed by GitHub
commit 27fd2f35d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -108,9 +108,16 @@ def init_distributed_device_so(
world_size = 1
global_rank = 0
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)
if dist_backend is None:
# FIXME sane defaults for other device backends?
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
# FIXME: verify that ROCm transform nccl to rccl
dist_backends = {
"xpu": "ccl",
"hpu": "hccl",
"cuda": "nccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
# TBD, support horovod?
@ -150,18 +157,15 @@ def init_distributed_device_so(
global_rank = torch.distributed.get_rank()
distributed = True
if 'cuda' in device:
if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if distributed and device != 'cpu':
device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'
device = f'{device_type}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device)