Sixiao Zheng ec91893931
[Feature] Official implementation of SETR (#531)
* Adjust vision transformer backbone architectures;

* Add DropPath, trunc_normal_ for VisionTransformer implementation;

* Add class token buring intermediate period and remove it during final period;

* Fix some parameters loss bug;

* * Store intermediate token features and impose no processes on them;

* Remove class token and reshape entire token feature from NLC to NCHW;

* Fix some doc error

* Add a arg for VisionTransformer backbone to control if input class token into transformer;

* Add stochastic depth decay rule for DropPath;

* * Fix output bug when input_cls_token=False;

* Add related unit test;

* Re-implement of SETR

* Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA);

* * Modify some docs of heads of SETR;

* Add MLA auxiliary head of SETR;

* * Modify some arg of setr heads;

* Add unit test for setr heads;

* * Add 768x768 cityscapes dataset config;

* Add Backbone: SETR -- Backbone: MLA, PUP, Naive;

* Add SETR cityscapes training & testing config;

* * Fix the low code coverage of unit test about heads of setr;

* Remove some rebundant error capture;

* * Add pascal context dataset & ade20k dataset config;

* Modify auxiliary head relative config;

* Modify folder structure.

* add setr

* modify vit

* Fix the test_cfg arg position;

* Fix some learning schedule bug;

* optimize setr code

* Add arg: final_reshape to control if converting output feature information from NLC to NCHW;

* Fix the default value of final_reshape;

* Modify arg: final_reshape to arg: out_shape;

* Fix some unit test bug;

* Add MLA neck;

* Modify setr configs to add MLA neck;

* Modify MLA decode head to remove rebundant structure;

* Remove some rebundant files.

* * Fix the code style bug;

* Remove some rebundant files;

* Modify some unit tests of SETR;

* Ignoring CityscapesCoarseDataset and MapillaryDataset.

* Fix the activation function loss bug;

* Fix the img_size bug of SETR_PUP_ADE20K

* * Fix the lint bug of transformers.py;

* Add mla neck unit test;

* Convert vit of setr out shape from NLC to NCHW.

* * Modify Resize action of data pipeline;

* Fix deit related bug;

* Set find_unused_parameters=False for pascal context dataset;

* Remove arg: find_unused_parameters which is False by default.

* Error auxiliary head of PUP deit

* Remove the minimal restrict of slide inference.

* Modify doc string of Resize

* Seperate this part of code to a new PR #544

* * Remove some rebundant codes;

* Modify unit tests of SETR heads;

* Fix the tuple in_channels of mla_deit.

* Modify code style

* Move detailed definition of auxiliary head into model config dict;

* Add some setr config for default cityscapes.py;

* Fix the doc string of SETR head;

* Modify implementation of SETR Heads

* Remove setr aux head and use fcn head to replace it;

* Remove arg: img_size and remove last interpolate op of heads;

* Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead;

* non-square input support for setr heads

* Modify config argument for above commits

* Remove norm_layer argument of SETRMLAHead

* Add mla_align_corners for MLAModule interpolate

* [Refactor]Refactor of SETRMLAHead

* Modify Head implementation;

* Modify Head unit test;

* Modify related config file;

* [Refactor]MLA Neck

* Fix config bug

* [Refactor]SETR Naive Head and SETR PUP Head

* [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg

* Fix config error

* Refactor of SETR MLA, Naive, PUP heads.

* Modify some attribute name of SETR Heads.

* Modify setr configs to adapt new vit code.

* Fix trunc_normal_ bug

* Parameters init adjustment.

* Remove redundant doc string of SETRUPHead

* Fix pretrained bug

* [Fix] Fix vit init bug

* Add some vit unit tests

* Modify module import

* Remove norm from PatchEmbed

* Fix pretrain weights bug

* Modify pretrained judge

* Fix some gradient backward bugs.

* Add some unit tests to improve code cov

* Fix init_weights of setr up head

* Add DropPath in FFN

* Finish benchmark of SETR

1. Add benchmark information into README.MD of SETR;

2. Fix some name bugs of vit;

* Remove DropPath implementation and use DropPath from mmcv.

* Modify out_indices arg

* Fix out_indices bug.

* Remove cityscapes base dataset config.

Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn>
Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
2021-06-23 09:39:29 -07:00

118 lines
3.7 KiB
Python

import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import NECKS
class MLAModule(nn.Module):
def __init__(self,
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=None,
act_cfg=None):
super(MLAModule, self).__init__()
self.channel_proj = nn.ModuleList()
for i in range(len(in_channels)):
self.channel_proj.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.feat_extract = nn.ModuleList()
for i in range(len(in_channels)):
self.feat_extract.append(
ConvModule(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
# feat_list -> [p2, p3, p4, p5]
feat_list = []
for x, conv in zip(inputs, self.channel_proj):
feat_list.append(conv(x))
# feat_list -> [p5, p4, p3, p2]
# mid_list -> [m5, m4, m3, m2]
feat_list = feat_list[::-1]
mid_list = []
for feat in feat_list:
if len(mid_list) == 0:
mid_list.append(feat)
else:
mid_list.append(mid_list[-1] + feat)
# mid_list -> [m5, m4, m3, m2]
# out_list -> [o2, o3, o4, o5]
out_list = []
for mid, conv in zip(mid_list, self.feat_extract):
out_list.append(conv(mid))
return tuple(out_list)
@NECKS.register_module()
class MLANeck(nn.Module):
"""Multi-level Feature Aggregation.
The Multi-level Feature Aggregation construction of SETR:
https://arxiv.org/pdf/2012.15840.pdf
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
norm_cfg=None,
act_cfg=None):
super(MLANeck, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
# In order to build general vision transformer backbone, we have to
# move MLA to neck.
self.norm = nn.ModuleList([
build_norm_layer(norm_layer, in_channels[i])[1]
for i in range(len(in_channels))
])
self.mla = MLAModule(
in_channels=in_channels,
out_channels=out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# Convert from nchw to nlc
outs = []
for i in range(len(inputs)):
x = inputs[i]
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm[i](x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
outs.append(x)
outs = self.mla(outs)
return tuple(outs)