[Feature] Implement the conformer backbone. (#494)
* implement the conformer * format code style * format code style * reuse the TransformerEncoderLayer in the vision_transformer.py * Modify variable name * delete unused params * Remove warning info in Conformer head since it already exists in Conformer. * Rename some variables * Add unit tests * Use `getattr` instead of `get_submodule`. * Remove some useless layers * Refactor conformer and add configs * Update configs and add metafile. * Fix unit tests * Update README Co-authored-by: mzr1996 <mzr1996@163.com>pull/580/head
parent
0aa789f3c3
commit
18f6bb0b10
|
@ -0,0 +1,22 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='Conformer', arch='base', drop_path_rate=0.1, init_cfg=None),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='ConformerHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=[1536, 576],
|
||||||
|
init_cfg=None,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||||
|
cal_acc=False),
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||||
|
],
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||||
|
]))
|
|
@ -0,0 +1,22 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='Conformer', arch='small', drop_path_rate=0.1, init_cfg=None),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='ConformerHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=[1024, 384],
|
||||||
|
init_cfg=None,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||||
|
cal_acc=False),
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||||
|
],
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||||
|
]))
|
|
@ -0,0 +1,26 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='Conformer',
|
||||||
|
arch='small',
|
||||||
|
patch_size=32,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
init_cfg=None),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='ConformerHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=[1024, 384],
|
||||||
|
init_cfg=None,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||||
|
cal_acc=False),
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||||
|
],
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||||
|
]))
|
|
@ -0,0 +1,22 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='Conformer', arch='tiny', drop_path_rate=0.1, init_cfg=None),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='ConformerHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=[256, 384],
|
||||||
|
init_cfg=None,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||||
|
cal_acc=False),
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||||
|
],
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||||
|
]))
|
|
@ -0,0 +1,29 @@
|
||||||
|
paramwise_cfg = dict(
|
||||||
|
norm_decay_mult=0.0,
|
||||||
|
bias_decay_mult=0.0,
|
||||||
|
custom_keys={
|
||||||
|
'.cls_token': dict(decay_mult=0.0),
|
||||||
|
})
|
||||||
|
|
||||||
|
# for batch in each gpu is 128, 8 gpu
|
||||||
|
# lr = 5e-4 * 128 * 8 / 512 = 0.001
|
||||||
|
optimizer = dict(
|
||||||
|
type='AdamW',
|
||||||
|
lr=5e-4 * 128 * 8 / 512,
|
||||||
|
weight_decay=0.05,
|
||||||
|
eps=1e-8,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
paramwise_cfg=paramwise_cfg)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
by_epoch=False,
|
||||||
|
min_lr_ratio=1e-2,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_ratio=1e-3,
|
||||||
|
warmup_iters=5 * 1252,
|
||||||
|
warmup_by_epoch=False)
|
||||||
|
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Conformer: Local Features Coupling Global Representations for Visual Recognition
|
||||||
|
<!-- {Conformer} -->
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
<!-- [ABSTRACT] -->
|
||||||
|
Within Convolutional Neural Network (CNN), the convolution operations are good at extracting local features but experience difficulty to capture global representations. Within visual transformer, the cascaded self-attention modules can capture long-distance feature dependencies but unfortunately deteriorate local feature details. In this paper, we propose a hybrid network structure, termed Conformer, to take advantage of convolutional operations and self-attention mechanisms for enhanced representation learning. Conformer roots in the Feature Coupling Unit (FCU), which fuses local features and global representations under different resolutions in an interactive fashion. Conformer adopts a concurrent structure so that local features and global representations are retained to the maximum extent. Experiments show that Conformer, under the comparable parameter complexity, outperforms the visual transformer (DeiT-B) by 2.3% on ImageNet. On MSCOCO, it outperforms ResNet-101 by 3.7% and 3.6% mAPs for object detection and instance segmentation, respectively, demonstrating the great potential to be a general backbone network.
|
||||||
|
|
||||||
|
<!-- [IMAGE] -->
|
||||||
|
<div align=center>
|
||||||
|
<img src="https://user-images.githubusercontent.com/26739999/144957687-926390ed-6119-4e4c-beaa-9bc0017fe953.png" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{peng2021conformer,
|
||||||
|
title={Conformer: Local Features Coupling Global Representations for Visual Recognition},
|
||||||
|
author={Zhiliang Peng and Wei Huang and Shanzhi Gu and Lingxi Xie and Yaowei Wang and Jianbin Jiao and Qixiang Ye},
|
||||||
|
journal={arXiv preprint arXiv:2105.03889},
|
||||||
|
year={2021},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
Some pre-trained models are converted from [official repo](https://github.com/pengzhiliang/Conformer).
|
||||||
|
|
||||||
|
## ImageNet-1k
|
||||||
|
|
||||||
|
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||||
|
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||||
|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) |
|
||||||
|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) |
|
||||||
|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) |
|
||||||
|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) |
|
||||||
|
|
||||||
|
*Models with \* are converted from other repos.*
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/conformer/base-p16.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||||
|
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(samples_per_gpu=128)
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/conformer/small-p16.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||||
|
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(samples_per_gpu=128)
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/conformer/small-p32.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||||
|
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(samples_per_gpu=128)
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,9 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/conformer/tiny-p16.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||||
|
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(samples_per_gpu=128)
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,78 @@
|
||||||
|
Collections:
|
||||||
|
- Name: Conformer
|
||||||
|
Metadata:
|
||||||
|
Training Data: ImageNet-1k
|
||||||
|
Architecture:
|
||||||
|
- Layer Normalization
|
||||||
|
- Scaled Dot-Product Attention
|
||||||
|
- Dropout
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/2105.03889
|
||||||
|
Title: "Conformer: Local Features Coupling Global Representations for Visual Recognition"
|
||||||
|
README: configs/conformer/README.md
|
||||||
|
# Code:
|
||||||
|
# URL: # todo
|
||||||
|
# Version: # todo
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: conformer-tiny-p16_3rdparty_8xb128_in1k
|
||||||
|
In Collection: Conformer
|
||||||
|
Config: configs/conformer/conformer-tiny-p16_8xb128_in1k.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4899611328
|
||||||
|
Parameters: 23524704
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 81.31
|
||||||
|
Top 5 Accuracy: 95.60
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/file/d/19SxGhKcWOR5oQSxNUWUM2MGYiaWMrF1z/view?usp=sharing
|
||||||
|
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L65
|
||||||
|
- Name: conformer-small-p16_3rdparty_8xb128_in1k
|
||||||
|
In Collection: Conformer
|
||||||
|
Config: configs/conformer/conformer-small-p16_8xb128_in1k.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 10311309312
|
||||||
|
Parameters: 37673424
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 83.32
|
||||||
|
Top 5 Accuracy: 96.46
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/file/d/1mpOlbLaVxOfEwV4-ha78j_1Ebqzj2B83/view?usp=sharing
|
||||||
|
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L73
|
||||||
|
- Name: conformer-small-p32_8xb128_in1k
|
||||||
|
In Collection: Conformer
|
||||||
|
Config: configs/conformer/conformer-small-p32_8xb128_in1k.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 7087281792
|
||||||
|
Parameters: 38853072
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 81.96
|
||||||
|
Top 5 Accuracy: 96.02
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth
|
||||||
|
- Name: conformer-base-p16_3rdparty_8xb128_in1k
|
||||||
|
In Collection: Conformer
|
||||||
|
Config: configs/conformer/conformer-base-p16_8xb128_in1k.py
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 22892078080
|
||||||
|
Parameters: 83289136
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 83.82
|
||||||
|
Top 5 Accuracy: 96.59
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://drive.google.com/file/d/1oeQ9LSOGKEUaYGu7WTlUGl3KDsQIi0MA/view?usp=sharing
|
||||||
|
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L89
|
|
@ -23,7 +23,7 @@ Transformers, which are popular for language modeling, have been explored for so
|
||||||
|
|
||||||
## Pretrain model
|
## Pretrain model
|
||||||
|
|
||||||
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
|
The pre-trained models are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
|
||||||
|
|
||||||
### ImageNet-1k
|
### ImageNet-1k
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,10 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
||||||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
||||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
|
||||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
|
||||||
|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
|
||||||
|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
|
||||||
|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
|
||||||
|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
|
||||||
|
|
||||||
|
|
||||||
Models with * are converted from other repos, others are trained by ourselves.
|
Models with * are converted from other repos, others are trained by ourselves.
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .alexnet import AlexNet
|
from .alexnet import AlexNet
|
||||||
|
from .conformer import Conformer
|
||||||
from .lenet import LeNet5
|
from .lenet import LeNet5
|
||||||
from .mlp_mixer import MlpMixer
|
from .mlp_mixer import MlpMixer
|
||||||
from .mobilenet_v2 import MobileNetV2
|
from .mobilenet_v2 import MobileNetV2
|
||||||
|
@ -27,5 +28,5 @@ __all__ = [
|
||||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||||
'MlpMixer'
|
'Conformer', 'MlpMixer'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,616 @@
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||||
|
from mmcv.cnn.bricks.drop import DropPath
|
||||||
|
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||||
|
|
||||||
|
from mmcls.utils import get_root_logger
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from .base_backbone import BaseBackbone, BaseModule
|
||||||
|
from .vision_transformer import TransformerEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(BaseModule):
|
||||||
|
"""Basic convluation block used in Conformer.
|
||||||
|
|
||||||
|
This block includes three convluation modules, and supports three new
|
||||||
|
functions:
|
||||||
|
1. Returns the output of both the final layers and the second convluation
|
||||||
|
module.
|
||||||
|
2. Fuses the input of the second convluation module with an extra input
|
||||||
|
feature map.
|
||||||
|
3. Supports to add an extra convluation module to the identity connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
stride (int): The stride of the second convluation module.
|
||||||
|
Defaults to 1.
|
||||||
|
groups (int): The groups of the second convluation module.
|
||||||
|
Defaults to 1.
|
||||||
|
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
|
||||||
|
with_residual_conv (bool): Whether to add an extra convluation module
|
||||||
|
to the identity connection. Defaults to False.
|
||||||
|
norm_cfg (dict): The config of normalization layers.
|
||||||
|
Defaults to ``dict(type='BN', eps=1e-6)``.
|
||||||
|
act_cfg (dict): The config of activative functions.
|
||||||
|
Defaults to ``dict(type='ReLU', inplace=True))``.
|
||||||
|
init_cfg (dict, optional): The extra config to initialize the module.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
with_residual_conv=False,
|
||||||
|
norm_cfg=dict(type='BN', eps=1e-6),
|
||||||
|
act_cfg=dict(type='ReLU', inplace=True),
|
||||||
|
init_cfg=None):
|
||||||
|
super(ConvBlock, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
expansion = 4
|
||||||
|
mid_channels = out_channels // expansion
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
mid_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
|
||||||
|
self.act1 = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
mid_channels,
|
||||||
|
mid_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
|
||||||
|
self.act2 = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False)
|
||||||
|
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
|
||||||
|
self.act3 = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
if with_residual_conv:
|
||||||
|
self.residual_conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
bias=False)
|
||||||
|
self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
|
||||||
|
|
||||||
|
self.with_residual_conv = with_residual_conv
|
||||||
|
self.drop_path = DropPath(
|
||||||
|
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def zero_init_last_bn(self):
|
||||||
|
nn.init.zeros_(self.bn3.weight)
|
||||||
|
|
||||||
|
def forward(self, x, fusion_features=None, out_conv2=True):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
|
||||||
|
x = self.conv2(x) if fusion_features is None else self.conv2(
|
||||||
|
x + fusion_features)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x2 = self.act2(x)
|
||||||
|
|
||||||
|
x = self.conv3(x2)
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
if self.with_residual_conv:
|
||||||
|
identity = self.residual_conv(identity)
|
||||||
|
identity = self.residual_bn(identity)
|
||||||
|
|
||||||
|
x += identity
|
||||||
|
x = self.act3(x)
|
||||||
|
|
||||||
|
if out_conv2:
|
||||||
|
return x, x2
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FCUDown(BaseModule):
|
||||||
|
"""CNN feature maps -> Transformer patch embeddings."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
down_stride,
|
||||||
|
with_cls_token=True,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(FCUDown, self).__init__(init_cfg=init_cfg)
|
||||||
|
self.down_stride = down_stride
|
||||||
|
self.with_cls_token = with_cls_token
|
||||||
|
|
||||||
|
self.conv_project = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.sample_pooling = nn.AvgPool2d(
|
||||||
|
kernel_size=down_stride, stride=down_stride)
|
||||||
|
|
||||||
|
self.ln = build_norm_layer(norm_cfg, out_channels)[1]
|
||||||
|
self.act = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
def forward(self, x, x_t):
|
||||||
|
x = self.conv_project(x) # [N, C, H, W]
|
||||||
|
|
||||||
|
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
|
||||||
|
x = self.ln(x)
|
||||||
|
x = self.act(x)
|
||||||
|
|
||||||
|
if self.with_cls_token:
|
||||||
|
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FCUUp(BaseModule):
|
||||||
|
"""Transformer patch embeddings -> CNN feature maps."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
up_stride,
|
||||||
|
with_cls_token=True,
|
||||||
|
norm_cfg=dict(type='BN', eps=1e-6),
|
||||||
|
act_cfg=dict(type='ReLU', inplace=True),
|
||||||
|
init_cfg=None):
|
||||||
|
super(FCUUp, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
self.up_stride = up_stride
|
||||||
|
self.with_cls_token = with_cls_token
|
||||||
|
|
||||||
|
self.conv_project = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.bn = build_norm_layer(norm_cfg, out_channels)[1]
|
||||||
|
self.act = build_activation_layer(act_cfg)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, _, C = x.shape
|
||||||
|
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
|
||||||
|
if self.with_cls_token:
|
||||||
|
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
else:
|
||||||
|
x_r = x.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
|
||||||
|
x_r = self.act(self.bn(self.conv_project(x_r)))
|
||||||
|
|
||||||
|
return F.interpolate(
|
||||||
|
x_r, size=(H * self.up_stride, W * self.up_stride))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvTransBlock(BaseModule):
|
||||||
|
"""Basic module for Conformer.
|
||||||
|
|
||||||
|
This module is a fusion of CNN block transformer encoder block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels in conv blocks.
|
||||||
|
out_channels (int): The number of output channels in conv blocks.
|
||||||
|
embed_dims (int): The embedding dimension in transformer blocks.
|
||||||
|
conv_stride (int): The stride of conv2d layers. Defaults to 1.
|
||||||
|
groups (int): The groups of conv blocks. Defaults to 1.
|
||||||
|
with_residual_conv (bool): Whether to add a conv-bn layer to the
|
||||||
|
identity connect in the conv block. Defaults to False.
|
||||||
|
down_stride (int): The stride of the downsample pooling layer.
|
||||||
|
Defaults to 4.
|
||||||
|
num_heads (int): The number of heads in transformer attention layers.
|
||||||
|
Defaults to 12.
|
||||||
|
mlp_ratio (float): The expansion ratio in transformer FFN module.
|
||||||
|
Defaults to 4.
|
||||||
|
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
|
||||||
|
with_cls_token (bool): Whether use class token or not.
|
||||||
|
Defaults to True.
|
||||||
|
drop_rate (float): The dropout rate of the output projection and
|
||||||
|
FFN in the transformer block. Defaults to 0.
|
||||||
|
attn_drop_rate (float): The dropout rate after the attention
|
||||||
|
calculation in the transformer block. Defaults to 0.
|
||||||
|
drop_path_rate (bloat): The drop path rate in both the conv block
|
||||||
|
and the transformer block. Defaults to 0.
|
||||||
|
last_fusion (bool): Whether this block is the last stage. If so,
|
||||||
|
downsample the fusion feature map.
|
||||||
|
init_cfg (dict, optional): The extra config to initialize the module.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
embed_dims,
|
||||||
|
conv_stride=1,
|
||||||
|
groups=1,
|
||||||
|
with_residual_conv=False,
|
||||||
|
down_stride=4,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
with_cls_token=True,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
last_fusion=False,
|
||||||
|
init_cfg=None):
|
||||||
|
super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
|
||||||
|
expansion = 4
|
||||||
|
self.cnn_block = ConvBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
with_residual_conv=with_residual_conv,
|
||||||
|
stride=conv_stride,
|
||||||
|
groups=groups)
|
||||||
|
|
||||||
|
if last_fusion:
|
||||||
|
self.fusion_block = ConvBlock(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=2,
|
||||||
|
with_residual_conv=True,
|
||||||
|
groups=groups,
|
||||||
|
drop_path_rate=drop_path_rate)
|
||||||
|
else:
|
||||||
|
self.fusion_block = ConvBlock(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
groups=groups,
|
||||||
|
drop_path_rate=drop_path_rate)
|
||||||
|
|
||||||
|
self.squeeze_block = FCUDown(
|
||||||
|
in_channels=out_channels // expansion,
|
||||||
|
out_channels=embed_dims,
|
||||||
|
down_stride=down_stride,
|
||||||
|
with_cls_token=with_cls_token)
|
||||||
|
|
||||||
|
self.expand_block = FCUUp(
|
||||||
|
in_channels=embed_dims,
|
||||||
|
out_channels=out_channels // expansion,
|
||||||
|
up_stride=down_stride,
|
||||||
|
with_cls_token=with_cls_token)
|
||||||
|
|
||||||
|
self.trans_block = TransformerEncoderLayer(
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
num_heads=num_heads,
|
||||||
|
feedforward_channels=int(embed_dims * mlp_ratio),
|
||||||
|
drop_rate=drop_rate,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
attn_drop_rate=attn_drop_rate,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-6))
|
||||||
|
|
||||||
|
self.down_stride = down_stride
|
||||||
|
self.embed_dim = embed_dims
|
||||||
|
self.last_fusion = last_fusion
|
||||||
|
|
||||||
|
def forward(self, cnn_input, trans_input):
|
||||||
|
x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
|
||||||
|
|
||||||
|
_, _, H, W = x_conv2.shape
|
||||||
|
|
||||||
|
# Convert the feature map of conv2 to transformer embedding
|
||||||
|
# and concat with class token.
|
||||||
|
conv2_embedding = self.squeeze_block(x_conv2, trans_input)
|
||||||
|
|
||||||
|
trans_output = self.trans_block(conv2_embedding + trans_input)
|
||||||
|
|
||||||
|
# Convert the transformer output embedding to feature map
|
||||||
|
trans_features = self.expand_block(trans_output, H // self.down_stride,
|
||||||
|
W // self.down_stride)
|
||||||
|
x = self.fusion_block(
|
||||||
|
x, fusion_features=trans_features, out_conv2=False)
|
||||||
|
|
||||||
|
return x, trans_output
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class Conformer(BaseBackbone):
|
||||||
|
"""Conformer backbone.
|
||||||
|
|
||||||
|
A PyTorch implementation of : `Conformer: Local Features Coupling Global
|
||||||
|
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
|
||||||
|
patch_size (int): The patch size. Defaults to 16.
|
||||||
|
base_channels (int): The base number of channels in CNN network.
|
||||||
|
Defaults to 64.
|
||||||
|
mlp_ratio (float): The expansion ratio of FFN network in transformer
|
||||||
|
block. Defaults to 4.
|
||||||
|
with_cls_token (bool): Whether use class token or not.
|
||||||
|
Defaults to True.
|
||||||
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||||
|
out_indices (Sequence | int): Output from which stages.
|
||||||
|
Defaults to -1, means the last stage.
|
||||||
|
init_cfg (dict, optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
arch_zoo = {
|
||||||
|
**dict.fromkeys(['t', 'tiny'],
|
||||||
|
{'embed_dims': 384,
|
||||||
|
'channel_ratio': 1,
|
||||||
|
'num_heads': 6,
|
||||||
|
'depths': 12
|
||||||
|
}),
|
||||||
|
**dict.fromkeys(['s', 'small'],
|
||||||
|
{'embed_dims': 384,
|
||||||
|
'channel_ratio': 4,
|
||||||
|
'num_heads': 6,
|
||||||
|
'depths': 12
|
||||||
|
}),
|
||||||
|
**dict.fromkeys(['b', 'base'],
|
||||||
|
{'embed_dims': 576,
|
||||||
|
'channel_ratio': 6,
|
||||||
|
'num_heads': 9,
|
||||||
|
'depths': 12
|
||||||
|
}),
|
||||||
|
} # yapf: disable
|
||||||
|
|
||||||
|
_version = 1
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
arch='tiny',
|
||||||
|
patch_size=16,
|
||||||
|
base_channels=64,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
with_cls_token=True,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
norm_eval=True,
|
||||||
|
frozen_stages=0,
|
||||||
|
out_indices=-1,
|
||||||
|
init_cfg=None):
|
||||||
|
|
||||||
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
if isinstance(arch, str):
|
||||||
|
arch = arch.lower()
|
||||||
|
assert arch in set(self.arch_zoo), \
|
||||||
|
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||||
|
self.arch_settings = self.arch_zoo[arch]
|
||||||
|
else:
|
||||||
|
essential_keys = {
|
||||||
|
'embed_dims', 'depths', 'num_heads', 'channel_ratio'
|
||||||
|
}
|
||||||
|
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||||
|
f'Custom arch needs a dict with keys {essential_keys}'
|
||||||
|
self.arch_settings = arch
|
||||||
|
|
||||||
|
self.num_features = self.embed_dims = self.arch_settings['embed_dims']
|
||||||
|
self.depths = self.arch_settings['depths']
|
||||||
|
self.num_heads = self.arch_settings['num_heads']
|
||||||
|
self.channel_ratio = self.arch_settings['channel_ratio']
|
||||||
|
|
||||||
|
if isinstance(out_indices, int):
|
||||||
|
out_indices = [out_indices]
|
||||||
|
assert isinstance(out_indices, Sequence), \
|
||||||
|
f'"out_indices" must by a sequence or int, ' \
|
||||||
|
f'get {type(out_indices)} instead.'
|
||||||
|
for i, index in enumerate(out_indices):
|
||||||
|
if index < 0:
|
||||||
|
out_indices[i] = self.depths + index + 1
|
||||||
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||||
|
self.out_indices = out_indices
|
||||||
|
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
|
||||||
|
self.with_cls_token = with_cls_token
|
||||||
|
if self.with_cls_token:
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||||
|
|
||||||
|
# stochastic depth decay rule
|
||||||
|
self.trans_dpr = [
|
||||||
|
x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Stem stage: get the feature maps by conv block
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
3, 64, kernel_size=7, stride=2, padding=3,
|
||||||
|
bias=False) # 1 / 2 [112, 112]
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(
|
||||||
|
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
|
||||||
|
|
||||||
|
# 1 stage
|
||||||
|
stage1_channels = int(base_channels * self.channel_ratio)
|
||||||
|
trans_down_stride = patch_size // 4
|
||||||
|
self.conv_1 = ConvBlock(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=stage1_channels,
|
||||||
|
with_residual_conv=True,
|
||||||
|
stride=1)
|
||||||
|
self.trans_patch_conv = nn.Conv2d(
|
||||||
|
64,
|
||||||
|
self.embed_dims,
|
||||||
|
kernel_size=trans_down_stride,
|
||||||
|
stride=trans_down_stride,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
self.trans_1 = TransformerEncoderLayer(
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
feedforward_channels=int(self.embed_dims * mlp_ratio),
|
||||||
|
drop_path_rate=self.trans_dpr[0],
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-6))
|
||||||
|
|
||||||
|
# 2~4 stage
|
||||||
|
init_stage = 2
|
||||||
|
fin_stage = self.depths // 3 + 1
|
||||||
|
for i in range(init_stage, fin_stage):
|
||||||
|
self.add_module(
|
||||||
|
f'conv_trans_{i}',
|
||||||
|
ConvTransBlock(
|
||||||
|
in_channels=stage1_channels,
|
||||||
|
out_channels=stage1_channels,
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
conv_stride=1,
|
||||||
|
with_residual_conv=False,
|
||||||
|
down_stride=trans_down_stride,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop_path_rate=self.trans_dpr[i - 1],
|
||||||
|
with_cls_token=self.with_cls_token))
|
||||||
|
|
||||||
|
stage2_channels = int(base_channels * self.channel_ratio * 2)
|
||||||
|
# 5~8 stage
|
||||||
|
init_stage = fin_stage # 5
|
||||||
|
fin_stage = fin_stage + self.depths // 3 # 9
|
||||||
|
for i in range(init_stage, fin_stage):
|
||||||
|
if i == init_stage:
|
||||||
|
conv_stride = 2
|
||||||
|
in_channels = stage1_channels
|
||||||
|
else:
|
||||||
|
conv_stride = 1
|
||||||
|
in_channels = stage2_channels
|
||||||
|
|
||||||
|
with_residual_conv = True if i == init_stage else False
|
||||||
|
self.add_module(
|
||||||
|
f'conv_trans_{i}',
|
||||||
|
ConvTransBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=stage2_channels,
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
conv_stride=conv_stride,
|
||||||
|
with_residual_conv=with_residual_conv,
|
||||||
|
down_stride=trans_down_stride // 2,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop_path_rate=self.trans_dpr[i - 1],
|
||||||
|
with_cls_token=self.with_cls_token))
|
||||||
|
|
||||||
|
stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
|
||||||
|
# 9~12 stage
|
||||||
|
init_stage = fin_stage # 9
|
||||||
|
fin_stage = fin_stage + self.depths // 3 # 13
|
||||||
|
for i in range(init_stage, fin_stage):
|
||||||
|
if i == init_stage:
|
||||||
|
conv_stride = 2
|
||||||
|
in_channels = stage2_channels
|
||||||
|
with_residual_conv = True
|
||||||
|
else:
|
||||||
|
conv_stride = 1
|
||||||
|
in_channels = stage3_channels
|
||||||
|
with_residual_conv = False
|
||||||
|
|
||||||
|
last_fusion = (i == self.depths)
|
||||||
|
|
||||||
|
self.add_module(
|
||||||
|
f'conv_trans_{i}',
|
||||||
|
ConvTransBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=stage3_channels,
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
conv_stride=conv_stride,
|
||||||
|
with_residual_conv=with_residual_conv,
|
||||||
|
down_stride=trans_down_stride // 4,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
drop_path_rate=self.trans_dpr[i - 1],
|
||||||
|
with_cls_token=self.with_cls_token,
|
||||||
|
last_fusion=last_fusion))
|
||||||
|
self.fin_stage = fin_stage
|
||||||
|
|
||||||
|
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.trans_norm = nn.LayerNorm(self.embed_dims)
|
||||||
|
|
||||||
|
if self.with_cls_token:
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1.)
|
||||||
|
nn.init.constant_(m.bias, 0.)
|
||||||
|
|
||||||
|
if hasattr(m, 'zero_init_last_bn'):
|
||||||
|
m.zero_init_last_bn()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
super(Conformer, self).init_weights()
|
||||||
|
logger = get_root_logger()
|
||||||
|
|
||||||
|
if (isinstance(self.init_cfg, dict)
|
||||||
|
and self.init_cfg['type'] == 'Pretrained'):
|
||||||
|
# Suppress default init if use pretrained model.
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info(f'No pre-trained weights for '
|
||||||
|
f'{self.__class__.__name__}, '
|
||||||
|
f'training start from scratch')
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = []
|
||||||
|
B = x.shape[0]
|
||||||
|
if self.with_cls_token:
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
|
||||||
|
# stem
|
||||||
|
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
|
||||||
|
|
||||||
|
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
|
||||||
|
x = self.conv_1(x_base, out_conv2=False)
|
||||||
|
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
|
||||||
|
if self.with_cls_token:
|
||||||
|
x_t = torch.cat([cls_tokens, x_t], dim=1)
|
||||||
|
x_t = self.trans_1(x_t)
|
||||||
|
|
||||||
|
# 2 ~ final
|
||||||
|
for i in range(2, self.fin_stage):
|
||||||
|
stage = getattr(self, f'conv_trans_{i}')
|
||||||
|
x, x_t = stage(x, x_t)
|
||||||
|
if i in self.out_indices:
|
||||||
|
if self.with_cls_token:
|
||||||
|
output.append([
|
||||||
|
self.pooling(x).flatten(1),
|
||||||
|
self.trans_norm(x_t)[:, 0]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# if no class token, use the mean patch token
|
||||||
|
# as the transformer feature.
|
||||||
|
output.append([
|
||||||
|
self.pooling(x).flatten(1),
|
||||||
|
self.trans_norm(x_t).mean(dim=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
return tuple(output)
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
|
from .conformer_head import ConformerHead
|
||||||
from .linear_head import LinearClsHead
|
from .linear_head import LinearClsHead
|
||||||
from .multi_label_head import MultiLabelClsHead
|
from .multi_label_head import MultiLabelClsHead
|
||||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||||
|
@ -8,5 +9,5 @@ from .vision_transformer_head import VisionTransformerClsHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead'
|
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||||
|
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .cls_head import ClsHead
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class ConformerHead(ClsHead):
|
||||||
|
"""Linear classifier head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of categories excluding the background
|
||||||
|
category.
|
||||||
|
in_channels (int): Number of channels in the input feature map.
|
||||||
|
init_cfg (dict | optional): The extra init config of layers.
|
||||||
|
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes,
|
||||||
|
in_channels, # [conv_dim, trans_dim]
|
||||||
|
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.init_cfg = init_cfg
|
||||||
|
|
||||||
|
if self.num_classes <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'num_classes={num_classes} must be a positive integer')
|
||||||
|
|
||||||
|
self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes)
|
||||||
|
self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
super(ConformerHead, self).init_weights()
|
||||||
|
|
||||||
|
if (isinstance(self.init_cfg, dict)
|
||||||
|
and self.init_cfg['type'] == 'Pretrained'):
|
||||||
|
# Suppress default init if use pretrained model.
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def simple_test(self, x):
|
||||||
|
"""Test without augmentation."""
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[-1]
|
||||||
|
assert isinstance(x,
|
||||||
|
list) # There are two outputs in the Conformer model
|
||||||
|
|
||||||
|
conv_cls_score = self.conv_cls_head(x[0])
|
||||||
|
tran_cls_score = self.trans_cls_head(x[1])
|
||||||
|
|
||||||
|
cls_score = conv_cls_score + tran_cls_score
|
||||||
|
|
||||||
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
|
|
||||||
|
return self.post_process(pred)
|
||||||
|
|
||||||
|
def forward_train(self, x, gt_label):
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[-1]
|
||||||
|
assert isinstance(x, list) and len(x) == 2, \
|
||||||
|
'There should be two outputs in the Conformer model'
|
||||||
|
|
||||||
|
conv_cls_score = self.conv_cls_head(x[0])
|
||||||
|
tran_cls_score = self.trans_cls_head(x[1])
|
||||||
|
|
||||||
|
losses = self.loss([conv_cls_score, tran_cls_score], gt_label)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def loss(self, cls_score, gt_label):
|
||||||
|
num_samples = len(cls_score[0])
|
||||||
|
losses = dict()
|
||||||
|
# compute loss
|
||||||
|
loss = sum([
|
||||||
|
self.compute_loss(score, gt_label, avg_factor=num_samples) /
|
||||||
|
len(cls_score) for score in cls_score
|
||||||
|
])
|
||||||
|
if self.cal_acc:
|
||||||
|
# compute accuracy
|
||||||
|
acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label)
|
||||||
|
assert len(acc) == len(self.topk)
|
||||||
|
losses['accuracy'] = {
|
||||||
|
f'top-{k}': a
|
||||||
|
for k, a in zip(self.topk, acc)
|
||||||
|
}
|
||||||
|
losses['loss'] = loss
|
||||||
|
return losses
|
|
@ -13,3 +13,4 @@ Import:
|
||||||
- configs/vision_transformer/metafile.yml
|
- configs/vision_transformer/metafile.yml
|
||||||
- configs/t2t_vit/metafile.yml
|
- configs/t2t_vit/metafile.yml
|
||||||
- configs/mlp_mixer/metafile.yml
|
- configs/mlp_mixer/metafile.yml
|
||||||
|
- configs/conformer/metafile.yml
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import Conformer
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_conformer_backbone():
|
||||||
|
|
||||||
|
cfg_ori = dict(
|
||||||
|
arch='T',
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# test invalid arch
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['arch'] = 'unknown'
|
||||||
|
Conformer(**cfg)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# test arch without essential keys
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['arch'] = {'embed_dims': 24, 'channel_ratio': 6, 'num_heads': 9}
|
||||||
|
Conformer(**cfg)
|
||||||
|
|
||||||
|
# Test Conformer small model with patch size of 16
|
||||||
|
model = Conformer(**cfg_ori)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(3, 3, 224, 224)
|
||||||
|
conv_feature, transformer_feature = model(imgs)[-1]
|
||||||
|
assert conv_feature.shape == (3, 64 * 1 * 4
|
||||||
|
) # base_channels * channel_ratio * 4
|
||||||
|
assert transformer_feature.shape == (3, 384)
|
||||||
|
|
||||||
|
# Test custom arch Conformer without output cls token
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['arch'] = {
|
||||||
|
'embed_dims': 128,
|
||||||
|
'depths': 15,
|
||||||
|
'num_heads': 16,
|
||||||
|
'channel_ratio': 3,
|
||||||
|
}
|
||||||
|
cfg['with_cls_token'] = False
|
||||||
|
cfg['base_channels'] = 32
|
||||||
|
model = Conformer(**cfg)
|
||||||
|
conv_feature, transformer_feature = model(imgs)[-1]
|
||||||
|
assert conv_feature.shape == (3, 32 * 3 * 4)
|
||||||
|
assert transformer_feature.shape == (3, 128)
|
||||||
|
|
||||||
|
# Test ViT with multi out indices
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['out_indices'] = [4, 8, 12]
|
||||||
|
model = Conformer(**cfg)
|
||||||
|
outs = model(imgs)
|
||||||
|
assert len(outs) == 3
|
||||||
|
# stage 1
|
||||||
|
conv_feature, transformer_feature = outs[0]
|
||||||
|
assert conv_feature.shape == (3, 64 * 1)
|
||||||
|
assert transformer_feature.shape == (3, 384)
|
||||||
|
# stage 2
|
||||||
|
conv_feature, transformer_feature = outs[1]
|
||||||
|
assert conv_feature.shape == (3, 64 * 1 * 2)
|
||||||
|
assert transformer_feature.shape == (3, 384)
|
||||||
|
# stage 3
|
||||||
|
conv_feature, transformer_feature = outs[2]
|
||||||
|
assert conv_feature.shape == (3, 64 * 1 * 4)
|
||||||
|
assert transformer_feature.shape == (3, 384)
|
Loading…
Reference in New Issue