[Feature] Support HorNet Backbone for dev1.x. (#1094)
* add hornet * add hornet * fix mixup config * add optim cfgs Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1102/head
parent
b16938dc59
commit
d05cbbcf9b
|
@ -144,7 +144,7 @@ def show_summary(summary_data, args):
|
|||
if args.inference_time:
|
||||
table.add_column('Inference Time (std) (ms/im)')
|
||||
if args.flops:
|
||||
table.add_column('Flops', justify='right')
|
||||
table.add_column('Flops', justify='right', width=11)
|
||||
table.add_column('Params', justify='right')
|
||||
|
||||
for model_name, summary in summary_data.items():
|
||||
|
|
|
@ -148,6 +148,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
|
||||
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
|
||||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
|
||||
</details>
|
||||
|
|
|
@ -147,6 +147,7 @@ mim install -e .
|
|||
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
|
||||
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
|
||||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
|
||||
</details>
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='base-gf', drop_path_rate=0.5),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='base', drop_path_rate=0.5),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='large-gf', drop_path_rate=0.2),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,17 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='large-gf384', drop_path_rate=0.4),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
])
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='large', drop_path_rate=0.2),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='small-gf', drop_path_rate=0.4),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='small', drop_path_rate=0.4),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='tiny-gf', drop_path_rate=0.2),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='HorNet', arch='tiny', drop_path_rate=0.2),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
dict(type='Constant', layer=['LayerScale'], val=1e-6)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,51 @@
|
|||
# HorNet
|
||||
|
||||
> [HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions](https://arxiv.org/pdf/2207.14284v2.pdf)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Recent progress in vision Transformers exhibits great success in various tasks driven by the new spatial modeling mechanism based on dot-product self-attention. In this paper, we show that the key ingredients behind the vision Transformers, namely input-adaptive, long-range and high-order spatial interactions, can also be efficiently implemented with a convolution-based framework. We present the Recursive Gated Convolution (g nConv) that performs high-order spatial interactions with gated convolutions and recursive designs. The new operation is highly flexible and customizable, which is compatible with various variants of convolution and extends the two-order interactions in self-attention to arbitrary orders without introducing significant extra computation. g nConv can serve as a plug-and-play module to improve various vision Transformers and convolution-based models. Based on the operation, we construct a new family of generic vision backbones named HorNet. Extensive experiments on ImageNet classification, COCO object detection and ADE20K semantic segmentation show HorNet outperform Swin Transformers and ConvNeXt by a significant margin with similar overall architecture and training configurations. HorNet also shows favorable scalability to more training data and a larger model size. Apart from the effectiveness in visual encoders, we also show g nConv can be applied to task-specific decoders and consistently improve dense prediction performance with less computation. Our results demonstrate that g nConv can be a new basic module for visual modeling that effectively combines the merits of both vision Transformers and CNNs. Code is available at https://github.com/raoyongming/HorNet.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24734142/188356236-b8e3db94-eaa6-48e9-b323-15e5ba7f2991.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-----------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------: | :----------------------------------------------------------------: |
|
||||
| HorNet-T\* | From scratch | 224x224 | 22.41 | 3.98 | 82.84 | 96.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth) |
|
||||
| HorNet-T-GF\* | From scratch | 224x224 | 22.99 | 3.9 | 82.98 | 96.38 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-tiny-gf_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth) |
|
||||
| HorNet-S\* | From scratch | 224x224 | 49.53 | 8.83 | 83.79 | 96.75 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth) |
|
||||
| HorNet-S-GF\* | From scratch | 224x224 | 50.4 | 8.71 | 83.98 | 96.77 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-small-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth) |
|
||||
| HorNet-B\* | From scratch | 224x224 | 87.26 | 15.59 | 84.24 | 96.94 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth) |
|
||||
| HorNet-B-GF\* | From scratch | 224x224 | 88.42 | 15.42 | 84.32 | 96.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hornet/hornet-base-gf_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth) |
|
||||
|
||||
\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.
|
||||
|
||||
### Pre-trained Models
|
||||
|
||||
The pre-trained models on ImageNet-21k are used to fine-tune on the downstream tasks.
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Download |
|
||||
| :--------------: | :----------: | :--------: | :-------: | :------: | :------------------------------------------------------------------------------------------------------------------------: |
|
||||
| HorNet-L\* | ImageNet-21k | 224x224 | 194.54 | 34.83 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large_3rdparty_in21k_20220909-9ccef421.pth) |
|
||||
| HorNet-L-GF\* | ImageNet-21k | 224x224 | 196.29 | 34.58 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf_3rdparty_in21k_20220909-3aea3b61.pth) |
|
||||
| HorNet-L-GF384\* | ImageNet-21k | 384x384 | 201.23 | 101.63 | [model](https://download.openmmlab.com/mmclassification/v0/hornet/hornet-large-gf384_3rdparty_in21k_20220909-80894290.pth) |
|
||||
|
||||
\*Models with * are converted from [the official repo](https://github.com/raoyongming/HorNet).
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@article{rao2022hornet,
|
||||
title={HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions},
|
||||
author={Rao, Yongming and Zhao, Wenliang and Tang, Yansong and Zhou, Jie and Lim, Ser-Lam and Lu, Jiwen},
|
||||
journal={arXiv preprint arXiv:2207.14284},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-base-gf.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-base.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=5.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-small-gf.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-small.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=5.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-tiny-gf.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=128)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=1.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hornet/hornet-tiny.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=128)
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=4e-3), clip_grad=dict(max_norm=100.0))
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,97 @@
|
|||
Collections:
|
||||
- Name: HorNet
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
- Weight Decay
|
||||
Architecture:
|
||||
- HorNet
|
||||
- gnConv
|
||||
Paper:
|
||||
URL: https://arxiv.org/pdf/2207.14284v2.pdf
|
||||
Title: "HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions"
|
||||
README: configs/hornet/README.md
|
||||
Code:
|
||||
Version: v0.24.0
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/backbones/hornet.py
|
||||
|
||||
Models:
|
||||
- Name: hornet-tiny_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 3976156352 # 3.98G
|
||||
Parameters: 22409512 # 22.41M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.84
|
||||
Top 5 Accuracy: 96.24
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny_3rdparty_in1k_20220915-0e8eedff.pth
|
||||
Config: configs/hornet/hornet-tiny_8xb128_in1k.py
|
||||
- Name: hornet-tiny-gf_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 3896472160 # 3.9G
|
||||
Parameters: 22991848 # 22.99M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.98
|
||||
Top 5 Accuracy: 96.38
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-tiny-gf_3rdparty_in1k_20220915-4c35a66b.pth
|
||||
Config: configs/hornet/hornet-tiny-gf_8xb128_in1k.py
|
||||
- Name: hornet-small_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 8825621280 # 8.83G
|
||||
Parameters: 49528264 # 49.53M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.79
|
||||
Top 5 Accuracy: 96.75
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small_3rdparty_in1k_20220915-5935f60f.pth
|
||||
Config: configs/hornet/hornet-small_8xb64_in1k.py
|
||||
- Name: hornet-small-gf_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 8706094992 # 8.71G
|
||||
Parameters: 50401768 # 50.4M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.98
|
||||
Top 5 Accuracy: 96.77
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-small-gf_3rdparty_in1k_20220915-649ca492.pth
|
||||
Config: configs/hornet/hornet-small-gf_8xb64_in1k.py
|
||||
- Name: hornet-base_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 15582677376 # 15.59G
|
||||
Parameters: 87256680 # 87.26M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.24
|
||||
Top 5 Accuracy: 96.94
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base_3rdparty_in1k_20220915-a06176bb.pth
|
||||
Config: configs/hornet/hornet-base_8xb64_in1k.py
|
||||
- Name: hornet-base-gf_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 15423308992 # 15.42G
|
||||
Parameters: 88421352 # 88.42M
|
||||
In Collection: HorNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.32
|
||||
Top 5 Accuracy: 96.95
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/hornet/hornet-base-gf_3rdparty_in1k_20220915-82c06fa7.pth
|
||||
Config: configs/hornet/hornet-base-gf_8xb64_in1k.py
|
|
@ -69,6 +69,7 @@ Backbones
|
|||
EdgeNeXt
|
||||
EfficientFormer
|
||||
EfficientNet
|
||||
HorNet
|
||||
HRNet
|
||||
InceptionV3
|
||||
LeNet5
|
||||
|
|
|
@ -10,6 +10,7 @@ from .densenet import DenseNet
|
|||
from .edgenext import EdgeNeXt
|
||||
from .efficientformer import EfficientFormer
|
||||
from .efficientnet import EfficientNet
|
||||
from .hornet import HorNet
|
||||
from .hrnet import HRNet
|
||||
from .inception_v3 import InceptionV3
|
||||
from .lenet import LeNet5
|
||||
|
@ -90,5 +91,6 @@ __all__ = [
|
|||
'SwinTransformerV2',
|
||||
'MViT',
|
||||
'DeiT3',
|
||||
'HorNet',
|
||||
'MobileViT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,495 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapted from official impl at https://github.com/raoyongming/HorNet.
|
||||
try:
|
||||
import torch.fft
|
||||
fft = True
|
||||
except ImportError:
|
||||
fft = None
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import LayerScale
|
||||
|
||||
|
||||
def get_dwconv(dim, kernel_size, bias=True):
|
||||
"""build a pepth-wise convolution."""
|
||||
return nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
bias=bias,
|
||||
groups=dim)
|
||||
|
||||
|
||||
class HorNetLayerNorm(nn.Module):
|
||||
"""An implementation of LayerNorm of HorNet.
|
||||
|
||||
The differences between HorNetLayerNorm & torch LayerNorm:
|
||||
1. Supports two data formats channels_last or channels_first.
|
||||
Args:
|
||||
normalized_shape (int or list or torch.Size): input shape from an
|
||||
expected input of size.
|
||||
eps (float): a value added to the denominator for numerical stability.
|
||||
Defaults to 1e-5.
|
||||
data_format (str): The ordering of the dimensions in the inputs.
|
||||
channels_last corresponds to inputs with shape (batch_size, height,
|
||||
width, channels) while channels_first corresponds to inputs with
|
||||
shape (batch_size, channels, height, width).
|
||||
Defaults to 'channels_last'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape,
|
||||
eps=1e-6,
|
||||
data_format='channels_last'):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ['channels_last', 'channels_first']:
|
||||
raise ValueError(
|
||||
'data_format must be channels_last or channels_first')
|
||||
self.normalized_shape = (normalized_shape, )
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == 'channels_last':
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight,
|
||||
self.bias, self.eps)
|
||||
elif self.data_format == 'channels_first':
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
class GlobalLocalFilter(nn.Module):
|
||||
"""A GlobalLocalFilter of HorNet.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
h (int): Height of complex_weight.
|
||||
Defaults to 14.
|
||||
w (int): Width of complex_weight.
|
||||
Defaults to 8.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, h=14, w=8):
|
||||
super().__init__()
|
||||
self.dw = nn.Conv2d(
|
||||
dim // 2,
|
||||
dim // 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=dim // 2)
|
||||
self.complex_weight = nn.Parameter(
|
||||
torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
|
||||
self.pre_norm = HorNetLayerNorm(
|
||||
dim, eps=1e-6, data_format='channels_first')
|
||||
self.post_norm = HorNetLayerNorm(
|
||||
dim, eps=1e-6, data_format='channels_first')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pre_norm(x)
|
||||
x1, x2 = torch.chunk(x, 2, dim=1)
|
||||
x1 = self.dw(x1)
|
||||
|
||||
x2 = x2.to(torch.float32)
|
||||
B, C, a, b = x2.shape
|
||||
x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
|
||||
|
||||
weight = self.complex_weight
|
||||
if not weight.shape[1:3] == x2.shape[2:4]:
|
||||
weight = F.interpolate(
|
||||
weight.permute(3, 0, 1, 2),
|
||||
size=x2.shape[2:4],
|
||||
mode='bilinear',
|
||||
align_corners=True).permute(1, 2, 3, 0)
|
||||
|
||||
weight = torch.view_as_complex(weight.contiguous())
|
||||
|
||||
x2 = x2 * weight
|
||||
x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
|
||||
|
||||
x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)],
|
||||
dim=2).reshape(B, 2 * C, a, b)
|
||||
x = self.post_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class gnConv(nn.Module):
|
||||
"""A gnConv of HorNet.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
order (int): Order of gnConv.
|
||||
Defaults to 5.
|
||||
dw_cfg (dict): The Config for dw conv.
|
||||
Defaults to ``dict(type='DW', kernel_size=7)``.
|
||||
scale (float): Scaling parameter of gflayer outputs.
|
||||
Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
order=5,
|
||||
dw_cfg=dict(type='DW', kernel_size=7),
|
||||
scale=1.0):
|
||||
super().__init__()
|
||||
self.order = order
|
||||
self.dims = [dim // 2**i for i in range(order)]
|
||||
self.dims.reverse()
|
||||
self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
|
||||
|
||||
cfg = copy.deepcopy(dw_cfg)
|
||||
dw_type = cfg.pop('type')
|
||||
assert dw_type in ['DW', 'GF'],\
|
||||
'dw_type should be `DW` or `GF`'
|
||||
if dw_type == 'DW':
|
||||
self.dwconv = get_dwconv(sum(self.dims), **cfg)
|
||||
elif dw_type == 'GF':
|
||||
self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg)
|
||||
|
||||
self.proj_out = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
self.projs = nn.ModuleList([
|
||||
nn.Conv2d(self.dims[i], self.dims[i + 1], 1)
|
||||
for i in range(order - 1)
|
||||
])
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj_in(x)
|
||||
y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1)
|
||||
|
||||
x = self.dwconv(x) * self.scale
|
||||
|
||||
dw_list = torch.split(x, self.dims, dim=1)
|
||||
x = y * dw_list[0]
|
||||
|
||||
for i in range(self.order - 1):
|
||||
x = self.projs[i](x) * dw_list[i + 1]
|
||||
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class HorNetBlock(nn.Module):
|
||||
"""A block of HorNet.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
order (int): Order of gnConv.
|
||||
Defaults to 5.
|
||||
dw_cfg (dict): The Config for dw conv.
|
||||
Defaults to ``dict(type='DW', kernel_size=7)``.
|
||||
scale (float): Scaling parameter of gflayer outputs.
|
||||
Defaults to 1.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
use_layer_scale (bool): Whether to use use_layer_scale in HorNet
|
||||
block. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
order=5,
|
||||
dw_cfg=dict(type='DW', kernel_size=7),
|
||||
scale=1.0,
|
||||
drop_path_rate=0.,
|
||||
use_layer_scale=True):
|
||||
super().__init__()
|
||||
self.out_channels = dim
|
||||
|
||||
self.norm1 = HorNetLayerNorm(
|
||||
dim, eps=1e-6, data_format='channels_first')
|
||||
self.gnconv = gnConv(dim, order, dw_cfg, scale)
|
||||
self.norm2 = HorNetLayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, 4 * dim)
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(4 * dim, dim)
|
||||
|
||||
if use_layer_scale:
|
||||
self.gamma1 = LayerScale(dim, data_format='channels_first')
|
||||
self.gamma2 = LayerScale(dim)
|
||||
else:
|
||||
self.gamma1, self.gamma2 = nn.Identity(), nn.Identity()
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x))))
|
||||
|
||||
input = x
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
x = self.norm2(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
x = self.gamma2(x)
|
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||
|
||||
x = input + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HorNet(BaseBackbone):
|
||||
"""HorNet
|
||||
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
|
||||
with Recursive Gated Convolutions`
|
||||
Inspiration from
|
||||
https://github.com/raoyongming/HorNet
|
||||
Args:
|
||||
arch (str | dict): HorNet architecture.
|
||||
If use string, choose from 'tiny', 'small', 'base' and 'large'.
|
||||
If use dict, it should have below keys:
|
||||
- **base_dim** (int): The base dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **orders** (List[int]): The number of order of gnConv in each
|
||||
stage.
|
||||
- **dw_cfg** (List[dict]): The Config for dw conv.
|
||||
Defaults to 'tiny'.
|
||||
in_channels (int): Number of input image channels. Defaults to 3.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3.
|
||||
use_layer_scale (bool): Whether to use use_layer_scale in HorNet
|
||||
block. Defaults to True.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: ``(3, )``.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
gap_before_final_norm (bool): Whether to globally average the feature
|
||||
map before the final norm layer. In the official repo, it's only
|
||||
used in classification task. Defaults to True.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'base_dim': 64,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
|
||||
**dict.fromkeys(['t-gf', 'tiny-gf'],
|
||||
{'base_dim': 64,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=14, w=8),
|
||||
dict(type='GF', h=7, w=4)]}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'base_dim': 96,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
|
||||
**dict.fromkeys(['s-gf', 'small-gf'],
|
||||
{'base_dim': 96,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=14, w=8),
|
||||
dict(type='GF', h=7, w=4)]}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'base_dim': 128,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
|
||||
**dict.fromkeys(['b-gf', 'base-gf'],
|
||||
{'base_dim': 128,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=14, w=8),
|
||||
dict(type='GF', h=7, w=4)]}),
|
||||
**dict.fromkeys(['b-gf384', 'base-gf384'],
|
||||
{'base_dim': 128,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=24, w=12),
|
||||
dict(type='GF', h=13, w=7)]}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'base_dim': 192,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
|
||||
**dict.fromkeys(['l-gf', 'large-gf'],
|
||||
{'base_dim': 192,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=14, w=8),
|
||||
dict(type='GF', h=7, w=4)]}),
|
||||
**dict.fromkeys(['l-gf384', 'large-gf384'],
|
||||
{'base_dim': 192,
|
||||
'depths': [2, 3, 18, 2],
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=24, w=12),
|
||||
dict(type='GF', h=13, w=7)]}),
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='tiny',
|
||||
in_channels=3,
|
||||
drop_path_rate=0.,
|
||||
scale=1 / 3,
|
||||
use_layer_scale=True,
|
||||
out_indices=(3, ),
|
||||
frozen_stages=-1,
|
||||
with_cp=False,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if fft is None:
|
||||
raise RuntimeError(
|
||||
'Failed to import torch.fft. Please install "torch>=1.7".')
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.scale = scale
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.with_cp = with_cp
|
||||
self.gap_before_final_norm = gap_before_final_norm
|
||||
|
||||
base_dim = self.arch_settings['base_dim']
|
||||
dims = list(map(lambda x: 2**x * base_dim, range(4)))
|
||||
|
||||
self.downsample_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4),
|
||||
HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
|
||||
self.downsample_layers.append(stem)
|
||||
for i in range(3):
|
||||
downsample_layer = nn.Sequential(
|
||||
HorNetLayerNorm(
|
||||
dims[i], eps=1e-6, data_format='channels_first'),
|
||||
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
||||
)
|
||||
self.downsample_layers.append(downsample_layer)
|
||||
|
||||
total_depth = sum(self.arch_settings['depths'])
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
cur_block_idx = 0
|
||||
self.stages = nn.ModuleList()
|
||||
for i in range(4):
|
||||
stage = nn.Sequential(*[
|
||||
HorNetBlock(
|
||||
dim=dims[i],
|
||||
order=self.arch_settings['orders'][i],
|
||||
dw_cfg=self.arch_settings['dw_cfg'][i],
|
||||
scale=self.scale,
|
||||
drop_path_rate=dpr[cur_block_idx + j],
|
||||
use_layer_scale=use_layer_scale)
|
||||
for j in range(self.arch_settings['depths'][i])
|
||||
])
|
||||
self.stages.append(stage)
|
||||
cur_block_idx += self.arch_settings['depths'][i]
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
out_indices = list(out_indices)
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = len(self.stages) + index
|
||||
assert 0 <= out_indices[i] <= len(self.stages), \
|
||||
f'Invalid out_indices {index}.'
|
||||
self.out_indices = out_indices
|
||||
|
||||
norm_layer = partial(
|
||||
HorNetLayerNorm, eps=1e-6, data_format='channels_first')
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(dims[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(HorNet, self).train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
# freeze patch embed
|
||||
m = self.downsample_layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze blocks
|
||||
m = self.stages[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if i in self.out_indices:
|
||||
# freeze norm
|
||||
m = getattr(self, f'norm{i + 1}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i in range(4):
|
||||
x = self.downsample_layers[i](x)
|
||||
if self.with_cp:
|
||||
x = checkpoint.checkpoint_sequential(self.stages[i],
|
||||
len(self.stages[i]), x)
|
||||
else:
|
||||
x = self.stages[i](x)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
if self.gap_before_final_norm:
|
||||
gap = x.mean([-2, -1], keepdim=True)
|
||||
outs.append(norm_layer(gap).flatten(1))
|
||||
else:
|
||||
# The output of LayerNorm2d may be discontiguous, which
|
||||
# may cause some problem in the downstream tasks
|
||||
outs.append(norm_layer(x).contiguous())
|
||||
return tuple(outs)
|
|
@ -34,4 +34,5 @@ Import:
|
|||
- configs/efficientformer/metafile.yml
|
||||
- configs/swin_transformer_v2/metafile.yml
|
||||
- configs/deit3/metafile.yml
|
||||
- configs/hornet/metafile.yml
|
||||
- configs/mobilevit/metafile.yml
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models.backbones import HorNet
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(torch.__version__) < digit_version('1.7.0'),
|
||||
reason='torch.fft is not available before 1.7.0')
|
||||
class TestHorNet(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='t', drop_path_rate=0.1, gap_before_final_norm=False)
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
HorNet(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'depths': [1, 1, 1, 1],
|
||||
'orders': [1, 1, 1, 1],
|
||||
}
|
||||
HorNet(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
base_dim = 64
|
||||
depths = [2, 3, 18, 2]
|
||||
embed_dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8]
|
||||
cfg['arch'] = {
|
||||
'base_dim':
|
||||
base_dim,
|
||||
'depths':
|
||||
depths,
|
||||
'orders': [2, 3, 4, 5],
|
||||
'dw_cfg': [
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='DW', kernel_size=7),
|
||||
dict(type='GF', h=14, w=8),
|
||||
dict(type='GF', h=7, w=4)
|
||||
],
|
||||
}
|
||||
model = HorNet(**cfg)
|
||||
|
||||
for i in range(len(depths)):
|
||||
stage = model.stages[i]
|
||||
self.assertEqual(stage[-1].out_channels, embed_dims[i])
|
||||
self.assertEqual(len(stage), depths[i])
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = HorNet(**cfg)
|
||||
ori_weight = model.downsample_layers[0][0].weight.clone().detach()
|
||||
|
||||
model.init_weights()
|
||||
initialized_weight = model.downsample_layers[0][0].weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = HorNet(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 512, 7, 7))
|
||||
|
||||
# test multiple output indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = HorNet(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 4)
|
||||
for emb_size, stride, out in zip([64, 128, 256, 512], [1, 2, 4, 8],
|
||||
outs):
|
||||
self.assertEqual(out.shape,
|
||||
(3, emb_size, 56 // stride, 56 // stride))
|
||||
|
||||
# test with dynamic input shape
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = HorNet(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
expect_feat_shape = (math.floor(imgs.shape[2] / 32),
|
||||
math.floor(imgs.shape[3] / 32))
|
||||
self.assertEqual(feat.shape, (3, 512, *expect_feat_shape))
|
||||
|
||||
def test_structure(self):
|
||||
# test drop_path_rate decay
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.2
|
||||
model = HorNet(**cfg)
|
||||
depths = model.arch_settings['depths']
|
||||
stages = model.stages
|
||||
blocks = chain(*[stage for stage in stages])
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
|
||||
]
|
||||
for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
|
||||
if expect_prob == 0:
|
||||
assert isinstance(block.drop_path, nn.Identity)
|
||||
else:
|
||||
self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
|
||||
|
||||
# test VAN with first stage frozen.
|
||||
cfg = deepcopy(self.cfg)
|
||||
frozen_stages = 0
|
||||
cfg['frozen_stages'] = frozen_stages
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = HorNet(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# the patch_embed and first stage should not require grad.
|
||||
for i in range(frozen_stages + 1):
|
||||
down = model.downsample_layers[i]
|
||||
for param in down.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
blocks = model.stages[i]
|
||||
for param in blocks.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
|
||||
# the second stage should require grad.
|
||||
for i in range(frozen_stages + 1, 4):
|
||||
down = model.downsample_layers[i]
|
||||
for param in down.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
blocks = model.stages[i]
|
||||
for param in blocks.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_hornet(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('head'):
|
||||
new_k = k.replace('head.', 'head.fc.')
|
||||
new_ckpt[new_k] = new_v
|
||||
continue
|
||||
elif k.startswith('norm'):
|
||||
new_k = k.replace('norm.', 'norm3.')
|
||||
elif 'gnconv.pws' in k:
|
||||
new_k = k.replace('gnconv.pws', 'gnconv.projs')
|
||||
elif 'gamma1' in k:
|
||||
new_k = k.replace('gamma1', 'gamma1.weight')
|
||||
elif 'gamma2' in k:
|
||||
new_k = k.replace('gamma2', 'gamma2.weight')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
if not new_k.startswith('head'):
|
||||
new_k = 'backbone.' + new_k
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained van models to mmcls style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_hornet(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue