[Feature] Support Segmenter (#955)

* segmenter: add model

* update

* readme: update

* config: update

* segmenter: update readme

* segmenter: update

* segmenter: update

* segmenter: update

* configs: set checkpoint path to pretrain folder

* segmenter: modify vit-s/lin, remove data config

* rreadme: update

* configs: transfer from _base_ to segmenter

* configs: add 8x1 suffix

* configs: remove redundant lines

* configs: cleanup

* first attempt

* swipe CI error

* Update mmseg/models/decode_heads/__init__.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* segmenter_linear: use fcn backbone

* segmenter_mask: update

* models: add segmenter vit

* decoders: yapf+remove unused imports

* apply precommit

* segmenter/linear_head: fix

* segmenter/linear_header: fix

* segmenter: fix mask transformer

* fix error

* segmenter/mask_head: use trunc_normal init

* refactor segmenter head

* Fetch upstream (#1)

* [Feature] Change options to cfg-option (#1129)

* [Feature] Change option to cfg-option

* add expire date and fix the docs

* modify docstring

* [Fix] Add <!-- [ABSTRACT] --> in metafile #1127

* [Fix] Fix correct num_classes of HRNet in LoveDA dataset #1136

* Bump to v0.20.1 (#1138)

* bump version 0.20.1

* bump version 0.20.1

* [Fix] revise --option to --options #1140

Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>

* decode_head: switch from linear to fcn

* fix init list formatting

* configs: remove variants, keep only vit-s on ade

* align inference metric of vit-s-mask

* configs: add vit t/b/l

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* model_converters: use torch instead of einops

* setup: remove einops

* segmenter_mask: fix missing imports

* add necessary imported init funtion

* segmenter/seg-l: set resolution to 640

* segmenter/seg-l: fix test size

* fix vitjax2mmseg

* add README and unittest

* fix unittest

* add docstring

* refactor config and add pretrained link

* fix typo

* add paper name in readme

* change segmenter config names

* fix typo in readme

* fix typos in readme

* fix segmenter typo

* fix segmenter typo

* delete redundant comma in config files

* delete redundant comma in config files

* fix convert script

* update lateset master version

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
rstrudel 2022-01-26 06:50:51 +01:00 committed by GitHub
parent 80a48c840e
commit cb1bf9f372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 754 additions and 2 deletions

View File

@ -118,6 +118,7 @@ Supported methods:
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
Supported datasets:

View File

@ -117,6 +117,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
已支持的数据集:

View File

@ -0,0 +1,35 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain/vit_base_p16_384.pth',
backbone=dict(
type='VisionTransformer',
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
drop_path_rate=0.1,
attn_drop_rate=0.0,
drop_rate=0.0,
final_norm=True,
norm_cfg=backbone_norm_cfg,
with_cls_token=True,
interpolate_mode='bicubic',
),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=768,
channels=768,
num_classes=150,
num_layers=2,
num_heads=12,
embed_dims=768,
dropout_ratio=0.0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
),
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)),
)

View File

