add initial version of torchpack

pull/7/head
Kai Chen 2018-08-28 23:27:59 +08:00
parent 02ceae8327
commit ffdc1d457f
22 changed files with 1522 additions and 0 deletions

View File

@ -0,0 +1,31 @@
# model settings
model = 'resnet18'
# dataset settings
data_root = '/mnt/SSD/dataset/cifar10'
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
batch_size = 64
# optimizer and learning rate
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
lr_policy = dict(policy='step', step=2)
# runtime settings
work_dir = './demo'
gpus = range(2)
data_workers = 2 # data workers per gpu
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)]
max_epoch = 6
resume_from = None
load_from = None
# logging settings
log_level = 'INFO'
log_cfg = dict(
# log at every 50 iterations
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
])

View File

@ -0,0 +1,132 @@
# copied from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def resnet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
def resnet34():
return ResNet(BasicBlock, [3, 4, 6, 3])
def resnet50():
return ResNet(Bottleneck, [3, 4, 6, 3])
def resnet101():
return ResNet(Bottleneck, [3, 4, 23, 3])
def resnet152():
return ResNet(Bottleneck, [3, 8, 36, 3])

View File

@ -0,0 +1,102 @@
from argparse import ArgumentParser
from collections import OrderedDict
import torch
import torch.nn.functional as F
from mmcv import Config
from mmcv.torchpack import Runner
from torchvision import datasets, transforms
import resnet_cifar
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def batch_processor(model, data, train_mode):
img, label = data
label = label.cuda(non_blocking=True)
pred = model(img)
loss = F.cross_entropy(pred, label)
acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
log_vars = OrderedDict()
log_vars['loss'] = loss.item()
log_vars['acc_top1'] = acc_top1.item()
log_vars['acc_top5'] = acc_top5.item()
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs
def parse_args():
parser = ArgumentParser(description='Train CIFAR-10 classification')
parser.add_argument('config', help='train config file path')
return parser.parse_args()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
model = getattr(resnet_cifar, cfg.model)()
model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda()
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
train_dataset = datasets.CIFAR10(
root=cfg.data_root,
train=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.CIFAR10(
root=cfg.data_root,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
num_workers = cfg.data_workers * len(cfg.gpus)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(
lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)
if cfg.get('resume_from') is not None:
runner.resume(cfg.resume_from)
elif cfg.get('load_from') is not None:
runner.load_checkpoint(cfg.load_from)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,5 @@
from .hooks import *
from .io import *
from .parallel import *
from .runner import *
from .utils import *

View File

@ -0,0 +1,7 @@
from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook
from .closure import ClosureHook
from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook
from .iter_timer import IterTimerHook
from .logger import *

View File

@ -0,0 +1,25 @@
from .hook import Hook
from ..utils import master_only
class CheckpointSaverHook(Hook):
def __init__(self,
interval=-1,
save_optimizer=True,
out_dir=None,
**kwargs):
self.interval = interval
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.args = kwargs
@master_only
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)

View File

@ -0,0 +1,9 @@
from .hook import Hook
class ClosureHook(Hook):
def __init__(self, fn_name, fn):
assert hasattr(self, fn_name)
assert callable(fn)
setattr(self, fn_name, fn)

View File

@ -0,0 +1,55 @@
class Hook(object):
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)

View File

@ -0,0 +1,16 @@
import time
from .hook import Hook
class IterTimerHook(Hook):
def before_epoch(self, runner):
self.t = time.time()
def before_iter(self, runner):
runner.log_buffer.update({'data_time': time.time() - self.t})
def after_iter(self, runner):
runner.log_buffer.update({'time': time.time() - self.t})
self.t = time.time()

View File

@ -0,0 +1,4 @@
from .base import LoggerHook
from .pavi import PaviClient, PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook

View File

