RenQin 7a7b048f23
[Feature]: Add BEiT Support (#425)
* [Feature]: Add BEiT Support

* [Fix]: fix bugs after update

* [Fix]: fix bugs in backbone

* [Refactor]: refactor config

* [Feature]: Support BEiTv2

* [Fix]: Fix UT

* [Fix]: rename some configs

* [Fix]: fix beitv2neck

* [Refactor]: refactor beitv2

* [Fix]: fix lint

* refactor configs

* refactor beitv2

* update configs

* add dalle target generator

* refactor for beitv1

* refactor rel_pos_bias of beit

* update configs

* update configs

* update v1 configs

* update v2 configs

* refactoe layer decay

* update unittest

* fix lint

* fix ut

* add docstrings

* rename

* fix lint

* add beit model and log links

* fix lint

* update according to review

* update

* update

* update LearningRateDecayOptimWrapperConstructor
related configs

* update init and backbone

* update neck and vqkd

* refactor neck

* fix lint

* add some comments

* fix typo

Co-authored-by: 任琴 <PJLAB\renqin@shai14001114l.pjlab.org>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2022-12-06 16:40:05 +08:00

66 lines
1.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmengine.structures import InstanceData
from mmselfsup.models import BEiT
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(-20.4, -20.4, -20.4),
second_std=(204., 204., 204.),
bgr_to_rgb=True)
# model settings
backbone = dict(
type='BEiTViT',
arch='base',
patch_size=16,
drop_path_rate=0.1,
final_norm=True,
layer_scale_init_value=0.1,
)
neck = None
head = dict(
type='BEiTV1Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss'))
target_generator = dict(type='DALL-E')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_beitv1():
register_all_modules()
model = BEiT(
backbone=backbone,
neck=neck,
head=head,
target_generator=target_generator,
data_preprocessor=data_preprocessor)
fake_img = torch.rand((1, 3, 224, 224))
fake_target_img = torch.rand((1, 3, 112, 112))
fake_mask = torch.zeros((196)).bool()
fake_mask[75:150] = 1
fake_data_sample = SelfSupDataSample()
fake_mask = InstanceData(value=fake_mask)
fake_data_sample.mask = fake_mask
fake_data_sample = [fake_data_sample]
fake_data = {
'inputs': [fake_img, fake_target_img],
'data_sample': fake_data_sample
}
fake_batch_inputs, fake_data_samples = model.data_preprocessor(fake_data)
fake_outputs = model(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)