From 47c9bc4dc675daeaea337ded67fd02ecf415f5bf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 21:41:14 -0800 Subject: [PATCH] Fix device idx split --- timm/utils/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 92b8a6b8..286e8ba4 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -92,7 +92,7 @@ def init_distributed_device(args): args.world_size = result['world_size'] args.rank = result['global_rank'] args.local_rank = result['local_rank'] - args.distributed = args.world_size > 1 + args.distributed = result['distributed'] device = torch.device(args.device) return device @@ -154,12 +154,12 @@ def init_distributed_device_so( assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, device_idx = device.split(':', maxsplit=1) + device, *device_idx = device.split(':', maxsplit=1) # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: - _logger.warning(f'device index {device_idx} removed from specified ({device}).') + _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') device = f'{device}:{local_rank}'