update riformer mmpretrain

pull/1453/head
techmonsterwang 2023-04-03 21:07:46 +08:00 committed by Ezra-Yu
parent 1ee9bbe050
commit 0b70c108b0
35 changed files with 1541 additions and 8 deletions

View File

@ -70,7 +70,7 @@ val_dataloader = dict(
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root='data/imagenet', data_root='data/imagenet',
ann_file='meta/val.txt', # ann_file='meta/val.txt',
data_prefix='val', data_prefix='val',
pipeline=test_pipeline), pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False), sampler=dict(type='DefaultSampler', shuffle=False),
@ -79,4 +79,4 @@ val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset # If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader test_dataloader = val_dataloader
test_evaluator = val_evaluator test_evaluator = val_evaluator

View File

@ -70,7 +70,7 @@ val_dataloader = dict(
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root='data/imagenet', data_root='data/imagenet',
ann_file='meta/val.txt', # ann_file='meta/val.txt',
data_prefix='val', data_prefix='val',
pipeline=test_pipeline), pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False), sampler=dict(type='DefaultSampler', shuffle=False),
@ -79,4 +79,4 @@ val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset # If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader test_dataloader = val_dataloader
test_evaluator = val_evaluator test_evaluator = val_evaluator

View File

@ -0,0 +1,82 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=404,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,82 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=426,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,22 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='RIFormer',
arch='m36',
drop_path_rate=0.1,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,22 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='RIFormer',
arch='m48',
drop_path_rate=0.1,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,22 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='RIFormer',
arch='s12',
drop_path_rate=0.1,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,22 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='RIFormer',
arch='s24',
drop_path_rate=0.1,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,22 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='RIFormer',
arch='s36',
drop_path_rate=0.1,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,207 @@
# RIFormer
> [RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer](https://arxiv.org/abs/xxxx.xxxxx)
<!-- [ALGORITHM] -->
## Introduction
RIFormer is a way to keep a vision backbone effective while removing token mixers in its basic building blocks. Equipped with our proposed optimization strategy, we are able to build an extremely simple vision backbone with encouraging performance, while enjoying the high efficiency during inference. RIFormer shares nearly the same macro and micro design as MetaFormer, but safely removing all token mixers. The quantitative results show that our networks outperform many prevailing backbones with faster inference speed on ImageNet-1K.
<div align=center>
<img src="https://user-images.githubusercontent.com/48375204/223930120-dc075c8e-0513-42eb-9830-469a45c1d941.png" width="60%"/>
</div>
## Abstract
<details>
<summary>Show the paper's abstract</summary>
<br>
This paper studies how to keep a vision backbone effective while removing token mixers in its basic building blocks. Token mixers, as self-attention for vision transformers (ViTs), are intended to perform information communication between different spatial tokens but suffer from considerable computational cost and latency. However, directly removing them will lead to an incomplete model structure prior, and thus brings a significant accuracy drop. To this end, we first develop an RepIdentityFormer base on the re-parameterizing idea, to study the token mixer free model architecture. And we then explore the improved learning paradigm to break the limitation of simple token mixer free backbone, and summarize the empirical practice into 5 guidelines. Equipped with the proposed optimization strategy, we are able to build an extremely simple vision backbone with encouraging performance, while enjoying the high efficiency during inference. Extensive experiments and ablative analysis also demonstrate that the inductive bias of network architecture, can be incorporated into simple network structure with appropriate optimization strategy. We hope this work can serve as a starting point for the exploration of optimization-driven efficient network design.
</br>
</details>
## How to use
The checkpoints provided are all `training-time` models. Use the reparameterize tool or `switch_to_deploy` interface to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.
<!-- [TABS-BEGIN] -->
**Predict image**
Use `classifier.backbone.switch_to_deploy()` interface to switch the RIFormer models into inference mode.
```python
>>> import torch
>>> from mmcls.apis import init_model, inference_model
>>>
>>> model = init_model('configs/riformer/riformer-s12_32xb128_in1k.py', '/home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar')
>>> results = inference_model(model, 'demo/demo.JPEG')
>>> print( (results['pred_class'], results['pred_score']) )
('sea snake' 0.7827475666999817)
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> results = inference_model(model, 'demo/demo.JPEG')
>>> print( (results['pred_class'], results['pred_score']) )
('sea snake', 0.7827480435371399)
```
**Use the model**
```python
>>> import torch
>>>
>>> model = init_model('configs/riformer/riformer-s12_32xb128_in1k.py', '/home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar')
>>> model.eval()
>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
>>> # To get classification scores.
>>> out = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> # To extract features.
>>> outs = model.extract_feat(inputs)
>>> print(outs[0].shape)
torch.Size([1, 512])
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> out_deploy = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> assert torch.allclose(out, out_deploy, rtol=1e-4, atol=1e-5) # pass without error
```
**Test Command**
Place the ImageNet dataset to the `data/imagenet/` directory, or prepare datasets according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset).
*224×224*
Download Checkpoint:
```shell
wget /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar
```
Test use unfused model:
```shell
python tools/test.py configs/riformer/riformer-s12_32xb128_in1k.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar
```
Reparameterize checkpoint:
```shell
python tools/model_converters/reparameterize_model.py configs/riformer/riformer-s12_32xb128_in1k.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/224/repidentityformer-s12.pth.tar
```
Test use fused model:
```shell
python tools/test.py configs/riformer/deploy/riformer-s12-deploy_32xb128_in1k.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/224/repidentityformer-s12.pth.tar
```
*384×384*
Download Checkpoint:
```shell
wget /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/384/repidentityformer-s12.pth.tar
```
Test use unfused model:
```shell
python tools/test.py configs/riformer/riformer-s12_32xb128_in1k_384.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/384/repidentityformer-s12.pth.tar
```
Reparameterize checkpoint:
```shell
python tools/model_converters/reparameterize_model.py configs/riformer/riformer-s12_32xb128_in1k_384.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/384/repidentityformer-s12.pth.tar /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/384/repidentityformer-s12.pth.tar
```
Test use fused model:
```shell
python tools/test.py configs/riformer/deploy/riformer-s12-deploy_32xb128_in1k_384.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/384/repidentityformer-s12.pth.tar
```
<!-- [TABS-END] -->
For more configurable parameters, please refer to the [API](https://mmclassification.readthedocs.io/en/1.x/api/generated/mmcls.models.backbones.RIFormer.html#mmcls.models.backbones.RIFormer).
<details>
<summary><b>How to use the reparameterization tool</b>(click to show)</summary>
<br>
Use provided tool to reparameterize the given model and save the checkpoint:
```bash
python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
```
`${CFG_PATH}` is the config file path, `${SRC_CKPT_PATH}` is the source chenpoint file path, `${TARGET_CKPT_PATH}` is the target deploy weight file path.
For example:
```shell
# download the weight
wget /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar
# reparameterize unfused weight to fused weight
python tools/model_converters/reparameterize_model.py configs/riformer/riformer-s12_32xb128_in1k.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained/224/repidentityformer-s12.pth.tar /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/224/repidentityformer-s12.pth.tar
```
To use reparameterized weights, you can use the deploy model config file such as the [s12_deploy example](./riformer-s12-deploy_32xb128_in1k.py):
```text
# in riformer-s12-deploy_32xb128_in1k.py
_base_ = '../riformer-s12-deploy_32xb128_in1k.py' # basic s12 config
model = dict(backbone=dict(deploy=True)) # switch model into deploy mode
```
```shell
python tools/test.py configs/riformer/deploy/riformer-s12-deploy_32xb128_in1k.py /home/PJLAB/wangjiahao/project/RepIndentityFormer/mmcls_pretrained_deploy/224/repidentityformer-s12.pth.tar
```
</br>
</details>
## Results and models
### ImageNet-1k
| Model | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :------------------------------------------: | :-------------------------------------------------------------------------------------------------: |
| RIFormer-S12 | 224x224 | 11.92 | 1.82 | 76.90 | 93.06 | [config](./riformer-s12_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth) |
| RIFormer-S24 | 224x224 | 21.39 | 3.41 | 80.28 | 94.80 | [config](./riformer-s24_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s24_3rdparty_32xb128_in1k_20220414-d7055904.pth) |
| RIFormer-S36 | 224x224 | 30.86 | 5.00 | 81.29 | 95.41 | [config](./riformer-s36_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s36_3rdparty_32xb128_in1k_20220414-d78ff3e8.pth) |
| RIFormer-M36 | 224x224 | 56.17 | 8.80 | 82.57 | 95.99 | [config](./riformer-m36_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m36_3rdparty_32xb128_in1k_20220414-c55e0949.pth) |
| RIFormer-M48 | 224x224 | 73.47 | 11.59 | 82.75 | 96.11 | [config](./riformer-m48_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth) |
| RIFormer-S12 | 384x384 | 11.92 | 5.36 | 78.29 | 93.93 | [config](./riformer-s12_32xb128_in1k_384.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth) |
| RIFormer-S24 | 384x384 | 21.39 | 10.03 | 81.36 | 95.40 | [config](./riformer-s24_32xb128_in1k_384.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s24_3rdparty_32xb128_in1k_20220414-d7055904.pth) |
| RIFormer-S36 | 384x384 | 30.86 | 14.70 | 82.22 | 95.95 | [config](./riformer-s36_32xb128_in1k_384.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s36_3rdparty_32xb128_in1k_20220414-d78ff3e8.pth) |
| RIFormer-M36 | 384x384 | 56.17 | 25.87 | 83.39 | 96.40 | [config](./riformer-m36_32xb128_in1k_384.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m36_3rdparty_32xb128_in1k_20220414-c55e0949.pth) |
| RIFormer-M48 | 384x384 | 73.47 | 34.06 | 83.70 | 96.60 | [config](./riformer-m48_32xb128_in1k_384.py) | [model](https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth) |
The config files of these models are only for inference.
## Citation
```bibtex
@inproceedings{wang2023riformer,
title={MetaFormer is Actually What You Need for Vision},
author={Wang, Jiahao and Zhang, Songyang and Liu, Yong and Wu, Taiqiang and Yang, Yujiu and Liu, Xihui and Chen, Kai and Luo, Ping and Lin, Dahua},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
```

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-m36_32xb128_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-m36_32xb128_in1k_384.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-m48_32xb128_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-m48_32xb128_in1k_384.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s12_32xb128_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s12_32xb128_in1k_384.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s24_32xb128_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s24_32xb128_in1k_384.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s36_32xb128_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = '../riformer-s36_32xb128_in1k_384.py'
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,183 @@
Collections:
- Name: RIFormer
Metadata:
Training Data: ImageNet-1k
Architecture:
- Affine
- 1x1 Convolution
- LayerScale
Paper:
URL: https://arxiv.org/abs/xxxx.xxxxx
Title: RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer
README: configs/riformer/README.md
Code:
Version: v1.0.rc6
URL:
Models:
- Name: riformer-s12_in1k
Metadata:
FLOPs: 1822000000
Parameters: 11915000
In Collection: RIFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.90
Top 5 Accuracy: 93.06
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s12_32xb128_in1k.py
Converted From:
Weights:
Code:
- Name: riformer-s24_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 3412000000
Parameters: 21389000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.28
Top 5 Accuracy: 94.80
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s24_32xb128_in1k.py
Converted From:
Weights:
Code:
- Name: riformer-s36_in1k
Metadata:
FLOPs: 5003000000
Parameters: 30863000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.29
Top 5 Accuracy: 95.41
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s36_32xb128_in1k.py
Converted From:
Weights:
Code:
- Name: riformer-m36_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 8801000000
Parameters: 56173000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.57
Top 5 Accuracy: 95.99
Task: Image Classification
Weights:
Config: configs/riformer/riformer-m36_32xb128_in1k.py
Converted From:
Weights:
Code:
- Name: riformer-m48_in1k
Metadata:
FLOPs: 11590000000
Parameters: 73473000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.75
Top 5 Accuracy: 96.11
Task: Image Classification
Weights:
Config: configs/riformer/riformer-m48_32xb128_in1k.py
Converted From:
Weights:
Code:
- Name: riformer-s12_384_in1k
Metadata:
FLOPs: 5355000000
Parameters: 11915000
In Collection: RIFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.29
Top 5 Accuracy: 93.93
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s12_32xb128_in1k_384.py
Converted From:
Weights:
Code:
- Name: riformer-s24_384_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 10028000000
Parameters: 21389000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.36
Top 5 Accuracy: 95.40
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s24_32xb128_in1k_384.py
Converted From:
Weights:
Code:
- Name: riformer-s36_384_in1k
Metadata:
FLOPs: 14702000000
Parameters: 30863000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.22
Top 5 Accuracy: 95.95
Task: Image Classification
Weights:
Config: configs/riformer/riformer-s36_32xb128_in1k_384.py
Converted From:
Weights:
Code:
- Name: riformer-m36_384_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 25865000000
Parameters: 56173000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.39
Top 5 Accuracy: 96.40
Task: Image Classification
Weights:
Config: configs/riformer/riformer-m36_32xb128_in1k_384.py
Converted From:
Weights:
Code:
- Name: riformer-m48_384_in1k
Metadata:
FLOPs: 34060000000
Parameters: 73473000
In Collection: PoolFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.70
Top 5 Accuracy: 96.60
Task: Image Classification
Weights:
Config: configs/riformer/riformer-m48_32xb128_in1k_384.py
Converted From:
Weights:
Code:

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_m36.py',
'../_base_/datasets/imagenet_bs128_poolformer_medium_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_m36.py',
'../_base_/datasets/imagenet_bs128_riformer_medium_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_m48.py',
'../_base_/datasets/imagenet_bs128_poolformer_medium_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_m48.py',
'../_base_/datasets/imagenet_bs128_riformer_medium_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s12.py',
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s12.py',
'../_base_/datasets/imagenet_bs128_riformer_small_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s24.py',
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s24.py',
'../_base_/datasets/imagenet_bs128_riformer_small_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s36.py',
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,17 @@
_base_ = [
'../_base_/models/riformer/riformer_s36.py',
'../_base_/datasets/imagenet_bs128_riformer_small_384.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
# schedule settings
optim_wrapper = dict(
optimizer=dict(lr=4e-3),
clip_grad=dict(max_norm=5.0),
)
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (32 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -26,6 +26,7 @@ from .mobileone import MobileOne
from .mobilevit import MobileViT from .mobilevit import MobileViT
from .mvit import MViT from .mvit import MViT
from .poolformer import PoolFormer from .poolformer import PoolFormer
from .riformer import RIFormer
from .regnet import RegNet from .regnet import RegNet
from .replknet import RepLKNet from .replknet import RepLKNet
from .repmlp import RepMLPNet from .repmlp import RepMLPNet
@ -95,6 +96,7 @@ __all__ = [
'RepLKNet', 'RepLKNet',
'RepMLPNet', 'RepMLPNet',
'PoolFormer', 'PoolFormer',
'RIFormer',
'DenseNet', 'DenseNet',
'VAN', 'VAN',
'InceptionV3', 'InceptionV3',

View File

@ -0,0 +1,498 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class PatchEmbed(nn.Module):
"""Patch Embedding module implemented by a layer of convolution.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
Args:
patch_size (int): Patch size of the patch embedding. Defaults to 16.
stride (int): Stride of the patch embedding. Defaults to 16.
padding (int): Padding of the patch embedding. Defaults to 0.
in_chans (int): Input channels. Defaults to 3.
embed_dim (int): Output dimension of the patch embedding.
Defaults to 768.
norm_layer (module): Normalization module. Defaults to None (not use).
"""
def __init__(self,
patch_size=16,
stride=16,
padding=0,
in_chans=3,
embed_dim=768,
norm_layer=None):
super().__init__()
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class Pooling(nn.Module):
"""Pooling module.
Args:
pool_size (int): Pooling size. Defaults to 3.
"""
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size,
stride=1,
padding=pool_size // 2,
count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class Affine(nn.Module):
"""Affine Transformation module.
Args:
in_features (int): Input dimension. Defaults to None.
"""
def __init__(self, in_features=None):
super().__init__()
self.affine = nn.Conv2d(
in_features,
in_features,
kernel_size=1,
stride=1,
padding=0,
groups=in_features,
bias=True)
def forward(self, x):
return self.affine(x) - x
class Mlp(nn.Module):
"""Mlp implemented by with 1*1 convolutions.
Input: Tensor with shape [B, C, H, W].
Output: Tensor with shape [B, C, H, W].
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class RIFormerBlock(BaseModule):
"""RIFormer Block.
Args:
dim (int): Embedding dim.
mlp_ratio (float): Mlp expansion ratio. Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='GN', num_groups=1)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-5.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
"""
def __init__(self,
dim,
mlp_ratio=4.,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5,
deploy=False):
super().__init__()
if deploy:
self.norm_reparam = build_norm_layer(norm_cfg, dim)[1]
else:
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.token_mixer = Affine(in_features=dim)
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
# The following two techniques are useful to train deep RIFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.norm_cfg = norm_cfg
self.dim = dim
self.deploy = deploy
def forward(self, x):
if hasattr(self, 'norm_reparam'):
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.norm_reparam(x))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
return x
def fuse_affine(self, norm, token_mixer):
gamma_affn = token_mixer.affine.weight.reshape(-1)
gamma_affn = gamma_affn - torch.ones_like(gamma_affn)
beta_affn = token_mixer.affine.bias
gamma_ln = norm.weight
beta_ln = norm.bias
print('gamma_affn:', gamma_affn.shape)
print('beta_affn:', beta_affn.shape)
print('gamma_ln:', gamma_ln.shape)
print('beta_ln:', beta_ln.shape)
return (gamma_ln * gamma_affn), (beta_ln * gamma_affn + beta_affn)
def get_equivalent_scale_bias(self):
eq_s, eq_b = self.fuse_affine(self.norm1, self.token_mixer)
return eq_s, eq_b
def switch_to_deploy(self):
if self.deploy:
return
eq_s, eq_b = self.get_equivalent_scale_bias()
self.norm_reparam = build_norm_layer(self.norm_cfg, self.dim)[1]
self.norm_reparam.weight.data = eq_s
self.norm_reparam.bias.data = eq_b
self.__delattr__('norm1')
if hasattr(self, 'token_mixer'):
self.__delattr__('token_mixer')
self.deploy = True
def basic_blocks(dim,
index,
layers,
mlp_ratio=4.,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop_rate=.0,
drop_path_rate=0.,
layer_scale_init_value=1e-5,
deploy=False):
"""
generate RIFormer blocks for a stage
return: RIFormer blocks
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
sum(layers) - 1)
blocks.append(
RIFormerBlock(
dim,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
layer_scale_init_value=layer_scale_init_value,
deploy=deploy,
))
blocks = nn.Sequential(*blocks)
return blocks
@MODELS.register_module()
class RIFormer(BaseBackbone):
"""RIFormer.
A PyTorch implementation of RIFormer introduced by:
`RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer <https://arxiv.org/abs/xxxx.xxxxx>`_
Modified from the `official repo
<https://github.com/techmonsterwang/RIFormer.py>`.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``RIFormer.arch_settings``. And if dict, it
should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- mlp_ratios (list[int]): Expansion ratio of MLPs.
- layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 'S12'.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
in_patch_size (int): The patch size of/? input image patch embedding.
Defaults to 7.
in_stride (int): The stride of input image patch embedding.
Defaults to 4.
in_pad (int): The padding of input image patch embedding.
Defaults to 2.
down_patch_size (int): The patch size of downsampling patch embedding.
Defaults to 3.
down_stride (int): The stride of downsampling patch embedding.
Defaults to 2.
down_pad (int): The padding of downsampling patch embedding.
Defaults to 1.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which network position.
Index 0-6 respectively corresponds to
[stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims, --mlp_ratios:
# embedding dims and mlp ratios for the four stages
# --downsamples: flags to apply downsampling or not in four blocks
arch_settings = {
's12': {
'layers': [2, 2, 6, 2],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-5,
},
's24': {
'layers': [4, 4, 12, 4],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-5,
},
's36': {
'layers': [6, 6, 18, 6],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
'm36': {
'layers': [6, 6, 18, 6],
'embed_dims': [96, 192, 384, 768],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
'm48': {
'layers': [8, 8, 24, 8],
'embed_dims': [96, 192, 384, 768],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
}
def __init__(self,
arch='s12',
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
in_patch_size=7,
in_stride=4,
in_pad=2,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.,
drop_path_rate=0.,
out_indices=-1,
frozen_stages=0,
init_cfg=None,
deploy=False):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'layers' in arch and 'embed_dims' in arch, \
f'The arch dict must have "layers" and "embed_dims", ' \
f'but got {list(arch.keys())}.'
layers = arch['layers']
embed_dims = arch['embed_dims']
mlp_ratios = arch['mlp_ratios'] \
if 'mlp_ratios' in arch else [4, 4, 4, 4]
layer_scale_init_value = arch['layer_scale_init_value'] \
if 'layer_scale_init_value' in arch else 1e-5
self.patch_embed = PatchEmbed(
patch_size=in_patch_size,
stride=in_stride,
padding=in_pad,
in_chans=3,
embed_dim=embed_dims[0])
# set the main block in network
network = []
for i in range(len(layers)):
stage = basic_blocks(
embed_dims[i],
i,
layers,
mlp_ratio=mlp_ratios[i],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
deploy=deploy)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i],
embed_dim=embed_dims[i + 1]))
self.network = nn.ModuleList(network)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 7 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
if self.out_indices:
for i_layer in self.out_indices:
layer = build_norm_layer(norm_cfg,
embed_dims[(i_layer + 1) // 2])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
self.deploy = deploy
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
return tuple(outs)
def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
return x
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
# Include both block and downsample layer.
module = self.network[i]
module.eval()
for param in module.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(RIFormer, self).train(mode)
self._freeze_stages()
def switch_to_deploy(self):
for m in self.modules():
if isinstance(m, RIFormerBlock):
m.switch_to_deploy()
self.deploy = True
if __name__ == '__main__':
model = RIFormer(arch='s12', deploy=False)
model.eval()
print('------------------- training-time model -------------')
for i in model.state_dict().keys():
print(i)

View File

@ -0,0 +1,169 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
import torch.nn as nn
from mmpretrain.models.backbones import RIFormer
from mmpretrain.models.backbones.riformer import RIFormerBlock
class TestRIFormer(TestCase):
def setUp(self):
arch = 's12'
self.cfg = dict(arch=arch, drop_path_rate=0.1)
self.arch = RIFormer.arch_settings[arch]
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'Unavailable arch'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
RIFormer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'must have "layers"'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 96,
'num_heads': [3, 6, 12, 16],
}
RIFormer(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
layers = [2, 2, 4, 2]
embed_dims = [6, 12, 6, 12]
mlp_ratios = [2, 3, 4, 4]
layer_scale_init_value = 1e-4
cfg['arch'] = dict(
layers=layers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
layer_scale_init_value=layer_scale_init_value,
)
model = RIFormer(**cfg)
for i, stage in enumerate(model.network):
if not isinstance(stage, RIFormerBlock):
continue
self.assertEqual(len(stage), layers[i])
self.assertEqual(stage[0].mlp.fc1.in_channels, embed_dims[i])
self.assertEqual(stage[0].mlp.fc1.out_channels,
embed_dims[i] * mlp_ratios[i])
self.assertTrue(
torch.allclose(stage[0].layer_scale_1,
torch.tensor(layer_scale_init_value)))
self.assertTrue(
torch.allclose(stage[0].layer_scale_2,
torch.tensor(layer_scale_init_value)))
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = RIFormer(**cfg)
ori_weight = model.patch_embed.proj.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed.proj.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = RIFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (1, 512, 7, 7))
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 2, 4, 6)
model = RIFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for dim, stride, out in zip(self.arch['embed_dims'], [1, 2, 4, 8],
outs):
self.assertEqual(out.shape, (1, dim, 56 // stride, 56 // stride))
def test_repameterization(self):
# Test eval of "train" mode and "deploy" mode
imgs = torch.randn(1, 3, 224, 224)
gap = nn.AdaptiveAvgPool2d(output_size=(1))
fc = nn.Linear(self.arch['embed_dims'][3], 10)
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 2, 4, 6)
model = RIFormer(**cfg)
model.eval()
feats = model(imgs)
self.assertIsInstance(feats, tuple)
feat = feats[-1]
pred = fc(gap(feat).flatten(1))
model.switch_to_deploy()
for m in model.modules():
if isinstance(m, RIFormerBlock):
assert m.deploy is True
feats_deploy = model(imgs)
pred_deploy = fc(gap(feats_deploy[-1]).flatten(1))
for i in range(4):
torch.allclose(feats[i], feats_deploy[i])
torch.allclose(pred, pred_deploy)
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = RIFormer(**cfg)
layers = self.arch['layers']
for i, block in enumerate(model.network):
expect_prob = 0.2 / (sum(layers) - 1) * i
if hasattr(block, 'drop_path'):
if expect_prob == 0:
self.assertIsInstance(block.drop_path, torch.nn.Identity)
else:
self.assertAlmostEqual(block.drop_path.drop_prob,
expect_prob)
# test with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 1
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 2, 4, 6)
model = RIFormer(**cfg)
model.init_weights()
model.train()
# the patch_embed and first stage should not require grad.
self.assertFalse(model.patch_embed.training)
for param in model.patch_embed.parameters():
self.assertFalse(param.requires_grad)
for i in range(frozen_stages):
module = model.network[i]
for param in module.parameters():
self.assertFalse(param.requires_grad)
for param in model.norm0.parameters():
self.assertFalse(param.requires_grad)
# the second stage should require grad.
for i in range(frozen_stages + 1, 7):
module = model.network[i]
for param in module.parameters():
self.assertTrue(param.requires_grad)
if hasattr(model, f'norm{i}'):
norm = getattr(model, f'norm{i}')
for param in norm.parameters():
self.assertTrue(param.requires_grad)

View File

@ -4,8 +4,8 @@ from pathlib import Path
import torch import torch
from mmpretrain.apis import init_model from mmcls.apis import init_model
from mmpretrain.models.classifiers import ImageClassifier from mmcls.models.classifiers import ImageClassifier
def convert_classifier_to_deploy(model, checkpoint, save_path): def convert_classifier_to_deploy(model, checkpoint, save_path):
@ -39,7 +39,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
save_path = Path(args.save_path) save_path = Path(args.save_path)
if save_path.suffix != '.pth': if save_path.suffix != '.pth' and save_path.suffix != '.tar':
print('The path should contain the name of the pth format file.') print('The path should contain the name of the pth format file.')
exit() exit()
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
@ -47,7 +47,7 @@ def main():
model = init_model( model = init_model(
args.config_path, checkpoint=args.checkpoint_path, device='cpu') args.config_path, checkpoint=args.checkpoint_path, device='cpu')
assert isinstance(model, ImageClassifier), \ assert isinstance(model, ImageClassifier), \
'`model` must be a `mmpretrain.classifiers.ImageClassifier` instance.' '`model` must be a `mmcls.classifiers.ImageClassifier` instance.'
checkpoint = torch.load(args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path)
convert_classifier_to_deploy(model, checkpoint, args.save_path) convert_classifier_to_deploy(model, checkpoint, args.save_path)