[Feature] Support PoolFormer in MMSegmentation 2.0 (#2191)

* [Feature] 2.0 PoolFormer

* fix mmcls version

* fix ut error

* fix ut

* fix ut
This commit is contained in:
MengzhangLI 2022-10-19 13:08:07 +08:00 committed by GitHub
parent e8af7a0ed0
commit 25604a151b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 360 additions and 5 deletions

View File

@ -102,6 +102,7 @@ Supported backbones:
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
Supported methods:

View File

@ -96,6 +96,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
已支持的算法:

View File

@ -0,0 +1,50 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='mmcls.PoolFormer',
arch='s12',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file, prefix='backbone.'),
in_patch_size=7,
in_stride=4,
in_pad=2,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.,
drop_path_rate=0.,
out_indices=(0, 2, 4, 6),
frozen_stages=0,
),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,65 @@
# PoolFormer
[MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418)
## Introduction
<!-- [BACKBONE] -->
<a href="https://github.com/sail-sg/poolformer/tree/main/segmentation">Official Repo</a>
<a href="https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/poolformer.py#L198">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design. Code is available at [this https URL](https://github.com/sail-sg/poolformer)
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/15921929/144710761-1635f59a-abde-4946-984c-a2c3f22a19d2.png" width="70%"/>
</div>
## Citation
```bibtex
@inproceedings{yu2022metaformer,
title={Metaformer is actually what you need for vision},
author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={10819--10829},
year={2022}
}
```
### Usage
- PoolFormer backbone needs to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
```shell
pip install "mmcls>=1.0.0rc0"
```
- The pretrained models could also be downloaded from [PoolFormer config of MMClassification](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer).
## Results and models
### ADE20K
| Method | Backbone | Crop Size | pretrain | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | mIoU\* | mIoU\*(ms+flip) | config | download |
| ------ | -------------- | --------- | ----------- | ---------- | ------- | -------- | -------------- | ----- | ------------: | ------ | --------------: | ------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | PoolFormer-S12 | 512x512 | ImageNet-1K | 32 | 40000 | 4.17 | 23.48 | 36.68 | - | 37.07 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154.log.json) |
| FPN | PoolFormer-S24 | 512x512 | ImageNet-1K | 32 | 40000 | 5.47 | 15.74 | 40.12 | - | 40.36 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s24_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049.log.json) |
| FPN | PoolFormer-S36 | 512x512 | ImageNet-1K | 32 | 40000 | 6.77 | 11.34 | 41.61 | - | 41.81 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s36_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122-b47e607d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122.log.json) |
| FPN | PoolFormer-M36 | 512x512 | ImageNet-1K | 32 | 40000 | 8.59 | 8.97 | 41.95 | - | 42.35 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_m36_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230.log.json) |
| FPN | PoolFormer-M48 | 512x512 | ImageNet-1K | 32 | 40000 | 10.48 | 6.69 | 42.43 | - | 42.76 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_m48_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923.log.json) |
Note:
- We replace `AlignedResize` in original PoolFormer implementation to `Resize + ResizeToMultiple`.
- `mIoU` with * is collected when `Resize + ResizeToMultiple` is adopted in `test_pipeline`, so do `mIoU` in logs.
- The Test Time Augmentation i.e., "ms+flip" in MMSegmentation v1.x is developing, stay tuned!

View File

@ -0,0 +1,11 @@
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m36_3rdparty_32xb128_in1k_20220414-c55e0949.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='m36',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]))

View File

@ -0,0 +1,11 @@
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='m48',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]))

View File

@ -0,0 +1,91 @@
_base_ = [
'../_base_/models/fpn_poolformer_s12.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
pipeline=train_pipeline)))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
# model settings
model = dict(
data_preprocessor=data_preprocessor,
neck=dict(in_channels=[64, 128, 320, 512]),
decode_head=dict(num_classes=150))
# optimizer
# optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.0001)
# optimizer_config = dict()
# # learning policy
# lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False)
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001))
param_scheduler = [
dict(
type='PolyLR',
power=0.9,
begin=0,
end=40000,
eta_min=0.0,
by_epoch=False,
)
]

View File

@ -0,0 +1,9 @@
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s24_3rdparty_32xb128_in1k_20220414-d7055904.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='s24',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')))

View File

@ -0,0 +1,10 @@
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s36_3rdparty_32xb128_in1k_20220414-d78ff3e8.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='s36',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')))

View File

@ -0,0 +1,106 @@
Models:
- Name: fpn_poolformer_s12_8xb4-40k_ade20k-512x512
In Collection: FPN
Metadata:
backbone: PoolFormer-S12
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 42.59
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 4.17
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 36.68
Config: configs/poolformer/fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth
- Name: fpn_poolformer_s24_8xb4-40k_ade20k-512x512
In Collection: FPN
Metadata:
backbone: PoolFormer-S24
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 63.53
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 5.47
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 40.12
Config: configs/poolformer/fpn_poolformer_s24_8xb4-40k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth
- Name: ''
In Collection: FPN
Metadata:
backbone: PoolFormer-S36
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 88.18
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 6.77
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.61
Config: ''
Weights: ''
- Name: fpn_poolformer_m36_8xb4-40k_ade20k-512x512
In Collection: FPN
Metadata:
backbone: PoolFormer-M36
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 111.48
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 8.59
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.95
Config: configs/poolformer/fpn_poolformer_m36_8xb4-40k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth
- Name: fpn_poolformer_m48_8xb4-40k_ade20k-512x512
In Collection: FPN
Metadata:
backbone: PoolFormer-M48
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 149.48
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 10.48
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 42.43
Config: configs/poolformer/fpn_poolformer_m48_8xb4-40k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth

View File

@ -30,6 +30,7 @@ Import:
- configs/nonlocal_net/nonlocal_net.yml
- configs/ocrnet/ocrnet.yml
- configs/point_rend/point_rend.yml
- configs/poolformer/poolformer.yml
- configs/psanet/psanet.yml
- configs/pspnet/pspnet.yml
- configs/resnest/resnest.yml

View File

@ -89,10 +89,10 @@ def test_config_data_pipeline():
# remove loading pipeline
load_img_pipeline = config_mod.train_pipeline.pop(0)
to_float32 = load_img_pipeline.get('to_float32', False)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
del config_mod.train_pipeline[0]
del config_mod.test_pipeline[0]
# remove loading annotation in test pipeline
config_mod.test_pipeline.pop(1)
del config_mod.test_pipeline[-2]
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
@ -120,8 +120,7 @@ def test_config_data_pipeline():
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
)
ori_shape=img.shape)
print(f'Test testing data pipeline: \n{test_pipeline!r}')
output_results = test_pipeline(results)
assert output_results is not None