mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* Fall back to CosineLr * Fix consineanealing with unittest * Cover momentum hook * Add comments to explain
476 lines
18 KiB
Python
476 lines
18 KiB
Python
# Copyright (c) Open-MMLab. All rights reserved.
|
|
import logging
|
|
import os.path as osp
|
|
import time
|
|
|
|
import torch
|
|
|
|
import mmcv
|
|
from .checkpoint import load_checkpoint, save_checkpoint
|
|
from .dist_utils import get_dist_info
|
|
from .hooks import HOOKS, Hook, IterTimerHook
|
|
from .log_buffer import LogBuffer
|
|
from .priority import get_priority
|
|
from .utils import get_host_info, get_time_str, obj_from_dict
|
|
|
|
|
|
class Runner(object):
|
|
"""A training helper for PyTorch.
|
|
|
|
Args:
|
|
model (:obj:`torch.nn.Module`): The model to be run.
|
|
batch_processor (callable): A callable method that process a data
|
|
batch. The interface of this method should be
|
|
`batch_processor(model, data, train_mode) -> dict`
|
|
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
|
|
runner will construct an optimizer according to it.
|
|
work_dir (str, optional): The working directory to save checkpoints
|
|
and logs.
|
|
log_level (int): Logging level.
|
|
logger (:obj:`logging.Logger`): Custom logger. If `None`, use the
|
|
default logger.
|
|
meta (dict | None): A dict records some import information such as
|
|
environment info and seed, which will be logged in logger hook.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model,
|
|
batch_processor,
|
|
optimizer=None,
|
|
work_dir=None,
|
|
log_level=logging.INFO,
|
|
logger=None,
|
|
meta=None):
|
|
assert callable(batch_processor)
|
|
self.model = model
|
|
if optimizer is not None:
|
|
self.optimizer = self.init_optimizer(optimizer)
|
|
else:
|
|
self.optimizer = None
|
|
self.batch_processor = batch_processor
|
|
|
|
# create work_dir
|
|
if mmcv.is_str(work_dir):
|
|
self.work_dir = osp.abspath(work_dir)
|
|
mmcv.mkdir_or_exist(self.work_dir)
|
|
elif work_dir is None:
|
|
self.work_dir = None
|
|
else:
|
|
raise TypeError('"work_dir" must be a str or None')
|
|
|
|
# get model name from the model class
|
|
if hasattr(self.model, 'module'):
|
|
self._model_name = self.model.module.__class__.__name__
|
|
else:
|
|
self._model_name = self.model.__class__.__name__
|
|
|
|
self._rank, self._world_size = get_dist_info()
|
|
self.timestamp = get_time_str()
|
|
if logger is None:
|
|
self.logger = self.init_logger(work_dir, log_level)
|
|
else:
|
|
self.logger = logger
|
|
self.log_buffer = LogBuffer()
|
|
|
|
if meta is not None:
|
|
assert isinstance(meta, dict), '"meta" must be a dict or None'
|
|
self.meta = meta
|
|
|
|
self.mode = None
|
|
self._hooks = []
|
|
self._epoch = 0
|
|
self._iter = 0
|
|
self._inner_iter = 0
|
|
self._max_epochs = 0
|
|
self._max_iters = 0
|
|
|
|
@property
|
|
def model_name(self):
|
|
"""str: Name of the model, usually the module class name."""
|
|
return self._model_name
|
|
|
|
@property
|
|
def rank(self):
|
|
"""int: Rank of current process. (distributed training)"""
|
|
return self._rank
|
|
|
|
@property
|
|
def world_size(self):
|
|
"""int: Number of processes participating in the job.
|
|
(distributed training)"""
|
|
return self._world_size
|
|
|
|
@property
|
|
def hooks(self):
|
|
"""list[:obj:`Hook`]: A list of registered hooks."""
|
|
return self._hooks
|
|
|
|
@property
|
|
def epoch(self):
|
|
"""int: Current epoch."""
|
|
return self._epoch
|
|
|
|
@property
|
|
def iter(self):
|
|
"""int: Current iteration."""
|
|
return self._iter
|
|
|
|
@property
|
|
def inner_iter(self):
|
|
"""int: Iteration in an epoch."""
|
|
return self._inner_iter
|
|
|
|
@property
|
|
def max_epochs(self):
|
|
"""int: Maximum training epochs."""
|
|
return self._max_epochs
|
|
|
|
@property
|
|
def max_iters(self):
|
|
"""int: Maximum training iterations."""
|
|
return self._max_iters
|
|
|
|
def init_optimizer(self, optimizer):
|
|
"""Init the optimizer.
|
|
|
|
Args:
|
|
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
|
|
optimizer object or a dict used for constructing the optimizer.
|
|
|
|
Returns:
|
|
:obj:`~torch.optim.Optimizer`: An optimizer object.
|
|
|
|
Examples:
|
|
>>> optimizer = dict(type='SGD', lr=0.01, momentum=0.9)
|
|
>>> type(runner.init_optimizer(optimizer))
|
|
<class 'torch.optim.sgd.SGD'>
|
|
"""
|
|
if isinstance(optimizer, dict):
|
|
optimizer = obj_from_dict(optimizer, torch.optim,
|
|
dict(params=self.model.parameters()))
|
|
elif not isinstance(optimizer, torch.optim.Optimizer):
|
|
raise TypeError(
|
|
'optimizer must be either an Optimizer object or a dict, '
|
|
f'but got {type(optimizer)}')
|
|
return optimizer
|
|
|
|
def _add_file_handler(self,
|
|
logger,
|
|
filename=None,
|
|
mode='w',
|
|
level=logging.INFO):
|
|
# TODO: move this method out of runner
|
|
file_handler = logging.FileHandler(filename, mode)
|
|
file_handler.setFormatter(
|
|
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
|
file_handler.setLevel(level)
|
|
logger.addHandler(file_handler)
|
|
return logger
|
|
|
|
def init_logger(self, log_dir=None, level=logging.INFO):
|
|
"""Init the logger.
|
|
|
|
Args:
|
|
log_dir(str, optional): Log file directory. If not specified, no
|
|
log file will be used.
|
|
level (int or str): See the built-in python logging module.
|
|
|
|
Returns:
|
|
:obj:`~logging.Logger`: Python logger.
|
|
"""
|
|
logging.basicConfig(
|
|
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
|
|
logger = logging.getLogger(__name__)
|
|
if log_dir and self.rank == 0:
|
|
filename = f'{self.timestamp}.log'
|
|
log_file = osp.join(log_dir, filename)
|
|
self._add_file_handler(logger, log_file, level=level)
|
|
return logger
|
|
|
|
def current_lr(self):
|
|
"""Get current learning rates.
|
|
|
|
Returns:
|
|
list: Current learning rate of all param groups.
|
|
"""
|
|
if self.optimizer is None:
|
|
raise RuntimeError(
|
|
'lr is not applicable because optimizer does not exist.')
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
def current_momentum(self):
|
|
"""Get current momentums.
|
|
|
|
Returns:
|
|
list: Current momentum of all param groups.
|
|
"""
|
|
if self.optimizer is None:
|
|
raise RuntimeError(
|
|
'momentum is not applicable because optimizer does not exist.')
|
|
momentums = []
|
|
for group in self.optimizer.param_groups:
|
|
if 'momentum' in group.keys():
|
|
momentums.append(group['momentum'])
|
|
elif 'betas' in group.keys():
|
|
momentums.append(group['betas'][0])
|
|
else:
|
|
momentums.append(0)
|
|
return momentums
|
|
|
|
def register_hook(self, hook, priority='NORMAL'):
|
|
"""Register a hook into the hook list.
|
|
|
|
Args:
|
|
hook (:obj:`Hook`): The hook to be registered.
|
|
priority (int or str or :obj:`Priority`): Hook priority.
|
|
Lower value means higher priority.
|
|
"""
|
|
assert isinstance(hook, Hook)
|
|
if hasattr(hook, 'priority'):
|
|
raise ValueError('"priority" is a reserved attribute for hooks')
|
|
priority = get_priority(priority)
|
|
hook.priority = priority
|
|
# insert the hook to a sorted list
|
|
inserted = False
|
|
for i in range(len(self._hooks) - 1, -1, -1):
|
|
if priority >= self._hooks[i].priority:
|
|
self._hooks.insert(i + 1, hook)
|
|
inserted = True
|
|
break
|
|
if not inserted:
|
|
self._hooks.insert(0, hook)
|
|
|
|
def call_hook(self, fn_name):
|
|
for hook in self._hooks:
|
|
getattr(hook, fn_name)(self)
|
|
|
|
def load_checkpoint(self, filename, map_location='cpu', strict=False):
|
|
self.logger.info('load checkpoint from %s', filename)
|
|
return load_checkpoint(self.model, filename, map_location, strict,
|
|
self.logger)
|
|
|
|
def save_checkpoint(self,
|
|
out_dir,
|
|
filename_tmpl='epoch_{}.pth',
|
|
save_optimizer=True,
|
|
meta=None,
|
|
create_symlink=True):
|
|
if meta is None:
|
|
meta = dict(epoch=self.epoch + 1, iter=self.iter)
|
|
else:
|
|
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
|
|
|
filename = filename_tmpl.format(self.epoch + 1)
|
|
filepath = osp.join(out_dir, filename)
|
|
optimizer = self.optimizer if save_optimizer else None
|
|
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
|
# in some environments, `os.symlink` is not supported, you may need to
|
|
# set `create_symlink` to False
|
|
if create_symlink:
|
|
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
|
|
|
|
def train(self, data_loader, **kwargs):
|
|
self.model.train()
|
|
self.mode = 'train'
|
|
self.data_loader = data_loader
|
|
|
|
self.call_hook('before_train_epoch')
|
|
for i, data_batch in enumerate(data_loader):
|
|
self._inner_iter = i
|
|
self.call_hook('before_train_iter')
|
|
outputs = self.batch_processor(
|
|
self.model, data_batch, train_mode=True, **kwargs)
|
|
if not isinstance(outputs, dict):
|
|
raise TypeError('batch_processor() must return a dict')
|
|
if 'log_vars' in outputs:
|
|
self.log_buffer.update(outputs['log_vars'],
|
|
outputs['num_samples'])
|
|
self.outputs = outputs
|
|
self.call_hook('after_train_iter')
|
|
self._iter += 1
|
|
|
|
self.call_hook('after_train_epoch')
|
|
self._epoch += 1
|
|
|
|
def val(self, data_loader, **kwargs):
|
|
self.model.eval()
|
|
self.mode = 'val'
|
|
self.data_loader = data_loader
|
|
self.call_hook('before_val_epoch')
|
|
|
|
for i, data_batch in enumerate(data_loader):
|
|
self._inner_iter = i
|
|
self.call_hook('before_val_iter')
|
|
with torch.no_grad():
|
|
outputs = self.batch_processor(
|
|
self.model, data_batch, train_mode=False, **kwargs)
|
|
if not isinstance(outputs, dict):
|
|
raise TypeError('batch_processor() must return a dict')
|
|
if 'log_vars' in outputs:
|
|
self.log_buffer.update(outputs['log_vars'],
|
|
outputs['num_samples'])
|
|
self.outputs = outputs
|
|
self.call_hook('after_val_iter')
|
|
|
|
self.call_hook('after_val_epoch')
|
|
|
|
def resume(self,
|
|
checkpoint,
|
|
resume_optimizer=True,
|
|
map_location='default'):
|
|
if map_location == 'default':
|
|
device_id = torch.cuda.current_device()
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint,
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
else:
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint, map_location=map_location)
|
|
|
|
self._epoch = checkpoint['meta']['epoch']
|
|
self._iter = checkpoint['meta']['iter']
|
|
if 'optimizer' in checkpoint and resume_optimizer:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
|
|
|
def run(self, data_loaders, workflow, max_epochs, **kwargs):
|
|
"""Start running.
|
|
|
|
Args:
|
|
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
|
|
and validation.
|
|
workflow (list[tuple]): A list of (phase, epochs) to specify the
|
|
running order and epochs. E.g, [('train', 2), ('val', 1)] means
|
|
running 2 epochs for training and 1 epoch for validation,
|
|
iteratively.
|
|
max_epochs (int): Total training epochs.
|
|
"""
|
|
assert isinstance(data_loaders, list)
|
|
assert mmcv.is_list_of(workflow, tuple)
|
|
assert len(data_loaders) == len(workflow)
|
|
|
|
self._max_epochs = max_epochs
|
|
for i, flow in enumerate(workflow):
|
|
mode, epochs = flow
|
|
if mode == 'train':
|
|
self._max_iters = self._max_epochs * len(data_loaders[i])
|
|
break
|
|
|
|
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
|
self.logger.info('Start running, host: %s, work_dir: %s',
|
|
get_host_info(), work_dir)
|
|
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
|
|
self.call_hook('before_run')
|
|
|
|
while self.epoch < max_epochs:
|
|
for i, flow in enumerate(workflow):
|
|
mode, epochs = flow
|
|
if isinstance(mode, str): # self.train()
|
|
if not hasattr(self, mode):
|
|
raise ValueError(
|
|
f'runner has no method named "{mode}" to run an '
|
|
'epoch')
|
|
epoch_runner = getattr(self, mode)
|
|
elif callable(mode): # custom train()
|
|
epoch_runner = mode
|
|
else:
|
|
raise TypeError('mode in workflow must be a str or '
|
|
f'callable function, not {type(mode)}')
|
|
for _ in range(epochs):
|
|
if mode == 'train' and self.epoch >= max_epochs:
|
|
return
|
|
epoch_runner(data_loaders[i], **kwargs)
|
|
|
|
time.sleep(1) # wait for some hooks like loggers to finish
|
|
self.call_hook('after_run')
|
|
|
|
def register_lr_hook(self, lr_config):
|
|
if isinstance(lr_config, dict):
|
|
assert 'policy' in lr_config
|
|
policy_type = lr_config.pop('policy')
|
|
# If the type of policy is all in lower case, e.g., 'cyclic',
|
|
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
|
# This is for the convenient usage of Lr updater updater.
|
|
# Since this is not applicable for `CosineAnealingLrUpdater`,
|
|
# the string will not be changed if it contains capital letters.
|
|
if policy_type == policy_type.lower():
|
|
policy_type = policy_type.title()
|
|
hook_type = policy_type + 'LrUpdaterHook'
|
|
lr_config['type'] = hook_type
|
|
hook = mmcv.build_from_cfg(lr_config, HOOKS)
|
|
else:
|
|
hook = lr_config
|
|
self.register_hook(hook)
|
|
|
|
def register_optimizer_hook(self, optimizer_config):
|
|
if optimizer_config is None:
|
|
return
|
|
if isinstance(optimizer_config, dict):
|
|
optimizer_config.setdefault('type', 'OptimizerHook')
|
|
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
|
|
else:
|
|
hook = optimizer_config
|
|
self.register_hook(hook)
|
|
|
|
def register_checkpoint_hook(self, checkpoint_config):
|
|
if checkpoint_config is None:
|
|
return
|
|
if isinstance(checkpoint_config, dict):
|
|
checkpoint_config.setdefault('type', 'CheckpointHook')
|
|
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
|
|
else:
|
|
hook = checkpoint_config
|
|
self.register_hook(hook)
|
|
|
|
def register_momentum_hook(self, momentum_config):
|
|
if momentum_config is None:
|
|
return
|
|
if isinstance(momentum_config, dict):
|
|
assert 'policy' in momentum_config
|
|
policy_type = momentum_config.pop('policy')
|
|
# If the type of policy is all in lower case, e.g., 'cyclic',
|
|
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
|
# This is for the convenient usage of momentum updater.
|
|
# Since this is not applicable for `CosineAnealingMomentumUpdater`,
|
|
# the string will not be changed if it contains capital letters.
|
|
if policy_type == policy_type.lower():
|
|
policy_type = policy_type.title()
|
|
hook_type = policy_type + 'MomentumUpdaterHook'
|
|
momentum_config['type'] = hook_type
|
|
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
|
|
else:
|
|
hook = momentum_config
|
|
self.register_hook(hook)
|
|
|
|
def register_logger_hooks(self, log_config):
|
|
log_interval = log_config['interval']
|
|
for info in log_config['hooks']:
|
|
logger_hook = mmcv.build_from_cfg(
|
|
info, HOOKS, default_args=dict(interval=log_interval))
|
|
self.register_hook(logger_hook, priority='VERY_LOW')
|
|
|
|
def register_training_hooks(self,
|
|
lr_config,
|
|
optimizer_config=None,
|
|
checkpoint_config=None,
|
|
log_config=None,
|
|
momentum_config=None):
|
|
"""Register default hooks for training.
|
|
|
|
Default hooks include:
|
|
|
|
- LrUpdaterHook
|
|
- MomentumUpdaterHook
|
|
- OptimizerStepperHook
|
|
- CheckpointSaverHook
|
|
- IterTimerHook
|
|
- LoggerHook(s)
|
|
"""
|
|
self.register_lr_hook(lr_config)
|
|
self.register_momentum_hook(momentum_config)
|
|
self.register_optimizer_hook(optimizer_config)
|
|
self.register_checkpoint_hook(checkpoint_config)
|
|
self.register_hook(IterTimerHook())
|
|
self.register_logger_hooks(log_config)
|