[NEW][Feature]Support SegNeXt(NeurIPS'2022) in master branch (#2600)

## Motivation

Support SegNeXt.

Due to many commits & changed files caused by WIP too long (perhaps it
could be resolved by `git merge` or `git rebase`).

This PR is created only for backup of old PR
https://github.com/open-mmlab/mmsegmentation/pull/2247

Co-authored-by: MeowZheng <meowzheng@outlook.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
pull/2651/head
MengzhangLI 2023-02-24 16:08:27 +08:00 committed by GitHub
parent b2fdae7ce7
commit 70477d21ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1216 additions and 2 deletions

View File

@ -145,6 +145,7 @@ Supported backbones:
- [x] [ConvNeXt (CVPR'2022)](configs/convnext) - [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae) - [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer) - [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
Supported methods: Supported methods:

View File

@ -128,6 +128,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ConvNeXt (CVPR'2022)](configs/convnext) - [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae) - [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer) - [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
已支持的算法: 已支持的算法:

View File

@ -0,0 +1,63 @@
# SegNeXt
[SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/visual-attention-network/segnext">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.31.0/mmseg/models/backbones/mscan.py#L328">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch).
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/215688018-5d4c8366-7793-4fdf-9397-960a09fac951.png" width="70%"/>
</div>
```bibtex
@article{guo2022segnext,
title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2209.08575},
year={2022}
}
```
## Pretrained model
The pretrained model could be found [here](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) from [original repo](https://github.com/Visual-Attention-Network/SegNeXt). You can download and put them in `./pretrain` folder.
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) |
| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) |
| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) |
| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) |
Note:
- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9.
- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/master/tools/test.py) like:
```python
def main():
from mmseg.apis import set_random_seed
random_seed = xxx # set random seed recorded in training log
set_random_seed(random_seed, deterministic=False)
...
```
- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0.

View File

@ -0,0 +1,103 @@
Collections:
- Name: SegNeXt
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2209.08575
Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation'
README: configs/segnext/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.31.0/mmseg/models/backbones/mscan.py#L328
Version: v0.31.0
Converted From:
Code: https://github.com/visual-attention-network/segnext
Models:
- Name: segnext_mscan-t_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-T
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 19.09
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 17.88
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.5
mIoU(ms+flip): 42.59
Config: configs/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth
- Name: segnext_mscan-s_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-S
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 23.66
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 21.47
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 44.16
mIoU(ms+flip): 45.81
Config: configs/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth
- Name: segnext_mscan-b_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-B
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 28.45
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 31.03
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.03
mIoU(ms+flip): 49.68
Config: configs/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth
- Name: segnext_mscan-l_1x16_512x512_adamw_160k_ade20k
In Collection: SegNeXt
Metadata:
backbone: MSCAN-L
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 43.65
hardware: A100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 43.32
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 50.99
mIoU(ms+flip): 52.1
Config: configs/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint='pretrain/mscan_b.pth'),
drop_path_rate=0.1,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=512,
ham_channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 5, 27, 3],
init_cfg=dict(type='Pretrained', checkpoint='pretrain/mscan_l.pth'),
drop_path_rate=0.3,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=1024,
ham_channels=1024,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,26 @@
_base_ = './segnext_mscan-t_1x16_512x512_adamw_160k_ade20k.py'
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[2, 2, 4, 2],
init_cfg=dict(type='Pretrained', checkpoint='./pretrain/mscan_s.pth'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
ham_kwargs=dict(MD_R=16),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,125 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
# model settings
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='MSCAN',
init_cfg=dict(type='Pretrained', checkpoint='./pretrain/mscan_t.pth'),
embed_dims=[32, 64, 160, 256],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[64, 160, 256],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=16,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=16,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

View File

@ -11,6 +11,7 @@ from .mae import MAE
from .mit import MixVisionTransformer from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3 from .mobilenet_v3 import MobileNetV3
from .mscan import MSCAN
from .resnest import ResNeSt from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt from .resnext import ResNeXt
@ -26,5 +27,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE' 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'MSCAN'
] ]

View File

