diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py index 00fcf3218..3c8b53084 100644 --- a/mmcls/models/backbones/vision_transformer.py +++ b/mmcls/models/backbones/vision_transformer.py @@ -1,16 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmengine.model import BaseModule, ModuleList from mmengine.model.utils import trunc_normal_ from mmcls.registry import MODELS -from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple +from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed, + to_2tuple) from .base_backbone import BaseBackbone @@ -98,6 +100,116 @@ class TransformerEncoderLayer(BaseModule): return x +class BEiTTransformerEncoderLayer(TransformerEncoderLayer): + """Implements one encoder layer in BEiT. + + Comparing with conventional ``TransformerEncoderLayer``, this module + adds weights to the shortcut connection. In addition, ``BEiTAttention`` + is used to replace the original ``MultiheadAttention`` in + ``TransformerEncoderLayer``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + window_size (tuple[int]): The height and width of the window. + Defaults to None. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='LN'). + attn_cfg (dict): The configuration for the attention layer. + Defaults to an empty dict. + ffn_cfg (dict): The configuration for the ffn layer. + Defaults to ``dict(add_identity=False)``. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + layer_scale_init_value: float, + window_size: Tuple[int, int], + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + bias: Union[str, bool] = 'qv_bias', + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + attn_cfg: dict = dict(), + ffn_cfg: dict = dict(add_identity=False), + init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: + attn_cfg.update(dict(window_size=window_size, qk_scale=None)) + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + qkv_bias=bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + # overwrite the default attention layer in TransformerEncoderLayer + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + bias=bias)) + self.attn = BEiTAttention(**attn_cfg) + + # overwrite the default ffn layer in TransformerEncoderLayer + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.ffn = FFN(**ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + return x + + @MODELS.register_module() class VisionTransformer(BaseBackbone): """Vision Transformer. @@ -136,8 +248,16 @@ class VisionTransformer(BaseBackbone): final feature map. Defaults to True. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. + avg_token (bool): Whether or not to use the mean patch token for + classification. If True, the model will only take the average + of all patch tokens. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. output_cls_token (bool): Whether output the cls_token. If set True, ``with_cls_token`` must be True. Defaults to True. + beit_style (bool): Whether or not use BEiT-style. Defaults to False. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -205,7 +325,11 @@ class VisionTransformer(BaseBackbone): norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, with_cls_token=True, + avg_token=False, + frozen_stages=-1, output_cls_token=True, + beit_style=False, + layer_scale_init_value=0.1, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), @@ -289,18 +413,40 @@ class VisionTransformer(BaseBackbone): qkv_bias=qkv_bias, norm_cfg=norm_cfg) _layer_cfg.update(layer_cfgs[i]) - self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + if beit_style: + _layer_cfg.update( + dict( + layer_scale_init_value=layer_scale_init_value, + window_size=self.patch_resolution)) + _layer_cfg.pop('qkv_bias') + self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) + else: + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + self.frozen_stages = frozen_stages self.final_norm = final_norm if final_norm: self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) + self.avg_token = avg_token + if avg_token: + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + @property def norm1(self): return getattr(self, self.norm1_name) + @property + def norm2(self): + return getattr(self, self.norm2_name) + def init_weights(self): super(VisionTransformer, self).init_weights() @@ -336,6 +482,29 @@ class VisionTransformer(BaseBackbone): """Interface for backward-compatibility.""" return resize_pos_embed(*args, **kwargs) + def _freeze_stages(self): + # freeze position embedding + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers) and self.final_norm: + self.norm1.eval() + for param in self.norm1.parameters(): + param.requires_grad = False + def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) @@ -372,6 +541,12 @@ class VisionTransformer(BaseBackbone): patch_token = x.reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = None + if self.avg_token: + patch_token = patch_token.permute(0, 2, 3, 1) + patch_token = patch_token.reshape( + B, patch_resolution[0] * patch_resolution[1], + C).mean(dim=1) + patch_token = self.norm2(patch_token) if self.output_cls_token: out = [patch_token, cls_token] else: diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py index 98ec41a68..e1553bc85 100644 --- a/mmcls/models/utils/__init__.py +++ b/mmcls/models/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .attention import MultiheadAttention, ShiftWindowMSA +from .attention import BEiTAttention, MultiheadAttention, ShiftWindowMSA from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix from .channel_shuffle import channel_shuffle from .data_preprocessor import ClsDataPreprocessor @@ -17,5 +17,5 @@ __all__ = [ 'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA', 'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed', 'resize_relative_position_bias_table', - 'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix' + 'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention' ] diff --git a/mmcls/models/utils/attention.py b/mmcls/models/utils/attention.py index 835a7fd41..2da1ac148 100644 --- a/mmcls/models/utils/attention.py +++ b/mmcls/models/utils/attention.py @@ -382,3 +382,134 @@ class MultiheadAttention(BaseModule): if self.v_shortcut: x = v.squeeze(1) + x return x + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + The initial implementation is in MMSegmentation. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + bias (str): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + self.window_size = window_size + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def init_weights(self): + super().init_weights() + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x