mirror of https://github.com/open-mmlab/mmcv.git
add initial version of torchpack
parent
02ceae8327
commit
ffdc1d457f
|
@ -0,0 +1,31 @@
|
|||
# model settings
|
||||
model = 'resnet18'
|
||||
# dataset settings
|
||||
data_root = '/mnt/SSD/dataset/cifar10'
|
||||
mean = [0.4914, 0.4822, 0.4465]
|
||||
std = [0.2023, 0.1994, 0.2010]
|
||||
batch_size = 64
|
||||
|
||||
# optimizer and learning rate
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
lr_policy = dict(policy='step', step=2)
|
||||
|
||||
# runtime settings
|
||||
work_dir = './demo'
|
||||
gpus = range(2)
|
||||
data_workers = 2 # data workers per gpu
|
||||
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
|
||||
workflow = [('train', 1), ('val', 1)]
|
||||
max_epoch = 6
|
||||
resume_from = None
|
||||
load_from = None
|
||||
|
||||
# logging settings
|
||||
log_level = 'INFO'
|
||||
log_cfg = dict(
|
||||
# log at every 50 iterations
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
|
||||
])
|
|
@ -0,0 +1,132 @@
|
|||
# copied from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, self.expansion * planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def resnet18():
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
|
||||
def resnet34():
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def resnet50():
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def resnet101():
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3])
|
||||
|
||||
|
||||
def resnet152():
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3])
|
|
@ -0,0 +1,102 @@
|
|||
from argparse import ArgumentParser
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv import Config
|
||||
from mmcv.torchpack import Runner
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import resnet_cifar
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1, )):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def batch_processor(model, data, train_mode):
|
||||
img, label = data
|
||||
label = label.cuda(non_blocking=True)
|
||||
pred = model(img)
|
||||
loss = F.cross_entropy(pred, label)
|
||||
acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
|
||||
log_vars = OrderedDict()
|
||||
log_vars['loss'] = loss.item()
|
||||
log_vars['acc_top1'] = acc_top1.item()
|
||||
log_vars['acc_top5'] = acc_top5.item()
|
||||
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
|
||||
return outputs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='Train CIFAR-10 classification')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
model = getattr(resnet_cifar, cfg.model)()
|
||||
model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda()
|
||||
|
||||
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
|
||||
train_dataset = datasets.CIFAR10(
|
||||
root=cfg.data_root,
|
||||
train=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
val_dataset = datasets.CIFAR10(
|
||||
root=cfg.data_root,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
num_workers = cfg.data_workers * len(cfg.gpus)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
|
||||
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
|
||||
runner.register_default_hooks(
|
||||
lr_config=cfg.lr_policy,
|
||||
checkpoint_config=cfg.checkpoint_cfg,
|
||||
log_config=cfg.log_cfg)
|
||||
|
||||
if cfg.get('resume_from') is not None:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.get('load_from') is not None:
|
||||
runner.load_checkpoint(cfg.load_from)
|
||||
|
||||
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,5 @@
|
|||
from .hooks import *
|
||||
from .io import *
|
||||
from .parallel import *
|
||||
from .runner import *
|
||||
from .utils import *
|
|
@ -0,0 +1,7 @@
|
|||
from .hook import Hook
|
||||
from .checkpoint_saver import CheckpointSaverHook
|
||||
from .closure import ClosureHook
|
||||
from .lr_updater import LrUpdaterHook
|
||||
from .optimizer_stepper import OptimizerStepperHook
|
||||
from .iter_timer import IterTimerHook
|
||||
from .logger import *
|
|
@ -0,0 +1,25 @@
|
|||
from .hook import Hook
|
||||
from ..utils import master_only
|
||||
|
||||
|
||||
class CheckpointSaverHook(Hook):
|
||||
|
||||
def __init__(self,
|
||||
interval=-1,
|
||||
save_optimizer=True,
|
||||
out_dir=None,
|
||||
**kwargs):
|
||||
self.interval = interval
|
||||
self.save_optimizer = save_optimizer
|
||||
self.out_dir = out_dir
|
||||
self.args = kwargs
|
||||
|
||||
@master_only
|
||||
def after_train_epoch(self, runner):
|
||||
if not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
|
||||
if not self.out_dir:
|
||||
self.out_dir = runner.work_dir
|
||||
runner.save_checkpoint(
|
||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
|
@ -0,0 +1,9 @@
|
|||
from .hook import Hook
|
||||
|
||||
|
||||
class ClosureHook(Hook):
|
||||
|
||||
def __init__(self, fn_name, fn):
|
||||
assert hasattr(self, fn_name)
|
||||
assert callable(fn)
|
||||
setattr(self, fn_name, fn)
|
|
@ -0,0 +1,55 @@
|
|||
class Hook(object):
|
||||
|
||||
def before_run(self, runner):
|
||||
pass
|
||||
|
||||
def after_run(self, runner):
|
||||
pass
|
||||
|
||||
def before_epoch(self, runner):
|
||||
pass
|
||||
|
||||
def after_epoch(self, runner):
|
||||
pass
|
||||
|
||||
def before_iter(self, runner):
|
||||
pass
|
||||
|
||||
def after_iter(self, runner):
|
||||
pass
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
self.before_epoch(runner)
|
||||
|
||||
def before_val_epoch(self, runner):
|
||||
self.before_epoch(runner)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
self.after_epoch(runner)
|
||||
|
||||
def after_val_epoch(self, runner):
|
||||
self.after_epoch(runner)
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
self.before_iter(runner)
|
||||
|
||||
def before_val_iter(self, runner):
|
||||
self.before_iter(runner)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
self.after_iter(runner)
|
||||
|
||||
def after_val_iter(self, runner):
|
||||
self.after_iter(runner)
|
||||
|
||||
def every_n_epochs(self, runner, n):
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, runner, n):
|
||||
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, runner, n):
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner):
|
||||
return runner.inner_iter + 1 == len(runner.data_loader)
|
|
@ -0,0 +1,16 @@
|
|||
import time
|
||||
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
class IterTimerHook(Hook):
|
||||
|
||||
def before_epoch(self, runner):
|
||||
self.t = time.time()
|
||||
|
||||
def before_iter(self, runner):
|
||||
runner.log_buffer.update({'data_time': time.time() - self.t})
|
||||
|
||||
def after_iter(self, runner):
|
||||
runner.log_buffer.update({'time': time.time() - self.t})
|
||||
self.t = time.time()
|
|
@ -0,0 +1,4 @@
|
|||
from .base import LoggerHook
|
||||
from .pavi import PaviClient, PaviLoggerHook
|
||||
from .tensorboard import TensorboardLoggerHook
|
||||
from .text import TextLoggerHook
|
|
@ -0,0 +1,49 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from ..hook import Hook
|
||||
|
||||
|
||||
class LoggerHook(Hook):
|
||||
"""Base class for logger hooks."""
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
|
||||
self.interval = interval
|
||||
self.ignore_last = ignore_last
|
||||
self.reset_flag = reset_flag
|
||||
|
||||
@abstractmethod
|
||||
def log(self, runner):
|
||||
pass
|
||||
|
||||
def before_run(self, runner):
|
||||
for hook in runner.hooks[::-1]:
|
||||
if isinstance(hook, LoggerHook):
|
||||
hook.reset_flag = True
|
||||
break
|
||||
|
||||
def before_epoch(self, runner):
|
||||
runner.log_buffer.clear() # clear logs of last epoch
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.every_n_inner_iters(runner, self.interval):
|
||||
runner.log_buffer.average(self.interval)
|
||||
elif self.end_of_epoch(runner) and not self.ignore_last:
|
||||
# not precise but more stable
|
||||
runner.log_buffer.average(self.interval)
|
||||
|
||||
if runner.log_buffer.ready:
|
||||
self.log(runner)
|
||||
if self.reset_flag:
|
||||
runner.log_buffer.clear_output()
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
if runner.log_buffer.ready:
|
||||
self.log(runner)
|
||||
|
||||
def after_val_epoch(self, runner):
|
||||
runner.log_buffer.average()
|
||||
self.log(runner)
|
||||
if self.reset_flag:
|
||||
runner.log_buffer.clear_output()
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
||||
import requests
|
||||
from six.moves.queue import Empty, Queue
|
||||
|
||||
from .base import LoggerHook
|
||||
from ...utils import master_only, get_host_info
|
||||
|
||||
|
||||
class PaviClient(object):
|
||||
|
||||
def __init__(self, url, username=None, password=None, instance_id=None):
|
||||
self.url = url
|
||||
self.username = self._get_env_var(username, 'PAVI_USERNAME')
|
||||
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
|
||||
self.instance_id = instance_id
|
||||
self.log_queue = None
|
||||
|
||||
def _get_env_var(self, var, env_var):
|
||||
if var is not None:
|
||||
return str(var)
|
||||
|
||||
var = os.getenv(env_var)
|
||||
if not var:
|
||||
raise ValueError(
|
||||
'"{}" is neither specified nor defined as env variables'.
|
||||
format(env_var))
|
||||
return var
|
||||
|
||||
def connect(self,
|
||||
model_name,
|
||||
work_dir=None,
|
||||
info=dict(),
|
||||
timeout=5,
|
||||
logger=None):
|
||||
if logger:
|
||||
log_info = logger.info
|
||||
log_error = logger.error
|
||||
else:
|
||||
log_info = log_error = print
|
||||
log_info('connecting pavi service {}...'.format(self.url))
|
||||
post_data = dict(
|
||||
time=str(datetime.now()),
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
instance_id=self.instance_id,
|
||||
model=model_name,
|
||||
work_dir=os.path.abspath(work_dir) if work_dir else '',
|
||||
session_file=info.get('session_file', ''),
|
||||
session_text=info.get('session_text', ''),
|
||||
model_text=info.get('model_text', ''),
|
||||
device=get_host_info())
|
||||
try:
|
||||
response = requests.post(self.url, json=post_data, timeout=timeout)
|
||||
except Exception as ex:
|
||||
log_error('fail to connect to pavi service: {}'.format(ex))
|
||||
else:
|
||||
if response.status_code == 200:
|
||||
self.instance_id = response.text
|
||||
log_info('pavi service connected, instance_id: {}'.format(
|
||||
self.instance_id))
|
||||
self.log_queue = Queue()
|
||||
self.log_thread = Thread(target=self.post_worker_fn)
|
||||
self.log_thread.daemon = True
|
||||
self.log_thread.start()
|
||||
return True
|
||||
else:
|
||||
log_error('fail to connect to pavi service, status code: '
|
||||
'{}, err message: {}'.format(response.status_code,
|
||||
response.reason))
|
||||
return False
|
||||
|
||||
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
|
||||
while True:
|
||||
try:
|
||||
log = self.log_queue.get(timeout=queue_timeout)
|
||||
except Empty:
|
||||
time.sleep(1)
|
||||
except Exception as ex:
|
||||
print('fail to get logs from queue: {}'.format(ex))
|
||||
else:
|
||||
retry = 0
|
||||
while retry < max_retry:
|
||||
try:
|
||||
response = requests.post(
|
||||
self.url, json=log, timeout=req_timeout)
|
||||
except Exception as ex:
|
||||
retry += 1
|
||||
print('error when posting logs to pavi: {}'.format(ex))
|
||||
else:
|
||||
status_code = response.status_code
|
||||
if status_code == 200:
|
||||
break
|
||||
else:
|
||||
print('unexpected status code: %d, err msg: %s',
|
||||
status_code, response.reason)
|
||||
retry += 1
|
||||
if retry == max_retry:
|
||||
print('fail to send logs of iteration %d', log['iter_num'])
|
||||
|
||||
def log(self, phase, iter, outputs):
|
||||
if self.log_queue is not None:
|
||||
logs = {
|
||||
'time': str(datetime.now()),
|
||||
'instance_id': self.instance_id,
|
||||
'flow_id': phase,
|
||||
'iter_num': iter,
|
||||
'outputs': outputs,
|
||||
'msg': ''
|
||||
}
|
||||
self.log_queue.put(logs)
|
||||
|
||||
|
||||
class PaviLoggerHook(LoggerHook):
|
||||
|
||||
def __init__(self,
|
||||
url,
|
||||
username=None,
|
||||
password=None,
|
||||
instance_id=None,
|
||||
interval=10,
|
||||
reset_meter=True,
|
||||
ignore_last=True):
|
||||
self.pavi = PaviClient(url, username, password, instance_id)
|
||||
super(PaviLoggerHook, self).__init__(interval, reset_meter,
|
||||
ignore_last)
|
||||
|
||||
@master_only
|
||||
def connect(self,
|
||||
model_name,
|
||||
work_dir=None,
|
||||
info=dict(),
|
||||
timeout=5,
|
||||
logger=None):
|
||||
return self.pavi.connect(model_name, work_dir, info, timeout, logger)
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
log_outs = runner.log_buffer.output.copy()
|
||||
log_outs.pop('time', None)
|
||||
log_outs.pop('data_time', None)
|
||||
self.pavi.log(runner.mode, runner.iter, log_outs)
|
|
@ -0,0 +1,37 @@
|
|||
from .base import LoggerHook
|
||||
from ...utils import master_only
|
||||
|
||||
|
||||
class TensorboardLoggerHook(LoggerHook):
|
||||
|
||||
def __init__(self,
|
||||
log_dir,
|
||||
interval=10,
|
||||
reset_meter=True,
|
||||
ignore_last=True):
|
||||
super(TensorboardLoggerHook, self).__init__(interval, reset_meter,
|
||||
ignore_last)
|
||||
self.log_dir = log_dir
|
||||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
raise ImportError('Please install tensorflow and tensorboardX '
|
||||
'to use TensorboardLoggerHook.')
|
||||
else:
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
for var in runner.log_buffer.output:
|
||||
if var in ['time', 'data_time']:
|
||||
continue
|
||||
tag = '{}/{}'.format(var, runner.mode)
|
||||
self.writer.add_scalar(tag, runner.log_buffer.output[var],
|
||||
runner.iter)
|
||||
|
||||
@master_only
|
||||
def after_run(self, runner):
|
||||
self.writer.close()
|
|
@ -0,0 +1,26 @@
|
|||
from .base import LoggerHook
|
||||
|
||||
|
||||
class TextLoggerHook(LoggerHook):
|
||||
|
||||
def log(self, runner):
|
||||
if runner.mode == 'train':
|
||||
lr_str = ', '.join(
|
||||
['{:.5f}'.format(lr) for lr in runner.current_lr()])
|
||||
log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format(
|
||||
runner.epoch + 1, runner.inner_iter + 1,
|
||||
len(runner.data_loader), lr_str)
|
||||
else:
|
||||
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
|
||||
runner.inner_iter + 1)
|
||||
if 'time' in runner.log_buffer.output:
|
||||
log_str += (
|
||||
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
|
||||
format(log=runner.log_buffer.output))
|
||||
log_items = []
|
||||
for name, val in runner.log_buffer.output.items():
|
||||
if name in ['time', 'data_time']:
|
||||
continue
|
||||
log_items.append('{}: {:.4f}'.format(name, val))
|
||||
log_str += ', '.join(log_items)
|
||||
runner.logger.info(log_str)
|
|
@ -0,0 +1,163 @@
|
|||
from __future__ import division
|
||||
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
class LrUpdaterHook(Hook):
|
||||
|
||||
def __init__(self,
|
||||
by_epoch=True,
|
||||
warmup=None,
|
||||
warmup_iters=0,
|
||||
warmup_ratio=0.1,
|
||||
**kwargs):
|
||||
# validate the "warmup" argument
|
||||
if warmup is not None:
|
||||
if warmup not in ['constant', 'linear', 'exp']:
|
||||
raise ValueError(
|
||||
'"{}" is not a supported type for warming up, valid types'
|
||||
' are "constant" and "linear"'.format(warmup))
|
||||
if warmup is not None:
|
||||
assert warmup_iters > 0, \
|
||||
'"warmup_iters" must be a positive integer'
|
||||
assert 0 < warmup_ratio <= 1.0, \
|
||||
'"warmup_ratio" must be in range (0,1]'
|
||||
|
||||
self.by_epoch = by_epoch
|
||||
self.warmup = warmup
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_ratio = warmup_ratio
|
||||
|
||||
self.base_lr = [] # initial lr for all param groups
|
||||
self.regular_lr = [] # expected lr if no warming up is performed
|
||||
|
||||
def _set_lr(self, runner, lr_groups):
|
||||
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_regular_lr(self, runner):
|
||||
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
|
||||
|
||||
def get_warmup_lr(self, cur_iters):
|
||||
if self.warmup == 'constant':
|
||||
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
|
||||
elif self.warmup == 'linear':
|
||||
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
|
||||
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
|
||||
elif self.warmup == 'exp':
|
||||
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
|
||||
warmup_lr = [_lr * k for _lr in self.regular_lr]
|
||||
return warmup_lr
|
||||
|
||||
def before_run(self, runner):
|
||||
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
|
||||
# it will be set according to the optimizer params
|
||||
for group in runner.optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lr = [
|
||||
group['initial_lr'] for group in runner.optimizer.param_groups
|
||||
]
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
if not self.by_epoch:
|
||||
return
|
||||
self.regular_lr = self.get_regular_lr(runner)
|
||||
self._set_lr(runner, self.regular_lr)
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
cur_iter = runner.iter
|
||||
if not self.by_epoch:
|
||||
self.regular_lr = self.get_regular_lr(runner)
|
||||
if self.warmup is None or cur_iter >= self.warmup_iters:
|
||||
self._set_lr(runner, self.regular_lr)
|
||||
else:
|
||||
warmup_lr = self.get_warmup_lr(cur_iter)
|
||||
self._set_lr(runner, warmup_lr)
|
||||
elif self.by_epoch:
|
||||
if self.warmup is None or cur_iter > self.warmup_iters:
|
||||
return
|
||||
elif cur_iter == self.warmup_iters:
|
||||
self._set_lr(runner, self.regular_lr)
|
||||
else:
|
||||
warmup_lr = self.get_warmup_lr(cur_iter)
|
||||
self._set_lr(runner, warmup_lr)
|
||||
|
||||
|
||||
class FixedLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(FixedLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
return base_lr
|
||||
|
||||
|
||||
class StepLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, step, gamma=0.1, **kwargs):
|
||||
assert isinstance(step, (list, int))
|
||||
if isinstance(step, list):
|
||||
for s in step:
|
||||
assert isinstance(s, int) and s > 0
|
||||
elif isinstance(step, int):
|
||||
assert step > 0
|
||||
else:
|
||||
raise TypeError('"step" must be a list or integer')
|
||||
self.step = step
|
||||
self.gamma = gamma
|
||||
super(StepLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
|
||||
if isinstance(self.step, int):
|
||||
return base_lr * (self.gamma**(progress // self.step))
|
||||
|
||||
exp = len(self.step)
|
||||
for i, s in enumerate(self.step):
|
||||
if progress < s:
|
||||
exp = i
|
||||
break
|
||||
return base_lr * self.gamma**exp
|
||||
|
||||
|
||||
class ExpLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, gamma, **kwargs):
|
||||
self.gamma = gamma
|
||||
super(ExpLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
return base_lr * self.gamma**progress
|
||||
|
||||
|
||||
class PolyLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, power=1., **kwargs):
|
||||
self.power = power
|
||||
super(PolyLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
progress = runner.epoch
|
||||
max_progress = runner.max_epochs
|
||||
else:
|
||||
progress = runner.iter
|
||||
max_progress = runner.max_iters
|
||||
return base_lr * (1 - progress / max_progress)**self.power
|
||||
|
||||
|
||||
class InvLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, gamma, power=1., **kwargs):
|
||||
self.gamma = gamma
|
||||
self.power = power
|
||||
super(InvLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
return base_lr * (1 + self.gamma * progress)**(-self.power)
|
|
@ -0,0 +1,21 @@
|
|||
from torch.nn.utils import clip_grad
|
||||
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
class OptimizerStepperHook(Hook):
|
||||
|
||||
def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
|
||||
self.grad_clip = grad_clip
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
runner.optimizer.zero_grad()
|
||||
runner.outputs['loss'].backward()
|
||||
if self.grad_clip:
|
||||
clip_grad.clip_grad_norm_(
|
||||
filter(lambda p: p.requires_grad, runner.model.parameters()),
|
||||
max_norm=self.max_norm,
|
||||
norm_type=self.norm_type)
|
||||
runner.optimizer.step()
|
|
@ -0,0 +1,157 @@
|
|||
import os.path as osp
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
from torch.utils import model_zoo
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
"""Load state_dict to a module.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
||||
Default value for ``strict`` is set to ``False`` and the message for
|
||||
param mismatch will be shown even if strict is False.
|
||||
|
||||
Args:
|
||||
module (Module): Module that receives the state_dict.
|
||||
state_dict (OrderedDict): Weights.
|
||||
strict (bool): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||
message. If not specified, print function will be used.
|
||||
"""
|
||||
unexpected_keys = []
|
||||
own_state = module.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name not in own_state:
|
||||
unexpected_keys.append(name)
|
||||
continue
|
||||
if isinstance(param, torch.nn.Parameter):
|
||||
# backwards compatibility for serialized parameters
|
||||
param = param.data
|
||||
|
||||
try:
|
||||
own_state[name].copy_(param)
|
||||
except Exception:
|
||||
raise RuntimeError('While copying the parameter named {}, '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}.'
|
||||
.format(name, own_state[name].size(),
|
||||
param.size()))
|
||||
missing_keys = set(own_state.keys()) - set(state_dict.keys())
|
||||
|
||||
err_msg = []
|
||||
if unexpected_keys:
|
||||
err_msg.append('unexpected key in source state_dict: {}\n'.format(
|
||||
', '.join(unexpected_keys)))
|
||||
if missing_keys:
|
||||
err_msg.append('missing keys in source state_dict: {}\n'.format(
|
||||
', '.join(missing_keys)))
|
||||
err_msg = '\n'.join(err_msg)
|
||||
if err_msg:
|
||||
if strict:
|
||||
raise RuntimeError(err_msg)
|
||||
elif logger is not None:
|
||||
logger.warn(err_msg)
|
||||
else:
|
||||
print(err_msg)
|
||||
|
||||
|
||||
def load_checkpoint(model,
|
||||
filename,
|
||||
map_location=None,
|
||||
strict=False,
|
||||
logger=None):
|
||||
"""Load checkpoint from a file or URI.
|
||||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Either a filepath or URL or modelzoll://xxxxxxx.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||
|
||||
Returns:
|
||||
dict or OrderedDict: The loaded checkpoint.
|
||||
"""
|
||||
# load checkpoint from modelzoo or file or url
|
||||
if filename.startswith('modelzoo://'):
|
||||
from torchvision.models.resnet import model_urls
|
||||
model_name = filename[11:]
|
||||
checkpoint = model_zoo.load_url(model_urls[model_name])
|
||||
elif filename.startswith(('http://', 'https://')):
|
||||
checkpoint = model_zoo.load_url(filename)
|
||||
else:
|
||||
if not osp.isfile(filename):
|
||||
raise IOError('{} is not a checkpoint file'.format(filename))
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
# get state_dict from checkpoint
|
||||
if isinstance(checkpoint, OrderedDict):
|
||||
state_dict = checkpoint
|
||||
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'No state_dict found in checkpoint file {}'.format(filename))
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
|
||||
# load state_dict
|
||||
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
||||
load_state_dict(model.module, state_dict, strict, logger)
|
||||
else:
|
||||
load_state_dict(model, state_dict, strict, logger)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def weights_to_cpu(state_dict):
|
||||
"""Copy a model state_dict to cpu.
|
||||
|
||||
Args:
|
||||
state_dict (OrderedDict): Model weights on GPU.
|
||||
|
||||
Returns:
|
||||
OrderedDict: Model weights on GPU.
|
||||
"""
|
||||
state_dict_cpu = OrderedDict()
|
||||
for key, val in state_dict.items():
|
||||
state_dict_cpu[key] = val.cpu()
|
||||
return state_dict_cpu
|
||||
|
||||
|
||||
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||
"""Save checkpoint to file.
|
||||
|
||||
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
||||
``optimizer``. By default ``meta`` will contain version and time info.
|
||||
|
||||
Args:
|
||||
model (Module): Module whose params are to be saved.
|
||||
filename (str): Checkpoint filename.
|
||||
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
||||
meta (dict, optional): Metadata to be saved in checkpoint.
|
||||
"""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
raise TypeError('meta must be a dict or None, but got {}'.format(
|
||||
type(meta)))
|
||||
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
||||
|
||||
mmcv.mkdir_or_exist(osp.dirname(filename))
|
||||
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(model.state_dict())
|
||||
}
|
||||
if optimizer is not None:
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
|
||||
torch.save(checkpoint, filename)
|
|
@ -0,0 +1,74 @@
|
|||
import multiprocessing
|
||||
|
||||
import torch
|
||||
|
||||
from .io import load_checkpoint
|
||||
|
||||
|
||||
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
||||
gpu_id, idx_queue, result_queue):
|
||||
model = model_cls(**model_kwargs)
|
||||
load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
torch.cuda.set_device(gpu_id)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
while True:
|
||||
idx = idx_queue.get()
|
||||
data = dataset[idx]
|
||||
result = model(**data_func(data, gpu_id))
|
||||
result_queue.put((idx, result))
|
||||
|
||||
|
||||
def parallel_test(model_cls,
|
||||
model_kwargs,
|
||||
checkpoint,
|
||||
dataset,
|
||||
data_func,
|
||||
gpus,
|
||||
workers_per_gpu=1):
|
||||
"""Parallel testing on multiple GPUs.
|
||||
|
||||
Args:
|
||||
model_cls (type): Model class type.
|
||||
model_kwargs (dict): Arguments to init the model.
|
||||
checkpoint (str): Checkpoint filepath.
|
||||
dataset (:obj:`Dataset`): The dataset to be tested.
|
||||
data_func (callable): The function that generates model inputs.
|
||||
gpus (list[int]): GPU ids to be used.
|
||||
workers_per_gpu (int): Number of processes on each GPU. It is possible
|
||||
to run multiple workers on each GPU.
|
||||
|
||||
Returns:
|
||||
list: Test results.
|
||||
"""
|
||||
ctx = multiprocessing.get_context('spawn')
|
||||
idx_queue = ctx.Queue()
|
||||
result_queue = ctx.Queue()
|
||||
num_workers = len(gpus) * workers_per_gpu
|
||||
workers = [
|
||||
ctx.Process(
|
||||
target=worker_func,
|
||||
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
||||
gpus[i % len(gpus)], idx_queue, result_queue))
|
||||
for i in range(num_workers)
|
||||
]
|
||||
for w in workers:
|
||||
w.daemon = True
|
||||
w.start()
|
||||
|
||||
for i in range(len(dataset)):
|
||||
idx_queue.put(i)
|
||||
|
||||
results = [None for _ in range(len(dataset))]
|
||||
import cvbase as cvb
|
||||
prog_bar = cvb.ProgressBar(task_num=len(dataset))
|
||||
for _ in range(len(dataset)):
|
||||
idx, res = result_queue.get()
|
||||
results[idx] = res
|
||||
prog_bar.update()
|
||||
print('\n')
|
||||
for worker in workers:
|
||||
worker.terminate()
|
||||
|
||||
return results
|
|
@ -0,0 +1,2 @@
|
|||
from .log_buffer import LogBuffer
|
||||
from .runner import Runner
|
|
@ -0,0 +1,39 @@
|
|||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LogBuffer(object):
|
||||
|
||||
def __init__(self):
|
||||
self.val_history = OrderedDict()
|
||||
self.n_history = OrderedDict()
|
||||
self.output = OrderedDict()
|
||||
self.ready = False
|
||||
|
||||
def clear(self):
|
||||
self.val_history.clear()
|
||||
self.n_history.clear()
|
||||
self.clear_output()
|
||||
|
||||
def clear_output(self):
|
||||
self.output.clear()
|
||||
self.ready = False
|
||||
|
||||
def update(self, vars, count=1):
|
||||
assert isinstance(vars, dict)
|
||||
for key, var in vars.items():
|
||||
if key not in self.val_history:
|
||||
self.val_history[key] = []
|
||||
self.n_history[key] = []
|
||||
self.val_history[key].append(var)
|
||||
self.n_history[key].append(count)
|
||||
|
||||
def average(self, n=0):
|
||||
"""Average latest n values or all values"""
|
||||
assert n >= 0
|
||||
for key in self.val_history:
|
||||
values = np.array(self.val_history[key][-n:])
|
||||
nums = np.array(self.n_history[key][-n:])
|
||||
avg = np.sum(values * nums) / np.sum(nums)
|
||||
self.output[key] = avg
|
||||
self.ready = True
|
|
@ -0,0 +1,344 @@
|
|||
import logging
|
||||
import os.path as osp
|
||||
import time
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from .log_buffer import LogBuffer
|
||||
from .. import hooks
|
||||
from ..hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, IterTimerHook,
|
||||
OptimizerStepperHook)
|
||||
from ..io import load_checkpoint, save_checkpoint
|
||||
from ..utils import (get_dist_info, get_host_info, get_time_str,
|
||||
add_file_handler, obj_from_dict)
|
||||
|
||||
|
||||
class Runner(object):
|
||||
"""A training helper for PyTorch."""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
optimizer,
|
||||
batch_processor,
|
||||
work_dir=None,
|
||||
log_level=logging.INFO):
|
||||
assert callable(batch_processor)
|
||||
self.model = model
|
||||
self.optimizer = self.init_optimizer(optimizer)
|
||||
self.batch_processor = batch_processor
|
||||
|
||||
# create work_dir
|
||||
if mmcv.is_str(work_dir):
|
||||
self.work_dir = osp.abspath(work_dir)
|
||||
mmcv.mkdir_or_exist(self.work_dir)
|
||||
elif work_dir is None:
|
||||
self.work_dir = None
|
||||
else:
|
||||
raise TypeError('"work_dir" must be a str or None')
|
||||
|
||||
# get model name from the model class
|
||||
if isinstance(self.model, (DataParallel, DistributedDataParallel)):
|
||||
self._model_name = self.model.module.__class__.__name__
|
||||
else:
|
||||
self._model_name = self.model.__class__.__name__
|
||||
|
||||
self._rank, self._world_size = get_dist_info()
|
||||
self.logger = self.init_logger(work_dir, log_level)
|
||||
self.log_buffer = LogBuffer()
|
||||
|
||||
self.mode = None
|
||||
self._hooks = []
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
self._inner_iter = 0
|
||||
self._max_epochs = 0
|
||||
self._max_iters = 0
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
"""str: Name of the model, usually the module class name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""int: Rank of current process. (distributed training)"""
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
"""int: Number of processes participating in the job.
|
||||
(distributed training)"""
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def hooks(self):
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
"""int: Current epoch."""
|
||||
return self._epoch
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""int: Current iteration."""
|
||||
return self._iter
|
||||
|
||||
@property
|
||||
def inner_iter(self):
|
||||
"""int: Iteration in an epoch."""
|
||||
return self._inner_iter
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
"""int: Maximum training epochs."""
|
||||
return self._max_epochs
|
||||
|
||||
@property
|
||||
def max_iters(self):
|
||||
"""int: Maximum training iterations."""
|
||||
return self._max_iters
|
||||
|
||||
def init_optimizer(self, optimizer):
|
||||
"""Init the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
|
||||
optimizer object or a dict used for constructing the optimizer.
|
||||
An example of the dict: ``{'algorithm': 'SGD', 'lr': 0.02,
|
||||
'momentum': 0.9, 'weight_decay': 0.0001}``.
|
||||
|
||||
Returns:
|
||||
:obj:`~torch.optim.Optimizer`: An optimizer object.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
optimizer = obj_from_dict(
|
||||
optimizer, torch.optim, dict(params=self.model.parameters()))
|
||||
elif not isinstance(optimizer, torch.optim.Optimizer):
|
||||
raise TypeError(
|
||||
'optimizer must be either an Optimizer object or a dict, '
|
||||
'but got {}'.format(type(optimizer)))
|
||||
return optimizer
|
||||
|
||||
def init_logger(self, log_dir=None, level=logging.INFO):
|
||||
"""Init the logger.
|
||||
|
||||
Args:
|
||||
log_dir(str, optional): Log file directory. If not specified, no
|
||||
log file will be used.
|
||||
level (int or str): See the built-in python logging module.
|
||||
|
||||
Returns:
|
||||
:obj:`~logging.Logger`: Python logger.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
|
||||
logger = logging.getLogger(__name__)
|
||||
if log_dir:
|
||||
filename = '{}_{}.log'.format(get_time_str(), self.rank)
|
||||
log_file = osp.join(log_dir, filename)
|
||||
add_file_handler(logger, log_file, level=level)
|
||||
return logger
|
||||
|
||||
def current_lr(self):
|
||||
"""Get current learning rates.
|
||||
|
||||
Returns:
|
||||
list: Current learning rate of all param groups.
|
||||
"""
|
||||
return [group['lr'] for group in self.optimizer.param_groups]
|
||||
|
||||
def register_hook(self, hook, priority=50):
|
||||
"""Register a hook into the hook list.
|
||||
|
||||
Args:
|
||||
hook (:obj:`Hook`): The hook to be registered.
|
||||
priority (int): Hook priority. Lower value means higher priority.
|
||||
"""
|
||||
assert isinstance(hook, Hook)
|
||||
assert isinstance(priority, int) and priority >= 0 and priority <= 100
|
||||
if hasattr(hook, 'priority'):
|
||||
raise ValueError('"priority" is a reserved attribute for hooks')
|
||||
hook.priority = priority
|
||||
# insert the hook to a sorted list
|
||||
inserted = False
|
||||
for i in range(len(self._hooks) - 1, -1, -1):
|
||||
if priority >= self._hooks[i].priority:
|
||||
self._hooks.insert(i + 1, hook)
|
||||
inserted = True
|
||||
break
|
||||
if not inserted:
|
||||
self._hooks.insert(0, hook)
|
||||
|
||||
def call_hook(self, fn_name):
|
||||
for hook in self._hooks:
|
||||
getattr(hook, fn_name)(self)
|
||||
|
||||
def load_checkpoint(self, filename, map_location='cpu', strict=False):
|
||||
self.logger.info('load checkpoint from %s', filename)
|
||||
return load_checkpoint(self.model, filename, map_location, strict,
|
||||
self.logger)
|
||||
|
||||
def save_checkpoint(self,
|
||||
out_dir,
|
||||
filename_tmpl='epoch_{}.pth',
|
||||
save_optimizer=True,
|
||||
meta=None):
|
||||
if meta is None:
|
||||
meta = dict(epoch=self.epoch + 1, iter=self.iter)
|
||||
else:
|
||||
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
||||
|
||||
filename = osp.join(out_dir, filename_tmpl.format(self.epoch))
|
||||
linkname = osp.join(out_dir, 'latest.pth')
|
||||
optimizer = self.optimizer if save_optimizer else None
|
||||
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
|
||||
mmcv.symlink(filename, linkname)
|
||||
|
||||
def train(self, data_loader, **kwargs):
|
||||
self.model.train()
|
||||
self.mode = 'train'
|
||||
self.data_loader = data_loader
|
||||
self._max_iters = self._max_epochs * len(data_loader)
|
||||
self.call_hook('before_train_epoch')
|
||||
|
||||
for i, data_batch in enumerate(data_loader):
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_train_iter')
|
||||
outputs = self.batch_processor(
|
||||
self.model, data_batch, train_mode=True, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
raise TypeError('batch_processor() must return a dict')
|
||||
if 'log_vars' in outputs:
|
||||
self.log_buffer.update(outputs['log_vars'],
|
||||
outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_train_iter')
|
||||
self._iter += 1
|
||||
|
||||
self.call_hook('after_train_epoch')
|
||||
self._epoch += 1
|
||||
|
||||
def val(self, data_loader, **kwargs):
|
||||
self.model.eval()
|
||||
self.mode = 'val'
|
||||
self.data_loader = data_loader
|
||||
self.call_hook('before_val_epoch')
|
||||
|
||||
for i, data_batch in enumerate(data_loader):
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_val_iter')
|
||||
outputs = self.batch_processor(
|
||||
self.model, data_batch, train_mode=False, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
raise TypeError('batch_processor() must return a dict')
|
||||
if 'log_vars' in outputs:
|
||||
self.log_buffer.update(outputs['log_vars'],
|
||||
outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_val_iter')
|
||||
|
||||
self.call_hook('after_val_epoch')
|
||||
|
||||
def resume(self, checkpoint, resume_optimizer=True,
|
||||
map_location='default'):
|
||||
if map_location == 'default':
|
||||
device_id = torch.cuda.current_device()
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint,
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
else:
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint, map_location=map_location)
|
||||
|
||||
self._epoch = checkpoint['meta']['epoch']
|
||||
self._iter = checkpoint['meta']['iter']
|
||||
if 'optimizer' in checkpoint and resume_optimizer:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
||||
|
||||
def run(self, data_loaders, workflow, max_epochs, **kwargs):
|
||||
assert isinstance(data_loaders, list)
|
||||
assert mmcv.is_list_of(workflow, tuple)
|
||||
assert len(data_loaders) == len(workflow)
|
||||
|
||||
self._max_epochs = max_epochs
|
||||
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
||||
self.logger.info('Start running, host: %s, work_dir: %s',
|
||||
get_host_info(), work_dir)
|
||||
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
|
||||
self.call_hook('before_run')
|
||||
|
||||
while self.epoch < max_epochs:
|
||||
for i, flow in enumerate(workflow):
|
||||
mode, epochs = flow
|
||||
if isinstance(mode, str): # self.train()
|
||||
if not hasattr(self, mode):
|
||||
raise ValueError(
|
||||
'runner has no method named "{}" to run an epoch'.
|
||||
format(mode))
|
||||
epoch_runner = getattr(self, mode)
|
||||
elif callable(mode): # custom train()
|
||||
epoch_runner = mode
|
||||
else:
|
||||
raise TypeError('mode in workflow must be a str or '
|
||||
'callable function, not {}'.format(
|
||||
type(mode)))
|
||||
for _ in range(epochs):
|
||||
if mode == 'train' and self.epoch >= max_epochs:
|
||||
return
|
||||
epoch_runner(data_loaders[i], **kwargs)
|
||||
|
||||
time.sleep(1) # wait for some hooks like loggers to finish
|
||||
self.call_hook('after_run')
|
||||
|
||||
def register_lr_hooks(self, lr_config):
|
||||
if isinstance(lr_config, LrUpdaterHook):
|
||||
self.register_hook(lr_config)
|
||||
elif isinstance(lr_config, dict):
|
||||
assert 'policy' in lr_config
|
||||
from ..hooks import lr_updater
|
||||
hook_name = lr_config['policy'].title() + 'LrUpdaterHook'
|
||||
if not hasattr(lr_updater, hook_name):
|
||||
raise ValueError('"{}" does not exist'.format(hook_name))
|
||||
hook_cls = getattr(lr_updater, hook_name)
|
||||
self.register_hook(hook_cls(**lr_config))
|
||||
else:
|
||||
raise TypeError('"lr_config" must be either a LrUpdaterHook object'
|
||||
' or dict, not {}'.format(type(lr_config)))
|
||||
|
||||
def register_logger_hooks(self, log_config):
|
||||
log_interval = log_config['interval']
|
||||
for info in log_config['hooks']:
|
||||
logger_hook = obj_from_dict(
|
||||
info, hooks, default_args=dict(interval=log_interval))
|
||||
self.register_hook(logger_hook, priority=60)
|
||||
|
||||
def register_default_hooks(self,
|
||||
lr_config,
|
||||
grad_clip_config=None,
|
||||
checkpoint_config=None,
|
||||
log_config=None):
|
||||
"""Register several default hooks.
|
||||
|
||||
Default hooks include:
|
||||
- LrUpdaterHook
|
||||
- OptimizerStepperHook
|
||||
- CheckpointSaverHook
|
||||
- IterTimerHook
|
||||
- LoggerHook
|
||||
"""
|
||||
if grad_clip_config is None:
|
||||
grad_clip_config = {}
|
||||
if checkpoint_config is None:
|
||||
checkpoint_config = {}
|
||||
self.register_lr_hooks(lr_config)
|
||||
self.register_hook(OptimizerStepperHook(**grad_clip_config))
|
||||
self.register_hook(CheckpointSaverHook(**checkpoint_config))
|
||||
self.register_hook(IterTimerHook())
|
||||
if log_config is not None:
|
||||
self.register_logger_hooks(log_config)
|
|
@ -0,0 +1,77 @@
|
|||
import functools
|
||||
import logging
|
||||
import time
|
||||
from getpass import getuser
|
||||
from socket import gethostname
|
||||
|
||||
import mmcv
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def get_host_info():
|
||||
return '{}@{}'.format(getuser(), gethostname())
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if dist._initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
|
||||
def add_file_handler(logger, filename=None, mode='w', level=logging.INFO):
|
||||
file_handler = logging.FileHandler(filename, mode)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(file_handler)
|
||||
return logger
|
||||
|
||||
|
||||
def obj_from_dict(info, module, default_args=None):
|
||||
"""Initialize an object from dict.
|
||||
|
||||
The dict must contain the key "type", which indicates the object type, it
|
||||
can be either a string or type, such as "list" or ``list``. Remaining
|
||||
fields are treated as the arguments for constructing the object.
|
||||
|
||||
Args:
|
||||
info (dict): Object types and arguments.
|
||||
module (:class:`module`): Module which may containing expected object
|
||||
classes.
|
||||
default_args (dict, optional): Default arguments for initializing the
|
||||
object.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert isinstance(info, dict) and 'type' in info
|
||||
assert isinstance(default_args, dict) or default_args is None
|
||||
args = info.copy()
|
||||
obj_type = args.pop('type')
|
||||
if mmcv.is_str(obj_type):
|
||||
obj_type = getattr(module, obj_type)
|
||||
elif not isinstance(obj_type, type):
|
||||
raise TypeError('type must be a str or valid type, but got {}'.format(
|
||||
type(obj_type)))
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_type(**args)
|
Loading…
Reference in New Issue