mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
* [Feature]: Add MaskfeatMaskGenerator Pipeline * [Feature]: Add HogLayerC for MaskFeat * [Feature]: Add Backbone of MaskFeat * [Feature]: Add Head of MaskFeat * [Feature]: Add Algorithms of MaskFeat * [Feature]: Add Config of MaskFeat * [Doc] Update Readme of MaskFeat * [Fix] fix ut and hog_layer. * [fix] Add and correct docstring * [Fix] Refine the docstring of MaskFeat * [fix] fix value of trunc_normal_ * [fix] rename the finetune config of maskfeat * [fix] rename the fine-tuning config of maskfeat * [fix] rename the fine-tuning config of maskfeat * [fix] add new paramwise_options in fine-tuning config * [fix] update the top-1 accuary of maskfeat * [fix] update the top-1 accuary of maskfeat in model_zoo * [fix] rename MaskfeatMaskGenerator
34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmselfsup.models.algorithms import MaskFeat
|
|
|
|
backbone = dict(
|
|
type='MaskFeatViT',
|
|
arch='b',
|
|
patch_size=16,
|
|
drop_path_rate=0,
|
|
)
|
|
head = dict(type='MaskFeatPretrainHead', hog_dim=108)
|
|
hog_para = dict(nbins=9, pool=8, gaussian_window=16)
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_maskfeat():
|
|
with pytest.raises(AssertionError):
|
|
alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para)
|
|
with pytest.raises(AssertionError):
|
|
alg = MaskFeat(backbone=None, head=head, hog_para=hog_para)
|
|
alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para)
|
|
|
|
fake_img = torch.randn((2, 3, 224, 224))
|
|
fake_mask = torch.randn((2, 14, 14)).bool()
|
|
fake_input = (fake_img, fake_mask)
|
|
fake_loss = alg.forward_train(fake_input)
|
|
fake_feature = alg.extract_feat(fake_input)
|
|
assert isinstance(fake_loss['loss'].item(), float)
|
|
assert list(fake_feature.shape) == [2, 197, 768]
|