@ -0,0 +1,73 @@
# Segmenter
[Segmenter: Transformer for Semantic Segmentation](https://arxiv.org/abs/2105.05633)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/rstrudel/segmenter">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
Image segmentation is often ambiguous at the level of individual image patches and requires contextual information to reach label consensus. In this paper we introduce Segmenter, a transformer model for semantic segmentation. In contrast to convolution-based methods, our approach allows to model global context already at the first layer and throughout the network. We build on the recent Vision Transformer (ViT) and extend it to semantic segmentation. To do so, we rely on the output embeddings corresponding to image patches and obtain class labels from these embeddings with a point-wise linear decoder or a mask transformer decoder. We leverage models pre-trained for image classification and show that we can fine-tune them on moderate sized datasets available for semantic segmentation. The linear decoder allows to obtain excellent results already, but the performance can be further improved by a mask transformer generating class masks. We conduct an extensive ablation study to show the impact of the different parameters, in particular the performance is better for large models and small patch sizes. Segmenter attains excellent results for semantic segmentation. It outperforms the state of the art on both ADE20K and Pascal Context datasets and is competitive on Cityscapes.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/148507554-87eb80bd-02c7-4c31-b102-c6141e231ec8.png" width="70%"/>
</div>
```bibtex
@article{strudel2021Segmenter,
title={Segmenter: Transformer for Semantic Segmentation},
author={Strudel, Robin and Ricardo, Garcia, and Laptev, Ivan and Schmid, Cordelia},
journal={arXiv preprint arXiv:2105.05633},
year={2021}
}
```
## Usage
To use the pre-trained ViT model from [Segmenter](https://github.com/rstrudel/segmenter), it is necessary to convert keys.
We provide a script [`vitjax2mmseg.py`](../../tools/model_converters/vitjax2mmseg.py) in the tools directory to convert the key of models from [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) to MMSegmentation style.
```shell
python tools/model_converters/vitjax2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/vitjax2mmseg.py \
Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz \
pretrain/vit_tiny_p16_384.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
In our default setting, pretrained models and their corresponding [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) models could be defined below:
| pretrained models | original models |
| ------ | -------- |
|vit_tiny_p16_384.pth | ['vit_tiny_patch16_384'](https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
|vit_small_p16_384.pth | ['vit_small_patch16_384'](https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
|vit_base_p16_384.pth | ['vit_base_patch16_384'](https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz) |
|vit_large_p16_384.pth | ['vit_large_patch16_384'](https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz) |
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
| Segmenter-Mask | ViT-T_16 | 512x512 | 160000 | 1.21 | 27.98 | 39.99 | 40.83 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Linear | ViT-S_16 | 512x512 | 160000 | 1.78 | 28.07 | 45.75 | 46.82 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713.log.json) |
| Segmenter-Mask | ViT-S_16 | 512x512 | 160000 | 2.03 | 24.80 | 46.19 | 47.85 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Mask | ViT-B_16 |512x512 | 160000 | 4.20 | 13.20 | 49.60 | 51.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Mask | ViT-L_16 |640x640 | 160000 | 16.56 | 2.62 | 52.16 | 53.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750.log.json) |

View File

@ -0,0 +1,125 @@
Collections:
- Name: segmenter
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2105.05633
Title: 'Segmenter: Transformer for Semantic Segmentation'
README: configs/segmenter/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15
Version: v0.21.0
Converted From:
Code: https://github.com/rstrudel/segmenter
Models:
- Name: segmenter_vit-t_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-T_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 35.74
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.21
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 39.99
mIoU(ms+flip): 40.83
Config: configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth
- Name: segmenter_vit-s_linear_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-S_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 35.63
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.78
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 45.75
mIoU(ms+flip): 46.82
Config: configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth
- Name: segmenter_vit-s_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-S_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 40.32
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 2.03
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 46.19
mIoU(ms+flip): 47.85
Config: configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth
- Name: segmenter_vit-b_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-B_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 75.76
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 4.2
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 49.6
mIoU(ms+flip): 51.07
Config: configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth
- Name: segmenter_vit-l_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-L_16
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 381.68
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (640,640)
Training Memory (GB): 16.56
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 52.16
mIoU(ms+flip): 53.65
Config: configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth

View File

@ -0,0 +1,43 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
optimizer = dict(lr=0.001, weight_decay=0.0)
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,60 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained='pretrain/vit_large_p16_384.pth',
backbone=dict(
type='VisionTransformer',
img_size=(640, 640),
embed_dims=1024,
num_layers=24,
num_heads=16),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=1024,
channels=1024,
num_heads=16,
embed_dims=1024),
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(608, 608)))
optimizer = dict(lr=0.001, weight_decay=0.0)
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,14 @@
_base_ = './segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py'
model = dict(
decode_head=dict(
_delete_=True,
type='FCNHead',
in_channels=384,
channels=384,
num_convs=0,
dropout_ratio=0.0,
concat_input=False,
num_classes=150,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))

View File

@ -0,0 +1,64 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
model = dict(
pretrained='pretrain/vit_small_p16_384.pth',
backbone=dict(
img_size=(512, 512),
embed_dims=384,
num_heads=6,
),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=384,
channels=384,
num_classes=150,
num_layers=2,
num_heads=6,
embed_dims=384,
dropout_ratio=0.0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
optimizer = dict(lr=0.001, weight_decay=0.0)
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,54 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained='pretrain/vit_tiny_p16_384.pth',
backbone=dict(embed_dims=192, num_heads=3),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=192,
channels=192,
num_heads=3,
embed_dims=192))
optimizer = dict(lr=0.001, weight_decay=0.0)
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -1,5 +1,6 @@
## Changelog
### V0.20.2 (12/15/2021)
**Bug Fixes**

View File

@ -20,6 +20,7 @@ from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .segformer_head import SegformerHead
from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
@ -32,6 +33,6 @@ __all__ = [
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead',
'STDCHead'
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead'
]

View File

@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
trunc_normal_init)
from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.
This head is the implementation of
`Segmenter: <https://arxiv.org/abs/2105.05633>`_.
Args:
backbone_cfg:(dict): Config of backbone of
Context Path.
in_channels (int): The number of channels of input image.
num_layers (int): The depth of transformer.
num_heads (int): The number of attention heads.
embed_dims (int): The number of embedding dimension.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_path_rate (float): stochastic depth rate. Default 0.1.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
init_std (float): The value of std in weight initialization.
Default: 0.02.
"""
def __init__(
self,
in_channels,
num_layers,
num_heads,
embed_dims,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_std=0.02,
**kwargs,
):
super(SegmenterMaskTransformerHead, self).__init__(
in_channels=in_channels, **kwargs)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
batch_first=True,
))
self.dec_proj = nn.Linear(in_channels, embed_dims)
self.cls_emb = nn.Parameter(
torch.randn(1, self.num_classes, embed_dims))
self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
self.decoder_norm = build_norm_layer(
norm_cfg, embed_dims, postfix=1)[1]
self.mask_norm = build_norm_layer(
norm_cfg, self.num_classes, postfix=2)[1]
self.init_std = init_std
delattr(self, 'conv_seg')
def init_weights(self):
trunc_normal_(self.cls_emb, std=self.init_std)
trunc_normal_init(self.patch_proj, std=self.init_std)
trunc_normal_init(self.classes_proj, std=self.init_std)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=self.init_std, bias=0)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.0)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for layer in self.layers:
x = layer(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
return masks

View File

@ -30,6 +30,7 @@ Import:
- configs/pspnet/pspnet.yml
- configs/resnest/resnest.yml
- configs/segformer/segformer.yml
- configs/segmenter/segmenter.yml
- configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml
- configs/stdc/stdc.yml

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import SegmenterMaskTransformerHead
from .utils import _conv_has_norm, to_cuda
def test_segmenter_mask_transformer_head():
head = SegmenterMaskTransformerHead(
in_channels=2,
channels=2,
num_classes=150,
num_layers=2,
num_heads=3,
embed_dims=192,
dropout_ratio=0.0)
assert _conv_has_norm(head, sync_bn=True)
head.init_weights()
inputs = [torch.randn(1, 2, 32, 32)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 32, 32)

View File

@ -0,0 +1,122 @@
import argparse
import os.path as osp
import mmcv
import numpy as np
import torch
def vit_jax_to_torch(jax_weights, num_layer=12):
torch_weights = dict()
# patch embedding
conv_filters = jax_weights['embedding/kernel']
conv_filters = conv_filters.permute(3, 2, 0, 1)
torch_weights['patch_embed.projection.weight'] = conv_filters
torch_weights['patch_embed.projection.bias'] = jax_weights[
'embedding/bias']
# pos embedding
torch_weights['pos_embed'] = jax_weights[
'Transformer/posembed_input/pos_embedding']
# cls token
torch_weights['cls_token'] = jax_weights['cls']
# head
torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
# transformer blocks
for i in range(num_layer):
jax_block = f'Transformer/encoderblock_{i}'
torch_block = f'layers.{i}'
# attention norm
torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
f'{jax_block}/LayerNorm_0/scale']
torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
f'{jax_block}/LayerNorm_0/bias']
# attention
query_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
query_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
key_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
key_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
value_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
value_bias = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
qkv_weight = torch.from_numpy(
np.stack((query_weight, key_weight, value_weight), 1))
qkv_weight = torch.flatten(qkv_weight, start_dim=1)
qkv_bias = torch.from_numpy(
np.stack((query_bias, key_bias, value_bias), 0))
qkv_bias = torch.flatten(qkv_bias, start_dim=0)
torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
to_out_weight = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
torch_weights[
f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
# mlp norm
torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
f'{jax_block}/LayerNorm_2/scale']
torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
f'{jax_block}/LayerNorm_2/bias']
# mlp
torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/kernel']
torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_0/bias']
torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/kernel']
torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
f'{jax_block}/MlpBlock_3/Dense_1/bias']
# transpose weights
for k, v in torch_weights.items():
if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
v = v.permute(1, 0)
torch_weights[k] = v
return torch_weights
def main():
# stole refactoring code from Robin Strudel, thanks
parser = argparse.ArgumentParser(
description='Convert keys from jax official pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
jax_weights = np.load(args.src)
jax_weights_tensor = {}
for key in jax_weights.files:
value = torch.from_numpy(jax_weights[key])
jax_weights_tensor[key] = value
if 'L_16-i21k' in args.src:
num_layer = 24
else:
num_layer = 12
torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(torch_weights, args.dst)
if __name__ == '__main__':
main()