lkylkylky 9e015762d1 [Feature] Add Maskfeat Support (#485)
* [Feature]: Add MaskfeatMaskGenerator Pipeline

* [Feature]: Add HogLayerC for MaskFeat

* [Feature]: Add Backbone of MaskFeat

* [Feature]: Add Head of MaskFeat

* [Feature]: Add Algorithms of MaskFeat

* [Feature]: Add Config of MaskFeat

* [Doc] Update Readme of MaskFeat

* [Fix] fix ut and hog_layer.

* [fix] Add and correct docstring

* [Fix] Refine the docstring of MaskFeat

* [fix] fix value of trunc_normal_

* [fix] rename the finetune config of maskfeat

* [fix] rename the fine-tuning config of maskfeat

* [fix] rename the fine-tuning config of maskfeat

* [fix] add new paramwise_options in fine-tuning config

* [fix] update the top-1 accuary of maskfeat

* [fix] update the top-1 accuary of maskfeat in model_zoo

* [fix] rename MaskfeatMaskGenerator
2022-10-01 13:39:27 +08:00

34 lines
1.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.algorithms import MaskFeat
backbone = dict(
type='MaskFeatViT',
arch='b',
patch_size=16,
drop_path_rate=0,
)
head = dict(type='MaskFeatPretrainHead', hog_dim=108)
hog_para = dict(nbins=9, pool=8, gaussian_window=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat():
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para)
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=None, head=head, hog_para=hog_para)
alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para)
fake_img = torch.randn((2, 3, 224, 224))
fake_mask = torch.randn((2, 14, 14)).bool()
fake_input = (fake_img, fake_mask)
fake_loss = alg.forward_train(fake_input)
fake_feature = alg.extract_feat(fake_input)
assert isinstance(fake_loss['loss'].item(), float)
assert list(fake_feature.shape) == [2, 197, 768]