mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add IterBasedRunner (#314)
* feat: add IterBasedRunner * fix: unittest * feat: more unittest * fix: expose dataloader len * minor updates of BaseRunner * refactor: remove CosineRestartLrUpdaterHook * style: add docstring * refactor: update IterTextLoggerHook: fstring and exp_name * fix: epoch_runner unittest * refactor: remove IterBasedTextLogger * fix: old IterTextLoggerHook issue * refactor: remove __len__ of IterLoader * feat: add IterBasedRunner to init * feat: add __len__ to IterLoader * fix some docstrings * refactor: use is_parallel_module * fix: import issue * fix: runner unittest missing logger * fix checkpoints * feat: add by_epoch default value to IterBaseRunner regitering loggger_hook * refactor: remove setting by_epoch in log_config * minor refactoring * docs: add docstring * fix: remove unused doc * update the log info for saving checkpoints Co-authored-by: Kai Chen <chenkaidev@gmail.com>
This commit is contained in:
parent
61f9e91c9f
commit
67a26da917
@ -1,10 +1,84 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import (DistributedDataParallel,
|
||||
_find_tensors)
|
||||
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
class MMDistributedDataParallel(DistributedDataParallel):
|
||||
"""The DDP module that supports DataContainer.
|
||||
|
||||
MMDDP has two main differences with PyTorch DDP:
|
||||
|
||||
- It supports a custom type :class:`DataContainer` which allows more
|
||||
flexible control of input data.
|
||||
- It implement two APIs ``train_step()`` and ``val_step()``.
|
||||
"""
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
||||
def train_step(self, *inputs, **kwargs):
|
||||
"""train_step() API for module wrapped by DistributedDataParallel.
|
||||
|
||||
This method is basically the same as
|
||||
``DistributedDataParallel.forward()``, while replacing
|
||||
``self.module.forward()`` with ``self.module.train_step()``.
|
||||
It is compatible with PyTorch 1.1 - 1.5.
|
||||
"""
|
||||
if getattr(self, 'require_forward_param_sync', True):
|
||||
self._sync_params()
|
||||
if self.device_ids:
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
output = self.module.train_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
outputs = self.parallel_apply(
|
||||
self._module_copies[:len(inputs)], inputs, kwargs)
|
||||
output = self.gather(outputs, self.output_device)
|
||||
else:
|
||||
output = self.module.train_step(*inputs, **kwargs)
|
||||
|
||||
if torch.is_grad_enabled() and getattr(
|
||||
self, 'require_backward_grad_sync', True):
|
||||
if self.find_unused_parameters:
|
||||
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
if torch.__version__ > '1.2':
|
||||
self.require_forward_param_sync = False
|
||||
return output
|
||||
|
||||
def val_step(self, *inputs, **kwargs):
|
||||
"""val_step() API for module wrapped by DistributedDataParallel.
|
||||
|
||||
This method is basically the same as
|
||||
``DistributedDataParallel.forward()``, while replacing
|
||||
``self.module.forward()`` with ``self.module.val_step()``.
|
||||
It is compatible with PyTorch 1.1 - 1.5.
|
||||
"""
|
||||
if getattr(self, 'require_forward_param_sync', True):
|
||||
self._sync_params()
|
||||
if self.device_ids:
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
output = self.module.val_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
outputs = self.parallel_apply(
|
||||
self._module_copies[:len(inputs)], inputs, kwargs)
|
||||
output = self.gather(outputs, self.output_device)
|
||||
else:
|
||||
output = self.module.val_step(*inputs, **kwargs)
|
||||
|
||||
if torch.is_grad_enabled() and getattr(
|
||||
self, 'require_backward_grad_sync', True):
|
||||
if self.find_unused_parameters:
|
||||
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
if torch.__version__ > '1.2':
|
||||
self.require_forward_param_sync = False
|
||||
return output
|
||||
|
@ -52,3 +52,15 @@ class MMDistributedDataParallel(nn.Module):
|
||||
inputs, kwargs = self.scatter(inputs, kwargs,
|
||||
[torch.cuda.current_device()])
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
|
||||
def train_step(self, *inputs, **kwargs):
|
||||
inputs, kwargs = self.scatter(inputs, kwargs,
|
||||
[torch.cuda.current_device()])
|
||||
output = self.module.train_step(*inputs[0], **kwargs[0])
|
||||
return output
|
||||
|
||||
def val_step(self, *inputs, **kwargs):
|
||||
inputs, kwargs = self.scatter(inputs, kwargs,
|
||||
[torch.cuda.current_device()])
|
||||
output = self.module.val_step(*inputs[0], **kwargs[0])
|
||||
return output
|
||||
|
@ -8,6 +8,7 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
||||
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
|
||||
MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
|
||||
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
|
||||
from .iter_based_runner import IterBasedRunner, IterLoader
|
||||
from .log_buffer import LogBuffer
|
||||
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
|
||||
DefaultOptimizerConstructor, build_optimizer,
|
||||
@ -16,14 +17,15 @@ from .priority import Priority, get_priority
|
||||
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||
|
||||
__all__ = [
|
||||
'BaseRunner', 'Runner', 'EpochBasedRunner', 'LogBuffer', 'HOOKS', 'Hook',
|
||||
'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
|
||||
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook',
|
||||
'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook',
|
||||
'MlflowLoggerHook', '_load_checkpoint', 'load_state_dict',
|
||||
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
|
||||
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
|
||||
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
|
||||
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
|
||||
'build_optimizer_constructor'
|
||||
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
|
||||
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
|
||||
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
|
||||
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
||||
'WandbLoggerHook', 'MlflowLoggerHook', '_load_checkpoint',
|
||||
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
|
||||
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
|
||||
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
|
||||
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
|
||||
'IterBasedRunner'
|
||||
]
|
||||
|
@ -5,6 +5,7 @@ import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import mmcv
|
||||
from ..parallel import is_parallel_module
|
||||
@ -31,12 +32,14 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
batch_processor (callable): A callable method that process a data
|
||||
batch. The interface of this method should be
|
||||
`batch_processor(model, data, train_mode) -> dict`
|
||||
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
|
||||
runner will construct an optimizer according to it.
|
||||
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
|
||||
optimizer (in most cases) or a dict of optimizers (in models that
|
||||
requires more than one optimizer, e.g., GAN).
|
||||
work_dir (str, optional): The working directory to save checkpoints
|
||||
and logs. Defaults to None.
|
||||
logger (:obj:`logging.Logger`): Logger used during training.
|
||||
Defaults to None.
|
||||
Defaults to None. (The default value is just for backward
|
||||
compatibility)
|
||||
meta (dict | None): A dict records some import information such as
|
||||
environment info and seed, which will be logged in logger hook.
|
||||
Defaults to None.
|
||||
@ -67,9 +70,34 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
'cannot be both available.')
|
||||
else:
|
||||
assert hasattr(model, 'train_step')
|
||||
|
||||
# check the type of `optimizer`
|
||||
if isinstance(optimizer, dict):
|
||||
for name, optim in optimizer.items():
|
||||
if not isinstance(optim, Optimizer):
|
||||
raise TypeError(
|
||||
f'optimizer must be a dict of torch.optim.Optimizers, '
|
||||
f'but optimizer["{name}"] is a {type(optim)}')
|
||||
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
|
||||
raise TypeError(
|
||||
f'optimizer must be a torch.optim.Optimizer object '
|
||||
f'or dict or None, but got {type(optimizer)}')
|
||||
|
||||
# check the type of `logger`
|
||||
if not isinstance(logger, logging.Logger):
|
||||
raise TypeError(f'logger must be a logging.Logger object, '
|
||||
f'but got {type(logger)}')
|
||||
|
||||
# check the type of `meta`
|
||||
if meta is not None and not isinstance(meta, dict):
|
||||
raise TypeError(
|
||||
f'meta must be a dict or None, but got {type(meta)}')
|
||||
|
||||
self.model = model
|
||||
self.batch_processor = batch_processor
|
||||
self.optimizer = optimizer
|
||||
self.logger = logger
|
||||
self.meta = meta
|
||||
|
||||
# create work_dir
|
||||
if mmcv.is_str(work_dir):
|
||||
@ -86,13 +114,6 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
else:
|
||||
self._model_name = self.model.__class__.__name__
|
||||
|
||||
assert logging is not None
|
||||
self.logger = logger
|
||||
|
||||
if meta is not None:
|
||||
assert isinstance(meta, dict), '"meta" must be a dict or None'
|
||||
self.meta = meta
|
||||
|
||||
self._rank, self._world_size = get_dist_info()
|
||||
self.timestamp = get_time_str()
|
||||
self.mode = None
|
||||
@ -176,30 +197,50 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
"""Get current learning rates.
|
||||
|
||||
Returns:
|
||||
list: Current learning rate of all param groups.
|
||||
list[float] | dict[str, list[float]]: Current learning rates of all
|
||||
param groups. If the runner has a dict of optimizers, this
|
||||
method will return a dict.
|
||||
"""
|
||||
if self.optimizer is None:
|
||||
if isinstance(self.optimizer, torch.optim.Optimizer):
|
||||
lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||
elif isinstance(self.optimizer, dict):
|
||||
lr = dict()
|
||||
for name, optim in self.optimizer.items():
|
||||
lr[name] = [group['lr'] for group in optim.param_groups]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'lr is not applicable because optimizer does not exist.')
|
||||
return [group['lr'] for group in self.optimizer.param_groups]
|
||||
return lr
|
||||
|
||||
def current_momentum(self):
|
||||
"""Get current momentums.
|
||||
|
||||
Returns:
|
||||
list: Current momentum of all param groups.
|
||||
list[float] | dict[str, list[float]]: Current momentums of all
|
||||
param groups. If the runner has a dict of optimizers, this
|
||||
method will return a dict.
|
||||
"""
|
||||
|
||||
def _get_momentum(optimizer):
|
||||
momentums = []
|
||||
for group in optimizer.param_groups:
|
||||
if 'momentum' in group.keys():
|
||||
momentums.append(group['momentum'])
|
||||
elif 'betas' in group.keys():
|
||||
momentums.append(group['betas'][0])
|
||||
else:
|
||||
momentums.append(0)
|
||||
return momentums
|
||||
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError(
|
||||
'momentum is not applicable because optimizer does not exist.')
|
||||
momentums = []
|
||||
for group in self.optimizer.param_groups:
|
||||
if 'momentum' in group.keys():
|
||||
momentums.append(group['momentum'])
|
||||
elif 'betas' in group.keys():
|
||||
momentums.append(group['betas'][0])
|
||||
else:
|
||||
momentums.append(0)
|
||||
elif isinstance(self.optimizer, torch.optim.Optimizer):
|
||||
momentums = _get_momentum(self.optimizer)
|
||||
elif isinstance(self.optimizer, dict):
|
||||
momentums = dict()
|
||||
for name, optim in self.optimizer.items():
|
||||
momentums[name] = _get_momentum(optim)
|
||||
return momentums
|
||||
|
||||
def register_hook(self, hook, priority='NORMAL'):
|
||||
|
@ -9,10 +9,12 @@ from importlib import import_module
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils import model_zoo
|
||||
|
||||
import mmcv
|
||||
from ..fileio import load as load_file
|
||||
from ..parallel import is_parallel_module
|
||||
from ..utils import mkdir_or_exist
|
||||
from .dist_utils import get_dist_info
|
||||
|
||||
@ -59,6 +61,10 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
|
||||
# use _load_from_state_dict to enable checkpoint version control
|
||||
def load(module, prefix=''):
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_parallel_module(module):
|
||||
module = module.module
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||
@ -228,10 +234,7 @@ def load_checkpoint(model,
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
|
||||
# load state_dict
|
||||
if hasattr(model, 'module'):
|
||||
load_state_dict(model.module, state_dict, strict, logger)
|
||||
else:
|
||||
load_state_dict(model, state_dict, strict, logger)
|
||||
load_state_dict(model, state_dict, strict, logger)
|
||||
return checkpoint
|
||||
|
||||
|
||||
@ -269,15 +272,20 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
||||
|
||||
mmcv.mkdir_or_exist(osp.dirname(filename))
|
||||
if hasattr(model, 'module'):
|
||||
if is_parallel_module(model):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(model.state_dict())
|
||||
}
|
||||
if optimizer is not None:
|
||||
# save optimizer state dict in the checkpoint
|
||||
if isinstance(optimizer, Optimizer):
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
elif isinstance(optimizer, dict):
|
||||
checkpoint['optimizer'] = {}
|
||||
for name, optim in optimizer.items():
|
||||
checkpoint['optimizer'][name] = optim.state_dict()
|
||||
# immediately flush buffer
|
||||
with open(filename, 'wb') as f:
|
||||
torch.save(checkpoint, f)
|
||||
|
@ -7,14 +7,34 @@ from .hook import HOOKS, Hook
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointHook(Hook):
|
||||
"""Save checkpoints periodically.
|
||||
|
||||
Args:
|
||||
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||
indicates epochs, otherwise it indicates iterations.
|
||||
Default: -1, which means "never".
|
||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||
Default: True.
|
||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||
checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
out_dir (str, optional): The directory to save checkpoints. If not
|
||||
specified, ``runner.work_dir`` will be used by default.
|
||||
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
|
||||
In some cases we want only the latest few checkpoints and would
|
||||
like to delete old ones to save the disk space.
|
||||
Default: -1, which means unlimited.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
interval=-1,
|
||||
by_epoch=True,
|
||||
save_optimizer=True,
|
||||
out_dir=None,
|
||||
max_keep_ckpts=-1,
|
||||
**kwargs):
|
||||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
self.save_optimizer = save_optimizer
|
||||
self.out_dir = out_dir
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
@ -22,9 +42,10 @@ class CheckpointHook(Hook):
|
||||
|
||||
@master_only
|
||||
def after_train_epoch(self, runner):
|
||||
if not self.every_n_epochs(runner, self.interval):
|
||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
|
||||
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
|
||||
if not self.out_dir:
|
||||
self.out_dir = runner.work_dir
|
||||
runner.save_checkpoint(
|
||||
@ -41,3 +62,29 @@ class CheckpointHook(Hook):
|
||||
os.remove(ckpt_path)
|
||||
else:
|
||||
break
|
||||
|
||||
@master_only
|
||||
def after_train_iter(self, runner):
|
||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||
return
|
||||
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||
if not self.out_dir:
|
||||
self.out_dir = runner.work_dir
|
||||
runner.save_checkpoint(
|
||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
||||
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth')
|
||||
current_iter = runner.iter + 1
|
||||
for _iter in range(
|
||||
current_iter - self.max_keep_ckpts * self.interval, 0,
|
||||
-self.interval):
|
||||
ckpt_path = os.path.join(self.out_dir,
|
||||
filename_tmpl.format(_iter))
|
||||
if os.path.exists(ckpt_path):
|
||||
os.remove(ckpt_path)
|
||||
else:
|
||||
break
|
||||
|
@ -71,8 +71,22 @@ class PaviLoggerHook(LoggerHook):
|
||||
for tag, val in runner.log_buffer.output.items():
|
||||
if tag not in ['time', 'data_time'] and is_scalar(val):
|
||||
tags[tag] = val
|
||||
tags['learning_rate'] = runner.current_lr()[0]
|
||||
tags['momentum'] = runner.current_momentum()[0]
|
||||
# add learning rate
|
||||
lrs = runner.current_lr()
|
||||
if isinstance(lrs, dict):
|
||||
for name, value in lrs.items():
|
||||
tags[f'learning_rate/{name}'] = value[0]
|
||||
else:
|
||||
tags['learning_rate'] = lrs[0]
|
||||
|
||||
# add momentum
|
||||
momentums = runner.current_momentum()
|
||||
if isinstance(momentums, dict):
|
||||
for name, value in momentums.items():
|
||||
tags[f'momentum/{name}'] = value[0]
|
||||
else:
|
||||
tags['momentum'] = momentums[0]
|
||||
|
||||
if tags:
|
||||
self.writer.add_scalars(runner.mode, tags, runner.iter)
|
||||
|
||||
|
@ -53,10 +53,22 @@ class TensorboardLoggerHook(LoggerHook):
|
||||
else:
|
||||
self.writer.add_scalar(tag, runner.log_buffer.output[var],
|
||||
runner.iter)
|
||||
self.writer.add_scalar('learning_rate',
|
||||
runner.current_lr()[0], runner.iter)
|
||||
self.writer.add_scalar('momentum',
|
||||
runner.current_momentum()[0], runner.iter)
|
||||
# add learning rate
|
||||
lrs = runner.current_lr()
|
||||
if isinstance(lrs, dict):
|
||||
for name, value in lrs.items():
|
||||
self.writer.add_scalar(f'learning_rate/{name}', value[0],
|
||||
runner.iter)
|
||||
else:
|
||||
self.writer.add_scalar('learning_rate', lrs[0], runner.iter)
|
||||
# add momentum
|
||||
momentums = runner.current_momentum()
|
||||
if isinstance(momentums, dict):
|
||||
for name, value in momentums.items():
|
||||
self.writer.add_scalar(f'momentum/{name}', value[0],
|
||||
runner.iter)
|
||||
else:
|
||||
self.writer.add_scalar('momentum', momentums[0], runner.iter)
|
||||
|
||||
@master_only
|
||||
def after_run(self, runner):
|
||||
|
@ -19,6 +19,7 @@ class TextLoggerHook(LoggerHook):
|
||||
saved in json file.
|
||||
|
||||
Args:
|
||||
by_epoch (bool): Whether EpochBasedRunner is used.
|
||||
interval (int): Logging interval (every k iterations).
|
||||
ignore_last (bool): Ignore the log of last iterations in each epoch
|
||||
if less than `interval`.
|
||||
@ -29,11 +30,13 @@ class TextLoggerHook(LoggerHook):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
by_epoch=True,
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
interval_exp_name=1000):
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
|
||||
self.by_epoch = by_epoch
|
||||
self.time_sec_tot = 0
|
||||
self.interval_exp_name = interval_exp_name
|
||||
|
||||
@ -55,17 +58,32 @@ class TextLoggerHook(LoggerHook):
|
||||
return mem_mb.item()
|
||||
|
||||
def _log_info(self, log_dict, runner):
|
||||
# print exp name for users to distinguish experiments
|
||||
# at every ``interval_exp_name`` iterations and the end of each epoch
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_inner_iters(
|
||||
runner,
|
||||
self.interval_exp_name)) or self.end_of_epoch(runner):
|
||||
exp_info = f"Exp name: {runner.meta['exp_name']}"
|
||||
if (self.every_n_inner_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
|
||||
if runner.mode == 'train':
|
||||
log_str = f'Epoch [{log_dict["epoch"]}]' \
|
||||
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' \
|
||||
f'lr: {log_dict["lr"]:.5f}, '
|
||||
if isinstance(log_dict['lr'], dict):
|
||||
lr_str = []
|
||||
for k, val in log_dict['lr'].items():
|
||||
lr_str.append(f'lr_{k}: {val:.3e}')
|
||||
lr_str = ' '.join(lr_str)
|
||||
else:
|
||||
lr_str = f'lr: {log_dict["lr"]:.3e}'
|
||||
|
||||
# by epoch: Epoch [4][100/1000]
|
||||
# by iter: Iter [100/100000]
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch [{log_dict["epoch"]}]' \
|
||||
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
|
||||
else:
|
||||
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
|
||||
log_str += f'{lr_str}, '
|
||||
|
||||
if 'time' in log_dict.keys():
|
||||
self.time_sec_tot += (log_dict['time'] * self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
@ -79,8 +97,12 @@ class TextLoggerHook(LoggerHook):
|
||||
if torch.cuda.is_available():
|
||||
log_str += f'memory: {log_dict["memory"]}, '
|
||||
else:
|
||||
log_str = f'Epoch({log_dict["mode"]}) ' \
|
||||
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch({log_dict["mode"]}) ' \
|
||||
f'[{log_dict["epoch"] - 1}][{log_dict["iter"]}]\t'
|
||||
else:
|
||||
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
|
||||
|
||||
log_items = []
|
||||
for name, val in log_dict.items():
|
||||
# TODO: resolve this hack
|
||||
@ -94,6 +116,7 @@ class TextLoggerHook(LoggerHook):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ', '.join(log_items)
|
||||
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def _dump_log(self, log_dict, runner):
|
||||
@ -123,7 +146,16 @@ class TextLoggerHook(LoggerHook):
|
||||
log_dict['epoch'] = runner.epoch + 1
|
||||
log_dict['iter'] = runner.inner_iter + 1
|
||||
# only record lr of the first param group
|
||||
log_dict['lr'] = runner.current_lr()[0]
|
||||
cur_lr = runner.current_lr()
|
||||
if isinstance(cur_lr, list):
|
||||
log_dict['lr'] = cur_lr[0]
|
||||
else:
|
||||
assert isinstance(cur_lr, dict)
|
||||
log_dict['lr'] = {}
|
||||
for k, lr_ in cur_lr.items():
|
||||
assert isinstance(lr_, list)
|
||||
log_dict['lr'].update({k: lr_[0]})
|
||||
|
||||
if mode == 'train':
|
||||
log_dict['time'] = runner.log_buffer.output['time']
|
||||
log_dict['data_time'] = runner.log_buffer.output['data_time']
|
||||
|
@ -55,14 +55,31 @@ class LrUpdaterHook(Hook):
|
||||
self.regular_lr = [] # expected lr if no warming up is performed
|
||||
|
||||
def _set_lr(self, runner, lr_groups):
|
||||
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
|
||||
param_group['lr'] = lr
|
||||
if isinstance(runner.optimizer, dict):
|
||||
for k, optim in runner.optimizer.items():
|
||||
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
for param_group, lr in zip(runner.optimizer.param_groups,
|
||||
lr_groups):
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_regular_lr(self, runner):
|
||||
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
|
||||
if isinstance(runner.optimizer, dict):
|
||||
lr_groups = {}
|
||||
for k in runner.optimizer.keys():
|
||||
_lr_group = [
|
||||
self.get_lr(runner, _base_lr)
|
||||
for _base_lr in self.base_lr[k]
|
||||
]
|
||||
lr_groups.update({k: _lr_group})
|
||||
|
||||
return lr_groups
|
||||
else:
|
||||
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
|
||||
|
||||
def get_warmup_lr(self, cur_iters):
|
||||
if self.warmup == 'constant':
|
||||
@ -78,11 +95,21 @@ class LrUpdaterHook(Hook):
|
||||
def before_run(self, runner):
|
||||
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
|
||||
# it will be set according to the optimizer params
|
||||
for group in runner.optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lr = [
|
||||
group['initial_lr'] for group in runner.optimizer.param_groups
|
||||
]
|
||||
if isinstance(runner.optimizer, dict):
|
||||
self.base_lr = {}
|
||||
for k, optim in runner.optimizer.items():
|
||||
for group in optim.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
_base_lr = [
|
||||
group['initial_lr'] for group in optim.param_groups
|
||||
]
|
||||
self.base_lr.update({k: _base_lr})
|
||||
else:
|
||||
for group in runner.optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lr = [
|
||||
group['initial_lr'] for group in runner.optimizer.param_groups
|
||||
]
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
if not self.by_epoch:
|
||||
@ -213,6 +240,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
|
||||
else:
|
||||
progress = runner.iter
|
||||
max_progress = runner.max_iters
|
||||
|
||||
if self.min_lr_ratio is not None:
|
||||
target_lr = base_lr * self.min_lr_ratio
|
||||
else:
|
||||
@ -224,7 +252,7 @@ class CosineAnealingLrUpdaterHook(LrUpdaterHook):
|
||||
class CyclicLrUpdaterHook(LrUpdaterHook):
|
||||
"""Cyclic LR Scheduler
|
||||
|
||||
Implemet the cyclical learning rate policy (CLR) described in
|
||||
Implement the cyclical learning rate policy (CLR) described in
|
||||
https://arxiv.org/pdf/1506.01186.pdf
|
||||
|
||||
Different from the original paper, we use cosine anealing rather than
|
||||
|
223
mmcv/runner/iter_based_runner.py
Normal file
223
mmcv/runner/iter_based_runner.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import mmcv
|
||||
from .base_runner import BaseRunner
|
||||
from .checkpoint import save_checkpoint
|
||||
from .hooks import IterTimerHook
|
||||
from .utils import get_host_info
|
||||
|
||||
|
||||
class IterLoader:
|
||||
|
||||
def __init__(self, dataloader):
|
||||
self._dataloader = dataloader
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
self._epoch = 0
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
data = next(self.iter_loader)
|
||||
except StopIteration:
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader.sampler, 'set_epoch'):
|
||||
self._dataloader.sampler.set_epoch(self._epoch)
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
data = next(self.iter_loader)
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataloader)
|
||||
|
||||
|
||||
class IterBasedRunner(BaseRunner):
|
||||
"""Iteration-based Runner.
|
||||
|
||||
This runner train models iteration by iteration.
|
||||
"""
|
||||
|
||||
def train(self, data_loader, **kwargs):
|
||||
self.model.train()
|
||||
self.mode = 'train'
|
||||
self.data_loader = data_loader
|
||||
self._epoch = data_loader.epoch
|
||||
self.call_hook('before_train_iter')
|
||||
data_batch = next(data_loader)
|
||||
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
raise TypeError('model.train_step() must return a dict')
|
||||
if 'log_vars' in outputs:
|
||||
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_train_iter')
|
||||
self._inner_iter += 1
|
||||
self._iter += 1
|
||||
|
||||
def val(self, data_loader, **kwargs):
|
||||
self.model.eval()
|
||||
self.mode = 'val'
|
||||
self._inner_iter = 0
|
||||
self.data_loader = data_loader
|
||||
self.call_hook('before_val_iter')
|
||||
data_batch = next(data_loader)
|
||||
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
raise TypeError('model.val_step() must return a dict')
|
||||
if 'log_vars' in outputs:
|
||||
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_val_iter')
|
||||
self._inner_iter += 1
|
||||
|
||||
def run(self, data_loaders, workflow, max_iters, **kwargs):
|
||||
"""Start running.
|
||||
|
||||
Args:
|
||||
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
|
||||
and validation.
|
||||
workflow (list[tuple]): A list of (phase, iters) to specify the
|
||||
running order and iterations. E.g, [('train', 10000),
|
||||
('val', 1000)] means running 10000 iterations for training and
|
||||
1000 iterations for validation, iteratively.
|
||||
max_iters (int): Total training iterations.
|
||||
"""
|
||||
assert isinstance(data_loaders, list)
|
||||
assert mmcv.is_list_of(workflow, tuple)
|
||||
assert len(data_loaders) == len(workflow)
|
||||
|
||||
self._max_iters = max_iters
|
||||
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
||||
self.logger.info('Start running, host: %s, work_dir: %s',
|
||||
get_host_info(), work_dir)
|
||||
self.logger.info('workflow: %s, max: %d iters', workflow, max_iters)
|
||||
self.call_hook('before_run')
|
||||
|
||||
iter_loaders = [IterLoader(x) for x in data_loaders]
|
||||
|
||||
self.call_hook('before_epoch')
|
||||
|
||||
while self.iter < max_iters:
|
||||
for i, flow in enumerate(workflow):
|
||||
mode, iters = flow
|
||||
if not isinstance(mode, str) or not hasattr(self, mode):
|
||||
raise ValueError(
|
||||
'runner has no method named "{}" to run a workflow'.
|
||||
format(mode))
|
||||
iter_runner = getattr(self, mode)
|
||||
for _ in range(iters):
|
||||
if mode == 'train' and self.iter >= max_iters:
|
||||
return
|
||||
iter_runner(iter_loaders[i], **kwargs)
|
||||
|
||||
time.sleep(1) # wait for some hooks like loggers to finish
|
||||
self.call_hook('after_epoch')
|
||||
self.call_hook('after_run')
|
||||
|
||||
def resume(self,
|
||||
checkpoint,
|
||||
resume_optimizer=True,
|
||||
map_location='default'):
|
||||
"""Resume model from checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Checkpoint to resume from.
|
||||
resume_optimizer (bool, optional): Whether resume the optimizer(s)
|
||||
if the checkpoint file includes optimizer(s). Default to True.
|
||||
map_location (str, optional): Same as :func:`torch.load`.
|
||||
Default to 'default'.
|
||||
"""
|
||||
if map_location == 'default':
|
||||
device_id = torch.cuda.current_device()
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint,
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
else:
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint, map_location=map_location)
|
||||
|
||||
self._epoch = checkpoint['meta']['epoch']
|
||||
self._iter = checkpoint['meta']['iter']
|
||||
self._inner_iter = checkpoint['meta']['iter']
|
||||
if 'optimizer' in checkpoint and resume_optimizer:
|
||||
if isinstance(self.optimizer, Optimizer):
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
elif isinstance(self.optimizer, dict):
|
||||
for k in self.optimizer.keys():
|
||||
self.optimizer[k].load_state_dict(
|
||||
checkpoint['optimizer'][k])
|
||||
|
||||
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
|
||||
|
||||
def save_checkpoint(self,
|
||||
out_dir,
|
||||
filename_tmpl='iter_{}.pth',
|
||||
meta=None,
|
||||
save_optimizer=True,
|
||||
create_symlink=True):
|
||||
"""Save checkpoint to file.
|
||||
|
||||
Args:
|
||||
out_dir (str): Directory to save checkpoint files.
|
||||
filename_tmpl (str, optional): Checkpoint file template.
|
||||
Defaults to 'iter_{}.pth'.
|
||||
meta (dict, optional): Metadata to be saved in checkpoint.
|
||||
Defaults to None.
|
||||
save_optimizer (bool, optional): Whether save optimizer.
|
||||
Defaults to True.
|
||||
create_symlink (bool, optional): Whether create symlink to the
|
||||
latest checkpoint file. Defaults to True.
|
||||
"""
|
||||
if meta is None:
|
||||
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
|
||||
elif isinstance(meta, dict):
|
||||
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'meta should be a dict or None, but got {type(meta)}')
|
||||
meta.update(self.meta)
|
||||
|
||||
filename = filename_tmpl.format(self.iter + 1)
|
||||
filepath = osp.join(out_dir, filename)
|
||||
optimizer = self.optimizer if save_optimizer else None
|
||||
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
||||
# in some environments, `os.symlink` is not supported, you may need to
|
||||
# set `create_symlink` to False
|
||||
if create_symlink:
|
||||
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
|
||||
|
||||
def register_training_hooks(self,
|
||||
lr_config,
|
||||
optimizer_config=None,
|
||||
checkpoint_config=None,
|
||||
log_config=None,
|
||||
momentum_config=None):
|
||||
"""Register default hooks for iter-based training.
|
||||
|
||||
Default hooks include:
|
||||
|
||||
- LrUpdaterHook
|
||||
- MomentumUpdaterHook
|
||||
- OptimizerStepperHook
|
||||
- CheckpointSaverHook
|
||||
- IterTimerHook
|
||||
- LoggerHook(s)
|
||||
"""
|
||||
if checkpoint_config is not None:
|
||||
checkpoint_config.setdefault('by_epoch', False)
|
||||
if lr_config is not None:
|
||||
lr_config.setdefault('by_epoch', False)
|
||||
self.register_lr_hook(lr_config)
|
||||
self.register_momentum_hook(momentum_config)
|
||||
self.register_optimizer_hook(optimizer_config)
|
||||
self.register_checkpoint_hook(checkpoint_config)
|
||||
self.register_hook(IterTimerHook())
|
||||
self.register_logger_hooks(log_config)
|
@ -39,22 +39,48 @@ def test_epoch_based_runner():
|
||||
def batch_processor():
|
||||
pass
|
||||
|
||||
_ = EpochBasedRunner(model, batch_processor)
|
||||
_ = EpochBasedRunner(
|
||||
model, batch_processor, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# batch_processor must be callable
|
||||
model = OldStyleModel()
|
||||
_ = EpochBasedRunner(model, batch_processor=0)
|
||||
_ = EpochBasedRunner(
|
||||
model, batch_processor=0, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# optimizer must be a optimizer or a dict of optimizers
|
||||
model = Model()
|
||||
optimizer = 'NotAOptimizer'
|
||||
_ = EpochBasedRunner(
|
||||
model, optimizer=optimizer, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# optimizer must be a optimizer or a dict of optimizers
|
||||
model = Model()
|
||||
optimizers = dict(optim1=torch.optim.Adam(), optim2='NotAOptimizer')
|
||||
_ = EpochBasedRunner(
|
||||
model, optimizer=optimizers, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# logger must be a logging.Logger
|
||||
model = Model()
|
||||
_ = EpochBasedRunner(model, logger=None)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# meta must be a dict or None
|
||||
model = Model()
|
||||
_ = EpochBasedRunner(model, logger=logging.getLogger(), meta=['list'])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# model must implement the method train_step()
|
||||
model = OldStyleModel()
|
||||
_ = EpochBasedRunner(model)
|
||||
_ = EpochBasedRunner(model, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# work_dir must be a str or None
|
||||
model = Model()
|
||||
_ = EpochBasedRunner(model, work_dir=1)
|
||||
_ = EpochBasedRunner(model, work_dir=1, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# batch_processor and train_step() cannot be both set
|
||||
@ -63,7 +89,8 @@ def test_epoch_based_runner():
|
||||
pass
|
||||
|
||||
model = Model()
|
||||
_ = EpochBasedRunner(model, batch_processor)
|
||||
_ = EpochBasedRunner(
|
||||
model, batch_processor, logger=logging.getLogger())
|
||||
|
||||
# test work_dir
|
||||
model = Model()
|
||||
@ -71,9 +98,9 @@ def test_epoch_based_runner():
|
||||
dir_name = ''.join(
|
||||
[random.choice(string.ascii_letters) for _ in range(10)])
|
||||
work_dir = osp.join(temp_root, dir_name)
|
||||
_ = EpochBasedRunner(model, work_dir=work_dir)
|
||||
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
|
||||
assert osp.isdir(work_dir)
|
||||
_ = EpochBasedRunner(model, work_dir=work_dir)
|
||||
_ = EpochBasedRunner(model, work_dir=work_dir, logger=logging.getLogger())
|
||||
assert osp.isdir(work_dir)
|
||||
os.removedirs(work_dir)
|
||||
|
||||
@ -84,7 +111,7 @@ def test_runner_with_parallel():
|
||||
pass
|
||||
|
||||
model = MMDataParallel(OldStyleModel())
|
||||
_ = EpochBasedRunner(model, batch_processor)
|
||||
_ = EpochBasedRunner(model, batch_processor, logger=logging.getLogger())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# batch_processor and train_step() cannot be both set
|
||||
@ -93,7 +120,8 @@ def test_runner_with_parallel():
|
||||
pass
|
||||
|
||||
model = MMDataParallel(Model())
|
||||
_ = EpochBasedRunner(model, batch_processor)
|
||||
_ = EpochBasedRunner(
|
||||
model, batch_processor, logger=logging.getLogger())
|
||||
|
||||
|
||||
def test_save_checkpoint():
|
||||
|
Loading…
x
Reference in New Issue
Block a user