diff --git a/data/random_erasing.py b/data/random_erasing.py index b352c12d..478b6253 100644 --- a/data/random_erasing.py +++ b/data/random_erasing.py @@ -91,15 +91,13 @@ class RandomErasingTorch: def __init__( self, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - per_pixel=False, rand_color=False, - device='cuda'): + per_pixel=False, rand_color=False): self.probability = probability self.sl = sl self.sh = sh self.min_aspect = min_aspect self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] self.rand_color = rand_color # per block random, bounded by [pl, ph] - self.device = device def __call__(self, batch): batch_size, chan, img_h, img_w = batch.size() @@ -115,15 +113,15 @@ class RandomErasingTorch: h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if self.rand_color: - c = torch.empty((chan, 1, 1), dtype=batch.dtype, device=self.device).normal_() + c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_() elif not self.per_pixel: - c = torch.zeros((chan, 1, 1), dtype=batch.dtype, device=self.device) + c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda() if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) if self.per_pixel: img[:, top:top + h, left:left + w] = torch.empty( - (chan, h, w), dtype=batch.dtype, device=self.device).normal_() + (chan, h, w), dtype=batch.dtype).cuda().normal_() else: img[:, top:top + h, left:left + w] = c break diff --git a/data/utils.py b/data/utils.py index bda111f3..836df37f 100644 --- a/data/utils.py +++ b/data/utils.py @@ -18,25 +18,19 @@ class PrefetchLoader: def __init__(self, loader, - fp16=False, random_erasing=0., mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): self.loader = loader - self.fp16 = fp16 self.random_erasing = random_erasing self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) if random_erasing: self.random_erasing = RandomErasingTorch( - probability=random_erasing, per_pixel=True) + probability=random_erasing, per_pixel=False) else: self.random_erasing = None - if self.fp16: - self.mean = self.mean.half() - self.std = self.std.half() - def __iter__(self): stream = torch.cuda.Stream() first = True @@ -45,10 +39,7 @@ class PrefetchLoader: with torch.cuda.stream(stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) - if self.fp16: - next_input = next_input.half() - else: - next_input = next_input.float() + next_input = next_input.float() next_input = next_input.sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -67,6 +58,10 @@ class PrefetchLoader: def __len__(self): return len(self.loader) + @property + def sampler(self): + return self.loader.sampler + def create_loader( dataset, @@ -78,6 +73,7 @@ def create_loader( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=1, + distributed=False, ): if is_training: @@ -95,11 +91,16 @@ def create_loader( dataset.transform = transform + sampler = None + if distributed: + sampler = tdata.distributed.DistributedSampler(dataset) + loader = tdata.DataLoader( dataset, batch_size=batch_size, - shuffle=is_training, + shuffle=sampler is None and is_training, num_workers=num_workers, + sampler=sampler, collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate, ) if use_prefetcher: diff --git a/distributed_train.sh b/distributed_train.sh new file mode 100755 index 00000000..a1c68c0a --- /dev/null +++ b/distributed_train.sh @@ -0,0 +1,5 @@ +#!/bin/bash +NUM_PROC=$1 +shift +python -m torch.distributed.launch --nproc_per_node=$NUM_PROC dtrain.py "$@" + diff --git a/train.py b/train.py index 195dfee8..89632d39 100644 --- a/train.py +++ b/train.py @@ -1,19 +1,29 @@ + import argparse import time from collections import OrderedDict from datetime import datetime +try: + from apex import amp + from apex.parallel import DistributedDataParallel as DDP + has_apex = True +except ImportError: + has_apex = False + from data import * from models import model_factory from utils import * from optim import Nadam, AdaBound +from loss import LabelSmoothingCrossEntropy import scheduler import torch -import torch.nn +import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data as data +import torch.distributed as dist import torchvision.utils torch.backends.cudnn.benchmark = True @@ -45,6 +55,8 @@ parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', help='epoch interval to decay LR') +parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', + help='epochs to warmup LR, if scheduler supports') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', @@ -53,10 +65,14 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.1)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M', +parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M', help='weight decay (default: 0.0001)') +parser.add_argument('--smoothing', type=float, default=0.1, metavar='M', + help='label smoothing (default: 0.1)') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', @@ -73,22 +89,51 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') +parser.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA amp for mixed precision training') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') +parser.add_argument("--local_rank", default=0, type=int) def main(): args = parser.parse_args() - if args.output: - output_base = args.output + args.distributed = False + if 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + if args.distributed and args.num_gpu > 1: + print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') + args.num_gpu = 1 + + args.device = 'cuda:0' + args.world_size = 1 + r = -1 + if args.distributed: + args.device = 'cuda:%d' % args.local_rank + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', + init_method='env://') + args.world_size = torch.distributed.get_world_size() + r = torch.distributed.get_rank() + + if args.distributed: + print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.' + % (args.world_size, r)) else: - output_base = './output' - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, - str(args.img_size)]) - output_dir = get_outdir(output_base, 'train', exp_name) + print('Training with a single process with %d GPUs.' % args.num_gpu) + + output_dir = '' + if args.local_rank == 0: + if args.output: + output_base = args.output + else: + output_base = './output' + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(args.img_size)]) + output_dir = get_outdir(output_base, 'train', exp_name) batch_size = args.batch_size torch.manual_seed(args.seed) @@ -103,10 +148,11 @@ def main(): batch_size=batch_size, is_training=True, use_prefetcher=True, - random_erasing=0.5, + random_erasing=0.3, mean=data_mean, std=data_std, num_workers=args.workers, + distributed=args.distributed, ) dataset_eval = Dataset(os.path.join(args.data, 'validation')) @@ -120,6 +166,7 @@ def main(): mean=data_mean, std=data_std, num_workers=args.workers, + distributed=args.distributed, ) model = model_factory.create_model( @@ -156,28 +203,53 @@ def main(): print("=> no checkpoint found at '{}'".format(args.resume)) return False + if args.smoothing: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + if args.num_gpu > 1: - model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + if args.amp: + print('Warning: AMP does not work well with nn.DataParallel, disabling. ' + 'Use distributed mode for multi-GPU AMP.') + args.amp = False + model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() - train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda() - optimizer = create_optimizer(args, model.parameters()) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) - lr_scheduler, num_epochs = create_scheduler(args, optimizer) - print(num_epochs) + if has_apex and args.amp: + model, optimizer = amp.initialize(model, optimizer, opt_level='O3') + use_amp = True + print('AMP enabled') + else: + use_amp = False + print('AMP disabled') - saver = CheckpointSaver(checkpoint_dir=output_dir) + if args.distributed: + model = DDP(model, delay_allreduce=True) + + lr_scheduler, num_epochs = create_scheduler(args, optimizer) + if args.local_rank == 0: + print('Scheduled epochs: ', num_epochs) + + saver = None + if output_dir: + saver = CheckpointSaver(checkpoint_dir=output_dir) best_loss = None try: for epoch in range(start_epoch, num_epochs): + if args.distributed: + loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir) + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp) eval_metrics = validate( model, loader_eval, validate_loss_fn, args) @@ -189,16 +261,17 @@ def main(): epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_loss is None) - # save proper checkpoint with eval metric - best_loss = saver.save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.model, - 'state_dict': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'args': args, - }, - epoch=epoch + 1, - metric=eval_metrics['eval_loss']) + if saver is not None: + # save proper checkpoint with eval metric + best_loss = saver.save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.model, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'args': args, + }, + epoch=epoch + 1, + metric=eval_metrics['eval_loss']) except KeyboardInterrupt: pass @@ -207,7 +280,7 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir=''): + lr_scheduler=None, saver=None, output_dir='', use_amp=False): batch_time_m = AverageMeter() data_time_m = AverageMeter() @@ -225,10 +298,15 @@ def train_epoch( output = model(input) loss = loss_fn(output, target) - losses_m.update(loss.item(), input.size(0)) + if not args.distributed: + losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() - loss.backward() + if use_amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() optimizer.step() torch.cuda.synchronize() @@ -239,30 +317,36 @@ def train_epoch( lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) - print('Train: {} [{}/{} ({:.0f}%)] ' - 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' - 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' - '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' - 'LR: {lr:.4f} ' - 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( - epoch, - batch_idx, len(loader), - 100. * batch_idx / last_idx, - loss=losses_m, - batch_time=batch_time_m, - rate=input.size(0) / batch_time_m.val, - rate_avg=input.size(0) / batch_time_m.avg, - lr=lr, - data_time=data_time_m)) + if args.distributed: + reduced_loss = reduce_tensor(loss.data, args.world_size) + losses_m.update(reduced_loss.item(), input.size(0)) - if args.save_images: - torchvision.utils.save_image( - input, - os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), - padding=0, - normalize=True) + if args.local_rank == 0: + print('Train: {} [{}/{} ({:.0f}%)] ' + 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' + 'LR: {lr:.4f} ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, len(loader), + 100. * batch_idx / last_idx, + loss=losses_m, + batch_time=batch_time_m, + rate=input.size(0) * args.world_size / batch_time_m.val, + rate_avg=input.size(0) * args.world_size / batch_time_m.avg, + lr=lr, + data_time=data_time_m)) - if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0: + if args.save_images and output_dir: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True) + + if args.local_rank == 0 and ( + saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0): save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery({ 'epoch': save_epoch, @@ -309,15 +393,22 @@ def validate(model, loader, loss_fn, args): loss = loss_fn(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) + if args.distributed: + reduced_loss = reduce_tensor(loss.data, args.world_size) + prec1 = reduce_tensor(prec1, args.world_size) + prec5 = reduce_tensor(prec5, args.world_size) + else: + reduced_loss = loss.data + torch.cuda.synchronize() - losses_m.update(loss.item(), input.size(0)) + losses_m.update(reduced_loss.item(), input.size(0)) prec1_m.update(prec1.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() - if last_batch or batch_idx % args.log_interval == 0: + if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' @@ -362,6 +453,7 @@ def create_optimizer(args, parameters): def create_scheduler(args, optimizer): num_epochs = args.epochs + #FIXME expose cycle parms of the scheduler config to arguments if args.sched == 'cosine': lr_scheduler = scheduler.CosineLRScheduler( optimizer, @@ -369,8 +461,8 @@ def create_scheduler(args, optimizer): t_mul=1.0, lr_min=1e-5, decay_rate=args.decay_rate, - warmup_lr_init=1e-4, - warmup_t=0, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, ) @@ -381,8 +473,8 @@ def create_scheduler(args, optimizer): t_initial=num_epochs, t_mul=1.0, lr_min=1e-5, - warmup_lr_init=.001, - warmup_t=3, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, ) @@ -392,9 +484,18 @@ def create_scheduler(args, optimizer): optimizer, decay_t=args.decay_epochs, decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, ) return lr_scheduler, num_epochs +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= n + return rt + + if __name__ == '__main__': main() diff --git a/validate.py b/validate.py index f1b6df1b..9d21c873 100644 --- a/validate.py +++ b/validate.py @@ -9,6 +9,7 @@ import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.parallel +from collections import OrderedDict from models import create_model from data import Dataset, create_loader, get_model_meanstd @@ -60,7 +61,14 @@ def main(): print("=> loading checkpoint '{}'".format(args.checkpoint)) checkpoint = torch.load(args.checkpoint) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - model.load_state_dict(checkpoint['state_dict']) + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) print("=> loaded checkpoint '{}'".format(args.checkpoint))