@ -0,0 +1,49 @@
from abc import ABCMeta, abstractmethod
from ..hook import Hook
class LoggerHook(Hook):
"""Base class for logger hooks."""
__metaclass__ = ABCMeta
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
self.interval = interval
self.ignore_last = ignore_last
self.reset_flag = reset_flag
@abstractmethod
def log(self, runner):
pass
def before_run(self, runner):
for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook):
hook.reset_flag = True
break
def before_epoch(self, runner):
runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner):
if self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif self.end_of_epoch(runner) and not self.ignore_last:
# not precise but more stable
runner.log_buffer.average(self.interval)
if runner.log_buffer.ready:
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_train_epoch(self, runner):
if runner.log_buffer.ready:
self.log(runner)
def after_val_epoch(self, runner):
runner.log_buffer.average()
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()

View File

@ -0,0 +1,147 @@
from __future__ import print_function
import os
import time
from datetime import datetime
from threading import Thread
import requests
from six.moves.queue import Empty, Queue
from .base import LoggerHook
from ...utils import master_only, get_host_info
class PaviClient(object):
def __init__(self, url, username=None, password=None, instance_id=None):
self.url = url
self.username = self._get_env_var(username, 'PAVI_USERNAME')
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
self.instance_id = instance_id
self.log_queue = None
def _get_env_var(self, var, env_var):
if var is not None:
return str(var)
var = os.getenv(env_var)
if not var:
raise ValueError(
'"{}" is neither specified nor defined as env variables'.
format(env_var))
return var
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
if logger:
log_info = logger.info
log_error = logger.error
else:
log_info = log_error = print
log_info('connecting pavi service {}...'.format(self.url))
post_data = dict(
time=str(datetime.now()),
username=self.username,
password=self.password,
instance_id=self.instance_id,
model=model_name,
work_dir=os.path.abspath(work_dir) if work_dir else '',
session_file=info.get('session_file', ''),
session_text=info.get('session_text', ''),
model_text=info.get('model_text', ''),
device=get_host_info())
try:
response = requests.post(self.url, json=post_data, timeout=timeout)
except Exception as ex:
log_error('fail to connect to pavi service: {}'.format(ex))
else:
if response.status_code == 200:
self.instance_id = response.text
log_info('pavi service connected, instance_id: {}'.format(
self.instance_id))
self.log_queue = Queue()
self.log_thread = Thread(target=self.post_worker_fn)
self.log_thread.daemon = True
self.log_thread.start()
return True
else:
log_error('fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code,
response.reason))
return False
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
while True:
try:
log = self.log_queue.get(timeout=queue_timeout)
except Empty:
time.sleep(1)
except Exception as ex:
print('fail to get logs from queue: {}'.format(ex))
else:
retry = 0
while retry < max_retry:
try:
response = requests.post(
self.url, json=log, timeout=req_timeout)
except Exception as ex:
retry += 1
print('error when posting logs to pavi: {}'.format(ex))
else:
status_code = response.status_code
if status_code == 200:
break
else:
print('unexpected status code: %d, err msg: %s',
status_code, response.reason)
retry += 1
if retry == max_retry:
print('fail to send logs of iteration %d', log['iter_num'])
def log(self, phase, iter, outputs):
if self.log_queue is not None:
logs = {
'time': str(datetime.now()),
'instance_id': self.instance_id,
'flow_id': phase,
'iter_num': iter,
'outputs': outputs,
'msg': ''
}
self.log_queue.put(logs)
class PaviLoggerHook(LoggerHook):
def __init__(self,
url,
username=None,
password=None,
instance_id=None,
interval=10,
reset_meter=True,
ignore_last=True):
self.pavi = PaviClient(url, username, password, instance_id)
super(PaviLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
@master_only
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
return self.pavi.connect(model_name, work_dir, info, timeout, logger)
@master_only
def log(self, runner):
log_outs = runner.log_buffer.output.copy()
log_outs.pop('time', None)
log_outs.pop('data_time', None)
self.pavi.log(runner.mode, runner.iter, log_outs)

View File

@ -0,0 +1,37 @@
from .base import LoggerHook
from ...utils import master_only
class TensorboardLoggerHook(LoggerHook):
def __init__(self,
log_dir,
interval=10,
reset_meter=True,
ignore_last=True):
super(TensorboardLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
try:
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError('Please install tensorflow and tensorboardX '
'to use TensorboardLoggerHook.')
else:
self.writer = SummaryWriter(self.log_dir)
@master_only
def log(self, runner):
for var in runner.log_buffer.output:
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter)
@master_only
def after_run(self, runner):
self.writer.close()

View File

@ -0,0 +1,26 @@
from .base import LoggerHook
class TextLoggerHook(LoggerHook):
def log(self, runner):
if runner.mode == 'train':
lr_str = ', '.join(
['{:.5f}'.format(lr) for lr in runner.current_lr()])
log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format(
runner.epoch + 1, runner.inner_iter + 1,
len(runner.data_loader), lr_str)
else:
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
runner.inner_iter + 1)
if 'time' in runner.log_buffer.output:
log_str += (
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
format(log=runner.log_buffer.output))
log_items = []
for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']:
continue
log_items.append('{}: {:.4f}'.format(name, val))
log_str += ', '.join(log_items)
runner.logger.info(log_str)

