From 0279ac2e8de45d6963a7429033c5223a74ee2aa6 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Thu, 19 May 2022 18:53:04 +0800 Subject: [PATCH] [Feature] Support EMA and SWA. (#239) * [Feature] Support EMA and SWA. * add ema hook * add avg model ut * add more unit tests * resolve comments * fix warmup ema * rename * fix comments * add assert * fix typehint * add comments --- mmengine/hooks/__init__.py | 3 +- mmengine/hooks/ema_hook.py | 94 +++++++++ mmengine/model/__init__.py | 8 +- mmengine/model/averaged_model.py | 239 ++++++++++++++++++++++ mmengine/runner/runner.py | 19 +- tests/test_hook/test_ema_hook.py | 142 +++++++++++++ tests/test_model/test_averaged_model.py | 255 ++++++++++++++++++++++++ 7 files changed, 751 insertions(+), 9 deletions(-) create mode 100644 mmengine/hooks/ema_hook.py create mode 100644 mmengine/model/averaged_model.py create mode 100644 tests/test_hook/test_ema_hook.py create mode 100644 tests/test_model/test_averaged_model.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 44698bbf..5252e052 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .checkpoint_hook import CheckpointHook +from .ema_hook import EMAHook from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook @@ -13,5 +14,5 @@ from .sync_buffer_hook import SyncBuffersHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', - 'LoggerHook', 'NaiveVisualizationHook' + 'LoggerHook', 'NaiveVisualizationHook', 'EMAHook' ] diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py new file mode 100644 index 00000000..6d8a398e --- /dev/null +++ b/mmengine/hooks/ema_hook.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import Optional + +from mmengine.model import is_model_wrapper +from mmengine.registry import HOOKS, MODELS +from .hook import DATA_BATCH, Hook + + +@HOOKS.register_module() +class EMAHook(Hook): + """A Hook to apply Exponential Moving Average (EMA) on the model during + training. + + Note: + - EMAHook takes priority over CheckpointHook. + - The original model parameters are actually saved in ema field after + train. + + Args: + ema_type (str): The type of EMA strategy to use. You can find the + supported strategies in ``mmengine.model.averaged_model``. + Defaults to 'ExponentialMovingAverage' + """ + + def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs): + self.ema_cfg = dict(type=ema_type, **kwargs) + + def before_run(self, runner) -> None: + """Create an ema copy of the model.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + self.src_model = model + self.ema_model = MODELS.build( + self.ema_cfg, default_args=dict(model=self.src_model)) + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + """Update ema parameter.""" + self.ema_model.update_parameters(self.src_model) + + def before_val_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before + validation.""" + self._swap_ema_parameters() + + def after_val_epoch(self, runner) -> None: + """We recover source model's parameter from ema model after + validation.""" + self._swap_ema_parameters() + + def before_test_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before + test.""" + self._swap_ema_parameters() + + def after_test_epoch(self, runner) -> None: + """We recover source model's parameter from ema model after test.""" + self._swap_ema_parameters() + + def before_save_checkpoint(self, runner, checkpoint: dict) -> None: + """Save ema parameters to checkpoint.""" + # save ema parameters to the source model's state dict so that we can + # directly load the averaged model weights for deployment. + self._swap_ema_parameters() + checkpoint['ema_state_dict'] = self.ema_model.state_dict() + self._swap_ema_parameters() + + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: + """Resume ema parameters from checkpoint.""" + self.ema_model.load_state_dict(checkpoint['ema_state_dict']) + # The original model parameters are actually saved in ema field. + # swap the weights back to resume ema state. + self._swap_ema_parameters() + + def _swap_ema_parameters(self) -> None: + """Swap the parameter of model with ema_model.""" + avg_param = ( + itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + if self.ema_model.update_buffers else + self.ema_model.module.parameters()) + src_param = ( + itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + if self.ema_model.update_buffers else self.src_model.parameters()) + for p_avg, p_src in zip(avg_param, src_param): + tmp = p_avg.data.clone() + p_avg.data.copy_(p_src.data) + p_src.data.copy_(tmp) diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 3620b7ff..082f9131 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,5 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, + StochasticWeightAverage) from .wrappers import (MMDataParallel, MMDistributedDataParallel, is_model_wrapper) -__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper'] +__all__ = [ + 'MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper', + 'StochasticWeightAverage', 'ExponentialMovingAverage', + 'MomentumAnnealingEMA' +] diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py new file mode 100644 index 00000000..e42defe5 --- /dev/null +++ b/mmengine/model/averaged_model.py @@ -0,0 +1,239 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from abc import abstractmethod +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from mmengine.registry import MODELS + + +class BaseAveragedModel(nn.Module): + """A base class for averaging model weights. + + Weight averaging, such as SWA and EMA, is a widely used technique for + training neural networks. This class implements the averaging process + for a model. All subclasses must implement the `avg_func` method. + This class creates a copy of the provided module :attr:`model` + on the device :attr:`device` and allows computing running averages of the + parameters of the :attr:`model`. + The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py + + In mmengine, we provide two ways to use the model averaging: + 1. Use the model averaging module in hook: + We provide an EMAHook to apply the model averaging during training. + Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner. + The hook is implemented in mmengine/hooks/ema_hook.py + + 2. Use the model averaging module directly in the algorithm. Take the ema + teacher in semi-supervise as an example: + >>> from mmengine.model import ExponentialMovingAverage + >>> student = ResNet(depth=50) + >>> # use ema model as teacher + >>> ema_teacher = ExponentialMovingAverage(student) + + Args: + model (nn.Module): The model to be averaged. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ # noqa: E501 + + def __init__(self, + model: nn.Module, + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__() + self.module = deepcopy(model) + self.interval = interval + if device is not None: + self.module = self.module.to(device) + self.register_buffer('steps', + torch.tensor(0, dtype=torch.long, device=device)) + self.update_buffers = update_buffers + + @abstractmethod + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> Tensor: + """Compute the average of the parameters. All subclasses must implement + this method. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + """ + + def forward(self, *args, **kwargs): + """Forward method of the averaged model.""" + return self.module(*args, **kwargs) + + def update_parameters(self, model: nn.Module) -> None: + """Update the parameters of the model. This method will execute the + ``avg_func`` to compute the new parameters and update the model's + parameters. + + Args: + model (nn.Module): The model whose parameters will be averaged. + """ + if self.steps % self.interval == 0: + avg_param = ( + itertools.chain(self.module.parameters(), + self.module.buffers()) + if self.update_buffers else self.parameters()) + src_param = ( + itertools.chain(model.parameters(), model.buffers()) + if self.update_buffers else model.parameters()) + for p_avg, p_src in zip(avg_param, src_param): + device = p_avg.device + p_src_ = p_src.detach().to(device) + if self.steps == 0: + p_avg.detach().copy_(p_src_) + else: + p_avg.detach().copy_( + self.avg_func(p_avg.detach(), p_src_, + self.steps.to(device))) + self.steps += 1 + + +@MODELS.register_module() +class StochasticWeightAverage(BaseAveragedModel): + """Implements the stochastic weight averaging (SWA) of the model. + + Stochastic Weight Averaging was proposed in `Averaging Weights Leads to + Wider Optima and Better Generalization, UAI 2018. + `_ by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson. + """ + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> Tensor: + """Compute the average of the parameters using stochastic weight + average. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + Returns: + Tensor: The averaged parameters. + """ + return averaged_param + (source_param - averaged_param) / ( + steps // self.interval + 1) + + +@MODELS.register_module() +class ExponentialMovingAverage(BaseAveragedModel): + """Implements the exponential moving average (EMA) of the model. + + All parameters are updated by the formula as below: + + .. math:: + + Xema\_{t+1} = (1 - \text{momentum}) \times + Xema\_{t} + \text{momentum} \times X_t + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameter are updated with the formula: + `averaged_param = (1-momentum) * averaged_param + momentum * + source_param`. Defaults to 0.0002. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ # noqa: W605 + + def __init__(self, + model: nn.Module, + momentum: float = 0.0002, + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__(model, interval, device, update_buffers) + assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ + f'but got {momentum}' + self.momentum = momentum + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> Tensor: + """Compute the moving average of the parameters using exponential + moving average. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + Returns: + Tensor: The averaged parameters. + """ + return averaged_param * (1 - + self.momentum) + source_param * self.momentum + + +@MODELS.register_module() +class MomentumAnnealingEMA(ExponentialMovingAverage): + """Exponential moving average (EMA) with momentum annealing strategy. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameter are updated with the formula: + `averaged_param = (1-momentum) * averaged_param + momentum * + source_param`. Defaults to 0.0002. + gamma (int): Use a larger momentum early in training and gradually + annealing to a smaller value to update the ema model smoothly. The + momentum is calculated as max(momentum, gamma / (gamma + steps)) + Defaults to 100. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.0002, + gamma: int = 100, + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' + self.gamma = gamma + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> Tensor: + """Compute the moving average of the parameters using the linear + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + Returns: + Tensor: The averaged parameters. + """ + momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) + return averaged_param * (1 - momentum) + source_param * momentum diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index ead5c6e1..91b5a6d2 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1233,10 +1233,12 @@ class Runner: self._val_loop = self.build_val_loop( self._val_loop) # type: ignore - self.load_or_resume() - # TODO: add a contextmanager to avoid calling `before_run` many times self.call_hook('before_run') + + # make sure checkpoint-related hooks are triggered after `before_run` + self.load_or_resume() + self.train_loop.run() # type: ignore self.call_hook('after_run') @@ -1250,9 +1252,11 @@ class Runner: self._val_loop = self.build_val_loop(self._val_loop) # type: ignore + self.call_hook('before_run') + + # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() - self.call_hook('before_run') self.val_loop.run() # type: ignore self.call_hook('after_run') @@ -1266,9 +1270,11 @@ class Runner: self._test_loop = self.build_test_loop(self._test_loop) # type: ignore + self.call_hook('before_run') + + # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() - self.call_hook('before_run') self.test_loop.run() # type: ignore self.call_hook('after_run') @@ -1535,7 +1541,7 @@ class Runner: checkpoint = _load_checkpoint(filename, map_location=map_location) # Add comments to describe the usage of `after_load_ckpt` - self.call_hook('after_load_ckpt', checkpoint=checkpoint) + self.call_hook('after_load_checkpoint', checkpoint=checkpoint) if is_model_wrapper(self.model): model = self.model.module @@ -1627,8 +1633,7 @@ class Runner: state_dict = _scheduler.state_dict() # type: ignore checkpoint['param_schedulers'].append(state_dict) - self.call_hook('before_save_ckpt', checkpoint=checkpoint) - + self.call_hook('before_save_checkpoint', checkpoint=checkpoint) save_checkpoint(checkpoint, filepath) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py new file mode 100644 index 00000000..995d6f8e --- /dev/null +++ b/tests/test_hook/test_ema_hook.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase +from unittest.mock import Mock + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from mmengine.hooks import EMAHook +from mmengine.model import ExponentialMovingAverage +from mmengine.registry import DATASETS, MODEL_WRAPPERS +from mmengine.runner import Runner + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, data_batch, return_loss=False): + inputs, labels = [], [] + for x in data_batch: + inputs.append(x['inputs']) + labels.append(x['data_sample']) + + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + inputs = torch.stack(inputs).to(device) + labels = torch.stack(labels).to(device) + outputs = self.linear(inputs) + if return_loss: + loss = (labels - outputs).sum() + outputs = dict(loss=loss, log_vars=dict(loss=loss.item())) + return outputs + else: + outputs = dict(log_vars=dict(a=1, b=0.5)) + return outputs + + +@DATASETS.register_module() +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class TestEMAHook(TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_ema_hook(self): + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model = ToyModel().to(device) + evaluator = Mock() + evaluator.evaluate = Mock(return_value=dict(acc=0.5)) + runner = Runner( + model=model, + train_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=evaluator, + work_dir=self.temp_dir.name, + optimizer=torch.optim.Adam(ToyModel().parameters()), + train_cfg=dict(by_epoch=True, max_epochs=2), + val_cfg=dict(interval=1), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', )], + experiment_name='test1') + runner.train() + for hook in runner.hooks: + if isinstance(hook, EMAHook): + self.assertTrue( + isinstance(hook.ema_model, ExponentialMovingAverage)) + + self.assertTrue( + osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth'))) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + self.assertTrue('ema_state_dict' in checkpoint) + self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) + + # load and testing + runner = Runner( + model=model, + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook')], + experiment_name='test2') + runner.test() + + @MODEL_WRAPPERS.register_module() + class DummyWrapper(nn.Module): + + def __init__(self, model): + super().__init__() + self.module = model + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + # with model wrapper + runner = Runner( + model=DummyWrapper(model), + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook')], + experiment_name='test3') + runner.test() diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py new file mode 100644 index 00000000..f4ed1186 --- /dev/null +++ b/tests/test_model/test_averaged_model.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from unittest import TestCase + +import torch + +from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA, + StochasticWeightAverage) +from mmengine.testing import assert_allclose + + +class TestAveragedModel(TestCase): + """Test the AveragedModel class. + + Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py + """ # noqa: E501 + + def _test_swa_model(self, net_device, avg_device): + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)).to(net_device) + + averaged_model = StochasticWeightAverage(model, device=avg_device) + averaged_params = [ + torch.zeros_like(param) for param in model.parameters() + ] + n_updates = 10 + for i in range(n_updates): + for p, p_avg in zip(model.parameters(), averaged_params): + p.detach().add_(torch.randn_like(p)) + p_avg += p.detach() / n_updates + averaged_model.update_parameters(model) + + for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()): + # Check that AveragedModel is on the correct device + self.assertTrue(p_swa.device == avg_device) + self.assertTrue(p.device == net_device) + assert_allclose(p_avg, p_swa.to(p_avg.device)) + self.assertTrue(averaged_model.steps.device == avg_device) + + def test_averaged_model_all_devices(self): + cpu = torch.device('cpu') + self._test_swa_model(cpu, cpu) + if torch.cuda.is_available(): + cuda = torch.device(0) + self._test_swa_model(cuda, cpu) + self._test_swa_model(cpu, cuda) + self._test_swa_model(cuda, cuda) + + def test_swa_mixed_device(self): + if not torch.cuda.is_available(): + return + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model[0].cuda() + model[1].cpu() + averaged_model = StochasticWeightAverage(model) + averaged_params = [ + torch.zeros_like(param) for param in model.parameters() + ] + n_updates = 10 + for i in range(n_updates): + for p, p_avg in zip(model.parameters(), averaged_params): + p.detach().add_(torch.randn_like(p)) + p_avg += p.detach() / n_updates + averaged_model.update_parameters(model) + + for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()): + assert_allclose(p_avg, p_swa) + # Check that AveragedModel is on the correct device + self.assertTrue(p_avg.device == p_swa.device) + + def test_swa_state_dict(self): + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + averaged_model = StochasticWeightAverage(model) + averaged_model2 = StochasticWeightAverage(model) + n_updates = 10 + for i in range(n_updates): + for p in model.parameters(): + p.detach().add_(torch.randn_like(p)) + averaged_model.update_parameters(model) + averaged_model2.load_state_dict(averaged_model.state_dict()) + for p_swa, p_swa2 in zip(averaged_model.parameters(), + averaged_model2.parameters()): + assert_allclose(p_swa, p_swa2) + self.assertTrue(averaged_model.steps == averaged_model2.steps) + + def test_ema(self): + # test invalid momentum + with self.assertRaisesRegex(AssertionError, + 'momentum must be in range'): + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + ExponentialMovingAverage(model, momentum=3) + # test EMA + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + momentum = 0.1 + + ema_model = ExponentialMovingAverage(model, momentum=momentum) + averaged_params = [ + torch.zeros_like(param) for param in model.parameters() + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + for p, p_avg in zip(model.parameters(), averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + updated_averaged_params.append( + (p_avg * (1 - momentum) + p * momentum).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + for p_target, p_ema in zip(averaged_params, ema_model.parameters()): + assert_allclose(p_target, p_ema) + + def test_ema_update_buffers(self): + # Test EMA and update_buffers as True. + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + momentum = 0.1 + + ema_model = ExponentialMovingAverage( + model, momentum=momentum, update_buffers=True) + averaged_params = [ + torch.zeros_like(param) + for param in itertools.chain(model.parameters(), model.buffers()) + if param.size() != torch.Size([]) + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + params = [ + param for param in itertools.chain(model.parameters(), + model.buffers()) + if param.size() != torch.Size([]) + ] + for p, p_avg in zip(params, averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + updated_averaged_params.append( + (p_avg * (1 - momentum) + p * momentum).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + ema_params = [ + param for param in itertools.chain(ema_model.module.parameters(), + ema_model.module.buffers()) + if param.size() != torch.Size([]) + ] + for p_target, p_ema in zip(averaged_params, ema_params): + assert_allclose(p_target, p_ema) + + def test_momentum_annealing_ema(self): + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + # Test invalid gamma + with self.assertRaisesRegex(AssertionError, + 'gamma must be greater than 0'): + MomentumAnnealingEMA(model, gamma=-1) + + # Test EMA with momentum annealing. + momentum = 0.1 + gamma = 4 + + ema_model = MomentumAnnealingEMA( + model, gamma=gamma, momentum=momentum, update_buffers=True) + averaged_params = [ + torch.zeros_like(param) + for param in itertools.chain(model.parameters(), model.buffers()) + if param.size() != torch.Size([]) + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + params = [ + param for param in itertools.chain(model.parameters(), + model.buffers()) + if param.size() != torch.Size([]) + ] + for p, p_avg in zip(params, averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + m = max(momentum, gamma / (gamma + i)) + updated_averaged_params.append( + (p_avg * (1 - m) + p * m).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + ema_params = [ + param for param in itertools.chain(ema_model.module.parameters(), + ema_model.module.buffers()) + if param.size() != torch.Size([]) + ] + for p_target, p_ema in zip(averaged_params, ema_params): + assert_allclose(p_target, p_ema) + + def test_momentum_annealing_ema_with_interval(self): + # Test EMA with momentum annealing and interval + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + momentum = 0.1 + gamma = 4 + interval = 3 + + ema_model = MomentumAnnealingEMA( + model, + gamma=gamma, + momentum=momentum, + interval=interval, + update_buffers=True) + averaged_params = [ + torch.zeros_like(param) + for param in itertools.chain(model.parameters(), model.buffers()) + if param.size() != torch.Size([]) + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + params = [ + param for param in itertools.chain(model.parameters(), + model.buffers()) + if param.size() != torch.Size([]) + ] + for p, p_avg in zip(params, averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + elif i % interval == 0: + m = max(momentum, gamma / (gamma + i)) + updated_averaged_params.append( + (p_avg * (1 - m) + p * m).clone()) + else: + updated_averaged_params.append(p_avg.clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + ema_params = [ + param for param in itertools.chain(ema_model.module.parameters(), + ema_model.module.buffers()) + if param.size() != torch.Size([]) + ] + for p_target, p_ema in zip(averaged_params, ema_params): + assert_allclose(p_target, p_ema)