From 0bfe255fe6da38caa6a50d593f39a288ed1e820a Mon Sep 17 00:00:00 2001 From: tang576225574 <576225574@qq.com> Date: Mon, 22 May 2023 20:26:26 +0800 Subject: [PATCH] [Project] Added a supported for Visual Attention Network (VAN) (#2987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The original version of Visual Attention Network (VAN) can be found from https://github.com/Visual-Attention-Network/VAN-Segmentation 添加Visual Attention Network (VAN)的支持。 ## Modification added a floder mmsegmentation/projects/van/ added 13 configs totally and aligned performance basically. 只增加了一个文件夹,共增加13个配置文件,基本对齐性能(没有全部跑)。 ## Use cases (Optional) Before running, you may need to download the pretrain model from https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/ and then move them to the folder mmsegmentation/pretrained/, i.e. "mmsegmentation/pretrained/van_b2.pth". After that, run the following command: cd mmsegmentation bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4 --------- Co-authored-by: xiexinch --- projects/van/README.md | 101 ++++++++++++++ projects/van/backbones/__init__.py | 4 + projects/van/backbones/van.py | 124 ++++++++++++++++++ .../van/configs/_base_/datasets/ade20k.py | 14 ++ projects/van/configs/_base_/models/van_fpn.py | 43 ++++++ .../van/configs/_base_/models/van_upernet.py | 51 +++++++ .../van/van-b0_fpn_8xb4-40k_ade20k-512x512.py | 8 ++ .../van/van-b1_fpn_8xb4-40k_ade20k-512x512.py | 6 + .../van/van-b2_fpn_8xb4-40k_ade20k-512x512.py | 53 ++++++++ ...van-b2_upernet_4xb2-160k_ade20k-512x512.py | 46 +++++++ .../van/van-b3_fpn_8xb4-40k_ade20k-512x512.py | 11 ++ ...van-b3_upernet_4xb2-160k_ade20k-512x512.py | 8 ++ ...22kpre_upernet_4xb4-160k_ade20k-512x512.py | 10 ++ ...van-b4_upernet_4xb4-160k_ade20k-512x512.py | 10 ++ ...22kpre_upernet_4xb2-160k_ade20k-512x512.py | 10 ++ ...22kpre_upernet_4xb2-160k_ade20k-512x512.py | 10 ++ 16 files changed, 509 insertions(+) create mode 100644 projects/van/README.md create mode 100644 projects/van/backbones/__init__.py create mode 100644 projects/van/backbones/van.py create mode 100644 projects/van/configs/_base_/datasets/ade20k.py create mode 100644 projects/van/configs/_base_/models/van_fpn.py create mode 100644 projects/van/configs/_base_/models/van_upernet.py create mode 100644 projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py create mode 100644 projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py diff --git a/projects/van/README.md b/projects/van/README.md new file mode 100644 index 000000000..be0ba362f --- /dev/null +++ b/projects/van/README.md @@ -0,0 +1,101 @@ +# Visual Attention Network (VAN) for Segmentation + +This repo is a PyTorch implementation of applying **VAN** (**Visual Attention Network**) to semantic segmentation. + +The code is an integration from [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1) + +More details can be found in [**Visual Attention Network**](https://arxiv.org/abs/2202.09741). + +## Citation + +```bib +@article{guo2022visual, + title={Visual Attention Network}, + author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min}, + journal={arXiv preprint arXiv:2202.09741}, + year={2022} +} +``` + +## Results + +**Notes**: Pre-trained models can be found in [TsingHua Cloud](https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/). + +Results can be found in [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1) + +We provide evaluation results of the converted weights. + +| Method | Backbone | mIoU | Download | +| :-----: | :----------: | :---: | :--------------------------------------------------------------------------------------------------------------------------------------------: | +| UPerNet | VAN-B2 | 49.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-19c58aee.pth) | +| UPerNet | VAN-B3 | 49.71 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) | +| UPerNet | VAN-B4 | 51.56 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) | +| UPerNet | VAN-B4-in22k | 52.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-4a4d744a.pth) | +| UPerNet | VAN-B5-in22k | 53.11 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-5bb6f2b4.pth) | +| UPerNet | VAN-B6-in22k | 54.25 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-e226b363.pth) | +| FPN | VAN-B0 | 38.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-75a76298.pth) | +| FPN | VAN-B1 | 43.22 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-104499ff.pth) | +| FPN | VAN-B2 | 46.84 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-7074e6f8.pth) | +| FPN | VAN-B3 | 48.32 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-2c3b7f5e.pth) | + +## Preparation + +Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation. + +## Requirement + +**Step 0.** Install [MMCV](https://github.com/open-mmlab/mmcv) using [MIM](https://github.com/open-mmlab/mim). + +```shell +pip install -U openmim +mim install mmengine +mim install "mmcv>=2.0.0" +``` + +**Step 1.** Install MMSegmentation. + +Case a: If you develop and run mmseg directly, install it from source: + +```shell +git clone -b main https://github.com/open-mmlab/mmsegmentation.git +cd mmsegmentation +pip install -v -e . +``` + +Case b: If you use mmsegmentation as a dependency or third-party package, install it with pip: + +```shell +pip install "mmsegmentation>=1.0.0" +``` + +## Training + +If you use 4 GPUs for training by default. Run: + +```bash +bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4 +``` + +## Evaluation + +To evaluate the model, an example is: + +```bash +bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py work_dirs/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512/iter_160000.pth 4 --eval mIoU +``` + +## FLOPs + +To calculate FLOPs for a model, run: + +```bash +bash tools/analysis_tools/get_flops.py projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py --shape 512 512 +``` + +## Acknowledgment + +Our implementation is mainly based on [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0), [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), [PoolFormer](https://github.com/sail-sg/poolformer), [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger) and [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1). Thanks for their authors. + +## LICENSE + +This repo is under the Apache-2.0 license. For commercial use, please contact the authors. diff --git a/projects/van/backbones/__init__.py b/projects/van/backbones/__init__.py new file mode 100644 index 000000000..071995de2 --- /dev/null +++ b/projects/van/backbones/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .van import VAN + +__all__ = ['VAN'] diff --git a/projects/van/backbones/van.py b/projects/van/backbones/van.py new file mode 100644 index 000000000..301834a75 --- /dev/null +++ b/projects/van/backbones/van.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmseg.models.backbones.mscan import (MSCAN, MSCABlock, + MSCASpatialAttention, + OverlapPatchEmbed) +from mmseg.registry import MODELS + + +class VANAttentionModule(BaseModule): + + def __init__(self, in_channels): + super().__init__() + self.conv0 = nn.Conv2d( + in_channels, in_channels, 5, padding=2, groups=in_channels) + self.conv_spatial = nn.Conv2d( + in_channels, + in_channels, + 7, + stride=1, + padding=9, + groups=in_channels, + dilation=3) + self.conv1 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + u = x.clone() + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + return u * attn + + +class VANSpatialAttention(MSCASpatialAttention): + + def __init__(self, in_channels, act_cfg=dict(type='GELU')): + super().__init__(in_channels, act_cfg=act_cfg) + self.spatial_gating_unit = VANAttentionModule(in_channels) + + +class VANBlock(MSCABlock): + + def __init__(self, + channels, + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__( + channels, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path, + act_cfg=act_cfg, + norm_cfg=norm_cfg) + self.attn = VANSpatialAttention(channels) + + +@MODELS.register_module() +class VAN(MSCAN): + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[8, 8, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super(MSCAN, self).__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + + for i in range(num_stages): + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + VANBlock( + channels=embed_dims[i], + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + return super().init_weights() diff --git a/projects/van/configs/_base_/datasets/ade20k.py b/projects/van/configs/_base_/datasets/ade20k.py new file mode 100644 index 000000000..69b3c2a73 --- /dev/null +++ b/projects/van/configs/_base_/datasets/ade20k.py @@ -0,0 +1,14 @@ +# dataset settings +_base_ = '../../../../../configs/_base_/datasets/ade20k.py' + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 512), keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='PackSegInputs') +] +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/van/configs/_base_/models/van_fpn.py b/projects/van/configs/_base_/models/van_fpn.py new file mode 100644 index 000000000..c7fd7391f --- /dev/null +++ b/projects/van/configs/_base_/models/van_fpn.py @@ -0,0 +1,43 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) + +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size=(512, 512)) +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='VAN', + embed_dims=[32, 64, 160, 256], + drop_rate=0.0, + drop_path_rate=0.1, + depths=[3, 3, 5, 2], + act_cfg=dict(type='GELU'), + norm_cfg=norm_cfg, + init_cfg=dict()), + neck=dict( + type='FPN', + in_channels=[32, 64, 160, 256], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/projects/van/configs/_base_/models/van_upernet.py b/projects/van/configs/_base_/models/van_upernet.py new file mode 100644 index 000000000..8f94c0d9d --- /dev/null +++ b/projects/van/configs/_base_/models/van_upernet.py @@ -0,0 +1,51 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) + +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size=(512, 512)) +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='VAN', + embed_dims=[32, 64, 160, 256], + drop_rate=0.0, + drop_path_rate=0.1, + depths=[3, 3, 5, 2], + act_cfg=dict(type='GELU'), + norm_cfg=norm_cfg, + init_cfg=dict()), + decode_head=dict( + type='UPerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=160, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py new file mode 100644 index 000000000..2faf3788a --- /dev/null +++ b/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py @@ -0,0 +1,8 @@ +_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0_3rdparty_20230522-956f5e0d.pth' # noqa +model = dict( + backbone=dict( + embed_dims=[32, 64, 160, 256], + depths=[3, 3, 5, 2], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)), + neck=dict(in_channels=[32, 64, 160, 256])) diff --git a/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py new file mode 100644 index 000000000..cf64a7138 --- /dev/null +++ b/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py @@ -0,0 +1,6 @@ +_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1_3rdparty_20230522-3adb117f.pth' # noqa +model = dict( + backbone=dict( + depths=[2, 2, 4, 2], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path))) diff --git a/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py new file mode 100644 index 000000000..965fa1cd3 --- /dev/null +++ b/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py @@ -0,0 +1,53 @@ +_base_ = [ + '../_base_/models/van_fpn.py', + '../_base_/datasets/ade20k.py', + '../../../../configs/_base_/default_runtime.py', +] +custom_imports = dict(imports=['projects.van.backbones']) +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[3, 3, 12, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.2), + neck=dict(in_channels=[64, 128, 320, 512]), + decode_head=dict(num_classes=150)) + +train_dataloader = dict(batch_size=4) + +# we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2 +gpu_multiples = 2 +max_iters = 80000 // gpu_multiples +interval = 8000 // gpu_multiples +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.0001 * gpu_multiples, + # betas=(0.9, 0.999), + weight_decay=0.0001), + clip_grad=None) +# learning policy +param_scheduler = [ + dict( + type='PolyLR', + power=0.9, + eta_min=0.0, + begin=0, + end=max_iters, + by_epoch=False, + ) +] +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=max_iters, val_interval=interval) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=interval), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) diff --git a/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py new file mode 100644 index 000000000..c529606a2 --- /dev/null +++ b/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py @@ -0,0 +1,46 @@ +_base_ = [ + '../_base_/models/van_upernet.py', '../_base_/datasets/ade20k.py', + '../../../../configs/_base_/default_runtime.py', + '../../../../configs/_base_/schedules/schedule_160k.py' +] +custom_imports = dict(imports=['projects.van.backbones']) +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[3, 3, 12, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)), + decode_head=dict(in_channels=[64, 128, 320, 512], num_classes=150), + auxiliary_head=dict(in_channels=320, num_classes=150)) + +# AdamW optimizer +# no weight decay for position embedding & layer norm in backbone +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + clip_grad=None, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +# learning policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + power=1.0, + begin=1500, + end=_base_.train_cfg.max_iters, + eta_min=0.0, + by_epoch=False, + ) +] + +# By default, models are trained on 8 GPUs with 2 images per GPU +train_dataloader = dict(batch_size=2) diff --git a/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py new file mode 100644 index 000000000..b0493fe4f --- /dev/null +++ b/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py @@ -0,0 +1,11 @@ +_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[3, 5, 27, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.3), + neck=dict(in_channels=[64, 128, 320, 512])) +train_dataloader = dict(batch_size=4) diff --git a/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py new file mode 100644 index 000000000..8201801d9 --- /dev/null +++ b/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py @@ -0,0 +1,8 @@ +_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa +model = dict( + type='EncoderDecoder', + backbone=dict( + depths=[3, 5, 27, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.3)) diff --git a/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py b/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py new file mode 100644 index 000000000..15c8f7ca6 --- /dev/null +++ b/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py @@ -0,0 +1,10 @@ +_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22k_3rdparty_20230522-5e31cafb.pth' # noqa +model = dict( + backbone=dict( + depths=[3, 6, 40, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.4)) + +# By default, models are trained on 8 GPUs with 2 images per GPU +train_dataloader = dict(batch_size=4) diff --git a/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py b/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py new file mode 100644 index 000000000..33ae049d0 --- /dev/null +++ b/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py @@ -0,0 +1,10 @@ +_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4_3rdparty_20230522-1d71c077.pth' # noqa +model = dict( + backbone=dict( + depths=[3, 6, 40, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.4)) + +# By default, models are trained on 4 GPUs with 4 images per GPU +train_dataloader = dict(batch_size=4) diff --git a/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py new file mode 100644 index 000000000..f36c6242b --- /dev/null +++ b/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py @@ -0,0 +1,10 @@ +_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22k_3rdparty_20230522-b26134d7.pth' # noqa +model = dict( + backbone=dict( + embed_dims=[96, 192, 480, 768], + depths=[3, 3, 24, 3], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.4), + decode_head=dict(in_channels=[96, 192, 480, 768], num_classes=150), + auxiliary_head=dict(in_channels=480, num_classes=150)) diff --git a/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py new file mode 100644 index 000000000..aa529efed --- /dev/null +++ b/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py @@ -0,0 +1,10 @@ +_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py' +ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22k_3rdparty_20230522-5e5172a3.pth' # noqa +model = dict( + backbone=dict( + embed_dims=[96, 192, 384, 768], + depths=[6, 6, 90, 6], + init_cfg=dict(type='Pretrained', checkpoint=ckpt_path), + drop_path_rate=0.5), + decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), + auxiliary_head=dict(in_channels=384, num_classes=150))