mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2181 from huggingface/Delaunay-dist-backend
Delaunay dist backend flag
This commit is contained in:
commit
27fd2f35d3
@ -108,9 +108,16 @@ def init_distributed_device_so(
|
||||
world_size = 1
|
||||
global_rank = 0
|
||||
local_rank = 0
|
||||
device_type, *device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
if dist_backend is None:
|
||||
# FIXME sane defaults for other device backends?
|
||||
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
|
||||
# FIXME: verify that ROCm transform nccl to rccl
|
||||
dist_backends = {
|
||||
"xpu": "ccl",
|
||||
"hpu": "hccl",
|
||||
"cuda": "nccl",
|
||||
}
|
||||
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||
dist_url = dist_url or 'env://'
|
||||
|
||||
# TBD, support horovod?
|
||||
@ -150,18 +157,15 @@ def init_distributed_device_so(
|
||||
global_rank = torch.distributed.get_rank()
|
||||
distributed = True
|
||||
|
||||
if 'cuda' in device:
|
||||
if device_type == 'cuda':
|
||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||
|
||||
if distributed and device != 'cpu':
|
||||
device, *device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
# Ignore manually specified device index in distributed mode and
|
||||
# override with resolved local rank, fewer headaches in most setups.
|
||||
if device_idx:
|
||||
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
||||
|
||||
device = f'{device}:{local_rank}'
|
||||
device = f'{device_type}:{local_rank}'
|
||||
|
||||
if device.startswith('cuda:'):
|
||||
torch.cuda.set_device(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user