* add ema hook

* add ema hook resume

* add ema hook test

* fix typo

* fix according to comment

* delete logger

* fix according to comment

* fix unitest

* fix typo

* fix according to comment

* change to resume_from

* typo

* fix isort
This commit is contained in:
shilong 2020-07-30 22:06:19 +08:00 committed by GitHub
parent 4676031c8b
commit 1830347f8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 174 additions and 10 deletions

View File

@ -6,8 +6,8 @@ from .dist_utils import get_dist_info, init_dist, master_only
from .epoch_based_runner import EpochBasedRunner, Runner from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import auto_fp16, force_fp32 from .fp16_utils import auto_fp16, force_fp32
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook, EMAHook, Fp16OptimizerHook, Hook, IterTimerHook,
LrUpdaterHook, MlflowLoggerHook, OptimizerHook, LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook, PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook) TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader from .iter_based_runner import IterBasedRunner, IterLoader
@ -30,5 +30,5 @@ __all__ = [
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader', 'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook',
'SyncBuffersHook' 'SyncBuffersHook', 'EMAHook'
] ]

View File

@ -307,10 +307,13 @@ class BaseRunner(metaclass=ABCMeta):
resume_optimizer=True, resume_optimizer=True,
map_location='default'): map_location='default'):
if map_location == 'default': if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint( checkpoint = self.load_checkpoint(
checkpoint, checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id)) map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(checkpoint)
else: else:
checkpoint = self.load_checkpoint( checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location) checkpoint, map_location=map_location)

View File

@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import CheckpointHook from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .ema import EMAHook
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook, from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
@ -17,5 +18,5 @@ __all__ = [
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook' 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook'
] ]

88
mmcv/runner/hooks/ema.py Normal file
View File

@ -0,0 +1,88 @@
from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook
@HOOKS.register_module()
class EMAHook(Hook):
r"""Exponential Moving Average Hook.
Use Exponential Moving Average on all parameters of model in training
process. All parameters have a ema backup, which update by the formula
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
.. math::
\text{Xema_{t+1}} = (1 - \text{momentum}) \times
\text{Xema_{t}} + \text{momentum} \times X_t
Args:
momentum (float): The momentum used for updating ema parameter.
Defaults to 0.0002.
interval (int): Update ema parameter every interval iteration.
Defaults to 1.
warm_up (int): During first warm_up steps, we may use smaller momentum
to update ema parameters more slowly. Defaults to 100.
resume_from (str): The checkpoint path. Defaults to None.
"""
def __init__(self,
momentum=0.0002,
interval=1,
warm_up=100,
resume_from=None):
assert isinstance(interval, int) and interval > 0
self.warm_up = warm_up
self.interval = interval
assert momentum > 0 and momentum < 1
self.momentum = momentum**interval
self.checkpoint = resume_from
def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model
"""
model = runner.model
if is_module_wrapper(model):
model = model.module
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)
def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
curr_step = runner.iter
# We warm up the momentum considering the instability at beginning
momentum = min(self.momentum,
(1 + curr_step) / (self.warm_up + curr_step))
if curr_step % self.interval != 0:
return
for name, parameter in self.model_parameters.items():
buffer_name = self.param_ema_buffer[name]
buffer_parameter = self.model_buffers[buffer_name]
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
def after_train_epoch(self, runner):
"""We load parameter values from ema backup to model before the
EvalHook."""
self._swap_ema_parameters()
def before_train_epoch(self, runner):
"""We recover model's parameter from ema backup after last epoch's
EvalHook."""
self._swap_ema_parameters()
def _swap_ema_parameters(self):
"""Swap the parameter of model with parameter in ema_buffer."""
for name, value in self.model_parameters.items():
temp = value.data.clone()
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)

View File

@ -14,13 +14,85 @@ from unittest.mock import MagicMock, call
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import constant_
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook, from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner,
PaviLoggerHook, WandbLoggerHook) IterTimerHook, MlflowLoggerHook, PaviLoggerHook,
WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
padding=1,
bias=True)
self._init_weight()
def _init_weight(self):
constant_(self.conv.weight, 0)
constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x).sum()
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
demo_model = DemoModel()
runner.model = demo_model
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(emahook, priority='HIGHEST')
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 0
value.fill_(1)
else:
assert value.sum() == 0
assert contain_ema_buffer
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
work_dir = runner.work_dir
resume_ema_hook = EMAHook(
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner()
runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 2)
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
contain_ema_buffer = False
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
contain_ema_buffer = True
assert value.sum() == 2
else:
assert value.sum() == 1
assert contain_ema_buffer
shutil.rmtree(runner.work_dir)
shutil.rmtree(work_dir)
def test_pavi_hook(): def test_pavi_hook():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()