# Copyright (c) OpenMMLab. All rights reserved. import math from functools import partial from typing import Optional, Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn.bricks import ConvModule, DropPath from mmcv.cnn.bricks.transformer import FFN from mmengine.model import BaseModule, Sequential from mmengine.model.weight_init import trunc_normal_ from mmengine.utils import digit_version from mmpretrain.registry import MODELS from ..utils import build_norm_layer, to_2tuple from .base_backbone import BaseBackbone if digit_version(torch.__version__) < digit_version('1.8.0'): floor_div = torch.floor_divide else: floor_div = partial(torch.div, rounding_mode='floor') class ClassAttntion(BaseModule): """Class Attention Module. A PyTorch implementation of Class Attention Module introduced by: `Going deeper with Image Transformers `_ taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py with slight modifications to do CA Args: dim (int): The feature dimension. num_heads (int): Parallel attention heads. Defaults to 8. qkv_bias (bool): enable bias for qkv if True. Defaults to False. attn_drop (float): The drop out rate for attention output weights. Defaults to 0. proj_drop (float): The drop out rate for linear output weights. Defaults to 0. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ # noqa: E501 def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0., proj_drop: float = 0., init_cfg=None): super(ClassAttntion, self).__init__(init_cfg=init_cfg) self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias) self.v = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape # We only need to calculate query of cls token. q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) return x_cls class PositionalEncodingFourier(BaseModule): """Positional Encoding using a fourier kernel. A PyTorch implementation of Positional Encoding relying on a fourier kernel introduced by: `Attention is all you Need `_ Based on the `official XCiT code `_ Args: hidden_dim (int): The hidden feature dimension. Defaults to 32. dim (int): The output feature dimension. Defaults to 768. temperature (int): A control variable for position encoding. Defaults to 10000. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, hidden_dim: int = 32, dim: int = 768, temperature: int = 10000, init_cfg=None): super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg) self.token_projection = ConvModule( in_channels=hidden_dim * 2, out_channels=dim, kernel_size=1, conv_cfg=None, norm_cfg=None, act_cfg=None) self.scale = 2 * math.pi self.temperature = temperature self.hidden_dim = hidden_dim self.dim = dim self.eps = 1e-6 def forward(self, B: int, H: int, W: int): device = self.token_projection.conv.weight.device y_embed = torch.arange( 1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float() x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float() y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale dim_t = torch.arange(self.hidden_dim, device=device).float() dim_t = floor_div(dim_t, 2) dim_t = self.temperature**(2 * dim_t / self.hidden_dim) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3) pos_y = torch.stack( [pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = self.token_projection(pos) return pos.repeat(B, 1, 1, 1) # (B, C, H, W) class ConvPatchEmbed(BaseModule): """Patch Embedding using multiple convolution layers. Args: img_size (int, tuple): input image size. Defaults to 224, means the size is 224*224. patch_size (int): The patch size in conv patch embedding. Defaults to 16. in_channels (int): The input channels of this module. Defaults to 3. embed_dims (int): The feature dimension norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='BN')``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, img_size: Union[int, tuple] = 224, patch_size: int = 16, in_channels: int = 3, embed_dims: int = 768, norm_cfg=dict(type='BN'), act_cfg=dict(type='GELU'), init_cfg=None): super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg) img_size = to_2tuple(img_size) num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches conv = partial( ConvModule, kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg, ) layer = [] if patch_size == 16: layer.append( conv(in_channels=in_channels, out_channels=embed_dims // 8)) layer.append( conv( in_channels=embed_dims // 8, out_channels=embed_dims // 4)) elif patch_size == 8: layer.append( conv(in_channels=in_channels, out_channels=embed_dims // 4)) else: raise ValueError('For patch embedding, the patch size must be 16 ' f'or 8, but get patch size {self.patch_size}.') layer.append( conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2)) layer.append( conv( in_channels=embed_dims // 2, out_channels=embed_dims, act_cfg=None, )) self.proj = Sequential(*layer) def forward(self, x: torch.Tensor): x = self.proj(x) Hp, Wp = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) # (B, N, C) return x, (Hp, Wp) class ClassAttentionBlock(BaseModule): """Transformer block using Class Attention. Args: dim (int): The feature dimension. num_heads (int): Parallel attention heads. mlp_ratio (float): The hidden dimension ratio for FFN. Defaults to 4. qkv_bias (bool): enable bias for qkv if True. Defaults to False. drop (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. attn_drop (float): The drop out rate for attention output weights. Defaults to 0. drop_path (float): Stochastic depth rate. Defaults to 0. layer_scale_init_value (float): The initial value for layer scale. Defaults to 1. tokens_norm (bool): Whether to normalize all tokens or just the cls_token in the CA. Defaults to False. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN', eps=1e-6)``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4., qkv_bias: bool = False, drop=0., attn_drop=0., drop_path=0., layer_scale_init_value=1., tokens_norm=False, norm_cfg=dict(type='LN', eps=1e-6), act_cfg=dict(type='GELU'), init_cfg=None): super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg) self.norm1 = build_norm_layer(norm_cfg, dim) self.attn = ClassAttntion( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm2 = build_norm_layer(norm_cfg, dim) self.ffn = FFN( embed_dims=dim, feedforward_channels=int(dim * mlp_ratio), act_cfg=act_cfg, ffn_drop=drop, ) if layer_scale_init_value > 0: self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) else: self.gamma1, self.gamma2 = 1.0, 1.0 # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 self.tokens_norm = tokens_norm def forward(self, x): x_norm1 = self.norm1(x) x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) x = x + self.drop_path(self.gamma1 * x_attn) if self.tokens_norm: x = self.norm2(x) else: x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) x_res = x cls_token = x[:, 0:1] cls_token = self.gamma2 * self.ffn(cls_token, identity=0) x = torch.cat([cls_token, x[:, 1:]], dim=1) x = x_res + self.drop_path(x) return x class LPI(BaseModule): """Local Patch Interaction module. A PyTorch implementation of Local Patch Interaction module as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers `_ Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d Args: in_features (int): The input channels. out_features (int, optional): The output channels. Defaults to None. kernel_size (int): The kernel_size in ConvModule. Defaults to 3. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='BN')``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_features: int, out_features: Optional[int] = None, kernel_size: int = 3, norm_cfg=dict(type='BN'), act_cfg=dict(type='GELU'), init_cfg=None): super(LPI, self).__init__(init_cfg=init_cfg) out_features = out_features or in_features padding = kernel_size // 2 self.conv1 = ConvModule( in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding, groups=in_features, bias=True, norm_cfg=norm_cfg, act_cfg=act_cfg, order=('conv', 'act', 'norm')) self.conv2 = ConvModule( in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=out_features, norm_cfg=None, act_cfg=None) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, N, C = x.shape x = x.permute(0, 2, 1).reshape(B, C, H, W) x = self.conv1(x) x = self.conv2(x) x = x.reshape(B, C, N).permute(0, 2, 1) return x class XCA(BaseModule): r"""Cross-Covariance Attention module. A PyTorch implementation of Cross-Covariance Attention module as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers `_ In Cross-Covariance Attention (XCA), the channels are updated using a weighted sum. The weights are obtained from the (softmax normalized) Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)` Args: dim (int): The feature dimension. num_heads (int): Parallel attention heads. Defaults to 8. qkv_bias (bool): enable bias for qkv if True. Defaults to False. attn_drop (float): The drop out rate for attention output weights. Defaults to 0. proj_drop (float): The drop out rate for linear output weights. Defaults to 0. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0., proj_drop: float = 0., init_cfg=None): super(XCA, self).__init__(init_cfg=init_cfg) self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape # (qkv, B, num_heads, channels per head, N) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) # Paper section 3.2 l2-Normalization and temperature scaling q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C) x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class XCABlock(BaseModule): """Transformer block using XCA. Args: dim (int): The feature dimension. num_heads (int): Parallel attention heads. mlp_ratio (float): The hidden dimension ratio for FFNs. Defaults to 4. qkv_bias (bool): enable bias for qkv if True. Defaults to False. drop (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. attn_drop (float): The drop out rate for attention output weights. Defaults to 0. drop_path (float): Stochastic depth rate. Defaults to 0. layer_scale_init_value (float): The initial value for layer scale. Defaults to 1. bn_norm_cfg (dict): Config dict for batchnorm in LPI and ConvPatchEmbed. Defaults to ``dict(type='BN')``. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN', eps=1e-6)``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict | list[dict], optional): Initialization config dict. """ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4., qkv_bias: bool = False, drop: float = 0., attn_drop: float = 0., drop_path: float = 0., layer_scale_init_value: float = 1., bn_norm_cfg=dict(type='BN'), norm_cfg=dict(type='LN', eps=1e-6), act_cfg=dict(type='GELU'), init_cfg=None): super(XCABlock, self).__init__(init_cfg=init_cfg) self.norm1 = build_norm_layer(norm_cfg, dim) self.attn = XCA( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm3 = build_norm_layer(norm_cfg, dim) self.local_mp = LPI( in_features=dim, norm_cfg=bn_norm_cfg, act_cfg=act_cfg, ) self.norm2 = build_norm_layer(norm_cfg, dim) self.ffn = FFN( embed_dims=dim, feedforward_channels=int(dim * mlp_ratio), act_cfg=act_cfg, ffn_drop=drop, ) self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) def forward(self, x, H: int, W: int): x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # NOTE official code has 3 then 2, so keeping it the same to be # consistent with loaded weights See # https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 x = x + self.drop_path( self.gamma3 * self.local_mp(self.norm3(x), H, W)) x = x + self.drop_path( self.gamma2 * self.ffn(self.norm2(x), identity=0)) return x @MODELS.register_module() class XCiT(BaseBackbone): """XCiT backbone. A PyTorch implementation of XCiT backbone introduced by: `XCiT: Cross-Covariance Image Transformers `_ Args: img_size (int, tuple): Input image size. Defaults to 224. patch_size (int): Patch size. Defaults to 16. in_channels (int): Number of input channels. Defaults to 3. embed_dims (int): Embedding dimension. Defaults to 768. depth (int): depth of vision transformer. Defaults to 12. cls_attn_layers (int): Depth of Class attention layers. Defaults to 2. num_heads (int): Number of attention heads. Defaults to 12. mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. Defaults to 4. qkv_bias (bool): enable bias for qkv if True. Defaults to True. drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. attn_drop_rate (float): The drop out rate for attention output weights. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. use_pos_embed (bool): Whether to use positional encoding. Defaults to True. layer_scale_init_value (float): The initial value for layer scale. Defaults to 1. tokens_norm (bool): Whether to normalize all tokens or just the cls_token in the CA. Defaults to False. out_indices (Sequence[int]): Output from which layers. Defaults to (-1, ). frozen_stages (int): Layers to be frozen (all param fixed), and 0 means to freeze the stem stage. Defaults to -1, which means not freeze any parameters. bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and ConvPatchEmbed. Defaults to ``dict(type='BN')``. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN', eps=1e-6)``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict | list[dict], optional): Initialization config dict. """ def __init__(self, img_size: Union[int, tuple] = 224, patch_size: int = 16, in_channels: int = 3, embed_dims: int = 768, depth: int = 12, cls_attn_layers: int = 2, num_heads: int = 12, mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., use_pos_embed: bool = True, layer_scale_init_value: float = 1., tokens_norm: bool = False, out_type: str = 'cls_token', out_indices: Sequence[int] = (-1, ), final_norm: bool = True, frozen_stages: int = -1, bn_norm_cfg=dict(type='BN'), norm_cfg=dict(type='LN', eps=1e-6), act_cfg=dict(type='GELU'), init_cfg=dict(type='TruncNormal', layer='Linear')): super(XCiT, self).__init__(init_cfg=init_cfg) img_size = to_2tuple(img_size) if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0): raise ValueError(f'`patch_size` ({patch_size}) should divide ' f'the image shape ({img_size}) evenly.') self.embed_dims = embed_dims assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token') self.out_type = out_type self.patch_embed = ConvPatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims=embed_dims, norm_cfg=bn_norm_cfg, act_cfg=act_cfg, ) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.use_pos_embed = use_pos_embed if use_pos_embed: self.pos_embed = PositionalEncodingFourier(dim=embed_dims) self.pos_drop = nn.Dropout(p=drop_rate) self.xca_layers = nn.ModuleList() self.ca_layers = nn.ModuleList() self.num_layers = depth + cls_attn_layers for _ in range(depth): self.xca_layers.append( XCABlock( dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, bn_norm_cfg=bn_norm_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, layer_scale_init_value=layer_scale_init_value, )) for _ in range(cls_attn_layers): self.ca_layers.append( ClassAttentionBlock( dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, act_cfg=act_cfg, norm_cfg=norm_cfg, layer_scale_init_value=layer_scale_init_value, tokens_norm=tokens_norm, )) if final_norm: self.norm = build_norm_layer(norm_cfg, embed_dims) # Transform out_indices if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' out_indices = list(out_indices) for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index assert 0 <= out_indices[i] <= self.num_layers, \ f'Invalid out_indices {index}.' self.out_indices = out_indices if frozen_stages > self.num_layers + 1: raise ValueError('frozen_stages must be less than ' f'{self.num_layers} but get {frozen_stages}') self.frozen_stages = frozen_stages def init_weights(self): super().init_weights() if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': return trunc_normal_(self.cls_token, std=.02) def _freeze_stages(self): if self.frozen_stages < 0: return # freeze position embedding if self.use_pos_embed: self.pos_embed.eval() for param in self.pos_embed.parameters(): param.requires_grad = False # freeze patch embedding self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False # set dropout to eval model self.pos_drop.eval() # freeze cls_token, only use in self.Clslayers if self.frozen_stages > len(self.xca_layers): self.cls_token.requires_grad = False # freeze layers for i in range(1, self.frozen_stages): if i <= len(self.xca_layers): m = self.xca_layers[i - 1] else: m = self.ca_layers[i - len(self.xca_layers) - 1] m.eval() for param in m.parameters(): param.requires_grad = False # freeze the last layer norm if all_stages are frozen if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers): self.norm.eval() for param in self.norm.parameters(): param.requires_grad = False def forward(self, x): outs = [] B = x.shape[0] # x is (B, N, C). (Hp, Hw) is the patch resolution x, (Hp, Wp) = self.patch_embed(x) if self.use_pos_embed: # (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C) pos_encoding = self.pos_embed(B, Hp, Wp) x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1) x = self.pos_drop(x) for i, layer in enumerate(self.xca_layers): x = layer(x, Hp, Wp) if i in self.out_indices: outs.append(self._format_output(x, (Hp, Wp), False)) x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for i, layer in enumerate(self.ca_layers): x = layer(x) if i == len(self.ca_layers) - 1: x = self.norm(x) if i + len(self.xca_layers) in self.out_indices: outs.append(self._format_output(x, (Hp, Wp), True)) return tuple(outs) def _format_output(self, x, hw, with_cls_token: bool): if self.out_type == 'raw': return x if self.out_type == 'cls_token': if not with_cls_token: raise ValueError( 'Cannot output cls_token since there is no cls_token.') return x[:, 0] patch_token = x[:, 1:] if with_cls_token else x if self.out_type == 'featmap': B = x.size(0) # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) if self.out_type == 'avg_featmap': return patch_token.mean(dim=1) def train(self, mode=True): super().train(mode) self._freeze_stages()