mmcv/examples/train_cifar10.py

177 lines
5.3 KiB
Python
Raw Normal View History

2018-09-29 23:53:58 +08:00
import logging
import os
2018-08-28 23:27:59 +08:00
from argparse import ArgumentParser
from collections import OrderedDict
import resnet_cifar
2018-08-28 23:27:59 +08:00
import torch
2018-09-29 23:53:58 +08:00
import torch.distributed as dist
import torch.multiprocessing as mp
2018-08-28 23:27:59 +08:00
import torch.nn.functional as F
2018-09-29 23:53:58 +08:00
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
2018-08-28 23:27:59 +08:00
from torchvision import datasets, transforms
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
2018-08-28 23:27:59 +08:00
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k."""
2018-08-28 23:27:59 +08:00
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def batch_processor(model, data, train_mode):
img, label = data
label = label.cuda(non_blocking=True)
pred = model(img)
loss = F.cross_entropy(pred, label)
acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
log_vars = OrderedDict()
log_vars['loss'] = loss.item()
log_vars['acc_top1'] = acc_top1.item()
log_vars['acc_top5'] = acc_top5.item()
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs
2018-09-29 23:53:58 +08:00
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def init_dist(backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
2018-08-28 23:27:59 +08:00
def parse_args():
parser = ArgumentParser(description='Train CIFAR-10 classification')
parser.add_argument('config', help='train config file path')
2018-09-29 23:53:58 +08:00
parser.add_argument(
'--launcher',
choices=['none', 'pytorch'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
2018-08-28 23:27:59 +08:00
return parser.parse_args()
def main():
args = parse_args()
2018-09-29 23:53:58 +08:00
2018-08-28 23:27:59 +08:00
cfg = Config.fromfile(args.config)
2018-09-29 23:53:58 +08:00
logger = get_logger(cfg.log_level)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
logger.info('Disabled distributed training.')
else:
dist = True
init_dist(**cfg.dist_params)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
if rank != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
# build datasets and dataloaders
2018-08-28 23:27:59 +08:00
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
train_dataset = datasets.CIFAR10(
root=cfg.data_root,
train=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.CIFAR10(
root=cfg.data_root,
2018-09-29 23:53:58 +08:00
train=False,
2018-08-28 23:27:59 +08:00
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
2018-09-29 23:53:58 +08:00
if dist:
num_workers = cfg.data_workers
assert cfg.batch_size % world_size == 0
batch_size = cfg.batch_size // world_size
train_sampler = DistributedSampler(train_dataset, world_size, rank)
val_sampler = DistributedSampler(val_dataset, world_size, rank)
shuffle = False
else:
num_workers = cfg.data_workers * len(cfg.gpus)
batch_size = cfg.batch_size
train_sampler = None
val_sampler = None
shuffle = True
train_loader = DataLoader(
2018-08-28 23:27:59 +08:00
train_dataset,
2018-09-29 23:53:58 +08:00
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
num_workers=num_workers)
val_loader = DataLoader(
2018-08-28 23:27:59 +08:00
val_dataset,
2018-09-29 23:53:58 +08:00
batch_size=batch_size,
2018-08-28 23:27:59 +08:00
shuffle=False,
2018-09-29 23:53:58 +08:00
sampler=val_sampler,
num_workers=num_workers)
2018-08-28 23:27:59 +08:00
2018-09-29 23:53:58 +08:00
# build model
model = getattr(resnet_cifar, cfg.model)()
if dist:
model = DistributedDataParallel(
model.cuda(), device_ids=[torch.cuda.current_device()])
else:
model = DataParallel(model, device_ids=cfg.gpus).cuda()
# build runner and register hooks
runner = Runner(
model,
batch_processor,
cfg.optimizer,
cfg.work_dir,
log_level=cfg.log_level)
runner.register_training_hooks(
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
if dist:
runner.register_hook(DistSamplerSeedHook())
# load param (if necessary) and run
2018-08-28 23:27:59 +08:00
if cfg.get('resume_from') is not None:
runner.resume(cfg.resume_from)
elif cfg.get('load_from') is not None:
runner.load_checkpoint(cfg.load_from)
2018-09-29 23:53:58 +08:00
runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
2018-08-28 23:27:59 +08:00
if __name__ == '__main__':
main()