[Refactor] Refactor BEiT backbone and support v1/v2 inference. (#1144)
* refactor beit backbone * use LinearClsHead * fix mean and std value * fix lint * support inference if beit-v2 * update encoder layer and init * update * add ut * add prepare_relative_position_bias_table function * add cls_token * fix lint * add pos_embed check * update metafile and readme * update weights link * update link of weights * update metafile * update * update docstrings * update according to review * rename readme * update docstring * fix lintpull/1125/merge
parent
35fb03a577
commit
d80ec5a4b8
|
@ -152,6 +152,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -158,6 +158,7 @@ mim install -e .
|
|||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# BEiT
|
||||
|
||||
> [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/203688351-adac7146-4e71-4ab6-8958-5cfe643a2dc5.png" width="70%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :---------: | :----------: | :-------: | :------: | :-------: | :-------: | :-------------------------------------: | :-----------------------------------------------------------------------------------------------------: |
|
||||
| BEiT-base\* | ImageNet-21k | 86.53 | 17.58 | 85.28 | 97.59 | [config](./beit-base-p16_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/beit/beit-base_3rdparty_in1k_20221114-c0a4df23.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/microsoft/unilm/tree/master/beit). The config files of these models are only for inference.*
|
||||
|
||||
For BEiT self-supervised learning algorithm, welcome to [MMSelfSup page](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beit) to get more information.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{beit,
|
||||
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
|
||||
author={Hangbo Bao and Li Dong and Furu Wei},
|
||||
year={2021},
|
||||
eprint={2106.08254},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,44 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='BEiT',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,41 @@
|
|||
Collections:
|
||||
- Name: BEiT
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Attention Dropout
|
||||
- Convolution
|
||||
- Dense Connections
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
- Tanh Activation
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2106.08254
|
||||
Title: 'BEiT: BERT Pre-Training of Image Transformers'
|
||||
README: configs/beit/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/mmcls/models/backbones/beit.py
|
||||
Version: v1.0.0rc4
|
||||
|
||||
Models:
|
||||
- Name: beit-base_3rdparty_in1k
|
||||
In Collection: BEiT
|
||||
Metadata:
|
||||
FLOPs: 17581219584
|
||||
Parameters: 86530984
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
- ImageNet-1k
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.28
|
||||
Top 5 Accuracy: 97.59
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/beit/beit-base_3rdparty_in1k_20221114-c0a4df23.pth
|
||||
Converted From:
|
||||
Weights: https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth
|
||||
Code: https://github.com/microsoft/unilm/tree/master/beit
|
||||
Config: configs/beit/beit-base-p16_8xb64_in1k.py
|
|
@ -0,0 +1,38 @@
|
|||
# BEiT V2
|
||||
|
||||
> [BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers](https://arxiv.org/abs/2208.06366)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Masked image modeling (MIM) has demonstrated impressive results in self-supervised representation learning by recovering corrupted image patches. However, most existing studies operate on low-level image pixels, which hinders the exploitation of high-level semantics for representation models. In this work, we propose to use a semantic-rich visual tokenizer as the reconstruction target for masked prediction, providing a systematic way to promote MIM from pixel-level to semantic-level. Specifically, we propose vector-quantized knowledge distillation to train the tokenizer, which discretizes a continuous semantic space to compact codes. We then pretrain vision Transformers by predicting the original visual tokens for the masked image patches. Furthermore, we introduce a patch aggregation strategy which associates discrete image patches to enhance global semantic representation. Experiments on image classification and semantic segmentation show that BEiT v2 outperforms all compared MIM methods. On ImageNet-1K (224 size), the base-size BEiT v2 achieves 85.5% top-1 accuracy for fine-tuning and 80.1% top-1 accuracy for linear probing. The large-size BEiT v2 obtains 87.3% top-1 accuracy for ImageNet-1K (224 size) fine-tuning, and 56.7% mIoU on ADE20K for semantic segmentation.
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/203912182-5967a520-d455-49ea-bc67-dcbd500d76bf.png" width="70%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-----------: | :------------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------: | :-----------------------------------------------------------------------------------: |
|
||||
| BEiTv2-base\* | ImageNet-1k & ImageNet-21k | 86.53 | 17.58 | 86.47 | 97.99 | [config](./beitv2-base-p16_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/beit/beitv2-base_3rdparty_in1k_20221114-73e11905.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/microsoft/unilm/tree/master/beit2). The config files of these models are only for inference.*
|
||||
|
||||
For BEiTv2 self-supervised learning algorithm, welcome to [MMSelfSup page](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2) to get more information.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{beitv2,
|
||||
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
|
||||
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
|
||||
year={2022},
|
||||
eprint={2208.06366},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,35 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='BEiT',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,41 @@
|
|||
Collections:
|
||||
- Name: BEiTv2
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Attention Dropout
|
||||
- Convolution
|
||||
- Dense Connections
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
- Tanh Activation
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2208.06366
|
||||
Title: 'BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers'
|
||||
README: configs/beitv2/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/mmcls/models/backbones/beit.py
|
||||
Version: v1.0.0rc4
|
||||
|
||||
Models:
|
||||
- Name: beitv2-base_3rdparty_in1k
|
||||
In Collection: BEiTv2
|
||||
Metadata:
|
||||
FLOPs: 17581219584
|
||||
Parameters: 86530984
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
- ImageNet-1k
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 86.47
|
||||
Top 5 Accuracy: 97.99
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/beit/beitv2-base_3rdparty_in1k_20221114-73e11905.pth
|
||||
Converted From:
|
||||
Weights: https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth
|
||||
Code: https://github.com/microsoft/unilm/tree/master/beit2
|
||||
Config: configs/beitv2/beitv2-base-p16_8xb64_in1k.py
|
|
@ -58,6 +58,7 @@ Backbones
|
|||
:template: classtemplate.rst
|
||||
|
||||
AlexNet
|
||||
BEiT
|
||||
CSPDarkNet
|
||||
CSPNet
|
||||
CSPResNeXt
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .beit import BEiT
|
||||
from .conformer import Conformer
|
||||
from .convmixer import ConvMixer
|
||||
from .convnext import ConvNeXt
|
||||
|
@ -97,4 +98,5 @@ __all__ = [
|
|||
'HorNet',
|
||||
'MobileViT',
|
||||
'DaViT',
|
||||
'BEiT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,515 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import (BEiTAttention, resize_pos_embed,
|
||||
resize_relative_position_bias_table, to_2tuple)
|
||||
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
|
||||
|
||||
|
||||
class RelativePositionBias(BaseModule):
|
||||
"""Relative Position Bias.
|
||||
|
||||
This module is copied from
|
||||
https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209.
|
||||
|
||||
Args:
|
||||
window_size (Sequence[int]): The window size of the relative
|
||||
position bias.
|
||||
num_heads (int): The number of head in multi-head attention.
|
||||
with_cls_token (bool): To indicate the backbone has cls_token or not.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: Sequence[int],
|
||||
num_heads: int,
|
||||
with_cls_token: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
if with_cls_token:
|
||||
num_extra_tokens = 3
|
||||
else:
|
||||
num_extra_tokens = 0
|
||||
# cls to token & token to cls & cls to cls
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1) + num_extra_tokens
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance,
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each
|
||||
# token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] -\
|
||||
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
if with_cls_token:
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1, ) * 2,
|
||||
dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
else:
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1], ) * 2,
|
||||
dtype=relative_coords.dtype)
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
# Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1)
|
||||
return relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Implements one encoder layer in BEiT.
|
||||
|
||||
Comparing with conventional ``TransformerEncoderLayer``, this module
|
||||
adds weights to the shortcut connection. In addition, ``BEiTAttention``
|
||||
is used to replace the original ``MultiheadAttention`` in
|
||||
``TransformerEncoderLayer``.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN. 1 means no scaling.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
Defaults to None.
|
||||
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
||||
if False, use shared relative position bias defined in backbone.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Defaults to 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
||||
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
||||
only add leanable bias for q, v. If bias is False, it will not add
|
||||
bias for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='LN').
|
||||
attn_cfg (dict): The configuration for the attention layer.
|
||||
Defaults to an empty dict.
|
||||
ffn_cfg (dict): The configuration for the ffn layer.
|
||||
Defaults to ``dict(add_identity=False)``.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels: int,
|
||||
layer_scale_init_value: float,
|
||||
window_size: Tuple[int, int],
|
||||
use_rel_pos_bias: bool,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
num_fcs: int = 2,
|
||||
bias: Union[str, bool] = 'qv_bias',
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
attn_cfg: dict = dict(),
|
||||
ffn_cfg: dict = dict(add_identity=False),
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
attn_cfg = dict(
|
||||
window_size=window_size,
|
||||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
qk_scale=None,
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
bias=bias)
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
ffn_cfg = dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
rel_pos_bias: torch.Tensor) -> torch.Tensor:
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(
|
||||
self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(VisionTransformer):
|
||||
"""Backbone for BEiT.
|
||||
|
||||
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
|
||||
<https://arxiv.org/abs/2106.08254>`_
|
||||
A PyTorch implement of : `BEiT v2: Masked Image Modeling with
|
||||
Vector-Quantized Visual Tokenizers <https://arxiv.org/abs/2208.06366>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): BEiT architecture. If use string, choose from
|
||||
'base', 'large'. If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
|
||||
Defaults to False.
|
||||
use_rel_pos_bias (bool): Use relative position embedding in each
|
||||
transformer encoder layer. Defaults to True.
|
||||
use_shared_rel_pos_bias (bool): Use shared relative position embedding,
|
||||
all transformer encoder layers share the same relative position
|
||||
embedding. Defaults to False.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN. Defaults to 0.1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0,
|
||||
drop_path_rate=0,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=False,
|
||||
with_cls_token=True,
|
||||
avg_token=True,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
layer_scale_init_value=0.1,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
||||
}
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.num_layers = self.arch_settings['num_layers']
|
||||
self.img_size = to_2tuple(img_size)
|
||||
|
||||
# Set patch embedding
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
||||
# Set position embedding
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_extra_tokens,
|
||||
self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
assert not (use_rel_pos_bias and use_shared_rel_pos_bias), (
|
||||
'`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set '
|
||||
'to True at the same time')
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_resolution,
|
||||
num_heads=self.arch_settings['num_heads'])
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_relative_position_bias_table)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
self.layers = ModuleList()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
layer_cfgs = [layer_cfgs] * self.num_layers
|
||||
for i in range(self.num_layers):
|
||||
_layer_cfg = dict(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.arch_settings['num_heads'],
|
||||
feedforward_channels=self.
|
||||
arch_settings['feedforward_channels'],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=self.patch_resolution,
|
||||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.avg_token = avg_token
|
||||
if avg_token:
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() \
|
||||
if self.rel_pos_bias is not None else None
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x, rel_pos_bias)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
|
||||
if self.avg_token:
|
||||
patch_token = patch_token.permute(0, 2, 3, 1)
|
||||
patch_token = patch_token.reshape(
|
||||
B, patch_resolution[0] * patch_resolution[1],
|
||||
C).mean(dim=1)
|
||||
patch_token = self.norm2(patch_token)
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
|
||||
**kwargs):
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
|
||||
if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501
|
||||
logger.info('Expand the shared relative position embedding to '
|
||||
'each transformer block.')
|
||||
rel_pos_bias = state_dict[
|
||||
'rel_pos_bias.relative_position_bias_table']
|
||||
for i in range(self.num_layers):
|
||||
state_dict[
|
||||
f'layers.{i}.attn.relative_position_bias_table'] = \
|
||||
rel_pos_bias.clone()
|
||||
state_dict.pop('rel_pos_bias.relative_position_bias_table')
|
||||
|
||||
state_dict_model = self.state_dict()
|
||||
all_keys = list(state_dict_model.keys())
|
||||
for key in all_keys:
|
||||
if 'relative_position_bias_table' in key:
|
||||
ckpt_key = prefix + key
|
||||
if ckpt_key not in state_dict:
|
||||
continue
|
||||
rel_pos_bias_pretrained = state_dict[ckpt_key]
|
||||
rel_pos_bias_current = state_dict_model[key]
|
||||
L1, nH1 = rel_pos_bias_pretrained.size()
|
||||
L2, nH2 = rel_pos_bias_current.size()
|
||||
src_size = int((L1 - 3)**0.5)
|
||||
dst_size = int((L2 - 3)**0.5)
|
||||
if L1 != L2:
|
||||
extra_tokens = rel_pos_bias_pretrained[-3:, :]
|
||||
rel_pos_bias = rel_pos_bias_pretrained[:-3, :]
|
||||
|
||||
new_rel_pos_bias = resize_relative_position_bias_table(
|
||||
src_size, dst_size, rel_pos_bias, nH1)
|
||||
new_rel_pos_bias = torch.cat(
|
||||
(new_rel_pos_bias, extra_tokens), dim=0)
|
||||
logger.info('Resize the relative_position_bias_table from '
|
||||
f'{state_dict[ckpt_key].shape} to '
|
||||
f'{new_rel_pos_bias.shape}')
|
||||
state_dict[ckpt_key] = new_rel_pos_bias
|
||||
|
||||
# The index buffer need to be re-generated.
|
||||
index_buffer = ckpt_key.replace('bias_table', 'index')
|
||||
del state_dict[index_buffer]
|
|
@ -1,18 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,
|
||||
to_2tuple)
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -100,116 +98,6 @@ class TransformerEncoderLayer(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Implements one encoder layer in BEiT.
|
||||
|
||||
Comparing with conventional ``TransformerEncoderLayer``, this module
|
||||
adds weights to the shortcut connection. In addition, ``BEiTAttention``
|
||||
is used to replace the original ``MultiheadAttention`` in
|
||||
``TransformerEncoderLayer``.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
Defaults to None.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Defaults to 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
||||
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
||||
only add leanable bias for q, v. If bias is False, it will not add
|
||||
bias for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='LN').
|
||||
attn_cfg (dict): The configuration for the attention layer.
|
||||
Defaults to an empty dict.
|
||||
ffn_cfg (dict): The configuration for the ffn layer.
|
||||
Defaults to ``dict(add_identity=False)``.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels: int,
|
||||
layer_scale_init_value: float,
|
||||
window_size: Tuple[int, int],
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
num_fcs: int = 2,
|
||||
bias: Union[str, bool] = 'qv_bias',
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
attn_cfg: dict = dict(),
|
||||
ffn_cfg: dict = dict(add_identity=False),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# overwrite the default attention layer in TransformerEncoderLayer
|
||||
attn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
bias=bias))
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
# overwrite the default ffn layer in TransformerEncoderLayer
|
||||
ffn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseBackbone):
|
||||
"""Vision Transformer.
|
||||
|
@ -255,9 +143,6 @@ class VisionTransformer(BaseBackbone):
|
|||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
beit_style (bool): Whether or not use BEiT-style. Defaults to False.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN. Defaults to 0.1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -338,8 +223,6 @@ class VisionTransformer(BaseBackbone):
|
|||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=True,
|
||||
beit_style=False,
|
||||
layer_scale_init_value=0.1,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
|
@ -423,15 +306,7 @@ class VisionTransformer(BaseBackbone):
|
|||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
if beit_style:
|
||||
_layer_cfg.update(
|
||||
dict(
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=self.patch_resolution))
|
||||
_layer_cfg.pop('qkv_bias')
|
||||
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
else:
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
self.final_norm = final_norm
|
||||
|
@ -462,7 +337,8 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
|
@ -494,7 +370,8 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
def _freeze_stages(self):
|
||||
# freeze position embedding
|
||||
self.pos_embed.requires_grad = False
|
||||
if self.pos_embed is not None:
|
||||
self.pos_embed.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.drop_after_pos.eval()
|
||||
# freeze patch embedding
|
||||
|
|
|
@ -565,6 +565,8 @@ class BEiTAttention(BaseModule):
|
|||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
||||
if False, use shared relative position bias defined in backbone.
|
||||
bias (str): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
|
@ -582,6 +584,7 @@ class BEiTAttention(BaseModule):
|
|||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
use_rel_pos_bias,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
|
@ -601,6 +604,7 @@ class BEiTAttention(BaseModule):
|
|||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
|
@ -613,48 +617,56 @@ class BEiTAttention(BaseModule):
|
|||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
if self.use_rel_pos_bias:
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch_meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch_meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
if self.use_rel_pos_bias:
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
rel_pos_bias (tensor): input relative position bias with shape of
|
||||
(num_heads, N, N).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
|
@ -678,6 +690,11 @@ class BEiTAttention(BaseModule):
|
|||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
# use shared relative position bias
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
|
|
|
@ -40,3 +40,5 @@ Import:
|
|||
- configs/davit/metafile.yml
|
||||
- configs/replknet/metafile.yml
|
||||
- configs/csra/metafile.yml
|
||||
- configs/beit/metafile.yml
|
||||
- configs/beitv2/metafile.yml
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import BEiT
|
||||
|
||||
|
||||
class TestBEiT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b', img_size=224, patch_size=16, drop_path_rate=0.1)
|
||||
|
||||
def test_structure(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
BEiT(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
BEiT(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
model = BEiT(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 24)
|
||||
self.assertIsNone(model.pos_embed)
|
||||
self.assertIsNone(model.rel_pos_bias)
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.attn.num_heads, 16)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 1024)
|
||||
|
||||
# Test out_indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
BEiT(**cfg)
|
||||
cfg['out_indices'] = [0, 13]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
|
||||
BEiT(**cfg)
|
||||
|
||||
# Test pos_embed
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_emb'] = True
|
||||
model = BEiT(**cfg)
|
||||
self.assertEqual(model.pos_embed.shape, (1, 197, 768))
|
||||
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.1
|
||||
model = BEiT(**cfg)
|
||||
self.assertEqual(len(model.layers), 12)
|
||||
dpr_inc = 0.1 / (12 - 1)
|
||||
dpr = 0
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.gamma_1.shape, (768, ))
|
||||
self.assertEqual(layer.gamma_2.shape, (768, ))
|
||||
self.assertEqual(layer.attn.embed_dims, 768)
|
||||
self.assertEqual(layer.attn.num_heads, 12)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 3072)
|
||||
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
|
||||
dpr += dpr_inc
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = True
|
||||
model = BEiT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768))
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = BEiT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768))
|
||||
|
||||
# test without average
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['avg_token'] = False
|
||||
model = BEiT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = BEiT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token = out
|
||||
self.assertEqual(patch_token.shape, (1, 768))
|
Loading…
Reference in New Issue