mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* remove basehead * add moco series * add byol simclr simsiam * add ut * update configs * add simsiam hook * add and refactor beit * update ut * add cae * update extract_feat * refactor cae * add mae * refactor data preprocessor * update heads * add maskfeat * add milan * add simmim * add mixmim * fix lint * fix ut * fix lint * add eva * add densecl * add barlowtwins * add swav * fix lint * update readtherdocs rst * update docs * update * Decrease UT memory usage * Fix docstring * update DALLEEncoder * Update model docs * refactor dalle encoder * update docstring * fix ut * fix config error * add val_cfg and test_cfg * refactor clip generator * fix lint * pass check * fix ut * add lars * update type of BEiT in configs * Use MMEngine style momentum in EMA. * apply mmpretrain solarize --------- Co-authored-by: mzr1996 <mzr1996@163.com>
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.logging import MessageHub
|
|
from mmengine.testing import assert_allclose
|
|
|
|
from mmpretrain.models.utils import CosineEMA
|
|
|
|
|
|
class TestEMA(TestCase):
|
|
|
|
def test_cosine_ema(self):
|
|
model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
|
|
|
|
# init message hub
|
|
max_iters = 5
|
|
test = dict(name='ema_test')
|
|
message_hub = MessageHub.get_instance(**test)
|
|
message_hub.update_info('max_iters', max_iters)
|
|
|
|
# test EMA
|
|
momentum = 0.996
|
|
end_momentum = 1.
|
|
|
|
ema_model = CosineEMA(model, momentum=1 - momentum)
|
|
averaged_params = [
|
|
torch.zeros_like(param) for param in model.parameters()
|
|
]
|
|
|
|
for i in range(max_iters):
|
|
updated_averaged_params = []
|
|
for p, p_avg in zip(model.parameters(), averaged_params):
|
|
p.detach().add_(torch.randn_like(p))
|
|
if i == 0:
|
|
updated_averaged_params.append(p.clone())
|
|
else:
|
|
m = end_momentum - (end_momentum - momentum) * (
|
|
math.cos(math.pi * i / float(max_iters)) + 1) / 2
|
|
updated_averaged_params.append(
|
|
(p_avg * m + p * (1 - m)).clone())
|
|
ema_model.update_parameters(model)
|
|
averaged_params = updated_averaged_params
|
|
|
|
for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
|
|
assert_allclose(p_target, p_ema)
|