[Feature] Segformer backbone re-implementation (#594)
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Modify default value of pad_to_patch_size arg.pull/1801/head
parent
f6246d6eaa
commit
095ed243c0
|
@ -1,6 +1,7 @@
|
|||
from .cgnet import CGNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .mit import MixVisionTransformer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .resnest import ResNeSt
|
||||
|
@ -13,5 +14,5 @@ from .vit import VisionTransformer
|
|||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer'
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,416 @@
|
|||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
|
||||
constant_init, normal_init, trunc_normal_init)
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
||||
|
||||
from ...utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of Segformer.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super(MixFFN, self).__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
# 3x3 depth wise conv to provide positional encode information
|
||||
pe_conv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=(3 - 1) // 2,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
out = nlc_to_nchw(x, hw_shape)
|
||||
out = self.layers(out)
|
||||
out = nchw_to_nlc(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class EfficientMultiheadAttention(MultiheadAttention):
|
||||
"""An implementation of Efficient Multi-head Attention of Segformer.
|
||||
|
||||
This module is modified from MultiheadAttention which is a module from
|
||||
mmcv.cnn.bricks.transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None,
|
||||
batch_first=True,
|
||||
qkv_bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop,
|
||||
proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
init_cfg=init_cfg,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=sr_ratio,
|
||||
stride=sr_ratio)
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Segformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defalut: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.attn = EfficientMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
A PyTorch implement of : `SegFormer: Simple and Efficient Design for
|
||||
Semantic Segmentation with Transformers` -
|
||||
https://arxiv.org/pdf/2105.15203.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defalut: dict(type='GELU').
|
||||
pretrain_style (str): Choose to use official or mmcls pretrain weights.
|
||||
Default: official.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=64,
|
||||
num_stages=4,
|
||||
num_layers=[3, 4, 6, 3],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrain_style='official',
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert pretrain_style in [
|
||||
'official', 'mmcls'
|
||||
], 'we only support official weights or mmcls weights.'
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
self.pretrain_style = pretrain_style
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
embed_dims_i = embed_dims * num_heads[i]
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims_i,
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
pad_to_patch_size=False,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims_i,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratio * embed_dims_i,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
elif isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
checkpoint = _load_checkpoint(
|
||||
self.pretrained, logger=logger, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
if self.pretrain_style == 'official':
|
||||
# Because segformer backbone is not support by mmcls,
|
||||
# so we need to convert pretrain weights to match this
|
||||
# implementation.
|
||||
state_dict = mit_convert(state_dict)
|
||||
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
|
||||
hw_shape = (H, W)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
x = nlc_to_nchw(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
|
@ -628,6 +628,7 @@ class SwinTransformer(BaseModule):
|
|||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=strides[0],
|
||||
pad_to_patch_size=True,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
|
|
|
@ -210,6 +210,7 @@ class VisionTransformer(BaseModule):
|
|||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
pad_to_patch_size=True,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from .ckpt_convert import swin_convert, vit_convert
|
||||
from .ckpt_convert import mit_convert, swin_convert, vit_convert
|
||||
from .embed import PatchEmbed
|
||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||
from .make_divisible import make_divisible
|
||||
from .res_layer import ResLayer
|
||||
from .se_layer import SELayer
|
||||
from .self_attention_block import SelfAttentionBlock
|
||||
from .shape_convert import nchw_to_nlc, nlc_to_nchw
|
||||
from .up_conv_block import UpConvBlock
|
||||
|
||||
__all__ = [
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
|
||||
'swin_convert', 'PatchEmbed'
|
||||
'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def swin_convert(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
@ -88,3 +90,50 @@ def vit_convert(ckpt):
|
|||
new_ckpt[new_k] = v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def mit_convert(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
# Process the concat between q linear weights and kv linear weights
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('patch_embed'):
|
||||
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
|
||||
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
|
||||
new_v = v
|
||||
if 'proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'projection.')
|
||||
elif k.startswith('block'):
|
||||
stage_i = int(k.split('.')[0].replace('block', ''))
|
||||
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
|
||||
new_v = v
|
||||
if 'attn.q.' in new_k:
|
||||
sub_item_k = k.replace('q.', 'kv.')
|
||||
new_k = new_k.replace('q.', 'attn.in_proj_')
|
||||
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
|
||||
elif 'attn.kv.' in new_k:
|
||||
continue
|
||||
elif 'attn.proj.' in new_k:
|
||||
new_k = new_k.replace('proj.', 'attn.out_proj.')
|
||||
elif 'attn.sr.' in new_k:
|
||||
new_k = new_k.replace('sr.', 'sr.')
|
||||
elif 'mlp.' in new_k:
|
||||
string = f'{new_k}-'
|
||||
new_k = new_k.replace('mlp.', 'ffn.layers.')
|
||||
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
|
||||
new_v = v.reshape((*v.shape, 1, 1))
|
||||
new_k = new_k.replace('fc1.', '0.')
|
||||
new_k = new_k.replace('dwconv.dwconv.', '1.')
|
||||
new_k = new_k.replace('fc2.', '4.')
|
||||
string += f'{new_k} {v.shape}-{new_v.shape}'
|
||||
# print(string)
|
||||
elif k.startswith('norm'):
|
||||
stage_i = int(k.split('.')[0].replace('norm', ''))
|
||||
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
|
||||
new_v = v
|
||||
else:
|
||||
new_k = k
|
||||
new_v = v
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
|
|
@ -19,6 +19,8 @@ class PatchEmbed(BaseModule):
|
|||
Default: None (Default to be equal with kernel_size).
|
||||
padding (int): The padding length of embedding conv. Default: 0.
|
||||
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||||
pad_to_patch_size (bool, optional): Whether to pad feature map shape
|
||||
to multiple patch size. Default: True.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
||||
Default: None.
|
||||
|
@ -32,6 +34,7 @@ class PatchEmbed(BaseModule):
|
|||
stride=16,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
pad_to_patch_size=True,
|
||||
norm_cfg=None,
|
||||
init_cfg=None):
|
||||
super(PatchEmbed, self).__init__()
|
||||
|
@ -42,7 +45,9 @@ class PatchEmbed(BaseModule):
|
|||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
# The default setting of patch size is eaual to kernel size.
|
||||
self.pad_to_patch_size = pad_to_patch_size
|
||||
|
||||
# The default setting of patch size is equal to kernel size.
|
||||
patch_size = kernel_size
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
@ -56,7 +61,7 @@ class PatchEmbed(BaseModule):
|
|||
self.patch_size = patch_size
|
||||
|
||||
# Use conv layer to embed
|
||||
conv_type = conv_type or dict(type='Conv2d')
|
||||
conv_type = conv_type or 'Conv2d'
|
||||
self.projection = build_conv_layer(
|
||||
dict(type=conv_type),
|
||||
in_channels=in_channels,
|
||||
|
@ -73,12 +78,17 @@ class PatchEmbed(BaseModule):
|
|||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2], x.shape[3]
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x,
|
||||
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x,
|
||||
(0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
|
||||
|
||||
# TODO: Process overlapping op
|
||||
if self.pad_to_patch_size:
|
||||
# Modify H, W to multiple of patch size.
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(
|
||||
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(
|
||||
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
|
||||
|
||||
x = self.projection(x)
|
||||
self.DH, self.DW = x.shape[2], x.shape[3]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
def nlc_to_nchw(x, hw_shape):
|
||||
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, L, C] before convertion.
|
||||
hw_shape (Sequence[int]): The height and width of output feature map.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, C, H, W] after convertion.
|
||||
"""
|
||||
H, W = hw_shape
|
||||
assert len(x.shape) == 3
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||
return x.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
|
||||
def nchw_to_nlc(x):
|
||||
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor of shape [N, C, H, W] before convertion.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape [N, L, C] after convertion.
|
||||
"""
|
||||
assert len(x.shape) == 4
|
||||
return x.flatten(2).transpose(1, 2).contiguous()
|
|
@ -0,0 +1,60 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import MixVisionTransformer
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
||||
|
||||
|
||||
def test_mit():
|
||||
with pytest.raises(AssertionError):
|
||||
# It's only support official style and mmcls style now.
|
||||
MixVisionTransformer(pretrain_style='timm')
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained represents pretrain url and must be str or None.
|
||||
MixVisionTransformer(pretrained=123)
|
||||
|
||||
# Test normal input
|
||||
H, W = (224, 224)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
model = MixVisionTransformer(
|
||||
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
||||
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
||||
|
||||
# Test non-squared input
|
||||
H, W = (224, 320)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
||||
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
||||
|
||||
# Test MixFFN
|
||||
FFN = MixFFN(128, 512)
|
||||
hw_shape = (32, 32)
|
||||
token_len = 32 * 32
|
||||
temp = torch.randn((1, token_len, 128))
|
||||
# Self identity
|
||||
out = FFN(temp, hw_shape)
|
||||
assert out.shape == (1, token_len, 128)
|
||||
# Out identity
|
||||
outs = FFN(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 128)
|
||||
|
||||
# Test EfficientMHA
|
||||
MHA = EfficientMultiheadAttention(128, 2)
|
||||
hw_shape = (32, 32)
|
||||
token_len = 32 * 32
|
||||
temp = torch.randn((1, token_len, 128))
|
||||
# Self identity
|
||||
out = MHA(temp, hw_shape)
|
||||
assert out.shape == (1, token_len, 128)
|
||||
# Out identity
|
||||
outs = MHA(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 128)
|
Loading…
Reference in New Issue