[Feature]: Add BEiT Support (#425)

* [Feature]: Add BEiT Support

* [Fix]: fix bugs after update

* [Fix]: fix bugs in backbone

* [Refactor]: refactor config

* [Feature]: Support BEiTv2

* [Fix]: Fix UT

* [Fix]: rename some configs

* [Fix]: fix beitv2neck

* [Refactor]: refactor beitv2

* [Fix]: fix lint

* refactor configs

* refactor beitv2

* update configs

* add dalle target generator

* refactor for beitv1

* refactor rel_pos_bias of beit

* update configs

* update configs

* update v1 configs

* update v2 configs

* refactoe layer decay

* update unittest

* fix lint

* fix ut

* add docstrings

* rename

* fix lint

* add beit model and log links

* fix lint

* update according to review

* update

* update

* update LearningRateDecayOptimWrapperConstructor
related configs

* update init and backbone

* update neck and vqkd

* refactor neck

* fix lint

* add some comments

* fix typo

Co-authored-by: 任琴 <PJLAB\renqin@shai14001114l.pjlab.org>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
pull/611/head
RenQin 2022-12-06 16:40:05 +08:00 committed by GitHub
parent ce08509014
commit 7a7b048f23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 2475 additions and 50 deletions

View File

@ -19,14 +19,14 @@ optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9),
clip_grad=dict(max_norm=5.0),
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.norm': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
}),
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor')
}))
# learning rate scheduler
param_scheduler = [

View File

@ -97,7 +97,14 @@ optim_wrapper = dict(
weight_decay=0.05,
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor')
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# learning rate scheduler
param_scheduler = [

View File

@ -88,9 +88,9 @@ optim_wrapper = dict(
layer_decay_rate=0.65), # layer-wise lr decay factor
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))

View File

@ -0,0 +1,52 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=112,
interpolation='bicubic',
second_interpolation='lanczos',
scale=(0.08, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=None,
min_num_patches=16),
dict(
type='PackSelfSupInputs',
algorithm_keys=['mask'],
meta_keys=['img_path'])
]
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(-20.4, -20.4, -20.4),
second_std=(204., 204., 204.),
bgr_to_rgb=True)
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))

View File

@ -0,0 +1,52 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=224,
interpolation='bicubic',
second_interpolation='bicubic',
scale=(0.2, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=75,
min_num_patches=16),
dict(
type='PackSelfSupInputs',
algorithm_keys=['mask'],
meta_keys=['img_path'])
]
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(127.5, 127.5, 127.5),
second_std=(127.5, 127.5, 127.5),
bgr_to_rgb=True)
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))

View File

@ -0,0 +1,28 @@
# model settings
model = dict(
type='BEiT',
backbone=dict(
type='BEiTViT',
arch='base',
patch_size=16,
drop_path_rate=0.1,
final_norm=True,
layer_scale_init_value=0.1,
init_cfg=[
dict(type='TruncNormal', std=0.02, layer='Linear'),
dict(type='TruncNormal', std=0.02, layer='Conv2d'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
]),
neck=None,
head=dict(
type='BEiTV1Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss')),
target_generator=dict(
type='DALL-E',
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501
)))

View File

@ -0,0 +1,59 @@
# model settings
vqkd_encoder = dict(
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
layer_scale_init_value=0.,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None)
layer_scale_init_value = 0.1
drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs.
model = dict(
type='BEiT',
backbone=dict(
type='BEiTViT',
arch='base',
patch_size=16,
out_indices=[-4, -1],
drop_path_rate=drop_path_rate,
final_norm=False,
layer_scale_init_value=layer_scale_init_value,
init_cfg=[
dict(type='TruncNormal', std=0.02, layer='Linear'),
dict(type='TruncNormal', std=0.02, layer='Conv2d'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
]),
neck=dict(
type='BEiTV2Neck',
num_layers=2,
early_layers=9,
backbone_arch='base',
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
),
head=dict(
type='BEiTV2Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss')),
target_generator=dict(
type='VQKD',
encoder_config=vqkd_encoder,
init_cfg=dict(
type='Pretrained', checkpoint='beit_ckpt/vqkd_encoder.pth')))

View File

@ -0,0 +1,60 @@
# BEiT
> [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
<!-- [ALGORITHM] -->
## Abstract
We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/203688351-adac7146-4e71-4ab6-8958-5cfe643a2dc5.png" width="70%"/>
</div>
## Models and Benchmarks
Here, we report the results of the model on ImageNet, the details are below:
<table class="docutils">
<thead>
<tr>
<th rowspan="2">Algorithm</th>
<th rowspan="2">Backbone</th>
<th rowspan="2">Epoch</th>
<th rowspan="2">Batch Size</th>
<th colspan="2" align="center">Results (Top-1 %)</th>
<th colspan="3" align="center">Links</th>
</tr>
<tr>
<th>Linear Eval</th>
<th>Fine-tuning</th>
<th>Pretrain</th>
<th>Linear Eval</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tr>
<td>BEiT</td>
<td>ViT-base</td>
<td>300</td>
<td>2048</td>
<td>/</td>
<td>83.1</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221128-ab79e626.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221123_103802.json'>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221128-0ca393e9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221127_162126.json'>log</a></td>
</tr>
</tbody>
</table>
## Citation
```bibtex
@inproceedings{bao2022beit,
title={{BE}iT: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
}
```

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/beit_vit-base-p16.py',
'../_base_/datasets/imagenet_beit.py',
'../_base_/schedules/adamw_coslr-300e_in1k.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optimizer = dict(
type='AdamW', lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=optimizer,
clip_grad=dict(max_norm=3.0),
paramwise_cfg=dict(
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'q_bias': dict(decay_mult=0.0),
'v_bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=1e-5,
by_epoch=True,
begin=10,
end=300,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)
find_unused_parameters = True

View File

@ -0,0 +1,134 @@
# mmcls:: means we use the default settings from MMClassification
_base_ = [
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py',
'mmcls::_base_/default_runtime.py'
]
data_preprocessor = dict(
num_classes=1000,
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
to_rgb=True,
)
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='BEiT',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit',
layer_decay_rate=0.65),
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
_delete_=True,
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'q_bias': dict(decay_mult=0.0),
'v_bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=20,
end=100,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
train_cfg = dict(by_epoch=True, max_epochs=100)
randomness = dict(seed=0)

