mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* remove basehead * add moco series * add byol simclr simsiam * add ut * update configs * add simsiam hook * add and refactor beit * update ut * add cae * update extract_feat * refactor cae * add mae * refactor data preprocessor * update heads * add maskfeat * add milan * add simmim * add mixmim * fix lint * fix ut * fix lint * add eva * add densecl * add barlowtwins * add swav * fix lint * update readtherdocs rst * update docs * update * Decrease UT memory usage * Fix docstring * update DALLEEncoder * Update model docs * refactor dalle encoder * update docstring * fix ut * fix config error * add val_cfg and test_cfg * refactor clip generator * fix lint * pass check * fix ut * add lars * update type of BEiT in configs * Use MMEngine style momentum in EMA. * apply mmpretrain solarize --------- Co-authored-by: mzr1996 <mzr1996@163.com>
59 lines
1.5 KiB
Python
59 lines
1.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpretrain.models import MoCo
|
|
from mmpretrain.structures import DataSample
|
|
|
|
queue_len = 32
|
|
feat_dim = 2
|
|
momentum = 0.001
|
|
backbone = dict(type='ResNet', depth=18, norm_cfg=dict(type='BN'))
|
|
neck = dict(
|
|
type='MoCoV2Neck',
|
|
in_channels=512,
|
|
hid_channels=2,
|
|
out_channels=2,
|
|
with_avg_pool=True)
|
|
head = dict(
|
|
type='ContrastiveHead',
|
|
loss=dict(type='CrossEntropyLoss'),
|
|
temperature=0.2)
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_moco():
|
|
data_preprocessor = {
|
|
'mean': (123.675, 116.28, 103.53),
|
|
'std': (58.395, 57.12, 57.375),
|
|
'to_rgb': True
|
|
}
|
|
|
|
alg = MoCo(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
queue_len=queue_len,
|
|
feat_dim=feat_dim,
|
|
momentum=momentum,
|
|
data_preprocessor=data_preprocessor)
|
|
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
|
|
|
|
fake_data = {
|
|
'inputs':
|
|
[torch.randn((2, 3, 224, 224)),
|
|
torch.randn((2, 3, 224, 224))],
|
|
'data_samples': [DataSample() for _ in range(2)]
|
|
}
|
|
|
|
fake_inputs = alg.data_preprocessor(fake_data)
|
|
fake_loss = alg(**fake_inputs, mode='loss')
|
|
assert fake_loss['loss'] > 0
|
|
assert alg.queue_ptr.item() == 2
|
|
|
|
# test extract
|
|
fake_feats = alg(fake_inputs['inputs'][0], mode='tensor')
|
|
assert fake_feats[0].size() == torch.Size([2, 512, 7, 7])
|