View File

@ -0,0 +1,163 @@
from __future__ import division
from .hook import Hook
class LrUpdaterHook(Hook):
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.1,
**kwargs):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
'"{}" is not a supported type for warming up, valid types'
' are "constant" and "linear"'.format(warmup))
if warmup is not None:
assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer'
assert 0 < warmup_ratio <= 1.0, \
'"warmup_ratio" must be in range (0,1]'
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.base_lr = [] # initial lr for all param groups
self.regular_lr = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups):
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr):
raise NotImplementedError
def get_regular_lr(self, runner):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in self.regular_lr]
return warmup_lr
def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
if not self.by_epoch:
return
self.regular_lr = self.get_regular_lr(runner)
self._set_lr(runner, self.regular_lr)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
elif self.by_epoch:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs):
super(FixedLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
return base_lr
class StepLrUpdaterHook(LrUpdaterHook):
def __init__(self, step, gamma=0.1, **kwargs):
assert isinstance(step, (list, int))
if isinstance(step, list):
for s in step:
assert isinstance(s, int) and s > 0
elif isinstance(step, int):
assert step > 0
else:
raise TypeError('"step" must be a list or integer')
self.step = step
self.gamma = gamma
super(StepLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
if isinstance(self.step, int):
return base_lr * (self.gamma**(progress // self.step))
exp = len(self.step)
for i, s in enumerate(self.step):
if progress < s:
exp = i
break
return base_lr * self.gamma**exp
class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
self.gamma = gamma
super(ExpLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * self.gamma**progress
class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., **kwargs):
self.power = power
super(PolyLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
return base_lr * (1 - progress / max_progress)**self.power
class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
self.gamma = gamma
self.power = power
super(InvLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * (1 + self.gamma * progress)**(-self.power)

View File

@ -0,0 +1,21 @@
from torch.nn.utils import clip_grad
from .hook import Hook
class OptimizerStepperHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip
self.max_norm = max_norm
self.norm_type = norm_type
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
if self.grad_clip:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
max_norm=self.max_norm,
norm_type=self.norm_type)
runner.optimizer.step()

View File

@ -0,0 +1,157 @@
import os.path as osp
import time
from collections import OrderedDict
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils import model_zoo
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
own_state = module.state_dict()
for name, param in state_dict.items():
if name not in own_state:
unexpected_keys.append(name)
continue
if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(),
param.size()))
missing_keys = set(own_state.keys()) - set(state_dict.keys())
err_msg = []
if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format(
', '.join(unexpected_keys)))
if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys)))
err_msg = '\n'.join(err_msg)
if err_msg:
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warn(err_msg)
else:
print(err_msg)
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoll://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
# load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'):
from torchvision.models.resnet import model_urls
model_name = filename[11:]
checkpoint = model_zoo.load_url(model_urls[model_name])
elif filename.startswith(('http://', 'https://')):
checkpoint = model_zoo.load_url(filename)
else:
if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(filename))
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict
if isinstance(model, (DataParallel, DistributedDataParallel)):
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError('meta must be a dict or None, but got {}'.format(
type(meta)))
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
if optimizer is not None:
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, filename)

