mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix device idx split
This commit is contained in:
parent
59239d9df5
commit
47c9bc4dc6
@ -92,7 +92,7 @@ def init_distributed_device(args):
|
|||||||
args.world_size = result['world_size']
|
args.world_size = result['world_size']
|
||||||
args.rank = result['global_rank']
|
args.rank = result['global_rank']
|
||||||
args.local_rank = result['local_rank']
|
args.local_rank = result['local_rank']
|
||||||
args.distributed = args.world_size > 1
|
args.distributed = result['distributed']
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
return device
|
return device
|
||||||
|
|
||||||
@ -154,12 +154,12 @@ def init_distributed_device_so(
|
|||||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||||
|
|
||||||
if distributed and device != 'cpu':
|
if distributed and device != 'cpu':
|
||||||
device, device_idx = device.split(':', maxsplit=1)
|
device, *device_idx = device.split(':', maxsplit=1)
|
||||||
|
|
||||||
# Ignore manually specified device index in distributed mode and
|
# Ignore manually specified device index in distributed mode and
|
||||||
# override with resolved local rank, fewer headaches in most setups.
|
# override with resolved local rank, fewer headaches in most setups.
|
||||||
if device_idx:
|
if device_idx:
|
||||||
_logger.warning(f'device index {device_idx} removed from specified ({device}).')
|
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
||||||
|
|
||||||
device = f'{device}:{local_rank}'
|
device = f'{device}:{local_rank}'
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user