mmselfsup/tests/test_models/test_algorithms/test_maskfeat.py

34 lines
1.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.algorithms import MaskFeat
backbone = dict(
type='MaskFeatViT',
arch='b',
patch_size=16,
drop_path_rate=0,
)
head = dict(type='MaskFeatPretrainHead', hog_dim=108)
hog_para = dict(nbins=9, pool=8, gaussian_window=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat():
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para)
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=None, head=head, hog_para=hog_para)
alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para)
fake_img = torch.randn((2, 3, 224, 224))
fake_mask = torch.randn((2, 14, 14)).bool()
fake_input = (fake_img, fake_mask)
fake_loss = alg.forward_train(fake_input)
fake_feature = alg.extract_feat(fake_input)
assert isinstance(fake_loss['loss'].item(), float)
assert list(fake_feature.shape) == [2, 197, 768]