Add IterBasedRunner (#314)

* feat: add IterBasedRunner

* fix: unittest

* feat: more unittest

* fix: expose dataloader len

* minor updates of BaseRunner

* refactor: remove CosineRestartLrUpdaterHook

* style: add docstring

* refactor: update IterTextLoggerHook: fstring and exp_name

* fix: epoch_runner unittest

* refactor: remove IterBasedTextLogger

* fix: old IterTextLoggerHook issue

* refactor: remove __len__ of IterLoader

* feat: add IterBasedRunner to init

* feat: add __len__ to IterLoader

* fix some docstrings

* refactor: use is_parallel_module

* fix: import issue

* fix: runner unittest missing logger

* fix checkpoints

* feat: add by_epoch default value to IterBaseRunner regitering loggger_hook

* refactor: remove setting by_epoch in log_config

* minor refactoring

* docs: add docstring

* fix: remove unused doc

* update the log info for saving checkpoints

Co-authored-by: Kai Chen <chenkaidev@gmail.com>
This commit is contained in:
Harry 2020-06-11 13:35:34 +08:00 committed by GitHub
parent 61f9e91c9f
commit 67a26da917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 595 additions and 74 deletions

View File

@ -1,10 +1,84 @@
# Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DistributedDataParallel
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if torch.__version__ > '1.2':
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if torch.__version__ > '1.2':
self.require_forward_param_sync = False
return output

View File

@ -52,3 +52,15 @@ class MMDistributedDataParallel(nn.Module):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output

View File

@ -8,6 +8,7 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
@ -16,14 +17,15 @@ from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'LogBuffer', 'HOOKS', 'Hook',
'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook',
'MlflowLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor'
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MlflowLoggerHook', '_load_checkpoint',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner'
]

View File

@ -5,6 +5,7 @@ import warnings
from abc import ABCMeta, abstractmethod
import torch
from torch.optim import Optimizer
import mmcv
from ..parallel import is_parallel_module
@ -31,12 +32,14 @@ class BaseRunner(metaclass=ABCMeta):
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
@ -67,9 +70,34 @@ class BaseRunner(metaclass=ABCMeta):
'cannot be both available.')
else:
assert hasattr(model, 'train_step')
# check the type of `optimizer`
if isinstance(optimizer, dict):
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
raise TypeError(
f'optimizer must be a dict of torch.optim.Optimizers, '
f'but optimizer["{name}"] is a {type(optim)}')
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
raise TypeError(
f'optimizer must be a torch.optim.Optimizer object '
f'or dict or None, but got {type(optimizer)}')
# check the type of `logger`
if not isinstance(logger, logging.Logger):
raise TypeError(f'logger must be a logging.Logger object, '
f'but got {type(logger)}')
# check the type of `meta`
if meta is not None and not isinstance(meta, dict):
raise TypeError(
f'meta must be a dict or None, but got {type(meta)}')
self.model = model
self.batch_processor = batch_processor
self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
@ -86,13 +114,6 @@ class BaseRunner(metaclass=ABCMeta):
else:
self._model_name = self.model.__class__.__name__
assert logging is not None
self.logger = logger
if meta is not None:
assert isinstance(meta, dict), '"meta" must be a dict or None'
self.meta = meta
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
@ -176,30 +197,50 @@ class BaseRunner(metaclass=ABCMeta):
"""Get current learning rates.
Returns:
list: Current learning rate of all param groups.
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
"""
if self.optimizer is None:
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
lr = dict()
for name, optim in self.optimizer.items():
lr[name] = [group['lr'] for group in optim.param_groups]
else:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups]
return lr
def current_momentum(self):
"""Get current momentums.
Returns:
list: Current momentum of all param groups.
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
"""
def _get_momentum(optimizer):
momentums = []
for group in optimizer.param_groups:
if 'momentum' in group.keys():
momentums.append(group['momentum'])
elif 'betas' in group.keys():
momentums.append(group['betas'][0])
else:
momentums.append(0)
return momentums
if self.optimizer is None:
raise RuntimeError(
'momentum is not applicable because optimizer does not exist.')
momentums = []
for group in self.optimizer.param_groups:
if 'momentum' in group.keys():
momentums.append(group['momentum'])
elif 'betas' in group.keys():
momentums.append(group['betas'][0])
else:
momentums.append(0)
elif isinstance(self.optimizer, torch.optim.Optimizer):
momentums = _get_momentum(self.optimizer)
elif isinstance(self.optimizer, dict):
momentums = dict()
for name, optim in self.optimizer.items():
momentums[name] = _get_momentum(optim)
return momentums
def register_hook(self, hook, priority='NORMAL'):

View File

