[Feature]: Support BEiT Transformer layer. (#919)

* [Feature]: Add BEiT-style transformer encoder layer

* [Feature]: Add average token

* [Fix]: Fix lint

* [Fix]: Refactor CAE config

* [Fix]: Change cv2 backend to pillow backend

* [Fix]: Fix MAE and CAE reshape bug

* [Feature]: Add freeze vit layers

* [Feature]: Add mc

* [Fix]: Fix lint

* [Fix]: Fix dataset bug

* [Fix]: Delete cae selfsup config

* [Fix]: docstring

* [Refactor]: Add init_values to layer_scalue_init_value

* [Fix]: Refine the docstring of avg_token

* [Fix]: Call super init weight in beit attention

* [Fix]: remove mc

* [Fix]: Fix docstring

* [Fix]: Fix docstring

* [Fix]: Fix lint

* [Fix]: Fix init_value bug and change the logic of outputting cls token

* [Fix]: Fix docstring
This commit is contained in:
Yuan Liu 2022-08-17 00:07:06 +08:00 committed by GitHub
parent b8b31e9343
commit e4252d6848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 311 additions and 5 deletions

View File

@ -1,16 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,
to_2tuple)
from .base_backbone import BaseBackbone
@ -98,6 +100,116 @@ class TransformerEncoderLayer(BaseModule):
return x
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
"""Implements one encoder layer in BEiT.
Comparing with conventional ``TransformerEncoderLayer``, this module
adds weights to the shortcut connection. In addition, ``BEiTAttention``
is used to replace the original ``MultiheadAttention`` in
``TransformerEncoderLayer``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
window_size (tuple[int]): The height and width of the window.
Defaults to None.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='LN').
attn_cfg (dict): The configuration for the attention layer.
Defaults to an empty dict.
ffn_cfg (dict): The configuration for the ffn layer.
Defaults to ``dict(add_identity=False)``.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
layer_scale_init_value: float,
window_size: Tuple[int, int],
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
bias: Union[str, bool] = 'qv_bias',
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
attn_cfg: dict = dict(),
ffn_cfg: dict = dict(add_identity=False),
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
attn_drop_rate=attn_drop_rate,
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
qkv_bias=bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
# overwrite the default attention layer in TransformerEncoderLayer
attn_cfg.update(
dict(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
bias=bias))
self.attn = BEiTAttention(**attn_cfg)
# overwrite the default ffn layer in TransformerEncoderLayer
ffn_cfg.update(
dict(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
if drop_path_rate > 0 else None,
act_cfg=act_cfg))
self.ffn = FFN(**ffn_cfg)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.gamma_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
self.gamma_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
return x
@MODELS.register_module()
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
@ -136,8 +248,16 @@ class VisionTransformer(BaseBackbone):
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
beit_style (bool): Whether or not use BEiT-style. Defaults to False.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -205,7 +325,11 @@ class VisionTransformer(BaseBackbone):
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=True,
beit_style=False,
layer_scale_init_value=0.1,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
@ -289,18 +413,40 @@ class VisionTransformer(BaseBackbone):
qkv_bias=qkv_bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
if beit_style:
_layer_cfg.update(
dict(
layer_scale_init_value=layer_scale_init_value,
window_size=self.patch_resolution))
_layer_cfg.pop('qkv_bias')
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
else:
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.avg_token = avg_token
if avg_token:
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(VisionTransformer, self).init_weights()
@ -336,6 +482,29 @@ class VisionTransformer(BaseBackbone):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.norm1.eval()
for param in self.norm1.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
@ -372,6 +541,12 @@ class VisionTransformer(BaseBackbone):
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.avg_token:
patch_token = patch_token.permute(0, 2, 3, 1)
patch_token = patch_token.reshape(
B, patch_resolution[0] * patch_resolution[1],
C).mean(dim=1)
patch_token = self.norm2(patch_token)
if self.output_cls_token:
out = [patch_token, cls_token]
else:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import MultiheadAttention, ShiftWindowMSA
from .attention import BEiTAttention, MultiheadAttention, ShiftWindowMSA
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
from .channel_shuffle import channel_shuffle
from .data_preprocessor import ClsDataPreprocessor
@ -17,5 +17,5 @@ __all__ = [
'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA',
'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding',
'resize_pos_embed', 'resize_relative_position_bias_table',
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix'
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention'
]

View File

@ -382,3 +382,134 @@ class MultiheadAttention(BaseModule):
if self.v_shortcut:
x = v.squeeze(1) + x
return x
class BEiTAttention(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
The initial implementation is in MMSegmentation.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
bias (str): The option to add leanable bias for q, k, v. If bias is
True, it will add leanable bias. If bias is 'qv_bias', it will only
add leanable bias for q, v. If bias is False, it will not add bias
for q, k, v. Default to 'qv_bias'.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
bias='qv_bias',
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.bias = bias
self.scale = qk_scale or head_embed_dims**-0.5
qkv_bias = bias
if bias == 'qv_bias':
self._init_qv_bias()
qkv_bias = False
self.window_size = window_size
self._init_rel_pos_embedding()
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
def _init_qv_bias(self):
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
def _init_rel_pos_embedding(self):
Wh, Ww = self.window_size
# cls to token & token 2 cls & cls to cls
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, self.num_heads))
# get pair-wise relative position index for
# each token inside the window
coords_h = torch.arange(Wh)
coords_w = torch.arange(Ww)
# coords shape is (2, Wh, Ww)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
# coords_flatten shape is (2, Wh*Ww)
coords_flatten = torch.flatten(coords, 1)
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :])
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# shift to start from 0
relative_coords[:, :, 0] += Wh - 1
relative_coords[:, :, 1] += Ww - 1
relative_coords[:, :, 0] *= 2 * Ww - 1
relative_position_index = torch.zeros(
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
# relative_position_index shape is (Wh*Ww, Wh*Ww)
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer('relative_position_index',
relative_position_index)
def init_weights(self):
super().init_weights()
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C).
"""
B, N, C = x.shape
if self.bias == 'qv_bias':
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
else:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
Wh = self.window_size[0]
Ww = self.window_size[1]
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
Wh * Ww + 1, Wh * Ww + 1, -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x