diff --git a/README.md b/README.md
index ad948bed..c8ff976b 100644
--- a/README.md
+++ b/README.md
@@ -151,6 +151,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
+- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index d6e9665e..555e3b44 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -150,6 +150,7 @@ mim install -e .
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
+- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
diff --git a/configs/_base_/datasets/imagenet_bs16_pil_bicubic_384.py b/configs/_base_/datasets/imagenet_bs16_pil_bicubic_384.py
new file mode 100644
index 00000000..4ca5c828
--- /dev/null
+++ b/configs/_base_/datasets/imagenet_bs16_pil_bicubic_384.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'ImageNet'
+data_preprocessor = dict(
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=384,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(type='PackClsInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
+ dict(type='PackClsInputs'),
+]
+
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ persistent_workers=True,
+)
+
+val_dataloader = dict(
+ batch_size=16,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/val.txt',
+ data_prefix='val',
+ pipeline=test_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ persistent_workers=True,
+)
+val_evaluator = dict(type='Accuracy', topk=(1, 5))
+
+# If you want standard test, please manually configure the test dataset
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
diff --git a/configs/_base_/datasets/imagenet_bs8_pil_bicubic_320.py b/configs/_base_/datasets/imagenet_bs8_pil_bicubic_320.py
new file mode 100644
index 00000000..f65e70d9
--- /dev/null
+++ b/configs/_base_/datasets/imagenet_bs8_pil_bicubic_320.py
@@ -0,0 +1,63 @@
+# dataset settings
+dataset_type = 'ImageNet'
+data_preprocessor = dict(
+ # RGB format normalization parameters
+ mean=[122.5, 122.5, 122.5],
+ std=[122.5, 122.5, 122.5],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomResizedCrop',
+ scale=320,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(type='PackClsInputs'),
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeEdge',
+ scale=int(320 / 224 * 256),
+ edge='short',
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type='CenterCrop', crop_size=320),
+ dict(type='PackClsInputs'),
+]
+
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ persistent_workers=True,
+)
+
+val_dataloader = dict(
+ batch_size=8,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/val.txt',
+ data_prefix='val',
+ pipeline=test_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ persistent_workers=True,
+)
+val_evaluator = dict(type='Accuracy', topk=(1, 5))
+
+# If you want standard test, please manually configure the test dataset
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
diff --git a/configs/_base_/models/replknet-31B_in1k.py b/configs/_base_/models/replknet-31B_in1k.py
new file mode 100644
index 00000000..a6839537
--- /dev/null
+++ b/configs/_base_/models/replknet-31B_in1k.py
@@ -0,0 +1,25 @@
+from mmcls.models import build_classifier
+
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='RepLKNet',
+ arch='31B',
+ out_indices=(3, ),
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1024,
+ loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
+ topk=(1, 5),
+ ))
+
+if __name__ == '__main__':
+ # model.pop('type')
+ model = build_classifier(model)
+ model.eval()
+ print('------------------- training-time model -------------')
+ for i in model.state_dict().keys():
+ print(i)
diff --git a/configs/_base_/models/replknet-31L_in1k.py b/configs/_base_/models/replknet-31L_in1k.py
new file mode 100644
index 00000000..7830fb06
--- /dev/null
+++ b/configs/_base_/models/replknet-31L_in1k.py
@@ -0,0 +1,15 @@
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='RepLKNet',
+ arch='31L',
+ out_indices=(3, ),
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1536,
+ loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
+ topk=(1, 5),
+ ))
diff --git a/configs/_base_/models/replknet-XL_in1k.py b/configs/_base_/models/replknet-XL_in1k.py
new file mode 100644
index 00000000..b63f3459
--- /dev/null
+++ b/configs/_base_/models/replknet-XL_in1k.py
@@ -0,0 +1,15 @@
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='RepLKNet',
+ arch='XL',
+ out_indices=(3, ),
+ ),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=2048,
+ loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
+ topk=(1, 5),
+ ))
diff --git a/configs/replknet/README.md b/configs/replknet/README.md
new file mode 100644
index 00000000..1714b108
--- /dev/null
+++ b/configs/replknet/README.md
@@ -0,0 +1,95 @@
+# RepLKNet
+
+> [Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs](https://arxiv.org/abs/2203.06717)
+
+
+
+## Abstract
+
+We revisit large kernel design in modern convolutional neural networks (CNNs). Inspired by recent advances in vision transformers (ViTs), in this paper, we demonstrate that using a few large convolutional kernels instead of a stack of small kernels could be a more powerful paradigm. We suggested five guidelines, e.g., applying re-parameterized large depth-wise convolutions, to design efficient highperformance large-kernel CNNs. Following the guidelines, we propose RepLKNet, a pure CNN architecture whose kernel size is as large as 31×31, in contrast to commonly used 3×3. RepLKNet greatly closes the performance gap between CNNs and ViTs, e.g., achieving comparable or superior results than Swin Transformer on ImageNet and a few typical downstream tasks, with lower latency. RepLKNet also shows nice scalability to big data and large models, obtaining 87.8% top-1 accuracy on ImageNet and 56.0% mIoU on ADE20K, which is very competitive among the state-of-the-arts with similar model sizes. Our study further reveals that, in contrast to small-kernel CNNs, large kernel CNNs have much larger effective receptive fields and higher shape bias rather than texture bias.
+
+
+

+
+
+## Results and models
+
+### ImageNet-1k
+
+| Model | Resolution | Pretrained Dataset | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+| :------------: | :--------: | :----------------: | :-----------------------------: | :-----------------------------: | :-------: | :-------: | :------------------------------------: | :--------------------------------------: |
+| RepLKNet-31B\* | 224x224 | From Scratch | 79.9(train) \| 79.5 (deploy) | 15.6 (train) \| 15.4 (deploy) | 83.48 | 96.57 | [config (train)](./replknet-31B_32xb64_in1k.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth) |
+| RepLKNet-31B\* | 384x384 | From Scratch | 79.9(train) \| 79.5 (deploy) | 46.0 (train) \| 45.3 (deploy) | 84.84 | 97.34 | [config (train)](./replknet-31B_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k-384px_20221118-03a170ce.pth) |
+| RepLKNet-31B\* | 224x224 | ImageNet-21K | 79.9(train) \| 79.5 (deploy) | 15.6 (train) \| 15.4 (deploy) | 85.20 | 97.56 | [config (train)](./replknet-31B_32xb64_in1k.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth) |
+| RepLKNet-31B\* | 384x384 | ImageNet-21K | 79.9(train) \| 79.5 (deploy) | 46.0 (train) \| 45.3 (deploy) | 85.99 | 97.75 | [config (train)](./replknet-31B_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31B-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth) |
+| RepLKNet-31L\* | 384x384 | ImageNet-21K | 172.7(train) \| 172.0 (deploy) | 97.2 (train) \| 97.0 (deploy) | 86.63 | 98.00 | [config (train)](./replknet-31L_32xb64_in1k-384px.py) \| [config (deploy)](./deploy/replknet-31L-deploy_32xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31L_in21k-pre_3rdparty_in1k-384px_20221118-dc3fc07c.pth) |
+| RepLKNet-XL\* | 320x320 | MegData-73M | 335.4(train) \| 335.0 (deploy) | 129.6 (train) \| 129.0 (deploy) | 87.57 | 98.39 | [config (train)](./replknet-XL_32xb64_in1k-320px.py) \| [config (deploy)](./deploy/replknet-XL-deploy_32xb64_in1k-320px.py) | [model](https://download.openmmlab.com/mmclassification/v0/replknet/replknet-XL_meg73m-pre_3rdparty_in1k-320px_20221118-88259b1d.pth) |
+
+*Models with * are converted from the [official repo](https://github.com/DingXiaoH/RepVGG). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
+
+## How to use
+
+The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.
+
+### Use tool
+
+Use provided tool to reparameterize the given model and save the checkpoint:
+
+```bash
+python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
+```
+
+`${CFG_PATH}` is the config file, `${SRC_CKPT_PATH}` is the source chenpoint file, `${TARGET_CKPT_PATH}` is the target deploy weight file path.
+
+To use reparameterized weights, the config file must switch to the deploy config files.
+
+```bash
+python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy
+```
+
+### In the code
+
+Use `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()` to switch to the deploy mode. For example:
+
+```python
+from mmcls.models import build_backbone
+
+backbone_cfg=dict(type='RepLKNet',arch='31B'),
+backbone = build_backbone(backbone_cfg)
+backbone.switch_to_deploy()
+```
+
+or
+
+```python
+from mmcls.models import build_classifier
+
+cfg = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='RepLKNet',
+ arch='31B'),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1024,
+ loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
+ topk=(1, 5),
+ ))
+
+classifier = build_classifier(cfg)
+classifier.backbone.switch_to_deploy()
+```
+
+## Citation
+
+```
+@inproceedings{ding2022scaling,
+ title={Scaling up your kernels to 31x31: Revisiting large kernel design in cnns},
+ author={Ding, Xiaohan and Zhang, Xiangyu and Han, Jungong and Ding, Guiguang},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={11963--11975},
+ year={2022}
+}
+```
diff --git a/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k-384px.py b/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k-384px.py
new file mode 100644
index 00000000..a14fe63e
--- /dev/null
+++ b/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k-384px.py
@@ -0,0 +1,3 @@
+_base_ = '../replknet-31B_32xb64_in1k-384px.py'
+
+model = dict(backbone=dict(small_kernel_merged=True))
diff --git a/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k.py b/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k.py
new file mode 100644
index 00000000..4f92c494
--- /dev/null
+++ b/configs/replknet/deploy/replknet-31B-deploy_32xb64_in1k.py
@@ -0,0 +1,3 @@
+_base_ = '../replknet-31B_32xb64_in1k.py'
+
+model = dict(backbone=dict(small_kernel_merged=True))
diff --git a/configs/replknet/deploy/replknet-31L-deploy_32xb64_in1k-384px.py b/configs/replknet/deploy/replknet-31L-deploy_32xb64_in1k-384px.py
new file mode 100644
index 00000000..63e590f9
--- /dev/null
+++ b/configs/replknet/deploy/replknet-31L-deploy_32xb64_in1k-384px.py
@@ -0,0 +1,3 @@
+_base_ = '../replknet-31L_32xb64_in1k-384px.py'
+
+model = dict(backbone=dict(small_kernel_merged=True))
diff --git a/configs/replknet/deploy/replknet-XL-deploy_32xb64_in1k-320px.py b/configs/replknet/deploy/replknet-XL-deploy_32xb64_in1k-320px.py
new file mode 100644
index 00000000..a0a8ed5f
--- /dev/null
+++ b/configs/replknet/deploy/replknet-XL-deploy_32xb64_in1k-320px.py
@@ -0,0 +1,3 @@
+_base_ = '../replknet-XL_32xb64_in1k-320px.py'
+
+model = dict(backbone=dict(small_kernel_merged=True))
diff --git a/configs/replknet/metafile.yml b/configs/replknet/metafile.yml
new file mode 100644
index 00000000..05f19b79
--- /dev/null
+++ b/configs/replknet/metafile.yml
@@ -0,0 +1,129 @@
+Collections:
+ - Name: RepLKNet
+ Metadata:
+ Training Data: ImageNet-1k
+ Architecture:
+ - Large-Kernel Convolution
+ - VGG-style Neural Network
+ Paper:
+ URL: https://arxiv.org/abs/2203.06717
+ Title: 'Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs'
+ README: configs/replknet/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/backbones/replknet.py
+ Version: v1.0.0rc3
+
+Models:
+ - Name: replknet-31B_3rdparty_in1k
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-31B_32xb64_in1k.py
+ Metadata:
+ FLOPs: 15636547584
+ Parameters: 79864168
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 83.48
+ Top 5 Accuracy: 96.57
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1azQUiCxK9feYVkkrPqwVPBtNsTzDrX7S&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
+
+ - Name: replknet-31B_3rdparty_in1k-384px
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-31B_32xb64_in1k-384px.py
+ Metadata:
+ FLOPs: 45952303104
+ Parameters: 79864168
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 84.84
+ Top 5 Accuracy: 97.34
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k-384px_20221118-03a170ce.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
+
+ - Name: replknet-31B_in21k-pre_3rdparty_in1k
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-31B_32xb64_in1k.py
+ Metadata:
+ Training Data:
+ - ImageNet-21k
+ - ImageNet-1k
+ FLOPs: 15636547584
+ Parameters: 79864168
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 85.20
+ Top 5 Accuracy: 97.56
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1DslZ2voXZQR1QoFY9KnbsHAeF84hzS0s&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
+
+ - Name: replknet-31B_in21k-pre_3rdparty_in1k-384px
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-31B_32xb64_in1k-384px.py
+ Metadata:
+ Training Data:
+ - ImageNet-21k
+ - ImageNet-1k
+ FLOPs: 45952303104
+ Parameters: 79864168
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 85.99
+ Top 5 Accuracy: 97.75
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1Sc46BWdXXm2fVP-K_hKKU_W8vAB-0duX&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
+
+ - Name: replknet-31L_in21k-pre_3rdparty_in1k-384px
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-31L_32xb64_in1k-384px.py
+ Metadata:
+ Training Data:
+ - ImageNet-21k
+ - ImageNet-1k
+ FLOPs: 97240006656
+ Parameters: 172671016
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 86.63
+ Top 5 Accuracy: 98.00
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31L_in21k-pre_3rdparty_in1k-384px_20221118-dc3fc07c.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1JYXoNHuRvC33QV1pmpzMTKEni1hpWfBl&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
+
+ - Name: replknet-XL_meg73m-pre_3rdparty_in1k-320px
+ In Collection: RepLKNet
+ Config: configs/replknet/replknet-XL_32xb64_in1k-320px.py
+ Metadata:
+ Training Data:
+ - MegData-73M
+ - ImageNet-1k
+ FLOPs: 129570201600
+ Parameters: 335435752
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 87.57
+ Top 5 Accuracy: 98.39
+ Weights: https://download.openmmlab.com/mmclassification/v0/replknet/replknet-XL_meg73m-pre_3rdparty_in1k-320px_20221118-88259b1d.pth
+ Converted From:
+ Weights: https://drive.google.com/u/0/uc?id=1tPC60El34GntXByIRHb-z-Apm4Y5LX1T&export=download
+ Code: https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/replknet.py
diff --git a/configs/replknet/replknet-31B_32xb64_in1k-384px.py b/configs/replknet/replknet-31B_32xb64_in1k-384px.py
new file mode 100644
index 00000000..4e714f34
--- /dev/null
+++ b/configs/replknet/replknet-31B_32xb64_in1k-384px.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../_base_/models/replknet-31B_in1k.py',
+ '../_base_/datasets/imagenet_bs16_pil_bicubic_384.py',
+ '../_base_/schedules/imagenet_bs256_coslr.py',
+ '../_base_/default_runtime.py'
+]
+
+# schedule settings
+param_scheduler = dict(
+ type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
+
+train_cfg = dict(by_epoch=True, max_epochs=300)
diff --git a/configs/replknet/replknet-31B_32xb64_in1k.py b/configs/replknet/replknet-31B_32xb64_in1k.py
new file mode 100644
index 00000000..cf06f2d8
--- /dev/null
+++ b/configs/replknet/replknet-31B_32xb64_in1k.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../_base_/models/replknet-31B_in1k.py',
+ '../_base_/datasets/imagenet_bs32_pil_bicubic.py',
+ '../_base_/schedules/imagenet_bs256_coslr.py',
+ '../_base_/default_runtime.py'
+]
+
+# schedule settings
+param_scheduler = dict(
+ type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
+
+train_cfg = dict(by_epoch=True, max_epochs=300)
diff --git a/configs/replknet/replknet-31L_32xb64_in1k-384px.py b/configs/replknet/replknet-31L_32xb64_in1k-384px.py
new file mode 100644
index 00000000..8cdab249
--- /dev/null
+++ b/configs/replknet/replknet-31L_32xb64_in1k-384px.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../_base_/models/replknet-31L_in1k.py',
+ '../_base_/datasets/imagenet_bs16_pil_bicubic_384.py',
+ '../_base_/schedules/imagenet_bs256_coslr.py',
+ '../_base_/default_runtime.py'
+]
+
+# schedule settings
+param_scheduler = dict(
+ type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
+
+train_cfg = dict(by_epoch=True, max_epochs=300)
diff --git a/configs/replknet/replknet-XL_32xb64_in1k-320px.py b/configs/replknet/replknet-XL_32xb64_in1k-320px.py
new file mode 100644
index 00000000..9b0aab11
--- /dev/null
+++ b/configs/replknet/replknet-XL_32xb64_in1k-320px.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../_base_/models/replknet-XL_in1k.py',
+ '../_base_/datasets/imagenet_bs8_pil_bicubic_320.py',
+ '../_base_/schedules/imagenet_bs256_coslr.py',
+ '../_base_/default_runtime.py'
+]
+
+# schedule settings
+param_scheduler = dict(
+ type='CosineAnnealingLR', T_max=300, by_epoch=True, begin=0, end=300)
+
+train_cfg = dict(by_epoch=True, max_epochs=300)
diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst
index 8442b7a2..f72539ea 100644
--- a/docs/en/api/models.rst
+++ b/docs/en/api/models.rst
@@ -85,6 +85,7 @@ Backbones
PCPVT
PoolFormer
RegNet
+ RepLKNet
RepMLPNet
RepVGG
Res2Net
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index 35f410b7..6bfe8af0 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -23,6 +23,7 @@ from .mobilevit import MobileViT
from .mvit import MViT
from .poolformer import PoolFormer
from .regnet import RegNet
+from .replknet import RepLKNet
from .repmlp import RepMLPNet
from .repvgg import RepVGG
from .res2net import Res2Net
@@ -82,6 +83,7 @@ __all__ = [
'CSPResNet',
'CSPResNeXt',
'CSPNet',
+ 'RepLKNet',
'RepMLPNet',
'PoolFormer',
'DenseNet',
diff --git a/mmcls/models/backbones/replknet.py b/mmcls/models/backbones/replknet.py
new file mode 100644
index 00000000..3611c8b7
--- /dev/null
+++ b/mmcls/models/backbones/replknet.py
@@ -0,0 +1,668 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmcls.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+def conv_bn(in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ dilation=1,
+ norm_cfg=dict(type='BN')):
+ """Construct a sequential conv and bn.
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the convolution.
+ stride (int): stride of the convolution.
+ padding (int): stride of the convolution.
+ groups (int): groups of the convolution.
+ dilation (int): dilation of the convolution. Default to 1.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+
+ Returns:
+ nn.Sequential(): A conv layer and a batch norm layer.
+ """
+ if padding is None:
+ padding = kernel_size // 2
+ result = nn.Sequential()
+ result.add_module(
+ 'conv',
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=False))
+ result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1])
+ return result
+
+
+def conv_bn_relu(in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ dilation=1):
+ """Construct a sequential conv, bn and relu.
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the convolution.
+ stride (int): stride of the convolution.
+ padding (int): stride of the convolution.
+ groups (int): groups of the convolution.
+ dilation (int): dilation of the convolution. Default to 1.
+
+ Returns:
+ nn.Sequential(): A conv layer, batch norm layer and a relu function.
+ """
+
+ if padding is None:
+ padding = kernel_size // 2
+ result = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ dilation=dilation)
+ result.add_module('nonlinear', nn.ReLU())
+ return result
+
+
+def fuse_bn(conv, bn):
+ """Fuse the parameters in a branch with a conv and bn.
+
+ Args:
+ conv (nn.Conv2d): The convolution module to fuse.
+ bn (nn.BatchNorm2d): The batch normalization to fuse.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
+ fusing the parameters of conv and bn in one branch.
+ The first element is the weight and the second is the bias.
+ """
+ kernel = conv.weight
+ running_mean = bn.running_mean
+ running_var = bn.running_var
+ gamma = bn.weight
+ beta = bn.bias
+ eps = bn.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+
+class ReparamLargeKernelConv(BaseModule):
+ """Super large kernel implemented by with large convolutions.
+
+ Input: Tensor with shape [B, C, H, W].
+ Output: Tensor with shape [B, C, H, W].
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the large convolution.
+ stride (int): stride of the large convolution.
+ groups (int): groups of the large convolution.
+ small_kernel (int): kernel_size of the small convolution.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups,
+ small_kernel,
+ small_kernel_merged=False,
+ init_cfg=None):
+ super(ReparamLargeKernelConv, self).__init__(init_cfg)
+ self.kernel_size = kernel_size
+ self.small_kernel = small_kernel
+ self.small_kernel_merged = small_kernel_merged
+ # We assume the conv does not change the feature map size,
+ # so padding = k//2.
+ # Otherwise, you may configure padding as you wish,
+ # and change the padding of small_conv accordingly.
+ padding = kernel_size // 2
+ if small_kernel_merged:
+ self.lkb_reparam = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=1,
+ groups=groups,
+ bias=True)
+ else:
+ self.lkb_origin = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=1,
+ groups=groups)
+ if small_kernel is not None:
+ assert small_kernel <= kernel_size
+ self.small_conv = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=small_kernel,
+ stride=stride,
+ padding=small_kernel // 2,
+ groups=groups,
+ dilation=1)
+
+ def forward(self, inputs):
+ if hasattr(self, 'lkb_reparam'):
+ out = self.lkb_reparam(inputs)
+ else:
+ out = self.lkb_origin(inputs)
+ if hasattr(self, 'small_conv'):
+ out += self.small_conv(inputs)
+ return out
+
+ def get_equivalent_kernel_bias(self):
+ eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
+ if hasattr(self, 'small_conv'):
+ small_k, small_b = fuse_bn(self.small_conv.conv,
+ self.small_conv.bn)
+ eq_b += small_b
+ # add to the central part
+ eq_k += nn.functional.pad(
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
+ return eq_k, eq_b
+
+ def merge_kernel(self):
+ """Switch the model structure from training mode to deployment mode."""
+ if self.small_kernel_merged:
+ return
+ eq_k, eq_b = self.get_equivalent_kernel_bias()
+ self.lkb_reparam = nn.Conv2d(
+ in_channels=self.lkb_origin.conv.in_channels,
+ out_channels=self.lkb_origin.conv.out_channels,
+ kernel_size=self.lkb_origin.conv.kernel_size,
+ stride=self.lkb_origin.conv.stride,
+ padding=self.lkb_origin.conv.padding,
+ dilation=self.lkb_origin.conv.dilation,
+ groups=self.lkb_origin.conv.groups,
+ bias=True)
+
+ self.lkb_reparam.weight.data = eq_k
+ self.lkb_reparam.bias.data = eq_b
+ self.__delattr__('lkb_origin')
+ if hasattr(self, 'small_conv'):
+ self.__delattr__('small_conv')
+
+ self.small_kernel_merged = True
+
+
+class ConvFFN(BaseModule):
+ """Mlp implemented by with 1*1 convolutions.
+
+ Input: Tensor with shape [B, C, H, W].
+ Output: Tensor with shape [B, C, H, W].
+
+ Args:
+ in_channels (int): Dimension of input features.
+ internal_channels (int): Dimension of hidden features.
+ out_channels (int): Dimension of output features.
+ drop_path (float): Stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ internal_channels,
+ out_channels,
+ drop_path,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super(ConvFFN, self).__init__(init_cfg)
+ self.drop_path = DropPath(
+ drop_prob=drop_path) if drop_path > 0. else nn.Identity()
+ self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1]
+ self.pw1 = conv_bn(
+ in_channels=in_channels,
+ out_channels=internal_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1)
+ self.pw2 = conv_bn(
+ in_channels=internal_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1)
+ self.nonlinear = build_activation_layer(act_cfg)
+
+ def forward(self, x):
+ out = self.preffn_bn(x)
+ out = self.pw1(out)
+ out = self.nonlinear(out)
+ out = self.pw2(out)
+ return x + self.drop_path(out)
+
+
+class RepLKBlock(BaseModule):
+ """RepLKBlock for RepLKNet backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ dw_channels (int): The intermediate channels of the block,
+ i.e., input channels of the large kernel convolution.
+ block_lk_size (int): size of the super large kernel. Defaults: 31.
+ small_kernel (int): size of the parallel small kernel. Defaults: 5.
+ drop_path (float): Stochastic depth rate. Defaults: 0.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ act_cfg (dict): Config dict for activation layer.
+ Default to ``dict(type='ReLU')``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default to None
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ block_lk_size,
+ small_kernel,
+ drop_path,
+ small_kernel_merged=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ init_cfg=None):
+ super(RepLKBlock, self).__init__(init_cfg)
+ self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
+ self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
+ self.large_kernel = ReparamLargeKernelConv(
+ in_channels=dw_channels,
+ out_channels=dw_channels,
+ kernel_size=block_lk_size,
+ stride=1,
+ groups=dw_channels,
+ small_kernel=small_kernel,
+ small_kernel_merged=small_kernel_merged)
+ self.lk_nonlinear = build_activation_layer(act_cfg)
+ self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1]
+ self.drop_path = DropPath(
+ drop_prob=drop_path) if drop_path > 0. else nn.Identity()
+ # print('drop path:', self.drop_path)
+
+ def forward(self, x):
+ out = self.prelkb_bn(x)
+ out = self.pw1(out)
+ out = self.large_kernel(out)
+ out = self.lk_nonlinear(out)
+ out = self.pw2(out)
+ return x + self.drop_path(out)
+
+
+class RepLKNetStage(BaseModule):
+ """
+ generate RepLKNet blocks for a stage
+ return: RepLKNet blocks
+
+ Args:
+ channels (int): The input channels of the stage.
+ num_blocks (int): The number of blocks of the stage.
+ stage_lk_size (int): size of the super large kernel. Defaults: 31.
+ drop_path (float): Stochastic depth rate. Defaults: 0.
+ small_kernel (int): size of the parallel small kernel. Defaults: 5.
+ dw_ratio (float): The intermediate channels
+ expansion ratio of the block. Defaults: 1.
+ ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default to False.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ norm_intermediate_features (bool): Construct and config norm layer
+ or not.
+ Using True will normalize the intermediate features for
+ downstream dense prediction tasks.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default to None
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_blocks,
+ stage_lk_size,
+ drop_path,
+ small_kernel,
+ dw_ratio=1,
+ ffn_ratio=4,
+ with_cp=False, # train with torch.utils.checkpoint to save memory
+ small_kernel_merged=False,
+ norm_intermediate_features=False,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ super(RepLKNetStage, self).__init__(init_cfg)
+ self.with_cp = with_cp
+ blks = []
+ for i in range(num_blocks):
+ block_drop_path = drop_path[i] if isinstance(drop_path,
+ list) else drop_path
+ # Assume all RepLK Blocks within a stage share the same lk_size.
+ # You may tune it on your own model.
+ replk_block = RepLKBlock(
+ in_channels=channels,
+ dw_channels=int(channels * dw_ratio),
+ block_lk_size=stage_lk_size,
+ small_kernel=small_kernel,
+ drop_path=block_drop_path,
+ small_kernel_merged=small_kernel_merged)
+ convffn_block = ConvFFN(
+ in_channels=channels,
+ internal_channels=int(channels * ffn_ratio),
+ out_channels=channels,
+ drop_path=block_drop_path)
+ blks.append(replk_block)
+ blks.append(convffn_block)
+ self.blocks = nn.ModuleList(blks)
+ if norm_intermediate_features:
+ self.norm = build_norm_layer(norm_cfg, channels)[1]
+ else:
+ self.norm = nn.Identity()
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.with_cp:
+ x = checkpoint.checkpoint(blk, x) # Save training memory
+ else:
+ x = blk(x)
+ return x
+
+
+@MODELS.register_module()
+class RepLKNet(BaseBackbone):
+ """RepLKNet backbone.
+
+ A PyTorch impl of :
+ `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
+ `_
+
+ Args:
+ arch (str | dict): The parameter of RepLKNet.
+ If it's a dict, it should contain the following keys:
+
+ - large_kernel_sizes (Sequence[int]):
+ Large kernel size in each stage.
+ - layers (Sequence[int]): Number of blocks in each stage.
+ - channels (Sequence[int]): Number of channels in each stage.
+ - small_kernel (int): size of the parallel small kernel.
+ - dw_ratio (float): The intermediate channels
+ expansion ratio of the block.
+ in_channels (int): Number of input image channels. Default to 3.
+ ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
+ out_indices (Sequence[int]): Output from which stages.
+ Default to (3, ).
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default to (2, 2, 2, 2).
+ dilations (Sequence[int]): Dilation of each stage.
+ Default to (1, 1, 1, 1).
+ frozen_stages (int): Stages to be frozen
+ (all param fixed). -1 means not freezing any parameters.
+ Default to -1.
+ conv_cfg (dict | None): The config dict for conv layers.
+ Default to None.
+ norm_cfg (dict): The config dict for norm layers.
+ Default to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Default to ``dict(type='ReLU')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default to False.
+ deploy (bool): Whether to switch the model structure to deployment
+ mode. Default to False.
+ norm_intermediate_features (bool): Construct and
+ config norm layer or not.
+ Using True will normalize the intermediate features
+ for downstream dense prediction tasks.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ arch_settings = {
+ '31B':
+ dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[128, 256, 512, 1024],
+ small_kernel=5,
+ dw_ratio=1),
+ '31L':
+ dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[192, 384, 768, 1536],
+ small_kernel=5,
+ dw_ratio=1),
+ 'XL':
+ dict(
+ large_kernel_sizes=[27, 27, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[256, 512, 1024, 2048],
+ small_kernel=None,
+ dw_ratio=1.5),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ ffn_ratio=4,
+ out_indices=(3, ),
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ drop_path_rate=0.3,
+ small_kernel_merged=False,
+ norm_intermediate_features=False,
+ norm_eval=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(RepLKNet, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ assert len(arch['layers']) == len(
+ arch['channels']) == len(strides) == len(dilations)
+ assert max(out_indices) < len(arch['layers'])
+
+ self.arch = arch
+ self.in_channels = in_channels
+ self.out_indices = out_indices
+ self.strides = strides
+ self.dilations = dilations
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.with_cp = with_cp
+ self.drop_path_rate = drop_path_rate
+ self.small_kernel_merged = small_kernel_merged
+ self.norm_eval = norm_eval
+ self.norm_intermediate_features = norm_intermediate_features
+
+ self.out_indices = out_indices
+
+ base_width = self.arch['channels'][0]
+ self.norm_intermediate_features = norm_intermediate_features
+ self.num_stages = len(self.arch['layers'])
+ self.stem = nn.ModuleList([
+ conv_bn_relu(
+ in_channels=in_channels,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=base_width),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=base_width)
+ ])
+ # stochastic depth. We set block-wise drop-path rate.
+ # The higher level blocks are more likely to be dropped.
+ # This implementation follows Swin.
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate,
+ sum(self.arch['layers']))
+ ]
+ self.stages = nn.ModuleList()
+ self.transitions = nn.ModuleList()
+ for stage_idx in range(self.num_stages):
+ layer = RepLKNetStage(
+ channels=self.arch['channels'][stage_idx],
+ num_blocks=self.arch['layers'][stage_idx],
+ stage_lk_size=self.arch['large_kernel_sizes'][stage_idx],
+ drop_path=dpr[sum(self.arch['layers'][:stage_idx]
+ ):sum(self.arch['layers'][:stage_idx + 1])],
+ small_kernel=self.arch['small_kernel'],
+ dw_ratio=self.arch['dw_ratio'],
+ ffn_ratio=ffn_ratio,
+ with_cp=with_cp,
+ small_kernel_merged=small_kernel_merged,
+ norm_intermediate_features=(stage_idx in out_indices))
+ self.stages.append(layer)
+ if stage_idx < len(self.arch['layers']) - 1:
+ transition = nn.Sequential(
+ conv_bn_relu(
+ self.arch['channels'][stage_idx],
+ self.arch['channels'][stage_idx + 1],
+ 1,
+ 1,
+ 0,
+ groups=1),
+ conv_bn_relu(
+ self.arch['channels'][stage_idx + 1],
+ self.arch['channels'][stage_idx + 1],
+ 3,
+ stride=2,
+ padding=1,
+ groups=self.arch['channels'][stage_idx + 1]))
+ self.transitions.append(transition)
+
+ def forward_features(self, x):
+ x = self.stem[0](x)
+ for stem_layer in self.stem[1:]:
+ if self.with_cp:
+ x = checkpoint.checkpoint(stem_layer, x) # save memory
+ else:
+ x = stem_layer(x)
+
+ # Need the intermediate feature maps
+ outs = []
+ for stage_idx in range(self.num_stages):
+ x = self.stages[stage_idx](x)
+ if stage_idx in self.out_indices:
+ outs.append(self.stages[stage_idx].norm(x))
+ # For RepLKNet-XL normalize the features
+ # before feeding them into the heads
+ if stage_idx < self.num_stages - 1:
+ x = self.transitions[stage_idx](x)
+ return outs
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return tuple(x)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ for i in range(self.frozen_stages):
+ stage = self.stages[i]
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(RepLKNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def switch_to_deploy(self):
+ for m in self.modules():
+ if hasattr(m, 'merge_kernel'):
+ m.merge_kernel()
+ self.small_kernel_merged = True
diff --git a/model-index.yml b/model-index.yml
index c190b5a8..fbcaebcb 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -38,3 +38,4 @@ Import:
- configs/hornet/metafile.yml
- configs/mobilevit/metafile.yml
- configs/davit/metafile.yml
+ - configs/replknet/metafile.yml
diff --git a/tests/test_models/test_backbones/test_replknet.py b/tests/test_models/test_backbones/test_replknet.py
new file mode 100644
index 00000000..a7ad48ab
--- /dev/null
+++ b/tests/test_models/test_backbones/test_replknet.py
@@ -0,0 +1,304 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import tempfile
+
+import pytest
+import torch
+from mmengine.runner import load_checkpoint, save_checkpoint
+from torch import nn
+from torch.nn.modules import GroupNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmcls.models.backbones import RepLKNet
+from mmcls.models.backbones.replknet import ReparamLargeKernelConv
+
+
+def check_norm_state(modules, train_state):
+ """Check if norm layer is in correct train state."""
+ for mod in modules:
+ if isinstance(mod, _BatchNorm):
+ if mod.training != train_state:
+ return False
+ return True
+
+
+def is_norm(modules):
+ """Check if is one of the norms."""
+ if isinstance(modules, (GroupNorm, _BatchNorm)):
+ return True
+ return False
+
+
+def is_replk_block(modules):
+ if isinstance(modules, ReparamLargeKernelConv):
+ return True
+ return False
+
+
+def test_replknet_replkblock():
+ # Test ReparamLargeKernelConv with in_channels != out_channels,
+ # kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
+ block = ReparamLargeKernelConv(
+ 5, 10, kernel_size=31, stride=1, groups=5, small_kernel=5)
+ block.eval()
+ x = torch.randn(1, 5, 64, 64)
+ x_out_not_deploy = block(x)
+ assert block.small_kernel <= block.kernel_size
+ assert not hasattr(block, 'lkb_reparam')
+ assert hasattr(block, 'lkb_origin')
+ assert hasattr(block, 'small_conv')
+ assert x_out_not_deploy.shape == torch.Size((1, 10, 64, 64))
+ block.merge_kernel()
+ assert block.small_kernel_merged is True
+ x_out_deploy = block(x)
+ assert x_out_deploy.shape == torch.Size((1, 10, 64, 64))
+ assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
+
+ # Test ReparamLargeKernelConv with in_channels == out_channels,
+ # kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
+ block = ReparamLargeKernelConv(
+ 12, 12, kernel_size=31, stride=1, groups=12, small_kernel=5)
+ block.eval()
+ x = torch.randn(1, 12, 64, 64)
+ x_out_not_deploy = block(x)
+ assert block.small_kernel <= block.kernel_size
+ assert not hasattr(block, 'lkb_reparam')
+ assert hasattr(block, 'lkb_origin')
+ assert hasattr(block, 'small_conv')
+ assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
+ block.merge_kernel()
+ assert block.small_kernel_merged is True
+ x_out_deploy = block(x)
+ assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
+ assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
+
+ # Test ReparamLargeKernelConv with in_channels == out_channels,
+ # kernel_size = 31, stride = 2, groups=in_channels, small_kernel = 5
+ block = ReparamLargeKernelConv(
+ 16, 16, kernel_size=31, stride=2, groups=16, small_kernel=5)
+ block.eval()
+ x = torch.randn(1, 16, 64, 64)
+ x_out_not_deploy = block(x)
+ assert block.small_kernel <= block.kernel_size
+ assert not hasattr(block, 'lkb_reparam')
+ assert hasattr(block, 'lkb_origin')
+ assert hasattr(block, 'small_conv')
+ assert x_out_not_deploy.shape == torch.Size((1, 16, 32, 32))
+ block.merge_kernel()
+ assert block.small_kernel_merged is True
+ x_out_deploy = block(x)
+ assert x_out_deploy.shape == torch.Size((1, 16, 32, 32))
+ assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
+
+ # Test ReparamLargeKernelConv with in_channels == out_channels,
+ # kernel_size = 27, stride = 1, groups=in_channels, small_kernel = 5
+ block = ReparamLargeKernelConv(
+ 12, 12, kernel_size=27, stride=1, groups=12, small_kernel=5)
+ block.eval()
+ x = torch.randn(1, 12, 48, 48)
+ x_out_not_deploy = block(x)
+ assert block.small_kernel <= block.kernel_size
+ assert not hasattr(block, 'lkb_reparam')
+ assert hasattr(block, 'lkb_origin')
+ assert hasattr(block, 'small_conv')
+ assert x_out_not_deploy.shape == torch.Size((1, 12, 48, 48))
+ block.merge_kernel()
+ assert block.small_kernel_merged is True
+ x_out_deploy = block(x)
+ assert x_out_deploy.shape == torch.Size((1, 12, 48, 48))
+ assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
+
+ # Test ReparamLargeKernelConv with in_channels == out_channels,
+ # kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 7
+ block = ReparamLargeKernelConv(
+ 12, 12, kernel_size=31, stride=1, groups=12, small_kernel=7)
+ block.eval()
+ x = torch.randn(1, 12, 64, 64)
+ x_out_not_deploy = block(x)
+ assert block.small_kernel <= block.kernel_size
+ assert not hasattr(block, 'lkb_reparam')
+ assert hasattr(block, 'lkb_origin')
+ assert hasattr(block, 'small_conv')
+ assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
+ block.merge_kernel()
+ assert block.small_kernel_merged is True
+ x_out_deploy = block(x)
+ assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
+ assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
+
+ # Test ReparamLargeKernelConv with deploy == True
+ block = ReparamLargeKernelConv(
+ 8,
+ 8,
+ kernel_size=31,
+ stride=1,
+ groups=8,
+ small_kernel=5,
+ small_kernel_merged=True)
+ assert isinstance(block.lkb_reparam, nn.Conv2d)
+ assert not hasattr(block, 'lkb_origin')
+ assert not hasattr(block, 'small_conv')
+ x = torch.randn(1, 8, 48, 48)
+ x_out = block(x)
+ assert x_out.shape == torch.Size((1, 8, 48, 48))
+
+
+def test_replknet_backbone():
+ with pytest.raises(TypeError):
+ # arch must be str or dict
+ RepLKNet(arch=[4, 6, 16, 1])
+
+ with pytest.raises(AssertionError):
+ # arch must in arch_settings
+ RepLKNet(arch='31C')
+
+ with pytest.raises(KeyError):
+ # arch must have num_blocks and width_factor
+ arch = dict(large_kernel_sizes=[31, 29, 27, 13])
+ RepLKNet(arch=arch)
+
+ with pytest.raises(KeyError):
+ # arch must have num_blocks and width_factor
+ arch = dict(large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2])
+ RepLKNet(arch=arch)
+
+ with pytest.raises(KeyError):
+ # arch must have num_blocks and width_factor
+ arch = dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[128, 256, 512, 1024])
+ RepLKNet(arch=arch)
+
+ # len(arch['large_kernel_sizes']) == arch['layers'])
+ # == len(arch['channels'])
+ # == len(strides) == len(dilations)
+ with pytest.raises(AssertionError):
+ arch = dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[128, 256, 1024],
+ small_kernel=5,
+ dw_ratio=1)
+ RepLKNet(arch=arch)
+
+ # len(strides) must equal to 4
+ with pytest.raises(AssertionError):
+ RepLKNet('31B', strides=(2, 2, 2))
+
+ # len(dilations) must equal to 4
+ with pytest.raises(AssertionError):
+ RepLKNet('31B', strides=(2, 2, 2, 2), dilations=(1, 1, 1))
+
+ # max(out_indices) < len(arch['num_blocks'])
+ with pytest.raises(AssertionError):
+ RepLKNet('31B', out_indices=(5, ))
+
+ # Test RepLKNet norm state
+ model = RepLKNet('31B')
+ model.train()
+ assert check_norm_state(model.modules(), True)
+
+ # Test RepLKNet with first stage frozen
+ frozen_stages = 1
+ model = RepLKNet('31B', frozen_stages=frozen_stages)
+ model.train()
+ for param in model.stem.parameters():
+ assert param.requires_grad is False
+ for i in range(0, frozen_stages):
+ stage = model.stages[i]
+ for mod in stage.modules():
+ if isinstance(mod, _BatchNorm):
+ assert mod.training is False
+ for param in stage.parameters():
+ assert param.requires_grad is False
+
+ # Test RepLKNet with norm_eval
+ model = RepLKNet('31B', norm_eval=True)
+ model.train()
+ assert check_norm_state(model.modules(), False)
+
+ # Test RepLKNet forward with layer 3 forward
+ model = RepLKNet('31B', out_indices=(3, ))
+ model.init_weights()
+ model.train()
+
+ for m in model.modules():
+ if is_norm(m):
+ assert isinstance(m, _BatchNorm)
+
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert isinstance(feat, tuple)
+ assert len(feat) == 1
+ assert isinstance(feat[0], torch.Tensor)
+ assert feat[0].shape == torch.Size((1, 1024, 7, 7))
+
+ # Test RepLKNet forward
+ model_test_settings = [
+ dict(model_name='31B', out_sizes=(128, 256, 512, 1024)),
+ # dict(model_name='31L', out_sizes=(192, 384, 768, 1536)),
+ # dict(model_name='XL', out_sizes=(256, 512, 1024, 2048))
+ ]
+
+ choose_models = ['31B']
+ # Test RepLKNet model forward
+ for model_test_setting in model_test_settings:
+ if model_test_setting['model_name'] not in choose_models:
+ continue
+ model = RepLKNet(
+ model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
+ model.init_weights()
+
+ # Test Norm
+ for m in model.modules():
+ if is_norm(m):
+ assert isinstance(m, _BatchNorm)
+
+ model.train()
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[0].shape == torch.Size(
+ (1, model_test_setting['out_sizes'][0], 56, 56))
+ assert feat[1].shape == torch.Size(
+ (1, model_test_setting['out_sizes'][1], 28, 28))
+ assert feat[2].shape == torch.Size(
+ (1, model_test_setting['out_sizes'][2], 14, 14))
+ assert feat[3].shape == torch.Size(
+ (1, model_test_setting['out_sizes'][3], 7, 7))
+
+ # Test eval of "train" mode and "deploy" mode
+ gap = nn.AdaptiveAvgPool2d(output_size=(1))
+ fc = nn.Linear(model_test_setting['out_sizes'][3], 10)
+ model.eval()
+ feat = model(imgs)
+ pred = fc(gap(feat[3]).flatten(1))
+ model.switch_to_deploy()
+ for m in model.modules():
+ if isinstance(m, ReparamLargeKernelConv):
+ assert m.small_kernel_merged is True
+ feat_deploy = model(imgs)
+ pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
+ for i in range(4):
+ torch.allclose(feat[i], feat_deploy[i])
+ torch.allclose(pred, pred_deploy)
+
+
+def test_replknet_load():
+ # Test output before and load from deploy checkpoint
+ model = RepLKNet('31B', out_indices=(0, 1, 2, 3))
+ inputs = torch.randn((1, 3, 224, 224))
+ ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
+ model.switch_to_deploy()
+ model.eval()
+ outputs = model(inputs)
+
+ model_deploy = RepLKNet(
+ '31B', out_indices=(0, 1, 2, 3), small_kernel_merged=True)
+ model_deploy.eval()
+ save_checkpoint(model.state_dict(), ckpt_path)
+ load_checkpoint(model_deploy, ckpt_path, strict=True)
+
+ outputs_load = model_deploy(inputs)
+ for feat, feat_load in zip(outputs, outputs_load):
+ assert torch.allclose(feat, feat_load)
diff --git a/tests/test_models/test_backbones/test_repvgg.py b/tests/test_models/test_backbones/test_repvgg.py
index 7ac066ac..4976fdb3 100644
--- a/tests/test_models/test_backbones/test_repvgg.py
+++ b/tests/test_models/test_backbones/test_repvgg.py
@@ -342,6 +342,7 @@ def test_repvgg_load():
outputs = model(inputs)
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
+ model_deploy.eval()
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path, strict=True)
diff --git a/tools/model_converters/replknet_to_mmcls.py b/tools/model_converters/replknet_to_mmcls.py
new file mode 100644
index 00000000..584b4403
--- /dev/null
+++ b/tools/model_converters/replknet_to_mmcls.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+from collections import OrderedDict
+from pathlib import Path
+
+import torch
+
+
+def convert(src, dst):
+ print('Converting...')
+ blobs = torch.load(src, map_location='cpu')
+ converted_state_dict = OrderedDict()
+
+ for key in blobs:
+ splited_key = key.split('.')
+ print(splited_key)
+ splited_key = [
+ 'backbone.stem' if i[:4] == 'stem' else i for i in splited_key
+ ]
+ splited_key = [
+ 'backbone.stages' if i[:6] == 'stages' else i for i in splited_key
+ ]
+ splited_key = [
+ 'backbone.transitions' if i[:11] == 'transitions' else i
+ for i in splited_key
+ ]
+ splited_key = [
+ 'backbone.stages.3.norm' if i[:4] == 'norm' else i
+ for i in splited_key
+ ]
+ splited_key = [
+ 'head.fc' if i[:4] == 'head' else i for i in splited_key
+ ]
+
+ new_key = '.'.join(splited_key)
+ converted_state_dict[new_key] = blobs[key]
+
+ torch.save(converted_state_dict, dst)
+ print('Done!')
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Convert model keys')
+ parser.add_argument('src', help='src detectron model path')
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ dst = Path(args.dst)
+ if dst.suffix != '.pth':
+ print('The path should contain the name of the pth format file.')
+ exit(1)
+ dst.parent.mkdir(parents=True, exist_ok=True)
+
+ convert(args.src, args.dst)
+
+
+if __name__ == '__main__':
+ main()