[NEW][Feature]Support SegNeXt(NeurIPS'2022) in master branch (#2600)

## Motivation

Support SegNeXt.

Due to many commits & changed files caused by WIP too long (perhaps it
could be resolved by `git merge` or `git rebase`).

This PR is created only for backup of old PR
https://github.com/open-mmlab/mmsegmentation/pull/2247

Co-authored-by: MeowZheng <meowzheng@outlook.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
pull/2651/head
MengzhangLI 2023-02-24 16:08:27 +08:00 committed by GitHub
parent b2fdae7ce7
commit 70477d21ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1216 additions and 2 deletions

View File

@ -145,6 +145,7 @@ Supported backbones:
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
Supported methods:

View File

@ -128,6 +128,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
已支持的算法:

View File

@ -0,0 +1,63 @@
# SegNeXt
[SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/visual-attention-network/segnext">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.31.0/mmseg/models/backbones/mscan.py#L328">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch).
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/215688018-5d4c8366-7793-4fdf-9397-960a09fac951.png" width="70%"/>
</div>
```bibtex
@article{guo2022segnext,
title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2209.08575},
year={2022}
}
```
## Pretrained model
The pretrained model could be found [here](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) from [original repo](https://github.com/Visual-Attention-Network/SegNeXt). You can download and put them in `./pretrain` folder.
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) |
| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) |
| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) |
| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) |
Note:
- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9.
- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/master/tools/test.py) like:
```python
def main():
from mmseg.apis import set_random_seed
random_seed = xxx # set random seed recorded in training log
set_random_seed(random_seed, deterministic=False)
...
```
- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0.

View File

@ -0,0 +1,103 @@
Collections:
- Name: SegNeXt
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2209.08575
Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation'
README: configs/segnext/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.31.0/mmseg/models/backbones/mscan.py#L328
Version: v0.31.0
Converted From:
Code: https://github.com/visual-attention-network/segnext
Models:
- Name: segnext_mscan-t_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-T
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 19.09
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 17.88
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.5
mIoU(ms+flip): 42.59
Config: configs/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth
- Name: segnext_mscan-s_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-S
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 23.66
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 21.47
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 44.16
mIoU(ms+flip): 45.81
Config: configs/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth
- Name: segnext_mscan-b_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-B
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 28.45
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 31.03
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.03
mIoU(ms+flip): 49.68
Config: configs/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth
- Name: segnext_mscan-l_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-L
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 43.65
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 43.32
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 50.99
mIoU(ms+flip): 52.1
Config: configs/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint='pretrain/mscan_b.pth'),
drop_path_rate=0.1,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=512,
ham_channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 5, 27, 3],
init_cfg=dict(type='Pretrained', checkpoint='pretrain/mscan_l.pth'),
drop_path_rate=0.3,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=1024,
ham_channels=1024,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[2, 2, 4, 2],
init_cfg=dict(type='Pretrained', checkpoint='./pretrain/mscan_s.pth'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
ham_kwargs=dict(MD_R=16),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,125 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='MSCAN',
init_cfg=dict(type='Pretrained', checkpoint='./pretrain/mscan_t.pth'),
embed_dims=[32, 64, 160, 256],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[64, 160, 256],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=16,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=16,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

View File

@ -11,6 +11,7 @@ from .mae import MAE
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mscan import MSCAN
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
@ -26,5 +27,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'MSCAN'
]

View File

