[Feature] Support ConvNext (#1216)

* upload original backbone and configs

* ConvNext Refactor

* ConvNext Refactor

* convnext customization refactor with mmseg style

* convnext customization refactor with mmseg style

* add ade20k_640x640.py

* upload files for training

* delete dist_optimizer_hook and remove layer_decay_optimizer_constructor

* check max(out_indices) < num_stages

* add unittest

* fix lint error

* use MMClassification backbone

* fix bugs in base_1k

* add mmcls in requirements/mminstall.txt

* add mmcls in requirements/mminstall.txt

* fix drop_path_rate and layer_scale_init_value

* use logger.info instead of print

* add mmcls in runtime.txt

* fix f string && delete

* add doctring in LearningRateDecayOptimizerConstructor and fix mmcls version in requirements

* fix typo in LearningRateDecayOptimizerConstructor

* use ConvNext models in unit test for LearningRateDecayOptimizerConstructor

* add unit test

* fix typo

* fix typo

* add layer_wise and fix redundant backbone.downsample_norm in it

* fix unit test

* give a ground truth lr_scale and weight_decay

* upload models and readme

* delete 'backbone.stem_norm' and 'backbone.downsample_norm' in get_num_layer()

* fix unit test and use mmcls url

* update md2yml.py and metafile

* fix typo
This commit is contained in:
MengzhangLI 2022-03-04 15:52:01 +08:00 committed by GitHub
parent 369a2ee9bb
commit 7ddd2fe2ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 936 additions and 4 deletions

View File

@ -162,8 +162,9 @@ def parse_md(md_file):
model_name = fn[:-3]
fps = els[fps_id] if els[fps_id] != '-' and els[
fps_id] != '' else -1
mem = els[mem_id] if els[mem_id] != '-' and els[
mem_id] != '' else -1
mem = els[mem_id].split(
'\\'
)[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
crop_size = els[crop_size_id].split('x')
assert len(crop_size) == 2
method = els[method_id].split()[0].split('-')[-1]

View File

@ -84,6 +84,7 @@ Supported backbones:
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
Supported methods:

View File

@ -83,6 +83,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
已支持的算法:

View File

@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

View File

@ -0,0 +1,44 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth' # noqa
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='mmcls.ConvNeXt',
arch='base',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
type='UPerHead',
in_channels=[128, 256, 512, 1024],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,71 @@
# ConvNeXt
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)
## Introduction
<!-- [BACKBONE] -->
<a href="https://github.com/facebookresearch/ConvNeXt">Official Repo</a>
<a href="https://github.com/open-mmlab/mmclassification/blob/v0.20.1/mmcls/models/backbones/convnext.py#L133">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/8370623/148624004-e9581042-ea4d-4e10-b3bd-42c92b02053b.png" width="90%"/>
</div>
```bibtex
@article{liu2022convnet,
title={A ConvNet for the 2020s},
author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
journal={arXiv preprint arXiv:2201.03545},
year={2022}
}
```
### Usage
- This backbone need to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
```shell
pip install mmcls>=0.20.1
```
### Pre-trained Models
The pre-trained models on ImageNet-1k or ImageNet-21k are used to fine-tune on the downstream tasks.
| Model | Training Data | Params(M) | Flops(G) | Download |
|:--------------:|:-------------:|:---------:|:--------:|:--------:|
| ConvNeXt-T\* | ImageNet-1k | 28.59 | 4.46 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth) |
| ConvNeXt-S\* | ImageNet-1k | 50.22 | 8.69 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth) |
| ConvNeXt-B\* | ImageNet-1k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth) |
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_in21k_20220301-262fd037.pth) |
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth) |
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-xlarge_3rdparty_in21k_20220301-08aa5ddc.pth) |
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt/tree/main/semantic_segmentation#results-and-fine-tuned-models).*
## Results and models
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
| UperNet | ConvNeXt-T | 512x512 | 160000 | 4.23 | 19.90 | 46.11 | 46.62 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553.log.json) |
| UperNet | ConvNeXt-S | 512x512 | 160000 | 5.16 | 15.18 | 48.56 | 49.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208.log.json) |
| UperNet | ConvNeXt-B | 512x512 | 160000 | 6.33 | 14.41 | 48.71 | 49.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227.log.json) |
| UperNet | ConvNeXt-B |640x640 | 160000 | 8.53 | 10.88 | 52.13 | 52.66 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859.log.json) |
| UperNet | ConvNeXt-L |640x640 | 160000 | 12.08 | 7.69 | 53.16 | 53.38 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532.log.json) |
| UperNet | ConvNeXt-XL |640x640 | 160000 | 26.16\* | 6.33 | 53.58 | 54.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344.log.json) |
Note:
- `Mem (GB)` with \* is collected when `cudnn_benchmark=True`, and hardware is V100.

