* add ema hook

* add ema hook resume

* add ema hook test

* fix typo

* fix according to comment

* delete logger

* fix according to comment

* fix unitest

* fix typo

* fix according to comment

* change to resume_from

* typo

* fix isort
This commit is contained in:
shilong 2020-07-30 22:06:19 +08:00 committed by GitHub
parent 4676031c8b
commit 1830347f8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 174 additions and 10 deletions

View File

@ -6,8 +6,8 @@ from .dist_utils import get_dist_info, init_dist, master_only
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import auto_fp16, force_fp32
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
EMAHook, Fp16OptimizerHook, Hook, IterTimerHook,
LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
@ -30,5 +30,5 @@ __all__ = [
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook',
'SyncBuffersHook'
'SyncBuffersHook', 'EMAHook'
]

View File

@ -307,10 +307,13 @@ class BaseRunner(metaclass=ABCMeta):
resume_optimizer=True,
map_location='default'):
if map_location == 'default':
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)

View File

@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
@ -17,5 +18,5 @@ __all__ = [
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook'
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook'
]

88
mmcv/runner/hooks/ema.py Normal file
View File

@ -0,0 +1,88 @@
from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook
@HOOKS.register_module()
class EMAHook(Hook):
r"""Exponential Moving Average Hook.
Use Exponential Moving Average on all parameters of model in training
process. All parameters have a ema backup, which update by the formula
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
.. math::
\text{Xema_{t+1}} = (1 - \text{momentum}) \times
\text{Xema_{t}} + \text{momentum} \times X_t
Args:
momentum (float): The momentum used for updating ema parameter.
Defaults to 0.0002.
interval (int): Update ema parameter every interval iteration.
Defaults to 1.
warm_up (int): During first warm_up steps, we may use smaller momentum
to update ema parameters more slowly. Defaults to 100.
resume_from (str): The checkpoint path. Defaults to None.
"""
def __init__(self,
momentum=0.0002,
interval=1,
warm_up=100,
resume_from=None):
assert isinstance(interval, int) and interval > 0
self.warm_up = warm_up
self.interval = interval
assert momentum > 0 and momentum < 1
self.momentum = momentum**interval
self.checkpoint = resume_from
def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model
"""
model = runner.model
if is_module_wrapper(model):
model = model.module
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)
def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
curr_step = runner.iter
# We warm up the momentum considering the instability at beginning
momentum = min(self.momentum,
(1 + curr_step) / (self.warm_up + curr_step))
if curr_step % self.interval != 0:
return
for name, parameter in self.model_parameters.items():
buffer_name = self.param_ema_buffer[name]
buffer_parameter = self.model_buffers[buffer_name]
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
def after_train_epoch(self, runner):
"""We load parameter values from ema backup to model before the
EvalHook."""
self._swap_ema_parameters()
def before_train_epoch(self, runner):
"""We recover model's parameter from ema backup after last epoch's
EvalHook."""
self._swap_ema_parameters()
def _swap_ema_parameters(self):
"""Swap the parameter of model with parameter in ema_buffer."""
for name, value in self.model_parameters.items():
temp = value.data.clone()
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)

View File

@ -14,13 +14,85 @@ from unittest.mock import MagicMock, call
import pytest
import torch
import torch.nn as nn
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook)
from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner,
IterTimerHook, MlflowLoggerHook, PaviLoggerHook,
WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
padding=1,
bias=True)
self._init_weight()
def _init_weight(self):
constant_(self.conv.weight, 0)
constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x).sum()
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
demo_model = DemoModel()
runner.model = demo_model
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(emahook, priority='HIGHEST')
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 0
value.fill_(1)
else:
assert value.sum() == 0
assert contain_ema_buffer
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
work_dir = runner.work_dir
resume_ema_hook = EMAHook(
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner()
runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 2)
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 2
else:
assert value.sum() == 1
assert contain_ema_buffer
shutil.rmtree(runner.work_dir)
shutil.rmtree(work_dir)
def test_pavi_hook():
sys.modules['pavi'] = MagicMock()