@ -9,10 +9,12 @@ from importlib import import_module
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo
import mmcv
from ..fileio import load as load_file
from ..parallel import is_parallel_module
from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info
@ -59,6 +61,10 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_parallel_module(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
@ -228,10 +234,7 @@ def load_checkpoint(model,
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict
if hasattr(model, 'module'):
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
load_state_dict(model, state_dict, strict, logger)
return checkpoint
@ -269,15 +272,20 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if hasattr(model, 'module'):
if is_parallel_module(model):
model = model.module
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
if optimizer is not None:
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)

View File

@ -7,14 +7,34 @@ from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
"""
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
@ -22,9 +42,10 @@ class CheckpointHook(Hook):
@master_only
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
@ -41,3 +62,29 @@ class CheckpointHook(Hook):
os.remove(ckpt_path)
else:
break
@master_only
def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth')
current_iter = runner.iter + 1
for _iter in range(
current_iter - self.max_keep_ckpts * self.interval, 0,
-self.interval):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_iter))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break

View File

@ -71,8 +71,22 @@ class PaviLoggerHook(LoggerHook):
for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val
tags['learning_rate'] = runner.current_lr()[0]
tags['momentum'] = runner.current_momentum()[0]
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)

View File

@ -53,10 +53,22 @@ class TensorboardLoggerHook(LoggerHook):
else:
self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter)
self.writer.add_scalar('learning_rate',
runner.current_lr()[0], runner.iter)
self.writer.add_scalar('momentum',
runner.current_momentum()[0], runner.iter)
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
self.writer.add_scalar(f'learning_rate/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('learning_rate', lrs[0], runner.iter)
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
self.writer.add_scalar(f'momentum/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('momentum', momentums[0], runner.iter)
@master_only
def after_run(self, runner):

View File

@ -19,6 +19,7 @@ class TextLoggerHook(LoggerHook):
saved in json file.
Args:
by_epoch (bool): Whether EpochBasedRunner is used.
interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
@ -29,11 +30,13 @@ class TextLoggerHook(LoggerHook):
"""
def __init__(self,
by_epoch=True,
interval=10,
ignore_last=True,
reset_flag=False,
interval_exp_name=1000):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
@ -55,17 +58,32 @@ class TextLoggerHook(LoggerHook):
return mem_mb.item()
def _log_info(self, log_dict, runner):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_inner_iters(
runner,
self.interval_exp_name)) or self.end_of_epoch(runner):
exp_info = f"Exp name: {runner.meta['exp_name']}"
if (self.every_n_inner_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
if runner.mode == 'train':
log_str = f'Epoch [{log_dict["epoch"]}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' \
f'lr: {log_dict["lr"]:.5f}, '
if isinstance(log_dict['lr'], dict):
lr_str = []
for k, val in log_dict['lr'].items():
lr_str.append(f'lr_{k}: {val:.3e}')
lr_str = ' '.join(lr_str)
else:
lr_str = f'lr: {log_dict["lr"]:.3e}'
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{log_dict["epoch"]}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
else:
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
log_str += f'{lr_str}, '
if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (
@ -79,8 +97,12 @@ class TextLoggerHook(LoggerHook):
if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, '
else:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
else:
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
log_items = []
for name, val in log_dict.items():
# TODO: resolve this hack
@ -94,6 +116,7 @@ class TextLoggerHook(LoggerHook):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
def _dump_log(self, log_dict, runner):
@ -123,7 +146,16 @@ class TextLoggerHook(LoggerHook):
log_dict['epoch'] = runner.epoch + 1
log_dict['iter'] = runner.inner_iter + 1
# only record lr of the first param group
log_dict['lr'] = runner.current_lr()[0]
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
log_dict['lr'] = cur_lr[0]
else:
assert isinstance(cur_lr, dict)
log_dict['lr'] = {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['lr'].update({k: lr_[0]})
if mode == 'train':
log_dict['time'] = runner.log_buffer.output['time']
log_dict['data_time'] = runner.log_buffer.output['data_time']

View File

@ -55,14 +55,31 @@ class LrUpdaterHook(Hook):
self.regular_lr = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups):
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
param_group['lr'] = lr
if isinstance(runner.optimizer, dict):
for k, optim in runner.optimizer.items():
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
param_group['lr'] = lr
else:
for param_group, lr in zip(runner.optimizer.param_groups,
lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr):
raise NotImplementedError
def get_regular_lr(self, runner):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
_lr_group = [
self.get_lr(runner, _base_lr)
for _base_lr in self.base_lr[k]
]
lr_groups.update({k: _lr_group})
return lr_groups
else:
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant':
@ -78,11 +95,21 @@ class LrUpdaterHook(Hook):
def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
]
if isinstance(runner.optimizer, dict):
self.base_lr = {}
for k, optim in runner.optimizer.items():
for group in optim.param_groups:
group.setdefault('initial_lr', group['lr'])
_base_lr = [
group['initial_lr'] for group in optim.param_groups
]
self.base_lr.update({k: _base_lr})
else:
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
if not self.by_epoch:
@ -213,6 +240,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
else:
progress = runner.iter
max_progress = runner.max_iters
if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
@ -224,7 +252,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler
Implemet the cyclical learning rate policy (CLR) described in
Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
Different from the original paper, we use cosine anealing rather than

View File

@ -0,0 +1,223 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import torch
from torch.optim import Optimizer
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .hooks import IterTimerHook
from .utils import get_host_info
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
class IterBasedRunner(BaseRunner):
"""Iteration-based Runner.
This runner train models iteration by iteration.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
self.call_hook('before_train_iter')
data_batch = next(data_loader)
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self._inner_iter = 0
self.data_loader = data_loader
self.call_hook('before_val_iter')
data_batch = next(data_loader)
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.val_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, iters) to specify the
running order and iterations. E.g, [('train', 10000),
('val', 1000)] means running 10000 iterations for training and
1000 iterations for validation, iteratively.
max_iters (int): Total training iterations.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_iters = max_iters
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d iters', workflow, max_iters)
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < max_iters:
for i, flow in enumerate(workflow):
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run a workflow'.
format(mode))
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= max_iters:
return
iter_runner(iter_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
"""Resume model from checkpoint.
Args:
checkpoint (str): Checkpoint to resume from.
resume_optimizer (bool, optional): Whether resume the optimizer(s)
if the checkpoint file includes optimizer(s). Default to True.
map_location (str, optional): Same as :func:`torch.load`.
Default to 'default'.
"""
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self._inner_iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
"""Save checkpoint to file.
Args:
out_dir (str): Directory to save checkpoint files.
filename_tmpl (str, optional): Checkpoint file template.
Defaults to 'iter_{}.pth'.
meta (dict, optional): Metadata to be saved in checkpoint.
Defaults to None.
save_optimizer (bool, optional): Whether save optimizer.
Defaults to True.
create_symlink (bool, optional): Whether create symlink to the
latest checkpoint file. Defaults to True.
"""
if meta is None:
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
elif isinstance(meta, dict):
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None):
"""Register default hooks for iter-based training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
"""
if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False)
if lr_config is not None:
lr_config.setdefault('by_epoch', False)
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_logger_hooks(log_config)

View File

@ -39,22 +39,48 @@ def test_epoch_based_runner():
def batch_processor():
pass
_ = EpochBasedRunner(model, batch_processor)
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
with pytest.raises(TypeError):
# batch_processor must be callable
model = OldStyleModel()
_ = EpochBasedRunner(model, batch_processor=0)
_ = EpochBasedRunner(
model, batch_processor=0, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizer = 'NotAOptimizer'
_ = EpochBasedRunner(
model, optimizer=optimizer, logger=logging.getLogger())
with pytest.raises(TypeError):
# optimizer must be a optimizer or a dict of optimizers
model = Model()
optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer')
_ = EpochBasedRunner(
model, optimizer=optimizers, logger=logging.getLogger())
with pytest.raises(TypeError):
# logger must be a logging.Logger
model = Model()
_ = EpochBasedRunner(model, logger=None)
with pytest.raises(TypeError):
# meta must be a dict or None
model = Model()
_ = EpochBasedRunner(model, logger=logging.getLogger(), meta=['list'])
with pytest.raises(AssertionError):
# model must implement the method train_step()
model = OldStyleModel()
_ = EpochBasedRunner(model)
_ = EpochBasedRunner(model, logger=logging.getLogger())
with pytest.raises(TypeError):
# work_dir must be a str or None
model = Model()
_ = EpochBasedRunner(model, work_dir=1)
_ = EpochBasedRunner(model, work_dir=1, logger=logging.getLogger())
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
@ -63,7 +89,8 @@ def test_epoch_based_runner():
pass
model = Model()
_ = EpochBasedRunner(model, batch_processor)
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
# test work_dir
model = Model()
@ -71,9 +98,9 @@ def test_epoch_based_runner():
dir_name = ''.join(
[random.choice(string.ascii_letters) for _ in range(10)])
work_dir = osp.join(temp_root, dir_name)
_ = EpochBasedRunner(model, work_dir=work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir)
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
assert osp.isdir(work_dir)
os.removedirs(work_dir)
@ -84,7 +111,7 @@ def test_runner_with_parallel():
pass
model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor)
_ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger())
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
@ -93,7 +120,8 @@ def test_runner_with_parallel():
pass
model = MMDataParallel(Model())
_ = EpochBasedRunner(model, batch_processor)
_ = EpochBasedRunner(
model, batch_processor, logger=logging.getLogger())
def test_save_checkpoint():