mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
* [Fix]: Set qkv bias to False for cae and True for mae (#303) * [Fix]: Add mmcls transformer layer choice * [Fix]: Fix transformer encoder layer bug * [Fix]: Change UT of cae * [Feature]: Change the file name of cosine annealing hook (#304) * [Feature]: Change cosine annealing hook file name * [Feature]: Add UT for cosine annealing hook * [Fix]: Fix lint * read tutorials and fix typo (#308) * [Fix] fix config errors in MAE (#307) * update readthedocs algorithm readme (#310) * [Docs] Replace markdownlint with mdformat (#311) * Replace markdownlint with mdformat to avoid installing ruby * fix typo * add 'ba' to codespell ignore-words-list * Configure Myst-parser to parse anchor tag (#309) * [Docs] rewrite install.md (#317) * rewrite the install.md * add faq.md * fix lint * add FAQ to README * add Chinese version * fix typo * fix format * remove modification * fix format * [Docs] refine README.md file (#318) * refine README.md file * fix lint * format language button * rename getting_started.md * revise index.rst * add model_zoo.md to index.rst * fix lint * refine readme Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com> * [Enhance] update byol models and results (#319) * Update version information (#321) Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yi Lu <21515006@zju.edu.cn> Co-authored-by: RenQin <45731309+soonera@users.noreply.github.com> Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import platform
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmselfsup.models.algorithms import CAE
|
|
|
|
# model settings
|
|
backbone = dict(
|
|
type='CAEViT', arch='b', patch_size=16, init_values=0.1, qkv_bias=False)
|
|
neck = dict(
|
|
type='CAENeck',
|
|
patch_size=16,
|
|
embed_dims=768,
|
|
num_heads=12,
|
|
regressor_depth=4,
|
|
decoder_depth=4,
|
|
mlp_ratio=4,
|
|
init_values=0.1,
|
|
)
|
|
head = dict(
|
|
type='CAEHead', tokenizer_path='cae_ckpt/encoder_stat_dict.pth', lambd=2)
|
|
|
|
|
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
|
def test_cae():
|
|
with pytest.raises(AssertionError):
|
|
model = CAE(backbone=None, neck=neck, head=head)
|
|
with pytest.raises(AssertionError):
|
|
model = CAE(backbone=backbone, neck=None, head=head)
|
|
with pytest.raises(AssertionError):
|
|
model = CAE(backbone=backbone, neck=neck, head=None)
|
|
|
|
model = CAE(backbone=backbone, neck=neck, head=head)
|
|
model.init_weights()
|
|
|
|
fake_input = torch.rand((1, 3, 224, 224))
|
|
fake_target = torch.rand((1, 3, 112, 112))
|
|
fake_mask = torch.zeros((1, 196)).bool()
|
|
fake_mask[:, 75:150] = 1
|
|
|
|
inputs = (fake_input, fake_target, fake_mask)
|
|
|
|
fake_loss = model.forward_train(inputs)
|
|
fake_feat = model.extract_feat(fake_input, fake_mask)
|
|
assert isinstance(fake_loss['loss'].item(), float)
|
|
assert list(fake_feat.shape) == [1, 122, 768]
|