Yixiao Fang 08dc8c75d3
[Refactor] Add selfsup algorithms. (#1389)
* remove basehead

* add moco series

* add byol simclr simsiam

* add ut

* update configs

* add simsiam hook

* add and refactor beit

* update ut

* add cae

* update extract_feat

* refactor cae

* add mae

* refactor data preprocessor

* update heads

* add maskfeat

* add milan

* add simmim

* add mixmim

* fix lint

* fix ut

* fix lint

* add eva

* add densecl

* add barlowtwins

* add swav

* fix lint

* update readtherdocs rst

* update docs

* update

* Decrease UT memory usage

* Fix docstring

* update DALLEEncoder

* Update model docs

* refactor dalle encoder

* update docstring

* fix ut

* fix config error

* add val_cfg and test_cfg

* refactor clip generator

* fix lint

* pass check

* fix ut

* add lars

* update type of BEiT in configs

* Use MMEngine style momentum in EMA.

* apply mmpretrain solarize

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-06 16:53:15 +08:00

88 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from math import cos, pi
from typing import Optional
import torch
import torch.nn as nn
from mmengine.logging import MessageHub
from mmengine.model import ExponentialMovingAverage
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CosineEMA(ExponentialMovingAverage):
r"""CosineEMA is implemented for updating momentum parameter, used in BYOL,
MoCoV3, etc.
All parameters are updated by the formula as below:
.. math::
X'_{t+1} = (1 - m) * X'_t + m * X_t
Where :math:`m` the the momentum parameter. And it's updated with cosine
annealing, including momentum adjustment following:
.. math::
m = m_{end} + (m_{end} - m_{start}) * (\cos\frac{k\pi}{K} + 1) / 2
where :math:`k` is the current step, :math:`K` is the total steps.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically,
:math:`X'_{t}` is the moving average and :math:`X_t` is the new
observed value. The value of momentum is usually a small number,
allowing observed values to slowly update the ema parameters. See also
:external:py:class:`torch.nn.BatchNorm2d`.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The start momentum value. Defaults to 0.004.
end_momentum (float): The end momentum value for cosine annealing.
Defaults to 0.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.004,
end_momentum: float = 0.,
interval: int = 1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
self.end_momentum = end_momentum
def avg_func(self, averaged_param: torch.Tensor,
source_param: torch.Tensor, steps: int) -> None:
"""Compute the moving average of the parameters using the cosine
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
message_hub = MessageHub.get_current_instance()
max_iters = message_hub.get_info('max_iters')
cosine_annealing = (cos(pi * steps / float(max_iters)) + 1) / 2
momentum = self.end_momentum - (self.end_momentum -
self.momentum) * cosine_annealing
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)