[Feature] Support HorNet Backbone for dev1.x. (#1094)

* add hornet

* add hornet

* fix mixup config

* add optim cfgs

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
This commit is contained in:
takuoko 2022-11-04 16:33:46 +09:00 committed by GitHub
parent b16938dc59
commit d05cbbcf9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1141 additions and 1 deletions

View File

@ -144,7 +144,7 @@ def show_summary(summary_data, args):
if args.inference_time: if args.inference_time:
table.add_column('Inference Time (std) (ms/im)') table.add_column('Inference Time (std) (ms/im)')
if args.flops: if args.flops:
table.add_column('Flops', justify='right') table.add_column('Flops', justify='right', width=11)
table.add_column('Params', justify='right') table.add_column('Params', justify='right')
for model_name, summary in summary_data.items(): for model_name, summary in summary_data.items():

View File

@ -148,6 +148,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone) - [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer) - [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit) - [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit) - [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
</details> </details>

View File

@ -147,6 +147,7 @@ mim install -e .
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone) - [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer) - [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit) - [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit) - [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
</details> </details>

View File

@ -0,0 +1,20 @@
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='base-gf', drop_path_rate=0.5),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='base', drop_path_rate=0.5),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='large-gf', drop_path_rate=0.2),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,17 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='large-gf384', drop_path_rate=0.4),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
])

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='large', drop_path_rate=0.2),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='small-gf', drop_path_rate=0.4),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='small', drop_path_rate=0.4),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='tiny-gf', drop_path_rate=0.2),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HorNet', arch='tiny', drop_path_rate=0.2),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-6)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

51
configs/hornet/README.md Normal file
View File

@ -0,0 +1,51 @@
# HorNet
> [HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions](https://arxiv.org/pdf/2207.14284v2.pdf)
<!-- [ALGORITHM] -->
## Abstract
Recent progress in vision Transformers exhibits great success in various tasks driven by the new spatial modeling mechanism based on dot-product self-attention. In this paper, we show that the key ingredients behind the vision Transformers, namely input-adaptive, long-range and high-order spatial interactions, can also be efficiently implemented with a convolution-based framework. We present the Recursive Gated Convolution (g nConv) that performs high-order spatial interactions with gated convolutions and recursive designs. The new operation is highly flexible and customizable, which is compatible with various variants of convolution and extends the two-order interactions in self-attention to arbitrary orders without introducing significant extra computation. g nConv can serve as a plug-and-play module to improve various vision Transformers and convolution-based models. Based on the operation, we construct a new family of generic vision backbones named HorNet. Extensive experiments on ImageNet classification, COCO object detection and ADE20K semantic segmentation show HorNet outperform Swin Transformers and ConvNeXt by a significant margin with similar overall architecture and training configurations. HorNet also shows favorable scalability to more training data and a larger model size. Apart from the effectiveness in visual encoders, we also show g nConv can be applied to task-specific decoders and consistently improve dense prediction performance with less computation. Our results demonstrate that g nConv can be a new basic module for visual modeling that effectively combines the merits of both vision Transformers and CNNs. Code is available at https://github.com/raoyongming/HorNet.
<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/188356236-b8e3db94-eaa6-48e9-b323-15e5ba7f2991.png" width="80%"/>
</div>
## Results and models
### ImageNet-1k
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-----------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------: | :----------------------------------------------------------------: |
| HorNet-T\* | From scratch | 224x224 | 22.41 | 3.98 | 82.84 | 96.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth) |
| HorNet-T-GF\* | From scratch | 224x224 | 22.99 | 3.9 | 82.98 | 96.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny-gf_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth) |
| HorNet-S\* | From scratch | 224x224 | 49.53 | 8.83 | 83.79 | 96.75 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth) |
| HorNet-S-GF\* | From scratch | 224x224 | 50.4 | 8.71 | 83.98 | 96.77 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth) |
| HorNet-B\* | From scratch | 224x224 | 87.26 | 15.59 | 84.24 | 96.94 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth) |
| HorNet-B-GF\* | From scratch | 224x224 | 88.42 | 15.42 | 84.32 | 96.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth) |
\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.
### Pre-trained Models
The pre-trained models on ImageNet-21k are used to fine-tune on the downstream tasks.
| Model | Pretrain | resolution | Params(M) | Flops(G) | Download |
| :--------------: | :----------: | :--------: | :-------: | :------: | :------------------------------------------------------------------------------------------------------------------------: |
| HorNet-L\* | ImageNet-21k | 224x224 | 194.54 | 34.83 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large_3rdparty_in21k_20220909-9ccef421.pth) |
| HorNet-L-GF\* | ImageNet-21k | 224x224 | 196.29 | 34.58 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf_3rdparty_in21k_20220909-3aea3b61.pth) |
| HorNet-L-GF384\* | ImageNet-21k | 384x384 | 201.23 | 101.63 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf384_3rdparty_in21k_20220909-80894290.pth) |
\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet).
## Citation
```
@article{rao2022hornet,
title={HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions},
author={Rao, Yongming and Zhao, Wenliang and Tang, Yansong and Zhou, Jie and Lim, Ser-Lam and Lu, Jiwen},
journal={arXiv preprint arXiv:2207.14284},
year={2022}
}
```

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-base-gf.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=64)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-base.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=64)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=5.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-small-gf.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=64)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-small.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=64)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=5.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-tiny-gf.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=128)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/hornet/hornet-tiny.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
data = dict(samples_per_gpu=128)
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=100.0))
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]

