34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmselfsup.models.algorithms import MaskFeat
|
|
|
|
backbone = dict(
|
|
type='MaskFeatViT',
|
|
arch='b',
|
|
patch_size=16,
|
|
drop_path_rate=0,
|
|
)
|
|
head = dict(type='MaskFeatPretrainHead', hog_dim=108)
|
|
hog_para = dict(nbins=9, pool=8, gaussian_window=16)
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_maskfeat():
|
|
with pytest.raises(AssertionError):
|
|
alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para)
|
|
with pytest.raises(AssertionError):
|
|
alg = MaskFeat(backbone=None, head=head, hog_para=hog_para)
|
|
alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para)
|
|
|
|
fake_img = torch.randn((2, 3, 224, 224))
|
|
fake_mask = torch.randn((2, 14, 14)).bool()
|
|
fake_input = (fake_img, fake_mask)
|
|
fake_loss = alg.forward_train(fake_input)
|
|
fake_feature = alg.extract_feat(fake_input)
|
|
assert isinstance(fake_loss['loss'].item(), float)
|
|
assert list(fake_feature.shape) == [2, 197, 768]
|