From 4efb0249dce598da48e49332dd1c79dc01ff80c0 Mon Sep 17 00:00:00 2001 From: liuyuan <3463423099@qq.com> Date: Wed, 22 Mar 2023 11:26:09 +0800 Subject: [PATCH] [Fix]: Add PixMIM UT --- .../test_algorithms/test_pixmim.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_models/test_algorithms/test_pixmim.py diff --git a/tests/test_models/test_algorithms/test_pixmim.py b/tests/test_models/test_algorithms/test_pixmim.py new file mode 100644 index 00000000..0a904520 --- /dev/null +++ b/tests/test_models/test_algorithms/test_pixmim.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform + +import pytest +import torch + +from mmselfsup.models.algorithms.pixmim import PixMIM +from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import register_all_modules + +register_all_modules() + +backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75) +neck = dict( + type='MAEPretrainDecoder', + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., +) +loss = dict(type='MAEReconstructionLoss') +head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss) +target_generator = dict(type='LowFreqTargetGenerator', radius=40, img_size=224) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_pixmim(): + data_preprocessor = { + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'bgr_to_rgb': True + } + + alg = PixMIM( + backbone=backbone, + neck=neck, + head=head, + data_preprocessor=copy.deepcopy(data_preprocessor), + target_generator=target_generator) + + fake_data = { + 'inputs': [torch.randn((2, 3, 224, 224))], + 'data_sample': [SelfSupDataSample() for _ in range(2)] + } + fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data) + fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss') + assert isinstance(fake_outputs['loss'].item(), float) + + # test extraction + fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor') + assert list(fake_feats.shape) == [2, 196, 768] + + # test reconstruct + mean = fake_feats.mean(dim=-1, keepdim=True) + std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5 + results = alg.reconstruct( + fake_feats, fake_data_samples, mean=mean, std=std) + assert list(results.mask.value.shape) == [2, 224, 224, 3] + assert list(results.pred.value.shape) == [2, 224, 224, 3]