diff --git a/examples/config_cifar10.py b/examples/config_cifar10.py new file mode 100644 index 000000000..e5431f9db --- /dev/null +++ b/examples/config_cifar10.py @@ -0,0 +1,31 @@ +# model settings +model = 'resnet18' +# dataset settings +data_root = '/mnt/SSD/dataset/cifar10' +mean = [0.4914, 0.4822, 0.4465] +std = [0.2023, 0.1994, 0.2010] +batch_size = 64 + +# optimizer and learning rate +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4) +lr_policy = dict(policy='step', step=2) + +# runtime settings +work_dir = './demo' +gpus = range(2) +data_workers = 2 # data workers per gpu +checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch +workflow = [('train', 1), ('val', 1)] +max_epoch = 6 +resume_from = None +load_from = None + +# logging settings +log_level = 'INFO' +log_cfg = dict( + # log at every 50 iterations + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'), + ]) diff --git a/examples/resnet_cifar.py b/examples/resnet_cifar.py new file mode 100644 index 000000000..84e213f3c --- /dev/null +++ b/examples/resnet_cifar.py @@ -0,0 +1,132 @@ +# copied from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py + +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def resnet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def resnet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def resnet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def resnet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def resnet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) diff --git a/examples/train_cifar10.py b/examples/train_cifar10.py new file mode 100644 index 000000000..b52f65411 --- /dev/null +++ b/examples/train_cifar10.py @@ -0,0 +1,102 @@ +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from mmcv import Config +from mmcv.torchpack import Runner +from torchvision import datasets, transforms + +import resnet_cifar + + +def accuracy(output, target, topk=(1, )): + """Computes the precision@k for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def batch_processor(model, data, train_mode): + img, label = data + label = label.cuda(non_blocking=True) + pred = model(img) + loss = F.cross_entropy(pred, label) + acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5)) + log_vars = OrderedDict() + log_vars['loss'] = loss.item() + log_vars['acc_top1'] = acc_top1.item() + log_vars['acc_top5'] = acc_top5.item() + outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0)) + return outputs + + +def parse_args(): + parser = ArgumentParser(description='Train CIFAR-10 classification') + parser.add_argument('config', help='train config file path') + return parser.parse_args() + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + model = getattr(resnet_cifar, cfg.model)() + model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda() + + normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std) + train_dataset = datasets.CIFAR10( + root=cfg.data_root, + train=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + val_dataset = datasets.CIFAR10( + root=cfg.data_root, + transform=transforms.Compose([ + transforms.ToTensor(), + normalize, + ])) + + num_workers = cfg.data_workers * len(cfg.gpus) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True) + + runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir) + runner.register_default_hooks( + lr_config=cfg.lr_policy, + checkpoint_config=cfg.checkpoint_cfg, + log_config=cfg.log_cfg) + + if cfg.get('resume_from') is not None: + runner.resume(cfg.resume_from) + elif cfg.get('load_from') is not None: + runner.load_checkpoint(cfg.load_from) + + runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch) + + +if __name__ == '__main__': + main() diff --git a/mmcv/torchpack/__init__.py b/mmcv/torchpack/__init__.py new file mode 100644 index 000000000..607008200 --- /dev/null +++ b/mmcv/torchpack/__init__.py @@ -0,0 +1,5 @@ +from .hooks import * +from .io import * +from .parallel import * +from .runner import * +from .utils import * diff --git a/mmcv/torchpack/hooks/__init__.py b/mmcv/torchpack/hooks/__init__.py new file mode 100644 index 000000000..ac1636301 --- /dev/null +++ b/mmcv/torchpack/hooks/__init__.py @@ -0,0 +1,7 @@ +from .hook import Hook +from .checkpoint_saver import CheckpointSaverHook +from .closure import ClosureHook +from .lr_updater import LrUpdaterHook +from .optimizer_stepper import OptimizerStepperHook +from .iter_timer import IterTimerHook +from .logger import * diff --git a/mmcv/torchpack/hooks/checkpoint_saver.py b/mmcv/torchpack/hooks/checkpoint_saver.py new file mode 100644 index 000000000..1256cda5b --- /dev/null +++ b/mmcv/torchpack/hooks/checkpoint_saver.py @@ -0,0 +1,25 @@ +from .hook import Hook +from ..utils import master_only + + +class CheckpointSaverHook(Hook): + + def __init__(self, + interval=-1, + save_optimizer=True, + out_dir=None, + **kwargs): + self.interval = interval + self.save_optimizer = save_optimizer + self.out_dir = out_dir + self.args = kwargs + + @master_only + def after_train_epoch(self, runner): + if not self.every_n_epochs(runner, self.interval): + return + + if not self.out_dir: + self.out_dir = runner.work_dir + runner.save_checkpoint( + self.out_dir, save_optimizer=self.save_optimizer, **self.args) diff --git a/mmcv/torchpack/hooks/closure.py b/mmcv/torchpack/hooks/closure.py new file mode 100644 index 000000000..8087d985b --- /dev/null +++ b/mmcv/torchpack/hooks/closure.py @@ -0,0 +1,9 @@ +from .hook import Hook + + +class ClosureHook(Hook): + + def __init__(self, fn_name, fn): + assert hasattr(self, fn_name) + assert callable(fn) + setattr(self, fn_name, fn) diff --git a/mmcv/torchpack/hooks/hook.py b/mmcv/torchpack/hooks/hook.py new file mode 100644 index 000000000..c2e07579e --- /dev/null +++ b/mmcv/torchpack/hooks/hook.py @@ -0,0 +1,55 @@ +class Hook(object): + + def before_run(self, runner): + pass + + def after_run(self, runner): + pass + + def before_epoch(self, runner): + pass + + def after_epoch(self, runner): + pass + + def before_iter(self, runner): + pass + + def after_iter(self, runner): + pass + + def before_train_epoch(self, runner): + self.before_epoch(runner) + + def before_val_epoch(self, runner): + self.before_epoch(runner) + + def after_train_epoch(self, runner): + self.after_epoch(runner) + + def after_val_epoch(self, runner): + self.after_epoch(runner) + + def before_train_iter(self, runner): + self.before_iter(runner) + + def before_val_iter(self, runner): + self.before_iter(runner) + + def after_train_iter(self, runner): + self.after_iter(runner) + + def after_val_iter(self, runner): + self.after_iter(runner) + + def every_n_epochs(self, runner, n): + return (runner.epoch + 1) % n == 0 if n > 0 else False + + def every_n_inner_iters(self, runner, n): + return (runner.inner_iter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, runner, n): + return (runner.iter + 1) % n == 0 if n > 0 else False + + def end_of_epoch(self, runner): + return runner.inner_iter + 1 == len(runner.data_loader) diff --git a/mmcv/torchpack/hooks/iter_timer.py b/mmcv/torchpack/hooks/iter_timer.py new file mode 100644 index 000000000..13b2876ff --- /dev/null +++ b/mmcv/torchpack/hooks/iter_timer.py @@ -0,0 +1,16 @@ +import time + +from .hook import Hook + + +class IterTimerHook(Hook): + + def before_epoch(self, runner): + self.t = time.time() + + def before_iter(self, runner): + runner.log_buffer.update({'data_time': time.time() - self.t}) + + def after_iter(self, runner): + runner.log_buffer.update({'time': time.time() - self.t}) + self.t = time.time() diff --git a/mmcv/torchpack/hooks/logger/__init__.py b/mmcv/torchpack/hooks/logger/__init__.py new file mode 100644 index 000000000..cb3a51f18 --- /dev/null +++ b/mmcv/torchpack/hooks/logger/__init__.py @@ -0,0 +1,4 @@ +from .base import LoggerHook +from .pavi import PaviClient, PaviLoggerHook +from .tensorboard import TensorboardLoggerHook +from .text import TextLoggerHook diff --git a/mmcv/torchpack/hooks/logger/base.py b/mmcv/torchpack/hooks/logger/base.py new file mode 100644 index 000000000..e35e049ac --- /dev/null +++ b/mmcv/torchpack/hooks/logger/base.py @@ -0,0 +1,49 @@ +from abc import ABCMeta, abstractmethod + +from ..hook import Hook + + +class LoggerHook(Hook): + """Base class for logger hooks.""" + + __metaclass__ = ABCMeta + + def __init__(self, interval=10, ignore_last=True, reset_flag=False): + self.interval = interval + self.ignore_last = ignore_last + self.reset_flag = reset_flag + + @abstractmethod + def log(self, runner): + pass + + def before_run(self, runner): + for hook in runner.hooks[::-1]: + if isinstance(hook, LoggerHook): + hook.reset_flag = True + break + + def before_epoch(self, runner): + runner.log_buffer.clear() # clear logs of last epoch + + def after_train_iter(self, runner): + if self.every_n_inner_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif self.end_of_epoch(runner) and not self.ignore_last: + # not precise but more stable + runner.log_buffer.average(self.interval) + + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_train_epoch(self, runner): + if runner.log_buffer.ready: + self.log(runner) + + def after_val_epoch(self, runner): + runner.log_buffer.average() + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() diff --git a/mmcv/torchpack/hooks/logger/pavi.py b/mmcv/torchpack/hooks/logger/pavi.py new file mode 100644 index 000000000..f07dfb0f5 --- /dev/null +++ b/mmcv/torchpack/hooks/logger/pavi.py @@ -0,0 +1,147 @@ +from __future__ import print_function + +import os +import time +from datetime import datetime +from threading import Thread + +import requests +from six.moves.queue import Empty, Queue + +from .base import LoggerHook +from ...utils import master_only, get_host_info + + +class PaviClient(object): + + def __init__(self, url, username=None, password=None, instance_id=None): + self.url = url + self.username = self._get_env_var(username, 'PAVI_USERNAME') + self.password = self._get_env_var(password, 'PAVI_PASSWORD') + self.instance_id = instance_id + self.log_queue = None + + def _get_env_var(self, var, env_var): + if var is not None: + return str(var) + + var = os.getenv(env_var) + if not var: + raise ValueError( + '"{}" is neither specified nor defined as env variables'. + format(env_var)) + return var + + def connect(self, + model_name, + work_dir=None, + info=dict(), + timeout=5, + logger=None): + if logger: + log_info = logger.info + log_error = logger.error + else: + log_info = log_error = print + log_info('connecting pavi service {}...'.format(self.url)) + post_data = dict( + time=str(datetime.now()), + username=self.username, + password=self.password, + instance_id=self.instance_id, + model=model_name, + work_dir=os.path.abspath(work_dir) if work_dir else '', + session_file=info.get('session_file', ''), + session_text=info.get('session_text', ''), + model_text=info.get('model_text', ''), + device=get_host_info()) + try: + response = requests.post(self.url, json=post_data, timeout=timeout) + except Exception as ex: + log_error('fail to connect to pavi service: {}'.format(ex)) + else: + if response.status_code == 200: + self.instance_id = response.text + log_info('pavi service connected, instance_id: {}'.format( + self.instance_id)) + self.log_queue = Queue() + self.log_thread = Thread(target=self.post_worker_fn) + self.log_thread.daemon = True + self.log_thread.start() + return True + else: + log_error('fail to connect to pavi service, status code: ' + '{}, err message: {}'.format(response.status_code, + response.reason)) + return False + + def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3): + while True: + try: + log = self.log_queue.get(timeout=queue_timeout) + except Empty: + time.sleep(1) + except Exception as ex: + print('fail to get logs from queue: {}'.format(ex)) + else: + retry = 0 + while retry < max_retry: + try: + response = requests.post( + self.url, json=log, timeout=req_timeout) + except Exception as ex: + retry += 1 + print('error when posting logs to pavi: {}'.format(ex)) + else: + status_code = response.status_code + if status_code == 200: + break + else: + print('unexpected status code: %d, err msg: %s', + status_code, response.reason) + retry += 1 + if retry == max_retry: + print('fail to send logs of iteration %d', log['iter_num']) + + def log(self, phase, iter, outputs): + if self.log_queue is not None: + logs = { + 'time': str(datetime.now()), + 'instance_id': self.instance_id, + 'flow_id': phase, + 'iter_num': iter, + 'outputs': outputs, + 'msg': '' + } + self.log_queue.put(logs) + + +class PaviLoggerHook(LoggerHook): + + def __init__(self, + url, + username=None, + password=None, + instance_id=None, + interval=10, + reset_meter=True, + ignore_last=True): + self.pavi = PaviClient(url, username, password, instance_id) + super(PaviLoggerHook, self).__init__(interval, reset_meter, + ignore_last) + + @master_only + def connect(self, + model_name, + work_dir=None, + info=dict(), + timeout=5, + logger=None): + return self.pavi.connect(model_name, work_dir, info, timeout, logger) + + @master_only + def log(self, runner): + log_outs = runner.log_buffer.output.copy() + log_outs.pop('time', None) + log_outs.pop('data_time', None) + self.pavi.log(runner.mode, runner.iter, log_outs) diff --git a/mmcv/torchpack/hooks/logger/tensorboard.py b/mmcv/torchpack/hooks/logger/tensorboard.py new file mode 100644 index 000000000..6c5dd300b --- /dev/null +++ b/mmcv/torchpack/hooks/logger/tensorboard.py @@ -0,0 +1,37 @@ +from .base import LoggerHook +from ...utils import master_only + + +class TensorboardLoggerHook(LoggerHook): + + def __init__(self, + log_dir, + interval=10, + reset_meter=True, + ignore_last=True): + super(TensorboardLoggerHook, self).__init__(interval, reset_meter, + ignore_last) + self.log_dir = log_dir + + @master_only + def before_run(self, runner): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorflow and tensorboardX ' + 'to use TensorboardLoggerHook.') + else: + self.writer = SummaryWriter(self.log_dir) + + @master_only + def log(self, runner): + for var in runner.log_buffer.output: + if var in ['time', 'data_time']: + continue + tag = '{}/{}'.format(var, runner.mode) + self.writer.add_scalar(tag, runner.log_buffer.output[var], + runner.iter) + + @master_only + def after_run(self, runner): + self.writer.close() diff --git a/mmcv/torchpack/hooks/logger/text.py b/mmcv/torchpack/hooks/logger/text.py new file mode 100644 index 000000000..06ac3475e --- /dev/null +++ b/mmcv/torchpack/hooks/logger/text.py @@ -0,0 +1,26 @@ +from .base import LoggerHook + + +class TextLoggerHook(LoggerHook): + + def log(self, runner): + if runner.mode == 'train': + lr_str = ', '.join( + ['{:.5f}'.format(lr) for lr in runner.current_lr()]) + log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format( + runner.epoch + 1, runner.inner_iter + 1, + len(runner.data_loader), lr_str) + else: + log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, + runner.inner_iter + 1) + if 'time' in runner.log_buffer.output: + log_str += ( + 'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '. + format(log=runner.log_buffer.output)) + log_items = [] + for name, val in runner.log_buffer.output.items(): + if name in ['time', 'data_time']: + continue + log_items.append('{}: {:.4f}'.format(name, val)) + log_str += ', '.join(log_items) + runner.logger.info(log_str) diff --git a/mmcv/torchpack/hooks/lr_updater.py b/mmcv/torchpack/hooks/lr_updater.py new file mode 100644 index 000000000..27709bb17 --- /dev/null +++ b/mmcv/torchpack/hooks/lr_updater.py @@ -0,0 +1,163 @@ +from __future__ import division + +from .hook import Hook + + +class LrUpdaterHook(Hook): + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.1, + **kwargs): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + '"{}" is not a supported type for warming up, valid types' + ' are "constant" and "linear"'.format(warmup)) + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_ratio" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + + self.base_lr = [] # initial lr for all param groups + self.regular_lr = [] # expected lr if no warming up is performed + + def _set_lr(self, runner, lr_groups): + for param_group, lr in zip(runner.optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def get_lr(self, runner, base_lr): + raise NotImplementedError + + def get_regular_lr(self, runner): + return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] + + def get_warmup_lr(self, cur_iters): + if self.warmup == 'constant': + warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio) + warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_lr = [_lr * k for _lr in self.regular_lr] + return warmup_lr + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, + # it will be set according to the optimizer params + for group in runner.optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + self.base_lr = [ + group['initial_lr'] for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if not self.by_epoch: + return + self.regular_lr = self.get_regular_lr(runner) + self._set_lr(runner, self.regular_lr) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_lr = self.get_regular_lr(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + + +class FixedLrUpdaterHook(LrUpdaterHook): + + def __init__(self, **kwargs): + super(FixedLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + return base_lr + + +class StepLrUpdaterHook(LrUpdaterHook): + + def __init__(self, step, gamma=0.1, **kwargs): + assert isinstance(step, (list, int)) + if isinstance(step, list): + for s in step: + assert isinstance(s, int) and s > 0 + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + super(StepLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + + if isinstance(self.step, int): + return base_lr * (self.gamma**(progress // self.step)) + + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + return base_lr * self.gamma**exp + + +class ExpLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, **kwargs): + self.gamma = gamma + super(ExpLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * self.gamma**progress + + +class PolyLrUpdaterHook(LrUpdaterHook): + + def __init__(self, power=1., **kwargs): + self.power = power + super(PolyLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + return base_lr * (1 - progress / max_progress)**self.power + + +class InvLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, power=1., **kwargs): + self.gamma = gamma + self.power = power + super(InvLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * (1 + self.gamma * progress)**(-self.power) diff --git a/mmcv/torchpack/hooks/optimizer_stepper.py b/mmcv/torchpack/hooks/optimizer_stepper.py new file mode 100644 index 000000000..dea06530d --- /dev/null +++ b/mmcv/torchpack/hooks/optimizer_stepper.py @@ -0,0 +1,21 @@ +from torch.nn.utils import clip_grad + +from .hook import Hook + + +class OptimizerStepperHook(Hook): + + def __init__(self, grad_clip=False, max_norm=35, norm_type=2): + self.grad_clip = grad_clip + self.max_norm = max_norm + self.norm_type = norm_type + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + runner.outputs['loss'].backward() + if self.grad_clip: + clip_grad.clip_grad_norm_( + filter(lambda p: p.requires_grad, runner.model.parameters()), + max_norm=self.max_norm, + norm_type=self.norm_type) + runner.optimizer.step() diff --git a/mmcv/torchpack/io.py b/mmcv/torchpack/io.py new file mode 100644 index 000000000..bc94b3070 --- /dev/null +++ b/mmcv/torchpack/io.py @@ -0,0 +1,157 @@ +import os.path as osp +import time +from collections import OrderedDict + +import mmcv +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel +from torch.utils import model_zoo + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + own_state = module.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + unexpected_keys.append(name) + continue + if isinstance(param, torch.nn.Parameter): + # backwards compatibility for serialized parameters + param = param.data + + try: + own_state[name].copy_(param) + except Exception: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), + param.size())) + missing_keys = set(own_state.keys()) - set(state_dict.keys()) + + err_msg = [] + if unexpected_keys: + err_msg.append('unexpected key in source state_dict: {}\n'.format( + ', '.join(unexpected_keys))) + if missing_keys: + err_msg.append('missing keys in source state_dict: {}\n'.format( + ', '.join(missing_keys))) + err_msg = '\n'.join(err_msg) + if err_msg: + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warn(err_msg) + else: + print(err_msg) + + +def load_checkpoint(model, + filename, + map_location=None, + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Either a filepath or URL or modelzoll://xxxxxxx. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + # load checkpoint from modelzoo or file or url + if filename.startswith('modelzoo://'): + from torchvision.models.resnet import model_urls + model_name = filename[11:] + checkpoint = model_zoo.load_url(model_urls[model_name]) + elif filename.startswith(('http://', 'https://')): + checkpoint = model_zoo.load_url(filename) + else: + if not osp.isfile(filename): + raise IOError('{} is not a checkpoint file'.format(filename)) + checkpoint = torch.load(filename, map_location=map_location) + # get state_dict from checkpoint + if isinstance(checkpoint, OrderedDict): + state_dict = checkpoint + elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + raise RuntimeError( + 'No state_dict found in checkpoint file {}'.format(filename)) + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} + # load state_dict + if isinstance(model, (DataParallel, DistributedDataParallel)): + load_state_dict(model.module, state_dict, strict, logger) + else: + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError('meta must be a dict or None, but got {}'.format( + type(meta))) + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + mmcv.mkdir_or_exist(osp.dirname(filename)) + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(model.state_dict()) + } + if optimizer is not None: + checkpoint['optimizer'] = optimizer.state_dict() + + torch.save(checkpoint, filename) diff --git a/mmcv/torchpack/parallel.py b/mmcv/torchpack/parallel.py new file mode 100644 index 000000000..923d066b1 --- /dev/null +++ b/mmcv/torchpack/parallel.py @@ -0,0 +1,74 @@ +import multiprocessing + +import torch + +from .io import load_checkpoint + + +def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func, + gpu_id, idx_queue, result_queue): + model = model_cls(**model_kwargs) + load_checkpoint(model, checkpoint, map_location='cpu') + torch.cuda.set_device(gpu_id) + model.cuda() + model.eval() + with torch.no_grad(): + while True: + idx = idx_queue.get() + data = dataset[idx] + result = model(**data_func(data, gpu_id)) + result_queue.put((idx, result)) + + +def parallel_test(model_cls, + model_kwargs, + checkpoint, + dataset, + data_func, + gpus, + workers_per_gpu=1): + """Parallel testing on multiple GPUs. + + Args: + model_cls (type): Model class type. + model_kwargs (dict): Arguments to init the model. + checkpoint (str): Checkpoint filepath. + dataset (:obj:`Dataset`): The dataset to be tested. + data_func (callable): The function that generates model inputs. + gpus (list[int]): GPU ids to be used. + workers_per_gpu (int): Number of processes on each GPU. It is possible + to run multiple workers on each GPU. + + Returns: + list: Test results. + """ + ctx = multiprocessing.get_context('spawn') + idx_queue = ctx.Queue() + result_queue = ctx.Queue() + num_workers = len(gpus) * workers_per_gpu + workers = [ + ctx.Process( + target=worker_func, + args=(model_cls, model_kwargs, checkpoint, dataset, data_func, + gpus[i % len(gpus)], idx_queue, result_queue)) + for i in range(num_workers) + ] + for w in workers: + w.daemon = True + w.start() + + for i in range(len(dataset)): + idx_queue.put(i) + + results = [None for _ in range(len(dataset))] + import cvbase as cvb + prog_bar = cvb.ProgressBar(task_num=len(dataset)) + for _ in range(len(dataset)): + idx, res = result_queue.get() + results[idx] = res + prog_bar.update() + print('\n') + for worker in workers: + worker.terminate() + + return results diff --git a/mmcv/torchpack/runner/__init__.py b/mmcv/torchpack/runner/__init__.py new file mode 100644 index 000000000..b64d41e6d --- /dev/null +++ b/mmcv/torchpack/runner/__init__.py @@ -0,0 +1,2 @@ +from .log_buffer import LogBuffer +from .runner import Runner diff --git a/mmcv/torchpack/runner/log_buffer.py b/mmcv/torchpack/runner/log_buffer.py new file mode 100644 index 000000000..c83d132cc --- /dev/null +++ b/mmcv/torchpack/runner/log_buffer.py @@ -0,0 +1,39 @@ +from collections import OrderedDict +import numpy as np + + +class LogBuffer(object): + + def __init__(self): + self.val_history = OrderedDict() + self.n_history = OrderedDict() + self.output = OrderedDict() + self.ready = False + + def clear(self): + self.val_history.clear() + self.n_history.clear() + self.clear_output() + + def clear_output(self): + self.output.clear() + self.ready = False + + def update(self, vars, count=1): + assert isinstance(vars, dict) + for key, var in vars.items(): + if key not in self.val_history: + self.val_history[key] = [] + self.n_history[key] = [] + self.val_history[key].append(var) + self.n_history[key].append(count) + + def average(self, n=0): + """Average latest n values or all values""" + assert n >= 0 + for key in self.val_history: + values = np.array(self.val_history[key][-n:]) + nums = np.array(self.n_history[key][-n:]) + avg = np.sum(values * nums) / np.sum(nums) + self.output[key] = avg + self.ready = True diff --git a/mmcv/torchpack/runner/runner.py b/mmcv/torchpack/runner/runner.py new file mode 100644 index 000000000..ba4679d4b --- /dev/null +++ b/mmcv/torchpack/runner/runner.py @@ -0,0 +1,344 @@ +import logging +import os.path as osp +import time + +import mmcv +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from .log_buffer import LogBuffer +from .. import hooks +from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook, + OptimizerStepperHook) +from ..io import load_checkpoint, save_checkpoint +from ..utils import (get_dist_info, get_host_info, get_time_str, + add_file_handler, obj_from_dict) + + +class Runner(object): + """A training helper for PyTorch.""" + + def __init__(self, + model, + optimizer, + batch_processor, + work_dir=None, + log_level=logging.INFO): + assert callable(batch_processor) + self.model = model + self.optimizer = self.init_optimizer(optimizer) + self.batch_processor = batch_processor + + # create work_dir + if mmcv.is_str(work_dir): + self.work_dir = osp.abspath(work_dir) + mmcv.mkdir_or_exist(self.work_dir) + elif work_dir is None: + self.work_dir = None + else: + raise TypeError('"work_dir" must be a str or None') + + # get model name from the model class + if isinstance(self.model, (DataParallel, DistributedDataParallel)): + self._model_name = self.model.module.__class__.__name__ + else: + self._model_name = self.model.__class__.__name__ + + self._rank, self._world_size = get_dist_info() + self.logger = self.init_logger(work_dir, log_level) + self.log_buffer = LogBuffer() + + self.mode = None + self._hooks = [] + self._epoch = 0 + self._iter = 0 + self._inner_iter = 0 + self._max_epochs = 0 + self._max_iters = 0 + + @property + def model_name(self): + """str: Name of the model, usually the module class name.""" + return self._model_name + + @property + def rank(self): + """int: Rank of current process. (distributed training)""" + return self._rank + + @property + def world_size(self): + """int: Number of processes participating in the job. + (distributed training)""" + return self._world_size + + @property + def hooks(self): + """list[:obj:`Hook`]: A list of registered hooks.""" + return self._hooks + + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + + @property + def inner_iter(self): + """int: Iteration in an epoch.""" + return self._inner_iter + + @property + def max_epochs(self): + """int: Maximum training epochs.""" + return self._max_epochs + + @property + def max_iters(self): + """int: Maximum training iterations.""" + return self._max_iters + + def init_optimizer(self, optimizer): + """Init the optimizer. + + Args: + optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an + optimizer object or a dict used for constructing the optimizer. + An example of the dict: ``{'algorithm': 'SGD', 'lr': 0.02, + 'momentum': 0.9, 'weight_decay': 0.0001}``. + + Returns: + :obj:`~torch.optim.Optimizer`: An optimizer object. + """ + if isinstance(optimizer, dict): + optimizer = obj_from_dict( + optimizer, torch.optim, dict(params=self.model.parameters())) + elif not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError( + 'optimizer must be either an Optimizer object or a dict, ' + 'but got {}'.format(type(optimizer))) + return optimizer + + def init_logger(self, log_dir=None, level=logging.INFO): + """Init the logger. + + Args: + log_dir(str, optional): Log file directory. If not specified, no + log file will be used. + level (int or str): See the built-in python logging module. + + Returns: + :obj:`~logging.Logger`: Python logger. + """ + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=level) + logger = logging.getLogger(__name__) + if log_dir: + filename = '{}_{}.log'.format(get_time_str(), self.rank) + log_file = osp.join(log_dir, filename) + add_file_handler(logger, log_file, level=level) + return logger + + def current_lr(self): + """Get current learning rates. + + Returns: + list: Current learning rate of all param groups. + """ + return [group['lr'] for group in self.optimizer.param_groups] + + def register_hook(self, hook, priority=50): + """Register a hook into the hook list. + + Args: + hook (:obj:`Hook`): The hook to be registered. + priority (int): Hook priority. Lower value means higher priority. + """ + assert isinstance(hook, Hook) + assert isinstance(priority, int) and priority >= 0 and priority <= 100 + if hasattr(hook, 'priority'): + raise ValueError('"priority" is a reserved attribute for hooks') + hook.priority = priority + # insert the hook to a sorted list + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + if priority >= self._hooks[i].priority: + self._hooks.insert(i + 1, hook) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook) + + def call_hook(self, fn_name): + for hook in self._hooks: + getattr(hook, fn_name)(self) + + def load_checkpoint(self, filename, map_location='cpu', strict=False): + self.logger.info('load checkpoint from %s', filename) + return load_checkpoint(self.model, filename, map_location, strict, + self.logger) + + def save_checkpoint(self, + out_dir, + filename_tmpl='epoch_{}.pth', + save_optimizer=True, + meta=None): + if meta is None: + meta = dict(epoch=self.epoch + 1, iter=self.iter) + else: + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = osp.join(out_dir, filename_tmpl.format(self.epoch)) + linkname = osp.join(out_dir, 'latest.pth') + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta) + mmcv.symlink(filename, linkname) + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._max_iters = self._max_epochs * len(data_loader) + self.call_hook('before_train_epoch') + + for i, data_batch in enumerate(data_loader): + self._inner_iter = i + self.call_hook('before_train_iter') + outputs = self.batch_processor( + self.model, data_batch, train_mode=True, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('batch_processor() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], + outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_train_iter') + self._iter += 1 + + self.call_hook('after_train_epoch') + self._epoch += 1 + + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + self.call_hook('before_val_epoch') + + for i, data_batch in enumerate(data_loader): + self._inner_iter = i + self.call_hook('before_val_iter') + outputs = self.batch_processor( + self.model, data_batch, train_mode=False, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('batch_processor() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], + outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_val_iter') + + self.call_hook('after_val_epoch') + + def resume(self, checkpoint, resume_optimizer=True, + map_location='default'): + if map_location == 'default': + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + if 'optimizer' in checkpoint and resume_optimizer: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) + + def run(self, data_loaders, workflow, max_epochs, **kwargs): + assert isinstance(data_loaders, list) + assert mmcv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + + self._max_epochs = max_epochs + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) + self.call_hook('before_run') + + while self.epoch < max_epochs: + for i, flow in enumerate(workflow): + mode, epochs = flow + if isinstance(mode, str): # self.train() + if not hasattr(self, mode): + raise ValueError( + 'runner has no method named "{}" to run an epoch'. + format(mode)) + epoch_runner = getattr(self, mode) + elif callable(mode): # custom train() + epoch_runner = mode + else: + raise TypeError('mode in workflow must be a str or ' + 'callable function, not {}'.format( + type(mode))) + for _ in range(epochs): + if mode == 'train' and self.epoch >= max_epochs: + return + epoch_runner(data_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_run') + + def register_lr_hooks(self, lr_config): + if isinstance(lr_config, LrUpdaterHook): + self.register_hook(lr_config) + elif isinstance(lr_config, dict): + assert 'policy' in lr_config + from ..hooks import lr_updater + hook_name = lr_config['policy'].title() + 'LrUpdaterHook' + if not hasattr(lr_updater, hook_name): + raise ValueError('"{}" does not exist'.format(hook_name)) + hook_cls = getattr(lr_updater, hook_name) + self.register_hook(hook_cls(**lr_config)) + else: + raise TypeError('"lr_config" must be either a LrUpdaterHook object' + ' or dict, not {}'.format(type(lr_config))) + + def register_logger_hooks(self, log_config): + log_interval = log_config['interval'] + for info in log_config['hooks']: + logger_hook = obj_from_dict( + info, hooks, default_args=dict(interval=log_interval)) + self.register_hook(logger_hook, priority=60) + + def register_default_hooks(self, + lr_config, + grad_clip_config=None, + checkpoint_config=None, + log_config=None): + """Register several default hooks. + + Default hooks include: + - LrUpdaterHook + - OptimizerStepperHook + - CheckpointSaverHook + - IterTimerHook + - LoggerHook + """ + if grad_clip_config is None: + grad_clip_config = {} + if checkpoint_config is None: + checkpoint_config = {} + self.register_lr_hooks(lr_config) + self.register_hook(OptimizerStepperHook(**grad_clip_config)) + self.register_hook(CheckpointSaverHook(**checkpoint_config)) + self.register_hook(IterTimerHook()) + if log_config is not None: + self.register_logger_hooks(log_config) diff --git a/mmcv/torchpack/utils.py b/mmcv/torchpack/utils.py new file mode 100644 index 000000000..eaef7c47c --- /dev/null +++ b/mmcv/torchpack/utils.py @@ -0,0 +1,77 @@ +import functools +import logging +import time +from getpass import getuser +from socket import gethostname + +import mmcv +import torch.distributed as dist + + +def get_host_info(): + return '{}@{}'.format(getuser(), gethostname()) + + +def get_dist_info(): + if dist._initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def add_file_handler(logger, filename=None, mode='w', level=logging.INFO): + file_handler = logging.FileHandler(filename, mode) + file_handler.setFormatter( + logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(file_handler) + return logger + + +def obj_from_dict(info, module, default_args=None): + """Initialize an object from dict. + + The dict must contain the key "type", which indicates the object type, it + can be either a string or type, such as "list" or ``list``. Remaining + fields are treated as the arguments for constructing the object. + + Args: + info (dict): Object types and arguments. + module (:class:`module`): Module which may containing expected object + classes. + default_args (dict, optional): Default arguments for initializing the + object. + + Returns: + + """ + assert isinstance(info, dict) and 'type' in info + assert isinstance(default_args, dict) or default_args is None + args = info.copy() + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + obj_type = getattr(module, obj_type) + elif not isinstance(obj_type, type): + raise TypeError('type must be a str or valid type, but got {}'.format( + type(obj_type))) + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_type(**args)