[Feature] Support EMA and SWA. (#239)

* [Feature] Support EMA and SWA.

* add ema hook

* add avg model ut

* add more unit tests

* resolve comments

* fix warmup ema

* rename

* fix comments

* add assert

* fix typehint

* add comments
This commit is contained in:
RangiLyu 2022-05-19 18:53:04 +08:00 committed by GitHub
parent 86ffc19c9c
commit 0279ac2e8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 751 additions and 9 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint_hook import CheckpointHook
from .ema_hook import EMAHook
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
@ -13,5 +14,5 @@ from .sync_buffer_hook import SyncBuffersHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
'LoggerHook', 'NaiveVisualizationHook'
'LoggerHook', 'NaiveVisualizationHook', 'EMAHook'
]

View File

@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Optional
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from .hook import DATA_BATCH, Hook
@HOOKS.register_module()
class EMAHook(Hook):
"""A Hook to apply Exponential Moving Average (EMA) on the model during
training.
Note:
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after
train.
Args:
ema_type (str): The type of EMA strategy to use. You can find the
supported strategies in ``mmengine.model.averaged_model``.
Defaults to 'ExponentialMovingAverage'
"""
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
self.ema_cfg = dict(type=ema_type, **kwargs)
def before_run(self, runner) -> None:
"""Create an ema copy of the model."""
model = runner.model
if is_model_wrapper(model):
model = model.module
self.src_model = model
self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model))
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ema parameter."""
self.ema_model.update_parameters(self.src_model)
def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
validation."""
self._swap_ema_parameters()
def after_val_epoch(self, runner) -> None:
"""We recover source model's parameter from ema model after
validation."""
self._swap_ema_parameters()
def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
test."""
self._swap_ema_parameters()
def after_test_epoch(self, runner) -> None:
"""We recover source model's parameter from ema model after test."""
self._swap_ema_parameters()
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint."""
# save ema parameters to the source model's state dict so that we can
# directly load the averaged model weights for deployment.
self._swap_ema_parameters()
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
self._swap_ema_parameters()
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint."""
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
# The original model parameters are actually saved in ema field.
# swap the weights back to resume ema state.
self._swap_ema_parameters()
def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model."""
avg_param = (
itertools.chain(self.ema_model.module.parameters(),
self.ema_model.module.buffers())
if self.ema_model.update_buffers else
self.ema_model.module.parameters())
src_param = (
itertools.chain(self.src_model.parameters(),
self.src_model.buffers())
if self.ema_model.update_buffers else self.src_model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
tmp = p_avg.data.clone()
p_avg.data.copy_(p_src.data)
p_src.data.copy_(tmp)

View File

@ -1,5 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from .wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper']
__all__ = [
'MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper',
'StochasticWeightAverage', 'ExponentialMovingAverage',
'MomentumAnnealingEMA'
]

View File

@ -0,0 +1,239 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from abc import abstractmethod
from copy import deepcopy
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from mmengine.registry import MODELS
class BaseAveragedModel(nn.Module):
"""A base class for averaging model weights.
Weight averaging, such as SWA and EMA, is a widely used technique for
training neural networks. This class implements the averaging process
for a model. All subclasses must implement the `avg_func` method.
This class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows computing running averages of the
parameters of the :attr:`model`.
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py
In mmengine, we provide two ways to use the model averaging:
1. Use the model averaging module in hook:
We provide an EMAHook to apply the model averaging during training.
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
The hook is implemented in mmengine/hooks/ema_hook.py
2. Use the model averaging module directly in the algorithm. Take the ema
teacher in semi-supervise as an example:
>>> from mmengine.model import ExponentialMovingAverage
>>> student = ResNet(depth=50)
>>> # use ema model as teacher
>>> ema_teacher = ExponentialMovingAverage(student)
Args:
model (nn.Module): The model to be averaged.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
""" # noqa: E501
def __init__(self,
model: nn.Module,
interval: int = 1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__()
self.module = deepcopy(model)
self.interval = interval
if device is not None:
self.module = self.module.to(device)
self.register_buffer('steps',
torch.tensor(0, dtype=torch.long, device=device))
self.update_buffers = update_buffers
@abstractmethod
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
"""Compute the average of the parameters. All subclasses must implement
this method.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
"""
def forward(self, *args, **kwargs):
"""Forward method of the averaged model."""
return self.module(*args, **kwargs)
def update_parameters(self, model: nn.Module) -> None:
"""Update the parameters of the model. This method will execute the
``avg_func`` to compute the new parameters and update the model's
parameters.
Args:
model (nn.Module): The model whose parameters will be averaged.
"""
if self.steps % self.interval == 0:
avg_param = (
itertools.chain(self.module.parameters(),
self.module.buffers())
if self.update_buffers else self.parameters())
src_param = (
itertools.chain(model.parameters(), model.buffers())
if self.update_buffers else model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
device = p_avg.device
p_src_ = p_src.detach().to(device)
if self.steps == 0:
p_avg.detach().copy_(p_src_)
else:
p_avg.detach().copy_(
self.avg_func(p_avg.detach(), p_src_,
self.steps.to(device)))
self.steps += 1
@MODELS.register_module()
class StochasticWeightAverage(BaseAveragedModel):
"""Implements the stochastic weight averaging (SWA) of the model.
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization, UAI 2018.
<https://arxiv.org/abs/1803.05407>`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson.
"""
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
"""Compute the average of the parameters using stochastic weight
average.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
return averaged_param + (source_param - averaged_param) / (
steps // self.interval + 1)
@MODELS.register_module()
class ExponentialMovingAverage(BaseAveragedModel):
"""Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
.. math::
Xema\_{t+1} = (1 - \text{momentum}) \times
Xema\_{t} + \text{momentum} \times X_t
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
""" # noqa: W605
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
interval: int = 1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(model, interval, device, update_buffers)
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
self.momentum = momentum
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
"""Compute the moving average of the parameters using exponential
moving average.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
return averaged_param * (1 -
self.momentum) + source_param * self.momentum
@MODELS.register_module()
class MomentumAnnealingEMA(ExponentialMovingAverage):
"""Exponential moving average (EMA) with momentum annealing strategy.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as max(momentum, gamma / (gamma + steps))
Defaults to 100.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
gamma: int = 100,
interval: int = 1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
self.gamma = gamma
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
"""Compute the moving average of the parameters using the linear
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
return averaged_param * (1 - momentum) + source_param * momentum

View File

@ -1233,10 +1233,12 @@ class Runner:
self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore
self.load_or_resume()
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run')
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
self.train_loop.run() # type: ignore
self.call_hook('after_run')
@ -1250,9 +1252,11 @@ class Runner:
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
self.call_hook('before_run')
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
self.call_hook('before_run')
self.val_loop.run() # type: ignore
self.call_hook('after_run')
@ -1266,9 +1270,11 @@ class Runner:
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
self.call_hook('before_run')
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
self.call_hook('before_run')
self.test_loop.run() # type: ignore
self.call_hook('after_run')
@ -1535,7 +1541,7 @@ class Runner:
checkpoint = _load_checkpoint(filename, map_location=map_location)
# Add comments to describe the usage of `after_load_ckpt`
self.call_hook('after_load_ckpt', checkpoint=checkpoint)
self.call_hook('after_load_checkpoint', checkpoint=checkpoint)
if is_model_wrapper(self.model):
model = self.model.module
@ -1627,8 +1633,7 @@ class Runner:
state_dict = _scheduler.state_dict() # type: ignore
checkpoint['param_schedulers'].append(state_dict)
self.call_hook('before_save_ckpt', checkpoint=checkpoint)
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False

View File

@ -0,0 +1,142 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import Mock
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from mmengine.hooks import EMAHook
from mmengine.model import ExponentialMovingAverage
from mmengine.registry import DATASETS, MODEL_WRAPPERS
from mmengine.runner import Runner
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, data_batch, return_loss=False):
inputs, labels = [], []
for x in data_batch:
inputs.append(x['inputs'])
labels.append(x['data_sample'])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
inputs = torch.stack(inputs).to(device)
labels = torch.stack(labels).to(device)
outputs = self.linear(inputs)
if return_loss:
loss = (labels - outputs).sum()
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
return outputs
else:
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
@DATASETS.register_module()
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TestEMAHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_ema_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = ToyModel().to(device)
evaluator = Mock()
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
runner = Runner(
model=model,
train_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=evaluator,
work_dir=self.temp_dir.name,
optimizer=torch.optim.Adam(ToyModel().parameters()),
train_cfg=dict(by_epoch=True, max_epochs=2),
val_cfg=dict(interval=1),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', )],
experiment_name='test1')
runner.train()
for hook in runner.hooks:
if isinstance(hook, EMAHook):
self.assertTrue(
isinstance(hook.ema_model, ExponentialMovingAverage))
self.assertTrue(
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
self.assertTrue('ema_state_dict' in checkpoint)
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
# load and testing
runner = Runner(
model=model,
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook')],
experiment_name='test2')
runner.test()
@MODEL_WRAPPERS.register_module()
class DummyWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.module = model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# with model wrapper
runner = Runner(
model=DummyWrapper(model),
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook')],
experiment_name='test3')
runner.test()

View File

@ -0,0 +1,255 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from unittest import TestCase
import torch
from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase):
"""Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501
def _test_swa_model(self, net_device, avg_device):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.Linear(5, 10)).to(net_device)
averaged_model = StochasticWeightAverage(model, device=avg_device)
averaged_params = [
torch.zeros_like(param) for param in model.parameters()
]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_model.update_parameters(model)
for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
# Check that AveragedModel is on the correct device
self.assertTrue(p_swa.device == avg_device)
self.assertTrue(p.device == net_device)
assert_allclose(p_avg, p_swa.to(p_avg.device))
self.assertTrue(averaged_model.steps.device == avg_device)
def test_averaged_model_all_devices(self):
cpu = torch.device('cpu')
self._test_swa_model(cpu, cpu)
if torch.cuda.is_available():
cuda = torch.device(0)
self._test_swa_model(cuda, cpu)
self._test_swa_model(cpu, cuda)
self._test_swa_model(cuda, cuda)
def test_swa_mixed_device(self):
if not torch.cuda.is_available():
return
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
model[0].cuda()
model[1].cpu()
averaged_model = StochasticWeightAverage(model)
averaged_params = [
torch.zeros_like(param) for param in model.parameters()
]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_model.update_parameters(model)
for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
assert_allclose(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_avg.device == p_swa.device)
def test_swa_state_dict(self):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
averaged_model = StochasticWeightAverage(model)
averaged_model2 = StochasticWeightAverage(model)
n_updates = 10
for i in range(n_updates):
for p in model.parameters():
p.detach().add_(torch.randn_like(p))
averaged_model.update_parameters(model)
averaged_model2.load_state_dict(averaged_model.state_dict())
for p_swa, p_swa2 in zip(averaged_model.parameters(),
averaged_model2.parameters()):
assert_allclose(p_swa, p_swa2)
self.assertTrue(averaged_model.steps == averaged_model2.steps)
def test_ema(self):
# test invalid momentum
with self.assertRaisesRegex(AssertionError,
'momentum must be in range'):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=3)
# test EMA
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
momentum = 0.1
ema_model = ExponentialMovingAverage(model, momentum=momentum)
averaged_params = [
torch.zeros_like(param) for param in model.parameters()
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * (1 - momentum) + p * momentum).clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
assert_allclose(p_target, p_ema)
def test_ema_update_buffers(self):
# Test EMA and update_buffers as True.
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
momentum = 0.1
ema_model = ExponentialMovingAverage(
model, momentum=momentum, update_buffers=True)
averaged_params = [
torch.zeros_like(param)
for param in itertools.chain(model.parameters(), model.buffers())
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
params = [
param for param in itertools.chain(model.parameters(),
model.buffers())
if param.size() != torch.Size([])
]
for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * (1 - momentum) + p * momentum).clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
ema_params = [
param for param in itertools.chain(ema_model.module.parameters(),
ema_model.module.buffers())
if param.size() != torch.Size([])
]
for p_target, p_ema in zip(averaged_params, ema_params):
assert_allclose(p_target, p_ema)
def test_momentum_annealing_ema(self):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
# Test invalid gamma
with self.assertRaisesRegex(AssertionError,
'gamma must be greater than 0'):
MomentumAnnealingEMA(model, gamma=-1)
# Test EMA with momentum annealing.
momentum = 0.1
gamma = 4
ema_model = MomentumAnnealingEMA(
model, gamma=gamma, momentum=momentum, update_buffers=True)
averaged_params = [
torch.zeros_like(param)
for param in itertools.chain(model.parameters(), model.buffers())
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
params = [
param for param in itertools.chain(model.parameters(),
model.buffers())
if param.size() != torch.Size([])
]
for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
m = max(momentum, gamma / (gamma + i))
updated_averaged_params.append(
(p_avg * (1 - m) + p * m).clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
ema_params = [
param for param in itertools.chain(ema_model.module.parameters(),
ema_model.module.buffers())
if param.size() != torch.Size([])
]
for p_target, p_ema in zip(averaged_params, ema_params):
assert_allclose(p_target, p_ema)
def test_momentum_annealing_ema_with_interval(self):
# Test EMA with momentum annealing and interval
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
momentum = 0.1
gamma = 4
interval = 3
ema_model = MomentumAnnealingEMA(
model,
gamma=gamma,
momentum=momentum,
interval=interval,
update_buffers=True)
averaged_params = [
torch.zeros_like(param)
for param in itertools.chain(model.parameters(), model.buffers())
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
params = [
param for param in itertools.chain(model.parameters(),
model.buffers())
if param.size() != torch.Size([])
]
for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
elif i % interval == 0:
m = max(momentum, gamma / (gamma + i))
updated_averaged_params.append(
(p_avg * (1 - m) + p * m).clone())
else:
updated_averaged_params.append(p_avg.clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
ema_params = [
param for param in itertools.chain(ema_model.module.parameters(),
ema_model.module.buffers())
if param.size() != torch.Size([])
]
for p_target, p_ema in zip(averaged_params, ema_params):
assert_allclose(p_target, p_ema)