mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Project] Added a supported for Visual Attention Network (VAN) (#2987)
## Motivation The original version of Visual Attention Network (VAN) can be found from https://github.com/Visual-Attention-Network/VAN-Segmentation 添加Visual Attention Network (VAN)的支持。 ## Modification added a floder mmsegmentation/projects/van/ added 13 configs totally and aligned performance basically. 只增加了一个文件夹,共增加13个配置文件,基本对齐性能(没有全部跑)。 ## Use cases (Optional) Before running, you may need to download the pretrain model from https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/ and then move them to the folder mmsegmentation/pretrained/, i.e. "mmsegmentation/pretrained/van_b2.pth". After that, run the following command: cd mmsegmentation bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4 --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
parent
7d6156776e
commit
0bfe255fe6
101
projects/van/README.md
Normal file
101
projects/van/README.md
Normal file
@ -0,0 +1,101 @@
|
||||
# Visual Attention Network (VAN) for Segmentation
|
||||
|
||||
This repo is a PyTorch implementation of applying **VAN** (**Visual Attention Network**) to semantic segmentation.
|
||||
|
||||
The code is an integration from [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
|
||||
|
||||
More details can be found in [**Visual Attention Network**](https://arxiv.org/abs/2202.09741).
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{guo2022visual,
|
||||
title={Visual Attention Network},
|
||||
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
|
||||
journal={arXiv preprint arXiv:2202.09741},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
**Notes**: Pre-trained models can be found in [TsingHua Cloud](https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/).
|
||||
|
||||
Results can be found in [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
|
||||
|
||||
We provide evaluation results of the converted weights.
|
||||
|
||||
| Method | Backbone | mIoU | Download |
|
||||
| :-----: | :----------: | :---: | :--------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| UPerNet | VAN-B2 | 49.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-19c58aee.pth) |
|
||||
| UPerNet | VAN-B3 | 49.71 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
|
||||
| UPerNet | VAN-B4 | 51.56 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
|
||||
| UPerNet | VAN-B4-in22k | 52.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-4a4d744a.pth) |
|
||||
| UPerNet | VAN-B5-in22k | 53.11 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-5bb6f2b4.pth) |
|
||||
| UPerNet | VAN-B6-in22k | 54.25 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-e226b363.pth) |
|
||||
| FPN | VAN-B0 | 38.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-75a76298.pth) |
|
||||
| FPN | VAN-B1 | 43.22 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-104499ff.pth) |
|
||||
| FPN | VAN-B2 | 46.84 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-7074e6f8.pth) |
|
||||
| FPN | VAN-B3 | 48.32 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-2c3b7f5e.pth) |
|
||||
|
||||
## Preparation
|
||||
|
||||
Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation.
|
||||
|
||||
## Requirement
|
||||
|
||||
**Step 0.** Install [MMCV](https://github.com/open-mmlab/mmcv) using [MIM](https://github.com/open-mmlab/mim).
|
||||
|
||||
```shell
|
||||
pip install -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv>=2.0.0"
|
||||
```
|
||||
|
||||
**Step 1.** Install MMSegmentation.
|
||||
|
||||
Case a: If you develop and run mmseg directly, install it from source:
|
||||
|
||||
```shell
|
||||
git clone -b main https://github.com/open-mmlab/mmsegmentation.git
|
||||
cd mmsegmentation
|
||||
pip install -v -e .
|
||||
```
|
||||
|
||||
Case b: If you use mmsegmentation as a dependency or third-party package, install it with pip:
|
||||
|
||||
```shell
|
||||
pip install "mmsegmentation>=1.0.0"
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
If you use 4 GPUs for training by default. Run:
|
||||
|
||||
```bash
|
||||
bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
To evaluate the model, an example is:
|
||||
|
||||
```bash
|
||||
bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py work_dirs/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512/iter_160000.pth 4 --eval mIoU
|
||||
```
|
||||
|
||||
## FLOPs
|
||||
|
||||
To calculate FLOPs for a model, run:
|
||||
|
||||
```bash
|
||||
bash tools/analysis_tools/get_flops.py projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py --shape 512 512
|
||||
```
|
||||
|
||||
## Acknowledgment
|
||||
|
||||
Our implementation is mainly based on [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0), [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), [PoolFormer](https://github.com/sail-sg/poolformer), [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger) and [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1). Thanks for their authors.
|
||||
|
||||
## LICENSE
|
||||
|
||||
This repo is under the Apache-2.0 license. For commercial use, please contact the authors.
|
4
projects/van/backbones/__init__.py
Normal file
4
projects/van/backbones/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .van import VAN
|
||||
|
||||
__all__ = ['VAN']
|
124
projects/van/backbones/van.py
Normal file
124
projects/van/backbones/van.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.backbones.mscan import (MSCAN, MSCABlock,
|
||||
MSCASpatialAttention,
|
||||
OverlapPatchEmbed)
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class VANAttentionModule(BaseModule):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(
|
||||
in_channels, in_channels, 5, padding=2, groups=in_channels)
|
||||
self.conv_spatial = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
7,
|
||||
stride=1,
|
||||
padding=9,
|
||||
groups=in_channels,
|
||||
dilation=3)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
u = x.clone()
|
||||
attn = self.conv0(x)
|
||||
attn = self.conv_spatial(attn)
|
||||
attn = self.conv1(attn)
|
||||
return u * attn
|
||||
|
||||
|
||||
class VANSpatialAttention(MSCASpatialAttention):
|
||||
|
||||
def __init__(self, in_channels, act_cfg=dict(type='GELU')):
|
||||
super().__init__(in_channels, act_cfg=act_cfg)
|
||||
self.spatial_gating_unit = VANAttentionModule(in_channels)
|
||||
|
||||
|
||||
class VANBlock(MSCABlock):
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
self.attn = VANSpatialAttention(channels)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VAN(MSCAN):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(MSCAN, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
block = nn.ModuleList([
|
||||
VANBlock(
|
||||
channels=embed_dims[i],
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg) for j in range(depths[i])
|
||||
])
|
||||
norm = nn.LayerNorm(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
||||
setattr(self, f'block{i + 1}', block)
|
||||
setattr(self, f'norm{i + 1}', norm)
|
||||
|
||||
def init_weights(self):
|
||||
return super().init_weights()
|
14
projects/van/configs/_base_/datasets/ade20k.py
Normal file
14
projects/van/configs/_base_/datasets/ade20k.py
Normal file
@ -0,0 +1,14 @@
|
||||
# dataset settings
|
||||
_base_ = '../../../../../configs/_base_/datasets/ade20k.py'
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
dict(type='ResizeToMultiple', size_divisor=32),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
43
projects/van/configs/_base_/models/van_fpn.py
Normal file
43
projects/van/configs/_base_/models/van_fpn.py
Normal file
@ -0,0 +1,43 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(512, 512))
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
data_preprocessor=data_preprocessor,
|
||||
backbone=dict(
|
||||
type='VAN',
|
||||
embed_dims=[32, 64, 160, 256],
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
depths=[3, 3, 5, 2],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=dict()),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[32, 64, 160, 256],
|
||||
out_channels=256,
|
||||
num_outs=4),
|
||||
decode_head=dict(
|
||||
type='FPNHead',
|
||||
in_channels=[256, 256, 256, 256],
|
||||
in_index=[0, 1, 2, 3],
|
||||
feature_strides=[4, 8, 16, 32],
|
||||
channels=128,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
51
projects/van/configs/_base_/models/van_upernet.py
Normal file
51
projects/van/configs/_base_/models/van_upernet.py
Normal file
@ -0,0 +1,51 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(512, 512))
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
data_preprocessor=data_preprocessor,
|
||||
backbone=dict(
|
||||
type='VAN',
|
||||
embed_dims=[32, 64, 160, 256],
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
depths=[3, 3, 5, 2],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=dict()),
|
||||
decode_head=dict(
|
||||
type='UPerHead',
|
||||
in_channels=[32, 64, 160, 256],
|
||||
in_index=[0, 1, 2, 3],
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=160,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
@ -0,0 +1,8 @@
|
||||
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0_3rdparty_20230522-956f5e0d.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
embed_dims=[32, 64, 160, 256],
|
||||
depths=[3, 3, 5, 2],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
|
||||
neck=dict(in_channels=[32, 64, 160, 256]))
|
@ -0,0 +1,6 @@
|
||||
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1_3rdparty_20230522-3adb117f.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
depths=[2, 2, 4, 2],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)))
|
@ -0,0 +1,53 @@
|
||||
_base_ = [
|
||||
'../_base_/models/van_fpn.py',
|
||||
'../_base_/datasets/ade20k.py',
|
||||
'../../../../configs/_base_/default_runtime.py',
|
||||
]
|
||||
custom_imports = dict(imports=['projects.van.backbones'])
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
embed_dims=[64, 128, 320, 512],
|
||||
depths=[3, 3, 12, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.2),
|
||||
neck=dict(in_channels=[64, 128, 320, 512]),
|
||||
decode_head=dict(num_classes=150))
|
||||
|
||||
train_dataloader = dict(batch_size=4)
|
||||
|
||||
# we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
|
||||
gpu_multiples = 2
|
||||
max_iters = 80000 // gpu_multiples
|
||||
interval = 8000 // gpu_multiples
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=0.0001 * gpu_multiples,
|
||||
# betas=(0.9, 0.999),
|
||||
weight_decay=0.0001),
|
||||
clip_grad=None)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=0.9,
|
||||
eta_min=0.0,
|
||||
begin=0,
|
||||
end=max_iters,
|
||||
by_epoch=False,
|
||||
)
|
||||
]
|
||||
train_cfg = dict(
|
||||
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=interval)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=interval),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
@ -0,0 +1,46 @@
|
||||
_base_ = [
|
||||
'../_base_/models/van_upernet.py', '../_base_/datasets/ade20k.py',
|
||||
'../../../../configs/_base_/default_runtime.py',
|
||||
'../../../../configs/_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
custom_imports = dict(imports=['projects.van.backbones'])
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
embed_dims=[64, 128, 320, 512],
|
||||
depths=[3, 3, 12, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
|
||||
decode_head=dict(in_channels=[64, 128, 320, 512], num_classes=150),
|
||||
auxiliary_head=dict(in_channels=320, num_classes=150))
|
||||
|
||||
# AdamW optimizer
|
||||
# no weight decay for position embedding & layer norm in backbone
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
|
||||
clip_grad=None,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
'relative_position_bias_table': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}))
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=1.0,
|
||||
begin=1500,
|
||||
end=_base_.train_cfg.max_iters,
|
||||
eta_min=0.0,
|
||||
by_epoch=False,
|
||||
)
|
||||
]
|
||||
|
||||
# By default, models are trained on 8 GPUs with 2 images per GPU
|
||||
train_dataloader = dict(batch_size=2)
|
@ -0,0 +1,11 @@
|
||||
_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
embed_dims=[64, 128, 320, 512],
|
||||
depths=[3, 5, 27, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.3),
|
||||
neck=dict(in_channels=[64, 128, 320, 512]))
|
||||
train_dataloader = dict(batch_size=4)
|
@ -0,0 +1,8 @@
|
||||
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
depths=[3, 5, 27, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.3))
|
@ -0,0 +1,10 @@
|
||||
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22k_3rdparty_20230522-5e31cafb.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
depths=[3, 6, 40, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.4))
|
||||
|
||||
# By default, models are trained on 8 GPUs with 2 images per GPU
|
||||
train_dataloader = dict(batch_size=4)
|
@ -0,0 +1,10 @@
|
||||
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4_3rdparty_20230522-1d71c077.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
depths=[3, 6, 40, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.4))
|
||||
|
||||
# By default, models are trained on 4 GPUs with 4 images per GPU
|
||||
train_dataloader = dict(batch_size=4)
|
@ -0,0 +1,10 @@
|
||||
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22k_3rdparty_20230522-b26134d7.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
embed_dims=[96, 192, 480, 768],
|
||||
depths=[3, 3, 24, 3],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.4),
|
||||
decode_head=dict(in_channels=[96, 192, 480, 768], num_classes=150),
|
||||
auxiliary_head=dict(in_channels=480, num_classes=150))
|
@ -0,0 +1,10 @@
|
||||
_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
|
||||
ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22k_3rdparty_20230522-5e5172a3.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
embed_dims=[96, 192, 384, 768],
|
||||
depths=[6, 6, 90, 6],
|
||||
init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
|
||||
drop_path_rate=0.5),
|
||||
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
|
||||
auxiliary_head=dict(in_channels=384, num_classes=150))
|
Loading…
x
Reference in New Issue
Block a user