View File

@ -0,0 +1,133 @@
Models:
- Name: upernet_convnext_tiny_fp16_512x512_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-T
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 50.25
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (512,512)
Training Memory (GB): 4.23
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 46.11
mIoU(ms+flip): 46.62
Config: configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth
- Name: upernet_convnext_small_fp16_512x512_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-S
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 65.88
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (512,512)
Training Memory (GB): 5.16
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.56
mIoU(ms+flip): 49.02
Config: configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth
- Name: upernet_convnext_base_fp16_512x512_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-B
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 69.4
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (512,512)
Training Memory (GB): 6.33
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.71
mIoU(ms+flip): 49.54
Config: configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth
- Name: upernet_convnext_base_fp16_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-B
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 91.91
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (640,640)
Training Memory (GB): 8.53
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 52.13
mIoU(ms+flip): 52.66
Config: configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth
- Name: upernet_convnext_large_fp16_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-L
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 130.04
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (640,640)
Training Memory (GB): 12.08
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 53.16
mIoU(ms+flip): 53.38
Config: configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth
- Name: upernet_convnext_xlarge_fp16_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ConvNeXt-XL
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 157.98
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (640,640)
Training Memory (GB): 26.16
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 53.58
mIoU(ms+flip): 54.11
Config: configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth

View File

