mmclassification/mmcls/models/backbones/vision_transformer.py

557 lines
22 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
2022-07-12 16:10:59 +08:00
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmcls.registry import MODELS
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,
to_2tuple)
from .base_backbone import BaseBackbone
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(TransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
"""Implements one encoder layer in BEiT.
Comparing with conventional ``TransformerEncoderLayer``, this module
adds weights to the shortcut connection. In addition, ``BEiTAttention``
is used to replace the original ``MultiheadAttention`` in
``TransformerEncoderLayer``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
window_size (tuple[int]): The height and width of the window.
Defaults to None.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='LN').
attn_cfg (dict): The configuration for the attention layer.
Defaults to an empty dict.
ffn_cfg (dict): The configuration for the ffn layer.
Defaults to ``dict(add_identity=False)``.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
layer_scale_init_value: float,
window_size: Tuple[int, int],
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
bias: Union[str, bool] = 'qv_bias',
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
attn_cfg: dict = dict(),
ffn_cfg: dict = dict(add_identity=False),
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
attn_drop_rate=attn_drop_rate,
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
qkv_bias=bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
# overwrite the default attention layer in TransformerEncoderLayer
attn_cfg.update(
dict(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
bias=bias))
self.attn = BEiTAttention(**attn_cfg)
# overwrite the default ffn layer in TransformerEncoderLayer
ffn_cfg.update(
dict(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
if drop_path_rate > 0 else None,
act_cfg=act_cfg))
self.ffn = FFN(**ffn_cfg)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.gamma_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
self.gamma_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
return x
@MODELS.register_module()
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
beit_style (bool): Whether or not use BEiT-style. Defaults to False.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens = 1 # cls_token
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=True,
beit_style=False,
layer_scale_init_value=0.1,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
if beit_style:
_layer_cfg.update(
dict(
layer_scale_init_value=layer_scale_init_value,
window_size=self.patch_resolution))
_layer_cfg.pop('qkv_bias')
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
else:
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.avg_token = avg_token
if avg_token:
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(VisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
2022-07-12 16:10:59 +08:00
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
2022-07-12 16:10:59 +08:00
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.norm1.eval()
for param in self.norm1.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.avg_token:
patch_token = patch_token.permute(0, 2, 3, 1)
patch_token = patch_token.reshape(
B, patch_resolution[0] * patch_resolution[1],
C).mean(dim=1)
patch_token = self.norm2(patch_token)
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)