mmpretrain/mmcls/models/backbones/vision_transformer.py

481 lines
18 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from mmcv.cnn import (build_activation_layer, build_conv_layer,
build_norm_layer, kaiming_init)
from ..builder import BACKBONES
from ..utils import to_2tuple
from .base_backbone import BaseBackbone
class FFN(nn.Module):
"""Implements feed-forward networks (FFNs) with residual connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Defaluts to 2.
act_cfg (dict, optional): The activation config for FFNs.
dropout (float, optional): Probability of an element to be
zeroed. Default 0.0.
add_residual (bool, optional): Add resudual connection.
Defaults to False.
"""
def __init__(self,
embed_dims,
feedforward_channels,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout=0.0,
add_residual=True):
super(FFN, self).__init__()
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = nn.ModuleList()
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
nn.Linear(in_channels, feedforward_channels),
self.activate, nn.Dropout(dropout)))
in_channels = feedforward_channels
layers.append(nn.Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# xavier_init(m, distribution='uniform')
# Bias init is different from our API
# therefore initialize them separately
# The initialization is sync with ClassyVision
nn.init.xavier_normal_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x, residual=None):
"""Forward function for `FFN`."""
out = self.layers(x)
if not self.add_residual:
return out
if residual is None:
residual = x
return residual + self.dropout(out)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'add_residual={self.add_residual})'
return repr_str
class MultiheadAttention(nn.Module):
"""A warpper for torch.nn.MultiheadAttention.
This module implements MultiheadAttention with residual connection.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0.
proj_drop (float): The drop out rate after attention. Default 0.0.
"""
def __init__(self, embed_dims, num_heads, attn_drop=0.0, proj_drop=0.0):
super(MultiheadAttention, self).__init__()
assert embed_dims % num_heads == 0, 'embed_dims must be ' \
f'divisible by num_heads. got {embed_dims} and {num_heads}.'
self.embed_dims = embed_dims
self.num_heads = num_heads
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop)
self.dropout = nn.Dropout(proj_drop)
def forward(self,
x,
key=None,
value=None,
residual=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None):
"""Forward function for `MultiheadAttention`.
Args:
x (Tensor): The input query with shape [num_query, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
key (Tensor): The key tensor with shape [num_key, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
Default None. If None, the `query` will be used.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Default None.
If None, the `key` will be used.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. Default None. If not None, it will
be added to `x` before forward function.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Default None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (Tensor): ByteTensor mask with shape [num_query,
num_key]. Same in `nn.MultiheadAttention.forward`.
Default None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
Same in `nn.MultiheadAttention.forward`. Default None.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
query = x
if key is None:
key = query
if value is None:
value = key
if residual is None:
residual = x
if key_pos is None:
if query_pos is not None and key is not None:
if query_pos.shape == key.shape:
key_pos = query_pos
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query,
key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out)
class TransformerEncoderLayer(nn.Module):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension. Same as `FFN`.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
attn_drop (float): The drop out rate for attention layer.
Default 0.0.
proj_drop (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.0.
act_cfg (dict): The activation config for FFNs. Defalut GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
num_fcs (int): The number of fully-connected layers for FFNs.
Default 2.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
attn_drop=0.,
proj_drop=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
num_fcs=2):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
embed_dims,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.mlp = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
proj_drop)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
norm_x = self.norm1(x)
# Reason for permute: as the shape of input from pretrained weight
# from pytorch-image-models is [batch_size, num_query, embed_dim],
# but the one from nn.MultiheadAttention is
# [num_query, batch_size, embed_dim]
x = x.permute(1, 0, 2)
norm_x = norm_x.permute(1, 0, 2)
x = self.attn(norm_x, residual=x)
# Convert the shape back to [batch_size, num_query, embed_dim] in
# order to make use of the pretrained weight
x = x.permute(1, 0, 2)
x = self.mlp(self.norm2(x), residual=x)
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): The size of input image.
patch_size (int): The size of one patch
in_channels (int): The num of input channels.
embed_dim (int): The dimensions of embedding.
conv_cfg (dict | None): The config dict for conv layers.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768,
conv_cfg=None):
super(PatchEmbed, self).__init__()
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.patch_size = to_2tuple(patch_size)
num_patches = (self.img_size[1] // self.patch_size[1]) * (
self.img_size[0] // self.patch_size[0])
assert num_patches * self.patch_size[0] * self.patch_size[1] == \
self.img_size[0] * self.img_size[1], \
'The image size H*W must be divisible by patch size'
self.num_patches = num_patches
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg,
in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.init_weights()
def init_weights(self):
# Lecun norm from ClassyVision
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class HybridEmbed(nn.Module):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dim=768,
conv_cfg=None):
super().__init__()
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg, feature_dim, embed_dim, kernel_size=1, stride=1)
self.init_weights()
def init_weights(self):
# Lecun norm from ClassyVision
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
@BACKBONES.register_module()
class VisionTransformer(BaseBackbone):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words:
Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
Args:
num_layers (int): Depth of transformer
embed_dim (int): Embedding dimension
num_heads (int): Number of attention heads
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
in_channels (int): Number of input channels
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0.
attn_drop (float): The drop out rate for attention layer.
Default 0.0.
hybrid_backbone (nn.Module): CNN backbone to use in-place of
PatchEmbed module. Default None.
norm_cfg
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
act_cfg (dict): The activation config for FFNs. Defalut GELU.
num_fcs (int): The number of fully-connected layers for FFNs.
Default 2.
"""
def __init__(self,
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.,
attn_drop_rate=0.,
hybrid_backbone=None,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
num_fcs=2):
super(VisionTransformer, self).__init__()
self.embed_dim = embed_dim
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone,
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dim,
num_heads,
feedforward_channels,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
num_fcs=num_fcs))
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dim, postfix=1)
self.add_module(self.norm1_name, norm1)
self.init_weights()
def init_weights(self):
super(VisionTransformer, self).init_weights()
nn.init.normal_(self.pos_embed, std=0.02)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
for layer in self.layers:
x = layer(x)
x = self.norm1(x)[:, 0]
return x