# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import (build_activation_layer, build_conv_layer, build_norm_layer, kaiming_init) from ..builder import BACKBONES from ..utils import to_2tuple from .base_backbone import BaseBackbone class FFN(nn.Module): """Implements feed-forward networks (FFNs) with residual connection. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. feedforward_channels (int): The hidden dimension of FFNs. num_fcs (int, optional): The number of fully-connected layers in FFNs. Defaults to 2. act_cfg (dict, optional): The activation config for FFNs. dropout (float, optional): Probability of an element to be zeroed. Default 0.0. add_residual (bool, optional): Add resudual connection. Defaults to False. """ def __init__(self, embed_dims, feedforward_channels, num_fcs=2, act_cfg=dict(type='GELU'), dropout=0.0, add_residual=True): super(FFN, self).__init__() assert num_fcs >= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs self.act_cfg = act_cfg self.activate = build_activation_layer(act_cfg) layers = nn.ModuleList() in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( nn.Sequential( nn.Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(dropout))) in_channels = feedforward_channels layers.append(nn.Linear(feedforward_channels, embed_dims)) self.layers = nn.Sequential(*layers) self.dropout = nn.Dropout(dropout) self.add_residual = add_residual self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): # xavier_init(m, distribution='uniform') # Bias init is different from our API # therefore initialize them separately # The initialization is sync with ClassyVision nn.init.xavier_normal_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x, residual=None): """Forward function for `FFN`.""" out = self.layers(x) if not self.add_residual: return out if residual is None: residual = x return residual + self.dropout(out) def __repr__(self): """str: a string that describes the module""" repr_str = self.__class__.__name__ repr_str += f'(embed_dims={self.embed_dims}, ' repr_str += f'feedforward_channels={self.feedforward_channels}, ' repr_str += f'num_fcs={self.num_fcs}, ' repr_str += f'act_cfg={self.act_cfg}, ' repr_str += f'dropout={self.dropout}, ' repr_str += f'add_residual={self.add_residual})' return repr_str class MultiheadAttention(nn.Module): """A warpper for torch.nn.MultiheadAttention. This module implements MultiheadAttention with residual connection. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0. proj_drop (float): The drop out rate after attention. Default 0.0. """ def __init__(self, embed_dims, num_heads, attn_drop=0.0, proj_drop=0.0): super(MultiheadAttention, self).__init__() assert embed_dims % num_heads == 0, 'embed_dims must be ' \ f'divisible by num_heads. got {embed_dims} and {num_heads}.' self.embed_dims = embed_dims self.num_heads = num_heads self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop) self.dropout = nn.Dropout(proj_drop) def forward(self, x, key=None, value=None, residual=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None): """Forward function for `MultiheadAttention`. Args: x (Tensor): The input query with shape [num_query, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. key (Tensor): The key tensor with shape [num_key, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. Default None. If None, the `query` will be used. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Default None. If None, the `key` will be used. residual (Tensor): The tensor used for addition, with the same shape as `x`. Default None. If None, `x` will be used. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. Default None. If not None, it will be added to `x` before forward function. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Default None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. attn_mask (Tensor): ByteTensor mask with shape [num_query, num_key]. Same in `nn.MultiheadAttention.forward`. Default None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. Same in `nn.MultiheadAttention.forward`. Default None. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ query = x if key is None: key = query if value is None: value = key if residual is None: residual = x if key_pos is None: if query_pos is not None and key is not None: if query_pos.shape == key.shape: key_pos = query_pos if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos out = self.attn( query, key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] return residual + self.dropout(out) class TransformerEncoderLayer(nn.Module): """Implements one encoder layer in Vision Transformer. Args: embed_dims (int): The feature dimension. Same as `FFN`. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. attn_drop (float): The drop out rate for attention layer. Default 0.0. proj_drop (float): Probability of an element to be zeroed after the feed forward layer. Default 0.0. act_cfg (dict): The activation config for FFNs. Defaults to GELU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization. num_fcs (int): The number of fully-connected layers for FFNs. Default 2. """ def __init__(self, embed_dims, num_heads, feedforward_channels, attn_drop=0., proj_drop=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), num_fcs=2): super(TransformerEncoderLayer, self).__init__() self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) self.attn = MultiheadAttention( embed_dims, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop) self.norm2_name, norm2 = build_norm_layer( norm_cfg, embed_dims, postfix=2) self.add_module(self.norm2_name, norm2) self.mlp = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, proj_drop) @property def norm1(self): return getattr(self, self.norm1_name) @property def norm2(self): return getattr(self, self.norm2_name) def forward(self, x): norm_x = self.norm1(x) # Reason for permute: as the shape of input from pretrained weight # from pytorch-image-models is [batch_size, num_query, embed_dim], # but the one from nn.MultiheadAttention is # [num_query, batch_size, embed_dim] x = x.permute(1, 0, 2) norm_x = norm_x.permute(1, 0, 2) x = self.attn(norm_x, residual=x) # Convert the shape back to [batch_size, num_query, embed_dim] in # order to make use of the pretrained weight x = x.permute(1, 0, 2) x = self.mlp(self.norm2(x), residual=x) return x class PatchEmbed(nn.Module): """Image to Patch Embedding. Args: img_size (int | tuple): The size of input image. patch_size (int): The size of one patch in_channels (int): The num of input channels. embed_dim (int): The dimensions of embedding. conv_cfg (dict | None): The config dict for conv layers. Default: None. """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, conv_cfg=None): super(PatchEmbed, self).__init__() if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) assert len(img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' self.img_size = img_size self.patch_size = to_2tuple(patch_size) num_patches = (self.img_size[1] // self.patch_size[1]) * ( self.img_size[0] // self.patch_size[0]) assert num_patches * self.patch_size[0] * self.patch_size[1] == \ self.img_size[0] * self.img_size[1], \ 'The image size H*W must be divisible by patch size' self.num_patches = num_patches # Use conv layer to embed self.projection = build_conv_layer( conv_cfg, in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.init_weights() def init_weights(self): # Lecun norm from ClassyVision kaiming_init(self.projection, mode='fan_in', nonlinearity='linear') def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't " \ f'match model ({self.img_size[0]}*{self.img_size[1]}).' # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim x = self.projection(x).flatten(2).transpose(1, 2) return x class HybridEmbed(nn.Module): """CNN Feature Map Embedding. Extract feature map from CNN, flatten, project to embedding dim. """ def __init__(self, backbone, img_size=224, feature_size=None, in_channels=3, embed_dim=768, conv_cfg=None): super().__init__() assert isinstance(backbone, nn.Module) if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) assert len(img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(img_size)}' self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of # determining the exact dim of the output feature # map for all networks, the feature metadata has # reliable channel and stride info, but using # stride to calc feature dim requires info about padding of # each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone( torch.zeros(1, in_channels, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): # last feature if backbone outputs list/tuple of features o = o[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, 'feature_info'): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.num_patches = feature_size[0] * feature_size[1] # Use conv layer to embed self.projection = build_conv_layer( conv_cfg, feature_dim, embed_dim, kernel_size=1, stride=1) self.init_weights() def init_weights(self): # Lecun norm from ClassyVision kaiming_init(self.projection, mode='fan_in', nonlinearity='linear') def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): # last feature if backbone outputs list/tuple of features x = x[-1] x = self.projection(x).flatten(2).transpose(1, 2) return x @BACKBONES.register_module() class VisionTransformer(BaseBackbone): """Vision Transformer. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_ Args: num_layers (int): Depth of transformer embed_dim (int): Embedding dimension num_heads (int): Number of attention heads img_size (int | tuple): Input image size patch_size (int | tuple): The patch size in_channels (int): Number of input channels feedforward_channels (int): The hidden dimension for FFNs. drop_rate (float): Probability of an element to be zeroed. Default 0.0. attn_drop (float): The drop out rate for attention layer. Default 0.0. hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module. Default None. norm_cfg norm_cfg (dict): Config dict for normalization layer. Default layer normalization. act_cfg (dict): The activation config for FFNs. Defaults to GELU. num_fcs (int): The number of fully-connected layers for FFNs. Default 2. """ def __init__(self, num_layers=12, embed_dim=768, num_heads=12, img_size=224, patch_size=16, in_channels=3, feedforward_channels=3072, drop_rate=0., attn_drop_rate=0., hybrid_backbone=None, norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), num_fcs=2): super(VisionTransformer, self).__init__() self.embed_dim = embed_dim if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_channels=in_channels, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dim)) self.drop_after_pos = nn.Dropout(p=drop_rate) self.layers = nn.ModuleList() for _ in range(num_layers): self.layers.append( TransformerEncoderLayer( embed_dim, num_heads, feedforward_channels, attn_drop=attn_drop_rate, proj_drop=drop_rate, act_cfg=act_cfg, norm_cfg=norm_cfg, num_fcs=num_fcs)) self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dim, postfix=1) self.add_module(self.norm1_name, norm1) def init_weights(self): super(VisionTransformer, self).init_weights() nn.init.normal_(self.pos_embed, std=0.02) @property def norm1(self): return getattr(self, self.norm1_name) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand( B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.drop_after_pos(x) for layer in self.layers: x = layer(x) x = self.norm1(x)[:, 0] return x