mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* remove basehead * add moco series * add byol simclr simsiam * add ut * update configs * add simsiam hook * add and refactor beit * update ut * add cae * update extract_feat * refactor cae * add mae * refactor data preprocessor * update heads * add maskfeat * add milan * add simmim * add mixmim * fix lint * fix ut * fix lint * add eva * add densecl * add barlowtwins * add swav * fix lint * update readtherdocs rst * update docs * update * Decrease UT memory usage * Fix docstring * update DALLEEncoder * Update model docs * refactor dalle encoder * update docstring * fix ut * fix config error * add val_cfg and test_cfg * refactor clip generator * fix lint * pass check * fix ut * add lars * update type of BEiT in configs * Use MMEngine style momentum in EMA. * apply mmpretrain solarize --------- Co-authored-by: mzr1996 <mzr1996@163.com>
71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpretrain.models import SimMIM, SimMIMSwinTransformer
|
|
from mmpretrain.structures import DataSample
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_simmim_swin():
|
|
backbone = dict(
|
|
arch='B',
|
|
img_size=192,
|
|
stage_cfgs=dict(block_cfgs=dict(window_size=6)))
|
|
simmim_backbone = SimMIMSwinTransformer(**backbone)
|
|
simmim_backbone.init_weights()
|
|
fake_inputs = torch.randn((2, 3, 192, 192))
|
|
fake_mask = torch.rand((2, 48, 48))
|
|
|
|
# test with mask
|
|
fake_outputs = simmim_backbone(fake_inputs, fake_mask)[0]
|
|
assert fake_outputs.shape == torch.Size([2, 1024, 6, 6])
|
|
|
|
# test without mask
|
|
fake_outputs = simmim_backbone(fake_inputs, None)
|
|
assert len(fake_outputs) == 1
|
|
assert fake_outputs[0].shape == torch.Size([2, 1024, 6, 6])
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_simmim():
|
|
data_preprocessor = {
|
|
'mean': [0.5, 0.5, 0.5],
|
|
'std': [0.5, 0.5, 0.5],
|
|
'to_rgb': True
|
|
}
|
|
|
|
# model config
|
|
backbone = dict(
|
|
type='SimMIMSwinTransformer',
|
|
arch='B',
|
|
img_size=192,
|
|
stage_cfgs=dict(block_cfgs=dict(window_size=6)))
|
|
neck = dict(
|
|
type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32)
|
|
head = dict(
|
|
type='SimMIMHead',
|
|
patch_size=4,
|
|
loss=dict(type='PixelReconstructionLoss', criterion='L1', channel=3))
|
|
|
|
model = SimMIM(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
data_preprocessor=data_preprocessor)
|
|
|
|
# test forward_train
|
|
fake_data_sample = DataSample()
|
|
fake_mask = torch.rand((48, 48))
|
|
fake_data_sample.set_mask(fake_mask)
|
|
fake_data = {
|
|
'inputs': torch.randn((2, 3, 192, 192)),
|
|
'data_samples': [fake_data_sample for _ in range(2)]
|
|
}
|
|
|
|
fake_inputs = model.data_preprocessor(fake_data)
|
|
fake_outputs = model(**fake_inputs, mode='loss')
|
|
assert isinstance(fake_outputs['loss'].item(), float)
|