diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 94da035b..84a85c1c 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from abc import abstractmethod from copy import deepcopy from typing import Optional @@ -151,6 +152,13 @@ class ExponentialMovingAverage(BaseAveragedModel): Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, + :math:`Xema_{t+1}` is the moving average and :math:`X_t` is the + new observed value. The value of momentum is usually a small number, + allowing observed values to slowly update the ema parameters. + Args: model (nn.Module): The model to be averaged. momentum (float): The momentum used for updating ema parameter. @@ -175,6 +183,12 @@ class ExponentialMovingAverage(BaseAveragedModel): super().__init__(model, interval, device, update_buffers) assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ f'but got {momentum}' + if momentum > 0.5: + warnings.warn( + 'The value of momentum in EMA is usually a small number,' + 'which is different from the conventional notion of ' + f'momentum but got {momentum}. Please make sure the ' + f'value is correct.') self.momentum = momentum def avg_func(self, averaged_param: Tensor, source_param: Tensor, diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index 8151902c..efbe3349 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -93,6 +93,13 @@ class TestAveragedModel(TestCase): model = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=3) + + with self.assertWarnsRegex( + Warning, + 'The value of momentum in EMA is usually a small number'): + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + ExponentialMovingAverage(model, momentum=0.9) # test EMA model = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))