add configs for vit backbone plus decode_heads (#520)

* add config

* add cityscapes config

* add default value to docstring

* fix lint

* add deit-s and deit-b

* add readme

* add eps at norm_cfg

* add drop_path_rate experiment

* add deit case at init_weight

* add upernet result

* update result and add upernet 160k config

* update upernet result and fix settings

* Update iters number

* update result and delete some configs

* fix import error

* fix drop_path_rate

* update result and restore config

* update benchmark result

* remove cityscapes exp

* remove neck

* neck exp

* add more configs

* fix init error

* fix ffn setting

* update result

* update results

* update result

* update results and fill table

* delete or rename configs

* fix link delimiter

* rename configs and fix link

* rename neck to mln
pull/664/head
谢昕辰 2021-07-01 23:00:39 +08:00 committed by GitHub
parent 36c81441c1
commit 737544f1c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 270 additions and 2 deletions

View File

@ -0,0 +1,58 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(2, 5, 8, 11),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
out_shape='NCHW',
interpolate_mode='bicubic'),
neck=dict(
type='MultiLevelNeck',
in_channels=[768, 768, 768, 768],
out_channels=768,
scales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[768, 768, 768, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=768,
in_index=3,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable

View File

@ -0,0 +1,32 @@
# Vision Transformer
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{dosoViTskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
}
```
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| UPerNet | ViT-B + MLN | 512x512 | 80000 | 9.20 | 6.94 | 47.71 | 49.51 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k-0403cee1.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/20210624_130547.log.json) |
| UPerNet | ViT-B + MLN | 512x512 | 160000 | 9.20 | 7.58 | 46.75 | 48.46 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_160k_ade20k/upernet_vit-b16_mln_512x512_160k_ade20k-852fa768.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_160k_ade20k/20210623_192432.log.json) |
| UPerNet | ViT-B + LN + MLN | 512x512 | 160000 | 9.21 | 6.82 | 47.73 | 49.95 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k/upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k/20210621_172828.log.json) |
| UPerNet | DeiT-S | 512x512 | 80000 | 4.68 | 29.85 | 42.96 | 43.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_80k_ade20k/upernet_deit-s16_512x512_80k_ade20k-afc93ec2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_80k_ade20k/20210624_095228.log.json) |
| UPerNet | DeiT-S | 512x512 | 160000 | 4.68 | 29.19 | 42.87 | 43.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_160k_ade20k/upernet_deit-s16_512x512_160k_ade20k-5110d916.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_160k_ade20k/20210621_160903.log.json) |
| UPerNet | DeiT-S + MLN | 512x512 | 160000 | 5.69 | 11.18 | 43.82 | 45.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_mln_512x512_160k_ade20k/upernet_deit-s16_mln_512x512_160k_ade20k-fb9a5dfb.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_mln_512x512_160k_ade20k/20210621_161021.log.json) |
| UPerNet | DeiT-S + LN + MLN | 512x512 | 160000 | 5.69 | 12.39 | 43.52 | 45.01 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k/upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k/20210621_161021.log.json) |
| UPerNet | DeiT-B | 512x512 | 80000 | 7.75 | 9.69 | 45.24 | 46.73 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_80k_ade20k/upernet_deit-b16_512x512_80k_ade20k-1e090789.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_80k_ade20k/20210624_130529.log.json) |
| UPerNet | DeiT-B | 512x512 | 160000 | 7.75 | 10.39 | 45.36 | 47.16 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_160k_ade20k/upernet_deit-b16_512x512_160k_ade20k-828705d7.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_160k_ade20k/20210621_180100.log.json) |
| UPerNet | DeiT-B + MLN | 512x512 | 160000 | 9.21 | 7.78 | 45.46 | 47.16 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_mln_512x512_160k_ade20k/upernet_deit-b16_mln_512x512_160k_ade20k-4e1450f3.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_mln_512x512_160k_ade20k/20210621_191949.log.json) |
| UPerNet | DeiT-B + LN + MLN | 512x512 | 160000 | 9.21 | 7.75 | 45.37 | 47.23 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k/upernet_deit-b16_ln_mln_512x512_160k_ade20k-8a959c14.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k/20210623_153535.log.json) |

View File

@ -0,0 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable

View File

@ -0,0 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable

View File

@ -0,0 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable

View File

@ -0,0 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),) # yapf: disable

View File

@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable

View File

@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable

View File

@ -0,0 +1,12 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(
num_heads=6,
embed_dims=384,
drop_path_rate=0.1,
final_norm=True),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable

View File

@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'
model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable

View File

@ -0,0 +1,38 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
backbone=dict(drop_path_rate=0.1, final_norm=True),
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -1,6 +1,6 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, xavier_init
from ..builder import NECKS
@ -13,7 +13,8 @@ class MultiLevelNeck(nn.Module):
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
scales (List[int]): Scale factors for each input feature map.
scales (List[float]): Scale factors for each input feature map.
Default: [0.5, 1, 2, 4]
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
@ -52,6 +53,12 @@ class MultiLevelNeck(nn.Module):
norm_cfg=norm_cfg,
act_cfg=act_cfg))
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
inputs = [

View File

@ -5,6 +5,9 @@ from mmseg.models import MultiLevelNeck
def test_multilevel_neck():
# Test init_weights
MultiLevelNeck([266], 256).init_weights()
# Test multi feature maps
in_channels = [256, 512, 1024, 2048]
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]