diff --git a/configs/_base_/models/conformer/base-p16.py b/configs/_base_/models/conformer/base-p16.py
new file mode 100644
index 00000000..157dcc98
--- /dev/null
+++ b/configs/_base_/models/conformer/base-p16.py
@@ -0,0 +1,22 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='Conformer', arch='base', drop_path_rate=0.1, init_cfg=None),
+ neck=None,
+ head=dict(
+ type='ConformerHead',
+ num_classes=1000,
+ in_channels=[1536, 576],
+ init_cfg=None,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/conformer/small-p16.py b/configs/_base_/models/conformer/small-p16.py
new file mode 100644
index 00000000..17298089
--- /dev/null
+++ b/configs/_base_/models/conformer/small-p16.py
@@ -0,0 +1,22 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='Conformer', arch='small', drop_path_rate=0.1, init_cfg=None),
+ neck=None,
+ head=dict(
+ type='ConformerHead',
+ num_classes=1000,
+ in_channels=[1024, 384],
+ init_cfg=None,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/conformer/small-p32.py b/configs/_base_/models/conformer/small-p32.py
new file mode 100644
index 00000000..593aba12
--- /dev/null
+++ b/configs/_base_/models/conformer/small-p32.py
@@ -0,0 +1,26 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='Conformer',
+ arch='small',
+ patch_size=32,
+ drop_path_rate=0.1,
+ init_cfg=None),
+ neck=None,
+ head=dict(
+ type='ConformerHead',
+ num_classes=1000,
+ in_channels=[1024, 384],
+ init_cfg=None,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/models/conformer/tiny-p16.py b/configs/_base_/models/conformer/tiny-p16.py
new file mode 100644
index 00000000..dad8ecae
--- /dev/null
+++ b/configs/_base_/models/conformer/tiny-p16.py
@@ -0,0 +1,22 @@
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='Conformer', arch='tiny', drop_path_rate=0.1, init_cfg=None),
+ neck=None,
+ head=dict(
+ type='ConformerHead',
+ num_classes=1000,
+ in_channels=[256, 384],
+ init_cfg=None,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/_base_/schedules/imagenet_bs1024_adamw_conformer.py b/configs/_base_/schedules/imagenet_bs1024_adamw_conformer.py
new file mode 100644
index 00000000..92f18017
--- /dev/null
+++ b/configs/_base_/schedules/imagenet_bs1024_adamw_conformer.py
@@ -0,0 +1,29 @@
+paramwise_cfg = dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ custom_keys={
+ '.cls_token': dict(decay_mult=0.0),
+ })
+
+# for batch in each gpu is 128, 8 gpu
+# lr = 5e-4 * 128 * 8 / 512 = 0.001
+optimizer = dict(
+ type='AdamW',
+ lr=5e-4 * 128 * 8 / 512,
+ weight_decay=0.05,
+ eps=1e-8,
+ betas=(0.9, 0.999),
+ paramwise_cfg=paramwise_cfg)
+optimizer_config = dict(grad_clip=None)
+
+# learning policy
+lr_config = dict(
+ policy='CosineAnnealing',
+ by_epoch=False,
+ min_lr_ratio=1e-2,
+ warmup='linear',
+ warmup_ratio=1e-3,
+ warmup_iters=5 * 1252,
+ warmup_by_epoch=False)
+
+runner = dict(type='EpochBasedRunner', max_epochs=300)
diff --git a/configs/conformer/README.md b/configs/conformer/README.md
new file mode 100644
index 00000000..45e79aaf
--- /dev/null
+++ b/configs/conformer/README.md
@@ -0,0 +1,39 @@
+# Conformer: Local Features Coupling Global Representations for Visual Recognition
+
+
+
+## Abstract
+
+
+Within Convolutional Neural Network (CNN), the convolution operations are good at extracting local features but experience difficulty to capture global representations. Within visual transformer, the cascaded self-attention modules can capture long-distance feature dependencies but unfortunately deteriorate local feature details. In this paper, we propose a hybrid network structure, termed Conformer, to take advantage of convolutional operations and self-attention mechanisms for enhanced representation learning. Conformer roots in the Feature Coupling Unit (FCU), which fuses local features and global representations under different resolutions in an interactive fashion. Conformer adopts a concurrent structure so that local features and global representations are retained to the maximum extent. Experiments show that Conformer, under the comparable parameter complexity, outperforms the visual transformer (DeiT-B) by 2.3% on ImageNet. On MSCOCO, it outperforms ResNet-101 by 3.7% and 3.6% mAPs for object detection and instance segmentation, respectively, demonstrating the great potential to be a general backbone network.
+
+
+
+

