mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Fix]: Add PixMIM UT
This commit is contained in:
parent
e621733ecd
commit
4efb0249dc
63
tests/test_models/test_algorithms/test_pixmim.py
Normal file
63
tests/test_models/test_algorithms/test_pixmim.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms.pixmim import PixMIM
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
|
||||
neck = dict(
|
||||
type='MAEPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=8,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
)
|
||||
loss = dict(type='MAEReconstructionLoss')
|
||||
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
|
||||
target_generator = dict(type='LowFreqTargetGenerator', radius=40, img_size=224)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_pixmim():
|
||||
data_preprocessor = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
|
||||
alg = PixMIM(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor),
|
||||
target_generator=target_generator)
|
||||
|
||||
fake_data = {
|
||||
'inputs': [torch.randn((2, 3, 224, 224))],
|
||||
'data_sample': [SelfSupDataSample() for _ in range(2)]
|
||||
}
|
||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
# test extraction
|
||||
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats.shape) == [2, 196, 768]
|
||||
|
||||
# test reconstruct
|
||||
mean = fake_feats.mean(dim=-1, keepdim=True)
|
||||
std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
||||
results = alg.reconstruct(
|
||||
fake_feats, fake_data_samples, mean=mean, std=std)
|
||||
assert list(results.mask.value.shape) == [2, 224, 224, 3]
|
||||
assert list(results.pred.value.shape) == [2, 224, 224, 3]
|
Loading…
x
Reference in New Issue
Block a user