@ -0,0 +1,469 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule
from mmseg.models.builder import BACKBONES
class Mlp(BaseModule):
"""Multi Layer Perceptron (MLP) Module.
Args:
in_features (int): The dimension of input features.
hidden_features (int): The dimension of hidden features.
Defaults: None.
out_features (int): The dimension of output features.
Defaults: None.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Forward function."""
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class StemConv(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): The dimension of input channels.
out_channels (int): The dimension of output channels.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super(StemConv, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels // 2)[1],
build_activation_layer(act_cfg),
nn.Conv2d(
out_channels // 2,
out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels)[1],
)
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.size()
x = x.flatten(2).transpose(1, 2)
return x, H, W
class MSCAAttention(BaseModule):
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
channels (int): The dimension of channels.
kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
"""
def __init__(self,
channels,
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
paddings=[2, [0, 3], [0, 5], [0, 10]]):
super().__init__()
self.conv0 = nn.Conv2d(
channels,
channels,
kernel_size=kernel_sizes[0],
padding=paddings[0],
groups=channels)
for i, (kernel_size,
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
kernel_size_ = [kernel_size, kernel_size[::-1]]
padding_ = [padding, padding[::-1]]
conv_name = [f'conv{i}_1', f'conv{i}_2']
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
conv_name):
self.add_module(
i_conv,
nn.Conv2d(
channels,
channels,
tuple(i_kernel),
padding=i_pad,
groups=channels))
self.conv3 = nn.Conv2d(channels, channels, 1)
def forward(self, x):
"""Forward function."""
u = x.clone()
attn = self.conv0(x)
# Multi-Scale Feature extraction
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
# Channel Mixing
attn = self.conv3(attn)
# Convolutional Attention
x = attn * u
return x
class MSCASpatialAttention(BaseModule):
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
(MSCA).
Args:
in_channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU')):
super().__init__()
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = MSCAAttention(in_channels,
attention_kernel_sizes,
attention_kernel_paddings)
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
"""Forward function."""
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class MSCABlock(BaseModule):
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
kernel attention (LKA) mechanism to build both channel and spatial
attention. In each branch, it uses two depth-wise strip convolutions to
approximate standard depth-wise convolutions with large kernels. The kernel
size for each branch is set to 7, 11, and 21, respectively.
Args:
channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
mlp_ratio (float): The ratio of multiple input dimension to
calculate hidden feature in MLP layer. Defaults: 4.0.
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
drop_path (float): The ratio of drop paths.
Defaults: 0.0.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
attention_kernel_paddings, act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
mlp_hidden_channels = int(channels * mlp_ratio)
self.mlp = Mlp(
in_features=channels,
hidden_features=mlp_hidden_channels,
act_cfg=act_cfg,
drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((channels)),
requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((channels)),
requires_grad=True)
def forward(self, x, H, W):
"""Forward function."""
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.attn(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
x = x.view(B, C, N).permute(0, 2, 1)
return x
class OverlapPatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int): The patch size.
Defaults: 7.
stride (int): Stride of the convolutional layer.
Default: 4.
in_channels (int): The number of input channels.
Defaults: 3.
embed_dims (int): The dimensions of embedding.
Defaults: 768.
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
patch_size=7,
stride=4,
in_channels=3,
embed_dim=768,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=patch_size // 2)
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
@BACKBONES.register_module()
class MSCAN(BaseModule):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Args:
in_channels (int): The number of input channels. Defaults: 3.
embed_dims (list[int]): Embedding dimension.
Defaults: [64, 128, 256, 512].
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
Defaults: [4, 4, 4, 4].
drop_rate (float): Dropout rate. Defaults: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
depths (list[int]): Depths of each Swin Transformer stage.
Default: [3, 4, 6, 3].
num_stages (int): MSCAN stages. Default: 4.
attention_kernel_sizes (list): Size of attention kernel in
Attention Module (Figure 2(b) of original paper).
Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): Size of attention paddings
in Attention Module (Figure 2(b) of original paper).
Defaults: [2, [0, 3], [0, 5], [0, 10]].
norm_cfg (dict): Config of norm layers.
Defaults: dict(type='SyncBN', requires_grad=True).
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super(MSCAN, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
MSCABlock(
channels=embed_dims[i],
attention_kernel_sizes=attention_kernel_sizes,
attention_kernel_paddings=attention_kernel_paddings,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
"""Initialize modules of MSCAN."""
print('init cfg', self.init_cfg)
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super(MSCAN, self).init_weights()
def forward(self, x):
"""Forward function."""
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs

View File

@ -12,6 +12,7 @@ from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .ham_head import LightHamHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
@ -36,5 +37,5 @@ __all__ = [
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator'
'KernelUpdateHead', 'KernelUpdator', 'LightHamHead'
]

View File

@ -0,0 +1,258 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class Matrix_Decomposition_2D_Base(nn.Module):
"""Base class of 2D Matrix Decomposition.
Args:
MD_S (int): The number of spatial coefficient in
Matrix Decomposition, it may be used for calculation
of the number of latent dimension D in Matrix
Decomposition. Defaults: 1.
MD_R (int): The number of latent dimension R in
Matrix Decomposition. Defaults: 64.
train_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in training. Defaults: 6.
eval_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in evaluation. Defaults: 7.
inv_t (int): Inverted multiple number to make coefficient
smaller in softmax. Defaults: 100.
rand_init (bool): Whether to initialize randomly.
Defaults: True.
"""
def __init__(self,
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True):
super().__init__()
self.S = MD_S
self.R = MD_R
self.train_steps = train_steps
self.eval_steps = eval_steps
self.inv_t = inv_t
self.rand_init = rand_init
def _build_bases(self, B, S, D, R, cuda=False):
raise NotImplementedError
def local_step(self, x, bases, coef):
raise NotImplementedError
def local_inference(self, x, bases):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
raise NotImplementedError
def forward(self, x, return_bases=False):
"""Forward Function."""
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
cuda = 'cuda' in str(x.device)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
x = x.view(B, C, H, W)
return x
class NMF2D(Matrix_Decomposition_2D_Base):
"""Non-negative Matrix Factorization (NMF) module.
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
"""
def __init__(self, args=dict()):
super().__init__(**args)
self.inv_t = 1
def _build_bases(self, B, S, D, R, cuda=False):
"""Build bases in initialization."""
if cuda:
bases = torch.rand((B * S, D, R)).cuda()
else:
bases = torch.rand((B * S, D, R))
bases = F.normalize(bases, dim=1)
return bases
def local_step(self, x, bases, coef):
"""Local step in iteration to renew bases and coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def compute_coef(self, x, bases, coef):
"""Compute coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
class Hamburger(nn.Module):
"""Hamburger Module. It consists of one slice of "ham" (matrix
decomposition) and two slices of "bread" (linear transformation).
Args:
ham_channels (int): Input and output channels of feature.
ham_kwargs (dict): Config of matrix decomposition module.
norm_cfg (dict | None): Config of norm layers.
"""
def __init__(self,
ham_channels=512,
ham_kwargs=dict(),
norm_cfg=None,
**kwargs):
super().__init__()
self.ham_in = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
self.ham = NMF2D(ham_kwargs)
self.ham_out = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x):
enjoy = self.ham_in(x)
enjoy = F.relu(enjoy, inplace=True)
enjoy = self.ham(enjoy)
enjoy = self.ham_out(enjoy)
ham = F.relu(x + enjoy, inplace=True)
return ham
@HEADS.register_module()
class LightHamHead(BaseDecodeHead):
"""SegNeXt decode head.
This decode head is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Specifically, LightHamHead is inspired by HamNet from
`Is Attention Better Than Matrix Decomposition?
<https://arxiv.org/abs/2109.04553>`.
Args:
ham_channels (int): input channels for Hamburger.
Defaults: 512.
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
"""
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
super(LightHamHead, self).__init__(
input_transform='multiple_select', **kwargs)
self.ham_channels = ham_channels
self.squeeze = ConvModule(
sum(self.in_channels),
self.ham_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
self.align = ConvModule(
self.ham_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
inputs = [
resize(
level,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for level in inputs
]
inputs = torch.cat(inputs, dim=1)
# apply a conv block to squeeze feature map
x = self.squeeze(inputs)
# apply hamburger module
x = self.hamburger(x)
# apply a conv block to align feature map
output = self.align(x)
output = self.cls_seg(output)
return output

View File

@ -36,6 +36,7 @@ Import:
- configs/resnest/resnest.yml
- configs/segformer/segformer.yml
- configs/segmenter/segmenter.yml
- configs/segnext/segnext.yml
- configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml
- configs/stdc/stdc.yml

View File

@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.backbones import MSCAN
from mmseg.models.backbones.mscan import (MSCAAttention, MSCASpatialAttention,
OverlapPatchEmbed, StemConv)
def test_mscan_backbone():
# Test MSCAN Standard Forward
model = MSCAN(
embed_dims=[8, 16, 32, 64],
norm_cfg=dict(type='BN', requires_grad=True))
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feat = model(imgs)
assert len(feat) == 4
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 8, 16, 32])
assert feat[1].shape == torch.Size([batch_size, 16, 8, 16])
assert feat[2].shape == torch.Size([batch_size, 32, 4, 8])
assert feat[3].shape == torch.Size([batch_size, 64, 2, 4])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 95, 27)
feat = model(imgs)
assert len(feat) == 4
def test_mscan_overlap_patch_embed_module():
x_overlap_patch_embed = OverlapPatchEmbed(
norm_cfg=dict(type='BN', requires_grad=True))
assert x_overlap_patch_embed.proj.in_channels == 3
assert x_overlap_patch_embed.norm.weight.shape == torch.Size([768])
x = torch.randn(2, 3, 16, 32)
x_out, H, W = x_overlap_patch_embed(x)
assert x_out.shape == torch.Size([2, 32, 768])
def test_mscan_spatial_attention_module():
x_spatial_attention = MSCASpatialAttention(8)
assert x_spatial_attention.proj_1.kernel_size == (1, 1)
assert x_spatial_attention.proj_2.stride == (1, 1)
x = torch.randn(2, 8, 16, 32)
x_out = x_spatial_attention(x)
assert x_out.shape == torch.Size([2, 8, 16, 32])
def test_mscan_attention_module():
x_attention = MSCAAttention(8)
assert x_attention.conv0.weight.shape[0] == 8
assert x_attention.conv3.kernel_size == (1, 1)
x = torch.randn(2, 8, 16, 32)
x_out = x_attention(x)
assert x_out.shape == torch.Size([2, 8, 16, 32])
def test_mscan_stem_module():
x_stem = StemConv(8, 8, norm_cfg=dict(type='BN', requires_grad=True))
assert x_stem.proj[0].weight.shape[0] == 4
assert x_stem.proj[-1].weight.shape[0] == 8
x = torch.randn(2, 8, 16, 32)
x_out, H, W = x_stem(x)
assert x_out.shape == torch.Size([2, 32, 8])
assert (H, W) == (4, 8)

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import LightHamHead
from .utils import _conv_has_norm, to_cuda
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
def test_ham_head():
# test without sync_bn
head = LightHamHead(
in_channels=[16, 32, 64],
in_index=[1, 2, 3],
channels=64,
ham_channels=64,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True))
assert not _conv_has_norm(head, sync_bn=False)
inputs = [
torch.randn(1, 8, 32, 32),
torch.randn(1, 16, 16, 16),
torch.randn(1, 32, 8, 8),
torch.randn(1, 64, 4, 4)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == [16, 32, 64]
assert head.hamburger.ham_in.in_channels == 64
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 16, 16)