View File

@ -0,0 +1,74 @@
import multiprocessing
import torch
from .io import load_checkpoint
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpu_id, idx_queue, result_queue):
model = model_cls(**model_kwargs)
load_checkpoint(model, checkpoint, map_location='cpu')
torch.cuda.set_device(gpu_id)
model.cuda()
model.eval()
with torch.no_grad():
while True:
idx = idx_queue.get()
data = dataset[idx]
result = model(**data_func(data, gpu_id))
result_queue.put((idx, result))
def parallel_test(model_cls,
model_kwargs,
checkpoint,
dataset,
data_func,
gpus,
workers_per_gpu=1):
"""Parallel testing on multiple GPUs.
Args:
model_cls (type): Model class type.
model_kwargs (dict): Arguments to init the model.
checkpoint (str): Checkpoint filepath.
dataset (:obj:`Dataset`): The dataset to be tested.
data_func (callable): The function that generates model inputs.
gpus (list[int]): GPU ids to be used.
workers_per_gpu (int): Number of processes on each GPU. It is possible
to run multiple workers on each GPU.
Returns:
list: Test results.
"""
ctx = multiprocessing.get_context('spawn')
idx_queue = ctx.Queue()
result_queue = ctx.Queue()
num_workers = len(gpus) * workers_per_gpu
workers = [
ctx.Process(
target=worker_func,
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpus[i % len(gpus)], idx_queue, result_queue))
for i in range(num_workers)
]
for w in workers:
w.daemon = True
w.start()
for i in range(len(dataset)):
idx_queue.put(i)
results = [None for _ in range(len(dataset))]
import cvbase as cvb
prog_bar = cvb.ProgressBar(task_num=len(dataset))
for _ in range(len(dataset)):
idx, res = result_queue.get()
results[idx] = res
prog_bar.update()
print('\n')
for worker in workers:
worker.terminate()
return results

View File

@ -0,0 +1,2 @@
from .log_buffer import LogBuffer
from .runner import Runner

View File

@ -0,0 +1,39 @@
from collections import OrderedDict
import numpy as np
class LogBuffer(object):
def __init__(self):
self.val_history = OrderedDict()
self.n_history = OrderedDict()
self.output = OrderedDict()
self.ready = False
def clear(self):
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self):
self.output.clear()
self.ready = False
def update(self, vars, count=1):
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
self.val_history[key] = []
self.n_history[key] = []
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n=0):
"""Average latest n values or all values"""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])
nums = np.array(self.n_history[key][-n:])
avg = np.sum(values * nums) / np.sum(nums)
self.output[key] = avg
self.ready = True

View File

