Tweak dist_backend to use device_type (before possible :)

This commit is contained in:
Ross Wightman 2024-05-15 08:49:25 -07:00
parent 6ca92570f7
commit e57625e814

View File

@ -108,6 +108,8 @@ def init_distributed_device_so(
world_size = 1
global_rank = 0
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)
if dist_backend is None:
# FIXME: verify that ROCm transform nccl to rccl
dist_backends = {
@ -115,7 +117,7 @@ def init_distributed_device_so(
"hpu": "hccl",
"cuda": "nccl",
}
dist_backend = dist_backends.get(device, 'gloo')
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
# TBD, support horovod?
@ -155,18 +157,15 @@ def init_distributed_device_so(
global_rank = torch.distributed.get_rank()
distributed = True
if 'cuda' in device:
if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if distributed and device != 'cpu':
device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'
device = f'{device_type}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device)