View File

@ -0,0 +1,97 @@
Collections:
- Name: HorNet
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
- Weight Decay
Architecture:
- HorNet
- gnConv
Paper:
URL: https://arxiv.org/pdf/2207.14284v2.pdf
Title: "HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions"
README: configs/hornet/README.md
Code:
Version: v0.24.0
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/backbones/hornet.py
Models:
- Name: hornet-tiny_3rdparty_in1k
Metadata:
FLOPs: 3976156352 # 3.98G
Parameters: 22409512 # 22.41M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.84
Top 5 Accuracy: 96.24
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth
Config: configs/hornet/hornet-tiny_8xb128_in1k.py
- Name: hornet-tiny-gf_3rdparty_in1k
Metadata:
FLOPs: 3896472160 # 3.9G
Parameters: 22991848 # 22.99M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.98
Top 5 Accuracy: 96.38
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth
Config: configs/hornet/hornet-tiny-gf_8xb128_in1k.py
- Name: hornet-small_3rdparty_in1k
Metadata:
FLOPs: 8825621280 # 8.83G
Parameters: 49528264 # 49.53M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.79
Top 5 Accuracy: 96.75
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth
Config: configs/hornet/hornet-small_8xb64_in1k.py
- Name: hornet-small-gf_3rdparty_in1k
Metadata:
FLOPs: 8706094992 # 8.71G
Parameters: 50401768 # 50.4M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.98
Top 5 Accuracy: 96.77
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth
Config: configs/hornet/hornet-small-gf_8xb64_in1k.py
- Name: hornet-base_3rdparty_in1k
Metadata:
FLOPs: 15582677376 # 15.59G
Parameters: 87256680 # 87.26M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.24
Top 5 Accuracy: 96.94
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth
Config: configs/hornet/hornet-base_8xb64_in1k.py
- Name: hornet-base-gf_3rdparty_in1k
Metadata:
FLOPs: 15423308992 # 15.42G
Parameters: 88421352 # 88.42M
In Collection: HorNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.32
Top 5 Accuracy: 96.95
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth
Config: configs/hornet/hornet-base-gf_8xb64_in1k.py

View File

@ -69,6 +69,7 @@ Backbones
EdgeNeXt EdgeNeXt
EfficientFormer EfficientFormer
EfficientNet EfficientNet
HorNet
HRNet HRNet
InceptionV3 InceptionV3
LeNet5 LeNet5

View File

@ -10,6 +10,7 @@ from .densenet import DenseNet
from .edgenext import EdgeNeXt from .edgenext import EdgeNeXt
from .efficientformer import EfficientFormer from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet from .efficientnet import EfficientNet
from .hornet import HorNet
from .hrnet import HRNet from .hrnet import HRNet
from .inception_v3 import InceptionV3 from .inception_v3 import InceptionV3
from .lenet import LeNet5 from .lenet import LeNet5
@ -90,5 +91,6 @@ __all__ = [
'SwinTransformerV2', 'SwinTransformerV2',
'MViT', 'MViT',
'DeiT3', 'DeiT3',
'HorNet',
'MobileViT', 'MobileViT',
] ]