@ -0,0 +1,344 @@
import logging
import os.path as osp
import time
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .log_buffer import LogBuffer
from .. import hooks
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook,
OptimizerStepperHook)
from ..io import load_checkpoint, save_checkpoint
from ..utils import (get_dist_info, get_host_info, get_time_str,
add_file_handler, obj_from_dict)
class Runner(object):
"""A training helper for PyTorch."""
def __init__(self,
model,
optimizer,
batch_processor,
work_dir=None,
log_level=logging.INFO):
assert callable(batch_processor)
self.model = model
self.optimizer = self.init_optimizer(optimizer)
self.batch_processor = batch_processor
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
else:
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if isinstance(self.model, (DataParallel, DistributedDataParallel)):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.logger = self.init_logger(work_dir, log_level)
self.log_buffer = LogBuffer()
self.mode = None
self._hooks = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0
@property
def model_name(self):
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self):
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self):
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self):
"""int: Iteration in an epoch."""
return self._inner_iter
@property
def max_epochs(self):
"""int: Maximum training epochs."""
return self._max_epochs
@property
def max_iters(self):
"""int: Maximum training iterations."""
return self._max_iters
def init_optimizer(self, optimizer):
"""Init the optimizer.
Args:
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
optimizer object or a dict used for constructing the optimizer.
An example of the dict: ``{'algorithm': 'SGD', 'lr': 0.02,
'momentum': 0.9, 'weight_decay': 0.0001}``.
Returns:
:obj:`~torch.optim.Optimizer`: An optimizer object.
"""
if isinstance(optimizer, dict):
optimizer = obj_from_dict(
optimizer, torch.optim, dict(params=self.model.parameters()))
elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
'optimizer must be either an Optimizer object or a dict, '
'but got {}'.format(type(optimizer)))
return optimizer
def init_logger(self, log_dir=None, level=logging.INFO):
"""Init the logger.
Args:
log_dir(str, optional): Log file directory. If not specified, no
log file will be used.
level (int or str): See the built-in python logging module.
Returns:
:obj:`~logging.Logger`: Python logger.
"""
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir:
filename = '{}_{}.log'.format(get_time_str(), self.rank)
log_file = osp.join(log_dir, filename)
add_file_handler(logger, log_file, level=level)
return logger
def current_lr(self):
"""Get current learning rates.
Returns:
list: Current learning rate of all param groups.
"""
return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
"""
assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def load_checkpoint(self, filename, map_location='cpu', strict=False):
self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None):
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch))
linkname = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
mmcv.symlink(filename, linkname)
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def resume(self, checkpoint, resume_optimizer=True,
map_location='default'):
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run an epoch'.
format(mode))
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
else:
raise TypeError('mode in workflow must be a str or '
'callable function, not {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def register_lr_hooks(self, lr_config):
if isinstance(lr_config, LrUpdaterHook):
self.register_hook(lr_config)
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
from ..hooks import lr_updater
hook_name = lr_config['policy'].title() + 'LrUpdaterHook'
if not hasattr(lr_updater, hook_name):
raise ValueError('"{}" does not exist'.format(hook_name))
hook_cls = getattr(lr_updater, hook_name)
self.register_hook(hook_cls(**lr_config))
else:
raise TypeError('"lr_config" must be either a LrUpdaterHook object'
' or dict, not {}'.format(type(lr_config)))
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60)
def register_default_hooks(self,
lr_config,
grad_clip_config=None,
checkpoint_config=None,
log_config=None):
"""Register several default hooks.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
"""
if grad_clip_config is None:
grad_clip_config = {}
if checkpoint_config is None:
checkpoint_config = {}
self.register_lr_hooks(lr_config)
self.register_hook(OptimizerStepperHook(**grad_clip_config))
self.register_hook(CheckpointSaverHook(**checkpoint_config))
self.register_hook(IterTimerHook())
if log_config is not None:
self.register_logger_hooks(log_config)

View File

@ -0,0 +1,77 @@
import functools
import logging
import time
from getpass import getuser
from socket import gethostname
import mmcv
import torch.distributed as dist
def get_host_info():
return '{}@{}'.format(getuser(), gethostname())
def get_dist_info():
if dist._initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def add_file_handler(logger, filename=None, mode='w', level=logging.INFO):
file_handler = logging.FileHandler(filename, mode)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
return logger
def obj_from_dict(info, module, default_args=None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
module (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
"""
assert isinstance(info, dict) and 'type' in info
assert isinstance(default_args, dict) or default_args is None
args = info.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = getattr(module, obj_type)
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)