[Feature] Support RepLKnet backbone. (#1129)
* update replknet configs * update replknet test * update replknet model * update replknet model * update replknet model * update replknet model * Fix docs and config names Co-authored-by: mzr1996 <mzr1996@163.com>pull/1177/head
parent
c3c1cb93aa
commit
72c6bc4864
|
@ -151,6 +151,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
||||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||||
|
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
@ -150,6 +150,7 @@ mim install -e .
|
||||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||||
|
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'ImageNet'
|
||||||
|
data_preprocessor = dict(
|
||||||
|
# RGB format normalization parameters
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
# convert image from BGR to RGB
|
||||||
|
to_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=384,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackClsInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
|
||||||
|
dict(type='PackClsInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root='data/imagenet',
|
||||||
|
ann_file='meta/train.txt',
|
||||||
|
data_prefix='train',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root='data/imagenet',
|
||||||
|
ann_file='meta/val.txt',
|
||||||
|
data_prefix='val',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||||
|
|
||||||
|
# If you want standard test, please manually configure the test dataset
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -0,0 +1,63 @@
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'ImageNet'
|
||||||
|
data_preprocessor = dict(
|
||||||
|
# RGB format normalization parameters
|
||||||
|
mean=[122.5, 122.5, 122.5],
|
||||||
|
std=[122.5, 122.5, 122.5],
|
||||||
|
# convert image from BGR to RGB
|
||||||
|
to_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
scale=320,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='PackClsInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeEdge',
|
||||||
|
scale=int(320 / 224 * 256),
|
||||||
|
edge='short',
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=320),
|
||||||
|
dict(type='PackClsInputs'),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=8,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root='data/imagenet',
|
||||||
|
ann_file='meta/train.txt',
|
||||||
|
data_prefix='train',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=8,
|
||||||
|
num_workers=5,
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root='data/imagenet',
|
||||||
|
ann_file='meta/val.txt',
|
||||||
|
data_prefix='val',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||||
|
|
||||||
|
# If you want standard test, please manually configure the test dataset
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -0,0 +1,25 @@
|
||||||
|
from mmcls.models import build_classifier
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='RepLKNet',
|
||||||
|
arch='31B',
|
||||||
|
out_indices=(3, ),
|
||||||
|
),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=1024,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# model.pop('type')
|
||||||
|
model = build_classifier(model)
|
||||||
|
model.eval()
|
||||||
|
print('------------------- training-time model -------------')
|
||||||
|
for i in model.state_dict().keys():
|
||||||
|
print(i)
|
|
@ -0,0 +1,15 @@
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='RepLKNet',
|
||||||
|
arch='31L',
|
||||||
|
out_indices=(3, ),
|
||||||
|
),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=1536,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='RepLKNet',
|
||||||
|
arch='XL',
|
||||||
|
out_indices=(3, ),
|
||||||
|
),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=2048,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,95 @@
|
||||||
|
# RepLKNet
|
||||||
|
|
||||||
|
> [Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs](https://arxiv.org/abs/2203.06717)
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
We revisit large kernel design in modern convolutional neural networks (CNNs). Inspired by recent advances in vision transformers (ViTs), in this paper, we demonstrate that using a few large convolutional kernels instead of a stack of small kernels could be a more powerful paradigm. We suggested five guidelines, e.g., applying re-parameterized large depth-wise convolutions, to design efficient highperformance large-kernel CNNs. Following the guidelines, we propose RepLKNet, a pure CNN architecture whose kernel size is as large as 31×31, in contrast to commonly used 3×3. RepLKNet greatly closes the performance gap between CNNs and ViTs, e.g., achieving comparable or superior results than Swin Transformer on ImageNet and a few typical downstream tasks, with lower latency. RepLKNet also shows nice scalability to big data and large models, obtaining 87.8% top-1 accuracy on ImageNet and 56.0% mIoU on ADE20K, which is very competitive among the state-of-the-arts with similar model sizes. Our study further reveals that, in contrast to small-kernel CNNs, large kernel CNNs have much larger effective receptive fields and higher shape bias rather than texture bias.
|
||||||
|
|
||||||
|
<div align=center>
|
||||||
|
<img src="https://user-images.githubusercontent.com/48375204/197546040-cdf078c3-7fbd-400f-8b27-01668c8dfebf.png" width="60%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### ImageNet-1k
|
||||||
|
|
||||||
|
| Model | Resolution | Pretrained Dataset | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||||
|
| :------------: | :--------: | :----------------: | :-----------------------------: | :-----------------------------: | :-------: | :-------: | :------------------------------------: | :--------------------------------------: |
|
||||||
|
| RepLKNet-31B\* | 224x224 | From Scratch | 79.9(train) \| 79.5 (deploy) | 15.6 (train) \| 15.4 (deploy) | 83.48 | 96.57 | [config (train)](./replknet-31B_32xb64_in1k.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth) |
|
||||||
|
| RepLKNet-31B\* | 384x384 | From Scratch | 79.9(train) \| 79.5 (deploy) | 46.0 (train) \| 45.3 (deploy) | 84.84 | 97.34 | [config (train)](./replknet-31B_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k-384px_20221118-03a170ce.pth) |
|
||||||
|
| RepLKNet-31B\* | 224x224 | ImageNet-21K | 79.9(train) \| 79.5 (deploy) | 15.6 (train) \| 15.4 (deploy) | 85.20 | 97.56 | [config (train)](./replknet-31B_32xb64_in1k.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth) |
|
||||||
|
| RepLKNet-31B\* | 384x384 | ImageNet-21K | 79.9(train) \| 79.5 (deploy) | 46.0 (train) \| 45.3 (deploy) | 85.99 | 97.75 | [config (train)](./replknet-31B_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth) |
|
||||||
|
| RepLKNet-31L\* | 384x384 | ImageNet-21K | 172.7(train) \| 172.0 (deploy) | 97.2 (train) \| 97.0 (deploy) | 86.63 | 98.00 | [config (train)](./replknet-31L_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31L-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31L_in21k-pre_3rdparty_in1k-384px_20221118-dc3fc07c.pth) |
|
||||||
|
| RepLKNet-XL\* | 320x320 | MegData-73M | 335.4(train) \| 335.0 (deploy) | 129.6 (train) \| 129.0 (deploy) | 87.57 | 98.39 | [config (train)](./replknet-XL_32xb64_in1k-320px.py) \| [config (deploy)](./deploy/replknet-XL-deploy_32xb64_in1k-320px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-XL_meg73m-pre_3rdparty_in1k-320px_20221118-88259b1d.pth) |
|
||||||
|
|
||||||
|
*Models with * are converted from the [official repo](https://github.com/DingXiaoH/RepVGG). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||||
|
|
||||||
|
## How to use
|
||||||
|
|
||||||
|
The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.
|
||||||
|
|
||||||
|
### Use tool
|
||||||
|
|
||||||
|
Use provided tool to reparameterize the given model and save the checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
`${CFG_PATH}` is the config file, `${SRC_CKPT_PATH}` is the source chenpoint file, `${TARGET_CKPT_PATH}` is the target deploy weight file path.
|
||||||
|
|
||||||
|
To use reparameterized weights, the config file must switch to the deploy config files.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
### In the code
|
||||||
|
|
||||||
|
Use `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()` to switch to the deploy mode. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmcls.models import build_backbone
|
||||||
|
|
||||||
|
backbone_cfg=dict(type='RepLKNet',arch='31B'),
|
||||||
|
backbone = build_backbone(backbone_cfg)
|
||||||
|
backbone.switch_to_deploy()
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmcls.models import build_classifier
|
||||||
|
|
||||||
|
cfg = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='RepLKNet',
|
||||||
|
arch='31B'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=1024,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
classifier = build_classifier(cfg)
|
||||||
|
classifier.backbone.switch_to_deploy()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```
|
||||||
|
@inproceedings{ding2022scaling,
|
||||||
|
title={Scaling up your kernels to 31x31: Revisiting large kernel design in cnns},
|
||||||
|
author={Ding, Xiaohan and Zhang, Xiangyu and Han, Jungong and Ding, Guiguang},
|
||||||
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||||
|
pages={11963--11975},
|
||||||
|
year={2022}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,3 @@
|
||||||
|
_base_ = '../replknet-31B_32xb64_in1k-384px.py'
|
||||||
|
|
||||||
|
model = dict(backbone=dict(small_kernel_merged=True))
|
|
@ -0,0 +1,3 @@
|
||||||
|
_base_ = '../replknet-31B_32xb64_in1k.py'
|
||||||
|
|
||||||
|
model = dict(backbone=dict(small_kernel_merged=True))
|
|
@ -0,0 +1,3 @@
|
||||||
|
_base_ = '../replknet-31L_32xb64_in1k-384px.py'
|
||||||
|
|
||||||
|
model = dict(backbone=dict(small_kernel_merged=True))
|
|
@ -0,0 +1,3 @@
|
||||||
|
_base_ = '../replknet-XL_32xb64_in1k-320px.py'
|
||||||
|
|
||||||
|
model = dict(backbone=dict(small_kernel_merged=True))
|
|
@ -0,0 +1,129 @@
|
||||||
|
Collections:
|
||||||
|
- Name: RepLKNet
|
||||||
|
Metadata:
|
||||||
|
Training Data: ImageNet-1k
|
||||||
|
Architecture:
|
||||||
|
- Large-Kernel Convolution
|
||||||
|
- VGG-style Neural Network
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/2203.06717
|
||||||
|
Title: 'Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs'
|
||||||
|
README: configs/replknet/README.md
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/backbones/replknet.py
|
||||||
|
Version: v1.0.0rc3
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: replknet-31B_3rdparty_in1k
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-31B_32xb64_in1k.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 15636547584
|
||||||
|
Parameters: 79864168
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 83.48
|
||||||
|
Top 5 Accuracy: 96.57
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1azQUiCxK9feYVkkrPqwVPBtNsTzDrX7S&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
||||||
|
|
||||||
|
- Name: replknet-31B_3rdparty_in1k-384px
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-31B_32xb64_in1k-384px.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 45952303104
|
||||||
|
Parameters: 79864168
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 84.84
|
||||||
|
Top 5 Accuracy: 97.34
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k-384px_20221118-03a170ce.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
||||||
|
|
||||||
|
- Name: replknet-31B_in21k-pre_3rdparty_in1k
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-31B_32xb64_in1k.py
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-21k
|
||||||
|
- ImageNet-1k
|
||||||
|
FLOPs: 15636547584
|
||||||
|
Parameters: 79864168
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 85.20
|
||||||
|
Top 5 Accuracy: 97.56
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1DslZ2voXZQR1QoFY9KnbsHAeF84hzS0s&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
||||||
|
|
||||||
|
- Name: replknet-31B_in21k-pre_3rdparty_in1k-384px
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-31B_32xb64_in1k-384px.py
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-21k
|
||||||
|
- ImageNet-1k
|
||||||
|
FLOPs: 45952303104
|
||||||
|
Parameters: 79864168
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 85.99
|
||||||
|
Top 5 Accuracy: 97.75
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1Sc46BWdXXm2fVP-K_hKKU_W8vAB-0duX&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
||||||
|
|
||||||
|
- Name: replknet-31L_in21k-pre_3rdparty_in1k-384px
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-31L_32xb64_in1k-384px.py
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-21k
|
||||||
|
- ImageNet-1k
|
||||||
|
FLOPs: 97240006656
|
||||||
|
Parameters: 172671016
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 86.63
|
||||||
|
Top 5 Accuracy: 98.00
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31L_in21k-pre_3rdparty_in1k-384px_20221118-dc3fc07c.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1JYXoNHuRvC33QV1pmpzMTKEni1hpWfBl&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
||||||
|
|
||||||
|
- Name: replknet-XL_meg73m-pre_3rdparty_in1k-320px
|
||||||
|
In Collection: RepLKNet
|
||||||
|
Config: configs/replknet/replknet-XL_32xb64_in1k-320px.py
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- MegData-73M
|
||||||
|
- ImageNet-1k
|
||||||
|
FLOPs: 129570201600
|
||||||
|
Parameters: 335435752
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Task: Image Classification
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 87.57
|
||||||
|
Top 5 Accuracy: 98.39
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-XL_meg73m-pre_3rdparty_in1k-320px_20221118-88259b1d.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/u/0/uc?id=1tPC60El34GntXByIRHb-z-Apm4Y5LX1T&export=download
|
||||||
|
Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
|
|
@ -0,0 +1,12 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/replknet-31B_in1k.py',
|
||||||
|
'../_base_/datasets/imagenet_bs16_pil_bicubic_384.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# schedule settings
|
||||||
|
param_scheduler = dict(
|
||||||
|
type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
|
||||||
|
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=300)
|
|
@ -0,0 +1,12 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/replknet-31B_in1k.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_bicubic.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# schedule settings
|
||||||
|
param_scheduler = dict(
|
||||||
|
type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
|
||||||
|
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=300)
|
|
@ -0,0 +1,12 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/replknet-31L_in1k.py',
|
||||||
|
'../_base_/datasets/imagenet_bs16_pil_bicubic_384.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# schedule settings
|
||||||
|
param_scheduler = dict(
|
||||||
|
type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
|
||||||
|
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=300)
|
|
@ -0,0 +1,12 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/replknet-XL_in1k.py',
|
||||||
|
'../_base_/datasets/imagenet_bs8_pil_bicubic_320.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# schedule settings
|
||||||
|
param_scheduler = dict(
|
||||||
|
type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
|
||||||
|
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=300)
|
|
@ -85,6 +85,7 @@ Backbones
|
||||||
PCPVT
|
PCPVT
|
||||||
PoolFormer
|
PoolFormer
|
||||||
RegNet
|
RegNet
|
||||||
|
RepLKNet
|
||||||
RepMLPNet
|
RepMLPNet
|
||||||
RepVGG
|
RepVGG
|
||||||
Res2Net
|
Res2Net
|
||||||
|
|
|
@ -23,6 +23,7 @@ from .mobilevit import MobileViT
|
||||||
from .mvit import MViT
|
from .mvit import MViT
|
||||||
from .poolformer import PoolFormer
|
from .poolformer import PoolFormer
|
||||||
from .regnet import RegNet
|
from .regnet import RegNet
|
||||||
|
from .replknet import RepLKNet
|
||||||
from .repmlp import RepMLPNet
|
from .repmlp import RepMLPNet
|
||||||
from .repvgg import RepVGG
|
from .repvgg import RepVGG
|
||||||
from .res2net import Res2Net
|
from .res2net import Res2Net
|
||||||
|
@ -82,6 +83,7 @@ __all__ = [
|
||||||
'CSPResNet',
|
'CSPResNet',
|
||||||
'CSPResNeXt',
|
'CSPResNeXt',
|
||||||
'CSPNet',
|
'CSPNet',
|
||||||
|
'RepLKNet',
|
||||||
'RepMLPNet',
|
'RepMLPNet',
|
||||||
'PoolFormer',
|
'PoolFormer',
|
||||||
'DenseNet',
|
'DenseNet',
|
||||||
|
|
|
@ -0,0 +1,668 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||||
|
from mmcv.cnn.bricks import DropPath
|
||||||
|
from mmengine.model import BaseModule
|
||||||
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.registry import MODELS
|
||||||
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups,
|
||||||
|
dilation=1,
|
||||||
|
norm_cfg=dict(type='BN')):
|
||||||
|
"""Construct a sequential conv and bn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Dimension of input features.
|
||||||
|
out_channels (int): Dimension of output features.
|
||||||
|
kernel_size (int): kernel_size of the convolution.
|
||||||
|
stride (int): stride of the convolution.
|
||||||
|
padding (int): stride of the convolution.
|
||||||
|
groups (int): groups of the convolution.
|
||||||
|
dilation (int): dilation of the convolution. Default to 1.
|
||||||
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||||
|
Default to ``dict(type='BN', requires_grad=True)``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Sequential(): A conv layer and a batch norm layer.
|
||||||
|
"""
|
||||||
|
if padding is None:
|
||||||
|
padding = kernel_size // 2
|
||||||
|
result = nn.Sequential()
|
||||||
|
result.add_module(
|
||||||
|
'conv',
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=False))
|
||||||
|
result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_relu(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups,
|
||||||
|
dilation=1):
|
||||||
|
"""Construct a sequential conv, bn and relu.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Dimension of input features.
|
||||||
|
out_channels (int): Dimension of output features.
|
||||||
|
kernel_size (int): kernel_size of the convolution.
|
||||||
|
stride (int): stride of the convolution.
|
||||||
|
padding (int): stride of the convolution.
|
||||||
|
groups (int): groups of the convolution.
|
||||||
|
dilation (int): dilation of the convolution. Default to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Sequential(): A conv layer, batch norm layer and a relu function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if padding is None:
|
||||||
|
padding = kernel_size // 2
|
||||||
|
result = conv_bn(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation)
|
||||||
|
result.add_module('nonlinear', nn.ReLU())
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_bn(conv, bn):
|
||||||
|
"""Fuse the parameters in a branch with a conv and bn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv (nn.Conv2d): The convolution module to fuse.
|
||||||
|
bn (nn.BatchNorm2d): The batch normalization to fuse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
|
||||||
|
fusing the parameters of conv and bn in one branch.
|
||||||
|
The first element is the weight and the second is the bias.
|
||||||
|
"""
|
||||||
|
kernel = conv.weight
|
||||||
|
running_mean = bn.running_mean
|
||||||
|
running_var = bn.running_var
|
||||||
|
gamma = bn.weight
|
||||||
|
beta = bn.bias
|
||||||
|
eps = bn.eps
|
||||||
|
std = (running_var + eps).sqrt()
|
||||||
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||||
|
return kernel * t, beta - running_mean * gamma / std
|
||||||
|
|
||||||
|
|
||||||
|
class ReparamLargeKernelConv(BaseModule):
|
||||||
|
"""Super large kernel implemented by with large convolutions.
|
||||||
|
|
||||||
|
Input: Tensor with shape [B, C, H, W].
|
||||||
|
Output: Tensor with shape [B, C, H, W].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Dimension of input features.
|
||||||
|
out_channels (int): Dimension of output features.
|
||||||
|
kernel_size (int): kernel_size of the large convolution.
|
||||||
|
stride (int): stride of the large convolution.
|
||||||
|
groups (int): groups of the large convolution.
|
||||||
|
small_kernel (int): kernel_size of the small convolution.
|
||||||
|
small_kernel_merged (bool): Whether to switch the model structure to
|
||||||
|
deployment mode (merge the small kernel to the large kernel).
|
||||||
|
Default to False.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
groups,
|
||||||
|
small_kernel,
|
||||||
|
small_kernel_merged=False,
|
||||||
|
init_cfg=None):
|
||||||
|
super(ReparamLargeKernelConv, self).__init__(init_cfg)
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.small_kernel = small_kernel
|
||||||
|
self.small_kernel_merged = small_kernel_merged
|
||||||
|
# We assume the conv does not change the feature map size,
|
||||||
|
# so padding = k//2.
|
||||||
|
# Otherwise, you may configure padding as you wish,
|
||||||
|
# and change the padding of small_conv accordingly.
|
||||||
|
padding = kernel_size // 2
|
||||||
|
if small_kernel_merged:
|
||||||
|
self.lkb_reparam = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=1,
|
||||||
|
groups=groups,
|
||||||
|
bias=True)
|
||||||
|
else:
|
||||||
|
self.lkb_origin = conv_bn(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=1,
|
||||||
|
groups=groups)
|
||||||
|
if small_kernel is not None:
|
||||||
|
assert small_kernel <= kernel_size
|
||||||
|
self.small_conv = conv_bn(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=small_kernel,
|
||||||
|
stride=stride,
|
||||||
|
padding=small_kernel // 2,
|
||||||
|
groups=groups,
|
||||||
|
dilation=1)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if hasattr(self, 'lkb_reparam'):
|
||||||
|
out = self.lkb_reparam(inputs)
|
||||||
|
else:
|
||||||
|
out = self.lkb_origin(inputs)
|
||||||
|
if hasattr(self, 'small_conv'):
|
||||||
|
out += self.small_conv(inputs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_equivalent_kernel_bias(self):
|
||||||
|
eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
||||||
|
if hasattr(self, 'small_conv'):
|
||||||
|
small_k, small_b = fuse_bn(self.small_conv.conv,
|
||||||
|
self.small_conv.bn)
|
||||||
|
eq_b += small_b
|
||||||
|
# add to the central part
|
||||||
|
eq_k += nn.functional.pad(
|
||||||
|
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
|
||||||
|
return eq_k, eq_b
|
||||||
|
|
||||||
|
def merge_kernel(self):
|
||||||
|
"""Switch the model structure from training mode to deployment mode."""
|
||||||
|
if self.small_kernel_merged:
|
||||||
|
return
|
||||||
|
eq_k, eq_b = self.get_equivalent_kernel_bias()
|
||||||
|
self.lkb_reparam = nn.Conv2d(
|
||||||
|
in_channels=self.lkb_origin.conv.in_channels,
|
||||||
|
out_channels=self.lkb_origin.conv.out_channels,
|
||||||
|
kernel_size=self.lkb_origin.conv.kernel_size,
|
||||||
|
stride=self.lkb_origin.conv.stride,
|
||||||
|
padding=self.lkb_origin.conv.padding,
|
||||||
|
dilation=self.lkb_origin.conv.dilation,
|
||||||
|
groups=self.lkb_origin.conv.groups,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.lkb_reparam.weight.data = eq_k
|
||||||
|
self.lkb_reparam.bias.data = eq_b
|
||||||
|
self.__delattr__('lkb_origin')
|
||||||
|
if hasattr(self, 'small_conv'):
|
||||||
|
self.__delattr__('small_conv')
|
||||||
|
|
||||||
|
self.small_kernel_merged = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFFN(BaseModule):
|
||||||
|
"""Mlp implemented by with 1*1 convolutions.
|
||||||
|
|
||||||
|
Input: Tensor with shape [B, C, H, W].
|
||||||
|
Output: Tensor with shape [B, C, H, W].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Dimension of input features.
|
||||||
|
internal_channels (int): Dimension of hidden features.
|
||||||
|
out_channels (int): Dimension of output features.
|
||||||
|
drop_path (float): Stochastic depth rate. Defaults to 0.
|
||||||
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||||
|
Default to ``dict(type='BN', requires_grad=True)``.
|
||||||
|
act_cfg (dict): The config dict for activation between pointwise
|
||||||
|
convolution. Defaults to ``dict(type='GELU')``.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
internal_channels,
|
||||||
|
out_channels,
|
||||||
|
drop_path,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(ConvFFN, self).__init__(init_cfg)
|
||||||
|
self.drop_path = DropPath(
|
||||||
|
drop_prob=drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1]
|
||||||
|
self.pw1 = conv_bn(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=internal_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=1)
|
||||||
|
self.pw2 = conv_bn(
|
||||||
|
in_channels=internal_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=1)
|
||||||
|
self.nonlinear = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.preffn_bn(x)
|
||||||
|
out = self.pw1(out)
|
||||||
|
out = self.nonlinear(out)
|
||||||
|
out = self.pw2(out)
|
||||||
|
return x + self.drop_path(out)
|
||||||
|
|
||||||
|
|
||||||
|
class RepLKBlock(BaseModule):
|
||||||
|
"""RepLKBlock for RepLKNet backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The input channels of the block.
|
||||||
|
dw_channels (int): The intermediate channels of the block,
|
||||||
|
i.e., input channels of the large kernel convolution.
|
||||||
|
block_lk_size (int): size of the super large kernel. Defaults: 31.
|
||||||
|
small_kernel (int): size of the parallel small kernel. Defaults: 5.
|
||||||
|
drop_path (float): Stochastic depth rate. Defaults: 0.
|
||||||
|
small_kernel_merged (bool): Whether to switch the model structure to
|
||||||
|
deployment mode (merge the small kernel to the large kernel).
|
||||||
|
Default to False.
|
||||||
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||||
|
Default to ``dict(type='BN', requires_grad=True)``.
|
||||||
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
Default to ``dict(type='ReLU')``.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default to None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
dw_channels,
|
||||||
|
block_lk_size,
|
||||||
|
small_kernel,
|
||||||
|
drop_path,
|
||||||
|
small_kernel_merged=False,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(RepLKBlock, self).__init__(init_cfg)
|
||||||
|
self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
|
||||||
|
self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
|
||||||
|
self.large_kernel = ReparamLargeKernelConv(
|
||||||
|
in_channels=dw_channels,
|
||||||
|
out_channels=dw_channels,
|
||||||
|
kernel_size=block_lk_size,
|
||||||
|
stride=1,
|
||||||
|
groups=dw_channels,
|
||||||
|
small_kernel=small_kernel,
|
||||||
|
small_kernel_merged=small_kernel_merged)
|
||||||
|
self.lk_nonlinear = build_activation_layer(act_cfg)
|
||||||
|
self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1]
|
||||||
|
self.drop_path = DropPath(
|
||||||
|
drop_prob=drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
# print('drop path:', self.drop_path)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.prelkb_bn(x)
|
||||||
|
out = self.pw1(out)
|
||||||
|
out = self.large_kernel(out)
|
||||||
|
out = self.lk_nonlinear(out)
|
||||||
|
out = self.pw2(out)
|
||||||
|
return x + self.drop_path(out)
|
||||||
|
|
||||||
|
|
||||||
|
class RepLKNetStage(BaseModule):
|
||||||
|
"""
|
||||||
|
generate RepLKNet blocks for a stage
|
||||||
|
return: RepLKNet blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The input channels of the stage.
|
||||||
|
num_blocks (int): The number of blocks of the stage.
|
||||||
|
stage_lk_size (int): size of the super large kernel. Defaults: 31.
|
||||||
|
drop_path (float): Stochastic depth rate. Defaults: 0.
|
||||||
|
small_kernel (int): size of the parallel small kernel. Defaults: 5.
|
||||||
|
dw_ratio (float): The intermediate channels
|
||||||
|
expansion ratio of the block. Defaults: 1.
|
||||||
|
ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default to False.
|
||||||
|
small_kernel_merged (bool): Whether to switch the model structure to
|
||||||
|
deployment mode (merge the small kernel to the large kernel).
|
||||||
|
Default to False.
|
||||||
|
norm_intermediate_features (bool): Construct and config norm layer
|
||||||
|
or not.
|
||||||
|
Using True will normalize the intermediate features for
|
||||||
|
downstream dense prediction tasks.
|
||||||
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||||
|
Default to ``dict(type='BN', requires_grad=True)``.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default to None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
num_blocks,
|
||||||
|
stage_lk_size,
|
||||||
|
drop_path,
|
||||||
|
small_kernel,
|
||||||
|
dw_ratio=1,
|
||||||
|
ffn_ratio=4,
|
||||||
|
with_cp=False, # train with torch.utils.checkpoint to save memory
|
||||||
|
small_kernel_merged=False,
|
||||||
|
norm_intermediate_features=False,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(RepLKNetStage, self).__init__(init_cfg)
|
||||||
|
self.with_cp = with_cp
|
||||||
|
blks = []
|
||||||
|
for i in range(num_blocks):
|
||||||
|
block_drop_path = drop_path[i] if isinstance(drop_path,
|
||||||
|
list) else drop_path
|
||||||
|
# Assume all RepLK Blocks within a stage share the same lk_size.
|
||||||
|
# You may tune it on your own model.
|
||||||
|
replk_block = RepLKBlock(
|
||||||
|
in_channels=channels,
|
||||||
|
dw_channels=int(channels * dw_ratio),
|
||||||
|
block_lk_size=stage_lk_size,
|
||||||
|
small_kernel=small_kernel,
|
||||||
|
drop_path=block_drop_path,
|
||||||
|
small_kernel_merged=small_kernel_merged)
|
||||||
|
convffn_block = ConvFFN(
|
||||||
|
in_channels=channels,
|
||||||
|
internal_channels=int(channels * ffn_ratio),
|
||||||
|
out_channels=channels,
|
||||||
|
drop_path=block_drop_path)
|
||||||
|
blks.append(replk_block)
|
||||||
|
blks.append(convffn_block)
|
||||||
|
self.blocks = nn.ModuleList(blks)
|
||||||
|
if norm_intermediate_features:
|
||||||
|
self.norm = build_norm_layer(norm_cfg, channels)[1]
|
||||||
|
else:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for blk in self.blocks:
|
||||||
|
if self.with_cp:
|
||||||
|
x = checkpoint.checkpoint(blk, x) # Save training memory
|
||||||
|
else:
|
||||||
|
x = blk(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class RepLKNet(BaseBackbone):
|
||||||
|
"""RepLKNet backbone.
|
||||||
|
|
||||||
|
A PyTorch impl of :
|
||||||
|
`Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
|
||||||
|
<https://arxiv.org/abs/2203.06717>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arch (str | dict): The parameter of RepLKNet.
|
||||||
|
If it's a dict, it should contain the following keys:
|
||||||
|
|
||||||
|
- large_kernel_sizes (Sequence[int]):
|
||||||
|
Large kernel size in each stage.
|
||||||
|
- layers (Sequence[int]): Number of blocks in each stage.
|
||||||
|
- channels (Sequence[int]): Number of channels in each stage.
|
||||||
|
- small_kernel (int): size of the parallel small kernel.
|
||||||
|
- dw_ratio (float): The intermediate channels
|
||||||
|
expansion ratio of the block.
|
||||||
|
in_channels (int): Number of input image channels. Default to 3.
|
||||||
|
ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
|
||||||
|
out_indices (Sequence[int]): Output from which stages.
|
||||||
|
Default to (3, ).
|
||||||
|
strides (Sequence[int]): Strides of the first block of each stage.
|
||||||
|
Default to (2, 2, 2, 2).
|
||||||
|
dilations (Sequence[int]): Dilation of each stage.
|
||||||
|
Default to (1, 1, 1, 1).
|
||||||
|
frozen_stages (int): Stages to be frozen
|
||||||
|
(all param fixed). -1 means not freezing any parameters.
|
||||||
|
Default to -1.
|
||||||
|
conv_cfg (dict | None): The config dict for conv layers.
|
||||||
|
Default to None.
|
||||||
|
norm_cfg (dict): The config dict for norm layers.
|
||||||
|
Default to ``dict(type='BN')``.
|
||||||
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
Default to ``dict(type='ReLU')``.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default to False.
|
||||||
|
deploy (bool): Whether to switch the model structure to deployment
|
||||||
|
mode. Default to False.
|
||||||
|
norm_intermediate_features (bool): Construct and
|
||||||
|
config norm layer or not.
|
||||||
|
Using True will normalize the intermediate features
|
||||||
|
for downstream dense prediction tasks.
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only. Default to False.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
arch_settings = {
|
||||||
|
'31B':
|
||||||
|
dict(
|
||||||
|
large_kernel_sizes=[31, 29, 27, 13],
|
||||||
|
layers=[2, 2, 18, 2],
|
||||||
|
channels=[128, 256, 512, 1024],
|
||||||
|
small_kernel=5,
|
||||||
|
dw_ratio=1),
|
||||||
|
'31L':
|
||||||
|
dict(
|
||||||
|
large_kernel_sizes=[31, 29, 27, 13],
|
||||||
|
layers=[2, 2, 18, 2],
|
||||||
|
channels=[192, 384, 768, 1536],
|
||||||
|
small_kernel=5,
|
||||||
|
dw_ratio=1),
|
||||||
|
'XL':
|
||||||
|
dict(
|
||||||
|
large_kernel_sizes=[27, 27, 27, 13],
|
||||||
|
layers=[2, 2, 18, 2],
|
||||||
|
channels=[256, 512, 1024, 2048],
|
||||||
|
small_kernel=None,
|
||||||
|
dw_ratio=1.5),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
arch,
|
||||||
|
in_channels=3,
|
||||||
|
ffn_ratio=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
strides=(2, 2, 2, 2),
|
||||||
|
dilations=(1, 1, 1, 1),
|
||||||
|
frozen_stages=-1,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
with_cp=False,
|
||||||
|
drop_path_rate=0.3,
|
||||||
|
small_kernel_merged=False,
|
||||||
|
norm_intermediate_features=False,
|
||||||
|
norm_eval=False,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='Kaiming', layer=['Conv2d']),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]):
|
||||||
|
super(RepLKNet, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
if isinstance(arch, str):
|
||||||
|
assert arch in self.arch_settings, \
|
||||||
|
f'"arch": "{arch}" is not one of the arch_settings'
|
||||||
|
arch = self.arch_settings[arch]
|
||||||
|
elif not isinstance(arch, dict):
|
||||||
|
raise TypeError('Expect "arch" to be either a string '
|
||||||
|
f'or a dict, got {type(arch)}')
|
||||||
|
|
||||||
|
assert len(arch['layers']) == len(
|
||||||
|
arch['channels']) == len(strides) == len(dilations)
|
||||||
|
assert max(out_indices) < len(arch['layers'])
|
||||||
|
|
||||||
|
self.arch = arch
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_indices = out_indices
|
||||||
|
self.strides = strides
|
||||||
|
self.dilations = dilations
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.small_kernel_merged = small_kernel_merged
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.norm_intermediate_features = norm_intermediate_features
|
||||||
|
|
||||||
|
self.out_indices = out_indices
|
||||||
|
|
||||||
|
base_width = self.arch['channels'][0]
|
||||||
|
self.norm_intermediate_features = norm_intermediate_features
|
||||||
|
self.num_stages = len(self.arch['layers'])
|
||||||
|
self.stem = nn.ModuleList([
|
||||||
|
conv_bn_relu(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=base_width,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=1),
|
||||||
|
conv_bn_relu(
|
||||||
|
in_channels=base_width,
|
||||||
|
out_channels=base_width,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=base_width),
|
||||||
|
conv_bn_relu(
|
||||||
|
in_channels=base_width,
|
||||||
|
out_channels=base_width,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=1),
|
||||||
|
conv_bn_relu(
|
||||||
|
in_channels=base_width,
|
||||||
|
out_channels=base_width,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=base_width)
|
||||||
|
])
|
||||||
|
# stochastic depth. We set block-wise drop-path rate.
|
||||||
|
# The higher level blocks are more likely to be dropped.
|
||||||
|
# This implementation follows Swin.
|
||||||
|
dpr = [
|
||||||
|
x.item() for x in torch.linspace(0, drop_path_rate,
|
||||||
|
sum(self.arch['layers']))
|
||||||
|
]
|
||||||
|
self.stages = nn.ModuleList()
|
||||||
|
self.transitions = nn.ModuleList()
|
||||||
|
for stage_idx in range(self.num_stages):
|
||||||
|
layer = RepLKNetStage(
|
||||||
|
channels=self.arch['channels'][stage_idx],
|
||||||
|
num_blocks=self.arch['layers'][stage_idx],
|
||||||
|
stage_lk_size=self.arch['large_kernel_sizes'][stage_idx],
|
||||||
|
drop_path=dpr[sum(self.arch['layers'][:stage_idx]
|
||||||
|
):sum(self.arch['layers'][:stage_idx + 1])],
|
||||||
|
small_kernel=self.arch['small_kernel'],
|
||||||
|
dw_ratio=self.arch['dw_ratio'],
|
||||||
|
ffn_ratio=ffn_ratio,
|
||||||
|
with_cp=with_cp,
|
||||||
|
small_kernel_merged=small_kernel_merged,
|
||||||
|
norm_intermediate_features=(stage_idx in out_indices))
|
||||||
|
self.stages.append(layer)
|
||||||
|
if stage_idx < len(self.arch['layers']) - 1:
|
||||||
|
transition = nn.Sequential(
|
||||||
|
conv_bn_relu(
|
||||||
|
self.arch['channels'][stage_idx],
|
||||||
|
self.arch['channels'][stage_idx + 1],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
groups=1),
|
||||||
|
conv_bn_relu(
|
||||||
|
self.arch['channels'][stage_idx + 1],
|
||||||
|
self.arch['channels'][stage_idx + 1],
|
||||||
|
3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=self.arch['channels'][stage_idx + 1]))
|
||||||
|
self.transitions.append(transition)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.stem[0](x)
|
||||||
|
for stem_layer in self.stem[1:]:
|
||||||
|
if self.with_cp:
|
||||||
|
x = checkpoint.checkpoint(stem_layer, x) # save memory
|
||||||
|
else:
|
||||||
|
x = stem_layer(x)
|
||||||
|
|
||||||
|
# Need the intermediate feature maps
|
||||||
|
outs = []
|
||||||
|
for stage_idx in range(self.num_stages):
|
||||||
|
x = self.stages[stage_idx](x)
|
||||||
|
if stage_idx in self.out_indices:
|
||||||
|
outs.append(self.stages[stage_idx].norm(x))
|
||||||
|
# For RepLKNet-XL normalize the features
|
||||||
|
# before feeding them into the heads
|
||||||
|
if stage_idx < self.num_stages - 1:
|
||||||
|
x = self.transitions[stage_idx](x)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
return tuple(x)
|
||||||
|
|
||||||
|
def _freeze_stages(self):
|
||||||
|
if self.frozen_stages >= 0:
|
||||||
|
self.stem.eval()
|
||||||
|
for param in self.stem.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
for i in range(self.frozen_stages):
|
||||||
|
stage = self.stages[i]
|
||||||
|
stage.eval()
|
||||||
|
for param in stage.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
super(RepLKNet, self).train(mode)
|
||||||
|
self._freeze_stages()
|
||||||
|
if mode and self.norm_eval:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, _BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def switch_to_deploy(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if hasattr(m, 'merge_kernel'):
|
||||||
|
m.merge_kernel()
|
||||||
|
self.small_kernel_merged = True
|
|
@ -38,3 +38,4 @@ Import:
|
||||||
- configs/hornet/metafile.yml
|
- configs/hornet/metafile.yml
|
||||||
- configs/mobilevit/metafile.yml
|
- configs/mobilevit/metafile.yml
|
||||||
- configs/davit/metafile.yml
|
- configs/davit/metafile.yml
|
||||||
|
- configs/replknet/metafile.yml
|
||||||
|
|
|
@ -0,0 +1,304 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import RepLKNet
|
||||||
|
from mmcls.models.backbones.replknet import ReparamLargeKernelConv
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_replk_block(modules):
|
||||||
|
if isinstance(modules, ReparamLargeKernelConv):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_replknet_replkblock():
|
||||||
|
# Test ReparamLargeKernelConv with in_channels != out_channels,
|
||||||
|
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
5, 10, kernel_size=31, stride=1, groups=5, small_kernel=5)
|
||||||
|
block.eval()
|
||||||
|
x = torch.randn(1, 5, 64, 64)
|
||||||
|
x_out_not_deploy = block(x)
|
||||||
|
assert block.small_kernel <= block.kernel_size
|
||||||
|
assert not hasattr(block, 'lkb_reparam')
|
||||||
|
assert hasattr(block, 'lkb_origin')
|
||||||
|
assert hasattr(block, 'small_conv')
|
||||||
|
assert x_out_not_deploy.shape == torch.Size((1, 10, 64, 64))
|
||||||
|
block.merge_kernel()
|
||||||
|
assert block.small_kernel_merged is True
|
||||||
|
x_out_deploy = block(x)
|
||||||
|
assert x_out_deploy.shape == torch.Size((1, 10, 64, 64))
|
||||||
|
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
# Test ReparamLargeKernelConv with in_channels == out_channels,
|
||||||
|
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
12, 12, kernel_size=31, stride=1, groups=12, small_kernel=5)
|
||||||
|
block.eval()
|
||||||
|
x = torch.randn(1, 12, 64, 64)
|
||||||
|
x_out_not_deploy = block(x)
|
||||||
|
assert block.small_kernel <= block.kernel_size
|
||||||
|
assert not hasattr(block, 'lkb_reparam')
|
||||||
|
assert hasattr(block, 'lkb_origin')
|
||||||
|
assert hasattr(block, 'small_conv')
|
||||||
|
assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
|
||||||
|
block.merge_kernel()
|
||||||
|
assert block.small_kernel_merged is True
|
||||||
|
x_out_deploy = block(x)
|
||||||
|
assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
|
||||||
|
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
# Test ReparamLargeKernelConv with in_channels == out_channels,
|
||||||
|
# kernel_size = 31, stride = 2, groups=in_channels, small_kernel = 5
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
16, 16, kernel_size=31, stride=2, groups=16, small_kernel=5)
|
||||||
|
block.eval()
|
||||||
|
x = torch.randn(1, 16, 64, 64)
|
||||||
|
x_out_not_deploy = block(x)
|
||||||
|
assert block.small_kernel <= block.kernel_size
|
||||||
|
assert not hasattr(block, 'lkb_reparam')
|
||||||
|
assert hasattr(block, 'lkb_origin')
|
||||||
|
assert hasattr(block, 'small_conv')
|
||||||
|
assert x_out_not_deploy.shape == torch.Size((1, 16, 32, 32))
|
||||||
|
block.merge_kernel()
|
||||||
|
assert block.small_kernel_merged is True
|
||||||
|
x_out_deploy = block(x)
|
||||||
|
assert x_out_deploy.shape == torch.Size((1, 16, 32, 32))
|
||||||
|
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
# Test ReparamLargeKernelConv with in_channels == out_channels,
|
||||||
|
# kernel_size = 27, stride = 1, groups=in_channels, small_kernel = 5
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
12, 12, kernel_size=27, stride=1, groups=12, small_kernel=5)
|
||||||
|
block.eval()
|
||||||
|
x = torch.randn(1, 12, 48, 48)
|
||||||
|
x_out_not_deploy = block(x)
|
||||||
|
assert block.small_kernel <= block.kernel_size
|
||||||
|
assert not hasattr(block, 'lkb_reparam')
|
||||||
|
assert hasattr(block, 'lkb_origin')
|
||||||
|
assert hasattr(block, 'small_conv')
|
||||||
|
assert x_out_not_deploy.shape == torch.Size((1, 12, 48, 48))
|
||||||
|
block.merge_kernel()
|
||||||
|
assert block.small_kernel_merged is True
|
||||||
|
x_out_deploy = block(x)
|
||||||
|
assert x_out_deploy.shape == torch.Size((1, 12, 48, 48))
|
||||||
|
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
# Test ReparamLargeKernelConv with in_channels == out_channels,
|
||||||
|
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 7
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
12, 12, kernel_size=31, stride=1, groups=12, small_kernel=7)
|
||||||
|
block.eval()
|
||||||
|
x = torch.randn(1, 12, 64, 64)
|
||||||
|
x_out_not_deploy = block(x)
|
||||||
|
assert block.small_kernel <= block.kernel_size
|
||||||
|
assert not hasattr(block, 'lkb_reparam')
|
||||||
|
assert hasattr(block, 'lkb_origin')
|
||||||
|
assert hasattr(block, 'small_conv')
|
||||||
|
assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
|
||||||
|
block.merge_kernel()
|
||||||
|
assert block.small_kernel_merged is True
|
||||||
|
x_out_deploy = block(x)
|
||||||
|
assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
|
||||||
|
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||||
|
|
||||||
|
# Test ReparamLargeKernelConv with deploy == True
|
||||||
|
block = ReparamLargeKernelConv(
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
kernel_size=31,
|
||||||
|
stride=1,
|
||||||
|
groups=8,
|
||||||
|
small_kernel=5,
|
||||||
|
small_kernel_merged=True)
|
||||||
|
assert isinstance(block.lkb_reparam, nn.Conv2d)
|
||||||
|
assert not hasattr(block, 'lkb_origin')
|
||||||
|
assert not hasattr(block, 'small_conv')
|
||||||
|
x = torch.randn(1, 8, 48, 48)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size((1, 8, 48, 48))
|
||||||
|
|
||||||
|
|
||||||
|
def test_replknet_backbone():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# arch must be str or dict
|
||||||
|
RepLKNet(arch=[4, 6, 16, 1])
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# arch must in arch_settings
|
||||||
|
RepLKNet(arch='31C')
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# arch must have num_blocks and width_factor
|
||||||
|
arch = dict(large_kernel_sizes=[31, 29, 27, 13])
|
||||||
|
RepLKNet(arch=arch)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# arch must have num_blocks and width_factor
|
||||||
|
arch = dict(large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2])
|
||||||
|
RepLKNet(arch=arch)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# arch must have num_blocks and width_factor
|
||||||
|
arch = dict(
|
||||||
|
large_kernel_sizes=[31, 29, 27, 13],
|
||||||
|
layers=[2, 2, 18, 2],
|
||||||
|
channels=[128, 256, 512, 1024])
|
||||||
|
RepLKNet(arch=arch)
|
||||||
|
|
||||||
|
# len(arch['large_kernel_sizes']) == arch['layers'])
|
||||||
|
# == len(arch['channels'])
|
||||||
|
# == len(strides) == len(dilations)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
arch = dict(
|
||||||
|
large_kernel_sizes=[31, 29, 27, 13],
|
||||||
|
layers=[2, 2, 18, 2],
|
||||||
|
channels=[128, 256, 1024],
|
||||||
|
small_kernel=5,
|
||||||
|
dw_ratio=1)
|
||||||
|
RepLKNet(arch=arch)
|
||||||
|
|
||||||
|
# len(strides) must equal to 4
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
RepLKNet('31B', strides=(2, 2, 2))
|
||||||
|
|
||||||
|
# len(dilations) must equal to 4
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
RepLKNet('31B', strides=(2, 2, 2, 2), dilations=(1, 1, 1))
|
||||||
|
|
||||||
|
# max(out_indices) < len(arch['num_blocks'])
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
RepLKNet('31B', out_indices=(5, ))
|
||||||
|
|
||||||
|
# Test RepLKNet norm state
|
||||||
|
model = RepLKNet('31B')
|
||||||
|
model.train()
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
# Test RepLKNet with first stage frozen
|
||||||
|
frozen_stages = 1
|
||||||
|
model = RepLKNet('31B', frozen_stages=frozen_stages)
|
||||||
|
model.train()
|
||||||
|
for param in model.stem.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
for i in range(0, frozen_stages):
|
||||||
|
stage = model.stages[i]
|
||||||
|
for mod in stage.modules():
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
assert mod.training is False
|
||||||
|
for param in stage.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
|
||||||
|
# Test RepLKNet with norm_eval
|
||||||
|
model = RepLKNet('31B', norm_eval=True)
|
||||||
|
model.train()
|
||||||
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
|
# Test RepLKNet forward with layer 3 forward
|
||||||
|
model = RepLKNet('31B', out_indices=(3, ))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, _BatchNorm)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert isinstance(feat, tuple)
|
||||||
|
assert len(feat) == 1
|
||||||
|
assert isinstance(feat[0], torch.Tensor)
|
||||||
|
assert feat[0].shape == torch.Size((1, 1024, 7, 7))
|
||||||
|
|
||||||
|
# Test RepLKNet forward
|
||||||
|
model_test_settings = [
|
||||||
|
dict(model_name='31B', out_sizes=(128, 256, 512, 1024)),
|
||||||
|
# dict(model_name='31L', out_sizes=(192, 384, 768, 1536)),
|
||||||
|
# dict(model_name='XL', out_sizes=(256, 512, 1024, 2048))
|
||||||
|
]
|
||||||
|
|
||||||
|
choose_models = ['31B']
|
||||||
|
# Test RepLKNet model forward
|
||||||
|
for model_test_setting in model_test_settings:
|
||||||
|
if model_test_setting['model_name'] not in choose_models:
|
||||||
|
continue
|
||||||
|
model = RepLKNet(
|
||||||
|
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# Test Norm
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, _BatchNorm)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[0].shape == torch.Size(
|
||||||
|
(1, model_test_setting['out_sizes'][0], 56, 56))
|
||||||
|
assert feat[1].shape == torch.Size(
|
||||||
|
(1, model_test_setting['out_sizes'][1], 28, 28))
|
||||||
|
assert feat[2].shape == torch.Size(
|
||||||
|
(1, model_test_setting['out_sizes'][2], 14, 14))
|
||||||
|
assert feat[3].shape == torch.Size(
|
||||||
|
(1, model_test_setting['out_sizes'][3], 7, 7))
|
||||||
|
|
||||||
|
# Test eval of "train" mode and "deploy" mode
|
||||||
|
gap = nn.AdaptiveAvgPool2d(output_size=(1))
|
||||||
|
fc = nn.Linear(model_test_setting['out_sizes'][3], 10)
|
||||||
|
model.eval()
|
||||||
|
feat = model(imgs)
|
||||||
|
pred = fc(gap(feat[3]).flatten(1))
|
||||||
|
model.switch_to_deploy()
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, ReparamLargeKernelConv):
|
||||||
|
assert m.small_kernel_merged is True
|
||||||
|
feat_deploy = model(imgs)
|
||||||
|
pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
|
||||||
|
for i in range(4):
|
||||||
|
torch.allclose(feat[i], feat_deploy[i])
|
||||||
|
torch.allclose(pred, pred_deploy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replknet_load():
|
||||||
|
# Test output before and load from deploy checkpoint
|
||||||
|
model = RepLKNet('31B', out_indices=(0, 1, 2, 3))
|
||||||
|
inputs = torch.randn((1, 3, 224, 224))
|
||||||
|
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
|
||||||
|
model.switch_to_deploy()
|
||||||
|
model.eval()
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
model_deploy = RepLKNet(
|
||||||
|
'31B', out_indices=(0, 1, 2, 3), small_kernel_merged=True)
|
||||||
|
model_deploy.eval()
|
||||||
|
save_checkpoint(model.state_dict(), ckpt_path)
|
||||||
|
load_checkpoint(model_deploy, ckpt_path, strict=True)
|
||||||
|
|
||||||
|
outputs_load = model_deploy(inputs)
|
||||||
|
for feat, feat_load in zip(outputs, outputs_load):
|
||||||
|
assert torch.allclose(feat, feat_load)
|
|
@ -342,6 +342,7 @@ def test_repvgg_load():
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
|
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
|
||||||
|
model_deploy.eval()
|
||||||
save_checkpoint(model.state_dict(), ckpt_path)
|
save_checkpoint(model.state_dict(), ckpt_path)
|
||||||
load_checkpoint(model_deploy, ckpt_path, strict=True)
|
load_checkpoint(model_deploy, ckpt_path, strict=True)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def convert(src, dst):
|
||||||
|
print('Converting...')
|
||||||
|
blobs = torch.load(src, map_location='cpu')
|
||||||
|
converted_state_dict = OrderedDict()
|
||||||
|
|
||||||
|
for key in blobs:
|
||||||
|
splited_key = key.split('.')
|
||||||
|
print(splited_key)
|
||||||
|
splited_key = [
|
||||||
|
'backbone.stem' if i[:4] == 'stem' else i for i in splited_key
|
||||||
|
]
|
||||||
|
splited_key = [
|
||||||
|
'backbone.stages' if i[:6] == 'stages' else i for i in splited_key
|
||||||
|
]
|
||||||
|
splited_key = [
|
||||||
|
'backbone.transitions' if i[:11] == 'transitions' else i
|
||||||
|
for i in splited_key
|
||||||
|
]
|
||||||
|
splited_key = [
|
||||||
|
'backbone.stages.3.norm' if i[:4] == 'norm' else i
|
||||||
|
for i in splited_key
|
||||||
|
]
|
||||||
|
splited_key = [
|
||||||
|
'head.fc' if i[:4] == 'head' else i for i in splited_key
|
||||||
|
]
|
||||||
|
|
||||||
|
new_key = '.'.join(splited_key)
|
||||||
|
converted_state_dict[new_key] = blobs[key]
|
||||||
|
|
||||||
|
torch.save(converted_state_dict, dst)
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Convert model keys')
|
||||||
|
parser.add_argument('src', help='src detectron model path')
|
||||||
|
parser.add_argument('dst', help='save path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dst = Path(args.dst)
|
||||||
|
if dst.suffix != '.pth':
|
||||||
|
print('The path should contain the name of the pth format file.')
|
||||||
|
exit(1)
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
convert(args.src, args.dst)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue