Improving device flexibility in train. Fix #2081

This commit is contained in:
Ross Wightman 2024-01-20 15:10:20 -08:00
parent 53a4888328
commit a48ab818f5
2 changed files with 87 additions and 38 deletions

View File

@ -2,18 +2,17 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging
import os import os
from typing import Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model from .model import unwrap_model
_logger = logging.getLogger(__name__)
def reduce_tensor(tensor, n): def reduce_tensor(tensor, n):
rt = tensor.clone() rt = tensor.clone()
@ -84,9 +83,38 @@ def init_distributed_device(args):
args.world_size = 1 args.world_size = 1
args.rank = 0 # global rank args.rank = 0 # global rank
args.local_rank = 0 args.local_rank = 0
result = init_distributed_device_so(
device=getattr(args, 'device', 'cuda'),
dist_backend=getattr(args, 'dist_backend', None),
dist_url=getattr(args, 'dist_url', None),
)
args.device = result['device']
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
device = torch.device(args.device)
return device
def init_distributed_device_so(
device: str = 'cuda',
dist_backend: Optional[str] = None,
dist_url: Optional[str] = None,
):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
distributed = False
world_size = 1
global_rank = 0
local_rank = 0
if dist_backend is None:
# FIXME sane defaults for other device backends?
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
dist_url = dist_url or 'env://'
# TBD, support horovod? # TBD, support horovod?
# if args.horovod: # if args.horovod:
# import horovod.torch as hvd
# assert hvd is not None, "Horovod is not installed" # assert hvd is not None, "Horovod is not installed"
# hvd.init() # hvd.init()
# args.local_rank = int(hvd.local_rank()) # args.local_rank = int(hvd.local_rank())
@ -96,42 +124,51 @@ def init_distributed_device(args):
# os.environ['LOCAL_RANK'] = str(args.local_rank) # os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank) # os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size) # os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env(): if is_distributed_env():
if 'SLURM_PROCID' in os.environ: if 'SLURM_PROCID' in os.environ:
# DDP via SLURM # DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env() local_rank, global_rank, world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed # SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['RANK'] = str(args.rank) os.environ['RANK'] = str(global_rank)
os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=dist_backend, backend=dist_backend,
init_method=dist_url, init_method=dist_url,
world_size=args.world_size, world_size=world_size,
rank=args.rank, rank=global_rank,
) )
else: else:
# DDP via torchrun, torch.distributed.launch # DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env() local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=dist_backend, backend=dist_backend,
init_method=dist_url, init_method=dist_url,
) )
args.world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
args.distributed = True distributed = True
if torch.cuda.is_available(): if 'cuda' in device:
if args.distributed: assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
device = 'cuda:%d' % args.local_rank
else: if distributed and device != 'cpu':
device = 'cuda:0' device, device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx} removed from specified ({device}).')
device = f'{device}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device) torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device return dict(
device = torch.device(device) device=device,
return device global_rank=global_rank,
local_rank=local_rank,
world_size=world_size,
distributed=distributed,
)

View File

@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import importlib
import json import json
import logging import logging
import os import os
@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).") help="Enable compilation w/ specified backend (default: inductor).")
# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
help="Python imports for device backend modules.")
# Optimizer parameters # Optimizer parameters
group = parser.add_argument_group('Optimizer parameters') group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -350,16 +369,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)') help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False, group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument('--pin-mem', action='store_true', default=False, group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False, group.add_argument('--no-prefetcher', action='store_true', default=False,
@ -372,7 +381,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR
help='Best metric (default: "top1"') help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N', group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch') help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False, group.add_argument('--log-wandb', action='store_true', default=False,
@ -400,6 +408,10 @@ def main():
utils.setup_default_logging() utils.setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True