mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Tweak dist_backend to use device_type (before possible :)
This commit is contained in:
parent
6ca92570f7
commit
e57625e814
@ -108,6 +108,8 @@ def init_distributed_device_so(
|
|||||||
world_size = 1
|
world_size = 1
|
||||||
global_rank = 0
|
global_rank = 0
|
||||||
local_rank = 0
|
local_rank = 0
|
||||||
|
device_type, *device_idx = device.split(':', maxsplit=1)
|
||||||
|
|
||||||
if dist_backend is None:
|
if dist_backend is None:
|
||||||
# FIXME: verify that ROCm transform nccl to rccl
|
# FIXME: verify that ROCm transform nccl to rccl
|
||||||
dist_backends = {
|
dist_backends = {
|
||||||
@ -115,7 +117,7 @@ def init_distributed_device_so(
|
|||||||
"hpu": "hccl",
|
"hpu": "hccl",
|
||||||
"cuda": "nccl",
|
"cuda": "nccl",
|
||||||
}
|
}
|
||||||
dist_backend = dist_backends.get(device, 'gloo')
|
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||||
dist_url = dist_url or 'env://'
|
dist_url = dist_url or 'env://'
|
||||||
|
|
||||||
# TBD, support horovod?
|
# TBD, support horovod?
|
||||||
@ -155,18 +157,15 @@ def init_distributed_device_so(
|
|||||||
global_rank = torch.distributed.get_rank()
|
global_rank = torch.distributed.get_rank()
|
||||||
distributed = True
|
distributed = True
|
||||||
|
|
||||||
if 'cuda' in device:
|
if device_type == 'cuda':
|
||||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||||
|
|
||||||
if distributed and device != 'cpu':
|
if distributed and device != 'cpu':
|
||||||
device, *device_idx = device.split(':', maxsplit=1)
|
|
||||||
|
|
||||||
# Ignore manually specified device index in distributed mode and
|
# Ignore manually specified device index in distributed mode and
|
||||||
# override with resolved local rank, fewer headaches in most setups.
|
# override with resolved local rank, fewer headaches in most setups.
|
||||||
if device_idx:
|
if device_idx:
|
||||||
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
||||||
|
device = f'{device_type}:{local_rank}'
|
||||||
device = f'{device}:{local_rank}'
|
|
||||||
|
|
||||||
if device.startswith('cuda:'):
|
if device.startswith('cuda:'):
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user