mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support EMA and SWA. (#239)
* [Feature] Support EMA and SWA. * add ema hook * add avg model ut * add more unit tests * resolve comments * fix warmup ema * rename * fix comments * add assert * fix typehint * add comments
This commit is contained in:
parent
86ffc19c9c
commit
0279ac2e8d
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .checkpoint_hook import CheckpointHook
|
||||
from .ema_hook import EMAHook
|
||||
from .empty_cache_hook import EmptyCacheHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
@ -13,5 +14,5 @@ from .sync_buffer_hook import SyncBuffersHook
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||
'LoggerHook', 'NaiveVisualizationHook'
|
||||
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook'
|
||||
]
|
||||
|
94
mmengine/hooks/ema_hook.py
Normal file
94
mmengine/hooks/ema_hook.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.registry import HOOKS, MODELS
|
||||
from .hook import DATA_BATCH, Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EMAHook(Hook):
|
||||
"""A Hook to apply Exponential Moving Average (EMA) on the model during
|
||||
training.
|
||||
|
||||
Note:
|
||||
- EMAHook takes priority over CheckpointHook.
|
||||
- The original model parameters are actually saved in ema field after
|
||||
train.
|
||||
|
||||
Args:
|
||||
ema_type (str): The type of EMA strategy to use. You can find the
|
||||
supported strategies in ``mmengine.model.averaged_model``.
|
||||
Defaults to 'ExponentialMovingAverage'
|
||||
"""
|
||||
|
||||
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
|
||||
self.ema_cfg = dict(type=ema_type, **kwargs)
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Create an ema copy of the model."""
|
||||
model = runner.model
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
self.src_model = model
|
||||
self.ema_model = MODELS.build(
|
||||
self.ema_cfg, default_args=dict(model=self.src_model))
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Update ema parameter."""
|
||||
self.ema_model.update_parameters(self.src_model)
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
validation."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""We recover source model's parameter from ema model after
|
||||
validation."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
test."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
"""We recover source model's parameter from ema model after test."""
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Save ema parameters to checkpoint."""
|
||||
# save ema parameters to the source model's state dict so that we can
|
||||
# directly load the averaged model weights for deployment.
|
||||
self._swap_ema_parameters()
|
||||
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Resume ema parameters from checkpoint."""
|
||||
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
|
||||
# The original model parameters are actually saved in ema field.
|
||||
# swap the weights back to resume ema state.
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def _swap_ema_parameters(self) -> None:
|
||||
"""Swap the parameter of model with ema_model."""
|
||||
avg_param = (
|
||||
itertools.chain(self.ema_model.module.parameters(),
|
||||
self.ema_model.module.buffers())
|
||||
if self.ema_model.update_buffers else
|
||||
self.ema_model.module.parameters())
|
||||
src_param = (
|
||||
itertools.chain(self.src_model.parameters(),
|
||||
self.src_model.buffers())
|
||||
if self.ema_model.update_buffers else self.src_model.parameters())
|
||||
for p_avg, p_src in zip(avg_param, src_param):
|
||||
tmp = p_avg.data.clone()
|
||||
p_avg.data.copy_(p_src.data)
|
||||
p_src.data.copy_(tmp)
|
@ -1,5 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from .wrappers import (MMDataParallel, MMDistributedDataParallel,
|
||||
is_model_wrapper)
|
||||
|
||||
__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper']
|
||||
__all__ = [
|
||||
'MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper',
|
||||
'StochasticWeightAverage', 'ExponentialMovingAverage',
|
||||
'MomentumAnnealingEMA'
|
||||
]
|
||||
|
239
mmengine/model/averaged_model.py
Normal file
239
mmengine/model/averaged_model.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
|
||||
class BaseAveragedModel(nn.Module):
|
||||
"""A base class for averaging model weights.
|
||||
|
||||
Weight averaging, such as SWA and EMA, is a widely used technique for
|
||||
training neural networks. This class implements the averaging process
|
||||
for a model. All subclasses must implement the `avg_func` method.
|
||||
This class creates a copy of the provided module :attr:`model`
|
||||
on the device :attr:`device` and allows computing running averages of the
|
||||
parameters of the :attr:`model`.
|
||||
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py
|
||||
|
||||
In mmengine, we provide two ways to use the model averaging:
|
||||
1. Use the model averaging module in hook:
|
||||
We provide an EMAHook to apply the model averaging during training.
|
||||
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
|
||||
The hook is implemented in mmengine/hooks/ema_hook.py
|
||||
|
||||
2. Use the model averaging module directly in the algorithm. Take the ema
|
||||
teacher in semi-supervise as an example:
|
||||
>>> from mmengine.model import ExponentialMovingAverage
|
||||
>>> student = ResNet(depth=50)
|
||||
>>> # use ema model as teacher
|
||||
>>> ema_teacher = ExponentialMovingAverage(student)
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
interval (int): Interval between two updates. Defaults to 1.
|
||||
device (torch.device, optional): If provided, the averaged model will
|
||||
be stored on the :attr:`device`. Defaults to None.
|
||||
update_buffers (bool): if True, it will compute running averages for
|
||||
both the parameters and the buffers of the model. Defaults to
|
||||
False.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
interval: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
update_buffers: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.module = deepcopy(model)
|
||||
self.interval = interval
|
||||
if device is not None:
|
||||
self.module = self.module.to(device)
|
||||
self.register_buffer('steps',
|
||||
torch.tensor(0, dtype=torch.long, device=device))
|
||||
self.update_buffers = update_buffers
|
||||
|
||||
@abstractmethod
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
steps: int) -> Tensor:
|
||||
"""Compute the average of the parameters. All subclasses must implement
|
||||
this method.
|
||||
|
||||
Args:
|
||||
averaged_param (Tensor): The averaged parameters.
|
||||
source_param (Tensor): The source parameters.
|
||||
steps (int): The number of times the parameters have been
|
||||
updated.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Forward method of the averaged model."""
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def update_parameters(self, model: nn.Module) -> None:
|
||||
"""Update the parameters of the model. This method will execute the
|
||||
``avg_func`` to compute the new parameters and update the model's
|
||||
parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model whose parameters will be averaged.
|
||||
"""
|
||||
if self.steps % self.interval == 0:
|
||||
avg_param = (
|
||||
itertools.chain(self.module.parameters(),
|
||||
self.module.buffers())
|
||||
if self.update_buffers else self.parameters())
|
||||
src_param = (
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.update_buffers else model.parameters())
|
||||
for p_avg, p_src in zip(avg_param, src_param):
|
||||
device = p_avg.device
|
||||
p_src_ = p_src.detach().to(device)
|
||||
if self.steps == 0:
|
||||
p_avg.detach().copy_(p_src_)
|
||||
else:
|
||||
p_avg.detach().copy_(
|
||||
self.avg_func(p_avg.detach(), p_src_,
|
||||
self.steps.to(device)))
|
||||
self.steps += 1
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class StochasticWeightAverage(BaseAveragedModel):
|
||||
"""Implements the stochastic weight averaging (SWA) of the model.
|
||||
|
||||
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
|
||||
Wider Optima and Better Generalization, UAI 2018.
|
||||
<https://arxiv.org/abs/1803.05407>`_ by Pavel Izmailov, Dmitrii
|
||||
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson.
|
||||
"""
|
||||
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
steps: int) -> Tensor:
|
||||
"""Compute the average of the parameters using stochastic weight
|
||||
average.
|
||||
|
||||
Args:
|
||||
averaged_param (Tensor): The averaged parameters.
|
||||
source_param (Tensor): The source parameters.
|
||||
steps (int): The number of times the parameters have been
|
||||
updated.
|
||||
Returns:
|
||||
Tensor: The averaged parameters.
|
||||
"""
|
||||
return averaged_param + (source_param - averaged_param) / (
|
||||
steps // self.interval + 1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ExponentialMovingAverage(BaseAveragedModel):
|
||||
"""Implements the exponential moving average (EMA) of the model.
|
||||
|
||||
All parameters are updated by the formula as below:
|
||||
|
||||
.. math::
|
||||
|
||||
Xema\_{t+1} = (1 - \text{momentum}) \times
|
||||
Xema\_{t} + \text{momentum} \times X_t
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`averaged_param = (1-momentum) * averaged_param + momentum *
|
||||
source_param`. Defaults to 0.0002.
|
||||
interval (int): Interval between two updates. Defaults to 1.
|
||||
device (torch.device, optional): If provided, the averaged model will
|
||||
be stored on the :attr:`device`. Defaults to None.
|
||||
update_buffers (bool): if True, it will compute running averages for
|
||||
both the parameters and the buffers of the model. Defaults to
|
||||
False.
|
||||
""" # noqa: W605
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
momentum: float = 0.0002,
|
||||
interval: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
update_buffers: bool = False) -> None:
|
||||
super().__init__(model, interval, device, update_buffers)
|
||||
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
|
||||
f'but got {momentum}'
|
||||
self.momentum = momentum
|
||||
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
steps: int) -> Tensor:
|
||||
"""Compute the moving average of the parameters using exponential
|
||||
moving average.
|
||||
|
||||
Args:
|
||||
averaged_param (Tensor): The averaged parameters.
|
||||
source_param (Tensor): The source parameters.
|
||||
steps (int): The number of times the parameters have been
|
||||
updated.
|
||||
Returns:
|
||||
Tensor: The averaged parameters.
|
||||
"""
|
||||
return averaged_param * (1 -
|
||||
self.momentum) + source_param * self.momentum
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MomentumAnnealingEMA(ExponentialMovingAverage):
|
||||
"""Exponential moving average (EMA) with momentum annealing strategy.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`averaged_param = (1-momentum) * averaged_param + momentum *
|
||||
source_param`. Defaults to 0.0002.
|
||||
gamma (int): Use a larger momentum early in training and gradually
|
||||
annealing to a smaller value to update the ema model smoothly. The
|
||||
momentum is calculated as max(momentum, gamma / (gamma + steps))
|
||||
Defaults to 100.
|
||||
interval (int): Interval between two updates. Defaults to 1.
|
||||
device (torch.device, optional): If provided, the averaged model will
|
||||
be stored on the :attr:`device`. Defaults to None.
|
||||
update_buffers (bool): if True, it will compute running averages for
|
||||
both the parameters and the buffers of the model. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
momentum: float = 0.0002,
|
||||
gamma: int = 100,
|
||||
interval: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
update_buffers: bool = False) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
momentum=momentum,
|
||||
interval=interval,
|
||||
device=device,
|
||||
update_buffers=update_buffers)
|
||||
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
|
||||
self.gamma = gamma
|
||||
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
steps: int) -> Tensor:
|
||||
"""Compute the moving average of the parameters using the linear
|
||||
momentum strategy.
|
||||
|
||||
Args:
|
||||
averaged_param (Tensor): The averaged parameters.
|
||||
source_param (Tensor): The source parameters.
|
||||
steps (int): The number of times the parameters have been
|
||||
updated.
|
||||
Returns:
|
||||
Tensor: The averaged parameters.
|
||||
"""
|
||||
momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
|
||||
return averaged_param * (1 - momentum) + source_param * momentum
|
@ -1233,10 +1233,12 @@ class Runner:
|
||||
self._val_loop = self.build_val_loop(
|
||||
self._val_loop) # type: ignore
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||
self.call_hook('before_run')
|
||||
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
self.train_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
|
||||
@ -1250,9 +1252,11 @@ class Runner:
|
||||
|
||||
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
|
||||
|
||||
self.call_hook('before_run')
|
||||
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
self.call_hook('before_run')
|
||||
self.val_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
|
||||
@ -1266,9 +1270,11 @@ class Runner:
|
||||
|
||||
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
|
||||
|
||||
self.call_hook('before_run')
|
||||
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
self.call_hook('before_run')
|
||||
self.test_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
|
||||
@ -1535,7 +1541,7 @@ class Runner:
|
||||
checkpoint = _load_checkpoint(filename, map_location=map_location)
|
||||
|
||||
# Add comments to describe the usage of `after_load_ckpt`
|
||||
self.call_hook('after_load_ckpt', checkpoint=checkpoint)
|
||||
self.call_hook('after_load_checkpoint', checkpoint=checkpoint)
|
||||
|
||||
if is_model_wrapper(self.model):
|
||||
model = self.model.module
|
||||
@ -1627,8 +1633,7 @@ class Runner:
|
||||
state_dict = _scheduler.state_dict() # type: ignore
|
||||
checkpoint['param_schedulers'].append(state_dict)
|
||||
|
||||
self.call_hook('before_save_ckpt', checkpoint=checkpoint)
|
||||
|
||||
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
|
||||
save_checkpoint(checkpoint, filepath)
|
||||
# in some environments, `os.symlink` is not supported, you may need to
|
||||
# set `create_symlink` to False
|
||||
|
142
tests/test_hook/test_ema_hook.py
Normal file
142
tests/test_hook/test_ema_hook.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.model import ExponentialMovingAverage
|
||||
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, data_batch, return_loss=False):
|
||||
inputs, labels = [], []
|
||||
for x in data_batch:
|
||||
inputs.append(x['inputs'])
|
||||
labels.append(x['data_sample'])
|
||||
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
inputs = torch.stack(inputs).to(device)
|
||||
labels = torch.stack(labels).to(device)
|
||||
outputs = self.linear(inputs)
|
||||
if return_loss:
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
|
||||
return outputs
|
||||
else:
|
||||
outputs = dict(log_vars=dict(a=1, b=0.5))
|
||||
return outputs
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestEMAHook(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_ema_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = ToyModel().to(device)
|
||||
evaluator = Mock()
|
||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
optimizer=torch.optim.Adam(ToyModel().parameters()),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2),
|
||||
val_cfg=dict(interval=1),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', )],
|
||||
experiment_name='test1')
|
||||
runner.train()
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, EMAHook):
|
||||
self.assertTrue(
|
||||
isinstance(hook.ema_model, ExponentialMovingAverage))
|
||||
|
||||
self.assertTrue(
|
||||
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
|
||||
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
self.assertTrue('ema_state_dict' in checkpoint)
|
||||
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
||||
|
||||
# load and testing
|
||||
runner = Runner(
|
||||
model=model,
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
experiment_name='test2')
|
||||
runner.test()
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class DummyWrapper(nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
# with model wrapper
|
||||
runner = Runner(
|
||||
model=DummyWrapper(model),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='DummyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=evaluator,
|
||||
test_cfg=dict(),
|
||||
work_dir=self.temp_dir.name,
|
||||
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook')],
|
||||
experiment_name='test3')
|
||||
runner.test()
|
255
tests/test_model/test_averaged_model.py
Normal file
255
tests/test_model/test_averaged_model.py
Normal file
@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
class TestAveragedModel(TestCase):
|
||||
"""Test the AveragedModel class.
|
||||
|
||||
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||
""" # noqa: E501
|
||||
|
||||
def _test_swa_model(self, net_device, avg_device):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.Linear(5, 10)).to(net_device)
|
||||
|
||||
averaged_model = StochasticWeightAverage(model, device=avg_device)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param) for param in model.parameters()
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
p_avg += p.detach() / n_updates
|
||||
averaged_model.update_parameters(model)
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
|
||||
# Check that AveragedModel is on the correct device
|
||||
self.assertTrue(p_swa.device == avg_device)
|
||||
self.assertTrue(p.device == net_device)
|
||||
assert_allclose(p_avg, p_swa.to(p_avg.device))
|
||||
self.assertTrue(averaged_model.steps.device == avg_device)
|
||||
|
||||
def test_averaged_model_all_devices(self):
|
||||
cpu = torch.device('cpu')
|
||||
self._test_swa_model(cpu, cpu)
|
||||
if torch.cuda.is_available():
|
||||
cuda = torch.device(0)
|
||||
self._test_swa_model(cuda, cpu)
|
||||
self._test_swa_model(cpu, cuda)
|
||||
self._test_swa_model(cuda, cuda)
|
||||
|
||||
def test_swa_mixed_device(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
model[0].cuda()
|
||||
model[1].cpu()
|
||||
averaged_model = StochasticWeightAverage(model)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param) for param in model.parameters()
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
p_avg += p.detach() / n_updates
|
||||
averaged_model.update_parameters(model)
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
|
||||
assert_allclose(p_avg, p_swa)
|
||||
# Check that AveragedModel is on the correct device
|
||||
self.assertTrue(p_avg.device == p_swa.device)
|
||||
|
||||
def test_swa_state_dict(self):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
averaged_model = StochasticWeightAverage(model)
|
||||
averaged_model2 = StochasticWeightAverage(model)
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p in model.parameters():
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
averaged_model.update_parameters(model)
|
||||
averaged_model2.load_state_dict(averaged_model.state_dict())
|
||||
for p_swa, p_swa2 in zip(averaged_model.parameters(),
|
||||
averaged_model2.parameters()):
|
||||
assert_allclose(p_swa, p_swa2)
|
||||
self.assertTrue(averaged_model.steps == averaged_model2.steps)
|
||||
|
||||
def test_ema(self):
|
||||
# test invalid momentum
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'momentum must be in range'):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
ExponentialMovingAverage(model, momentum=3)
|
||||
# test EMA
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
momentum = 0.1
|
||||
|
||||
ema_model = ExponentialMovingAverage(model, momentum=momentum)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param) for param in model.parameters()
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
updated_averaged_params.append(
|
||||
(p_avg * (1 - momentum) + p * momentum).clone())
|
||||
ema_model.update_parameters(model)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
|
||||
assert_allclose(p_target, p_ema)
|
||||
|
||||
def test_ema_update_buffers(self):
|
||||
# Test EMA and update_buffers as True.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
|
||||
momentum = 0.1
|
||||
|
||||
ema_model = ExponentialMovingAverage(
|
||||
model, momentum=momentum, update_buffers=True)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param)
|
||||
for param in itertools.chain(model.parameters(), model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
params = [
|
||||
param for param in itertools.chain(model.parameters(),
|
||||
model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p, p_avg in zip(params, averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
updated_averaged_params.append(
|
||||
(p_avg * (1 - momentum) + p * momentum).clone())
|
||||
ema_model.update_parameters(model)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
ema_params = [
|
||||
param for param in itertools.chain(ema_model.module.parameters(),
|
||||
ema_model.module.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p_target, p_ema in zip(averaged_params, ema_params):
|
||||
assert_allclose(p_target, p_ema)
|
||||
|
||||
def test_momentum_annealing_ema(self):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
|
||||
# Test invalid gamma
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'gamma must be greater than 0'):
|
||||
MomentumAnnealingEMA(model, gamma=-1)
|
||||
|
||||
# Test EMA with momentum annealing.
|
||||
momentum = 0.1
|
||||
gamma = 4
|
||||
|
||||
ema_model = MomentumAnnealingEMA(
|
||||
model, gamma=gamma, momentum=momentum, update_buffers=True)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param)
|
||||
for param in itertools.chain(model.parameters(), model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
params = [
|
||||
param for param in itertools.chain(model.parameters(),
|
||||
model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p, p_avg in zip(params, averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
m = max(momentum, gamma / (gamma + i))
|
||||
updated_averaged_params.append(
|
||||
(p_avg * (1 - m) + p * m).clone())
|
||||
ema_model.update_parameters(model)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
ema_params = [
|
||||
param for param in itertools.chain(ema_model.module.parameters(),
|
||||
ema_model.module.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p_target, p_ema in zip(averaged_params, ema_params):
|
||||
assert_allclose(p_target, p_ema)
|
||||
|
||||
def test_momentum_annealing_ema_with_interval(self):
|
||||
# Test EMA with momentum annealing and interval
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
|
||||
momentum = 0.1
|
||||
gamma = 4
|
||||
interval = 3
|
||||
|
||||
ema_model = MomentumAnnealingEMA(
|
||||
model,
|
||||
gamma=gamma,
|
||||
momentum=momentum,
|
||||
interval=interval,
|
||||
update_buffers=True)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param)
|
||||
for param in itertools.chain(model.parameters(), model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
params = [
|
||||
param for param in itertools.chain(model.parameters(),
|
||||
model.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p, p_avg in zip(params, averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
elif i % interval == 0:
|
||||
m = max(momentum, gamma / (gamma + i))
|
||||
updated_averaged_params.append(
|
||||
(p_avg * (1 - m) + p * m).clone())
|
||||
else:
|
||||
updated_averaged_params.append(p_avg.clone())
|
||||
ema_model.update_parameters(model)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
ema_params = [
|
||||
param for param in itertools.chain(ema_model.module.parameters(),
|
||||
ema_model.module.buffers())
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
for p_target, p_ema in zip(averaged_params, ema_params):
|
||||
assert_allclose(p_target, p_ema)
|
Loading…
x
Reference in New Issue
Block a user