mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add segformer decode head and related train config (#599)
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * [Feature] Add segformer decode head and related train config * Add ade20K trainval support for segformer 1. Add related train and val configs; 2. Add AlignedResize; * Set arg: find_unused_parameters = True * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * Replace Linear Layer to 1X1 Conv * Use nn.ModuleList to refactor segformer head. * Remove local to_xtuple * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * Modify configs of segformer. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Support progressive test with fewer memory cost. * Modify default value of pad_to_patch_size arg. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Fix some bugs about model loading and eval hook. * Add ade20k 640x640 dataset. * Fix related segformer configs. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Modify error patch size. * Fix pretrain of mit_b0 * Fix the test api error. * Modify dataset base config. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Add part of benchmark results. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * Update readme. * Update readme of segformer. * Updata readme of segformer. * Update segformer readme and fix segformer mit_b4. * Update readme of segformer. * Clean AlignedResize related config. * Clean code from pr #709 * Clean code from pr #709 * Add 512x512 segformer_mit-b5. * Fix lint. * Fix some segformer head bugs. * Add segformer unit tests. * Replace AlignedResize to ResizeToMultiple. * Modify readme of segformer. * Fix bug of ResizeToMultiple. * Add ResizeToMultiple unit tests. * Resolve conflict. * Simplify the implementation of ResizeToMultiple. * Update test results. * Fix multi-scale test error when resize_ratio=1.75 and input size=640x640. * Update segformer results. * Update Segformer results. * Fix some url bugs and pipelines bug. * Move ckpt convertion to tools. * Add segformer official pretrain weights usage. * Clean redundant codes. * Remove redundant codes. * Unfied format. * Add description for segformer converter. * Update workers.
This commit is contained in:
parent
f6dca38283
commit
bcafcdd2aa
34
configs/_base_/models/segformer_mit-b0.py
Normal file
34
configs/_base_/models/segformer_mit-b0.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# model settings
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
pretrained=None,
|
||||||
|
backbone=dict(
|
||||||
|
type='MixVisionTransformer',
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=32,
|
||||||
|
num_stages=4,
|
||||||
|
num_layers=[2, 2, 2, 2],
|
||||||
|
num_heads=[1, 2, 5, 8],
|
||||||
|
patch_sizes=[7, 3, 3, 3],
|
||||||
|
sr_ratios=[8, 4, 2, 1],
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
mlp_ratio=4,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop_rate=0.0,
|
||||||
|
attn_drop_rate=0.0,
|
||||||
|
drop_path_rate=0.1),
|
||||||
|
decode_head=dict(
|
||||||
|
type='SegformerHead',
|
||||||
|
in_channels=[32, 64, 160, 256],
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
channels=256,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'))
|
73
configs/segformer/readme.md
Normal file
73
configs/segformer/readme.md
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{xie2021segformer,
|
||||||
|
title={SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers},
|
||||||
|
author={Xie, Enze and Wang, Wenhai and Yu, Zhiding and Anandkumar, Anima and Alvarez, Jose M and Luo, Ping},
|
||||||
|
journal={arXiv preprint arXiv:2105.15203},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### ADE20k
|
||||||
|
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||||
|
| ------ | -------- | --------- | ------: | -------: | -------------- | ---: | ------------- | ------ | -------- |
|
||||||
|
|Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 51.32 | 37.41 | 38.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530.log.json) |
|
||||||
|
|Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 47.66 | 40.97 | 42.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106.log.json) |
|
||||||
|
|Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 30.88 | 45.58 | 47.03 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103.log.json) |
|
||||||
|
|Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 22.11 | 47.82 | 48.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410.log.json) |
|
||||||
|
|Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 15.45 | 48.46 | 49.76 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055.log.json) |
|
||||||
|
|Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235.log.json) |
|
||||||
|
|Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 11.30 | 49.62 | 50.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243.log.json) |
|
||||||
|
|
||||||
|
Evaluation with AlignedResize:
|
||||||
|
|
||||||
|
| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) |
|
||||||
|
| ------ | -------- | --------- | ------: | ---: | ------------- |
|
||||||
|
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 |
|
||||||
|
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 |
|
||||||
|
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 |
|
||||||
|
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 |
|
||||||
|
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 |
|
||||||
|
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 |
|
||||||
|
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 |
|
||||||
|
|
||||||
|
We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by
|
||||||
|
using `AlignedResize`, you can change the dataset pipeline like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(2048, 512),
|
||||||
|
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
# resize image to multiple of 32, improve SegFormer by 0.5-1.0 mIoU.
|
||||||
|
dict(type='ResizeToMultiple', size_divisor=32),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to use segformer official pretrain weights
|
||||||
|
|
||||||
|
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.
|
||||||
|
|
||||||
|
You may follow below steps to start segformer training preparation:
|
||||||
|
|
||||||
|
1. Download segformer pretrain weights (Suggest put in `pretrain/`);
|
||||||
|
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
|
||||||
|
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;
|
33
configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py
Normal file
33
configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/segformer_mit-b0.py', '../_base_/datasets/ade20k.py',
|
||||||
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b0.pth', decode_head=dict(num_classes=150))
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(
|
||||||
|
_delete_=True,
|
||||||
|
type='AdamW',
|
||||||
|
lr=0.00006,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
weight_decay=0.01,
|
||||||
|
paramwise_cfg=dict(
|
||||||
|
custom_keys={
|
||||||
|
'pos_block': dict(decay_mult=0.),
|
||||||
|
'norm': dict(decay_mult=0.),
|
||||||
|
'head': dict(lr_mult=10.)
|
||||||
|
}))
|
||||||
|
|
||||||
|
lr_config = dict(
|
||||||
|
_delete_=True,
|
||||||
|
policy='poly',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=1500,
|
||||||
|
warmup_ratio=1e-6,
|
||||||
|
power=1.0,
|
||||||
|
min_lr=0.0,
|
||||||
|
by_epoch=False)
|
||||||
|
|
||||||
|
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b1.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[2, 2, 2, 2]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b2.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 6, 3]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b3.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 18, 3]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b4.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 8, 27, 3]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b5.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
44
configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py
Normal file
44
configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
crop_size = (640, 640)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
|
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
|
||||||
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PhotoMetricDistortion'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(2048, 640),
|
||||||
|
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
train=dict(pipeline=train_pipeline),
|
||||||
|
val=dict(pipeline=test_pipeline),
|
||||||
|
test=dict(pipeline=test_pipeline))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
pretrained='pretrain/mit_b5.pth',
|
||||||
|
backbone=dict(
|
||||||
|
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
|
||||||
|
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
@ -6,6 +6,63 @@ from numpy import random
|
|||||||
from ..builder import PIPELINES
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class ResizeToMultiple(object):
|
||||||
|
"""Resize images & seg to multiple of divisor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_divisor (int): images and gt seg maps need to resize to multiple
|
||||||
|
of size_divisor. Default: 32.
|
||||||
|
interpolation (str, optional): The interpolation mode of image resize.
|
||||||
|
Default: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size_divisor=32, interpolation=None):
|
||||||
|
self.size_divisor = size_divisor
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to resize images, semantic segmentation map to
|
||||||
|
multiple of size divisor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
|
||||||
|
"""
|
||||||
|
# Align image to multiple of size divisor.
|
||||||
|
img = results['img']
|
||||||
|
img = mmcv.imresize_to_multiple(
|
||||||
|
img,
|
||||||
|
self.size_divisor,
|
||||||
|
scale_factor=1,
|
||||||
|
interpolation=self.interpolation
|
||||||
|
if self.interpolation else 'bilinear')
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
|
||||||
|
# Align segmentation map to multiple of size divisor.
|
||||||
|
for key in results.get('seg_fields', []):
|
||||||
|
gt_seg = results[key]
|
||||||
|
gt_seg = mmcv.imresize_to_multiple(
|
||||||
|
gt_seg,
|
||||||
|
self.size_divisor,
|
||||||
|
scale_factor=1,
|
||||||
|
interpolation='nearest')
|
||||||
|
results[key] = gt_seg
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += (f'(size_divisor={self.size_divisor}, '
|
||||||
|
f'interpolation={self.interpolation})')
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class Resize(object):
|
class Resize(object):
|
||||||
"""Resize images & seg.
|
"""Resize images & seg.
|
||||||
|
@ -11,7 +11,7 @@ from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
|
|||||||
|
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw
|
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||||
|
|
||||||
|
|
||||||
class MixFFN(BaseModule):
|
class MixFFN(BaseModule):
|
||||||
@ -159,7 +159,13 @@ class EfficientMultiheadAttention(MultiheadAttention):
|
|||||||
if identity is None:
|
if identity is None:
|
||||||
identity = x_q
|
identity = x_q
|
||||||
|
|
||||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
# `need_weights=True` will let nn.MultiHeadAttention
|
||||||
|
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
|
||||||
|
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
|
||||||
|
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
|
||||||
|
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
|
||||||
|
# the error that large scale tensor sum operation may cause cuda error.
|
||||||
|
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
|
||||||
|
|
||||||
return identity + self.dropout_layer(self.proj_drop(out))
|
return identity + self.dropout_layer(self.proj_drop(out))
|
||||||
|
|
||||||
@ -387,17 +393,9 @@ class MixVisionTransformer(BaseModule):
|
|||||||
self.pretrained, logger=logger, map_location='cpu')
|
self.pretrained, logger=logger, map_location='cpu')
|
||||||
if 'state_dict' in checkpoint:
|
if 'state_dict' in checkpoint:
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
elif 'model' in checkpoint:
|
|
||||||
state_dict = checkpoint['model']
|
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
if self.pretrain_style == 'official':
|
|
||||||
# Because segformer backbone is not support by mmcls,
|
|
||||||
# so we need to convert pretrain weights to match this
|
|
||||||
# implementation.
|
|
||||||
state_dict = mit_convert(state_dict)
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -16,6 +16,7 @@ from .ocr_head import OCRHead
|
|||||||
from .point_head import PointHead
|
from .point_head import PointHead
|
||||||
from .psa_head import PSAHead
|
from .psa_head import PSAHead
|
||||||
from .psp_head import PSPHead
|
from .psp_head import PSPHead
|
||||||
|
from .segformer_head import SegformerHead
|
||||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||||
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||||
from .setr_mla_head import SETRMLAHead
|
from .setr_mla_head import SETRMLAHead
|
||||||
@ -26,5 +27,6 @@ __all__ = [
|
|||||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
|
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||||
|
'SETRMLAHead', 'SegformerHead'
|
||||||
]
|
]
|
||||||
|
65
mmseg/models/decode_heads/segformer_head.py
Normal file
65
mmseg/models/decode_heads/segformer_head.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
from mmseg.models.builder import HEADS
|
||||||
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class SegformerHead(BaseDecodeHead):
|
||||||
|
"""The all mlp Head of segformer.
|
||||||
|
|
||||||
|
This head is the implementation of
|
||||||
|
`Segformer <https://arxiv.org/abs/2105.15203>` _.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interpolate_mode: The interpolate mode of MLP head upsample operation.
|
||||||
|
Default: 'bilinear'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
||||||
|
super().__init__(input_transform='multiple_select', **kwargs)
|
||||||
|
|
||||||
|
self.interpolate_mode = interpolate_mode
|
||||||
|
num_inputs = len(self.in_channels)
|
||||||
|
|
||||||
|
assert num_inputs == len(self.in_index)
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
for i in range(num_inputs):
|
||||||
|
self.convs.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channels=self.in_channels[i],
|
||||||
|
out_channels=self.channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg))
|
||||||
|
|
||||||
|
self.fusion_conv = ConvModule(
|
||||||
|
in_channels=self.channels * num_inputs,
|
||||||
|
out_channels=self.channels,
|
||||||
|
kernel_size=1,
|
||||||
|
norm_cfg=self.norm_cfg)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
|
||||||
|
inputs = self._transform_inputs(inputs)
|
||||||
|
outs = []
|
||||||
|
for idx in range(len(inputs)):
|
||||||
|
x = inputs[idx]
|
||||||
|
conv = self.convs[idx]
|
||||||
|
outs.append(
|
||||||
|
resize(
|
||||||
|
input=conv(x),
|
||||||
|
size=inputs[0].shape[2:],
|
||||||
|
mode=self.interpolate_mode,
|
||||||
|
align_corners=self.align_corners))
|
||||||
|
|
||||||
|
out = self.fusion_conv(torch.cat(outs, dim=1))
|
||||||
|
|
||||||
|
out = self.cls_seg(out)
|
||||||
|
|
||||||
|
return out
|
@ -1,4 +1,4 @@
|
|||||||
from .ckpt_convert import mit_convert, swin_convert, vit_convert
|
from .ckpt_convert import swin_convert, vit_convert
|
||||||
from .embed import PatchEmbed
|
from .embed import PatchEmbed
|
||||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||||
from .make_divisible import make_divisible
|
from .make_divisible import make_divisible
|
||||||
@ -11,5 +11,5 @@ from .up_conv_block import UpConvBlock
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
|
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
|
||||||
'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
|
'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
|
||||||
]
|
]
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def swin_convert(ckpt):
|
def swin_convert(ckpt):
|
||||||
new_ckpt = OrderedDict()
|
new_ckpt = OrderedDict()
|
||||||
@ -90,50 +88,3 @@ def vit_convert(ckpt):
|
|||||||
new_ckpt[new_k] = v
|
new_ckpt[new_k] = v
|
||||||
|
|
||||||
return new_ckpt
|
return new_ckpt
|
||||||
|
|
||||||
|
|
||||||
def mit_convert(ckpt):
|
|
||||||
new_ckpt = OrderedDict()
|
|
||||||
# Process the concat between q linear weights and kv linear weights
|
|
||||||
for k, v in ckpt.items():
|
|
||||||
if k.startswith('head'):
|
|
||||||
continue
|
|
||||||
elif k.startswith('patch_embed'):
|
|
||||||
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
|
|
||||||
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
|
|
||||||
new_v = v
|
|
||||||
if 'proj.' in new_k:
|
|
||||||
new_k = new_k.replace('proj.', 'projection.')
|
|
||||||
elif k.startswith('block'):
|
|
||||||
stage_i = int(k.split('.')[0].replace('block', ''))
|
|
||||||
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
|
|
||||||
new_v = v
|
|
||||||
if 'attn.q.' in new_k:
|
|
||||||
sub_item_k = k.replace('q.', 'kv.')
|
|
||||||
new_k = new_k.replace('q.', 'attn.in_proj_')
|
|
||||||
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
|
|
||||||
elif 'attn.kv.' in new_k:
|
|
||||||
continue
|
|
||||||
elif 'attn.proj.' in new_k:
|
|
||||||
new_k = new_k.replace('proj.', 'attn.out_proj.')
|
|
||||||
elif 'attn.sr.' in new_k:
|
|
||||||
new_k = new_k.replace('sr.', 'sr.')
|
|
||||||
elif 'mlp.' in new_k:
|
|
||||||
string = f'{new_k}-'
|
|
||||||
new_k = new_k.replace('mlp.', 'ffn.layers.')
|
|
||||||
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
|
|
||||||
new_v = v.reshape((*v.shape, 1, 1))
|
|
||||||
new_k = new_k.replace('fc1.', '0.')
|
|
||||||
new_k = new_k.replace('dwconv.dwconv.', '1.')
|
|
||||||
new_k = new_k.replace('fc2.', '4.')
|
|
||||||
string += f'{new_k} {v.shape}-{new_v.shape}'
|
|
||||||
# print(string)
|
|
||||||
elif k.startswith('norm'):
|
|
||||||
stage_i = int(k.split('.')[0].replace('norm', ''))
|
|
||||||
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
|
|
||||||
new_v = v
|
|
||||||
else:
|
|
||||||
new_k = k
|
|
||||||
new_v = v
|
|
||||||
new_ckpt[new_k] = new_v
|
|
||||||
return new_ckpt
|
|
||||||
|
@ -10,6 +10,26 @@ from PIL import Image
|
|||||||
from mmseg.datasets.builder import PIPELINES
|
from mmseg.datasets.builder import PIPELINES
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_to_multiple():
|
||||||
|
transform = dict(type='ResizeToMultiple', size_divisor=32)
|
||||||
|
transform = build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
img = np.random.randn(213, 232, 3)
|
||||||
|
seg = np.random.randint(0, 19, (213, 232))
|
||||||
|
results = dict()
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
assert results['img'].shape == (224, 256, 3)
|
||||||
|
assert results['gt_semantic_seg'].shape == (224, 256)
|
||||||
|
assert results['img_shape'] == (224, 256, 3)
|
||||||
|
assert results['pad_shape'] == (224, 256, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_resize():
|
def test_resize():
|
||||||
# test assertion if img_scale is a list
|
# test assertion if img_scale is a list
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
39
tests/test_models/test_heads/test_segformer_head.py
Normal file
39
tests/test_models/test_heads/test_segformer_head.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.decode_heads import SegformerHead
|
||||||
|
|
||||||
|
|
||||||
|
def test_segformer_head():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# `in_channels` must have same length as `in_index`
|
||||||
|
SegformerHead(
|
||||||
|
in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2)
|
||||||
|
|
||||||
|
H, W = (64, 64)
|
||||||
|
in_channels = (32, 64, 160, 256)
|
||||||
|
shapes = [(H // 2**(i + 2), W // 2**(i + 2))
|
||||||
|
for i in range(len(in_channels))]
|
||||||
|
model = SegformerHead(
|
||||||
|
in_channels=in_channels,
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
channels=256,
|
||||||
|
num_classes=19)
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
# in_index must match the input feature maps.
|
||||||
|
inputs = [
|
||||||
|
torch.randn((1, in_channel, *shape))
|
||||||
|
for in_channel, shape in zip(in_channels, shapes)
|
||||||
|
][:3]
|
||||||
|
temp = model(inputs)
|
||||||
|
|
||||||
|
# Normal Input
|
||||||
|
# ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2)
|
||||||
|
inputs = [
|
||||||
|
torch.randn((1, in_channel, *shape))
|
||||||
|
for in_channel, shape in zip(in_channels, shapes)
|
||||||
|
]
|
||||||
|
temp = model(inputs)
|
||||||
|
|
||||||
|
assert temp.shape == (1, 19, H // 4, W // 4)
|
76
tools/model_converters/mit_convert.py
Normal file
76
tools/model_converters/mit_convert.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def mit_convert(ckpt):
|
||||||
|
new_ckpt = OrderedDict()
|
||||||
|
# Process the concat between q linear weights and kv linear weights
|
||||||
|
for k, v in ckpt.items():
|
||||||
|
if k.startswith('head'):
|
||||||
|
continue
|
||||||
|
# patch embedding convertion
|
||||||
|
elif k.startswith('patch_embed'):
|
||||||
|
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
|
||||||
|
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
|
||||||
|
new_v = v
|
||||||
|
if 'proj.' in new_k:
|
||||||
|
new_k = new_k.replace('proj.', 'projection.')
|
||||||
|
# transformer encoder layer convertion
|
||||||
|
elif k.startswith('block'):
|
||||||
|
stage_i = int(k.split('.')[0].replace('block', ''))
|
||||||
|
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
|
||||||
|
new_v = v
|
||||||
|
if 'attn.q.' in new_k:
|
||||||
|
sub_item_k = k.replace('q.', 'kv.')
|
||||||
|
new_k = new_k.replace('q.', 'attn.in_proj_')
|
||||||
|
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
|
||||||
|
elif 'attn.kv.' in new_k:
|
||||||
|
continue
|
||||||
|
elif 'attn.proj.' in new_k:
|
||||||
|
new_k = new_k.replace('proj.', 'attn.out_proj.')
|
||||||
|
elif 'attn.sr.' in new_k:
|
||||||
|
new_k = new_k.replace('sr.', 'sr.')
|
||||||
|
elif 'mlp.' in new_k:
|
||||||
|
string = f'{new_k}-'
|
||||||
|
new_k = new_k.replace('mlp.', 'ffn.layers.')
|
||||||
|
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
|
||||||
|
new_v = v.reshape((*v.shape, 1, 1))
|
||||||
|
new_k = new_k.replace('fc1.', '0.')
|
||||||
|
new_k = new_k.replace('dwconv.dwconv.', '1.')
|
||||||
|
new_k = new_k.replace('fc2.', '4.')
|
||||||
|
string += f'{new_k} {v.shape}-{new_v.shape}'
|
||||||
|
# norm layer convertion
|
||||||
|
elif k.startswith('norm'):
|
||||||
|
stage_i = int(k.split('.')[0].replace('norm', ''))
|
||||||
|
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
|
||||||
|
new_v = v
|
||||||
|
else:
|
||||||
|
new_k = k
|
||||||
|
new_v = v
|
||||||
|
new_ckpt[new_k] = new_v
|
||||||
|
return new_ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'Convert official segformer backbone weights to mmseg style.')
|
||||||
|
parser.add_argument(
|
||||||
|
'src', help='Source path of official segformer backbone weights.')
|
||||||
|
parser.add_argument(
|
||||||
|
'dst',
|
||||||
|
help='Destination path of converted segformer backbone weights.')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
src_path = args.src
|
||||||
|
dst_path = args.dst
|
||||||
|
|
||||||
|
ckpt = torch.load(src_path, map_location='cpu')
|
||||||
|
|
||||||
|
ckpt = mit_convert(ckpt)
|
||||||
|
torch.save(ckpt, dst_path)
|
Loading…
x
Reference in New Issue
Block a user