[Feature] Official implementation of SETR (#531)

* Adjust vision transformer backbone architectures;

* Add DropPath, trunc_normal_ for VisionTransformer implementation;

* Add class token buring intermediate period and remove it during final period;

* Fix some parameters loss bug;

* * Store intermediate token features and impose no processes on them;

* Remove class token and reshape entire token feature from NLC to NCHW;

* Fix some doc error

* Add a arg for VisionTransformer backbone to control if input class token into transformer;

* Add stochastic depth decay rule for DropPath;

* * Fix output bug when input_cls_token=False;

* Add related unit test;

* Re-implement of SETR

* Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA);

* * Modify some docs of heads of SETR;

* Add MLA auxiliary head of SETR;

* * Modify some arg of setr heads;

* Add unit test for setr heads;

* * Add 768x768 cityscapes dataset config;

* Add Backbone: SETR -- Backbone: MLA, PUP, Naive;

* Add SETR cityscapes training & testing config;

* * Fix the low code coverage of unit test about heads of setr;

* Remove some rebundant error capture;

* * Add pascal context dataset & ade20k dataset config;

* Modify auxiliary head relative config;

* Modify folder structure.

* add setr

* modify vit

* Fix the test_cfg arg position;

* Fix some learning schedule bug;

* optimize setr code

* Add arg: final_reshape to control if converting output feature information from NLC to NCHW;

* Fix the default value of final_reshape;

* Modify arg: final_reshape to arg: out_shape;

* Fix some unit test bug;

* Add MLA neck;

* Modify setr configs to add MLA neck;

* Modify MLA decode head to remove rebundant structure;

* Remove some rebundant files.

* * Fix the code style bug;

* Remove some rebundant files;

* Modify some unit tests of SETR;

* Ignoring CityscapesCoarseDataset and MapillaryDataset.

* Fix the activation function loss bug;

* Fix the img_size bug of SETR_PUP_ADE20K

* * Fix the lint bug of transformers.py;

* Add mla neck unit test;

* Convert vit of setr out shape from NLC to NCHW.

* * Modify Resize action of data pipeline;

* Fix deit related bug;

* Set find_unused_parameters=False for pascal context dataset;

* Remove arg: find_unused_parameters which is False by default.

* Error auxiliary head of PUP deit

* Remove the minimal restrict of slide inference.

* Modify doc string of Resize

* Seperate this part of code to a new PR #544

* * Remove some rebundant codes;

* Modify unit tests of SETR heads;

* Fix the tuple in_channels of mla_deit.

* Modify code style

* Move detailed definition of auxiliary head into model config dict;

* Add some setr config for default cityscapes.py;

* Fix the doc string of SETR head;

* Modify implementation of SETR Heads

* Remove setr aux head and use fcn head to replace it;

* Remove arg: img_size and remove last interpolate op of heads;

* Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead;

* non-square input support for setr heads

* Modify config argument for above commits

* Remove norm_layer argument of SETRMLAHead

* Add mla_align_corners for MLAModule interpolate

* [Refactor]Refactor of SETRMLAHead

* Modify Head implementation;

* Modify Head unit test;

* Modify related config file;

* [Refactor]MLA Neck

* Fix config bug

* [Refactor]SETR Naive Head and SETR PUP Head

