mmcv/mmcv/runner/base_runner.py

397 lines
14 KiB
Python

# Copyright (c) Open-MMLab. All rights reserved.
import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
import torch
from torch.optim import Optimizer
import mmcv
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook
from .log_buffer import LogBuffer
from .priority import get_priority
from .utils import get_time_str
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None):
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_module_wrapper(model):
_model = model.module
else:
_model = model
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
else:
assert hasattr(model, 'train_step')
# check the type of `optimizer`
if isinstance(optimizer, dict):
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
raise TypeError(
f'optimizer must be a dict of torch.optim.Optimizers, '
f'but optimizer["{name}"] is a {type(optim)}')
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
raise TypeError(
f'optimizer must be a torch.optim.Optimizer object '
f'or dict or None, but got {type(optimizer)}')
# check the type of `logger`
if not isinstance(logger, logging.Logger):
raise TypeError(f'logger must be a logging.Logger object, '
f'but got {type(logger)}')
# check the type of `meta`
if meta is not None and not isinstance(meta, dict):
raise TypeError(
f'meta must be a dict or None, but got {type(meta)}')
self.model = model
self.batch_processor = batch_processor
self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
else:
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
self._hooks = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._max_epochs = 0
self._max_iters = 0
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
@property
def model_name(self):
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self):
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self):
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self):
"""int: Iteration in an epoch."""
return self._inner_iter
@property
def max_epochs(self):
"""int: Maximum training epochs."""
return self._max_epochs
@property
def max_iters(self):
"""int: Maximum training iterations."""
return self._max_iters
@abstractmethod
def train(self):
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
pass
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
pass
def current_lr(self):
"""Get current learning rates.
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
"""
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
lr = dict()
for name, optim in self.optimizer.items():
lr[name] = [group['lr'] for group in optim.param_groups]
else:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return lr
def current_momentum(self):
"""Get current momentums.
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
"""
def _get_momentum(optimizer):
momentums = []
for group in optimizer.param_groups:
if 'momentum' in group.keys():
momentums.append(group['momentum'])
elif 'betas' in group.keys():
momentums.append(group['betas'][0])
else:
momentums.append(0)
return momentums
if self.optimizer is None:
raise RuntimeError(
'momentum is not applicable because optimizer does not exist.')
elif isinstance(self.optimizer, torch.optim.Optimizer):
momentums = _get_momentum(self.optimizer)
elif isinstance(self.optimizer, dict):
momentums = dict()
for name, optim in self.optimizer.items():
momentums[name] = _get_momentum(optim)
return momentums
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
def load_checkpoint(self, filename, map_location='cpu', strict=False):
self.logger.info('load checkpoint from %s', filename)
return load_checkpoint(self.model, filename, map_location, strict,
self.logger)
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def register_lr_hook(self, lr_config):
if isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `CosineAnealingLrUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook)
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
return
if isinstance(momentum_config, dict):
assert 'policy' in momentum_config
policy_type = momentum_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of momentum updater.
# Since this is not applicable for `CosineAnealingMomentumUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'MomentumUpdaterHook'
momentum_config['type'] = hook_type
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook)
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook)
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook)
def register_logger_hooks(self, log_config):
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_logger_hooks(log_config)