mmsegmentation/mmseg/models/necks/mla_neck.py

119 lines
3.8 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
[Feature] Official implementation of SETR (#531) * Adjust vision transformer backbone architectures; * Add DropPath, trunc_normal_ for VisionTransformer implementation; * Add class token buring intermediate period and remove it during final period; * Fix some parameters loss bug; * * Store intermediate token features and impose no processes on them; * Remove class token and reshape entire token feature from NLC to NCHW; * Fix some doc error * Add a arg for VisionTransformer backbone to control if input class token into transformer; * Add stochastic depth decay rule for DropPath; * * Fix output bug when input_cls_token=False; * Add related unit test; * Re-implement of SETR * Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA); * * Modify some docs of heads of SETR; * Add MLA auxiliary head of SETR; * * Modify some arg of setr heads; * Add unit test for setr heads; * * Add 768x768 cityscapes dataset config; * Add Backbone: SETR -- Backbone: MLA, PUP, Naive; * Add SETR cityscapes training & testing config; * * Fix the low code coverage of unit test about heads of setr; * Remove some rebundant error capture; * * Add pascal context dataset & ade20k dataset config; * Modify auxiliary head relative config; * Modify folder structure. * add setr * modify vit * Fix the test_cfg arg position; * Fix some learning schedule bug; * optimize setr code * Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug; * Add MLA neck; * Modify setr configs to add MLA neck; * Modify MLA decode head to remove rebundant structure; * Remove some rebundant files. * * Fix the code style bug; * Remove some rebundant files; * Modify some unit tests of SETR; * Ignoring CityscapesCoarseDataset and MapillaryDataset. * Fix the activation function loss bug; * Fix the img_size bug of SETR_PUP_ADE20K * * Fix the lint bug of transformers.py; * Add mla neck unit test; * Convert vit of setr out shape from NLC to NCHW. * * Modify Resize action of data pipeline; * Fix deit related bug; * Set find_unused_parameters=False for pascal context dataset; * Remove arg: find_unused_parameters which is False by default. * Error auxiliary head of PUP deit * Remove the minimal restrict of slide inference. * Modify doc string of Resize * Seperate this part of code to a new PR #544 * * Remove some rebundant codes; * Modify unit tests of SETR heads; * Fix the tuple in_channels of mla_deit. * Modify code style * Move detailed definition of auxiliary head into model config dict; * Add some setr config for default cityscapes.py; * Fix the doc string of SETR head; * Modify implementation of SETR Heads * Remove setr aux head and use fcn head to replace it; * Remove arg: img_size and remove last interpolate op of heads; * Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead; * non-square input support for setr heads * Modify config argument for above commits * Remove norm_layer argument of SETRMLAHead * Add mla_align_corners for MLAModule interpolate * [Refactor]Refactor of SETRMLAHead * Modify Head implementation; * Modify Head unit test; * Modify related config file; * [Refactor]MLA Neck * Fix config bug * [Refactor]SETR Naive Head and SETR PUP Head * [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg * Fix config error * Refactor of SETR MLA, Naive, PUP heads. * Modify some attribute name of SETR Heads. * Modify setr configs to adapt new vit code. * Fix trunc_normal_ bug * Parameters init adjustment. * Remove redundant doc string of SETRUPHead * Fix pretrained bug * [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Remove norm from PatchEmbed * Fix pretrain weights bug * Modify pretrained judge * Fix some gradient backward bugs. * Add some unit tests to improve code cov * Fix init_weights of setr up head * Add DropPath in FFN * Finish benchmark of SETR 1. Add benchmark information into README.MD of SETR; 2. Fix some name bugs of vit; * Remove DropPath implementation and use DropPath from mmcv. * Modify out_indices arg * Fix out_indices bug. * Remove cityscapes base dataset config. Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn> Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
2021-06-24 00:39:29 +08:00
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import NECKS
class MLAModule(nn.Module):
def __init__(self,
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=None,
act_cfg=None):
super(MLAModule, self).__init__()
self.channel_proj = nn.ModuleList()
for i in range(len(in_channels)):
self.channel_proj.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.feat_extract = nn.ModuleList()
for i in range(len(in_channels)):
self.feat_extract.append(
ConvModule(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
# feat_list -> [p2, p3, p4, p5]
feat_list = []
for x, conv in zip(inputs, self.channel_proj):
feat_list.append(conv(x))
# feat_list -> [p5, p4, p3, p2]
# mid_list -> [m5, m4, m3, m2]
feat_list = feat_list[::-1]
mid_list = []
for feat in feat_list:
if len(mid_list) == 0:
mid_list.append(feat)
else:
mid_list.append(mid_list[-1] + feat)
# mid_list -> [m5, m4, m3, m2]
# out_list -> [o2, o3, o4, o5]
out_list = []
for mid, conv in zip(mid_list, self.feat_extract):
out_list.append(conv(mid))
return tuple(out_list)
@NECKS.register_module()
class MLANeck(nn.Module):
"""Multi-level Feature Aggregation.
This neck is `The Multi-level Feature Aggregation construction of
SETR <https://arxiv.org/abs/2012.15840>`_.
[Feature] Official implementation of SETR (#531) * Adjust vision transformer backbone architectures; * Add DropPath, trunc_normal_ for VisionTransformer implementation; * Add class token buring intermediate period and remove it during final period; * Fix some parameters loss bug; * * Store intermediate token features and impose no processes on them; * Remove class token and reshape entire token feature from NLC to NCHW; * Fix some doc error * Add a arg for VisionTransformer backbone to control if input class token into transformer; * Add stochastic depth decay rule for DropPath; * * Fix output bug when input_cls_token=False; * Add related unit test; * Re-implement of SETR * Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA); * * Modify some docs of heads of SETR; * Add MLA auxiliary head of SETR; * * Modify some arg of setr heads; * Add unit test for setr heads; * * Add 768x768 cityscapes dataset config; * Add Backbone: SETR -- Backbone: MLA, PUP, Naive; * Add SETR cityscapes training & testing config; * * Fix the low code coverage of unit test about heads of setr; * Remove some rebundant error capture; * * Add pascal context dataset & ade20k dataset config; * Modify auxiliary head relative config; * Modify folder structure. * add setr * modify vit * Fix the test_cfg arg position; * Fix some learning schedule bug; * optimize setr code * Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug; * Add MLA neck; * Modify setr configs to add MLA neck; * Modify MLA decode head to remove rebundant structure; * Remove some rebundant files. * * Fix the code style bug; * Remove some rebundant files; * Modify some unit tests of SETR; * Ignoring CityscapesCoarseDataset and MapillaryDataset. * Fix the activation function loss bug; * Fix the img_size bug of SETR_PUP_ADE20K * * Fix the lint bug of transformers.py; * Add mla neck unit test; * Convert vit of setr out shape from NLC to NCHW. * * Modify Resize action of data pipeline; * Fix deit related bug; * Set find_unused_parameters=False for pascal context dataset; * Remove arg: find_unused_parameters which is False by default. * Error auxiliary head of PUP deit * Remove the minimal restrict of slide inference. * Modify doc string of Resize * Seperate this part of code to a new PR #544 * * Remove some rebundant codes; * Modify unit tests of SETR heads; * Fix the tuple in_channels of mla_deit. * Modify code style * Move detailed definition of auxiliary head into model config dict; * Add some setr config for default cityscapes.py; * Fix the doc string of SETR head; * Modify implementation of SETR Heads * Remove setr aux head and use fcn head to replace it; * Remove arg: img_size and remove last interpolate op of heads; * Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead; * non-square input support for setr heads * Modify config argument for above commits * Remove norm_layer argument of SETRMLAHead * Add mla_align_corners for MLAModule interpolate * [Refactor]Refactor of SETRMLAHead * Modify Head implementation; * Modify Head unit test; * Modify related config file; * [Refactor]MLA Neck * Fix config bug * [Refactor]SETR Naive Head and SETR PUP Head * [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg * Fix config error * Refactor of SETR MLA, Naive, PUP heads. * Modify some attribute name of SETR Heads. * Modify setr configs to adapt new vit code. * Fix trunc_normal_ bug * Parameters init adjustment. * Remove redundant doc string of SETRUPHead * Fix pretrained bug * [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Remove norm from PatchEmbed * Fix pretrain weights bug * Modify pretrained judge * Fix some gradient backward bugs. * Add some unit tests to improve code cov * Fix init_weights of setr up head * Add DropPath in FFN * Finish benchmark of SETR 1. Add benchmark information into README.MD of SETR; 2. Fix some name bugs of vit; * Remove DropPath implementation and use DropPath from mmcv. * Modify out_indices arg * Fix out_indices bug. * Remove cityscapes base dataset config. Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn> Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
2021-06-24 00:39:29 +08:00
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
norm_cfg=None,
act_cfg=None):
super(MLANeck, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
# In order to build general vision transformer backbone, we have to
# move MLA to neck.
self.norm = nn.ModuleList([
build_norm_layer(norm_layer, in_channels[i])[1]
for i in range(len(in_channels))
])
self.mla = MLAModule(
in_channels=in_channels,
out_channels=out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# Convert from nchw to nlc
outs = []
for i in range(len(inputs)):
x = inputs[i]
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm[i](x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
outs.append(x)
outs = self.mla(outs)
return tuple(outs)