tang576225574 0bfe255fe6
[Project] Added a supported for Visual Attention Network (VAN) (#2987)
## Motivation
The original version of Visual Attention Network (VAN) can be found from
https://github.com/Visual-Attention-Network/VAN-Segmentation
添加Visual Attention Network (VAN)的支持。



## Modification
added a floder mmsegmentation/projects/van/
added 13 configs totally and aligned performance basically.
只增加了一个文件夹,共增加13个配置文件,基本对齐性能(没有全部跑)。


## Use cases (Optional)
Before running, you may need to download the pretrain model from
https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/
and then move them to the folder mmsegmentation/pretrained/, i.e.
"mmsegmentation/pretrained/van_b2.pth".
After that, run the following command:
    cd mmsegmentation
bash tools/dist_train.sh
projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py
4

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
2023-05-22 20:26:26 +08:00

125 lines
3.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmseg.models.backbones.mscan import (MSCAN, MSCABlock,
MSCASpatialAttention,
OverlapPatchEmbed)
from mmseg.registry import MODELS
class VANAttentionModule(BaseModule):
def __init__(self, in_channels):
super().__init__()
self.conv0 = nn.Conv2d(
in_channels, in_channels, 5, padding=2, groups=in_channels)
self.conv_spatial = nn.Conv2d(
in_channels,
in_channels,
7,
stride=1,
padding=9,
groups=in_channels,
dilation=3)
self.conv1 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
class VANSpatialAttention(MSCASpatialAttention):
def __init__(self, in_channels, act_cfg=dict(type='GELU')):
super().__init__(in_channels, act_cfg=act_cfg)
self.spatial_gating_unit = VANAttentionModule(in_channels)
class VANBlock(MSCABlock):
def __init__(self,
channels,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__(
channels,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path,
act_cfg=act_cfg,
norm_cfg=norm_cfg)
self.attn = VANSpatialAttention(channels)
@MODELS.register_module()
class VAN(MSCAN):
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super(MSCAN, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
# stochastic depth decay rule
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(num_stages):
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
VANBlock(
channels=embed_dims[i],
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
return super().init_weights()