[Feature] Support VAN. (#739)
* add van * fix config * add metafile * add test * model convert script * fix review * fix lint * fix the configs and improve docs * rm debug lines * add VAN into api Co-authored-by: Yu Zhaohui <1105212286@qq.com>pull/750/head^2
parent
504e71c3e0
commit
df6edd7f5a
|
@ -122,6 +122,7 @@ venv.bak/
|
|||
*.log.json
|
||||
/work_dirs
|
||||
/mmcls/.mim
|
||||
.DS_Store
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
|
|
@ -138,6 +138,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
|
||||
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
|
||||
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
|
||||
|
|
|
@ -136,6 +136,7 @@ pip3 install -e .
|
|||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer)
|
||||
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
|
||||
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VAN', arch='base', drop_path_rate=0.1),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
|
@ -0,0 +1,13 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VAN', arch='large', drop_path_rate=0.2),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VAN', arch='small', drop_path_rate=0.1),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VAN', arch='tiny', drop_path_rate=0.1),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=256,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
|
@ -0,0 +1,37 @@
|
|||
# Visual Attention Network
|
||||
|
||||
> [Visual Attention Network](https://arxiv.org/pdf/2202.09741v2.pdf)
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:------:|:--------:|
|
||||
| VAN-T\* | From scratch | 224x224 | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) |
|
||||
| VAN-S\* | From scratch | 224x224 | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) |
|
||||
| VAN-B\* | From scratch | 224x224 | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) |
|
||||
| VAN-L\* | From scratch | 224x224 | 44.77 | 8.99 | 83.86 | 96.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth) |
|
||||
|
||||
*Models with \* are converted from [the official repo](https://github.com/Visual-Attention-Network/VAN-Classification). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@article{guo2022visual,
|
||||
title={Visual Attention Network},
|
||||
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
|
||||
journal={arXiv preprint arXiv:2202.09741},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,70 @@
|
|||
Collections:
|
||||
- Name: Visual-Attention-Network
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
- Weight Decay
|
||||
Architecture:
|
||||
- Visual Attention Network
|
||||
Paper:
|
||||
URL: https://arxiv.org/pdf/2202.09741v2.pdf
|
||||
Title: "Visual Attention Network"
|
||||
README: configs/van/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/van.py
|
||||
Version: v0.23.0
|
||||
|
||||
Models:
|
||||
- Name: van-tiny_8xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 4110000 # 4.11M
|
||||
Parameters: 880000000 # 0.88G
|
||||
In Collection: Visual-Attention-Network
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 75.41
|
||||
Top 5 Accuracy: 93.02
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth
|
||||
Config: configs/van/van-tiny_8xb128_in1k.py
|
||||
- Name: van-small_8xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 13860000 # 13.86M
|
||||
Parameters: 2520000000 # 2.52G
|
||||
In Collection: Visual-Attention-Network
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.01
|
||||
Top 5 Accuracy: 95.63
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth
|
||||
Config: configs/van/van-small_8xb128_in1k.py
|
||||
- Name: van-base_8xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 26580000 # 26.58M
|
||||
Parameters: 5030000000 # 5.03G
|
||||
In Collection: Visual-Attention-Network
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.80
|
||||
Top 5 Accuracy: 96.21
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth
|
||||
Config: configs/van/van-base_8xb128_in1k.py
|
||||
- Name: van-large_8xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 44770000 # 44.77 M
|
||||
Parameters: 8990000000 # 8.99G
|
||||
In Collection: Visual-Attention-Network
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.86
|
||||
Top 5 Accuracy: 96.73
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth
|
||||
Config: configs/van/van-large_8xb128_in1k.py
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/van/van_base.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Note that the mean and variance used here are different from other configs
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/van/van_large.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Note that the mean and variance used here are different from other configs
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/van/van_small.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Note that the mean and variance used here are different from other configs
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/van/van_tiny.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Note that the mean and variance used here are different from other configs
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
|
@ -83,6 +83,7 @@ Backbones
|
|||
T2T_ViT
|
||||
TIMMBackbone
|
||||
TNT
|
||||
VAN
|
||||
VGG
|
||||
VisionTransformer
|
||||
|
||||
|
|
|
@ -133,6 +133,10 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| CSPDarkNet50\* | 27.64 | 5.04 | 80.05 | 95.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspdarknet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspdarknet50_3rdparty_8xb32_in1k_20220329-bd275287.pth) |
|
||||
| CSPResNet50\* | 21.62 | 3.48 | 79.55 | 94.68 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnet50_3rdparty_8xb32_in1k_20220329-dd6dddfb.pth) |
|
||||
| CSPResNeXt50\* | 20.57 | 3.11 | 79.96 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnext50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnext50_3rdparty_8xb32_in1k_20220329-2cc84d21.pth) |
|
||||
| VAN-T\* | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) |
|
||||
| VAN-S\* | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) |
|
||||
| VAN-B\* | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) |
|
||||
| VAN-L\* | 44.77 | 8.99 | 83.86 | 96.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth) |
|
||||
|
||||
*Models with \* are converted from other repos, others are trained by ourselves.*
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
../en/changelog.md
|
||||
../en/changelog.md
|
||||
|
|
|
@ -1 +1 @@
|
|||
../en/model_zoo.md
|
||||
../en/model_zoo.md
|
||||
|
|
|
@ -29,46 +29,17 @@ from .t2t_vit import T2T_ViT
|
|||
from .timm_backbone import TIMMBackbone
|
||||
from .tnt import TNT
|
||||
from .twins import PCPVT, SVT
|
||||
from .van import VAN
|
||||
from .vgg import VGG
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
__all__ = [
|
||||
'LeNet5',
|
||||
'AlexNet',
|
||||
'VGG',
|
||||
'RegNet',
|
||||
'ResNet',
|
||||
'ResNeXt',
|
||||
'ResNetV1d',
|
||||
'ResNeSt',
|
||||
'ResNet_CIFAR',
|
||||
'SEResNet',
|
||||
'SEResNeXt',
|
||||
'ShuffleNetV1',
|
||||
'ShuffleNetV2',
|
||||
'MobileNetV2',
|
||||
'MobileNetV3',
|
||||
'VisionTransformer',
|
||||
'SwinTransformer',
|
||||
'TNT',
|
||||
'TIMMBackbone',
|
||||
'T2T_ViT',
|
||||
'Res2Net',
|
||||
'RepVGG',
|
||||
'Conformer',
|
||||
'MlpMixer',
|
||||
'DistilledVisionTransformer',
|
||||
'PCPVT',
|
||||
'SVT',
|
||||
'EfficientNet',
|
||||
'ConvNeXt',
|
||||
'HRNet',
|
||||
'ResNetV1c',
|
||||
'ConvMixer',
|
||||
'CSPDarkNet',
|
||||
'CSPResNet',
|
||||
'CSPResNeXt',
|
||||
'CSPNet',
|
||||
'RepMLPNet',
|
||||
'PoolFormer',
|
||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer',
|
||||
'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet',
|
||||
'PoolFormer', 'VAN'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,434 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of VAN. Refer to
|
||||
mmdetection/mmdet/models/backbones/pvt.py.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Depth-wise Conv to encode positional information.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
init_cfg=None):
|
||||
super(MixFFN, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.fc1 = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1)
|
||||
self.dwconv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=1)
|
||||
self.drop = nn.Dropout(ffn_drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LKA(BaseModule):
|
||||
"""Large Kernel Attention(LKA) of VAN.
|
||||
|
||||
.. code:: text
|
||||
DW_conv (depth-wise convolution)
|
||||
|
|
||||
|
|
||||
DW_D_conv (depth-wise dilation convolution)
|
||||
|
|
||||
|
|
||||
Transition Convolution (1×1 convolution)
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims, init_cfg=None):
|
||||
super(LKA, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# a spatial local convolution (depth-wise convolution)
|
||||
self.DW_conv = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
groups=embed_dims)
|
||||
|
||||
# a spatial long-range convolution (depth-wise dilation convolution)
|
||||
self.DW_D_conv = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=9,
|
||||
groups=embed_dims,
|
||||
dilation=3)
|
||||
|
||||
self.conv1 = Conv2d(
|
||||
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
u = x.clone()
|
||||
attn = self.DW_conv(x)
|
||||
attn = self.DW_D_conv(attn)
|
||||
attn = self.conv1(attn)
|
||||
|
||||
return u * attn
|
||||
|
||||
|
||||
class SpatialAttention(BaseModule):
|
||||
"""Basic attention module in VANBloack.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
|
||||
super(SpatialAttention, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.proj_1 = Conv2d(
|
||||
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
self.spatial_gating_unit = LKA(embed_dims)
|
||||
self.proj_2 = Conv2d(
|
||||
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
shorcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
return x
|
||||
|
||||
|
||||
class VANBlock(BaseModule):
|
||||
"""A block of VAN.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
layer_scale_init_value (float): Init value for Layer Scale.
|
||||
Defaults to 1e-2.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
ffn_ratio=4.,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='BN', eps=1e-5),
|
||||
layer_scale_init_value=1e-2,
|
||||
init_cfg=None):
|
||||
super(VANBlock, self).__init__(init_cfg=init_cfg)
|
||||
self.out_channels = embed_dims
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
|
||||
self.drop_path = DropPath(
|
||||
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
mlp_hidden_dim = int(embed_dims * ffn_ratio)
|
||||
self.mlp = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=mlp_hidden_dim,
|
||||
act_cfg=act_cfg,
|
||||
ffn_drop=drop_rate)
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True) if layer_scale_init_value > 0 else None
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True) if layer_scale_init_value > 0 else None
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
if self.layer_scale_1 is not None:
|
||||
x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
|
||||
x = identity + self.drop_path(x)
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.mlp(x)
|
||||
if self.layer_scale_2 is not None:
|
||||
x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
|
||||
x = identity + self.drop_path(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VANPatchEmbed(PatchEmbed):
|
||||
"""Image to Patch Embedding of VAN.
|
||||
|
||||
The differences between VANPatchEmbed & PatchEmbed:
|
||||
1. Use BN.
|
||||
2. Do not use 'flatten' and 'transpose'.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
|
||||
super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(out_h, out_w).
|
||||
"""
|
||||
|
||||
if self.adaptive_padding:
|
||||
x = self.adaptive_padding(x)
|
||||
|
||||
x = self.projection(x)
|
||||
out_size = (x.shape[2], x.shape[3])
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, out_size
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class VAN(BaseBackbone):
|
||||
"""Visual Attention Network.
|
||||
|
||||
A PyTorch implement of : `Visual Attention Network
|
||||
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
|
||||
|
||||
Inspiration from
|
||||
https://github.com/Visual-Attention-Network/VAN-Classification
|
||||
|
||||
Args:
|
||||
arch (str | dict): Visual Attention Network architecture.
|
||||
If use string, choose from 'tiny', 'small', 'base' and 'large'.
|
||||
If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (List[int]): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **ffn_ratios** (List[int]): The number of expansion ratio of
|
||||
feedforward network hidden layer channels.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
|
||||
Defaults to [7, 3, 3, 3].
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: ``(3, )``.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.models import VAN
|
||||
>>> import torch
|
||||
>>> cfg = dict(arch='tiny')
|
||||
>>> model = VAN(**cfg)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> outputs = model(inputs)
|
||||
>>> for out in outputs:
|
||||
>>> print(out.size())
|
||||
(1, 256, 7, 7)
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'embed_dims': [32, 64, 160, 256],
|
||||
'depths': [3, 3, 5, 2],
|
||||
'ffn_ratios': [8, 8, 4, 4]}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'embed_dims': [64, 128, 320, 512],
|
||||
'depths': [2, 2, 4, 2],
|
||||
'ffn_ratios': [8, 8, 4, 4]}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'embed_dims': [64, 128, 320, 512],
|
||||
'depths': [3, 3, 12, 3],
|
||||
'ffn_ratios': [8, 8, 4, 4]}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'embed_dims': [64, 128, 320, 512],
|
||||
'depths': [3, 5, 27, 3],
|
||||
'ffn_ratios': [8, 8, 4, 4]}),
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='tiny',
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
in_channels=3,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
out_indices=(3, ),
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
block_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(VAN, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.ffn_ratios = self.arch_settings['ffn_ratios']
|
||||
self.num_stages = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
total_depth = sum(self.depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
cur_block_idx = 0
|
||||
for i, depth in enumerate(self.depths):
|
||||
patch_embed = VANPatchEmbed(
|
||||
in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
|
||||
input_size=None,
|
||||
embed_dims=self.embed_dims[i],
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=patch_sizes[i] // 2 + 1,
|
||||
padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
blocks = ModuleList([
|
||||
VANBlock(
|
||||
embed_dims=self.embed_dims[i],
|
||||
ffn_ratio=self.ffn_ratios[i],
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[cur_block_idx + j],
|
||||
**block_cfgs) for j in range(depth)
|
||||
])
|
||||
cur_block_idx += depth
|
||||
norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
|
||||
|
||||
self.add_module(f'patch_embed{i + 1}', patch_embed)
|
||||
self.add_module(f'blocks{i + 1}', blocks)
|
||||
self.add_module(f'norm{i + 1}', norm)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(VAN, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
# freeze patch embed
|
||||
m = getattr(self, f'patch_embed{i + 1}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze blocks
|
||||
m = getattr(self, f'blocks{i + 1}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze norm
|
||||
m = getattr(self, f'norm{i + 1}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
||||
blocks = getattr(self, f'blocks{i + 1}')
|
||||
norm = getattr(self, f'norm{i + 1}')
|
||||
x, hw_shape = patch_embed(x)
|
||||
for block in blocks:
|
||||
x = block(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = norm(x)
|
||||
x = x.reshape(-1, *hw_shape,
|
||||
block.out_channels).permute(0, 3, 1, 2).contiguous()
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
|
@ -22,6 +22,7 @@ Import:
|
|||
- configs/hrnet/metafile.yml
|
||||
- configs/repmlp/metafile.yml
|
||||
- configs/wrn/metafile.yml
|
||||
- configs/van/metafile.yml
|
||||
- configs/cspnet/metafile.yml
|
||||
- configs/convmixer/metafile.yml
|
||||
- configs/poolformer/metafile.yml
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models.backbones import VAN
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestVAN(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(arch='t', drop_path_rate=0.1)
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
VAN(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': [32, 64, 160, 256],
|
||||
'ffn_ratios': [8, 8, 4, 4],
|
||||
}
|
||||
VAN(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
embed_dims = [32, 64, 160, 256]
|
||||
depths = [3, 3, 5, 2]
|
||||
ffn_ratios = [8, 8, 4, 4]
|
||||
cfg['arch'] = {
|
||||
'embed_dims': embed_dims,
|
||||
'depths': depths,
|
||||
'ffn_ratios': ffn_ratios
|
||||
}
|
||||
model = VAN(**cfg)
|
||||
|
||||
for i in range(len(depths)):
|
||||
stage = getattr(model, f'blocks{i + 1}')
|
||||
self.assertEqual(stage[-1].out_channels, embed_dims[i])
|
||||
self.assertEqual(len(stage), depths[i])
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = VAN(**cfg)
|
||||
ori_weight = model.patch_embed1.projection.weight.clone().detach()
|
||||
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed1.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VAN(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 256, 7, 7))
|
||||
|
||||
# test with patch_sizes
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['patch_sizes'] = [7, 5, 5, 5]
|
||||
model = VAN(**cfg)
|
||||
outs = model(torch.randn(3, 3, 224, 224))
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 256, 3, 3))
|
||||
|
||||
# test multiple output indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = VAN(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 4)
|
||||
for emb_size, stride, out in zip([32, 64, 160, 256], [1, 2, 4, 8],
|
||||
outs):
|
||||
self.assertEqual(out.shape,
|
||||
(3, emb_size, 56 // stride, 56 // stride))
|
||||
|
||||
# test with dynamic input shape
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VAN(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
|
||||
math.ceil(imgs.shape[3] / 32))
|
||||
self.assertEqual(feat.shape, (3, 256, *expect_feat_shape))
|
||||
|
||||
def test_structure(self):
|
||||
# test drop_path_rate decay
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.2
|
||||
model = VAN(**cfg)
|
||||
depths = model.arch_settings['depths']
|
||||
stages = [model.blocks1, model.blocks2, model.blocks3, model.blocks4]
|
||||
blocks = chain(*[stage for stage in stages])
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
|
||||
]
|
||||
for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
|
||||
if expect_prob == 0:
|
||||
assert isinstance(block.drop_path, nn.Identity)
|
||||
else:
|
||||
self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
|
||||
|
||||
# test VAN with norm_eval=True
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['norm_eval'] = True
|
||||
cfg['norm_cfg'] = dict(type='BN')
|
||||
model = VAN(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
self.assertTrue(check_norm_state(model.modules(), False))
|
||||
|
||||
# test VAN with first stage frozen.
|
||||
cfg = deepcopy(self.cfg)
|
||||
frozen_stages = 0
|
||||
cfg['frozen_stages'] = frozen_stages
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = VAN(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# the patch_embed and first stage should not require grad.
|
||||
self.assertFalse(model.patch_embed1.training)
|
||||
for param in model.patch_embed1.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
for i in range(frozen_stages + 1):
|
||||
patch = getattr(model, f'patch_embed{i+1}')
|
||||
for param in patch.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
blocks = getattr(model, f'blocks{i + 1}')
|
||||
for param in blocks.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
norm = getattr(model, f'norm{i + 1}')
|
||||
for param in norm.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
|
||||
# the second stage should require grad.
|
||||
for i in range(frozen_stages + 1, 4):
|
||||
patch = getattr(model, f'patch_embed{i + 1}')
|
||||
for param in patch.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
blocks = getattr(model, f'blocks{i+1}')
|
||||
for param in blocks.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
norm = getattr(model, f'norm{i + 1}')
|
||||
for param in norm.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_van(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('head'):
|
||||
new_k = k.replace('head.', 'head.fc.')
|
||||
new_ckpt[new_k] = new_v
|
||||
continue
|
||||
elif k.startswith('patch_embed'):
|
||||
if 'proj.' in k:
|
||||
new_k = k.replace('proj.', 'projection.')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('block'):
|
||||
new_k = k.replace('block', 'blocks')
|
||||
if 'attn.spatial_gating_unit' in new_k:
|
||||
new_k = new_k.replace('conv0', 'DW_conv')
|
||||
new_k = new_k.replace('conv_spatial', 'DW_D_conv')
|
||||
if 'dwconv.dwconv' in new_k:
|
||||
new_k = new_k.replace('dwconv.dwconv', 'dwconv')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
if not new_k.startswith('head'):
|
||||
new_k = 'backbone.' + new_k
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained van models to mmcls style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_van(state_dict)
|
||||
mmcv.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue