diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index ee9a358c..95655f2c 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -2,18 +2,17 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os +from typing import Optional import torch from torch import distributed as dist -try: - import horovod.torch as hvd -except ImportError: - hvd = None - from .model import unwrap_model +_logger = logging.getLogger(__name__) + def reduce_tensor(tensor, n): rt = tensor.clone() @@ -84,9 +83,38 @@ def init_distributed_device(args): args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 + result = init_distributed_device_so( + device=getattr(args, 'device', 'cuda'), + dist_backend=getattr(args, 'dist_backend', None), + dist_url=getattr(args, 'dist_url', None), + ) + args.device = result['device'] + args.world_size = result['world_size'] + args.rank = result['global_rank'] + args.local_rank = result['local_rank'] + device = torch.device(args.device) + return device + + +def init_distributed_device_so( + device: str = 'cuda', + dist_backend: Optional[str] = None, + dist_url: Optional[str] = None, +): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + distributed = False + world_size = 1 + global_rank = 0 + local_rank = 0 + if dist_backend is None: + # FIXME sane defaults for other device backends? + dist_backend = 'nccl' if 'cuda' in device else 'gloo' + dist_url = dist_url or 'env://' # TBD, support horovod? # if args.horovod: + # import horovod.torch as hvd # assert hvd is not None, "Horovod is not installed" # hvd.init() # args.local_rank = int(hvd.local_rank()) @@ -96,42 +124,51 @@ def init_distributed_device(args): # os.environ['LOCAL_RANK'] = str(args.local_rank) # os.environ['RANK'] = str(args.rank) # os.environ['WORLD_SIZE'] = str(args.world_size) - dist_backend = getattr(args, 'dist_backend', 'nccl') - dist_url = getattr(args, 'dist_url', 'env://') if is_distributed_env(): if 'SLURM_PROCID' in os.environ: # DDP via SLURM - args.local_rank, args.rank, args.world_size = world_info_from_env() + local_rank, global_rank, world_size = world_info_from_env() # SLURM var -> torch.distributed vars in case needed - os.environ['LOCAL_RANK'] = str(args.local_rank) - os.environ['RANK'] = str(args.rank) - os.environ['WORLD_SIZE'] = str(args.world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['RANK'] = str(global_rank) + os.environ['WORLD_SIZE'] = str(world_size) torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, - world_size=args.world_size, - rank=args.rank, + world_size=world_size, + rank=global_rank, ) else: # DDP via torchrun, torch.distributed.launch - args.local_rank, _, _ = world_info_from_env() + local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, ) - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - args.distributed = True + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + distributed = True - if torch.cuda.is_available(): - if args.distributed: - device = 'cuda:%d' % args.local_rank - else: - device = 'cuda:0' + if 'cuda' in device: + assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + + if distributed and device != 'cpu': + device, device_idx = device.split(':', maxsplit=1) + + # Ignore manually specified device index in distributed mode and + # override with resolved local rank, fewer headaches in most setups. + if device_idx: + _logger.warning(f'device index {device_idx} removed from specified ({device}).') + + device = f'{device}:{local_rank}' + + if device.startswith('cuda:'): torch.cuda.set_device(device) - else: - device = 'cpu' - args.device = device - device = torch.device(device) - return device + return dict( + device=device, + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + distributed=distributed, + ) diff --git a/train.py b/train.py index e3b3e037..a773d855 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import importlib import json import logging import os @@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', help="Enable compilation w/ specified backend (default: inductor).") +# Device & distributed +group = parser.add_argument_group('Device parameters') +group.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") +group.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +group.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +group.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') +group.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') +group.add_argument('--synchronize-step', action='store_true', default=False, + help='torch.cuda.synchronize() end of each step') +group.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--device-modules', default=None, type=str, nargs='+', + help="Python imports for device backend modules.") + # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -352,16 +371,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') -group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -group.add_argument('--amp-dtype', default='float16', type=str, - help='lower precision AMP dtype (default: float16)') -group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') -group.add_argument('--no-ddp-bb', action='store_true', default=False, - help='Force broadcast buffers for native DDP to off.') -group.add_argument('--synchronize-step', action='store_true', default=False, - help='torch.cuda.synchronize() end of each step') group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, @@ -374,7 +383,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR help='Best metric (default: "top1"') group.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -group.add_argument("--local_rank", default=0, type=int) group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, @@ -402,6 +410,10 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() + if args.device_modules: + for module in args.device_modules: + importlib.import_module(module) + if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True