[Enhance] Raise warning for abnormal momentum (#655)

pull/718/head^2
RangiLyu 2022-11-01 14:20:22 +08:00 committed by Zaida Zhou
parent 4a9df3bd3b
commit f2b0540f58
2 changed files with 21 additions and 0 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import abstractmethod
from copy import deepcopy
from typing import Optional
@ -151,6 +152,13 @@ class ExponentialMovingAverage(BaseAveragedModel):
Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically,
:math:`Xema_{t+1}` is the moving average and :math:`X_t` is the
new observed value. The value of momentum is usually a small number,
allowing observed values to slowly update the ema parameters.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
@ -175,6 +183,12 @@ class ExponentialMovingAverage(BaseAveragedModel):
super().__init__(model, interval, device, update_buffers)
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
self.momentum = momentum
def avg_func(self, averaged_param: Tensor, source_param: Tensor,

View File

@ -93,6 +93,13 @@ class TestAveragedModel(TestCase):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=3)
with self.assertWarnsRegex(
Warning,
'The value of momentum in EMA is usually a small number'):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=0.9)
# test EMA
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))