+
+
+## Citation
+
+```latex
+@article{peng2021conformer,
+ title={Conformer: Local Features Coupling Global Representations for Visual Recognition},
+ author={Zhiliang Peng and Wei Huang and Shanzhi Gu and Lingxi Xie and Yaowei Wang and Jianbin Jiao and Qixiang Ye},
+ journal={arXiv preprint arXiv:2105.03889},
+ year={2021},
+}
+```
+
+## Results and models
+
+Some pre-trained models are converted from [official repo](https://github.com/pengzhiliang/Conformer).
+
+## ImageNet-1k
+
+| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
+| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) |
+| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) |
+| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) |
+| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) |
+
+*Models with \* are converted from other repos.*
diff --git a/configs/conformer/conformer-base-p16_8xb128_in1k.py b/configs/conformer/conformer-base-p16_8xb128_in1k.py
new file mode 100644
index 00000000..29ed58be
--- /dev/null
+++ b/configs/conformer/conformer-base-p16_8xb128_in1k.py
@@ -0,0 +1,9 @@
+_base_ = [
+ '../_base_/models/conformer/base-p16.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
+ '../_base_/default_runtime.py'
+]
+
+data = dict(samples_per_gpu=128)
+evaluation = dict(interval=1, metric='accuracy')
diff --git a/configs/conformer/conformer-small-p16_8xb128_in1k.py b/configs/conformer/conformer-small-p16_8xb128_in1k.py
new file mode 100644
index 00000000..c40ed041
--- /dev/null
+++ b/configs/conformer/conformer-small-p16_8xb128_in1k.py
@@ -0,0 +1,9 @@
+_base_ = [
+ '../_base_/models/conformer/small-p16.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
+ '../_base_/default_runtime.py'
+]
+
+data = dict(samples_per_gpu=128)
+evaluation = dict(interval=1, metric='accuracy')
diff --git a/configs/conformer/conformer-small-p32_8xb128_in1k.py b/configs/conformer/conformer-small-p32_8xb128_in1k.py
new file mode 100644
index 00000000..aaa11895
--- /dev/null
+++ b/configs/conformer/conformer-small-p32_8xb128_in1k.py
@@ -0,0 +1,9 @@
+_base_ = [
+ '../_base_/models/conformer/small-p32.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
+ '../_base_/default_runtime.py'
+]
+
+data = dict(samples_per_gpu=128)
+evaluation = dict(interval=1, metric='accuracy')
diff --git a/configs/conformer/conformer-tiny-p16_8xb128_in1k.py b/configs/conformer/conformer-tiny-p16_8xb128_in1k.py
new file mode 100644
index 00000000..76a264c6
--- /dev/null
+++ b/configs/conformer/conformer-tiny-p16_8xb128_in1k.py
@@ -0,0 +1,9 @@
+_base_ = [
+ '../_base_/models/conformer/tiny-p16.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
+ '../_base_/default_runtime.py'
+]
+
+data = dict(samples_per_gpu=128)
+evaluation = dict(interval=1, metric='accuracy')
diff --git a/configs/conformer/metafile.yml b/configs/conformer/metafile.yml
new file mode 100644
index 00000000..31d28740
--- /dev/null
+++ b/configs/conformer/metafile.yml
@@ -0,0 +1,78 @@
+Collections:
+ - Name: Conformer
+ Metadata:
+ Training Data: ImageNet-1k
+ Architecture:
+ - Layer Normalization
+ - Scaled Dot-Product Attention
+ - Dropout
+ Paper:
+ URL: https://arxiv.org/abs/2105.03889
+ Title: "Conformer: Local Features Coupling Global Representations for Visual Recognition"
+ README: configs/conformer/README.md
+# Code:
+# URL: # todo
+# Version: # todo
+
+Models:
+ - Name: conformer-tiny-p16_3rdparty_8xb128_in1k
+ In Collection: Conformer
+ Config: configs/conformer/conformer-tiny-p16_8xb128_in1k.py
+ Metadata:
+ FLOPs: 4899611328
+ Parameters: 23524704
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 81.31
+ Top 5 Accuracy: 95.60
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth
+ Converted From:
+ Weights: https://drive.google.com/file/d/19SxGhKcWOR5oQSxNUWUM2MGYiaWMrF1z/view?usp=sharing
+ Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L65
+ - Name: conformer-small-p16_3rdparty_8xb128_in1k
+ In Collection: Conformer
+ Config: configs/conformer/conformer-small-p16_8xb128_in1k.py
+ Metadata:
+ FLOPs: 10311309312
+ Parameters: 37673424
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.32
+ Top 5 Accuracy: 96.46
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth
+ Converted From:
+ Weights: https://drive.google.com/file/d/1mpOlbLaVxOfEwV4-ha78j_1Ebqzj2B83/view?usp=sharing
+ Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L73
+ - Name: conformer-small-p32_8xb128_in1k
+ In Collection: Conformer
+ Config: configs/conformer/conformer-small-p32_8xb128_in1k.py
+ Metadata:
+ FLOPs: 7087281792
+ Parameters: 38853072
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 81.96
+ Top 5 Accuracy: 96.02
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth
+ - Name: conformer-base-p16_3rdparty_8xb128_in1k
+ In Collection: Conformer
+ Config: configs/conformer/conformer-base-p16_8xb128_in1k.py
+ Metadata:
+ FLOPs: 22892078080
+ Parameters: 83289136
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.82
+ Top 5 Accuracy: 96.59
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth
+ Converted From:
+ Weights: https://drive.google.com/file/d/1oeQ9LSOGKEUaYGu7WTlUGl3KDsQIi0MA/view?usp=sharing
+ Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L89
diff --git a/configs/t2t_vit/README.md b/configs/t2t_vit/README.md
index fd4b8eb0..c4b7b092 100644
--- a/configs/t2t_vit/README.md
+++ b/configs/t2t_vit/README.md
@@ -23,7 +23,7 @@ Transformers, which are popular for language modeling, have been explored for so
## Pretrain model
-The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
+The pre-trained models are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
### ImageNet-1k
diff --git a/docs/model_zoo.md b/docs/model_zoo.md
index 1cc7f3fd..7f378e0d 100644
--- a/docs/model_zoo.md
+++ b/docs/model_zoo.md
@@ -63,6 +63,10 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
+| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
+| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
+| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
+| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
Models with * are converted from other repos, others are trained by ourselves.
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index b37c8c37..f9dbf705 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
+from .conformer import Conformer
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@@ -27,5 +28,5 @@ __all__ = [
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
- 'MlpMixer'
+ 'Conformer', 'MlpMixer'
]
diff --git a/mmcls/models/backbones/conformer.py b/mmcls/models/backbones/conformer.py
new file mode 100644
index 00000000..0eab9c6a
--- /dev/null
+++ b/mmcls/models/backbones/conformer.py
@@ -0,0 +1,616 @@
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.drop import DropPath
+from mmcv.cnn.utils.weight_init import trunc_normal_
+
+from mmcls.utils import get_root_logger
+from ..builder import BACKBONES
+from .base_backbone import BaseBackbone, BaseModule
+from .vision_transformer import TransformerEncoderLayer
+
+
+class ConvBlock(BaseModule):
+ """Basic convluation block used in Conformer.
+
+ This block includes three convluation modules, and supports three new
+ functions:
+ 1. Returns the output of both the final layers and the second convluation
+ module.
+ 2. Fuses the input of the second convluation module with an extra input
+ feature map.
+ 3. Supports to add an extra convluation module to the identity connection.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ stride (int): The stride of the second convluation module.
+ Defaults to 1.
+ groups (int): The groups of the second convluation module.
+ Defaults to 1.
+ drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
+ with_residual_conv (bool): Whether to add an extra convluation module
+ to the identity connection. Defaults to False.
+ norm_cfg (dict): The config of normalization layers.
+ Defaults to ``dict(type='BN', eps=1e-6)``.
+ act_cfg (dict): The config of activative functions.
+ Defaults to ``dict(type='ReLU', inplace=True))``.
+ init_cfg (dict, optional): The extra config to initialize the module.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ groups=1,
+ drop_path_rate=0.,
+ with_residual_conv=False,
+ norm_cfg=dict(type='BN', eps=1e-6),
+ act_cfg=dict(type='ReLU', inplace=True),
+ init_cfg=None):
+ super(ConvBlock, self).__init__(init_cfg=init_cfg)
+
+ expansion = 4
+ mid_channels = out_channels // expansion
+
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
+ self.act1 = build_activation_layer(act_cfg)
+
+ self.conv2 = nn.Conv2d(
+ mid_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=groups,
+ padding=1,
+ bias=False)
+ self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
+ self.act2 = build_activation_layer(act_cfg)
+
+ self.conv3 = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act3 = build_activation_layer(act_cfg)
+
+ if with_residual_conv:
+ self.residual_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ bias=False)
+ self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
+
+ self.with_residual_conv = with_residual_conv
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x, fusion_features=None, out_conv2=True):
+ identity = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.conv2(x) if fusion_features is None else self.conv2(
+ x + fusion_features)
+ x = self.bn2(x)
+ x2 = self.act2(x)
+
+ x = self.conv3(x2)
+ x = self.bn3(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.with_residual_conv:
+ identity = self.residual_conv(identity)
+ identity = self.residual_bn(identity)
+
+ x += identity
+ x = self.act3(x)
+
+ if out_conv2:
+ return x, x2
+ else:
+ return x
+
+
+class FCUDown(BaseModule):
+ """CNN feature maps -> Transformer patch embeddings."""
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ down_stride,
+ with_cls_token=True,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super(FCUDown, self).__init__(init_cfg=init_cfg)
+ self.down_stride = down_stride
+ self.with_cls_token = with_cls_token
+
+ self.conv_project = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.sample_pooling = nn.AvgPool2d(
+ kernel_size=down_stride, stride=down_stride)
+
+ self.ln = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x, x_t):
+ x = self.conv_project(x) # [N, C, H, W]
+
+ x = self.sample_pooling(x).flatten(2).transpose(1, 2)
+ x = self.ln(x)
+ x = self.act(x)
+
+ if self.with_cls_token:
+ x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
+
+ return x
+
+
+class FCUUp(BaseModule):
+ """Transformer patch embeddings -> CNN feature maps."""
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ up_stride,
+ with_cls_token=True,
+ norm_cfg=dict(type='BN', eps=1e-6),
+ act_cfg=dict(type='ReLU', inplace=True),
+ init_cfg=None):
+ super(FCUUp, self).__init__(init_cfg=init_cfg)
+
+ self.up_stride = up_stride
+ self.with_cls_token = with_cls_token
+
+ self.conv_project = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.bn = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x, H, W):
+ B, _, C = x.shape
+ # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
+ if self.with_cls_token:
+ x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
+ else:
+ x_r = x.transpose(1, 2).reshape(B, C, H, W)
+
+ x_r = self.act(self.bn(self.conv_project(x_r)))
+
+ return F.interpolate(
+ x_r, size=(H * self.up_stride, W * self.up_stride))
+
+
+class ConvTransBlock(BaseModule):
+ """Basic module for Conformer.
+
+ This module is a fusion of CNN block transformer encoder block.
+
+ Args:
+ in_channels (int): The number of input channels in conv blocks.
+ out_channels (int): The number of output channels in conv blocks.
+ embed_dims (int): The embedding dimension in transformer blocks.
+ conv_stride (int): The stride of conv2d layers. Defaults to 1.
+ groups (int): The groups of conv blocks. Defaults to 1.
+ with_residual_conv (bool): Whether to add a conv-bn layer to the
+ identity connect in the conv block. Defaults to False.
+ down_stride (int): The stride of the downsample pooling layer.
+ Defaults to 4.
+ num_heads (int): The number of heads in transformer attention layers.
+ Defaults to 12.
+ mlp_ratio (float): The expansion ratio in transformer FFN module.
+ Defaults to 4.
+ qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
+ with_cls_token (bool): Whether use class token or not.
+ Defaults to True.
+ drop_rate (float): The dropout rate of the output projection and
+ FFN in the transformer block. Defaults to 0.
+ attn_drop_rate (float): The dropout rate after the attention
+ calculation in the transformer block. Defaults to 0.
+ drop_path_rate (bloat): The drop path rate in both the conv block
+ and the transformer block. Defaults to 0.
+ last_fusion (bool): Whether this block is the last stage. If so,
+ downsample the fusion feature map.
+ init_cfg (dict, optional): The extra config to initialize the module.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ embed_dims,
+ conv_stride=1,
+ groups=1,
+ with_residual_conv=False,
+ down_stride=4,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ with_cls_token=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ last_fusion=False,
+ init_cfg=None):
+ super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
+ expansion = 4
+ self.cnn_block = ConvBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ with_residual_conv=with_residual_conv,
+ stride=conv_stride,
+ groups=groups)
+
+ if last_fusion:
+ self.fusion_block = ConvBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ stride=2,
+ with_residual_conv=True,
+ groups=groups,
+ drop_path_rate=drop_path_rate)
+ else:
+ self.fusion_block = ConvBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ groups=groups,
+ drop_path_rate=drop_path_rate)
+
+ self.squeeze_block = FCUDown(
+ in_channels=out_channels // expansion,
+ out_channels=embed_dims,
+ down_stride=down_stride,
+ with_cls_token=with_cls_token)
+
+ self.expand_block = FCUUp(
+ in_channels=embed_dims,
+ out_channels=out_channels // expansion,
+ up_stride=down_stride,
+ with_cls_token=with_cls_token)
+
+ self.trans_block = TransformerEncoderLayer(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=int(embed_dims * mlp_ratio),
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ attn_drop_rate=attn_drop_rate,
+ qkv_bias=qkv_bias,
+ norm_cfg=dict(type='LN', eps=1e-6))
+
+ self.down_stride = down_stride
+ self.embed_dim = embed_dims
+ self.last_fusion = last_fusion
+
+ def forward(self, cnn_input, trans_input):
+ x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
+
+ _, _, H, W = x_conv2.shape
+
+ # Convert the feature map of conv2 to transformer embedding
+ # and concat with class token.
+ conv2_embedding = self.squeeze_block(x_conv2, trans_input)
+
+ trans_output = self.trans_block(conv2_embedding + trans_input)
+
+ # Convert the transformer output embedding to feature map
+ trans_features = self.expand_block(trans_output, H // self.down_stride,
+ W // self.down_stride)
+ x = self.fusion_block(
+ x, fusion_features=trans_features, out_conv2=False)
+
+ return x, trans_output
+
+
+@BACKBONES.register_module()
+class Conformer(BaseBackbone):
+ """Conformer backbone.
+
+ A PyTorch implementation of : `Conformer: Local Features Coupling Global
+ Representations for Visual Recognition `_
+
+ Args:
+ arch (str | dict): Conformer architecture. Defaults to 'tiny'.
+ patch_size (int): The patch size. Defaults to 16.
+ base_channels (int): The base number of channels in CNN network.
+ Defaults to 64.
+ mlp_ratio (float): The expansion ratio of FFN network in transformer
+ block. Defaults to 4.
+ with_cls_token (bool): Whether use class token or not.
+ Defaults to True.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'],
+ {'embed_dims': 384,
+ 'channel_ratio': 1,
+ 'num_heads': 6,
+ 'depths': 12
+ }),
+ **dict.fromkeys(['s', 'small'],
+ {'embed_dims': 384,
+ 'channel_ratio': 4,
+ 'num_heads': 6,
+ 'depths': 12
+ }),
+ **dict.fromkeys(['b', 'base'],
+ {'embed_dims': 576,
+ 'channel_ratio': 6,
+ 'num_heads': 9,
+ 'depths': 12
+ }),
+ } # yapf: disable
+
+ _version = 1
+
+ def __init__(self,
+ arch='tiny',
+ patch_size=16,
+ base_channels=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ with_cls_token=True,
+ drop_path_rate=0.,
+ norm_eval=True,
+ frozen_stages=0,
+ out_indices=-1,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'depths', 'num_heads', 'channel_ratio'
+ }
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.num_features = self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+ self.channel_ratio = self.arch_settings['channel_ratio']
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.depths + index + 1
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ self.norm_eval = norm_eval
+ self.frozen_stages = frozen_stages
+
+ self.with_cls_token = with_cls_token
+ if self.with_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+
+ # stochastic depth decay rule
+ self.trans_dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
+ ]
+
+ # Stem stage: get the feature maps by conv block
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False) # 1 / 2 [112, 112]
+ self.bn1 = nn.BatchNorm2d(64)
+ self.act1 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(
+ kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
+
+ # 1 stage
+ stage1_channels = int(base_channels * self.channel_ratio)
+ trans_down_stride = patch_size // 4
+ self.conv_1 = ConvBlock(
+ in_channels=64,
+ out_channels=stage1_channels,
+ with_residual_conv=True,
+ stride=1)
+ self.trans_patch_conv = nn.Conv2d(
+ 64,
+ self.embed_dims,
+ kernel_size=trans_down_stride,
+ stride=trans_down_stride,
+ padding=0)
+
+ self.trans_1 = TransformerEncoderLayer(
+ embed_dims=self.embed_dims,
+ num_heads=self.num_heads,
+ feedforward_channels=int(self.embed_dims * mlp_ratio),
+ drop_path_rate=self.trans_dpr[0],
+ qkv_bias=qkv_bias,
+ norm_cfg=dict(type='LN', eps=1e-6))
+
+ # 2~4 stage
+ init_stage = 2
+ fin_stage = self.depths // 3 + 1
+ for i in range(init_stage, fin_stage):
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=stage1_channels,
+ out_channels=stage1_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=1,
+ with_residual_conv=False,
+ down_stride=trans_down_stride,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token))
+
+ stage2_channels = int(base_channels * self.channel_ratio * 2)
+ # 5~8 stage
+ init_stage = fin_stage # 5
+ fin_stage = fin_stage + self.depths // 3 # 9
+ for i in range(init_stage, fin_stage):
+ if i == init_stage:
+ conv_stride = 2
+ in_channels = stage1_channels
+ else:
+ conv_stride = 1
+ in_channels = stage2_channels
+
+ with_residual_conv = True if i == init_stage else False
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=in_channels,
+ out_channels=stage2_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=conv_stride,
+ with_residual_conv=with_residual_conv,
+ down_stride=trans_down_stride // 2,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token))
+
+ stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
+ # 9~12 stage
+ init_stage = fin_stage # 9
+ fin_stage = fin_stage + self.depths // 3 # 13
+ for i in range(init_stage, fin_stage):
+ if i == init_stage:
+ conv_stride = 2
+ in_channels = stage2_channels
+ with_residual_conv = True
+ else:
+ conv_stride = 1
+ in_channels = stage3_channels
+ with_residual_conv = False
+
+ last_fusion = (i == self.depths)
+
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=in_channels,
+ out_channels=stage3_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=conv_stride,
+ with_residual_conv=with_residual_conv,
+ down_stride=trans_down_stride // 4,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token,
+ last_fusion=last_fusion))
+ self.fin_stage = fin_stage
+
+ self.pooling = nn.AdaptiveAvgPool2d(1)
+ self.trans_norm = nn.LayerNorm(self.embed_dims)
+
+ if self.with_cls_token:
+ trunc_normal_(self.cls_token, std=.02)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def init_weights(self):
+ super(Conformer, self).init_weights()
+ logger = get_root_logger()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress default init if use pretrained model.
+ return
+ else:
+ logger.info(f'No pre-trained weights for '
+ f'{self.__class__.__name__}, '
+ f'training start from scratch')
+ self.apply(self._init_weights)
+
+ def forward(self, x):
+ output = []
+ B = x.shape[0]
+ if self.with_cls_token:
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ # stem
+ x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
+
+ # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
+ x = self.conv_1(x_base, out_conv2=False)
+ x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
+ if self.with_cls_token:
+ x_t = torch.cat([cls_tokens, x_t], dim=1)
+ x_t = self.trans_1(x_t)
+
+ # 2 ~ final
+ for i in range(2, self.fin_stage):
+ stage = getattr(self, f'conv_trans_{i}')
+ x, x_t = stage(x, x_t)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ output.append([
+ self.pooling(x).flatten(1),
+ self.trans_norm(x_t)[:, 0]
+ ])
+ else:
+ # if no class token, use the mean patch token
+ # as the transformer feature.
+ output.append([
+ self.pooling(x).flatten(1),
+ self.trans_norm(x_t).mean(dim=1)
+ ])
+
+ return tuple(output)
diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py
index 7711272a..4be4daf8 100644
--- a/mmcls/models/heads/__init__.py
+++ b/mmcls/models/heads/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_head import ClsHead
+from .conformer_head import ConformerHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
@@ -8,5 +9,5 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
- 'MultiLabelLinearClsHead', 'VisionTransformerClsHead'
+ 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
]
diff --git a/mmcls/models/heads/conformer_head.py b/mmcls/models/heads/conformer_head.py
new file mode 100644
index 00000000..c913b657
--- /dev/null
+++ b/mmcls/models/heads/conformer_head.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.utils.weight_init import trunc_normal_
+
+from ..builder import HEADS
+from .cls_head import ClsHead
+
+
+@HEADS.register_module()
+class ConformerHead(ClsHead):
+ """Linear classifier head.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ init_cfg (dict | optional): The extra init config of layers.
+ Defaults to use dict(type='Normal', layer='Linear', std=0.01).
+ """
+
+ def __init__(
+ self,
+ num_classes,
+ in_channels, # [conv_dim, trans_dim]
+ init_cfg=dict(type='Normal', layer='Linear', std=0.01),
+ *args,
+ **kwargs):
+ super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs)
+
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.init_cfg = init_cfg
+
+ if self.num_classes <= 0:
+ raise ValueError(
+ f'num_classes={num_classes} must be a positive integer')
+
+ self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes)
+ self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def init_weights(self):
+ super(ConformerHead, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress default init if use pretrained model.
+ return
+ else:
+ self.apply(self._init_weights)
+
+ def simple_test(self, x):
+ """Test without augmentation."""
+ if isinstance(x, tuple):
+ x = x[-1]
+ assert isinstance(x,
+ list) # There are two outputs in the Conformer model
+
+ conv_cls_score = self.conv_cls_head(x[0])
+ tran_cls_score = self.trans_cls_head(x[1])
+
+ cls_score = conv_cls_score + tran_cls_score
+
+ pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
+
+ return self.post_process(pred)
+
+ def forward_train(self, x, gt_label):
+ if isinstance(x, tuple):
+ x = x[-1]
+ assert isinstance(x, list) and len(x) == 2, \
+ 'There should be two outputs in the Conformer model'
+
+ conv_cls_score = self.conv_cls_head(x[0])
+ tran_cls_score = self.trans_cls_head(x[1])
+
+ losses = self.loss([conv_cls_score, tran_cls_score], gt_label)
+ return losses
+
+ def loss(self, cls_score, gt_label):
+ num_samples = len(cls_score[0])
+ losses = dict()
+ # compute loss
+ loss = sum([
+ self.compute_loss(score, gt_label, avg_factor=num_samples) /
+ len(cls_score) for score in cls_score
+ ])
+ if self.cal_acc:
+ # compute accuracy
+ acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label)
+ assert len(acc) == len(self.topk)
+ losses['accuracy'] = {
+ f'top-{k}': a
+ for k, a in zip(self.topk, acc)
+ }
+ losses['loss'] = loss
+ return losses
diff --git a/model-index.yml b/model-index.yml
index 4da36a3a..f8b8b4d5 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -13,3 +13,4 @@ Import:
- configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml
+ - configs/conformer/metafile.yml
diff --git a/tests/test_models/test_backbones/test_conformer.py b/tests/test_models/test_backbones/test_conformer.py
new file mode 100644
index 00000000..9264256f
--- /dev/null
+++ b/tests/test_models/test_backbones/test_conformer.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import pytest
+import torch
+from torch.nn.modules import GroupNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmcls.models.backbones import Conformer
+
+
+def is_norm(modules):
+ """Check if is one of the norms."""
+ if isinstance(modules, (GroupNorm, _BatchNorm)):
+ return True
+ return False
+
+
+def check_norm_state(modules, train_state):
+ """Check if norm layer is in correct train state."""
+ for mod in modules:
+ if isinstance(mod, _BatchNorm):
+ if mod.training != train_state:
+ return False
+ return True
+
+
+def test_conformer_backbone():
+
+ cfg_ori = dict(
+ arch='T',
+ drop_path_rate=0.1,
+ )
+
+ with pytest.raises(AssertionError):
+ # test invalid arch
+ cfg = deepcopy(cfg_ori)
+ cfg['arch'] = 'unknown'
+ Conformer(**cfg)
+
+ with pytest.raises(AssertionError):
+ # test arch without essential keys
+ cfg = deepcopy(cfg_ori)
+ cfg['arch'] = {'embed_dims': 24, 'channel_ratio': 6, 'num_heads': 9}
+ Conformer(**cfg)
+
+ # Test Conformer small model with patch size of 16
+ model = Conformer(**cfg_ori)
+ model.init_weights()
+ model.train()
+
+ assert check_norm_state(model.modules(), True)
+
+ imgs = torch.randn(3, 3, 224, 224)
+ conv_feature, transformer_feature = model(imgs)[-1]
+ assert conv_feature.shape == (3, 64 * 1 * 4
+ ) # base_channels * channel_ratio * 4
+ assert transformer_feature.shape == (3, 384)
+
+ # Test custom arch Conformer without output cls token
+ cfg = deepcopy(cfg_ori)
+ cfg['arch'] = {
+ 'embed_dims': 128,
+ 'depths': 15,
+ 'num_heads': 16,
+ 'channel_ratio': 3,
+ }
+ cfg['with_cls_token'] = False
+ cfg['base_channels'] = 32
+ model = Conformer(**cfg)
+ conv_feature, transformer_feature = model(imgs)[-1]
+ assert conv_feature.shape == (3, 32 * 3 * 4)
+ assert transformer_feature.shape == (3, 128)
+
+ # Test ViT with multi out indices
+ cfg = deepcopy(cfg_ori)
+ cfg['out_indices'] = [4, 8, 12]
+ model = Conformer(**cfg)
+ outs = model(imgs)
+ assert len(outs) == 3
+ # stage 1
+ conv_feature, transformer_feature = outs[0]
+ assert conv_feature.shape == (3, 64 * 1)
+ assert transformer_feature.shape == (3, 384)
+ # stage 2
+ conv_feature, transformer_feature = outs[1]
+ assert conv_feature.shape == (3, 64 * 1 * 2)
+ assert transformer_feature.shape == (3, 384)
+ # stage 3
+ conv_feature, transformer_feature = outs[2]
+ assert conv_feature.shape == (3, 64 * 1 * 4)
+ assert transformer_feature.shape == (3, 384)