* [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg

* Fix config error

* Refactor of SETR MLA, Naive, PUP heads.

* Modify some attribute name of SETR Heads.

* Modify setr configs to adapt new vit code.

* Fix trunc_normal_ bug

* Parameters init adjustment.

* Remove redundant doc string of SETRUPHead

* Fix pretrained bug

* [Fix] Fix vit init bug

* Add some vit unit tests

* Modify module import

* Remove norm from PatchEmbed

* Fix pretrain weights bug

* Modify pretrained judge

* Fix some gradient backward bugs.

* Add some unit tests to improve code cov

* Fix init_weights of setr up head

* Add DropPath in FFN

* Finish benchmark of SETR

1. Add benchmark information into README.MD of SETR;

2. Fix some name bugs of vit;

* Remove DropPath implementation and use DropPath from mmcv.

* Modify out_indices arg

* Fix out_indices bug.

* Remove cityscapes base dataset config.

Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn>
Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
pull/646/head
Sixiao Zheng 2021-06-24 00:39:29 +08:00 committed by GitHub
parent 3e70d93285
commit ec91893931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 914 additions and 101 deletions

View File

@ -0,0 +1,96 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dims=1024,
num_layers=24,
num_heads=16,
out_indices=(5, 11, 17, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
with_cls_token=False,
interpolate_mode='bilinear',
),
neck=dict(
type='MLANeck',
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
),
decode_head=dict(
type='SETRMLAHead',
in_channels=(256, 256, 256, 256),
channels=512,
in_index=(0, 1, 2, 3),
dropout_ratio=0,
mla_channels=128,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=0,
dropout_ratio=0,
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=1,
dropout_ratio=0,
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=2,
dropout_ratio=0,
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=3,
dropout_ratio=0,
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,81 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dims=1024,
num_layers=24,
num_heads=16,
out_indices=(9, 14, 19, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
with_cls_token=True,
interpolate_mode='bilinear',
),
decode_head=dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=3,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=0,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=1,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=2,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,81 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dims=1024,
num_layers=24,
num_heads=16,
out_indices=(9, 14, 19, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
with_cls_token=True,
interpolate_mode='bilinear',
),
decode_head=dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=3,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=4,
up_scale=2,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=0,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=1,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=2,
num_classes=19,
dropout_ratio=0,
norm_cfg=norm_cfg,
num_convs=1,
up_scale=4,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,25 @@
# Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{zheng2020rethinking,
title={Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers},
author={Zheng, Sixiao and Lu, Jiachen and Zhao, Hengshuang and Zhu, Xiatian and Luo, Zekun and Wang, Yabiao and Fu, Yanwei and Feng, Jianfeng and Xiang, Tao and Torr, Philip HS and others},
journal={arXiv preprint arXiv:2012.15840},
year={2020}
}
```
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | -------------- | ----- | ------------: | ------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| SETR-Naive | ViT-L | 512x512 | 16 | 160000 | 18.40 | 4.72 | 48.28 | 49.56 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_naive_512x512_160k_b16_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_naive_512x512_160k_b16_ade20k/setr_naive_512x512_160k_b16_ade20k_20210619_191258-061f24f5.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_naive_512x512_160k_b16_ade20k/setr_naive_512x512_160k_b16_ade20k_20210619_191258.log.json) |
| SETR-PUP | ViT-L | 512x512 | 16 | 160000 | 19.54 | 4.50 | 48.24 | 49.99 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_pup_512x512_160k_b16_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_pup_512x512_160k_b16_ade20k/setr_pup_512x512_160k_b16_ade20k_20210619_191343-7e0ce826.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_pup_512x512_160k_b16_ade20k/setr_pup_512x512_160k_b16_ade20k_20210619_191343.log.json) |
| SETR-MLA | ViT-L | 512x512 | 8 | 160000 | 10.96 | - | 47.34 | 49.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_mla_512x512_160k_b8_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_mla_512x512_160k_b8_ade20k/setr_mla_512x512_160k_b8_ade20k_20210619_191118-c6d21df0.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_mla_512x512_160k_b8_ade20k/setr_mla_512x512_160k_b8_ade20k_20210619_191118.log.json) |
| SETR-MLA | ViT-L | 512x512 | 16 | 160000 | 17.30 | 5.25 | 47.54 | 49.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_mla_512x512_160k_b16_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_mla_512x512_160k_b16_ade20k/setr_mla_512x512_160k_b16_ade20k_20210619_191057-f9741de7.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/setr/setr_mla_512x512_160k_b16_ade20k/setr_mla_512x512_160k_b16_ade20k_20210619_191057.log.json) |

View File

@ -0,0 +1,4 @@
_base_ = ['./setr_mla_512x512_160k_b8_ade20k.py']
# num_gpus: 8 -> batch_size: 16
data = dict(samples_per_gpu=2)

View File

@ -0,0 +1,80 @@
_base_ = [
'../_base_/models/setr_mla.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=0,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=150,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=1,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=150,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=2,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=150,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=256,
channels=256,
in_index=3,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=0,
kernel_size=1,
concat_input=False,
num_classes=150,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
],
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)),
)
optimizer = dict(
lr=0.001,
weight_decay=0.0,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)}))
# num_gpus: 8 -> batch_size: 8
data = dict(samples_per_gpu=1)

View File

@ -0,0 +1,62 @@
_base_ = [
'../_base_/models/setr_naive.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=0,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=1,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=2,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=1,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)),
)
optimizer = dict(
lr=0.01,
weight_decay=0.0,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)}))
# num_gpus: 8 -> batch_size: 16
data = dict(samples_per_gpu=2)

View File

@ -0,0 +1,62 @@
_base_ = [
'../_base_/models/setr_pup.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(img_size=(512, 512), drop_rate=0.),
decode_head=dict(num_classes=150),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=0,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=1,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=256,
in_index=2,
num_classes=150,
dropout_ratio=0,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
num_convs=2,
kernel_size=3,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
],
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)),
)
optimizer = dict(
lr=0.001,
weight_decay=0.0,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)}))
# num_gpus: 8 -> batch_size: 16
data = dict(samples_per_gpu=2)

View File

@ -32,10 +32,13 @@ class Resize(object):
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
Default:None.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
Default: 'range'
ratio_range (tuple[float]): (min_ratio, max_ratio).
Default: None
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
image. Default: True
"""
def __init__(self,

View File

@ -20,23 +20,24 @@ class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.0
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
Default: 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
qkv_bias (bool): enable bias for qkv if True. Default True
act_cfg (dict): The activation config for FFNs. Defalut GELU
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
init_cfg (dict, optional): Initialization config dict
or (n, batch, embed_dim). Default: True.
"""
def __init__(self,
@ -50,7 +51,7 @@ class TransformerEncoderLayer(BaseModule):
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=False):
batch_first=True):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(
@ -75,7 +76,7 @@ class TransformerEncoderLayer(BaseModule):
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=None,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
@ -97,45 +98,32 @@ class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): The size of input image.
patch_size (int): The size of one patch
in_channels (int): The num of input channels.
embed_dim (int): The dimensions of embedding.
embed_dims (int): The dimensions of embedding.
norm_cfg (dict, optional): Config dict for normalization layer.
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768,
embed_dims=768,
norm_cfg=None,
conv_cfg=None):
super(PatchEmbed, self).__init__()
self.img_size = img_size
self.patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // self.patch_size[0],
img_size[1] // self.patch_size[1]
]
num_patches = patches_resolution[0] * patches_resolution[1]
self.patches_resolution = patches_resolution
self.num_patches = num_patches
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg,
in_channels,
embed_dim,
embed_dims,
kernel_size=patch_size,
stride=patch_size)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
@ -209,7 +197,7 @@ class VisionTransformer(BaseModule):
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=11,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
@ -260,13 +248,13 @@ class VisionTransformer(BaseModule):
self.init_cfg = init_cfg
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dims,
embed_dims=embed_dims,
norm_cfg=norm_cfg if patch_norm else None)
num_patches = self.patch_embed.num_patches
num_patches = (img_size[0] // patch_size) * \
(img_size[1] // patch_size)
self.with_cls_token = with_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
@ -275,6 +263,8 @@ class VisionTransformer(BaseModule):
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
@ -302,6 +292,7 @@ class VisionTransformer(BaseModule):
batch_first=True))
self.final_norm = final_norm
self.out_shape = out_shape
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
@ -314,7 +305,8 @@ class VisionTransformer(BaseModule):
def init_weights(self):
if isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(self.pretrained, logger=logger)
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:

View File

@ -18,11 +18,13 @@ from .psa_head import PSAHead
from .psp_head import PSPHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
from .setr_up_head import SETRUPHead
from .uper_head import UPerHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead'
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
]

View File

@ -0,0 +1,61 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SETRMLAHead(BaseDecodeHead):
"""Multi level feature aggretation head of SETR.
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`.
Args:
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature
aggregation. Default: 128.
up_scale (int): The scale factor of interpolate. Default:4.
"""
def __init__(self, mla_channels=128, up_scale=4, **kwargs):
super(SETRMLAHead, self).__init__(
input_transform='multiple_select', **kwargs)
self.mla_channels = mla_channels
num_inputs = len(self.in_channels)
# Refer to self.cls_seg settings of BaseDecodeHead
assert self.channels == num_inputs * mla_channels
self.up_convs = nn.ModuleList()
for i in range(num_inputs):
self.up_convs.append(
nn.Sequential(
ConvModule(
in_channels=self.in_channels[i],
out_channels=mla_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
in_channels=mla_channels,
out_channels=mla_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
def forward(self, inputs):
inputs = self._transform_inputs(inputs)
outs = []
for x, up_conv in zip(inputs, self.up_convs):
outs.append(up_conv(x))
out = torch.cat(outs, dim=1)
out = self.cls_seg(out)
return out

View File

@ -0,0 +1,75 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SETRUPHead(BaseDecodeHead):
"""Naive upsampling head and Progressive upsampling head of SETR.
Naive or PUP head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`.
Args:
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
num_convs (int): Number of decoder convolutions. Default: 1.
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
"""
def __init__(self,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
num_convs=1,
up_scale=4,
kernel_size=3,
**kwargs):
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
super(SETRUPHead, self).__init__(**kwargs)
assert isinstance(self.in_channels, int)
_, self.norm = build_norm_layer(norm_layer, self.in_channels)
self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for i in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
in_channels = out_channels
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
def forward(self, x):
x = self._transform_inputs(x)
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
for up_conv in self.up_convs:
x = up_conv(x)
out = self.cls_seg(x)
return out

View File

@ -1,4 +1,5 @@
from .fpn import FPN
from .mla_neck import MLANeck
from .multilevel_neck import MultiLevelNeck
__all__ = ['FPN', 'MultiLevelNeck']
__all__ = ['FPN', 'MultiLevelNeck', 'MLANeck']

View File

@ -0,0 +1,117 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import NECKS
class MLAModule(nn.Module):
def __init__(self,
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=None,
act_cfg=None):
super(MLAModule, self).__init__()
self.channel_proj = nn.ModuleList()
for i in range(len(in_channels)):
self.channel_proj.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.feat_extract = nn.ModuleList()
for i in range(len(in_channels)):
self.feat_extract.append(
ConvModule(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
# feat_list -> [p2, p3, p4, p5]
feat_list = []
for x, conv in zip(inputs, self.channel_proj):
feat_list.append(conv(x))
# feat_list -> [p5, p4, p3, p2]
# mid_list -> [m5, m4, m3, m2]
feat_list = feat_list[::-1]
mid_list = []
for feat in feat_list:
if len(mid_list) == 0:
mid_list.append(feat)
else:
mid_list.append(mid_list[-1] + feat)
# mid_list -> [m5, m4, m3, m2]
# out_list -> [o2, o3, o4, o5]
out_list = []
for mid, conv in zip(mid_list, self.feat_extract):
out_list.append(conv(mid))
return tuple(out_list)
@NECKS.register_module()
class MLANeck(nn.Module):
"""Multi-level Feature Aggregation.
The Multi-level Feature Aggregation construction of SETR:
https://arxiv.org/pdf/2012.15840.pdf
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
norm_cfg=None,
act_cfg=None):
super(MLANeck, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
# In order to build general vision transformer backbone, we have to
# move MLA to neck.
self.norm = nn.ModuleList([
build_norm_layer(norm_layer, in_channels[i])[1]
for i in range(len(in_channels))
])
self.mla = MLAModule(
in_channels=in_channels,
out_channels=out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# Convert from nchw to nlc
outs = []
for i in range(len(inputs)):
x = inputs[i]
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm[i](x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
outs.append(x)
outs = self.mla(outs)
return tuple(outs)

View File

@ -1,4 +1,3 @@
from .drop import DropPath
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .res_layer import ResLayer
@ -9,5 +8,5 @@ from .up_conv_block import UpConvBlock
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'vit_convert'
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert'
]

View File

@ -1,31 +0,0 @@
"""Modified from https://github.com/rwightman/pytorch-image-
models/blob/master/timm/models/layers/drop.py."""
import torch
from torch import nn
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
Args:
drop_prob (float): Drop rate for paths of model. Dropout rate has
to be between 0 and 1. Default: 0.
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.keep_prob = 1 - drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = self.keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(self.keep_prob) * random_tensor
return output

View File

@ -0,0 +1,62 @@
import pytest
import torch
from mmseg.models.decode_heads import SETRMLAHead
from .utils import to_cuda
def test_setr_mla_head(capsys):
with pytest.raises(AssertionError):
# MLA requires input multiple stage feature information.
SETRMLAHead(in_channels=32, channels=16, num_classes=19, in_index=1)
with pytest.raises(AssertionError):
# multiple in_indexs requires multiple in_channels.
SETRMLAHead(
in_channels=32, channels=16, num_classes=19, in_index=(0, 1, 2, 3))
with pytest.raises(AssertionError):
# channels should be len(in_channels) * mla_channels
SETRMLAHead(
in_channels=(32, 32, 32, 32),
channels=32,
mla_channels=16,
in_index=(0, 1, 2, 3),
num_classes=19)
# test inference of MLA head
img_size = (32, 32)
patch_size = 16
head = SETRMLAHead(
in_channels=(32, 32, 32, 32),
channels=64,
mla_channels=16,
in_index=(0, 1, 2, 3),
num_classes=19,
norm_cfg=dict(type='BN'))
h, w = img_size[0] // patch_size, img_size[1] // patch_size
# Input square NCHW format feature information
x = [
torch.randn(1, 32, h, w),
torch.randn(1, 32, h, w),
torch.randn(1, 32, h, w),
torch.randn(1, 32, h, w)
]
if torch.cuda.is_available():
head, x = to_cuda(head, x)
out = head(x)
assert out.shape == (1, head.num_classes, h * 4, w * 4)
# Input non-square NCHW format feature information
x = [
torch.randn(1, 32, h, w * 2),
torch.randn(1, 32, h, w * 2),
torch.randn(1, 32, h, w * 2),
torch.randn(1, 32, h, w * 2)
]
if torch.cuda.is_available():
head, x = to_cuda(head, x)
out = head(x)
assert out.shape == (1, head.num_classes, h * 4, w * 8)

View File

@ -0,0 +1,54 @@
import pytest
import torch
from mmseg.models.decode_heads import SETRUPHead
from .utils import to_cuda
def test_setr_up_head(capsys):
with pytest.raises(AssertionError):
# kernel_size must be [1/3]
SETRUPHead(num_classes=19, kernel_size=2)
with pytest.raises(AssertionError):
# in_channels must be int type and in_channels must be same
# as embed_dim.
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)
# test init_weights of head
head = SETRUPHead(
in_channels=32,
channels=16,
norm_cfg=dict(type='SyncBN'),
num_classes=19)
head.init_weights()
# test inference of Naive head
# the auxiliary head of Naive head is same as Naive head
img_size = (32, 32)
patch_size = 16
head = SETRUPHead(
in_channels=32,
channels=16,
num_classes=19,
num_convs=1,
up_scale=4,
kernel_size=1,
norm_cfg=dict(type='BN'))
h, w = img_size[0] // patch_size, img_size[1] // patch_size
# Input square NCHW format feature information
x = [torch.randn(1, 32, h, w)]
if torch.cuda.is_available():
head, x = to_cuda(head, x)
out = head(x)
assert out.shape == (1, head.num_classes, h * 4, w * 4)
# Input non-square NCHW format feature information
x = [torch.randn(1, 32, h, w * 2)]
if torch.cuda.is_available():
head, x = to_cuda(head, x)
out = head(x)
assert out.shape == (1, head.num_classes, h * 4, w * 8)

View File

@ -0,0 +1,15 @@
import torch
from mmseg.models import MLANeck
def test_mla():
in_channels = [1024, 1024, 1024, 1024]
mla = MLANeck(in_channels, 256)
inputs = [torch.randn(1, c, 24, 24) for i, c in enumerate(in_channels)]
outputs = mla(inputs)
assert outputs[0].shape == torch.Size([1, 256, 24, 24])
assert outputs[1].shape == torch.Size([1, 256, 24, 24])
assert outputs[2].shape == torch.Size([1, 256, 24, 24])
assert outputs[3].shape == torch.Size([1, 256, 24, 24])

View File

@ -1,28 +0,0 @@
import torch
from mmseg.models.utils import DropPath
def test_drop_path():
# zero drop
layer = DropPath()
# input NLC format feature
x = torch.randn((1, 16, 32))
layer(x)
# input NLHW format feature
x = torch.randn((1, 32, 4, 4))
layer(x)
# non-zero drop
layer = DropPath(0.1)
# input NLC format feature
x = torch.randn((1, 16, 32))
layer(x)
# input NLHW format feature
x = torch.randn((1, 32, 4, 4))
layer(x)