import logging import os from argparse import ArgumentParser from collections import OrderedDict import resnet_cifar import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision import datasets, transforms from mmcv import Config from mmcv.runner import DistSamplerSeedHook, Runner def accuracy(output, target, topk=(1, )): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def batch_processor(model, data, train_mode): img, label = data label = label.cuda(non_blocking=True) pred = model(img) loss = F.cross_entropy(pred, label) acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5)) log_vars = OrderedDict() log_vars['loss'] = loss.item() log_vars['acc_top1'] = acc_top1.item() log_vars['acc_top5'] = acc_top5.item() outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0)) return outputs def get_logger(log_level): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) logger = logging.getLogger() return logger def init_dist(backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def parse_args(): parser = ArgumentParser(description='Train CIFAR-10 classification') parser.add_argument('config', help='train config file path') parser.add_argument( '--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) return parser.parse_args() def main(): args = parse_args() cfg = Config.fromfile(args.config) logger = get_logger(cfg.log_level) # init distributed environment if necessary if args.launcher == 'none': dist = False logger.info('Disabled distributed training.') else: dist = True init_dist(**cfg.dist_params) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() if rank != 0: logger.setLevel('ERROR') logger.info('Enabled distributed training.') # build datasets and dataloaders normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std) train_dataset = datasets.CIFAR10( root=cfg.data_root, train=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) val_dataset = datasets.CIFAR10( root=cfg.data_root, train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])) if dist: num_workers = cfg.data_workers assert cfg.batch_size % world_size == 0 batch_size = cfg.batch_size // world_size train_sampler = DistributedSampler(train_dataset, world_size, rank) val_sampler = DistributedSampler(val_dataset, world_size, rank) shuffle = False else: num_workers = cfg.data_workers * len(cfg.gpus) batch_size = cfg.batch_size train_sampler = None val_sampler = None shuffle = True train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle, sampler=train_sampler, num_workers=num_workers) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, sampler=val_sampler, num_workers=num_workers) # build model model = getattr(resnet_cifar, cfg.model)() if dist: model = DistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()]) else: model = DataParallel(model, device_ids=cfg.gpus).cuda() # build runner and register hooks runner = Runner( model, batch_processor, cfg.optimizer, cfg.work_dir, log_level=cfg.log_level) runner.register_training_hooks( lr_config=cfg.lr_config, optimizer_config=cfg.optimizer_config, checkpoint_config=cfg.checkpoint_config, log_config=cfg.log_config) if dist: runner.register_hook(DistSamplerSeedHook()) # load param (if necessary) and run if cfg.get('resume_from') is not None: runner.resume(cfg.resume_from) elif cfg.get('load_from') is not None: runner.load_checkpoint(cfg.load_from) runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs) if __name__ == '__main__': main()