mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* Adjust vision transformer backbone architectures; * Add DropPath, trunc_normal_ for VisionTransformer implementation; * Add class token buring intermediate period and remove it during final period; * Fix some parameters loss bug; * * Store intermediate token features and impose no processes on them; * Remove class token and reshape entire token feature from NLC to NCHW; * Fix some doc error * Add a arg for VisionTransformer backbone to control if input class token into transformer; * Add stochastic depth decay rule for DropPath; * * Fix output bug when input_cls_token=False; * Add related unit test; * Re-implement of SETR * Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA); * * Modify some docs of heads of SETR; * Add MLA auxiliary head of SETR; * * Modify some arg of setr heads; * Add unit test for setr heads; * * Add 768x768 cityscapes dataset config; * Add Backbone: SETR -- Backbone: MLA, PUP, Naive; * Add SETR cityscapes training & testing config; * * Fix the low code coverage of unit test about heads of setr; * Remove some rebundant error capture; * * Add pascal context dataset & ade20k dataset config; * Modify auxiliary head relative config; * Modify folder structure. * add setr * modify vit * Fix the test_cfg arg position; * Fix some learning schedule bug; * optimize setr code * Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug; * Add MLA neck; * Modify setr configs to add MLA neck; * Modify MLA decode head to remove rebundant structure; * Remove some rebundant files. * * Fix the code style bug; * Remove some rebundant files; * Modify some unit tests of SETR; * Ignoring CityscapesCoarseDataset and MapillaryDataset. * Fix the activation function loss bug; * Fix the img_size bug of SETR_PUP_ADE20K * * Fix the lint bug of transformers.py; * Add mla neck unit test; * Convert vit of setr out shape from NLC to NCHW. * * Modify Resize action of data pipeline; * Fix deit related bug; * Set find_unused_parameters=False for pascal context dataset; * Remove arg: find_unused_parameters which is False by default. * Error auxiliary head of PUP deit * Remove the minimal restrict of slide inference. * Modify doc string of Resize * Seperate this part of code to a new PR #544 * * Remove some rebundant codes; * Modify unit tests of SETR heads; * Fix the tuple in_channels of mla_deit. * Modify code style * Move detailed definition of auxiliary head into model config dict; * Add some setr config for default cityscapes.py; * Fix the doc string of SETR head; * Modify implementation of SETR Heads * Remove setr aux head and use fcn head to replace it; * Remove arg: img_size and remove last interpolate op of heads; * Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead; * non-square input support for setr heads * Modify config argument for above commits * Remove norm_layer argument of SETRMLAHead * Add mla_align_corners for MLAModule interpolate * [Refactor]Refactor of SETRMLAHead * Modify Head implementation; * Modify Head unit test; * Modify related config file; * [Refactor]MLA Neck * Fix config bug * [Refactor]SETR Naive Head and SETR PUP Head * [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg * Fix config error * Refactor of SETR MLA, Naive, PUP heads. * Modify some attribute name of SETR Heads. * Modify setr configs to adapt new vit code. * Fix trunc_normal_ bug * Parameters init adjustment. * Remove redundant doc string of SETRUPHead * Fix pretrained bug * [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Remove norm from PatchEmbed * Fix pretrain weights bug * Modify pretrained judge * Fix some gradient backward bugs. * Add some unit tests to improve code cov * Fix init_weights of setr up head * Add DropPath in FFN * Finish benchmark of SETR 1. Add benchmark information into README.MD of SETR; 2. Fix some name bugs of vit; * Remove DropPath implementation and use DropPath from mmcv. * Modify out_indices arg * Fix out_indices bug. * Remove cityscapes base dataset config. Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn> Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_norm_layer
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
class MLAModule(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels=[1024, 1024, 1024, 1024],
|
|
out_channels=256,
|
|
norm_cfg=None,
|
|
act_cfg=None):
|
|
super(MLAModule, self).__init__()
|
|
self.channel_proj = nn.ModuleList()
|
|
for i in range(len(in_channels)):
|
|
self.channel_proj.append(
|
|
ConvModule(
|
|
in_channels=in_channels[i],
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
self.feat_extract = nn.ModuleList()
|
|
for i in range(len(in_channels)):
|
|
self.feat_extract.append(
|
|
ConvModule(
|
|
in_channels=out_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
def forward(self, inputs):
|
|
|
|
# feat_list -> [p2, p3, p4, p5]
|
|
feat_list = []
|
|
for x, conv in zip(inputs, self.channel_proj):
|
|
feat_list.append(conv(x))
|
|
|
|
# feat_list -> [p5, p4, p3, p2]
|
|
# mid_list -> [m5, m4, m3, m2]
|
|
feat_list = feat_list[::-1]
|
|
mid_list = []
|
|
for feat in feat_list:
|
|
if len(mid_list) == 0:
|
|
mid_list.append(feat)
|
|
else:
|
|
mid_list.append(mid_list[-1] + feat)
|
|
|
|
# mid_list -> [m5, m4, m3, m2]
|
|
# out_list -> [o2, o3, o4, o5]
|
|
out_list = []
|
|
for mid, conv in zip(mid_list, self.feat_extract):
|
|
out_list.append(conv(mid))
|
|
|
|
return tuple(out_list)
|
|
|
|
|
|
@NECKS.register_module()
|
|
class MLANeck(nn.Module):
|
|
"""Multi-level Feature Aggregation.
|
|
|
|
The Multi-level Feature Aggregation construction of SETR:
|
|
https://arxiv.org/pdf/2012.15840.pdf
|
|
|
|
|
|
Args:
|
|
in_channels (List[int]): Number of input channels per scale.
|
|
out_channels (int): Number of output channels (used at each scale).
|
|
norm_layer (dict): Config dict for input normalization.
|
|
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
|
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
act_cfg (dict): Config dict for activation layer in ConvModule.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
|
|
norm_cfg=None,
|
|
act_cfg=None):
|
|
super(MLANeck, self).__init__()
|
|
assert isinstance(in_channels, list)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
# In order to build general vision transformer backbone, we have to
|
|
# move MLA to neck.
|
|
self.norm = nn.ModuleList([
|
|
build_norm_layer(norm_layer, in_channels[i])[1]
|
|
for i in range(len(in_channels))
|
|
])
|
|
|
|
self.mla = MLAModule(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == len(self.in_channels)
|
|
|
|
# Convert from nchw to nlc
|
|
outs = []
|
|
for i in range(len(inputs)):
|
|
x = inputs[i]
|
|
n, c, h, w = x.shape
|
|
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
|
|
x = self.norm[i](x)
|
|
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
|
|
outs.append(x)
|
|
|
|
outs = self.mla(outs)
|
|
return tuple(outs)
|