# Copyright (c) OpenMMLab. All rights reserved. import math import warnings import torch import torch.nn as nn from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.utils.weight_init import trunc_normal_init from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.utils import to_2tuple from mmseg.models import BACKBONES from mmseg.models.utils.embed import AdaptivePadding class SimplifiedPatchEmbed(BaseModule): """Image to Patch Embedding. We use a conv layer to implement SimplifiedPatchEmbed. Args: in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 conv_type (str): The config dict for embedding conv layer type selection. Default: "Conv2d". kernel_size (int): The kernel_size of embedding conv. Default: 16. stride (int, optional): The slide stride of embedding conv. Default: None (Would be set as `kernel_size`). padding (int | tuple | string ): The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Default: "corner". dilation (int): The dilation rate of embedding conv. Default: 1. bias (bool): Bias of embed conv. Default: True. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. input_size (int | tuple | None): The size of input, which will be used to calculate the out size. Only work when `dynamic_size` is False. Default: None. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. """ def __init__(self, in_channels=3, embed_dims=768, conv_type='Conv2d', kernel_size=16, stride=None, padding='corner', dilation=1, bias=True, norm_cfg=None, input_size=None, init_cfg=None): super(SimplifiedPatchEmbed, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims if stride is None: stride = kernel_size kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) if isinstance(padding, str): self.adap_padding = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) # disable the padding of conv padding = 0 else: self.adap_padding = None padding = to_2tuple(padding) self.projection = build_conv_layer( dict(type=conv_type), in_channels=in_channels, out_channels=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None if input_size: input_size = to_2tuple(input_size) # `init_out_size` would be used outside to # calculate the num_patches # when `use_abs_pos_embed` outside self.init_input_size = input_size if self.adap_padding: pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) input_h, input_w = input_size input_h = input_h + pad_h input_w = input_w + pad_w input_size = (input_h, input_w) # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html h_out = (input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 w_out = (input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 self.init_out_size = (h_out, w_out) else: self.init_input_size = None self.init_out_size = None self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): """ Args: x (Tensor): Has shape (B, C, H, W). In most case, C is 3. Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, embed_dims, out_h * out_w) - out_size (tuple[int]): Spatial shape of x, arrange as (out_h, out_w). """ if self.adap_padding: x = self.adap_padding(x) x = self.projection(x) out_size = (x.shape[2], x.shape[3]) if self.norm is not None: x = self.norm(x) x = x.flatten(2) return x, out_size class DWConv(nn.Module): def __init__(self, dims): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dims, dims, 3, 1, 1, bias=True, groups=dims) def forward(self, x, H, W): B, C, N = x.shape x = x.reshape(B, C, H, W) x = self.dwconv(x) x = x.flatten(2) return x class MixFFN(nn.Module): """An implementation of MixFFN of DEST. The differences between MixFFN & FFN: 1. Use 1X1 Conv to replace Linear layer. 2. Introduce 3X3 Conv to encode positional information. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. Defaults: 256. feedforward_channels (int): The hidden dimension of FFNs. Defaults: 1024. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='ReLU') ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, feedforward_channels, act_cfg=dict(type='ReLU'), ffn_drop=0., norm_cfg=dict(type='SyncBN', requires_grad=True), dropout_layer=None, init_cfg=None): super(MixFFN, self).__init__() self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.act_cfg = act_cfg activate = build_activation_layer(act_cfg) in_channels = embed_dims fc1 = nn.Conv1d( in_channels=in_channels, out_channels=feedforward_channels, kernel_size=1, stride=1) norm1 = build_norm_layer(norm_cfg, feedforward_channels)[1] self.dwconv = DWConv(feedforward_channels) norm2 = build_norm_layer(norm_cfg, feedforward_channels)[1] fc2 = nn.Conv1d( in_channels=feedforward_channels, out_channels=in_channels, kernel_size=1, stride=1) drop = nn.Dropout(ffn_drop) pre_layers = [fc1, norm1] post_layers = [norm2, activate, drop, fc2, drop] self.pre_layers = Sequential(*pre_layers) self.post_layers = Sequential(*post_layers) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv1d): trunc_normal_init(m, std=.02, bias=0.) def forward(self, x, hw_shape, identity): out = self.pre_layers(x) out = self.dwconv(out, hw_shape[0], hw_shape[1]) out = self.post_layers(out) return identity + self.dropout_layer(out) class SimplifiedAttention(nn.Module): """An implementation of Simplified Multi-head Attention of DEST. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head Attention of Segformer. Default: 1. qkv_bias (bool): enable bias for qkv if True. Default True. qk_scale (float, optional): scales for query and key. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='SyncBN', requires_grad=True). """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., sr_ratio=1, qkv_bias=False, qk_scale=None, dropout_layer=None, norm_cfg=dict(type='SyncBN', requires_grad=True)): super().__init__() self.embed_dims = embed_dims self.num_heads = num_heads head_dim = embed_dims // num_heads self.scale = qk_scale or head_dim**-0.5 self.q = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias) self.k = nn.Conv1d(embed_dims, embed_dims, 1, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Conv1d(embed_dims, embed_dims, 1) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d( embed_dims, embed_dims, kernel_size=sr_ratio, stride=sr_ratio) self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Conv1d): trunc_normal_init(m, std=.02, bias=0.) def forward(self, x, hw_shape, identity): H, W = hw_shape B, C, N = x.shape q = self.q(x) q = q.reshape(B, self.num_heads, C // self.num_heads, N) q = q.permute(0, 1, 3, 2) if self.sr_ratio > 1: x_ = x.reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1) x_ = self.norm1(x_) k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1) else: k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1) v = torch.mean(x, 2, True).repeat(1, 1, self.num_heads).transpose(-2, -1) attn = (q @ k) * self.scale attn, _ = torch.max(attn, -1) out = (attn.transpose(-2, -1) @ v) out = out.transpose(-2, -1) out = self.proj(out) return identity + self.dropout_layer(out) class SimpliefiedTransformerEncoderLayer(BaseModule): """Implements one encoder layer in DEST. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. drop_rate (float): Probability of an element to be zeroed. after the feed forward layer. Default 0.0. attn_drop_rate (float): The drop out rate for attention layer. Default 0.0. drop_path_rate (float): stochastic depth rate. Default 0.0. qkv_bias (bool): enable bias for qkv if True. Default: True. act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default: False. qk_scale (float, optional): scales for query and key. Default: None. init_cfg (dict, optional): Initialization config dict. Default:None. sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head Attention of Segformer. Default: 1. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, embed_dims, num_heads, feedforward_channels, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., qkv_bias=True, act_cfg=dict(type='ReLU'), norm_cfg=dict(type='SyncBN'), batch_first=True, qk_scale=None, sr_ratio=1, with_cp=False): super(SimpliefiedTransformerEncoderLayer, self).__init__() # The ret[0] of build_norm_layer is norm name. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = SimplifiedAttention( embed_dims=embed_dims, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate, sr_ratio=sr_ratio, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)) self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = MixFFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, hw_shape): x = self.attn(self.norm1(x), hw_shape, identity=x) x = self.ffn(self.norm2(x), hw_shape, identity=x) return x @BACKBONES.register_module() class SimplifiedMixTransformer(BaseModule): """The backbone of DEST. This backbone is the implementation of `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers `_. Args: in_channels (int): Number of input channels. Default: 3. embed_dims (Sequence[int]): Embedding dimensions of each transformer encode layer. Default: [32, 64, 160, 256]. num_stags (int): The num of stages. Default: 4. num_layers (Sequence[int]): The layer number of each transformer encode layer. Default: [3, 4, 6, 3]. num_heads (Sequence[int]): The attention heads of each transformer encode layer. Default: [1, 2, 4, 8]. patch_sizes (Sequence[int]): The patch_size of each overlapped patch embedding. Default: [7, 3, 3, 3]. strides (Sequence[int]): The stride of each overlapped patch embedding. Default: [4, 2, 2, 2]. sr_ratios (Sequence[int]): The spatial reduction rate of each transformer encode layer. Default: [8, 4, 2, 1]. out_indices (Sequence[int] | int): Output from which stages. Default: (0, 1, 2, 3). mlp_ratios (Sequence[int]): ratios of mlp hidden dim to embedding dim. Default: [8, 8, 4, 4]. qkv_bias (bool): Enable bias for qkv if True. Default: True. drop_rate (float): Probability of an element to be zeroed. Default 0.0 attn_drop_rate (float): The drop out rate for attention layer. Default 0.0 drop_path_rate (float): stochastic depth rate. Default 0.0 norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, in_channels=3, embed_dims=[32, 64, 160, 256], num_stages=4, num_layers=[2, 2, 2, 2], num_heads=[1, 2, 4, 8], patch_sizes=[7, 3, 3, 3], strides=[4, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), mlp_ratios=[8, 8, 4, 4], qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='ReLU'), norm_cfg=dict(type='SyncBN', requires_grad=True), pretrained=None, init_cfg=None, with_cp=False): super(SimplifiedMixTransformer, self).__init__(init_cfg=init_cfg) assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be set at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is not None: raise TypeError('pretrained must be a str or None') self.embed_dims = embed_dims self.num_stages = num_stages self.num_layers = num_layers self.num_heads = num_heads self.patch_sizes = patch_sizes self.strides = strides self.sr_ratios = sr_ratios self.with_cp = with_cp assert num_stages == len(num_layers) == len(num_heads) \ == len(patch_sizes) == len(strides) == len(sr_ratios) self.out_indices = out_indices assert max(out_indices) < self.num_stages # transformer encoder dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers)) ] # stochastic num_layer decay rule cur = 0 self.layers = ModuleList() for i, num_layer in enumerate(num_layers): patch_embed = SimplifiedPatchEmbed( in_channels=in_channels, embed_dims=embed_dims[i], kernel_size=patch_sizes[i], stride=strides[i], padding=patch_sizes[i] // 2, norm_cfg=norm_cfg) layer = ModuleList([ SimpliefiedTransformerEncoderLayer( embed_dims=embed_dims[i], num_heads=num_heads[i], feedforward_channels=mlp_ratios[i] * embed_dims[i], drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[cur + idx], qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, with_cp=with_cp, sr_ratio=sr_ratios[i]) for idx in range(num_layer) ]) in_channels = embed_dims[i] # The ret[0] of build_norm_layer is norm name. norm = build_norm_layer(norm_cfg, embed_dims[i])[1] self.layers.append(ModuleList([patch_embed, layer, norm])) cur += num_layer self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): outs = [] for i, layer in enumerate(self.layers): x, (H, W) = layer[0](x) for block in layer[1]: x = block(x, (H, W)) x = layer[2](x) N, C, L = x.shape x = x.reshape(N, C, H, W) outs.append(x) return outs