View File

@ -0,0 +1,35 @@
Collections:
- Name: BEiT
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
URL: https://arxiv.org/abs/2106.08254
Title: "BEiT: BERT Pre-Training of Image Transformers"
README: configs/selfsup/beit/README.md
Models:
- Name: beit_vit-base-p16_8xb256-amp-coslr-300e_in1k
In Collection: BEiT
Metadata:
Epochs: 300
Batch Size: 2048
Results: null
Config: configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221128-ab79e626.pth
Downstream:
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 1024
Results:
- Task: Fine-tuning
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.1
Config: configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221128-0ca393e9.pth

View File

@ -0,0 +1,60 @@
# BEiT
> [BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers](https://arxiv.org/abs/2208.06366)
<!-- [ALGORITHM] -->
## Abstract
Masked image modeling (MIM) has demonstrated impressive results in self-supervised representation learning by recovering corrupted image patches. However, most existing studies operate on low-level image pixels, which hinders the exploitation of high-level semantics for representation models. In this work, we propose to use a semantic-rich visual tokenizer as the reconstruction target for masked prediction, providing a systematic way to promote MIM from pixel-level to semantic-level. Specifically, we propose vector-quantized knowledge distillation to train the tokenizer, which discretizes a continuous semantic space to compact codes. We then pretrain vision Transformers by predicting the original visual tokens for the masked image patches. Furthermore, we introduce a patch aggregation strategy which associates discrete image patches to enhance global semantic representation. Experiments on image classification and semantic segmentation show that BEiT v2 outperforms all compared MIM methods. On ImageNet-1K (224 size), the base-size BEiT v2 achieves 85.5% top-1 accuracy for fine-tuning and 80.1% top-1 accuracy for linear probing. The large-size BEiT v2 obtains 87.3% top-1 accuracy for ImageNet-1K (224 size) fine-tuning, and 56.7% mIoU on ADE20K for semantic segmentation.
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/203912182-5967a520-d455-49ea-bc67-dcbd500d76bf.png" width="70%"/>
</div>
## Models and Benchmarks
Here, we report the results of the model on ImageNet, the details are below:
<table class="docutils">
<thead>
<tr>
<th rowspan="2">Algorithm</th>
<th rowspan="2">Backbone</th>
<th rowspan="2">Epoch</th>
<th rowspan="2">Batch Size</th>
<th colspan="2" align="center">Results (Top-1 %)</th>
<th colspan="3" align="center">Links</th>
</tr>
<tr>
<th>Linear Eval</th>
<th>Fine-tuning</th>
<th>Pretrain</th>
<th>Linear Eval</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tr>
<td>BEiT</td>
<td>ViT-base</td>
<td>300</td>
<td>2048</td>
<td>/</td>
<td></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'>config</a> | <a href=''>model</a> | <a href=''>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href=''>model</a> | <a href=''>log</a></td>
</tr>
</tbody>
</table>
## Citation
```bibtex
@article{beitv2,
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
journal={ArXiv},
year={2022}
}
```

View File

@ -0,0 +1,55 @@
_base_ = [
'../_base_/models/beitv2_vit-base-p16.py',
'../_base_/datasets/imagenet_beitv2.py',
'../_base_/schedules/adamw_coslr-300e_in1k.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optimizer = dict(type='AdamW', lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=optimizer,
clip_grad=dict(max_norm=3.0),
paramwise_cfg=dict(
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'q_bias': dict(decay_mult=0.0),
'v_bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=1e-5,
by_epoch=True,
begin=10,
end=300,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)
find_unused_parameters = True

View File

@ -0,0 +1,129 @@
# mmcls:: means we use the default settings from MMClassification
_base_ = [
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py',
'mmcls::_base_/default_runtime.py'
]
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='BEiT',
arch='base',
img_size=224,
patch_size=16,
# 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs.
drop_path_rate=0.2,
avg_token=True,
output_cls_token=False,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=5e-4,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit',
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
layer_decay_rate=0.6),
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
_delete_=True,
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'q_bias': dict(decay_mult=0.0),
'v_bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=20,
end=100,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
train_cfg = dict(by_epoch=True, max_epochs=100)
randomness = dict(seed=0)

View File

@ -0,0 +1,126 @@
# mmcls:: means we use the default settings from MMClassification
_base_ = [
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py',
'mmcls::_base_/default_runtime.py'
]
# Fine-tuning 30 epoch is for models which have intermediate fine-tuning
# on ImageNet-21k after self-supervised pretrain.
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='BEiT',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]),
)
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=5e-5,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.75), # layer-wise lr decay factor
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
_delete_=True,
custom_keys={
# the following configurations are designed for BEiTs
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'q_bias': dict(decay_mult=0.0),
'v_bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=20,
end=30,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
train_cfg = dict(by_epoch=True, max_epochs=30)
randomness = dict(seed=0)

View File

@ -0,0 +1,35 @@
Collections:
- Name: BEiTv2
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
URL: https://arxiv.org/abs/2208.06366
Title: 'BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers'
README: configs/selfsup/beitv2/README.md
Models:
- Name: beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k
In Collection: BEiTv2
Metadata:
Epochs: 300
Batch Size: 2048
Results: null
Config: configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py
Weights:
Downstream:
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 1024
Results:
- Task: Fine-tuning
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy:
Config:
Weights:

View File

@ -393,5 +393,16 @@ ImageNet has multiple versions, but the most commonly used one is ILSVRC 2012. T
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb256-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k_20221028-5134431c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k_20221026_105344.json'>log</a></td>
</tr>
<tr>
<td>BEiT</td>
<td>ViT-base</td>
<td>300</td>
<td>2048</td>
<td>/</td>
<td>83.1</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221128-ab79e626.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221123_103802.json'>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221128-0ca393e9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221127_162126.json'>log</a></td>
</tr>
</tbody>
</table>

View File

@ -393,5 +393,16 @@ ImageNet 有多个版本,不过最常用的是 ILSVRC 2012。我们提供了
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb256-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k_20221028-5134431c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k/vit-base-p16_ft-8xb256-coslr-100e_in1k_20221026_105344.json'>log</a></td>
</tr>
<tr>
<td>BEiT</td>
<td>ViT-base</td>
<td>300</td>
<td>2048</td>
<td>/</td>
<td>83.1</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221128-ab79e626.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221123_103802.json'>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221128-0ca393e9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221127_162126.json'>log</a></td>
</tr>
</tbody>
</table>

View File

@ -66,11 +66,8 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
Note: Currently, this optimizer constructor is built for ViT and Swin.
In addition to applying layer-wise learning rate decay schedule, this
module will not apply weight decay to ``normalization parameters``,
``bias``, ``position embedding``, ``class token``, and
``relative position bias table, automatically. What's more, the
``paramwise_cfg`` in the base module will be ignored.
In addition to applying layer-wise learning rate decay schedule, the
paramwise_cfg only supports weight decay customization.
"""
def add_params(self, params: List[dict], module: nn.Module,
@ -87,14 +84,27 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
# get logger
logger = MMLogger.get_current_instance()
logger.warning(
'LearningRateDecayOptimWrapperConstructor is refactored in '
'v1.0.0rc4, which need to configure zero weight decay manually. '
'The previous versions would set zero weight decay according to '
'the dimension of parameter. Please specify weight decay settings '
'of different layers in config if needed.')
# Check if self.param_cfg is not None
if len(self.paramwise_cfg) > 0:
logger.info('The paramwise_cfg will be ignored, and normalization \
parameters, bias, position embedding, class token and \
relative position bias table will not be decayed by \
default.')
logger.info(
'The paramwise_cfg only supports weight decay customization '
'in LearningRateDecayOptimWrapperConstructor, please indicate '
'the specific weight decay settings of different layers in '
'config if needed.')
model_type = optimizer_cfg.pop('model_type', None)
# model_type should not be None
@ -111,24 +121,25 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
elif model_type == 'swin':
num_layers = sum(module.backbone.depths) + 2
weight_decay = self.base_wd
# if layer_decay_rate is not provided, not decay
decay_rate = optimizer_cfg.pop('layer_decay_rate', 1.0)
parameter_groups = {}
assert self.base_wd is not None
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
# will not decay normalization params, bias, position embedding
# class token, relative position bias table
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'backbone.pos_embed', 'backbone.cls_token'
) or 'relative_position_bias_table' in name:
this_weight_decay = self.base_wd
for key in sorted_keys:
if key in name:
decay_mult = custom_keys[key].get('decay_mult', 1.)
this_weight_decay = self.base_wd * decay_mult
if this_weight_decay == 0:
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
if model_type == 'vit':
layer_id = get_layer_id_for_vit(name, num_layers)

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .barlowtwins import BarlowTwins
from .base import BaseModel
from .beit import BEiT
from .byol import BYOL
from .cae import CAE
from .deepcluster import DeepCluster
@ -19,7 +20,7 @@ from .simsiam import SimSiam
from .swav import SwAV
__all__ = [
'BaseModel', 'BarlowTwins', 'BYOL', 'DeepCluster', 'DenseCL', 'MoCo',
'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV',
'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat'
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat'
]

View File

@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
import torch
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from .base import BaseModel
@MODELS.register_module()
class BEiT(BaseModel):
"""BEiT v1/v2.
Implementation of `BEiT: BERT Pre-Training of Image Transformers
<https://arxiv.org/abs/2106.08254>`_. Implementation of `BEiT v2: Masked
Image Modeling with Vector-Quantized Visual Tokenizers
<https://arxiv.org/abs/2208.06366>`_.
"""
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
img_latent = self.backbone(batch_inputs[0], mask)
# batch_inputs[1] is the target image
with torch.no_grad():
target = self.target_generator(batch_inputs[1])
target = target.detach()
if self.with_neck:
# BEiT v2
feats, feats_cls_pt = self.neck(
img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias)
loss = self.head(feats, feats_cls_pt, target, mask)
else:
# BEiT v1
loss = self.head(img_latent[0], target, mask)
if isinstance(loss, torch.Tensor):
losses = dict(loss=loss)
return losses
elif isinstance(loss, Tuple):
# the loss_1 and loss_2 are general reconstruction loss (patch
# feature vectors from last layer of backbone) and early state
# reconstruction loss (patch feature vectors from intermediate
# layer of backbone)
loss_1, loss_2 = loss[0], loss[1]
losses = dict()
# the key with prefix 'loss', like loss_1 and loss_2, will be used
# as the final criterion
losses['loss_1'] = loss_1
losses['loss_2'] = loss_2
return losses

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit_vit import BEiTViT
from .cae_vit import CAEViT
from .mae_vit import MAEViT
from .maskfeat_vit import MaskFeatViT
@ -9,5 +10,5 @@ from .simmim_swin import SimMIMSwinTransformer
__all__ = [
'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT',
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT'
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT'
]

View File

@ -0,0 +1,188 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import torch
from mmcls.models import BEiT, resize_pos_embed
from mmengine.model.weight_init import trunc_normal_
from torch import nn
from mmselfsup.registry import MODELS
@MODELS.register_module()
class BEiTViT(BEiT):
"""Vision Transformer for BEiT pre-training.
Rewritten version of: `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
use_abs_pos_emb (bool): Whether or not use absolute position embedding.
Defaults to False.
use_rel_pos_bias (bool): Whether or not use relative position bias.
Defaults to False.
use_shared_rel_pos_bias (bool): Whether or not use shared relative
position bias. Defaults to True.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: str = 'base',
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
out_indices: int = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
avg_token: bool = False,
frozen_stages: int = -1,
output_cls_token: bool = True,
use_abs_pos_emb: bool = False,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = True,
layer_scale_init_value: int = 0.1,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(padding=0),
layer_cfgs: dict = dict(),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
avg_token=avg_token,
frozen_stages=frozen_stages,
output_cls_token=output_cls_token,
use_abs_pos_emb=use_abs_pos_emb,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
use_rel_pos_bias=use_rel_pos_bias,
layer_scale_init_value=layer_scale_init_value,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
trunc_normal_(self.cls_token, std=0.02)
trunc_normal_(self.mask_token, std=0.02)
self.rescale_init_weight()
def rescale_init_weight(self) -> None:
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.layers):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
def forward(self, x: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor]:
"""The BEiT style forward function.
Args:
x (torch.Tensor): Input images, which is of shape (B x C x H x W).
mask (torch.Tensor): Mask for input, which is of shape
(B x patch_resolution[0] x patch_resolution[1]).
Returns:
Tuple[torch.Tensor]: Hidden features.
"""
x, patch_resolution = self.patch_embed(x)
# replace the masked visual tokens by mask_token
B, L, _ = x.shape
mask_token = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
x = x * (1. - w) + mask_token * w
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
self.shared_rel_pos_bias = self.rel_pos_bias().to(
mask.device) if self.rel_pos_bias is not None else None
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, rel_pos_bias=self.shared_rel_pos_bias)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beitv1_head import BEiTV1Head
from .beitv2_head import BEiTV2Head
from .cae_head import CAEHead
from .cls_head import ClsHead
from .contrastive_head import ContrastiveHead
@ -11,7 +13,8 @@ from .simmim_head import SimMIMHead
from .swav_head import SwAVHead
__all__ = [
'ContrastiveHead', 'ClsHead', 'LatentPredictHead',
'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEPretrainHead',
'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead', 'MaskFeatPretrainHead'
'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead',
'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead',
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead',
'MaskFeatPretrainHead'
]

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmselfsup.registry import MODELS
@MODELS.register_module()
class BEiTV1Head(BaseModule):
"""Pretrain Head for BEiT v1.
Compute the logits and the cross entropy loss.
Args:
embed_dims (int): The dimension of embedding.
num_embed (int): The number of classification types.
loss (dict): The config of loss.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_embed: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
self.cls_head = nn.Linear(embed_dims, num_embed)
self.loss = MODELS.build(loss)
def forward(self, feats: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Generate loss.
Args:
feats (torch.Tensor): Features from backbone.
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
mask = mask.flatten(1).to(torch.bool)
target = torch.argmax(target, dim=1).flatten(1)
target = target[mask]
# remove cls_token
feats = feats[:, 1:]
logits = self.cls_head(feats[mask])
loss = self.loss(logits, target)
return loss

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmselfsup.registry import MODELS
@MODELS.register_module()
class BEiTV2Head(BaseModule):
"""Pretrain Head for BEiT.
Compute the logits and the cross entropy loss.
Args:
embed_dims (int): The dimension of embedding.
num_embed (int): The number of classification types.
loss (dict): The config of loss.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_embed: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
self.cls_head = nn.Linear(embed_dims, num_embed)
self.loss = MODELS.build(loss)
def forward(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor,
target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate loss.
Args:
feats (torch.Tensor): Features from backbone.
feats_cls_pt (torch.Tensor) : Features from class late layers for
pretraining.
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
mask = mask.flatten(1).to(torch.bool)
target = target[mask]
# shared cls head
logits = self.cls_head(feats[mask])
logits_cls_pt = self.cls_head(feats_cls_pt[mask])
loss = self.loss((logits, logits_cls_pt), target)
return loss

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit_loss import BEiTLoss
from .cae_loss import CAELoss
from .cosine_similarity_loss import CosineSimilarityLoss
from .cross_correlation_loss import CrossCorrelationLoss
@ -8,7 +9,7 @@ from .simmim_loss import SimMIMReconstructionLoss
from .swav_loss import SwAVLoss
__all__ = [
'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss',
'BEiTLoss', 'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss',
'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss',
'PixelReconstructionLoss'
]

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
from mmengine.model import BaseModule
from torch import nn
from mmselfsup.registry import MODELS
@MODELS.register_module()
class BEiTLoss(BaseModule):
"""Loss function for BEiT.
The BEiTLoss supports 2 diffenrent logits shared 1 target, like BEiT v2.
"""
def __init__(self) -> None:
super().__init__()
self.loss_cross_entropy = nn.CrossEntropyLoss()
def forward(self, logits: Union[Tuple[torch.Tensor], torch.Tensor],
target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function of BEiT Loss.
Args:
logits (torch.Tensor): The outputs from the decoder.
target (torch.Tensor): The targets generated by dalle.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The main loss.
"""
if isinstance(logits, torch.Tensor):
loss = self.loss_cross_entropy(logits, target)
return loss
elif isinstance(logits, Tuple):
loss_1 = self.loss_cross_entropy(logits[0], target)
loss_2 = self.loss_cross_entropy(logits[1], target)
return loss_1, loss_2

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .avgpool2d_neck import AvgPool2dNeck
from .beitv2_neck import BEiTV2Neck
from .cae_neck import CAENeck
from .densecl_neck import DenseCLNeck
from .linear_neck import LinearNeck
@ -12,7 +13,7 @@ from .simmim_neck import SimMIMNeck
from .swav_neck import SwAVNeck
__all__ = [
'AvgPool2dNeck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck',
'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck',
'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck',
'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'ClsBatchNormNeck'
]

View File

@ -0,0 +1,153 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcls.models.backbones.beit import BEiTTransformerEncoderLayer
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmselfsup.registry import MODELS
@MODELS.register_module()
class BEiTV2Neck(BaseModule):
"""Neck for BEiTV2 Pre-training.
This module construct the decoder for the final prediction.
Args:
num_layers (int): Number of encoder layers of neck. Defaults to 2.
early_layers (int): The layer index of the early output from the
backbone. Defaults to 9.
backbone_arch (str): Vision Transformer architecture. Defaults to base.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initialization value for the
learnable scaling of attention and FFN. Defaults to 0.1.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'depth': 12,
'num_heads': 12,
'feedforward_channels': 3072,
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'depth': 24,
'num_heads': 16,
'feedforward_channels': 4096,
}),
}
def __init__(
self,
num_layers: int = 2,
early_layers: int = 9,
backbone_arch: str = 'base',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: float = 0.1,
use_rel_pos_bias: bool = False,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(backbone_arch, str):
backbone_arch = backbone_arch.lower()
assert backbone_arch in set(self.arch_zoo), \
(f'Arch {backbone_arch} is not in default archs '
f'{set(self.arch_zoo)}')
self.arch_settings = self.arch_zoo[backbone_arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(backbone_arch, dict) and essential_keys <= set(
backbone_arch
), f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = backbone_arch
# stochastic depth decay rule
self.early_layers = early_layers
depth = self.arch_settings['depth']
dpr = np.linspace(0, drop_path_rate,
max(depth, early_layers + num_layers))
self.patch_aggregation = nn.ModuleList()
for i in range(early_layers, early_layers + num_layers):
_layer_cfg = dict(
embed_dims=self.arch_settings['embed_dims'],
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
window_size=None,
use_rel_pos_bias=use_rel_pos_bias)
self.patch_aggregation.append(
BEiTTransformerEncoderLayer(**_layer_cfg))
self.rescale_patch_aggregation_init_weight()
embed_dims = self.arch_settings['embed_dims']
_, norm = build_norm_layer(norm_cfg, embed_dims)
self.add_module('norm', norm)
def rescale_patch_aggregation_init_weight(self):
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.patch_aggregation):
rescale(layer.attn.proj.weight.data,
self.early_layers + layer_id + 1)
rescale(layer.ffn.layers[1].weight.data,
self.early_layers + layer_id + 1)
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x (Tuple[torch.Tensor]): Features of tokens.
rel_pos_bias (torch.Tensor): Shared relative position bias table.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- ``x``: The final layer features from backbone, which are normed
in ``BEiTV2Neck``.
- ``x_cls_pt``: The early state features from backbone, which are
consist of final layer cls_token and early state patch_tokens
from backbone and sent to PatchAggregation layers in the neck.
"""
early_states, x = inputs[0], inputs[1]
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
for layer in self.patch_aggregation:
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
# shared norm
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
# remove cls_token
x = x[:, 1:]
x_cls_pt = x_cls_pt[:, 1:]
return x, x_cls_pt

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dall_e import Encoder
from .hog_generator import HOGGenerator
from .vqkd import VQKD
__all__ = [
'HOGGenerator',
]
__all__ = ['HOGGenerator', 'VQKD', 'Encoder']

View File

@ -0,0 +1,180 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from BEiT
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
import math
from collections import OrderedDict
from functools import partial
from typing import List, Optional, Union
import attr
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmselfsup.registry import MODELS
@attr.s(eq=False)
class Conv2d(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
use_float16: bool = attr.ib(default=True)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw),
dtype=torch.float32,
device=self.device,
requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2))
b = torch.zeros((self.n_out, ),
dtype=torch.float32,
device=self.device,
requires_grad=self.requires_grad)
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
@attr.s(eq=False, repr=False)
class EncoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers**2)
make_conv = partial(
Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(
self.n_in, self.n_out,
1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(
OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 1)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
@MODELS.register_module(name='DALL-E')
class Encoder(BaseModule):
group_count: int = 4
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
init_cfg: Optional[Union[dict, List[dict]]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
super().__init__(init_cfg=self.init_cfg)
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(
Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(
EncoderBlock,
n_layers=n_layers,
device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(
OrderedDict([
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
('group_1',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(1 * self.n_hid, 1 * self.n_hid))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
1 * self.n_hid if i == 0 else 2 * self.n_hid,
2 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
2 * self.n_hid if i == 0 else 4 * self.n_hid,
4 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
4 * self.n_hid if i == 0 else 8 * self.n_hid,
8 * self.n_hid)) for i in blk_range],
]))),
('output',
nn.Sequential(
OrderedDict([
('relu', nn.ReLU()),
('conv',
make_conv(
8 * self.n_hid,
self.vocab_size,
1,
use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.float()
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model \
built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)

View File

@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple
import torch
from einops import rearrange
from mmcls.models import BEiT
from mmengine.model import BaseModule
from torch import nn
from mmselfsup.models.utils import NormEMAVectorQuantizer
from mmselfsup.registry import MODELS
@MODELS.register_module()
class VQKD(BaseModule):
"""Vector-Quantized Knowledge Distillation.
The module only contains encoder and VectorQuantizer part
Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py
Args:
encoder_config (dict): The config of encoder.
decoder_config (dict, optional): The config of decoder. Currently,
VQKD only support to build encoder. Defaults to None.
num_embed (int): Number of embedding vectors in the codebook. Defaults
to 8192.
embed_dims (int) : The dimension of embedding vectors in the codebook.
Defaults to 32.
decay (float): The decay parameter of EMA. Defaults to 0.99.
beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1.
quantize_kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
""" # noqa: E501
def __init__(self,
encoder_config: dict,
decoder_config: Optional[dict] = None,
num_embed: int = 8192,
embed_dims: int = 32,
decay: float = 0.99,
beta: float = 1.0,
quantize_kmeans_init: bool = True,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.encoder = BEiT(**encoder_config)
if decoder_config is not None:
self.decoder = BEiT(**decoder_config)
self.quantize = NormEMAVectorQuantizer(
num_embed=num_embed,
embed_dims=embed_dims,
beta=beta,
decay=decay,
kmeans_init=quantize_kmeans_init,
)
# task layer
self.encode_task_layer = nn.Sequential(
nn.Linear(self.encoder.arch_settings['embed_dims'],
self.encoder.arch_settings['embed_dims']), nn.Tanh(),
nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims))
def get_tokens(self, x: torch.Tensor) -> dict:
"""Get tokens for beit pre-training."""
_, embed_ind, _ = self.encode(x)
output = {}
output['token'] = embed_ind.view(x.shape[0], -1)
output['input_img'] = x
return output
def encode(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode the input images and get corresponding results."""
encoder_features = self.encoder(x)[0]
B, C, N1, N2 = encoder_features.shape
encoder_features = encoder_features.permute(0, 2, 3,
1).reshape(B, N1 * N2, C)
with torch.cuda.amp.autocast(enabled=False):
to_quantizer_features = self.encode_task_layer(
encoder_features.type_as(self.encode_task_layer[-1].weight))
N = to_quantizer_features.shape[1]
h, w = int(math.sqrt(N)), int(math.sqrt(N))
to_quantizer_features = rearrange(
to_quantizer_features, 'b (h w) c -> b c h w', h=h,
w=w) # reshape for quantizer
quantize, loss, embed_ind = self.quantize(to_quantizer_features)
return quantize, embed_ind, loss
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The forward function.
Currently, only support to get tokens.
"""
return self.get_tokens(x)['token']

View File

@ -3,7 +3,8 @@ from .dall_e import Encoder
from .data_preprocessor import (CAEDataPreprocessor,
RelativeLocDataPreprocessor,
RotationPredDataPreprocessor,
SelfSupDataPreprocessor)
SelfSupDataPreprocessor,
TwoNormDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
@ -13,6 +14,7 @@ from .position_embedding import build_2d_sincos_position_embedding
from .sobel import Sobel
from .transformer_blocks import (CAETransformerRegressorLayer,
MultiheadAttention, TransformerEncoderLayer)
from .vector_quantizer import NormEMAVectorQuantizer
try:
from .res_layer_extra_norm import ResLayerExtraNorm
@ -20,9 +22,22 @@ except ImportError:
ResLayerExtraNorm = None
__all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm'
'Extractor',
'GatherLayer',
'MultiPooling',
'MultiPrototypes',
'build_2d_sincos_position_embedding',
'Sobel',
'MultiheadAttention',
'TransformerEncoderLayer',
'CAETransformerRegressorLayer',
'Encoder',
'CosineEMA',
'SelfSupDataPreprocessor',
'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor',
'CAEDataPreprocessor',
'ResLayerExtraNorm',
'NormEMAVectorQuantizer',
'TwoNormDataPreprocessor',
]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union
import torch
from mmengine.model import ImgDataPreprocessor
@ -179,3 +179,114 @@ class CAEDataPreprocessor(SelfSupDataPreprocessor):
batch_inputs[1] / 255. * 0.8 + 0.1]
return batch_inputs, batch_data_samples
@MODELS.register_module()
class TwoNormDataPreprocessor(SelfSupDataPreprocessor):
"""Image pre-processor for CAE, BEiT v1/v2, etc.
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
will normalize the prediction image and target image with different
normalization parameters.
Args:
mean (Sequence[float or int], optional): The pixel mean of image
channels. If ``bgr_to_rgb=True`` it means the mean value of R,
G, B channels. If the length of `mean` is 1, it means all
channels have the same mean value, or the input is a gray image.
If it is not specified, images will not be normalized. Defaults
None.
std (Sequence[float or int], optional): The pixel standard deviation of
image channels. If ``bgr_to_rgb=True`` it means the standard
deviation of R, G, B channels. If the length of `std` is 1,
it means all channels have the same standard deviation, or the
input is a gray image. If it is not specified, images will
not be normalized. Defaults None.
second_mean (Sequence[float or int], optional): The description is
like ``mean``, it can be customized for targe image. Defaults None.
second_std (Sequence[float or int], optional): The description is
like ``std``, it can be customized for targe image. Defaults None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (float or int): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
non_blocking (bool): Whether block current process
when transferring data to device.
"""
def __init__(self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
second_mean: Sequence[Union[float, int]] = None,
second_std: Sequence[Union[float, int]] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False):
super().__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
pad_value=pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
non_blocking=non_blocking)
assert (second_mean is not None) and (second_std is not None), (
'mean and std should not be None while using '
'`TwoNormDataPreprocessor`')
assert len(second_mean) == 3 or len(second_mean) == 1, (
'`mean` should have 1 or 3 values, to be compatible with '
f'RGB or gray image, but got {len(second_mean)} values')
assert len(second_std) == 3 or len(second_std) == 1, (
'`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501
f'or gray image, but got {len(std)} values') # type: ignore
self.register_buffer('second_mean',
torch.tensor(second_mean).view(-1, 1, 1), False)
self.register_buffer('second_std',
torch.tensor(second_std).view(-1, 1, 1), False)
def forward(
self,
data: dict,
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
data = [val for _, val in data.items()]
batch_inputs, batch_data_samples = self.cast_data(data)
# channel transform
if self._channel_conversion:
batch_inputs = [
_input[:, [2, 1, 0], ...] for _input in batch_inputs
]
# Convert to float after channel conversion to ensure
# efficiency
batch_inputs = [input_.float() for input_ in batch_inputs]
# Normalization. Here is what is different from
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
# image and prediction image with different normalization params
if self._enable_normalize:
batch_inputs = [
(batch_inputs[0] - self.mean) / self.std,
(batch_inputs[1] - self.second_mean) / self.second_std
]
return batch_inputs, batch_data_samples

View File

@ -0,0 +1,232 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) 2022 Microsoft
# Modified from
# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from mmengine.dist import all_reduce
def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
decay: torch.Tensor) -> None:
"""Update moving average."""
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
decay: torch.Tensor) -> None:
"""Update moving average with norm data."""
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1))
def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
"""Sample vectors according to the given number."""
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num, ), device=device)
return samples[indices]
def kmeans(samples: torch.Tensor,
num_clusters: int,
num_iters: int = 10,
use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run k-means algorithm."""
dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, 'n d -> n () d') \
- rearrange(means, 'c d -> () c d')
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = F.normalize(new_means, p=2, dim=-1)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EmbeddingEMA(nn.Module):
"""The codebook of embedding vectors.
Args:
num_tokens (int): Number of embedding vectors in the codebook.
codebook_dim (int) : The dimension of embedding vectors in the
codebook.
kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
codebook_init_path (str): The initialization checkpoint for codebook.
Defaults to None.
"""
def __init__(self,
num_tokens: int,
codebook_dim: int,
kmeans_init: bool = True,
codebook_init_path: Optional[str] = None):
super().__init__()
self.num_tokens = num_tokens
self.codebook_dim = codebook_dim
if codebook_init_path is None:
if not kmeans_init:
weight = torch.randn(num_tokens, codebook_dim)
weight = F.normalize(weight, p=2, dim=-1)
else:
weight = torch.zeros(num_tokens, codebook_dim)
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
else:
print(f'load init codebook weight from {codebook_init_path}')
codebook_ckpt_weight = torch.load(
codebook_init_path, map_location='cpu')
weight = codebook_ckpt_weight.clone()
self.register_buffer('initted', torch.Tensor([True]))
self.weight = nn.Parameter(weight, requires_grad=False)
self.update = True
@torch.jit.ignore
def init_embed_(self, data: torch.Tensor) -> None:
"""Initialize embedding vectors of codebook."""
if self.initted:
return
print('Performing K-means init for codebook')
embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
self.weight.data.copy_(embed)
self.initted.data.copy_(torch.Tensor([True]))
def forward(self, embed_id: torch.Tensor) -> torch.Tensor:
"""Get embedding vectors."""
return F.embedding(embed_id, self.weight)
class NormEMAVectorQuantizer(nn.Module):
"""Normed EMA vector quantizer module.
Args:
num_embed (int): Number of embedding vectors in the codebook. Defaults
to 8192.
embed_dims (int) : The dimension of embedding vectors in the codebook.
Defaults to 32.
beta (float): The mutiplier for VectorQuantizer embedding loss.
Defaults to 1.
decay (float): The decay parameter of EMA. Defaults to 0.99.
statistic_code_usage (bool): Whether to use cluster_size to record
statistic. Defaults to True.
kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
codebook_init_path (str): The initialization checkpoint for codebook.
Defaults to None.
"""
def __init__(self,
num_embed: int,
embed_dims: int,
beta: float,
decay: float = 0.99,
statistic_code_usage: bool = True,
kmeans_init: bool = True,
codebook_init_path: Optional[str] = None) -> None:
super().__init__()
self.codebook_dim = embed_dims
self.num_tokens = num_embed
self.beta = beta
self.decay = decay
# learnable = True if orthogonal_reg_weight > 0 else False
self.embedding = EmbeddingEMA(
num_tokens=self.num_tokens,
codebook_dim=self.codebook_dim,
kmeans_init=kmeans_init,
codebook_init_path=codebook_init_path)
self.statistic_code_usage = statistic_code_usage
if statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(num_embed))
def reset_cluster_size(self, device):
if self.statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
self.cluster_size = self.cluster_size.to(device)
def forward(self, z):
"""Forward function."""
# reshape z -> (batch, height, width, channel)
z = rearrange(z, 'b c h w -> b h w c')
z = F.normalize(z, p=2, dim=-1)
z_flattened = z.reshape(-1, self.codebook_dim)
self.embedding.init_embed_(z_flattened)
# 'n d -> d n'
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
if not self.training:
with torch.no_grad():
cluster_size = encodings.sum(0)
all_reduce(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
if self.training and self.embedding.update:
# update cluster size with EMA
bins = encodings.sum(0)
all_reduce(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = (bins == 0)
bins = bins.masked_fill(zero_mask, 1.)
embed_sum = z_flattened.t() @ encodings
all_reduce(embed_sum)
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = F.normalize(embed_normalized, p=2, dim=-1)
embed_normalized = torch.where(zero_mask[..., None],
self.embedding.weight,
embed_normalized)
# Update embedding vectors with EMA
norm_ema_inplace(self.embedding.weight, embed_normalized,
self.decay)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, encoding_indices

View File

@ -17,3 +17,4 @@ Import:
- configs/selfsup/barlowtwins/metafile.yml
- configs/selfsup/cae/metafile.yml
- configs/selfsup/maskfeat/metafile.yml
- configs/selfsup/beit/metafile.yml

View File

@ -12,7 +12,7 @@ class ToyViTBackbone(nn.Module):
def __init__(self):
super().__init__()
self.cls_token = nn.Parameter(torch.ones(1))
self.patch_embed = nn.Parameter(torch.ones(1))
self.pos_embed = nn.Parameter(torch.ones(1))
self.layers = nn.ModuleList()
for _ in range(2):
layer = nn.Conv2d(3, 3, 1)
@ -87,32 +87,47 @@ def test_learning_rate_decay_optimizer_wrapper_constructor():
weight_decay=base_wd,
model_type='vit',
layer_decay_rate=2.0))
paramwise_cfg = dict(
custom_keys={
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
})
# test when model_type is None
with pytest.raises(AssertionError):
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = None
optimizer_wrapper = optimizer_wrapper_constructor(model)
# test when model_type is invalid
with pytest.raises(AssertionError):
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'invalid'
optimizer_wrapper = optimizer_wrapper_constructor(model)
# test vit
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'vit'
optimizer_wrapper = optimizer_wrapper_constructor(model)
check_optimizer_lr_wd(optimizer_wrapper, expected_layer_wise_wd_lr_vit)
# test swin
paramwise_cfg = dict(
custom_keys={
'.norm': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
})
model = ToySwin()
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'swin'
optimizer_wrapper = optimizer_wrapper_constructor(model)
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmengine.structures import InstanceData
from mmselfsup.models import BEiT
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(-20.4, -20.4, -20.4),
second_std=(204., 204., 204.),
bgr_to_rgb=True)
# model settings
backbone = dict(
type='BEiTViT',
arch='base',
patch_size=16,
drop_path_rate=0.1,
final_norm=True,
layer_scale_init_value=0.1,
)
neck = None
head = dict(
type='BEiTV1Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss'))
target_generator = dict(type='DALL-E')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_beitv1():
register_all_modules()
model = BEiT(
backbone=backbone,
neck=neck,
head=head,
target_generator=target_generator,
data_preprocessor=data_preprocessor)
fake_img = torch.rand((1, 3, 224, 224))
fake_target_img = torch.rand((1, 3, 112, 112))
fake_mask = torch.zeros((196)).bool()
fake_mask[75:150] = 1
fake_data_sample = SelfSupDataSample()
fake_mask = InstanceData(value=fake_mask)
fake_data_sample.mask = fake_mask
fake_data_sample = [fake_data_sample]
fake_data = {
'inputs': [fake_img, fake_target_img],
'data_sample': fake_data_sample
}
fake_batch_inputs, fake_data_samples = model.data_preprocessor(fake_data)
fake_outputs = model(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmengine.structures import InstanceData
from mmselfsup.models import BEiT
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(127.5, 127.5, 127.5),
second_std=(127.5, 127.5, 127.5),
bgr_to_rgb=True)
# model settings
vqkd_encoder = dict(
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
layer_scale_init_value=0.,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None)
layer_scale_init_value = 0.1
drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs.
backbone = dict(
type='BEiTViT',
arch='base',
patch_size=16,
out_indices=[-4, -1],
drop_path_rate=drop_path_rate,
final_norm=False,
layer_scale_init_value=layer_scale_init_value,
)
neck = dict(
type='BEiTV2Neck',
num_layers=1,
early_layers=9,
backbone_arch='base',
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
)
head = dict(
type='BEiTV2Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss'))
target_generator = dict(type='VQKD', encoder_config=vqkd_encoder)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_beitv2():
register_all_modules()
model = BEiT(
backbone=backbone,
neck=neck,
head=head,
target_generator=target_generator,
data_preprocessor=data_preprocessor)
fake_img = torch.rand((1, 3, 224, 224))
fake_target_img = torch.rand((1, 3, 224, 224))
fake_mask = torch.zeros((196)).bool()
fake_mask[75:150] = 1
fake_data_sample = SelfSupDataSample()
fake_mask = InstanceData(value=fake_mask)
fake_data_sample.mask = fake_mask
fake_data_sample = [fake_data_sample]
fake_data = {
'inputs': [fake_img, fake_target_img],
'data_sample': fake_data_sample
}
fake_batch_inputs, fake_data_samples = model.data_preprocessor(fake_data)
fake_outputs = model(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss_1'].item(), float)
assert isinstance(fake_outputs['loss_2'].item(), float)

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.backbones import BEiTViT
backbone = dict(
arch='base',
patch_size=16,
drop_path_rate=0.1,
final_norm=True,
layer_scale_init_value=0.1,
)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_beit_vit():
beit_backbone = BEiTViT(**backbone)
beit_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_mask = torch.zeros((2, 196))
fake_mask[:, 75:150] = 1
fake_outputs = beit_backbone(fake_inputs, fake_mask)
assert list(fake_outputs[0].shape) == [2, 197, 768]

View File

@ -4,7 +4,7 @@ import platform
import pytest
import torch
from mmselfsup.models.utils import Encoder
from mmselfsup.models.target_generators import Encoder
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.target_generators import VQKD
vqkd_encoder = dict(
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
layer_scale_init_value=0.,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_vqkd():
model = VQKD(encoder_config=vqkd_encoder)
fake_inputs = torch.rand((2, 3, 224, 224))
fake_outputs = model(fake_inputs)
assert list(fake_outputs.shape) == [2, 196]

View File

@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmselfsup.models.utils import SelfSupDataPreprocessor
from mmselfsup.models.utils import (SelfSupDataPreprocessor,
TwoNormDataPreprocessor)
from mmselfsup.structures import SelfSupDataSample
@ -16,3 +18,51 @@ def test_selfsup_data_preprocessor():
fake_batches, fake_samples = data_preprocessor(fake_data)
assert len(fake_batches) == 1
assert len(fake_samples) == 2
def test_two_norm_data_preprocessor():
with pytest.raises(AssertionError):
data_preprocessor = TwoNormDataPreprocessor(
rgb_to_bgr=True,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
)
with pytest.raises(AssertionError):
data_preprocessor = TwoNormDataPreprocessor(
rgb_to_bgr=True,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(127.5, 127.5),
second_std=(127.5, 127.5, 127.5),
)
with pytest.raises(AssertionError):
data_preprocessor = TwoNormDataPreprocessor(
rgb_to_bgr=True,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(127.5, 127.5, 127.5),
second_std=(127.5, 127.5),
)
data_preprocessor = dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
second_mean=(127.5, 127.5, 127.5),
second_std=(127.5, 127.5, 127.5),
bgr_to_rgb=True)
data_preprocessor = TwoNormDataPreprocessor(**data_preprocessor)
fake_data = {
'inputs':
[torch.randn((4, 3, 224, 224)),
torch.randn((4, 3, 224, 224))],
'data_sample': [
SelfSupDataSample(),
SelfSupDataSample(),
SelfSupDataSample(),
SelfSupDataSample()
]
}
fake_batches, fake_samples = data_preprocessor(fake_data)
assert len(fake_batches) == 2
assert len(fake_samples) == 4