View File

@ -0,0 +1,495 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from official impl at https://github.com/raoyongming/HorNet.
try:
import torch.fft
fft = True
except ImportError:
fft = None
import copy
from functools import partial
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS
from ..utils import LayerScale
def get_dwconv(dim, kernel_size, bias=True):
"""build a pepth-wise convolution."""
return nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
bias=bias,
groups=dim)
class HorNetLayerNorm(nn.Module):
"""An implementation of LayerNorm of HorNet.
The differences between HorNetLayerNorm & torch LayerNorm:
1. Supports two data formats channels_last or channels_first.
Args:
normalized_shape (int or list or torch.Size): input shape from an
expected input of size.
eps (float): a value added to the denominator for numerical stability.
Defaults to 1e-5.
data_format (str): The ordering of the dimensions in the inputs.
channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with
shape (batch_size, channels, height, width).
Defaults to 'channels_last'.
"""
def __init__(self,
normalized_shape,
eps=1e-6,
data_format='channels_last'):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ['channels_last', 'channels_first']:
raise ValueError(
'data_format must be channels_last or channels_first')
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == 'channels_last':
return F.layer_norm(x, self.normalized_shape, self.weight,
self.bias, self.eps)
elif self.data_format == 'channels_first':
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GlobalLocalFilter(nn.Module):
"""A GlobalLocalFilter of HorNet.
Args:
dim (int): Number of input channels.
h (int): Height of complex_weight.
Defaults to 14.
w (int): Width of complex_weight.
Defaults to 8.
"""
def __init__(self, dim, h=14, w=8):
super().__init__()
self.dw = nn.Conv2d(
dim // 2,
dim // 2,
kernel_size=3,
padding=1,
bias=False,
groups=dim // 2)
self.complex_weight = nn.Parameter(
torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
self.pre_norm = HorNetLayerNorm(
dim, eps=1e-6, data_format='channels_first')
self.post_norm = HorNetLayerNorm(
dim, eps=1e-6, data_format='channels_first')
def forward(self, x):
x = self.pre_norm(x)
x1, x2 = torch.chunk(x, 2, dim=1)
x1 = self.dw(x1)
x2 = x2.to(torch.float32)
B, C, a, b = x2.shape
x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
weight = self.complex_weight
if not weight.shape[1:3] == x2.shape[2:4]:
weight = F.interpolate(
weight.permute(3, 0, 1, 2),
size=x2.shape[2:4],
mode='bilinear',
align_corners=True).permute(1, 2, 3, 0)
weight = torch.view_as_complex(weight.contiguous())
x2 = x2 * weight
x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)],
dim=2).reshape(B, 2 * C, a, b)
x = self.post_norm(x)
return x
class gnConv(nn.Module):
"""A gnConv of HorNet.
Args:
dim (int): Number of input channels.
order (int): Order of gnConv.
Defaults to 5.
dw_cfg (dict): The Config for dw conv.
Defaults to ``dict(type='DW', kernel_size=7)``.
scale (float): Scaling parameter of gflayer outputs.
Defaults to 1.0.
"""
def __init__(self,
dim,
order=5,
dw_cfg=dict(type='DW', kernel_size=7),
scale=1.0):
super().__init__()
self.order = order
self.dims = [dim // 2**i for i in range(order)]
self.dims.reverse()
self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
cfg = copy.deepcopy(dw_cfg)
dw_type = cfg.pop('type')
assert dw_type in ['DW', 'GF'],\
'dw_type should be `DW` or `GF`'
if dw_type == 'DW':
self.dwconv = get_dwconv(sum(self.dims), **cfg)
elif dw_type == 'GF':
self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg)
self.proj_out = nn.Conv2d(dim, dim, 1)
self.projs = nn.ModuleList([
nn.Conv2d(self.dims[i], self.dims[i + 1], 1)
for i in range(order - 1)
])
self.scale = scale
def forward(self, x):
x = self.proj_in(x)
y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1)
x = self.dwconv(x) * self.scale
dw_list = torch.split(x, self.dims, dim=1)
x = y * dw_list[0]
for i in range(self.order - 1):
x = self.projs[i](x) * dw_list[i + 1]
x = self.proj_out(x)
return x
class HorNetBlock(nn.Module):
"""A block of HorNet.
Args:
dim (int): Number of input channels.
order (int): Order of gnConv.
Defaults to 5.
dw_cfg (dict): The Config for dw conv.
Defaults to ``dict(type='DW', kernel_size=7)``.
scale (float): Scaling parameter of gflayer outputs.
Defaults to 1.0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
use_layer_scale (bool): Whether to use use_layer_scale in HorNet
block. Defaults to True.
"""
def __init__(self,
dim,
order=5,
dw_cfg=dict(type='DW', kernel_size=7),
scale=1.0,
drop_path_rate=0.,
use_layer_scale=True):
super().__init__()
self.out_channels = dim
self.norm1 = HorNetLayerNorm(
dim, eps=1e-6, data_format='channels_first')
self.gnconv = gnConv(dim, order, dw_cfg, scale)
self.norm2 = HorNetLayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
if use_layer_scale:
self.gamma1 = LayerScale(dim, data_format='channels_first')
self.gamma2 = LayerScale(dim)
else:
self.gamma1, self.gamma2 = nn.Identity(), nn.Identity()
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x))))
input = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm2(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
x = self.gamma2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
@MODELS.register_module()
class HorNet(BaseBackbone):
"""HorNet
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
with Recursive Gated Convolutions`
Inspiration from
https://github.com/raoyongming/HorNet
Args:
arch (str | dict): HorNet architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **base_dim** (int): The base dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **orders** (List[int]): The number of order of gnConv in each
stage.
- **dw_cfg** (List[dict]): The Config for dw conv.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3.
use_layer_scale (bool): Whether to use use_layer_scale in HorNet
block. Defaults to True.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'base_dim': 64,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
**dict.fromkeys(['t-gf', 'tiny-gf'],
{'base_dim': 64,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)]}),
**dict.fromkeys(['s', 'small'],
{'base_dim': 96,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
**dict.fromkeys(['s-gf', 'small-gf'],
{'base_dim': 96,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)]}),
**dict.fromkeys(['b', 'base'],
{'base_dim': 128,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
**dict.fromkeys(['b-gf', 'base-gf'],
{'base_dim': 128,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)]}),
**dict.fromkeys(['b-gf384', 'base-gf384'],
{'base_dim': 128,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=24, w=12),
dict(type='GF', h=13, w=7)]}),
**dict.fromkeys(['l', 'large'],
{'base_dim': 192,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
**dict.fromkeys(['l-gf', 'large-gf'],
{'base_dim': 192,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)]}),
**dict.fromkeys(['l-gf384', 'large-gf384'],
{'base_dim': 192,
'depths': [2, 3, 18, 2],
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=24, w=12),
dict(type='GF', h=13, w=7)]}),
} # yapf: disable
def __init__(self,
arch='tiny',
in_channels=3,
drop_path_rate=0.,
scale=1 / 3,
use_layer_scale=True,
out_indices=(3, ),
frozen_stages=-1,
with_cp=False,
gap_before_final_norm=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if fft is None:
raise RuntimeError(
'Failed to import torch.fft. Please install "torch>=1.7".')
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.scale = scale
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.with_cp = with_cp
self.gap_before_final_norm = gap_before_final_norm
base_dim = self.arch_settings['base_dim']
dims = list(map(lambda x: 2**x * base_dim, range(4)))
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4),
HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
HorNetLayerNorm(
dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
total_depth = sum(self.arch_settings['depths'])
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur_block_idx = 0
self.stages = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(*[
HorNetBlock(
dim=dims[i],
order=self.arch_settings['orders'][i],
dw_cfg=self.arch_settings['dw_cfg'][i],
scale=self.scale,
drop_path_rate=dpr[cur_block_idx + j],
use_layer_scale=use_layer_scale)
for j in range(self.arch_settings['depths'][i])
])
self.stages.append(stage)
cur_block_idx += self.arch_settings['depths'][i]
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = len(self.stages) + index
assert 0 <= out_indices[i] <= len(self.stages), \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
norm_layer = partial(
HorNetLayerNorm, eps=1e-6, data_format='channels_first')
for i_layer in out_indices:
layer = norm_layer(dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def train(self, mode=True):
super(HorNet, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
for i in range(0, self.frozen_stages + 1):
# freeze patch embed
m = self.downsample_layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze blocks
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
if i in self.out_indices:
# freeze norm
m = getattr(self, f'norm{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
for i in range(4):
x = self.downsample_layers[i](x)
if self.with_cp:
x = checkpoint.checkpoint_sequential(self.stages[i],
len(self.stages[i]), x)
else:
x = self.stages[i](x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
if self.gap_before_final_norm:
gap = x.mean([-2, -1], keepdim=True)
outs.append(norm_layer(gap).flatten(1))
else:
# The output of LayerNorm2d may be discontiguous, which
# may cause some problem in the downstream tasks
outs.append(norm_layer(x).contiguous())
return tuple(outs)

View File

@ -34,4 +34,5 @@ Import:
- configs/efficientformer/metafile.yml - configs/efficientformer/metafile.yml
- configs/swin_transformer_v2/metafile.yml - configs/swin_transformer_v2/metafile.yml
- configs/deit3/metafile.yml - configs/deit3/metafile.yml
- configs/hornet/metafile.yml
- configs/mobilevit/metafile.yml - configs/mobilevit/metafile.yml

View File

@ -0,0 +1,174 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from copy import deepcopy
from itertools import chain
from unittest import TestCase
import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmcls.models.backbones import HorNet
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.7.0'),
reason='torch.fft is not available before 1.7.0')
class TestHorNet(TestCase):
def setUp(self):
self.cfg = dict(
arch='t', drop_path_rate=0.1, gap_before_final_norm=False)
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
HorNet(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'depths': [1, 1, 1, 1],
'orders': [1, 1, 1, 1],
}
HorNet(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
base_dim = 64
depths = [2, 3, 18, 2]
embed_dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8]
cfg['arch'] = {
'base_dim':
base_dim,
'depths':
depths,
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)
],
}
model = HorNet(**cfg)
for i in range(len(depths)):
stage = model.stages[i]
self.assertEqual(stage[-1].out_channels, embed_dims[i])
self.assertEqual(len(stage), depths[i])
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = HorNet(**cfg)
ori_weight = model.downsample_layers[0][0].weight.clone().detach()
model.init_weights()
initialized_weight = model.downsample_layers[0][0].weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = HorNet(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 512, 7, 7))
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 1, 2, 3)
model = HorNet(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for emb_size, stride, out in zip([64, 128, 256, 512], [1, 2, 4, 8],
outs):
self.assertEqual(out.shape,
(3, emb_size, 56 // stride, 56 // stride))
# test with dynamic input shape
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = HorNet(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
expect_feat_shape = (math.floor(imgs.shape[2] / 32),
math.floor(imgs.shape[3] / 32))
self.assertEqual(feat.shape, (3, 512, *expect_feat_shape))
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = HorNet(**cfg)
depths = model.arch_settings['depths']
stages = model.stages
blocks = chain(*[stage for stage in stages])
total_depth = sum(depths)
dpr = [
x.item()
for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
]
for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
if expect_prob == 0:
assert isinstance(block.drop_path, nn.Identity)
else:
self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
# test VAN with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 0
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 1, 2, 3)
model = HorNet(**cfg)
model.init_weights()
model.train()
# the patch_embed and first stage should not require grad.
for i in range(frozen_stages + 1):
down = model.downsample_layers[i]
for param in down.parameters():
self.assertFalse(param.requires_grad)
blocks = model.stages[i]
for param in blocks.parameters():
self.assertFalse(param.requires_grad)
# the second stage should require grad.
for i in range(frozen_stages + 1, 4):
down = model.downsample_layers[i]
for param in down.parameters():
self.assertTrue(param.requires_grad)
blocks = model.stages[i]
for param in blocks.parameters():
self.assertTrue(param.requires_grad)

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_hornet(ckpt):
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
new_v = v
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = new_v
continue
elif k.startswith('norm'):
new_k = k.replace('norm.', 'norm3.')
elif 'gnconv.pws' in k:
new_k = k.replace('gnconv.pws', 'gnconv.projs')
elif 'gamma1' in k:
new_k = k.replace('gamma1', 'gamma1.weight')
elif 'gamma2' in k:
new_k = k.replace('gamma2', 'gamma2.weight')
else:
new_k = k
if not new_k.startswith('head'):
new_k = 'backbone.' + new_k
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained van models to mmcls style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_hornet(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()