@ -0,0 +1,469 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule
from mmseg.models.builder import BACKBONES
class Mlp(BaseModule):
"""Multi Layer Perceptron (MLP) Module.
Args:
in_features (int): The dimension of input features.
hidden_features (int): The dimension of hidden features.
Defaults: None.
out_features (int): The dimension of output features.
Defaults: None.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Forward function."""
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class StemConv(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): The dimension of input channels.
out_channels (int): The dimension of output channels.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super(StemConv, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels // 2)[1],
build_activation_layer(act_cfg),
nn.Conv2d(
out_channels // 2,
out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels)[1],
)
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.size()
x = x.flatten(2).transpose(1, 2)
return x, H, W
class MSCAAttention(BaseModule):
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
channels (int): The dimension of channels.
kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
"""
def __init__(self,
channels,
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
paddings=[2, [0, 3], [0, 5], [0, 10]]):
super().__init__()
self.conv0 = nn.Conv2d(
channels,
channels,
kernel_size=kernel_sizes[0],
padding=paddings[0],
groups=channels)
for i, (kernel_size,
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
kernel_size_ = [kernel_size, kernel_size[::-1]]
padding_ = [padding, padding[::-1]]
conv_name = [f'conv{i}_1', f'conv{i}_2']
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
conv_name):
self.add_module(
i_conv,
nn.Conv2d(
channels,
channels,
tuple(i_kernel),
padding=i_pad,
groups=channels))
self.conv3 = nn.Conv2d(channels, channels, 1)
def forward(self, x):
"""Forward function."""
u = x.clone()
attn = self.conv0(x)
# Multi-Scale Feature extraction
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
# Channel Mixing
attn = self.conv3(attn)
# Convolutional Attention
x = attn * u
return x
class MSCASpatialAttention(BaseModule):
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
(MSCA).
Args:
in_channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU')):
super().__init__()
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = MSCAAttention(in_channels,
attention_kernel_sizes,
attention_kernel_paddings)
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
"""Forward function."""
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class MSCABlock(BaseModule):
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
kernel attention (LKA) mechanism to build both channel and spatial
attention. In each branch, it uses two depth-wise strip convolutions to
approximate standard depth-wise convolutions with large kernels. The kernel
size for each branch is set to 7, 11, and 21, respectively.
Args:
channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
mlp_ratio (float): The ratio of multiple input dimension to
calculate hidden feature in MLP layer. Defaults: 4.0.
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
drop_path (float): The ratio of drop paths.
Defaults: 0.0.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
attention_kernel_paddings, act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
mlp_hidden_channels = int(channels * mlp_ratio)
self.mlp = Mlp(
in_features=channels,
hidden_features=mlp_hidden_channels,
act_cfg=act_cfg,
drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((channels)),
requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((channels)),
requires_grad=True)
def forward(self, x, H, W):
"""Forward function."""
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.attn(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
x = x.view(B, C, N).permute(0, 2, 1)
return x
class OverlapPatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int): The patch size.
Defaults: 7.
stride (int): Stride of the convolutional layer.
Default: 4.
in_channels (int): The number of input channels.
Defaults: 3.
embed_dims (int): The dimensions of embedding.
Defaults: 768.
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
patch_size=7,
stride=4,
in_channels=3,
embed_dim=768,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=patch_size // 2)
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
@BACKBONES.register_module()
class MSCAN(BaseModule):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Args:
in_channels (int): The number of input channels. Defaults: 3.
embed_dims (list[int]): Embedding dimension.
Defaults: [64, 128, 256, 512].
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
Defaults: [4, 4, 4, 4].
drop_rate (float): Dropout rate. Defaults: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
depths (list[int]): Depths of each Swin Transformer stage.
Default: [3, 4, 6, 3].
num_stages (int): MSCAN stages. Default: 4.
attention_kernel_sizes (list): Size of attention kernel in
Attention Module (Figure 2(b) of original paper).
Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): Size of attention paddings
in Attention Module (Figure 2(b) of original paper).
Defaults: [2, [0, 3], [0, 5], [0, 10]].
norm_cfg (dict): Config of norm layers.
Defaults: dict(type='SyncBN', requires_grad=True).
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super(MSCAN, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
MSCABlock(
channels=embed_dims[i],
attention_kernel_sizes=attention_kernel_sizes,
attention_kernel_paddings=attention_kernel_paddings,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
"""Initialize modules of MSCAN."""
print('init cfg', self.init_cfg)
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super(MSCAN, self).init_weights()
def forward(self, x):
"""Forward function."""
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs

View File

@ -12,6 +12,7 @@ from .enc_head import EncHead
from .fcn_head import FCNHead from .fcn_head import FCNHead
from .fpn_head import FPNHead from .fpn_head import FPNHead
from .gc_head import GCHead from .gc_head import GCHead
from .ham_head import LightHamHead
from .isa_head import ISAHead from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead from .lraspp_head import LRASPPHead
@ -36,5 +37,5 @@ __all__ = [
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator' 'KernelUpdateHead', 'KernelUpdator', 'LightHamHead'
] ]

View File

@ -0,0 +1,258 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class Matrix_Decomposition_2D_Base(nn.Module):
"""Base class of 2D Matrix Decomposition.
Args:
MD_S (int): The number of spatial coefficient in
Matrix Decomposition, it may be used for calculation
of the number of latent dimension D in Matrix
Decomposition. Defaults: 1.
MD_R (int): The number of latent dimension R in
Matrix Decomposition. Defaults: 64.
train_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in training. Defaults: 6.
eval_steps (int): The number of iteration steps in
Multiplicative Update (MU) rule to solve Non-negative
Matrix Factorization (NMF) in evaluation. Defaults: 7.
inv_t (int): Inverted multiple number to make coefficient
smaller in softmax. Defaults: 100.
rand_init (bool): Whether to initialize randomly.
Defaults: True.
"""
def __init__(self,
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True):
super().__init__()
self.S = MD_S
self.R = MD_R
self.train_steps = train_steps
self.eval_steps = eval_steps
self.inv_t = inv_t
self.rand_init = rand_init
def _build_bases(self, B, S, D, R, cuda=False):
raise NotImplementedError
def local_step(self, x, bases, coef):
raise NotImplementedError
def local_inference(self, x, bases):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
raise NotImplementedError
def forward(self, x, return_bases=False):
"""Forward Function."""
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
cuda = 'cuda' in str(x.device)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
x = x.view(B, C, H, W)
return x
class NMF2D(Matrix_Decomposition_2D_Base):
"""Non-negative Matrix Factorization (NMF) module.
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
"""
def __init__(self, args=dict()):
super().__init__(**args)
self.inv_t = 1
def _build_bases(self, B, S, D, R, cuda=False):
"""Build bases in initialization."""
if cuda:
bases = torch.rand((B * S, D, R)).cuda()
else:
bases = torch.rand((B * S, D, R))
bases = F.normalize(bases, dim=1)
return bases
def local_step(self, x, bases, coef):
"""Local step in iteration to renew bases and coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def compute_coef(self, x, bases, coef):
"""Compute coefficient."""
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
class Hamburger(nn.Module):
"""Hamburger Module. It consists of one slice of "ham" (matrix
decomposition) and two slices of "bread" (linear transformation).
Args:
ham_channels (int): Input and output channels of feature.
ham_kwargs (dict): Config of matrix decomposition module.
norm_cfg (dict | None): Config of norm layers.
"""
def __init__(self,
ham_channels=512,
ham_kwargs=dict(),
norm_cfg=None,
**kwargs):
super().__init__()
self.ham_in = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
self.ham = NMF2D(ham_kwargs)
self.ham_out = ConvModule(
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x):
enjoy = self.ham_in(x)
enjoy = F.relu(enjoy, inplace=True)
enjoy = self.ham(enjoy)
enjoy = self.ham_out(enjoy)
ham = F.relu(x + enjoy, inplace=True)
return ham
@HEADS.register_module()
class LightHamHead(BaseDecodeHead):
"""SegNeXt decode head.
This decode head is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Specifically, LightHamHead is inspired by HamNet from
`Is Attention Better Than Matrix Decomposition?
<https://arxiv.org/abs/2109.04553>`.
Args:
ham_channels (int): input channels for Hamburger.
Defaults: 512.
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
"""
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
super(LightHamHead, self).__init__(
input_transform='multiple_select', **kwargs)
self.ham_channels = ham_channels
self.squeeze = ConvModule(
sum(self.in_channels),
self.ham_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
self.align = ConvModule(
self.ham_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
inputs = [
resize(
level,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for level in inputs
]
inputs = torch.cat(inputs, dim=1)
# apply a conv block to squeeze feature map
x = self.squeeze(inputs)
# apply hamburger module
x = self.hamburger(x)
# apply a conv block to align feature map
output = self.align(x)
output = self.cls_seg(output)
return output

View File

@ -36,6 +36,7 @@ Import:
- configs/resnest/resnest.yml - configs/resnest/resnest.yml
- configs/segformer/segformer.yml - configs/segformer/segformer.yml
- configs/segmenter/segmenter.yml - configs/segmenter/segmenter.yml
- configs/segnext/segnext.yml
- configs/sem_fpn/sem_fpn.yml - configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml - configs/setr/setr.yml
- configs/stdc/stdc.yml - configs/stdc/stdc.yml

View File

@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.backbones import MSCAN
from mmseg.models.backbones.mscan import (MSCAAttention, MSCASpatialAttention,
OverlapPatchEmbed, StemConv)
def test_mscan_backbone():
# Test MSCAN Standard Forward
model = MSCAN(
embed_dims=[8, 16, 32, 64],
norm_cfg=dict(type='BN', requires_grad=True))
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feat = model(imgs)
assert len(feat) == 4
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 8, 16, 32])
assert feat[1].shape == torch.Size([batch_size, 16, 8, 16])
assert feat[2].shape == torch.Size([batch_size, 32, 4, 8])
assert feat[3].shape == torch.Size([batch_size, 64, 2, 4])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 95, 27)
feat = model(imgs)
assert len(feat) == 4
def test_mscan_overlap_patch_embed_module():
x_overlap_patch_embed = OverlapPatchEmbed(
norm_cfg=dict(type='BN', requires_grad=True))
assert x_overlap_patch_embed.proj.in_channels == 3
assert x_overlap_patch_embed.norm.weight.shape == torch.Size([768])
x = torch.randn(2, 3, 16, 32)
x_out, H, W = x_overlap_patch_embed(x)
assert x_out.shape == torch.Size([2, 32, 768])
def test_mscan_spatial_attention_module():
x_spatial_attention = MSCASpatialAttention(8)
assert x_spatial_attention.proj_1.kernel_size == (1, 1)
assert x_spatial_attention.proj_2.stride == (1, 1)
x = torch.randn(2, 8, 16, 32)
x_out = x_spatial_attention(x)
assert x_out.shape == torch.Size([2, 8, 16, 32])
def test_mscan_attention_module():
x_attention = MSCAAttention(8)
assert x_attention.conv0.weight.shape[0] == 8
assert x_attention.conv3.kernel_size == (1, 1)
x = torch.randn(2, 8, 16, 32)
x_out = x_attention(x)
assert x_out.shape == torch.Size([2, 8, 16, 32])
def test_mscan_stem_module():
x_stem = StemConv(8, 8, norm_cfg=dict(type='BN', requires_grad=True))
assert x_stem.proj[0].weight.shape[0] == 4
assert x_stem.proj[-1].weight.shape[0] == 8
x = torch.randn(2, 8, 16, 32)
x_out, H, W = x_stem(x)
assert x_out.shape == torch.Size([2, 32, 8])
assert (H, W) == (4, 8)

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import LightHamHead
from .utils import _conv_has_norm, to_cuda
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
def test_ham_head():
# test without sync_bn
head = LightHamHead(
in_channels=[16, 32, 64],
in_index=[1, 2, 3],
channels=64,
ham_channels=64,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True))
assert not _conv_has_norm(head, sync_bn=False)
inputs = [
torch.randn(1, 8, 32, 32),
torch.randn(1, 16, 16, 16),
torch.randn(1, 32, 8, 8),
torch.randn(1, 64, 4, 4)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == [16, 32, 64]
assert head.hamburger.ham_in.in_channels == 64
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 16, 16)