[Enhance] Raise warning for abnormal momentum (#655)
parent
4a9df3bd3b
commit
f2b0540f58
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
@ -151,6 +152,13 @@ class ExponentialMovingAverage(BaseAveragedModel):
|
|||
|
||||
Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t
|
||||
|
||||
.. note::
|
||||
This :attr:`momentum` argument is different from one used in optimizer
|
||||
classes and the conventional notion of momentum. Mathematically,
|
||||
:math:`Xema_{t+1}` is the moving average and :math:`X_t` is the
|
||||
new observed value. The value of momentum is usually a small number,
|
||||
allowing observed values to slowly update the ema parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
|
@ -175,6 +183,12 @@ class ExponentialMovingAverage(BaseAveragedModel):
|
|||
super().__init__(model, interval, device, update_buffers)
|
||||
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
|
||||
f'but got {momentum}'
|
||||
if momentum > 0.5:
|
||||
warnings.warn(
|
||||
'The value of momentum in EMA is usually a small number,'
|
||||
'which is different from the conventional notion of '
|
||||
f'momentum but got {momentum}. Please make sure the '
|
||||
f'value is correct.')
|
||||
self.momentum = momentum
|
||||
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
|
|
|
@ -93,6 +93,13 @@ class TestAveragedModel(TestCase):
|
|||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
ExponentialMovingAverage(model, momentum=3)
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
Warning,
|
||||
'The value of momentum in EMA is usually a small number'):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
ExponentialMovingAverage(model, momentum=0.9)
|
||||
# test EMA
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
|
|
Loading…
Reference in New Issue