Merge branch 'device_flex' into mesa_ema
commit
5e4a4b2adc
|
@ -2,18 +2,17 @@
|
|||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except ImportError:
|
||||
hvd = None
|
||||
|
||||
from .model import unwrap_model
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_tensor(tensor, n):
|
||||
rt = tensor.clone()
|
||||
|
@ -84,9 +83,38 @@ def init_distributed_device(args):
|
|||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
args.local_rank = 0
|
||||
result = init_distributed_device_so(
|
||||
device=getattr(args, 'device', 'cuda'),
|
||||
dist_backend=getattr(args, 'dist_backend', None),
|
||||
dist_url=getattr(args, 'dist_url', None),
|
||||
)
|
||||
args.device = result['device']
|
||||
args.world_size = result['world_size']
|
||||
args.rank = result['global_rank']
|
||||
args.local_rank = result['local_rank']
|
||||
device = torch.device(args.device)
|
||||
return device
|
||||
|
||||
|
||||
def init_distributed_device_so(
|
||||
device: str = 'cuda',
|
||||
dist_backend: Optional[str] = None,
|
||||
dist_url: Optional[str] = None,
|
||||
):
|
||||
# Distributed training = training on more than one GPU.
|
||||
# Works in both single and multi-node scenarios.
|
||||
distributed = False
|
||||
world_size = 1
|
||||
global_rank = 0
|
||||
local_rank = 0
|
||||
if dist_backend is None:
|
||||
# FIXME sane defaults for other device backends?
|
||||
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
|
||||
dist_url = dist_url or 'env://'
|
||||
|
||||
# TBD, support horovod?
|
||||
# if args.horovod:
|
||||
# import horovod.torch as hvd
|
||||
# assert hvd is not None, "Horovod is not installed"
|
||||
# hvd.init()
|
||||
# args.local_rank = int(hvd.local_rank())
|
||||
|
@ -96,42 +124,51 @@ def init_distributed_device(args):
|
|||
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
# os.environ['RANK'] = str(args.rank)
|
||||
# os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
dist_backend = getattr(args, 'dist_backend', 'nccl')
|
||||
dist_url = getattr(args, 'dist_url', 'env://')
|
||||
if is_distributed_env():
|
||||
if 'SLURM_PROCID' in os.environ:
|
||||
# DDP via SLURM
|
||||
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
||||
local_rank, global_rank, world_size = world_info_from_env()
|
||||
# SLURM var -> torch.distributed vars in case needed
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
os.environ['RANK'] = str(args.rank)
|
||||
os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
os.environ['RANK'] = str(global_rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
world_size=world_size,
|
||||
rank=global_rank,
|
||||
)
|
||||
else:
|
||||
# DDP via torchrun, torch.distributed.launch
|
||||
args.local_rank, _, _ = world_info_from_env()
|
||||
local_rank, _, _ = world_info_from_env()
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
)
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
args.distributed = True
|
||||
world_size = torch.distributed.get_world_size()
|
||||
global_rank = torch.distributed.get_rank()
|
||||
distributed = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if args.distributed:
|
||||
device = 'cuda:%d' % args.local_rank
|
||||
else:
|
||||
device = 'cuda:0'
|
||||
if 'cuda' in device:
|
||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||
|
||||
if distributed and device != 'cpu':
|
||||
device, device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
# Ignore manually specified device index in distributed mode and
|
||||
# override with resolved local rank, fewer headaches in most setups.
|
||||
if device_idx:
|
||||
_logger.warning(f'device index {device_idx} removed from specified ({device}).')
|
||||
|
||||
device = f'{device}:{local_rank}'
|
||||
|
||||
if device.startswith('cuda:'):
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
args.device = device
|
||||
device = torch.device(device)
|
||||
return device
|
||||
return dict(
|
||||
device=device,
|
||||
global_rank=global_rank,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
distributed=distributed,
|
||||
)
|
||||
|
|
34
train.py
34
train.py
|
@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
|||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
|
|||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
|
||||
# Device & distributed
|
||||
group = parser.add_argument_group('Device parameters')
|
||||
group.add_argument('--device', default='cuda', type=str,
|
||||
help="Device (accelerator) to use.")
|
||||
group.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
group.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
help='torch.cuda.synchronize() end of each step')
|
||||
group.add_argument("--local_rank", default=0, type=int)
|
||||
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
|
||||
help="Python imports for device backend modules.")
|
||||
|
||||
# Optimizer parameters
|
||||
group = parser.add_argument_group('Optimizer parameters')
|
||||
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
|
@ -352,16 +371,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|||
help='how many training processes to use (default: 4)')
|
||||
group.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
group.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
group.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
help='torch.cuda.synchronize() end of each step')
|
||||
group.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
group.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
|
@ -374,7 +383,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR
|
|||
help='Best metric (default: "top1"')
|
||||
group.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||
group.add_argument("--local_rank", default=0, type=int)
|
||||
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
||||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
group.add_argument('--log-wandb', action='store_true', default=False,
|
||||
|
@ -402,6 +410,10 @@ def main():
|
|||
utils.setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
if args.device_modules:
|
||||
for module in args.device_modules:
|
||||
importlib.import_module(module)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
|
Loading…
Reference in New Issue