Yixiao Fang e4c4a81b56
[Feature] Support iTPN and HiViT (#1584)
* hivit added

* Update hivit.py

* Update hivit.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update hivit.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Update itpn.py

* Add files via upload

* Update __init__.py

* Update mae_hivit-base-p16.py

* Delete mim_itpn-base-p16.py

* Add files via upload

* Update itpn_hivit-base-p16.py

* Update itpn.py

* Update hivit.py

* Update __init__.py

* Update mae.py

* Delete hivit.py

* Update __init__.py

* Delete configs/itpn directory

* Add files via upload

* Add files via upload

* Delete configs/hivit directory

* Add files via upload

* refactor and add metafile and readme

* update clip

* add ut

* update ut

* update

* update docstring

* update model.rst

---------

Co-authored-by: 田运杰 <48153283+sunsmarterjie@users.noreply.github.com>
2023-05-26 12:08:34 +08:00

58 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import iTPN
from mmpretrain.structures import DataSample
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_itpn():
data_preprocessor = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
backbone = dict(
type='iTPNHiViT',
arch='base',
reconstruction_type='pixel',
mask_ratio=0.75)
neck = dict(
type='iTPNPretrainDecoder',
num_patches=196,
patch_size=16,
in_chans=3,
embed_dim=512,
decoder_embed_dim=512,
decoder_depth=6,
decoder_num_heads=16,
mlp_ratio=4.,
reconstruction_type='pixel',
# transformer pyramid
fpn_dim=256,
fpn_depth=2,
num_outs=3,
)
head = dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2'))
alg = iTPN(
backbone=backbone,
neck=neck,
head=head,
data_preprocessor=data_preprocessor)
fake_data = {
'inputs': torch.randn((2, 3, 224, 224)),
'data_samples': [DataSample() for _ in range(2)]
}
fake_inputs = alg.data_preprocessor(fake_data)
fake_outputs = alg(**fake_inputs, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)