# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer from mmcv.cnn.bricks import DropPath from mmcv.cnn.bricks.transformer import PatchEmbed from mmcv.runner import BaseModule, ModuleList from mmcv.utils.parrots_wrapper import _BatchNorm from ..builder import BACKBONES from .base_backbone import BaseBackbone class MixFFN(BaseModule): """An implementation of MixFFN of VAN. Refer to mmdetection/mmdet/models/backbones/pvt.py. The differences between MixFFN & FFN: 1. Use 1X1 Conv to replace Linear layer. 2. Introduce 3X3 Depth-wise Conv to encode positional information. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. feedforward_channels (int): The hidden dimension of FFNs. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='GELU'). ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, feedforward_channels, act_cfg=dict(type='GELU'), ffn_drop=0., init_cfg=None): super(MixFFN, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.act_cfg = act_cfg self.fc1 = Conv2d( in_channels=embed_dims, out_channels=feedforward_channels, kernel_size=1) self.dwconv = Conv2d( in_channels=feedforward_channels, out_channels=feedforward_channels, kernel_size=3, stride=1, padding=1, bias=True, groups=feedforward_channels) self.act = build_activation_layer(act_cfg) self.fc2 = Conv2d( in_channels=feedforward_channels, out_channels=embed_dims, kernel_size=1) self.drop = nn.Dropout(ffn_drop) def forward(self, x): x = self.fc1(x) x = self.dwconv(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LKA(BaseModule): """Large Kernel Attention(LKA) of VAN. .. code:: text DW_conv (depth-wise convolution) | | DW_D_conv (depth-wise dilation convolution) | | Transition Convolution (1×1 convolution) Args: embed_dims (int): Number of input channels. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, init_cfg=None): super(LKA, self).__init__(init_cfg=init_cfg) # a spatial local convolution (depth-wise convolution) self.DW_conv = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=5, padding=2, groups=embed_dims) # a spatial long-range convolution (depth-wise dilation convolution) self.DW_D_conv = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=7, stride=1, padding=9, groups=embed_dims, dilation=3) self.conv1 = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) def forward(self, x): u = x.clone() attn = self.DW_conv(x) attn = self.DW_D_conv(attn) attn = self.conv1(attn) return u * attn class SpatialAttention(BaseModule): """Basic attention module in VANBloack. Args: embed_dims (int): Number of input channels. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='GELU'). init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None): super(SpatialAttention, self).__init__(init_cfg=init_cfg) self.proj_1 = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) self.activation = build_activation_layer(act_cfg) self.spatial_gating_unit = LKA(embed_dims) self.proj_2 = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) def forward(self, x): shorcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) x = self.proj_2(x) x = x + shorcut return x class VANBlock(BaseModule): """A block of VAN. Args: embed_dims (int): Number of input channels. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='GELU'). layer_scale_init_value (float): Init value for Layer Scale. Defaults to 1e-2. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, ffn_ratio=4., drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='BN', eps=1e-5), layer_scale_init_value=1e-2, init_cfg=None): super(VANBlock, self).__init__(init_cfg=init_cfg) self.out_channels = embed_dims self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg) self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] mlp_hidden_dim = int(embed_dims * ffn_ratio) self.mlp = MixFFN( embed_dims=embed_dims, feedforward_channels=mlp_hidden_dim, act_cfg=act_cfg, ffn_drop=drop_rate) self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((embed_dims)), requires_grad=True) if layer_scale_init_value > 0 else None self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((embed_dims)), requires_grad=True) if layer_scale_init_value > 0 else None def forward(self, x): identity = x x = self.norm1(x) x = self.attn(x) if self.layer_scale_1 is not None: x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x x = identity + self.drop_path(x) identity = x x = self.norm2(x) x = self.mlp(x) if self.layer_scale_2 is not None: x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x x = identity + self.drop_path(x) return x class VANPatchEmbed(PatchEmbed): """Image to Patch Embedding of VAN. The differences between VANPatchEmbed & PatchEmbed: 1. Use BN. 2. Do not use 'flatten' and 'transpose'. """ def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) def forward(self, x): """ Args: x (Tensor): Has shape (B, C, H, W). In most case, C is 3. Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, out_h * out_w, embed_dims) - out_size (tuple[int]): Spatial shape of x, arrange as (out_h, out_w). """ if self.adaptive_padding: x = self.adaptive_padding(x) x = self.projection(x) out_size = (x.shape[2], x.shape[3]) if self.norm is not None: x = self.norm(x) return x, out_size @BACKBONES.register_module() class VAN(BaseBackbone): """Visual Attention Network. A PyTorch implement of : `Visual Attention Network `_ Inspiration from https://github.com/Visual-Attention-Network/VAN-Classification Args: arch (str | dict): Visual Attention Network architecture. If use string, choose from 'tiny', 'small', 'base' and 'large'. If use dict, it should have below keys: - **embed_dims** (List[int]): The dimensions of embedding. - **depths** (List[int]): The number of blocks in each stage. - **ffn_ratios** (List[int]): The number of expansion ratio of feedforward network hidden layer channels. Defaults to 'tiny'. patch_sizes (List[int | tuple]): The patch size in patch embeddings. Defaults to [7, 3, 3, 3]. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. out_indices (Sequence[int]): Output from which stages. Default: ``(3, )``. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. norm_cfg (dict): Config dict for normalization layer for all output features. Defaults to ``dict(type='LN')`` block_cfgs (Sequence[dict] | dict): The extra config of each block. Defaults to empty dicts. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import VAN >>> import torch >>> cfg = dict(arch='tiny') >>> model = VAN(**cfg) >>> inputs = torch.rand(1, 3, 224, 224) >>> outputs = model(inputs) >>> for out in outputs: >>> print(out.size()) (1, 256, 7, 7) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': [32, 64, 160, 256], 'depths': [3, 3, 5, 2], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['s', 'small'], {'embed_dims': [64, 128, 320, 512], 'depths': [2, 2, 4, 2], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b', 'base'], {'embed_dims': [64, 128, 320, 512], 'depths': [3, 3, 12, 3], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['l', 'large'], {'embed_dims': [64, 128, 320, 512], 'depths': [3, 5, 27, 3], 'ffn_ratios': [8, 8, 4, 4]}), } # yapf: disable def __init__(self, arch='tiny', patch_sizes=[7, 3, 3, 3], in_channels=3, drop_rate=0., drop_path_rate=0., out_indices=(3, ), frozen_stages=-1, norm_eval=False, norm_cfg=dict(type='LN'), block_cfgs=dict(), init_cfg=None): super(VAN, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.ffn_ratios = self.arch_settings['ffn_ratios'] self.num_stages = len(self.depths) self.out_indices = out_indices self.frozen_stages = frozen_stages self.norm_eval = norm_eval total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule cur_block_idx = 0 for i, depth in enumerate(self.depths): patch_embed = VANPatchEmbed( in_channels=in_channels if i == 0 else self.embed_dims[i - 1], input_size=None, embed_dims=self.embed_dims[i], kernel_size=patch_sizes[i], stride=patch_sizes[i] // 2 + 1, padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), norm_cfg=dict(type='BN')) blocks = ModuleList([ VANBlock( embed_dims=self.embed_dims[i], ffn_ratio=self.ffn_ratios[i], drop_rate=drop_rate, drop_path_rate=dpr[cur_block_idx + j], **block_cfgs) for j in range(depth) ]) cur_block_idx += depth norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] self.add_module(f'patch_embed{i + 1}', patch_embed) self.add_module(f'blocks{i + 1}', blocks) self.add_module(f'norm{i + 1}', norm) def train(self, mode=True): super(VAN, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() def _freeze_stages(self): for i in range(0, self.frozen_stages + 1): # freeze patch embed m = getattr(self, f'patch_embed{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False # freeze blocks m = getattr(self, f'blocks{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False # freeze norm m = getattr(self, f'norm{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False def forward(self, x): outs = [] for i in range(self.num_stages): patch_embed = getattr(self, f'patch_embed{i + 1}') blocks = getattr(self, f'blocks{i + 1}') norm = getattr(self, f'norm{i + 1}') x, hw_shape = patch_embed(x) for block in blocks: x = block(x) x = x.flatten(2).transpose(1, 2) x = norm(x) x = x.reshape(-1, *hw_shape, block.out_channels).permute(0, 3, 1, 2).contiguous() if i in self.out_indices: outs.append(x) return tuple(outs)