diff --git a/README.md b/README.md
index eec6036b..1eab19a3 100644
--- a/README.md
+++ b/README.md
@@ -144,6 +144,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/mvit)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)
+- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f6235a0b..6fee274c 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -143,6 +143,8 @@ pip3 install -e .
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/mvit)
+- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)
+- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
diff --git a/configs/_base_/models/hornet/hornet-base-gf.py b/configs/_base_/models/hornet/hornet-base-gf.py
new file mode 100644
index 00000000..7544970f
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-base-gf.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='base-gf', drop_path_rate=0.5),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1024,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-base.py b/configs/_base_/models/hornet/hornet-base.py
new file mode 100644
index 00000000..82764146
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-base.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='base', drop_path_rate=0.5),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1024,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-large-gf.py b/configs/_base_/models/hornet/hornet-large-gf.py
new file mode 100644
index 00000000..a5b55113
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-large-gf.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='large-gf', drop_path_rate=0.2),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1536,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-large-gf384.py b/configs/_base_/models/hornet/hornet-large-gf384.py
new file mode 100644
index 00000000..fbb54787
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-large-gf384.py
@@ -0,0 +1,17 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='large-gf384', drop_path_rate=0.4),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1536,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ])
diff --git a/configs/_base_/models/hornet/hornet-large.py b/configs/_base_/models/hornet/hornet-large.py
new file mode 100644
index 00000000..26d99e1a
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-large.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='large', drop_path_rate=0.2),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1536,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-small-gf.py b/configs/_base_/models/hornet/hornet-small-gf.py
new file mode 100644
index 00000000..42d9d119
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-small-gf.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='small-gf', drop_path_rate=0.4),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=768,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-small.py b/configs/_base_/models/hornet/hornet-small.py
new file mode 100644
index 00000000..e8039765
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-small.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='small', drop_path_rate=0.4),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=768,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-tiny-gf.py b/configs/_base_/models/hornet/hornet-tiny-gf.py
new file mode 100644
index 00000000..0e417d04
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-tiny-gf.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='tiny-gf', drop_path_rate=0.2),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=512,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/hornet/hornet-tiny.py b/configs/_base_/models/hornet/hornet-tiny.py
new file mode 100644
index 00000000..068d7d6b
--- /dev/null
+++ b/configs/_base_/models/hornet/hornet-tiny.py
@@ -0,0 +1,21 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(type='HorNet', arch='tiny', drop_path_rate=0.2),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=512,
+ init_cfg=None, # suppress the default init_cfg of LinearClsHead.
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-6)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/hornet/README.md b/configs/hornet/README.md
new file mode 100644
index 00000000..7c1b9a9b
--- /dev/null
+++ b/configs/hornet/README.md
@@ -0,0 +1,51 @@
+# HorNet
+
+> [HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions](https://arxiv.org/pdf/2207.14284v2.pdf)
+
+
+
+## Abstract
+
+Recent progress in vision Transformers exhibits great success in various tasks driven by the new spatial modeling mechanism based on dot-product self-attention. In this paper, we show that the key ingredients behind the vision Transformers, namely input-adaptive, long-range and high-order spatial interactions, can also be efficiently implemented with a convolution-based framework. We present the Recursive Gated Convolution (g nConv) that performs high-order spatial interactions with gated convolutions and recursive designs. The new operation is highly flexible and customizable, which is compatible with various variants of convolution and extends the two-order interactions in self-attention to arbitrary orders without introducing significant extra computation. g nConv can serve as a plug-and-play module to improve various vision Transformers and convolution-based models. Based on the operation, we construct a new family of generic vision backbones named HorNet. Extensive experiments on ImageNet classification, COCO object detection and ADE20K semantic segmentation show HorNet outperform Swin Transformers and ConvNeXt by a significant margin with similar overall architecture and training configurations. HorNet also shows favorable scalability to more training data and a larger model size. Apart from the effectiveness in visual encoders, we also show g nConv can be applied to task-specific decoders and consistently improve dense prediction performance with less computation. Our results demonstrate that g nConv can be a new basic module for visual modeling that effectively combines the merits of both vision Transformers and CNNs. Code is available at https://github.com/raoyongming/HorNet.
+
+
+

+
+
+## Results and models
+
+### ImageNet-1k
+
+| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+| :-----------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------: | :----------------------------------------------------------------: |
+| HorNet-T\* | From scratch | 224x224 | 22.41 | 3.98 | 82.84 | 96.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth) |
+| HorNet-T-GF\* | From scratch | 224x224 | 22.99 | 3.9 | 82.98 | 96.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny-gf_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth) |
+| HorNet-S\* | From scratch | 224x224 | 49.53 | 8.83 | 83.79 | 96.75 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth) |
+| HorNet-S-GF\* | From scratch | 224x224 | 50.4 | 8.71 | 83.98 | 96.77 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth) |
+| HorNet-B\* | From scratch | 224x224 | 87.26 | 15.59 | 84.24 | 96.94 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth) |
+| HorNet-B-GF\* | From scratch | 224x224 | 88.42 | 15.42 | 84.32 | 96.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth) |
+
+\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.
+
+### Pre-trained Models
+
+The pre-trained models on ImageNet-21k are used to fine-tune on the downstream tasks.
+
+| Model | Pretrain | resolution | Params(M) | Flops(G) | Download |
+| :--------------: | :----------: | :--------: | :-------: | :------: | :------------------------------------------------------------------------------------------------------------------------: |
+| HorNet-L\* | ImageNet-21k | 224x224 | 194.54 | 34.83 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large_3rdparty_in21k_20220909-9ccef421.pth) |
+| HorNet-L-GF\* | ImageNet-21k | 224x224 | 196.29 | 34.58 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf_3rdparty_in21k_20220909-3aea3b61.pth) |
+| HorNet-L-GF384\* | ImageNet-21k | 384x384 | 201.23 | 101.63 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf384_3rdparty_in21k_20220909-80894290.pth) |
+
+\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet).
+
+## Citation
+
+```
+@article{rao2022hornet,
+ title={HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions},
+ author={Rao, Yongming and Zhao, Wenliang and Tang, Yansong and Zhou, Jie and Lim, Ser-Lam and Lu, Jiwen},
+ journal={arXiv preprint arXiv:2207.14284},
+ year={2022}
+}
+```
diff --git a/configs/hornet/hornet-base-gf_8xb64_in1k.py b/configs/hornet/hornet-base-gf_8xb64_in1k.py
new file mode 100644
index 00000000..6c29de66
--- /dev/null
+++ b/configs/hornet/hornet-base-gf_8xb64_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-base-gf.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=64)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=1.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/hornet-base_8xb64_in1k.py b/configs/hornet/hornet-base_8xb64_in1k.py
new file mode 100644
index 00000000..969d8b95
--- /dev/null
+++ b/configs/hornet/hornet-base_8xb64_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-base.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=64)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=5.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/hornet-small-gf_8xb64_in1k.py b/configs/hornet/hornet-small-gf_8xb64_in1k.py
new file mode 100644
index 00000000..deb570eb
--- /dev/null
+++ b/configs/hornet/hornet-small-gf_8xb64_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-small-gf.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=64)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=1.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/hornet-small_8xb64_in1k.py b/configs/hornet/hornet-small_8xb64_in1k.py
new file mode 100644
index 00000000..c07fa60d
--- /dev/null
+++ b/configs/hornet/hornet-small_8xb64_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-small.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=64)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=5.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/hornet-tiny-gf_8xb128_in1k.py b/configs/hornet/hornet-tiny-gf_8xb128_in1k.py
new file mode 100644
index 00000000..3a1d1a7a
--- /dev/null
+++ b/configs/hornet/hornet-tiny-gf_8xb128_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-tiny-gf.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=128)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=1.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/hornet-tiny_8xb128_in1k.py b/configs/hornet/hornet-tiny_8xb128_in1k.py
new file mode 100644
index 00000000..69a7cdf0
--- /dev/null
+++ b/configs/hornet/hornet-tiny_8xb128_in1k.py
@@ -0,0 +1,13 @@
+_base_ = [
+ '../_base_/models/hornet/hornet-tiny.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+data = dict(samples_per_gpu=128)
+
+optimizer = dict(lr=4e-3)
+optimizer_config = dict(grad_clip=dict(max_norm=100.0), _delete_=True)
+
+custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
diff --git a/configs/hornet/metafile.yml b/configs/hornet/metafile.yml
new file mode 100644
index 00000000..71207722
--- /dev/null
+++ b/configs/hornet/metafile.yml
@@ -0,0 +1,97 @@
+Collections:
+ - Name: HorNet
+ Metadata:
+ Training Data: ImageNet-1k
+ Training Techniques:
+ - AdamW
+ - Weight Decay
+ Architecture:
+ - HorNet
+ - gnConv
+ Paper:
+ URL: https://arxiv.org/pdf/2207.14284v2.pdf
+ Title: "HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions"
+ README: configs/hornet/README.md
+ Code:
+ Version: v0.24.0
+ URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/backbones/hornet.py
+
+Models:
+ - Name: hornet-tiny_3rdparty_in1k
+ Metadata:
+ FLOPs: 3980000000 # 3.98G
+ Parameters: 22410000 # 22.41M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 82.84
+ Top 5 Accuracy: 96.24
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth
+ Config: configs/hornet/hornet-tiny_8xb128_in1k.py
+ - Name: hornet-tiny-gf_3rdparty_in1k
+ Metadata:
+ FLOPs: 3900000000 # 3.9G
+ Parameters: 22990000 # 22.99M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 82.98
+ Top 5 Accuracy: 96.38
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth
+ Config: configs/hornet/hornet-tiny-gf_8xb128_in1k.py
+ - Name: hornet-small_3rdparty_in1k
+ Metadata:
+ FLOPs: 8830000000 # 8.83G
+ Parameters: 49530000 # 49.53M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.79
+ Top 5 Accuracy: 96.75
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth
+ Config: configs/hornet/hornet-small_8xb64_in1k.py
+ - Name: hornet-small-gf_3rdparty_in1k
+ Metadata:
+ FLOPs: 8710000000 # 8.71G
+ Parameters: 50400000 # 50.4M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.98
+ Top 5 Accuracy: 96.77
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth
+ Config: configs/hornet/hornet-small-gf_8xb64_in1k.py
+ - Name: hornet-base_3rdparty_in1k
+ Metadata:
+ FLOPs: 15590000000 # 15.59G
+ Parameters: 87260000 # 87.26M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 84.24
+ Top 5 Accuracy: 96.94
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth
+ Config: configs/hornet/hornet-base_8xb64_in1k.py
+ - Name: hornet-base-gf_3rdparty_in1k
+ Metadata:
+ FLOPs: 15420000000 # 15.42G
+ Parameters: 88420000 # 88.42M
+ In Collection: HorNet
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 84.32
+ Top 5 Accuracy: 96.95
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth
+ Config: configs/hornet/hornet-base-gf_8xb64_in1k.py
diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst
index 37938e34..0c317916 100644
--- a/docs/en/api/models.rst
+++ b/docs/en/api/models.rst
@@ -88,6 +88,7 @@ Backbones
VGG
VisionTransformer
EfficientFormer
+ HorNet
.. _necks:
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index ad7b8189..a919a42c 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -8,6 +8,7 @@ from .deit import DistilledVisionTransformer
from .densenet import DenseNet
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
+from .hornet import HorNet
from .hrnet import HRNet
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
@@ -45,5 +46,6 @@ __all__ = [
'Res2Net', 'RepVGG', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer',
'PCPVT', 'SVT', 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c',
'ConvMixer', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet',
- 'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT', 'EfficientFormer'
+ 'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT', 'EfficientFormer',
+ 'HorNet'
]
diff --git a/mmcls/models/backbones/efficientformer.py b/mmcls/models/backbones/efficientformer.py
index fa3b14eb..173444ff 100644
--- a/mmcls/models/backbones/efficientformer.py
+++ b/mmcls/models/backbones/efficientformer.py
@@ -9,6 +9,7 @@ from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
from mmcv.runner import BaseModule, ModuleList, Sequential
from ..builder import BACKBONES
+from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling
@@ -201,38 +202,6 @@ class ConvMlp(BaseModule):
return x
-class LayerScale(nn.Module):
- """LayerScale layer.
-
- Args:
- dim (int): Dimension of input features.
- inplace (bool): inplace: can optionally do the
- operation in-place. Default: ``False``
- data_format (str): The input data format, can be 'channels_last'
- and 'channels_first', representing (B, C, H, W) and
- (B, N, C) format data respectively.
- """
-
- def __init__(self,
- dim: int,
- inplace: bool = False,
- data_format: str = 'channels_last'):
- super().__init__()
- assert data_format in ('channels_last', 'channels_first'), \
- "'data_format' could only be channels_last or channels_first."
- self.inplace = inplace
- self.data_format = data_format
- self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
-
- def forward(self, x):
- if self.data_format == 'channels_first':
- if self.inplace:
- return x.mul_(self.weight.view(-1, 1, 1))
- else:
- return x * self.weight.view(-1, 1, 1)
- return x.mul_(self.weight) if self.inplace else x * self.weight
-
-
class Meta3D(BaseModule):
"""Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
(B, N, C)."""
diff --git a/mmcls/models/backbones/hornet.py b/mmcls/models/backbones/hornet.py
new file mode 100644
index 00000000..1822b7c0
--- /dev/null
+++ b/mmcls/models/backbones/hornet.py
@@ -0,0 +1,499 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Adapted from official impl at https://github.com/raoyongming/HorNet.
+try:
+ import torch.fft
+ fft = True
+except ImportError:
+ fft = None
+
+import copy
+from functools import partial
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from mmcv.cnn.bricks import DropPath
+
+from mmcls.models.builder import BACKBONES
+from ..utils import LayerScale
+from .base_backbone import BaseBackbone
+
+
+def get_dwconv(dim, kernel_size, bias=True):
+ """build a pepth-wise convolution."""
+ return nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ bias=bias,
+ groups=dim)
+
+
+class HorNetLayerNorm(nn.Module):
+ """An implementation of LayerNorm of HorNet.
+
+ The differences between HorNetLayerNorm & torch LayerNorm:
+ 1. Supports two data formats channels_last or channels_first.
+
+ Args:
+ normalized_shape (int or list or torch.Size): input shape from an
+ expected input of size.
+ eps (float): a value added to the denominator for numerical stability.
+ Defaults to 1e-5.
+ data_format (str): The ordering of the dimensions in the inputs.
+ channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with
+ shape (batch_size, channels, height, width).
+ Defaults to 'channels_last'.
+ """
+
+ def __init__(self,
+ normalized_shape,
+ eps=1e-6,
+ data_format='channels_last'):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ['channels_last', 'channels_first']:
+ raise ValueError(
+ 'data_format must be channels_last or channels_first')
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == 'channels_last':
+ return F.layer_norm(x, self.normalized_shape, self.weight,
+ self.bias, self.eps)
+ elif self.data_format == 'channels_first':
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class GlobalLocalFilter(nn.Module):
+ """A GlobalLocalFilter of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ h (int): Height of complex_weight.
+ Defaults to 14.
+ w (int): Width of complex_weight.
+ Defaults to 8.
+ """
+
+ def __init__(self, dim, h=14, w=8):
+ super().__init__()
+ self.dw = nn.Conv2d(
+ dim // 2,
+ dim // 2,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ groups=dim // 2)
+ self.complex_weight = nn.Parameter(
+ torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
+ self.pre_norm = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+ self.post_norm = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ x1, x2 = torch.chunk(x, 2, dim=1)
+ x1 = self.dw(x1)
+
+ x2 = x2.to(torch.float32)
+ B, C, a, b = x2.shape
+ x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
+
+ weight = self.complex_weight
+ if not weight.shape[1:3] == x2.shape[2:4]:
+ weight = F.interpolate(
+ weight.permute(3, 0, 1, 2),
+ size=x2.shape[2:4],
+ mode='bilinear',
+ align_corners=True).permute(1, 2, 3, 0)
+
+ weight = torch.view_as_complex(weight.contiguous())
+
+ x2 = x2 * weight
+ x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
+
+ x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)],
+ dim=2).reshape(B, 2 * C, a, b)
+ x = self.post_norm(x)
+ return x
+
+
+class gnConv(nn.Module):
+ """A gnConv of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ order (int): Order of gnConv.
+ Defaults to 5.
+ dw_cfg (dict): The Config for dw conv.
+ Defaults to ``dict(type='DW', kernel_size=7)``.
+ scale (float): Scaling parameter of gflayer outputs.
+ Defaults to 1.0.
+ """
+
+ def __init__(self,
+ dim,
+ order=5,
+ dw_cfg=dict(type='DW', kernel_size=7),
+ scale=1.0):
+ super().__init__()
+ self.order = order
+ self.dims = [dim // 2**i for i in range(order)]
+ self.dims.reverse()
+ self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
+
+ cfg = copy.deepcopy(dw_cfg)
+ dw_type = cfg.pop('type')
+ assert dw_type in ['DW', 'GF'],\
+ 'dw_type should be `DW` or `GF`'
+ if dw_type == 'DW':
+ self.dwconv = get_dwconv(sum(self.dims), **cfg)
+ elif dw_type == 'GF':
+ self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg)
+
+ self.proj_out = nn.Conv2d(dim, dim, 1)
+
+ self.projs = nn.ModuleList([
+ nn.Conv2d(self.dims[i], self.dims[i + 1], 1)
+ for i in range(order - 1)
+ ])
+
+ self.scale = scale
+
+ def forward(self, x):
+ x = self.proj_in(x)
+ y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1)
+
+ x = self.dwconv(x) * self.scale
+
+ dw_list = torch.split(x, self.dims, dim=1)
+ x = y * dw_list[0]
+
+ for i in range(self.order - 1):
+ x = self.projs[i](x) * dw_list[i + 1]
+
+ x = self.proj_out(x)
+
+ return x
+
+
+class HorNetBlock(nn.Module):
+ """A block of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ order (int): Order of gnConv.
+ Defaults to 5.
+ dw_cfg (dict): The Config for dw conv.
+ Defaults to ``dict(type='DW', kernel_size=7)``.
+ scale (float): Scaling parameter of gflayer outputs.
+ Defaults to 1.0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ use_layer_scale (bool): Whether to use use_layer_scale in HorNet
+ block. Defaults to True.
+ """
+
+ def __init__(self,
+ dim,
+ order=5,
+ dw_cfg=dict(type='DW', kernel_size=7),
+ scale=1.0,
+ drop_path_rate=0.,
+ use_layer_scale=True):
+ super().__init__()
+ self.out_channels = dim
+
+ self.norm1 = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+ self.gnconv = gnConv(dim, order, dw_cfg, scale)
+ self.norm2 = HorNetLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+
+ if use_layer_scale:
+ self.gamma1 = LayerScale(dim, data_format='channels_first')
+ self.gamma2 = LayerScale(dim)
+ else:
+ self.gamma1, self.gamma2 = nn.Identity(), nn.Identity()
+
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x))))
+
+ input = x
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm2(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ x = self.gamma2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+@BACKBONES.register_module()
+class HorNet(BaseBackbone):
+ """HorNet
+ A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
+ with Recursive Gated Convolutions`
+
+ Inspiration from
+ https://github.com/raoyongming/HorNet
+
+ Args:
+ arch (str | dict): HorNet architecture.
+ If use string, choose from 'tiny', 'small', 'base' and 'large'.
+ If use dict, it should have below keys:
+ - **base_dim** (int): The base dimensions of embedding.
+ - **depths** (List[int]): The number of blocks in each stage.
+ - **orders** (List[int]): The number of order of gnConv in each
+ stage.
+ - **dw_cfg** (List[dict]): The Config for dw conv.
+
+ Defaults to 'tiny'.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3.
+ use_layer_scale (bool): Whether to use use_layer_scale in HorNet
+ block. Defaults to True.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: ``(3, )``.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ gap_before_final_norm (bool): Whether to globally average the feature
+ map before the final norm layer. In the official repo, it's only
+ used in classification task. Defaults to True.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'],
+ {'base_dim': 64,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['t-gf', 'tiny-gf'],
+ {'base_dim': 64,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['s', 'small'],
+ {'base_dim': 96,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['s-gf', 'small-gf'],
+ {'base_dim': 96,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['b', 'base'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['b-gf', 'base-gf'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['b-gf384', 'base-gf384'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=24, w=12),
+ dict(type='GF', h=13, w=7)]}),
+ **dict.fromkeys(['l', 'large'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['l-gf', 'large-gf'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['l-gf384', 'large-gf384'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=24, w=12),
+ dict(type='GF', h=13, w=7)]}),
+ } # yapf: disable
+
+ def __init__(self,
+ arch='tiny',
+ in_channels=3,
+ drop_path_rate=0.,
+ scale=1 / 3,
+ use_layer_scale=True,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ with_cp=False,
+ gap_before_final_norm=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if fft is None:
+ raise RuntimeError(
+ 'Failed to import torch.fft. Please install "torch>=1.7".')
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.scale = scale
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.with_cp = with_cp
+ self.gap_before_final_norm = gap_before_final_norm
+
+ base_dim = self.arch_settings['base_dim']
+ dims = list(map(lambda x: 2**x * base_dim, range(4)))
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4),
+ HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ HorNetLayerNorm(
+ dims[i], eps=1e-6, data_format='channels_first'),
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ total_depth = sum(self.arch_settings['depths'])
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ] # stochastic depth decay rule
+
+ cur_block_idx = 0
+ self.stages = nn.ModuleList()
+ for i in range(4):
+ stage = nn.Sequential(*[
+ HorNetBlock(
+ dim=dims[i],
+ order=self.arch_settings['orders'][i],
+ dw_cfg=self.arch_settings['dw_cfg'][i],
+ scale=self.scale,
+ drop_path_rate=dpr[cur_block_idx + j],
+ use_layer_scale=use_layer_scale)
+ for j in range(self.arch_settings['depths'][i])
+ ])
+ self.stages.append(stage)
+ cur_block_idx += self.arch_settings['depths'][i]
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ out_indices = list(out_indices)
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = len(self.stages) + index
+ assert 0 <= out_indices[i] <= len(self.stages), \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ norm_layer = partial(
+ HorNetLayerNorm, eps=1e-6, data_format='channels_first')
+ for i_layer in out_indices:
+ layer = norm_layer(dims[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ def train(self, mode=True):
+ super(HorNet, self).train(mode)
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ for i in range(0, self.frozen_stages + 1):
+ # freeze patch embed
+ m = self.downsample_layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ # freeze blocks
+ m = self.stages[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if i in self.out_indices:
+ # freeze norm
+ m = getattr(self, f'norm{i + 1}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ outs = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ if self.with_cp:
+ x = checkpoint.checkpoint_sequential(self.stages[i],
+ len(self.stages[i]), x)
+ else:
+ x = self.stages[i](x)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ if self.gap_before_final_norm:
+ gap = x.mean([-2, -1], keepdim=True)
+ outs.append(norm_layer(gap).flatten(1))
+ else:
+ # The output of LayerNorm2d may be discontiguous, which
+ # may cause some problem in the downstream tasks
+ outs.append(norm_layer(x).contiguous())
+ return tuple(outs)
diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py
index 09d72735..05af4db9 100644
--- a/mmcls/models/utils/__init__.py
+++ b/mmcls/models/utils/__init__.py
@@ -6,6 +6,7 @@ from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
resize_relative_position_bias_table)
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
+from .layer_scale import LayerScale
from .make_divisible import make_divisible
from .position_encoding import ConditionalPositionEncoding
from .se_layer import SELayer
@@ -15,5 +16,5 @@ __all__ = [
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
- 'resize_relative_position_bias_table', 'WindowMSAV2'
+ 'resize_relative_position_bias_table', 'WindowMSAV2', 'LayerScale'
]
diff --git a/mmcls/models/utils/layer_scale.py b/mmcls/models/utils/layer_scale.py
new file mode 100644
index 00000000..fbd89bc2
--- /dev/null
+++ b/mmcls/models/utils/layer_scale.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ """LayerScale layer.
+
+ Args:
+ dim (int): Dimension of input features.
+ inplace (bool): inplace: can optionally do the
+ operation in-place. Default: ``False``
+ data_format (str): The input data format, can be 'channels_last'
+ and 'channels_first', representing (B, C, H, W) and
+ (B, N, C) format data respectively.
+ """
+
+ def __init__(self,
+ dim: int,
+ inplace: bool = False,
+ data_format: str = 'channels_last'):
+ super().__init__()
+ assert data_format in ('channels_last', 'channels_first'), \
+ "'data_format' could only be channels_last or channels_first."
+ self.inplace = inplace
+ self.data_format = data_format
+ self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
+
+ def forward(self, x):
+ if self.data_format == 'channels_first':
+ if self.inplace:
+ return x.mul_(self.weight.view(-1, 1, 1))
+ else:
+ return x * self.weight.view(-1, 1, 1)
+ return x.mul_(self.weight) if self.inplace else x * self.weight
diff --git a/model-index.yml b/model-index.yml
index a48ab85a..56c7dc97 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -31,3 +31,4 @@ Import:
- configs/csra/metafile.yml
- configs/mvit/metafile.yml
- configs/efficientformer/metafile.yml
+ - configs/hornet/metafile.yml
diff --git a/tests/test_models/test_backbones/test_efficientformer.py b/tests/test_models/test_backbones/test_efficientformer.py
index 01d9daea..88aad529 100644
--- a/tests/test_models/test_backbones/test_efficientformer.py
+++ b/tests/test_models/test_backbones/test_efficientformer.py
@@ -8,52 +8,10 @@ from torch import nn
from mmcls.models.backbones import EfficientFormer
from mmcls.models.backbones.efficientformer import (AttentionWithBias, Flat,
- LayerScale, Meta3D, Meta4D)
+ Meta3D, Meta4D)
from mmcls.models.backbones.poolformer import Pooling
-class TestLayerScale(TestCase):
-
- def test_init(self):
- with self.assertRaisesRegex(AssertionError, "'data_format' could"):
- cfg = dict(
- dim=10,
- inplace=False,
- data_format='BNC',
- )
- LayerScale(**cfg)
-
- cfg = dict(dim=10)
- ls = LayerScale(**cfg)
- assert torch.equal(ls.weight,
- torch.ones(10, requires_grad=True) * 1e-5)
-
- def forward(self):
- # Test channels_last
- cfg = dict(dim=256, inplace=False, data_format='channels_last')
- ls_channels_last = LayerScale(**cfg)
- x = torch.randn((4, 49, 256))
- out = ls_channels_last(x)
- self.assertEqual(tuple(out.size()), (4, 49, 256))
- assert torch.equal(x * 1e-5, out)
-
- # Test channels_first
- cfg = dict(dim=256, inplace=False, data_format='channels_first')
- ls_channels_first = LayerScale(**cfg)
- x = torch.randn((4, 256, 7, 7))
- out = ls_channels_first(x)
- self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
- assert torch.equal(x * 1e-5, out)
-
- # Test inplace True
- cfg = dict(dim=256, inplace=True, data_format='channels_first')
- ls_channels_first = LayerScale(**cfg)
- x = torch.randn((4, 256, 7, 7))
- out = ls_channels_first(x)
- self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
- self.assertIs(x, out)
-
-
class TestEfficientFormer(TestCase):
def setUp(self):
diff --git a/tests/test_models/test_backbones/test_hornet.py b/tests/test_models/test_backbones/test_hornet.py
new file mode 100644
index 00000000..5fdd84b3
--- /dev/null
+++ b/tests/test_models/test_backbones/test_hornet.py
@@ -0,0 +1,174 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from copy import deepcopy
+from itertools import chain
+from unittest import TestCase
+
+import pytest
+import torch
+from mmcv.utils import digit_version
+from mmcv.utils.parrots_wrapper import _BatchNorm
+from torch import nn
+
+from mmcls.models.backbones import HorNet
+
+
+def check_norm_state(modules, train_state):
+ """Check if norm layer is in correct train state."""
+ for mod in modules:
+ if isinstance(mod, _BatchNorm):
+ if mod.training != train_state:
+ return False
+ return True
+
+
+@pytest.mark.skipif(
+ digit_version(torch.__version__) < digit_version('1.7.0'),
+ reason='torch.fft is not available before 1.7.0')
+class TestHorNet(TestCase):
+
+ def setUp(self):
+ self.cfg = dict(
+ arch='t', drop_path_rate=0.1, gap_before_final_norm=False)
+
+ def test_arch(self):
+ # Test invalid default arch
+ with self.assertRaisesRegex(AssertionError, 'not in default archs'):
+ cfg = deepcopy(self.cfg)
+ cfg['arch'] = 'unknown'
+ HorNet(**cfg)
+
+ # Test invalid custom arch
+ with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
+ cfg = deepcopy(self.cfg)
+ cfg['arch'] = {
+ 'depths': [1, 1, 1, 1],
+ 'orders': [1, 1, 1, 1],
+ }
+ HorNet(**cfg)
+
+ # Test custom arch
+ cfg = deepcopy(self.cfg)
+ base_dim = 64
+ depths = [2, 3, 18, 2]
+ embed_dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8]
+ cfg['arch'] = {
+ 'base_dim':
+ base_dim,
+ 'depths':
+ depths,
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)
+ ],
+ }
+ model = HorNet(**cfg)
+
+ for i in range(len(depths)):
+ stage = model.stages[i]
+ self.assertEqual(stage[-1].out_channels, embed_dims[i])
+ self.assertEqual(len(stage), depths[i])
+
+ def test_init_weights(self):
+ # test weight init cfg
+ cfg = deepcopy(self.cfg)
+ cfg['init_cfg'] = [
+ dict(
+ type='Kaiming',
+ layer='Conv2d',
+ mode='fan_in',
+ nonlinearity='linear')
+ ]
+ model = HorNet(**cfg)
+ ori_weight = model.downsample_layers[0][0].weight.clone().detach()
+
+ model.init_weights()
+ initialized_weight = model.downsample_layers[0][0].weight
+ self.assertFalse(torch.allclose(ori_weight, initialized_weight))
+
+ def test_forward(self):
+ imgs = torch.randn(3, 3, 224, 224)
+
+ cfg = deepcopy(self.cfg)
+ model = HorNet(**cfg)
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 1)
+ feat = outs[-1]
+ self.assertEqual(feat.shape, (3, 512, 7, 7))
+
+ # test multiple output indices
+ cfg = deepcopy(self.cfg)
+ cfg['out_indices'] = (0, 1, 2, 3)
+ model = HorNet(**cfg)
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 4)
+ for emb_size, stride, out in zip([64, 128, 256, 512], [1, 2, 4, 8],
+ outs):
+ self.assertEqual(out.shape,
+ (3, emb_size, 56 // stride, 56 // stride))
+
+ # test with dynamic input shape
+ imgs1 = torch.randn(3, 3, 224, 224)
+ imgs2 = torch.randn(3, 3, 256, 256)
+ imgs3 = torch.randn(3, 3, 256, 309)
+ cfg = deepcopy(self.cfg)
+ model = HorNet(**cfg)
+ for imgs in [imgs1, imgs2, imgs3]:
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 1)
+ feat = outs[-1]
+ expect_feat_shape = (math.floor(imgs.shape[2] / 32),
+ math.floor(imgs.shape[3] / 32))
+ self.assertEqual(feat.shape, (3, 512, *expect_feat_shape))
+
+ def test_structure(self):
+ # test drop_path_rate decay
+ cfg = deepcopy(self.cfg)
+ cfg['drop_path_rate'] = 0.2
+ model = HorNet(**cfg)
+ depths = model.arch_settings['depths']
+ stages = model.stages
+ blocks = chain(*[stage for stage in stages])
+ total_depth = sum(depths)
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
+ ]
+ for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
+ if expect_prob == 0:
+ assert isinstance(block.drop_path, nn.Identity)
+ else:
+ self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
+
+ # test VAN with first stage frozen.
+ cfg = deepcopy(self.cfg)
+ frozen_stages = 0
+ cfg['frozen_stages'] = frozen_stages
+ cfg['out_indices'] = (0, 1, 2, 3)
+ model = HorNet(**cfg)
+ model.init_weights()
+ model.train()
+
+ # the patch_embed and first stage should not require grad.
+ for i in range(frozen_stages + 1):
+ down = model.downsample_layers[i]
+ for param in down.parameters():
+ self.assertFalse(param.requires_grad)
+ blocks = model.stages[i]
+ for param in blocks.parameters():
+ self.assertFalse(param.requires_grad)
+
+ # the second stage should require grad.
+ for i in range(frozen_stages + 1, 4):
+ down = model.downsample_layers[i]
+ for param in down.parameters():
+ self.assertTrue(param.requires_grad)
+ blocks = model.stages[i]
+ for param in blocks.parameters():
+ self.assertTrue(param.requires_grad)
diff --git a/tests/test_models/test_utils/test_layer_scale.py b/tests/test_models/test_utils/test_layer_scale.py
new file mode 100644
index 00000000..824be998
--- /dev/null
+++ b/tests/test_models/test_utils/test_layer_scale.py
@@ -0,0 +1,48 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+
+from mmcls.models.utils import LayerScale
+
+
+class TestLayerScale(TestCase):
+
+ def test_init(self):
+ with self.assertRaisesRegex(AssertionError, "'data_format' could"):
+ cfg = dict(
+ dim=10,
+ inplace=False,
+ data_format='BNC',
+ )
+ LayerScale(**cfg)
+
+ cfg = dict(dim=10)
+ ls = LayerScale(**cfg)
+ assert torch.equal(ls.weight,
+ torch.ones(10, requires_grad=True) * 1e-5)
+
+ def forward(self):
+ # Test channels_last
+ cfg = dict(dim=256, inplace=False, data_format='channels_last')
+ ls_channels_last = LayerScale(**cfg)
+ x = torch.randn((4, 49, 256))
+ out = ls_channels_last(x)
+ self.assertEqual(tuple(out.size()), (4, 49, 256))
+ assert torch.equal(x * 1e-5, out)
+
+ # Test channels_first
+ cfg = dict(dim=256, inplace=False, data_format='channels_first')
+ ls_channels_first = LayerScale(**cfg)
+ x = torch.randn((4, 256, 7, 7))
+ out = ls_channels_first(x)
+ self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
+ assert torch.equal(x * 1e-5, out)
+
+ # Test inplace True
+ cfg = dict(dim=256, inplace=True, data_format='channels_first')
+ ls_channels_first = LayerScale(**cfg)
+ x = torch.randn((4, 256, 7, 7))
+ out = ls_channels_first(x)
+ self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
+ self.assertIs(x, out)
diff --git a/tools/convert_models/hornet2mmcls.py b/tools/convert_models/hornet2mmcls.py
new file mode 100644
index 00000000..6f39ffb2
--- /dev/null
+++ b/tools/convert_models/hornet2mmcls.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmcv
+import torch
+from mmcv.runner import CheckpointLoader
+
+
+def convert_hornet(ckpt):
+
+ new_ckpt = OrderedDict()
+
+ for k, v in list(ckpt.items()):
+ new_v = v
+ if k.startswith('head'):
+ new_k = k.replace('head.', 'head.fc.')
+ new_ckpt[new_k] = new_v
+ continue
+ elif k.startswith('norm'):
+ new_k = k.replace('norm.', 'norm3.')
+ elif 'gnconv.pws' in k:
+ new_k = k.replace('gnconv.pws', 'gnconv.projs')
+ elif 'gamma1' in k:
+ new_k = k.replace('gamma1', 'gamma1.weight')
+ elif 'gamma2' in k:
+ new_k = k.replace('gamma2', 'gamma2.weight')
+ else:
+ new_k = k
+
+ if not new_k.startswith('head'):
+ new_k = 'backbone.' + new_k
+ new_ckpt[new_k] = new_v
+ return new_ckpt
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Convert keys in pretrained van models to mmcls style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+
+ if 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+
+ weight = convert_hornet(state_dict)
+ mmcv.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(weight, args.dst)
+
+ print('Done!!')
+
+
+if __name__ == '__main__':
+ main()