@ -0,0 +1,40 @@
_base_ = [
'../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
model = dict(
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 12
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,55 @@
_base_ = [
'../_base_/models/upernet_convnext.py',
'../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_in21k_20220301-262fd037.pth' # noqa
model = dict(
backbone=dict(
type='mmcls.ConvNeXt',
arch='base',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
in_channels=[128, 256, 512, 1024],
num_classes=150,
),
auxiliary_head=dict(in_channels=512, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 12
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,55 @@
_base_ = [
'../_base_/models/upernet_convnext.py',
'../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth' # noqa
model = dict(
backbone=dict(
type='mmcls.ConvNeXt',
arch='large',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
in_channels=[192, 384, 768, 1536],
num_classes=150,
),
auxiliary_head=dict(in_channels=768, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 12
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,54 @@
_base_ = [
'../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa
model = dict(
backbone=dict(
type='mmcls.ConvNeXt',
arch='small',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.3,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
in_channels=[96, 192, 384, 768],
num_classes=150,
),
auxiliary_head=dict(in_channels=384, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 12
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,54 @@
_base_ = [
'../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (512, 512)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa
model = dict(
backbone=dict(
type='mmcls.ConvNeXt',
arch='tiny',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
in_channels=[96, 192, 384, 768],
num_classes=150,
),
auxiliary_head=dict(in_channels=384, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 6
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,55 @@
_base_ = [
'../_base_/models/upernet_convnext.py',
'../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-xlarge_3rdparty_in21k_20220301-08aa5ddc.pth' # noqa
model = dict(
backbone=dict(
type='mmcls.ConvNeXt',
arch='xlarge',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
decode_head=dict(
in_channels=[256, 512, 1024, 2048],
num_classes=150,
),
auxiliary_head=dict(in_channels=1024, num_classes=150),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)),
)
optimizer = dict(
constructor='LearningRateDecayOptimizerConstructor',
_delete_=True,
type='AdamW',
lr=0.00008,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.9,
'decay_type': 'stage_wise',
'num_layers': 12
})
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
# fp16 placeholder
fp16 = dict()

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from .misc import add_prefix
__all__ = ['add_prefix']
__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']

View File

@ -0,0 +1,148 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
get_dist_info)
from ...utils import get_root_logger
def get_num_layer_layer_wise(var_name, num_max_layer=12):
"""Get the layer id to set the different learning rates in ``layer_wise``
decay_type.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
stage_id = int(var_name.split('.')[2])
if stage_id == 0:
layer_id = 0
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3
elif stage_id == 3:
layer_id = num_max_layer
return layer_id
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
block_id = int(var_name.split('.')[3])
if stage_id == 0:
layer_id = 1
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = num_max_layer
return layer_id
else:
return num_max_layer + 1
def get_num_layer_stage_wise(var_name, num_max_layer):
"""Get the layer id to set the different learning rates in ``stage_wise``
decay_type.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
return 0
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
return stage_id + 1
else:
return num_max_layer - 1
@OPTIMIZER_BUILDERS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone."""
def add_params(self, params, module):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
"""
logger = get_root_logger()
parameter_groups = {}
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
num_layers = self.paramwise_cfg.get('num_layers') + 2
decay_rate = self.paramwise_cfg.get('decay_rate')
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
logger.info('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'pos_embed', 'cls_token'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
if decay_type == 'layer_wise':
layer_id = get_num_layer_layer_wise(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
elif decay_type == 'stage_wise':
layer_id = get_num_layer_stage_wise(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
group_name = f'layer_{layer_id}_{group_name}'
if group_name not in parameter_groups:
scale = decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.base_lr,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())

View File

@ -5,6 +5,7 @@ Import:
- configs/bisenetv2/bisenetv2.yml
- configs/ccnet/ccnet.yml
- configs/cgnet/cgnet.yml
- configs/convnext/convnext.yml
- configs/danet/danet.yml
- configs/deeplabv3/deeplabv3.yml
- configs/deeplabv3plus/deeplabv3plus.yml

View File

@ -1 +1,2 @@
mmcv-full>=1.3.1,<=1.4.0
mmcls>=0.20.1
mmcv-full>=1.4.4,<=1.5.0

View File

@ -1,4 +1,5 @@
matplotlib
mmcls>=0.20.1
numpy
packaging
prettytable

View File

@ -0,0 +1,161 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.core.utils.layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
base_lr = 1
decay_rate = 2
base_wd = 0.05
weight_decay = 0.05
stage_wise_gt_lst = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 8
}, {
'weight_decay': 0.0,
'lr_scale': 8
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]
layer_wise_gt_lst = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 2
}, {
'weight_decay': 0.0,
'lr_scale': 2
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]
class ConvNeXtExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.ModuleList()
self.backbone.stages = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True))
self.backbone.stages.append(stage)
self.backbone.norm0 = nn.BatchNorm2d(2)
# add some variables to meet unit test coverate rate
self.backbone.cls_token = nn.Parameter(torch.ones(1))
self.backbone.mask_token = nn.Parameter(torch.ones(1))
self.backbone.pos_embed = nn.Parameter(torch.ones(1))
self.backbone.stem_norm = nn.Parameter(torch.ones(1))
self.backbone.downsample_norm0 = nn.BatchNorm2d(2)
self.backbone.downsample_norm1 = nn.BatchNorm2d(2)
self.backbone.downsample_norm2 = nn.BatchNorm2d(2)
self.backbone.lin = nn.Parameter(torch.ones(1))
self.backbone.lin.requires_grad = False
self.backbone.downsample_layers = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True))
self.backbone.downsample_layers.append(stage)
self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2)
class PseudoDataParallel(nn.Module):
def __init__(self):
super().__init__()
self.module = ConvNeXtExampleModel()
def forward(self, x):
return x
def check_convnext_adamw_optimizer(optimizer, gt_lst):
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
param_groups = optimizer.param_groups
assert len(param_groups) == 12
for i, param_dict in enumerate(param_groups):
assert param_dict['weight_decay'] == gt_lst[i]['weight_decay']
assert param_dict['lr_scale'] == gt_lst[i]['lr_scale']
assert param_dict['lr_scale'] == param_dict['lr']
def test_convnext_learning_rate_decay_optimizer_constructor():
# paramwise_cfg with ConvNeXtExampleModel
model = ConvNeXtExampleModel()
optimizer_cfg = dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
stagewise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, stagewise_paramwise_cfg)
optimizer = optim_constructor(model)
check_convnext_adamw_optimizer(optimizer, stage_wise_gt_lst)
layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_convnext_adamw_optimizer(optimizer, layer_wise_gt_lst)