mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Ema (#421)
* add ema hook * add ema hook resume * add ema hook test * fix typo * fix according to comment * delete logger * fix according to comment * fix unitest * fix typo * fix according to comment * change to resume_from * typo * fix isort
This commit is contained in:
parent
4676031c8b
commit
1830347f8b
@ -6,8 +6,8 @@ from .dist_utils import get_dist_info, init_dist, master_only
|
|||||||
from .epoch_based_runner import EpochBasedRunner, Runner
|
from .epoch_based_runner import EpochBasedRunner, Runner
|
||||||
from .fp16_utils import auto_fp16, force_fp32
|
from .fp16_utils import auto_fp16, force_fp32
|
||||||
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
||||||
Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
|
EMAHook, Fp16OptimizerHook, Hook, IterTimerHook,
|
||||||
LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
|
LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
|
||||||
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
|
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
|
||||||
TextLoggerHook, WandbLoggerHook)
|
TextLoggerHook, WandbLoggerHook)
|
||||||
from .iter_based_runner import IterBasedRunner, IterLoader
|
from .iter_based_runner import IterBasedRunner, IterLoader
|
||||||
@ -30,5 +30,5 @@ __all__ = [
|
|||||||
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||||
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
|
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
|
||||||
'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook',
|
'set_random_seed', 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook',
|
||||||
'SyncBuffersHook'
|
'SyncBuffersHook', 'EMAHook'
|
||||||
]
|
]
|
||||||
|
@ -307,10 +307,13 @@ class BaseRunner(metaclass=ABCMeta):
|
|||||||
resume_optimizer=True,
|
resume_optimizer=True,
|
||||||
map_location='default'):
|
map_location='default'):
|
||||||
if map_location == 'default':
|
if map_location == 'default':
|
||||||
|
if torch.cuda.is_available():
|
||||||
device_id = torch.cuda.current_device()
|
device_id = torch.cuda.current_device()
|
||||||
checkpoint = self.load_checkpoint(
|
checkpoint = self.load_checkpoint(
|
||||||
checkpoint,
|
checkpoint,
|
||||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||||
|
else:
|
||||||
|
checkpoint = self.load_checkpoint(checkpoint)
|
||||||
else:
|
else:
|
||||||
checkpoint = self.load_checkpoint(
|
checkpoint = self.load_checkpoint(
|
||||||
checkpoint, map_location=map_location)
|
checkpoint, map_location=map_location)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Open-MMLab. All rights reserved.
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
from .checkpoint import CheckpointHook
|
from .checkpoint import CheckpointHook
|
||||||
from .closure import ClosureHook
|
from .closure import ClosureHook
|
||||||
|
from .ema import EMAHook
|
||||||
from .hook import HOOKS, Hook
|
from .hook import HOOKS, Hook
|
||||||
from .iter_timer import IterTimerHook
|
from .iter_timer import IterTimerHook
|
||||||
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
|
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
|
||||||
@ -17,5 +18,5 @@ __all__ = [
|
|||||||
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
|
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
|
||||||
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
|
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
|
||||||
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
||||||
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook'
|
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook'
|
||||||
]
|
]
|
||||||
|
88
mmcv/runner/hooks/ema.py
Normal file
88
mmcv/runner/hooks/ema.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from ...parallel import is_module_wrapper
|
||||||
|
from ..hooks.hook import HOOKS, Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class EMAHook(Hook):
|
||||||
|
r"""Exponential Moving Average Hook.
|
||||||
|
|
||||||
|
Use Exponential Moving Average on all parameters of model in training
|
||||||
|
process. All parameters have a ema backup, which update by the formula
|
||||||
|
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{Xema_{t+1}} = (1 - \text{momentum}) \times
|
||||||
|
\text{Xema_{t}} + \text{momentum} \times X_t
|
||||||
|
|
||||||
|
Args:
|
||||||
|
momentum (float): The momentum used for updating ema parameter.
|
||||||
|
Defaults to 0.0002.
|
||||||
|
interval (int): Update ema parameter every interval iteration.
|
||||||
|
Defaults to 1.
|
||||||
|
warm_up (int): During first warm_up steps, we may use smaller momentum
|
||||||
|
to update ema parameters more slowly. Defaults to 100.
|
||||||
|
resume_from (str): The checkpoint path. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
momentum=0.0002,
|
||||||
|
interval=1,
|
||||||
|
warm_up=100,
|
||||||
|
resume_from=None):
|
||||||
|
assert isinstance(interval, int) and interval > 0
|
||||||
|
self.warm_up = warm_up
|
||||||
|
self.interval = interval
|
||||||
|
assert momentum > 0 and momentum < 1
|
||||||
|
self.momentum = momentum**interval
|
||||||
|
self.checkpoint = resume_from
|
||||||
|
|
||||||
|
def before_run(self, runner):
|
||||||
|
"""To resume model with it's ema parameters more friendly.
|
||||||
|
|
||||||
|
Register ema parameter as ``named_buffer`` to model
|
||||||
|
"""
|
||||||
|
model = runner.model
|
||||||
|
if is_module_wrapper(model):
|
||||||
|
model = model.module
|
||||||
|
self.param_ema_buffer = {}
|
||||||
|
self.model_parameters = dict(model.named_parameters(recurse=True))
|
||||||
|
for name, value in self.model_parameters.items():
|
||||||
|
# "." is not allowed in module's buffer name
|
||||||
|
buffer_name = f"ema_{name.replace('.', '_')}"
|
||||||
|
self.param_ema_buffer[name] = buffer_name
|
||||||
|
model.register_buffer(buffer_name, value.data.clone())
|
||||||
|
self.model_buffers = dict(model.named_buffers(recurse=True))
|
||||||
|
if self.checkpoint is not None:
|
||||||
|
runner.resume(self.checkpoint)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
"""Update ema parameter every self.interval iterations."""
|
||||||
|
curr_step = runner.iter
|
||||||
|
# We warm up the momentum considering the instability at beginning
|
||||||
|
momentum = min(self.momentum,
|
||||||
|
(1 + curr_step) / (self.warm_up + curr_step))
|
||||||
|
if curr_step % self.interval != 0:
|
||||||
|
return
|
||||||
|
for name, parameter in self.model_parameters.items():
|
||||||
|
buffer_name = self.param_ema_buffer[name]
|
||||||
|
buffer_parameter = self.model_buffers[buffer_name]
|
||||||
|
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
|
||||||
|
|
||||||
|
def after_train_epoch(self, runner):
|
||||||
|
"""We load parameter values from ema backup to model before the
|
||||||
|
EvalHook."""
|
||||||
|
self._swap_ema_parameters()
|
||||||
|
|
||||||
|
def before_train_epoch(self, runner):
|
||||||
|
"""We recover model's parameter from ema backup after last epoch's
|
||||||
|
EvalHook."""
|
||||||
|
self._swap_ema_parameters()
|
||||||
|
|
||||||
|
def _swap_ema_parameters(self):
|
||||||
|
"""Swap the parameter of model with parameter in ema_buffer."""
|
||||||
|
for name, value in self.model_parameters.items():
|
||||||
|
temp = value.data.clone()
|
||||||
|
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
|
||||||
|
value.data.copy_(ema_buffer.data)
|
||||||
|
ema_buffer.data.copy_(temp)
|
@ -14,13 +14,85 @@ from unittest.mock import MagicMock, call
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn.init import constant_
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
|
from mmcv.runner import (CheckpointHook, EMAHook, EpochBasedRunner,
|
||||||
PaviLoggerHook, WandbLoggerHook)
|
IterTimerHook, MlflowLoggerHook, PaviLoggerHook,
|
||||||
|
WandbLoggerHook)
|
||||||
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
|
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
|
||||||
|
|
||||||
|
|
||||||
|
def test_ema_hook():
|
||||||
|
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
|
||||||
|
|
||||||
|
class DemoModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=1,
|
||||||
|
bias=True)
|
||||||
|
self._init_weight()
|
||||||
|
|
||||||
|
def _init_weight(self):
|
||||||
|
constant_(self.conv.weight, 0)
|
||||||
|
constant_(self.conv.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x).sum()
|
||||||
|
|
||||||
|
def train_step(self, x, optimizer, **kwargs):
|
||||||
|
return dict(loss=self(x))
|
||||||
|
|
||||||
|
def val_step(self, x, optimizer, **kwargs):
|
||||||
|
return dict(loss=self(x))
|
||||||
|
|
||||||
|
loader = DataLoader(torch.ones((1, 1, 1, 1)))
|
||||||
|
runner = _build_demo_runner()
|
||||||
|
demo_model = DemoModel()
|
||||||
|
runner.model = demo_model
|
||||||
|
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
|
||||||
|
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
||||||
|
runner.register_hook(emahook, priority='HIGHEST')
|
||||||
|
runner.register_hook(checkpointhook)
|
||||||
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||||
|
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
|
||||||
|
contain_ema_buffer = False
|
||||||
|
for name, value in checkpoint['state_dict'].items():
|
||||||
|
if 'ema' in name:
|
||||||
|
contain_ema_buffer = True
|
||||||
|
assert value.sum() == 0
|
||||||
|
value.fill_(1)
|
||||||
|
else:
|
||||||
|
assert value.sum() == 0
|
||||||
|
assert contain_ema_buffer
|
||||||
|
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
|
||||||
|
work_dir = runner.work_dir
|
||||||
|
resume_ema_hook = EMAHook(
|
||||||
|
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
|
||||||
|
runner = _build_demo_runner()
|
||||||
|
runner.model = demo_model
|
||||||
|
runner.register_hook(resume_ema_hook, priority='HIGHEST')
|
||||||
|
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
||||||
|
runner.register_hook(checkpointhook)
|
||||||
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 2)
|
||||||
|
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
|
||||||
|
contain_ema_buffer = False
|
||||||
|
for name, value in checkpoint['state_dict'].items():
|
||||||
|
if 'ema' in name:
|
||||||
|
contain_ema_buffer = True
|
||||||
|
assert value.sum() == 2
|
||||||
|
else:
|
||||||
|
assert value.sum() == 1
|
||||||
|
assert contain_ema_buffer
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
shutil.rmtree(work_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_pavi_hook():
|
def test_pavi_hook():
|
||||||
sys.modules['pavi'] = MagicMock()
|
sys.modules['pavi'] = MagicMock()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user