mmselfsup/tests/test_models/test_algorithms/test_mae.py

41 lines
1.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.algorithms import MAE
backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
neck = dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
)
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae():
with pytest.raises(AssertionError):
alg = MAE(backbone=backbone, neck=None, head=head)
with pytest.raises(AssertionError):
alg = MAE(backbone=backbone, neck=neck, head=None)
with pytest.raises(AssertionError):
alg = MAE(backbone=None, neck=neck, head=head)
alg = MAE(backbone=backbone, neck=neck, head=head)
fake_input = torch.randn((2, 3, 224, 224))
fake_loss = alg.forward_train(fake_input)
fake_feature = alg.extract_feat(fake_input)
mask, pred = alg.forward_test(fake_input)
assert isinstance(fake_loss['loss'].item(), float)
assert list(fake_feature[0].shape) == [2, 50, 768]
assert list(mask.shape) == [2, 224, 224, 3]
assert list(pred.shape) == [2, 224, 224, 3]