diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 8c80d32a0..809cb1fd0 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -6,8 +6,8 @@ from .dist_utils import get_dist_info, init_dist, master_only from .epoch_based_runner import EpochBasedRunner, Runner from .fp16_utils import auto_fp16, force_fp32 from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, - Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook, - LrUpdaterHook, MlflowLoggerHook, OptimizerHook, + EMAHook, Fp16OptimizerHook, Hook, IterTimerHook, + LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook, PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) from .iter_based_runner import IterBasedRunner, IterLoader @@ -30,5 +30,5 @@ __all__ = [ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook', - 'SyncBuffersHook' + 'SyncBuffersHook', 'EMAHook' ] diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index 80db664eb..912d84bd2 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -307,10 +307,13 @@ class BaseRunner(metaclass=ABCMeta): resume_optimizer=True, map_location='default'): if map_location == 'default': - device_id = torch.cuda.current_device() - checkpoint = self.load_checkpoint( - checkpoint, - map_location=lambda storage, loc: storage.cuda(device_id)) + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint(checkpoint) else: checkpoint = self.load_checkpoint( checkpoint, map_location=map_location) diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index 66e390a05..c2d5a9514 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Open-MMLab. All rights reserved. from .checkpoint import CheckpointHook from .closure import ClosureHook +from .ema import EMAHook from .hook import HOOKS, Hook from .iter_timer import IterTimerHook from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook, @@ -17,5 +18,5 @@ __all__ = [ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', - 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook' + 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook' ] diff --git a/mmcv/runner/hooks/ema.py b/mmcv/runner/hooks/ema.py new file mode 100644 index 000000000..d5fe738dc --- /dev/null +++ b/mmcv/runner/hooks/ema.py @@ -0,0 +1,88 @@ +from ...parallel import is_module_wrapper +from ..hooks.hook import HOOKS, Hook + + +@HOOKS.register_module() +class EMAHook(Hook): + r"""Exponential Moving Average Hook. + + Use Exponential Moving Average on all parameters of model in training + process. All parameters have a ema backup, which update by the formula + as below. EMAHook takes priority over EvalHook and CheckpointSaverHook. + + .. math:: + + \text{Xema_{t+1}} = (1 - \text{momentum}) \times + \text{Xema_{t}} + \text{momentum} \times X_t + + Args: + momentum (float): The momentum used for updating ema parameter. + Defaults to 0.0002. + interval (int): Update ema parameter every interval iteration. + Defaults to 1. + warm_up (int): During first warm_up steps, we may use smaller momentum + to update ema parameters more slowly. Defaults to 100. + resume_from (str): The checkpoint path. Defaults to None. + """ + + def __init__(self, + momentum=0.0002, + interval=1, + warm_up=100, + resume_from=None): + assert isinstance(interval, int) and interval > 0 + self.warm_up = warm_up + self.interval = interval + assert momentum > 0 and momentum < 1 + self.momentum = momentum**interval + self.checkpoint = resume_from + + def before_run(self, runner): + """To resume model with it's ema parameters more friendly. + + Register ema parameter as ``named_buffer`` to model + """ + model = runner.model + if is_module_wrapper(model): + model = model.module + self.param_ema_buffer = {} + self.model_parameters = dict(model.named_parameters(recurse=True)) + for name, value in self.model_parameters.items(): + # "." is not allowed in module's buffer name + buffer_name = f"ema_{name.replace('.', '_')}" + self.param_ema_buffer[name] = buffer_name + model.register_buffer(buffer_name, value.data.clone()) + self.model_buffers = dict(model.named_buffers(recurse=True)) + if self.checkpoint is not None: + runner.resume(self.checkpoint) + + def after_train_iter(self, runner): + """Update ema parameter every self.interval iterations.""" + curr_step = runner.iter + # We warm up the momentum considering the instability at beginning + momentum = min(self.momentum, + (1 + curr_step) / (self.warm_up + curr_step)) + if curr_step % self.interval != 0: + return + for name, parameter in self.model_parameters.items(): + buffer_name = self.param_ema_buffer[name] + buffer_parameter = self.model_buffers[buffer_name] + buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data) + + def after_train_epoch(self, runner): + """We load parameter values from ema backup to model before the + EvalHook.""" + self._swap_ema_parameters() + + def before_train_epoch(self, runner): + """We recover model's parameter from ema backup after last epoch's + EvalHook.""" + self._swap_ema_parameters() + + def _swap_ema_parameters(self): + """Swap the parameter of model with parameter in ema_buffer.""" + for name, value in self.model_parameters.items(): + temp = value.data.clone() + ema_buffer = self.model_buffers[self.param_ema_buffer[name]] + value.data.copy_(ema_buffer.data) + ema_buffer.data.copy_(temp) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index bd72e720c..000d86ff6 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -14,13 +14,85 @@ from unittest.mock import MagicMock, call import pytest import torch import torch.nn as nn +from torch.nn.init import constant_ from torch.utils.data import DataLoader -from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook, - PaviLoggerHook, WandbLoggerHook) +from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner, + IterTimerHook, MlflowLoggerHook, PaviLoggerHook, + WandbLoggerHook) from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook +def test_ema_hook(): + """xdoctest -m tests/test_hooks.py test_ema_hook.""" + + class DemoModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d( + in_channels=1, + out_channels=2, + kernel_size=1, + padding=1, + bias=True) + self._init_weight() + + def _init_weight(self): + constant_(self.conv.weight, 0) + constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x).sum() + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) + + def val_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) + + loader = DataLoader(torch.ones((1, 1, 1, 1))) + runner = _build_demo_runner() + demo_model = DemoModel() + runner.model = demo_model + emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None) + checkpointhook = CheckpointHook(interval=1, by_epoch=True) + runner.register_hook(emahook, priority='HIGHEST') + runner.register_hook(checkpointhook) + runner.run([loader, loader], [('train', 1), ('val', 1)], 1) + checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth') + contain_ema_buffer = False + for name, value in checkpoint['state_dict'].items(): + if 'ema' in name: + contain_ema_buffer = True + assert value.sum() == 0 + value.fill_(1) + else: + assert value.sum() == 0 + assert contain_ema_buffer + torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth') + work_dir = runner.work_dir + resume_ema_hook = EMAHook( + momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth') + runner = _build_demo_runner() + runner.model = demo_model + runner.register_hook(resume_ema_hook, priority='HIGHEST') + checkpointhook = CheckpointHook(interval=1, by_epoch=True) + runner.register_hook(checkpointhook) + runner.run([loader, loader], [('train', 1), ('val', 1)], 2) + checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth') + contain_ema_buffer = False + for name, value in checkpoint['state_dict'].items(): + if 'ema' in name: + contain_ema_buffer = True + assert value.sum() == 2 + else: + assert value.sum() == 1 + assert contain_ema_buffer + shutil.rmtree(runner.work_dir) + shutil.rmtree(work_dir) + + def test_pavi_hook(): sys.modules['pavi'] = MagicMock()