From 095ed243c04a3d00751c1cf2f657a548cae34155 Mon Sep 17 00:00:00 2001 From: sennnnn <58427300+sennnnn@users.noreply.github.com> Date: Tue, 20 Jul 2021 00:40:40 +0800 Subject: [PATCH] [Feature] Segformer backbone re-implementation (#594) * [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Modify default value of pad_to_patch_size arg. --- mmseg/models/backbones/__init__.py | 3 +- mmseg/models/backbones/mit.py | 416 +++++++++++++++++++ mmseg/models/backbones/swin.py | 1 + mmseg/models/backbones/vit.py | 1 + mmseg/models/utils/__init__.py | 5 +- mmseg/models/utils/ckpt_convert.py | 49 +++ mmseg/models/utils/embed.py | 26 +- mmseg/models/utils/shape_convert.py | 28 ++ tests/test_models/test_backbones/test_mit.py | 60 +++ 9 files changed, 578 insertions(+), 11 deletions(-) create mode 100644 mmseg/models/backbones/mit.py create mode 100644 mmseg/models/utils/shape_convert.py create mode 100644 tests/test_models/test_backbones/test_mit.py diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 43690d6c8..b8c17b218 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -1,6 +1,7 @@ from .cgnet import CGNet from .fast_scnn import FastSCNN from .hrnet import HRNet +from .mit import MixVisionTransformer from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 from .resnest import ResNeSt @@ -13,5 +14,5 @@ from .vit import VisionTransformer __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', - 'VisionTransformer', 'SwinTransformer' + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer' ] diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py new file mode 100644 index 000000000..cad0b4313 --- /dev/null +++ b/mmseg/models/backbones/mit.py @@ -0,0 +1,416 @@ +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, + constant_init, normal_init, trunc_normal_init) +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint + +from ...utils import get_root_logger +from ..builder import BACKBONES +from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Defalut: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1): + super(TransformerEncoderLayer, self).__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + def forward(self, x, hw_shape): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + +@BACKBONES.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + A PyTorch implement of : `SegFormer: Simple and Efficient Design for + Semantic Segmentation with Transformers` - + https://arxiv.org/pdf/2105.15203.pdf + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Defalut: dict(type='GELU'). + pretrain_style (str): Choose to use official or mmcls pretrain weights. + Default: official. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrain_style='official', + pretrained=None, + init_cfg=None): + super().__init__() + + assert pretrain_style in [ + 'official', 'mmcls' + ], 'we only support official weights or mmcls weights.' + + if isinstance(pretrained, str) or pretrained is None: + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + else: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + self.pretrain_style = pretrain_style + self.pretrained = pretrained + self.init_cfg = init_cfg + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + pad_to_patch_size=False, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m.weight, std=.02) + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + constant_init(m.bias, 0) + constant_init(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(self.pretrained, str): + logger = get_root_logger() + checkpoint = _load_checkpoint( + self.pretrained, logger=logger, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + if self.pretrain_style == 'official': + # Because segformer backbone is not support by mmcls, + # so we need to convert pretrain weights to match this + # implementation. + state_dict = mit_convert(state_dict) + + self.load_state_dict(state_dict, False) + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, H, W = layer[0](x), layer[0].DH, layer[0].DW + hw_shape = (H, W) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index a798ad1eb..1ea6389fa 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -628,6 +628,7 @@ class SwinTransformer(BaseModule): conv_type='Conv2d', kernel_size=patch_size, stride=strides[0], + pad_to_patch_size=True, norm_cfg=norm_cfg if patch_norm else None, init_cfg=None) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 33176351e..021bf0933 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -210,6 +210,7 @@ class VisionTransformer(BaseModule): conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, + pad_to_patch_size=True, norm_cfg=norm_cfg if patch_norm else None, init_cfg=None, ) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 277dd2676..32a953b83 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,14 +1,15 @@ -from .ckpt_convert import swin_convert, vit_convert +from .ckpt_convert import mit_convert, swin_convert, vit_convert from .embed import PatchEmbed from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock +from .shape_convert import nchw_to_nlc, nlc_to_nchw from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', - 'swin_convert', 'PatchEmbed' + 'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw' ] diff --git a/mmseg/models/utils/ckpt_convert.py b/mmseg/models/utils/ckpt_convert.py index 0b1b27707..26a1b96df 100644 --- a/mmseg/models/utils/ckpt_convert.py +++ b/mmseg/models/utils/ckpt_convert.py @@ -1,5 +1,7 @@ from collections import OrderedDict +import torch + def swin_convert(ckpt): new_ckpt = OrderedDict() @@ -88,3 +90,50 @@ def vit_convert(ckpt): new_ckpt[new_k] = v return new_ckpt + + +def mit_convert(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + new_k = new_k.replace('fc2.', '4.') + string += f'{new_k} {v.shape}-{new_v.shape}' + # print(string) + elif k.startswith('norm'): + stage_i = int(k.split('.')[0].replace('norm', '')) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + return new_ckpt diff --git a/mmseg/models/utils/embed.py b/mmseg/models/utils/embed.py index 3bbb45b37..73d8ed1f1 100644 --- a/mmseg/models/utils/embed.py +++ b/mmseg/models/utils/embed.py @@ -19,6 +19,8 @@ class PatchEmbed(BaseModule): Default: None (Default to be equal with kernel_size). padding (int): The padding length of embedding conv. Default: 0. dilation (int): The dilation rate of embedding conv. Default: 1. + pad_to_patch_size (bool, optional): Whether to pad feature map shape + to multiple patch size. Default: True. norm_cfg (dict, optional): Config dict for normalization layer. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. @@ -32,6 +34,7 @@ class PatchEmbed(BaseModule): stride=16, padding=0, dilation=1, + pad_to_patch_size=True, norm_cfg=None, init_cfg=None): super(PatchEmbed, self).__init__() @@ -42,7 +45,9 @@ class PatchEmbed(BaseModule): if stride is None: stride = kernel_size - # The default setting of patch size is eaual to kernel size. + self.pad_to_patch_size = pad_to_patch_size + + # The default setting of patch size is equal to kernel size. patch_size = kernel_size if isinstance(patch_size, int): patch_size = to_2tuple(patch_size) @@ -56,7 +61,7 @@ class PatchEmbed(BaseModule): self.patch_size = patch_size # Use conv layer to embed - conv_type = conv_type or dict(type='Conv2d') + conv_type = conv_type or 'Conv2d' self.projection = build_conv_layer( dict(type=conv_type), in_channels=in_channels, @@ -73,12 +78,17 @@ class PatchEmbed(BaseModule): def forward(self, x): H, W = x.shape[2], x.shape[3] - if H % self.patch_size[0] != 0: - x = F.pad(x, - (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) - if W % self.patch_size[1] != 0: - x = F.pad(x, - (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) + + # TODO: Process overlapping op + if self.pad_to_patch_size: + # Modify H, W to multiple of patch size. + if H % self.patch_size[0] != 0: + x = F.pad( + x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if W % self.patch_size[1] != 0: + x = F.pad( + x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) + x = self.projection(x) self.DH, self.DW = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) diff --git a/mmseg/models/utils/shape_convert.py b/mmseg/models/utils/shape_convert.py new file mode 100644 index 000000000..744416092 --- /dev/null +++ b/mmseg/models/utils/shape_convert.py @@ -0,0 +1,28 @@ +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before convertion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after convertion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before convertion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after convertion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() diff --git a/tests/test_models/test_backbones/test_mit.py b/tests/test_models/test_backbones/test_mit.py new file mode 100644 index 000000000..bf6cca164 --- /dev/null +++ b/tests/test_models/test_backbones/test_mit.py @@ -0,0 +1,60 @@ +import pytest +import torch + +from mmseg.models.backbones import MixVisionTransformer +from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN + + +def test_mit(): + with pytest.raises(AssertionError): + # It's only support official style and mmcls style now. + MixVisionTransformer(pretrain_style='timm') + + with pytest.raises(TypeError): + # Pretrained represents pretrain url and must be str or None. + MixVisionTransformer(pretrained=123) + + # Test normal input + H, W = (224, 224) + temp = torch.randn((1, 3, H, W)) + model = MixVisionTransformer( + embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3)) + model.init_weights() + outs = model(temp) + assert outs[0].shape == (1, 32, H // 4, W // 4) + assert outs[1].shape == (1, 64, H // 8, W // 8) + assert outs[2].shape == (1, 160, H // 16, W // 16) + assert outs[3].shape == (1, 256, H // 32, W // 32) + + # Test non-squared input + H, W = (224, 320) + temp = torch.randn((1, 3, H, W)) + outs = model(temp) + assert outs[0].shape == (1, 32, H // 4, W // 4) + assert outs[1].shape == (1, 64, H // 8, W // 8) + assert outs[2].shape == (1, 160, H // 16, W // 16) + assert outs[3].shape == (1, 256, H // 32, W // 32) + + # Test MixFFN + FFN = MixFFN(128, 512) + hw_shape = (32, 32) + token_len = 32 * 32 + temp = torch.randn((1, token_len, 128)) + # Self identity + out = FFN(temp, hw_shape) + assert out.shape == (1, token_len, 128) + # Out identity + outs = FFN(temp, hw_shape, temp) + assert out.shape == (1, token_len, 128) + + # Test EfficientMHA + MHA = EfficientMultiheadAttention(128, 2) + hw_shape = (32, 32) + token_len = 32 * 32 + temp = torch.randn((1, token_len, 128)) + # Self identity + out = MHA(temp, hw_shape) + assert out.shape == (1, token_len, 128) + # Out identity + outs = MHA(temp, hw_shape, temp) + assert out.shape == (1, token_len, 128)