mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature]: Support BEiT Transformer layer. (#919)
* [Feature]: Add BEiT-style transformer encoder layer * [Feature]: Add average token * [Fix]: Fix lint * [Fix]: Refactor CAE config * [Fix]: Change cv2 backend to pillow backend * [Fix]: Fix MAE and CAE reshape bug * [Feature]: Add freeze vit layers * [Feature]: Add mc * [Fix]: Fix lint * [Fix]: Fix dataset bug * [Fix]: Delete cae selfsup config * [Fix]: docstring * [Refactor]: Add init_values to layer_scalue_init_value * [Fix]: Refine the docstring of avg_token * [Fix]: Call super init weight in beit attention * [Fix]: remove mc * [Fix]: Fix docstring * [Fix]: Fix docstring * [Fix]: Fix lint * [Fix]: Fix init_value bug and change the logic of outputting cls token * [Fix]: Fix docstring
This commit is contained in:
parent
b8b31e9343
commit
e4252d6848
@ -1,16 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,
|
||||
to_2tuple)
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@ -98,6 +100,116 @@ class TransformerEncoderLayer(BaseModule):
|
||||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Implements one encoder layer in BEiT.
|
||||
|
||||
Comparing with conventional ``TransformerEncoderLayer``, this module
|
||||
adds weights to the shortcut connection. In addition, ``BEiTAttention``
|
||||
is used to replace the original ``MultiheadAttention`` in
|
||||
``TransformerEncoderLayer``.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
Defaults to None.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Defaults to 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
||||
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
||||
only add leanable bias for q, v. If bias is False, it will not add
|
||||
bias for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaults to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='LN').
|
||||
attn_cfg (dict): The configuration for the attention layer.
|
||||
Defaults to an empty dict.
|
||||
ffn_cfg (dict): The configuration for the ffn layer.
|
||||
Defaults to ``dict(add_identity=False)``.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels: int,
|
||||
layer_scale_init_value: float,
|
||||
window_size: Tuple[int, int],
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
num_fcs: int = 2,
|
||||
bias: Union[str, bool] = 'qv_bias',
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
attn_cfg: dict = dict(),
|
||||
ffn_cfg: dict = dict(add_identity=False),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# overwrite the default attention layer in TransformerEncoderLayer
|
||||
attn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
bias=bias))
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
# overwrite the default ffn layer in TransformerEncoderLayer
|
||||
ffn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseBackbone):
|
||||
"""Vision Transformer.
|
||||
@ -136,8 +248,16 @@ class VisionTransformer(BaseBackbone):
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
beit_style (bool): Whether or not use BEiT-style. Defaults to False.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN. Defaults to 0.1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
@ -205,7 +325,11 @@ class VisionTransformer(BaseBackbone):
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=True,
|
||||
beit_style=False,
|
||||
layer_scale_init_value=0.1,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
@ -289,18 +413,40 @@ class VisionTransformer(BaseBackbone):
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
if beit_style:
|
||||
_layer_cfg.update(
|
||||
dict(
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=self.patch_resolution))
|
||||
_layer_cfg.pop('qkv_bias')
|
||||
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
else:
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.avg_token = avg_token
|
||||
if avg_token:
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def init_weights(self):
|
||||
super(VisionTransformer, self).init_weights()
|
||||
|
||||
@ -336,6 +482,29 @@ class VisionTransformer(BaseBackbone):
|
||||
"""Interface for backward-compatibility."""
|
||||
return resize_pos_embed(*args, **kwargs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
# freeze position embedding
|
||||
self.pos_embed.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.drop_after_pos.eval()
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.layers) and self.final_norm:
|
||||
self.norm1.eval()
|
||||
for param in self.norm1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
@ -372,6 +541,12 @@ class VisionTransformer(BaseBackbone):
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.avg_token:
|
||||
patch_token = patch_token.permute(0, 2, 3, 1)
|
||||
patch_token = patch_token.reshape(
|
||||
B, patch_resolution[0] * patch_resolution[1],
|
||||
C).mean(dim=1)
|
||||
patch_token = self.norm2(patch_token)
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attention import MultiheadAttention, ShiftWindowMSA
|
||||
from .attention import BEiTAttention, MultiheadAttention, ShiftWindowMSA
|
||||
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .data_preprocessor import ClsDataPreprocessor
|
||||
@ -17,5 +17,5 @@ __all__ = [
|
||||
'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA',
|
||||
'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding',
|
||||
'resize_pos_embed', 'resize_relative_position_bias_table',
|
||||
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix'
|
||||
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention'
|
||||
]
|
||||
|
@ -382,3 +382,134 @@ class MultiheadAttention(BaseModule):
|
||||
if self.v_shortcut:
|
||||
x = v.squeeze(1) + x
|
||||
return x
|
||||
|
||||
|
||||
class BEiTAttention(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
The initial implementation is in MMSegmentation.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
bias (str): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.bias = bias
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
qkv_bias = bias
|
||||
if bias == 'qv_bias':
|
||||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
def _init_qv_bias(self):
|
||||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.bias == 'qv_bias':
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
else:
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user