[Project] Added a supported for Visual Attention Network (VAN) (#2987)

## Motivation
The original version of Visual Attention Network (VAN) can be found from
https://github.com/Visual-Attention-Network/VAN-Segmentation
添加Visual Attention Network (VAN)的支持。



## Modification
added a floder mmsegmentation/projects/van/
added 13 configs totally and aligned performance basically.
只增加了一个文件夹,共增加13个配置文件,基本对齐性能(没有全部跑)。


## Use cases (Optional)
Before running, you may need to download the pretrain model from
https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/
and then move them to the folder mmsegmentation/pretrained/, i.e.
"mmsegmentation/pretrained/van_b2.pth".
After that, run the following command:
    cd mmsegmentation
bash tools/dist_train.sh
projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py
4

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
tang576225574 2023-05-22 20:26:26 +08:00 committed by GitHub
parent 7d6156776e
commit 0bfe255fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 509 additions and 0 deletions

101
projects/van/README.md Normal file
View File

@ -0,0 +1,101 @@
# Visual Attention Network (VAN) for Segmentation
This repo is a PyTorch implementation of applying **VAN** (**Visual Attention Network**) to semantic segmentation.
The code is an integration from [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
More details can be found in [**Visual Attention Network**](https://arxiv.org/abs/2202.09741).
## Citation
```bib
@article{guo2022visual,
title={Visual Attention Network},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2202.09741},
year={2022}
}
```
## Results
**Notes**: Pre-trained models can be found in [TsingHua Cloud](https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/).
Results can be found in [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
We provide evaluation results of the converted weights.
| Method | Backbone | mIoU | Download |
| :-----: | :----------: | :---: | :--------------------------------------------------------------------------------------------------------------------------------------------: |
| UPerNet | VAN-B2 | 49.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-19c58aee.pth) |
| UPerNet | VAN-B3 | 49.71 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
| UPerNet | VAN-B4 | 51.56 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
| UPerNet | VAN-B4-in22k | 52.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-4a4d744a.pth) |
| UPerNet | VAN-B5-in22k | 53.11 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-5bb6f2b4.pth) |
| UPerNet | VAN-B6-in22k | 54.25 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-e226b363.pth) |
| FPN | VAN-B0 | 38.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-75a76298.pth) |
| FPN | VAN-B1 | 43.22 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-104499ff.pth) |
| FPN | VAN-B2 | 46.84 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-7074e6f8.pth) |
| FPN | VAN-B3 | 48.32 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-2c3b7f5e.pth) |
## Preparation
Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation.
## Requirement
**Step 0.** Install [MMCV](https://github.com/open-mmlab/mmcv) using [MIM](https://github.com/open-mmlab/mim).
```shell
pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"
```
**Step 1.** Install MMSegmentation.
Case a: If you develop and run mmseg directly, install it from source:
```shell
git clone -b main https://github.com/open-mmlab/mmsegmentation.git
cd mmsegmentation
pip install -v -e .
```
Case b: If you use mmsegmentation as a dependency or third-party package, install it with pip:
```shell
pip install "mmsegmentation>=1.0.0"
```
## Training
If you use 4 GPUs for training by default. Run:
```bash
bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4
```
## Evaluation
To evaluate the model, an example is:
```bash
bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py work_dirs/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512/iter_160000.pth 4 --eval mIoU
```
## FLOPs
To calculate FLOPs for a model, run:
```bash
bash tools/analysis_tools/get_flops.py projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py --shape 512 512
```
## Acknowledgment
Our implementation is mainly based on [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0), [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), [PoolFormer](https://github.com/sail-sg/poolformer), [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger) and [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1). Thanks for their authors.
## LICENSE
This repo is under the Apache-2.0 license. For commercial use, please contact the authors.

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .van import VAN
__all__ = ['VAN']

View File

@ -0,0 +1,124 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmseg.models.backbones.mscan import (MSCAN, MSCABlock,
MSCASpatialAttention,
OverlapPatchEmbed)
from mmseg.registry import MODELS
class VANAttentionModule(BaseModule):
def __init__(self, in_channels):
super().__init__()
self.conv0 = nn.Conv2d(
in_channels, in_channels, 5, padding=2, groups=in_channels)
self.conv_spatial = nn.Conv2d(
in_channels,
in_channels,
7,
stride=1,
padding=9,
groups=in_channels,
dilation=3)
self.conv1 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
class VANSpatialAttention(MSCASpatialAttention):
def __init__(self, in_channels, act_cfg=dict(type='GELU')):
super().__init__(in_channels, act_cfg=act_cfg)
self.spatial_gating_unit = VANAttentionModule(in_channels)
class VANBlock(MSCABlock):
def __init__(self,
channels,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__(
channels,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path,
act_cfg=act_cfg,
norm_cfg=norm_cfg)
self.attn = VANSpatialAttention(channels)
@MODELS.register_module()
class VAN(MSCAN):
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super(MSCAN, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
# stochastic depth decay rule
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(num_stages):
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
VANBlock(
channels=embed_dims[i],
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
return super().init_weights()

View File

@ -0,0 +1,14 @@
# dataset settings
_base_ = '../../../../../configs/_base_/datasets/ade20k.py'
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

View File

@ -0,0 +1,43 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 512))
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='VAN',
embed_dims=[32, 64, 160, 256],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
act_cfg=dict(type='GELU'),
norm_cfg=norm_cfg,
init_cfg=dict()),
neck=dict(
type='FPN',
in_channels=[32, 64, 160, 256],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,51 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 512))
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='VAN',
embed_dims=[32, 64, 160, 256],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
act_cfg=dict(type='GELU'),
norm_cfg=norm_cfg,
init_cfg=dict()),
decode_head=dict(
type='UPerHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=160,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,8 @@
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0_3rdparty_20230522-956f5e0d.pth' # noqa
model = dict(
backbone=dict(
embed_dims=[32, 64, 160, 256],
depths=[3, 3, 5, 2],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
neck=dict(in_channels=[32, 64, 160, 256]))

View File

@ -0,0 +1,6 @@
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1_3rdparty_20230522-3adb117f.pth' # noqa
model = dict(
backbone=dict(
depths=[2, 2, 4, 2],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)))

View File

@ -0,0 +1,53 @@
_base_ = [
'../_base_/models/van_fpn.py',
'../_base_/datasets/ade20k.py',
'../../../../configs/_base_/default_runtime.py',
]
custom_imports = dict(imports=['projects.van.backbones'])
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.2),
neck=dict(in_channels=[64, 128, 320, 512]),
decode_head=dict(num_classes=150))
train_dataloader = dict(batch_size=4)
# we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
gpu_multiples = 2
max_iters = 80000 // gpu_multiples
interval = 8000 // gpu_multiples
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001 * gpu_multiples,
# betas=(0.9, 0.999),
weight_decay=0.0001),
clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
power=0.9,
eta_min=0.0,
begin=0,
end=max_iters,
by_epoch=False,
)
]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=interval)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=interval),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -0,0 +1,46 @@
_base_ = [
'../_base_/models/van_upernet.py', '../_base_/datasets/ade20k.py',
'../../../../configs/_base_/default_runtime.py',
'../../../../configs/_base_/schedules/schedule_160k.py'
]
custom_imports = dict(imports=['projects.van.backbones'])
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
decode_head=dict(in_channels=[64, 128, 320, 512], num_classes=150),
auxiliary_head=dict(in_channels=320, num_classes=150))
# AdamW optimizer
# no weight decay for position embedding & layer norm in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
clip_grad=None,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
# learning policy
param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
power=1.0,
begin=1500,
end=_base_.train_cfg.max_iters,
eta_min=0.0,
by_epoch=False,
)
]
# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)

