mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Ema (#421)
* add ema hook * add ema hook resume * add ema hook test * fix typo * fix according to comment * delete logger * fix according to comment * fix unitest * fix typo * fix according to comment * change to resume_from * typo * fix isort
This commit is contained in:
parent
4676031c8b
commit
1830347f8b
@ -6,8 +6,8 @@ from .dist_utils import get_dist_info, init_dist, master_only
|
||||
from .epoch_based_runner import EpochBasedRunner, Runner
|
||||
from .fp16_utils import auto_fp16, force_fp32
|
||||
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
||||
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
|
||||
LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
|
||||
EMAHook, Fp16OptimizerHook, Hook, IterTimerHook,
|
||||
LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
|
||||
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
|
||||
TextLoggerHook, WandbLoggerHook)
|
||||
from .iter_based_runner import IterBasedRunner, IterLoader
|
||||
@ -30,5 +30,5 @@ __all__ = [
|
||||
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
|
||||
'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook',
|
||||
'SyncBuffersHook'
|
||||
'SyncBuffersHook', 'EMAHook'
|
||||
]
|
||||
|
@ -307,10 +307,13 @@ class BaseRunner(metaclass=ABCMeta):
|
||||
resume_optimizer=True,
|
||||
map_location='default'):
|
||||
if map_location == 'default':
|
||||
device_id = torch.cuda.current_device()
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint,
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
if torch.cuda.is_available():
|
||||
device_id = torch.cuda.current_device()
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint,
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
else:
|
||||
checkpoint = self.load_checkpoint(checkpoint)
|
||||
else:
|
||||
checkpoint = self.load_checkpoint(
|
||||
checkpoint, map_location=map_location)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .checkpoint import CheckpointHook
|
||||
from .closure import ClosureHook
|
||||
from .ema import EMAHook
|
||||
from .hook import HOOKS, Hook
|
||||
from .iter_timer import IterTimerHook
|
||||
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
|
||||
@ -17,5 +18,5 @@ __all__ = [
|
||||
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
|
||||
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
|
||||
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
||||
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook'
|
||||
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook'
|
||||
]
|
||||
|
88
mmcv/runner/hooks/ema.py
Normal file
88
mmcv/runner/hooks/ema.py
Normal file
@ -0,0 +1,88 @@
|
||||
from ...parallel import is_module_wrapper
|
||||
from ..hooks.hook import HOOKS, Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EMAHook(Hook):
|
||||
r"""Exponential Moving Average Hook.
|
||||
|
||||
Use Exponential Moving Average on all parameters of model in training
|
||||
process. All parameters have a ema backup, which update by the formula
|
||||
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{Xema_{t+1}} = (1 - \text{momentum}) \times
|
||||
\text{Xema_{t}} + \text{momentum} \times X_t
|
||||
|
||||
Args:
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Defaults to 0.0002.
|
||||
interval (int): Update ema parameter every interval iteration.
|
||||
Defaults to 1.
|
||||
warm_up (int): During first warm_up steps, we may use smaller momentum
|
||||
to update ema parameters more slowly. Defaults to 100.
|
||||
resume_from (str): The checkpoint path. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
momentum=0.0002,
|
||||
interval=1,
|
||||
warm_up=100,
|
||||
resume_from=None):
|
||||
assert isinstance(interval, int) and interval > 0
|
||||
self.warm_up = warm_up
|
||||
self.interval = interval
|
||||
assert momentum > 0 and momentum < 1
|
||||
self.momentum = momentum**interval
|
||||
self.checkpoint = resume_from
|
||||
|
||||
def before_run(self, runner):
|
||||
"""To resume model with it's ema parameters more friendly.
|
||||
|
||||
Register ema parameter as ``named_buffer`` to model
|
||||
"""
|
||||
model = runner.model
|
||||
if is_module_wrapper(model):
|
||||
model = model.module
|
||||
self.param_ema_buffer = {}
|
||||
self.model_parameters = dict(model.named_parameters(recurse=True))
|
||||
for name, value in self.model_parameters.items():
|
||||
# "." is not allowed in module's buffer name
|
||||
buffer_name = f"ema_{name.replace('.', '_')}"
|
||||
self.param_ema_buffer[name] = buffer_name
|
||||
model.register_buffer(buffer_name, value.data.clone())
|
||||
self.model_buffers = dict(model.named_buffers(recurse=True))
|
||||
if self.checkpoint is not None:
|
||||
runner.resume(self.checkpoint)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
"""Update ema parameter every self.interval iterations."""
|
||||
curr_step = runner.iter
|
||||
# We warm up the momentum considering the instability at beginning
|
||||
momentum = min(self.momentum,
|
||||
(1 + curr_step) / (self.warm_up + curr_step))
|
||||
if curr_step % self.interval != 0:
|
||||
return
|
||||
for name, parameter in self.model_parameters.items():
|
||||
buffer_name = self.param_ema_buffer[name]
|
||||
buffer_parameter = self.model_buffers[buffer_name]
|
||||
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
"""We load parameter values from ema backup to model before the
|
||||
EvalHook."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
"""We recover model's parameter from ema backup after last epoch's
|
||||
EvalHook."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def _swap_ema_parameters(self):
|
||||
"""Swap the parameter of model with parameter in ema_buffer."""
|
||||
for name, value in self.model_parameters.items():
|
||||
temp = value.data.clone()
|
||||
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
|
||||
value.data.copy_(ema_buffer.data)
|
||||
ema_buffer.data.copy_(temp)
|
@ -14,13 +14,85 @@ from unittest.mock import MagicMock, call
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import constant_
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
|
||||
PaviLoggerHook, WandbLoggerHook)
|
||||
from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner,
|
||||
IterTimerHook, MlflowLoggerHook, PaviLoggerHook,
|
||||
WandbLoggerHook)
|
||||
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
|
||||
|
||||
|
||||
def test_ema_hook():
|
||||
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
|
||||
|
||||
class DemoModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=2,
|
||||
kernel_size=1,
|
||||
padding=1,
|
||||
bias=True)
|
||||
self._init_weight()
|
||||
|
||||
def _init_weight(self):
|
||||
constant_(self.conv.weight, 0)
|
||||
constant_(self.conv.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x).sum()
|
||||
|
||||
def train_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x))
|
||||
|
||||
def val_step(self, x, optimizer, **kwargs):
|
||||
return dict(loss=self(x))
|
||||
|
||||
loader = DataLoader(torch.ones((1, 1, 1, 1)))
|
||||
runner = _build_demo_runner()
|
||||
demo_model = DemoModel()
|
||||
runner.model = demo_model
|
||||
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
|
||||
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
||||
runner.register_hook(emahook, priority='HIGHEST')
|
||||
runner.register_hook(checkpointhook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
|
||||
contain_ema_buffer = False
|
||||
for name, value in checkpoint['state_dict'].items():
|
||||
if 'ema' in name:
|
||||
contain_ema_buffer = True
|
||||
assert value.sum() == 0
|
||||
value.fill_(1)
|
||||
else:
|
||||
assert value.sum() == 0
|
||||
assert contain_ema_buffer
|
||||
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
|
||||
work_dir = runner.work_dir
|
||||
resume_ema_hook = EMAHook(
|
||||
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
|
||||
runner = _build_demo_runner()
|
||||
runner.model = demo_model
|
||||
runner.register_hook(resume_ema_hook, priority='HIGHEST')
|
||||
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
||||
runner.register_hook(checkpointhook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 2)
|
||||
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
|
||||
contain_ema_buffer = False
|
||||
for name, value in checkpoint['state_dict'].items():
|
||||
if 'ema' in name:
|
||||
contain_ema_buffer = True
|
||||
assert value.sum() == 2
|
||||
else:
|
||||
assert value.sum() == 1
|
||||
assert contain_ema_buffer
|
||||
shutil.rmtree(runner.work_dir)
|
||||
shutil.rmtree(work_dir)
|
||||
|
||||
|
||||
def test_pavi_hook():
|
||||
sys.modules['pavi'] = MagicMock()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user