View File

@ -0,0 +1,11 @@
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 5, 27, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.3),
neck=dict(in_channels=[64, 128, 320, 512]))
train_dataloader = dict(batch_size=4)

View File

@ -0,0 +1,8 @@
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
model = dict(
type='EncoderDecoder',
backbone=dict(
depths=[3, 5, 27, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.3))

View File

@ -0,0 +1,10 @@
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22k_3rdparty_20230522-5e31cafb.pth' # noqa
model = dict(
backbone=dict(
depths=[3, 6, 40, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.4))
# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=4)

View File

@ -0,0 +1,10 @@
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4_3rdparty_20230522-1d71c077.pth' # noqa
model = dict(
backbone=dict(
depths=[3, 6, 40, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.4))
# By default, models are trained on 4 GPUs with 4 images per GPU
train_dataloader = dict(batch_size=4)

View File

@ -0,0 +1,10 @@
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22k_3rdparty_20230522-b26134d7.pth' # noqa
model = dict(
backbone=dict(
embed_dims=[96, 192, 480, 768],
depths=[3, 3, 24, 3],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.4),
decode_head=dict(in_channels=[96, 192, 480, 768], num_classes=150),
auxiliary_head=dict(in_channels=480, num_classes=150))

View File

@ -0,0 +1,10 @@
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22k_3rdparty_20230522-5e5172a3.pth' # noqa
model = dict(
backbone=dict(
embed_dims=[96, 192, 384, 768],
depths=[6, 6, 90, 6],
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
drop_path_rate=0.5),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))