[Feature] Implement the conformer backbone. (#494)

* implement the conformer

* format code style

* format code style

* reuse the TransformerEncoderLayer in the vision_transformer.py

* Modify variable name

* delete unused params

* Remove warning info in Conformer head since it already exists in
Conformer.

* Rename some variables

* Add unit tests

* Use `getattr` instead of `get_submodule`.

* Remove some useless layers

* Refactor conformer and add configs

* Update configs and add metafile.

* Fix unit tests

* Update README

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/580/head
Zhiliang Peng 2021-12-07 14:00:17 +08:00 committed by GitHub
parent 0aa789f3c3
commit 18f6bb0b10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1095 additions and 3 deletions

View File

@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='Conformer', arch='base', drop_path_rate=0.1, init_cfg=None),
neck=None,
head=dict(
type='ConformerHead',
num_classes=1000,
in_channels=[1536, 576],
init_cfg=None,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='Conformer', arch='small', drop_path_rate=0.1, init_cfg=None),
neck=None,
head=dict(
type='ConformerHead',
num_classes=1000,
in_channels=[1024, 384],
init_cfg=None,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,26 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='Conformer',
arch='small',
patch_size=32,
drop_path_rate=0.1,
init_cfg=None),
neck=None,
head=dict(
type='ConformerHead',
num_classes=1000,
in_channels=[1024, 384],
init_cfg=None,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='Conformer', arch='tiny', drop_path_rate=0.1, init_cfg=None),
neck=None,
head=dict(
type='ConformerHead',
num_classes=1000,
in_channels=[256, 384],
init_cfg=None,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

View File

@ -0,0 +1,29 @@
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
})
# for batch in each gpu is 128, 8 gpu
# lr = 5e-4 * 128 * 8 / 512 = 0.001
optimizer = dict(
type='AdamW',
lr=5e-4 * 128 * 8 / 512,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=paramwise_cfg)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='CosineAnnealing',
by_epoch=False,
min_lr_ratio=1e-2,
warmup='linear',
warmup_ratio=1e-3,
warmup_iters=5 * 1252,
warmup_by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=300)

View File

@ -0,0 +1,39 @@
# Conformer: Local Features Coupling Global Representations for Visual Recognition
<!-- {Conformer} -->
<!-- [ALGORITHM] -->
## Abstract
<!-- [ABSTRACT] -->
Within Convolutional Neural Network (CNN), the convolution operations are good at extracting local features but experience difficulty to capture global representations. Within visual transformer, the cascaded self-attention modules can capture long-distance feature dependencies but unfortunately deteriorate local feature details. In this paper, we propose a hybrid network structure, termed Conformer, to take advantage of convolutional operations and self-attention mechanisms for enhanced representation learning. Conformer roots in the Feature Coupling Unit (FCU), which fuses local features and global representations under different resolutions in an interactive fashion. Conformer adopts a concurrent structure so that local features and global representations are retained to the maximum extent. Experiments show that Conformer, under the comparable parameter complexity, outperforms the visual transformer (DeiT-B) by 2.3% on ImageNet. On MSCOCO, it outperforms ResNet-101 by 3.7% and 3.6% mAPs for object detection and instance segmentation, respectively, demonstrating the great potential to be a general backbone network.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/144957687-926390ed-6119-4e4c-beaa-9bc0017fe953.png" width="90%"/>
</div>
## Citation
```latex
@article{peng2021conformer,
title={Conformer: Local Features Coupling Global Representations for Visual Recognition},
author={Zhiliang Peng and Wei Huang and Shanzhi Gu and Lingxi Xie and Yaowei Wang and Jianbin Jiao and Qixiang Ye},
journal={arXiv preprint arXiv:2105.03889},
year={2021},
}
```
## Results and models
Some pre-trained models are converted from [official repo](https://github.com/pengzhiliang/Conformer).
## ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) |
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) |
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) |
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) |
*Models with \* are converted from other repos.*

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/conformer/base-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
'../_base_/default_runtime.py'
]
data = dict(samples_per_gpu=128)
evaluation = dict(interval=1, metric='accuracy')

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/conformer/small-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
'../_base_/default_runtime.py'
]
data = dict(samples_per_gpu=128)
evaluation = dict(interval=1, metric='accuracy')

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/conformer/small-p32.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
'../_base_/default_runtime.py'
]
data = dict(samples_per_gpu=128)
evaluation = dict(interval=1, metric='accuracy')

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/conformer/tiny-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_conformer.py',
'../_base_/default_runtime.py'
]
data = dict(samples_per_gpu=128)
evaluation = dict(interval=1, metric='accuracy')

View File

@ -0,0 +1,78 @@
Collections:
- Name: Conformer
Metadata:
Training Data: ImageNet-1k
Architecture:
- Layer Normalization
- Scaled Dot-Product Attention
- Dropout
Paper:
URL: https://arxiv.org/abs/2105.03889
Title: "Conformer: Local Features Coupling Global Representations for Visual Recognition"
README: configs/conformer/README.md
# Code:
# URL: # todo
# Version: # todo
Models:
- Name: conformer-tiny-p16_3rdparty_8xb128_in1k
In Collection: Conformer
Config: configs/conformer/conformer-tiny-p16_8xb128_in1k.py
Metadata:
FLOPs: 4899611328
Parameters: 23524704
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.31
Top 5 Accuracy: 95.60
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth
Converted From:
Weights: https://drive.google.com/file/d/19SxGhKcWOR5oQSxNUWUM2MGYiaWMrF1z/view?usp=sharing
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L65
- Name: conformer-small-p16_3rdparty_8xb128_in1k
In Collection: Conformer
Config: configs/conformer/conformer-small-p16_8xb128_in1k.py
Metadata:
FLOPs: 10311309312
Parameters: 37673424
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.32
Top 5 Accuracy: 96.46
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth
Converted From:
Weights: https://drive.google.com/file/d/1mpOlbLaVxOfEwV4-ha78j_1Ebqzj2B83/view?usp=sharing
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L73
- Name: conformer-small-p32_8xb128_in1k
In Collection: Conformer
Config: configs/conformer/conformer-small-p32_8xb128_in1k.py
Metadata:
FLOPs: 7087281792
Parameters: 38853072
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.96
Top 5 Accuracy: 96.02
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth
- Name: conformer-base-p16_3rdparty_8xb128_in1k
In Collection: Conformer
Config: configs/conformer/conformer-base-p16_8xb128_in1k.py
Metadata:
FLOPs: 22892078080
Parameters: 83289136
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.82
Top 5 Accuracy: 96.59
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth
Converted From:
Weights: https://drive.google.com/file/d/1oeQ9LSOGKEUaYGu7WTlUGl3KDsQIi0MA/view?usp=sharing
Code: https://github.com/pengzhiliang/Conformer/blob/main/models.py#L89

View File

@ -23,7 +23,7 @@ Transformers, which are popular for language modeling, have been explored for so
## Pretrain model
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
The pre-trained models are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
### ImageNet-1k

View File

@ -63,6 +63,10 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) &#124; [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) &#124; [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) &#124; [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) &#124; [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) &#124; [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) &#124; [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) &#124; [log]()|
Models with * are converted from other repos, others are trained by ourselves.

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .conformer import Conformer
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@ -27,5 +28,5 @@ __all__ = [
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
'MlpMixer'
'Conformer', 'MlpMixer'
]

View File

@ -0,0 +1,616 @@
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcls.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone, BaseModule
from .vision_transformer import TransformerEncoderLayer
class ConvBlock(BaseModule):
"""Basic convluation block used in Conformer.
This block includes three convluation modules, and supports three new
functions:
1. Returns the output of both the final layers and the second convluation
module.
2. Fuses the input of the second convluation module with an extra input
feature map.
3. Supports to add an extra convluation module to the identity connection.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
stride (int): The stride of the second convluation module.
Defaults to 1.
groups (int): The groups of the second convluation module.
Defaults to 1.
drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
with_residual_conv (bool): Whether to add an extra convluation module
to the identity connection. Defaults to False.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='ReLU', inplace=True))``.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
groups=1,
drop_path_rate=0.,
with_residual_conv=False,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(ConvBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
mid_channels = out_channels // expansion
self.conv1 = nn.Conv2d(
in_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act1 = build_activation_layer(act_cfg)
self.conv2 = nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=stride,
groups=groups,
padding=1,
bias=False)
self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
self.act2 = build_activation_layer(act_cfg)
self.conv3 = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
self.act3 = build_activation_layer(act_cfg)
if with_residual_conv:
self.residual_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=False)
self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
self.with_residual_conv = with_residual_conv
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x, fusion_features=None, out_conv2=True):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x) if fusion_features is None else self.conv2(
x + fusion_features)
x = self.bn2(x)
x2 = self.act2(x)
x = self.conv3(x2)
x = self.bn3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.with_residual_conv:
identity = self.residual_conv(identity)
identity = self.residual_bn(identity)
x += identity
x = self.act3(x)
if out_conv2:
return x, x2
else:
return x
class FCUDown(BaseModule):
"""CNN feature maps -> Transformer patch embeddings."""
def __init__(self,
in_channels,
out_channels,
down_stride,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(FCUDown, self).__init__(init_cfg=init_cfg)
self.down_stride = down_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sample_pooling = nn.AvgPool2d(
kernel_size=down_stride, stride=down_stride)
self.ln = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, x_t):
x = self.conv_project(x) # [N, C, H, W]
x = self.sample_pooling(x).flatten(2).transpose(1, 2)
x = self.ln(x)
x = self.act(x)
if self.with_cls_token:
x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
return x
class FCUUp(BaseModule):
"""Transformer patch embeddings -> CNN feature maps."""
def __init__(self,
in_channels,
out_channels,
up_stride,
with_cls_token=True,
norm_cfg=dict(type='BN', eps=1e-6),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(FCUUp, self).__init__(init_cfg=init_cfg)
self.up_stride = up_stride
self.with_cls_token = with_cls_token
self.conv_project = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.bn = build_norm_layer(norm_cfg, out_channels)[1]
self.act = build_activation_layer(act_cfg)
def forward(self, x, H, W):
B, _, C = x.shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
if self.with_cls_token:
x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
else:
x_r = x.transpose(1, 2).reshape(B, C, H, W)
x_r = self.act(self.bn(self.conv_project(x_r)))
return F.interpolate(
x_r, size=(H * self.up_stride, W * self.up_stride))
class ConvTransBlock(BaseModule):
"""Basic module for Conformer.
This module is a fusion of CNN block transformer encoder block.
Args:
in_channels (int): The number of input channels in conv blocks.
out_channels (int): The number of output channels in conv blocks.
embed_dims (int): The embedding dimension in transformer blocks.
conv_stride (int): The stride of conv2d layers. Defaults to 1.
groups (int): The groups of conv blocks. Defaults to 1.
with_residual_conv (bool): Whether to add a conv-bn layer to the
identity connect in the conv block. Defaults to False.
down_stride (int): The stride of the downsample pooling layer.
Defaults to 4.
num_heads (int): The number of heads in transformer attention layers.
Defaults to 12.
mlp_ratio (float): The expansion ratio in transformer FFN module.
Defaults to 4.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_rate (float): The dropout rate of the output projection and
FFN in the transformer block. Defaults to 0.
attn_drop_rate (float): The dropout rate after the attention
calculation in the transformer block. Defaults to 0.
drop_path_rate (bloat): The drop path rate in both the conv block
and the transformer block. Defaults to 0.
last_fusion (bool): Whether this block is the last stage. If so,
downsample the fusion feature map.
init_cfg (dict, optional): The extra config to initialize the module.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
embed_dims,
conv_stride=1,
groups=1,
with_residual_conv=False,
down_stride=4,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
with_cls_token=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
last_fusion=False,
init_cfg=None):
super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
expansion = 4
self.cnn_block = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
with_residual_conv=with_residual_conv,
stride=conv_stride,
groups=groups)
if last_fusion:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
stride=2,
with_residual_conv=True,
groups=groups,
drop_path_rate=drop_path_rate)
else:
self.fusion_block = ConvBlock(
in_channels=out_channels,
out_channels=out_channels,
groups=groups,
drop_path_rate=drop_path_rate)
self.squeeze_block = FCUDown(
in_channels=out_channels // expansion,
out_channels=embed_dims,
down_stride=down_stride,
with_cls_token=with_cls_token)
self.expand_block = FCUUp(
in_channels=embed_dims,
out_channels=out_channels // expansion,
up_stride=down_stride,
with_cls_token=with_cls_token)
self.trans_block = TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=int(embed_dims * mlp_ratio),
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
self.down_stride = down_stride
self.embed_dim = embed_dims
self.last_fusion = last_fusion
def forward(self, cnn_input, trans_input):
x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
_, _, H, W = x_conv2.shape
# Convert the feature map of conv2 to transformer embedding
# and concat with class token.
conv2_embedding = self.squeeze_block(x_conv2, trans_input)
trans_output = self.trans_block(conv2_embedding + trans_input)
# Convert the transformer output embedding to feature map
trans_features = self.expand_block(trans_output, H // self.down_stride,
W // self.down_stride)
x = self.fusion_block(
x, fusion_features=trans_features, out_conv2=False)
return x, trans_output
@BACKBONES.register_module()
class Conformer(BaseBackbone):
"""Conformer backbone.
A PyTorch implementation of : `Conformer: Local Features Coupling Global
Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_
Args:
arch (str | dict): Conformer architecture. Defaults to 'tiny'.
patch_size (int): The patch size. Defaults to 16.
base_channels (int): The base number of channels in CNN network.
Defaults to 64.
mlp_ratio (float): The expansion ratio of FFN network in transformer
block. Defaults to 4.
with_cls_token (bool): Whether use class token or not.
Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 384,
'channel_ratio': 1,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 384,
'channel_ratio': 4,
'num_heads': 6,
'depths': 12
}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 576,
'channel_ratio': 6,
'num_heads': 9,
'depths': 12
}),
} # yapf: disable
_version = 1
def __init__(self,
arch='tiny',
patch_size=16,
base_channels=64,
mlp_ratio=4.,
qkv_bias=True,
with_cls_token=True,
drop_path_rate=0.,
norm_eval=True,
frozen_stages=0,
out_indices=-1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'channel_ratio'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.num_features = self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.channel_ratio = self.arch_settings['channel_ratio']
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.depths + index + 1
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.norm_eval = norm_eval
self.frozen_stages = frozen_stages
self.with_cls_token = with_cls_token
if self.with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# stochastic depth decay rule
self.trans_dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
]
# Stem stage: get the feature maps by conv block
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3,
bias=False) # 1 / 2 [112, 112]
self.bn1 = nn.BatchNorm2d(64)
self.act1 = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
# 1 stage
stage1_channels = int(base_channels * self.channel_ratio)
trans_down_stride = patch_size // 4
self.conv_1 = ConvBlock(
in_channels=64,
out_channels=stage1_channels,
with_residual_conv=True,
stride=1)
self.trans_patch_conv = nn.Conv2d(
64,
self.embed_dims,
kernel_size=trans_down_stride,
stride=trans_down_stride,
padding=0)
self.trans_1 = TransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=int(self.embed_dims * mlp_ratio),
drop_path_rate=self.trans_dpr[0],
qkv_bias=qkv_bias,
norm_cfg=dict(type='LN', eps=1e-6))
# 2~4 stage
init_stage = 2
fin_stage = self.depths // 3 + 1
for i in range(init_stage, fin_stage):
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=stage1_channels,
out_channels=stage1_channels,
embed_dims=self.embed_dims,
conv_stride=1,
with_residual_conv=False,
down_stride=trans_down_stride,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage2_channels = int(base_channels * self.channel_ratio * 2)
# 5~8 stage
init_stage = fin_stage # 5
fin_stage = fin_stage + self.depths // 3 # 9
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage1_channels
else:
conv_stride = 1
in_channels = stage2_channels
with_residual_conv = True if i == init_stage else False
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage2_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 2,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token))
stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
# 9~12 stage
init_stage = fin_stage # 9
fin_stage = fin_stage + self.depths // 3 # 13
for i in range(init_stage, fin_stage):
if i == init_stage:
conv_stride = 2
in_channels = stage2_channels
with_residual_conv = True
else:
conv_stride = 1
in_channels = stage3_channels
with_residual_conv = False
last_fusion = (i == self.depths)
self.add_module(
f'conv_trans_{i}',
ConvTransBlock(
in_channels=in_channels,
out_channels=stage3_channels,
embed_dims=self.embed_dims,
conv_stride=conv_stride,
with_residual_conv=with_residual_conv,
down_stride=trans_down_stride // 4,
num_heads=self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=self.trans_dpr[i - 1],
with_cls_token=self.with_cls_token,
last_fusion=last_fusion))
self.fin_stage = fin_stage
self.pooling = nn.AdaptiveAvgPool2d(1)
self.trans_norm = nn.LayerNorm(self.embed_dims)
if self.with_cls_token:
trunc_normal_(self.cls_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def init_weights(self):
super(Conformer, self).init_weights()
logger = get_root_logger()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
else:
logger.info(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
self.apply(self._init_weights)
def forward(self, x):
output = []
B = x.shape[0]
if self.with_cls_token:
cls_tokens = self.cls_token.expand(B, -1, -1)
# stem
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
x = self.conv_1(x_base, out_conv2=False)
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
if self.with_cls_token:
x_t = torch.cat([cls_tokens, x_t], dim=1)
x_t = self.trans_1(x_t)
# 2 ~ final
for i in range(2, self.fin_stage):
stage = getattr(self, f'conv_trans_{i}')
x, x_t = stage(x, x_t)
if i in self.out_indices:
if self.with_cls_token:
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t)[:, 0]
])
else:
# if no class token, use the mean patch token
# as the transformer feature.
output.append([
self.pooling(x).flatten(1),
self.trans_norm(x_t).mean(dim=1)
])
return tuple(output)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
@ -8,5 +9,5 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead'
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
]

View File

@ -0,0 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import trunc_normal_
from ..builder import HEADS
from .cls_head import ClsHead
@HEADS.register_module()
class ConformerHead(ClsHead):
"""Linear classifier head.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
def __init__(
self,
num_classes,
in_channels, # [conv_dim, trans_dim]
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
**kwargs):
super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.init_cfg = init_cfg
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes)
self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def init_weights(self):
super(ConformerHead, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
else:
self.apply(self._init_weights)
def simple_test(self, x):
"""Test without augmentation."""
if isinstance(x, tuple):
x = x[-1]
assert isinstance(x,
list) # There are two outputs in the Conformer model
conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])
cls_score = conv_cls_score + tran_cls_score
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
def forward_train(self, x, gt_label):
if isinstance(x, tuple):
x = x[-1]
assert isinstance(x, list) and len(x) == 2, \
'There should be two outputs in the Conformer model'
conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])
losses = self.loss([conv_cls_score, tran_cls_score], gt_label)
return losses
def loss(self, cls_score, gt_label):
num_samples = len(cls_score[0])
losses = dict()
# compute loss
loss = sum([
self.compute_loss(score, gt_label, avg_factor=num_samples) /
len(cls_score) for score in cls_score
])
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label)
assert len(acc) == len(self.topk)
losses['accuracy'] = {
f'top-{k}': a
for k, a in zip(self.topk, acc)
}
losses['loss'] = loss
return losses

View File

@ -13,3 +13,4 @@ Import:
- configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml
- configs/conformer/metafile.yml

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import Conformer
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_conformer_backbone():
cfg_ori = dict(
arch='T',
drop_path_rate=0.1,
)
with pytest.raises(AssertionError):
# test invalid arch
cfg = deepcopy(cfg_ori)
cfg['arch'] = 'unknown'
Conformer(**cfg)
with pytest.raises(AssertionError):
# test arch without essential keys
cfg = deepcopy(cfg_ori)
cfg['arch'] = {'embed_dims': 24, 'channel_ratio': 6, 'num_heads': 9}
Conformer(**cfg)
# Test Conformer small model with patch size of 16
model = Conformer(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
conv_feature, transformer_feature = model(imgs)[-1]
assert conv_feature.shape == (3, 64 * 1 * 4
) # base_channels * channel_ratio * 4
assert transformer_feature.shape == (3, 384)
# Test custom arch Conformer without output cls token
cfg = deepcopy(cfg_ori)
cfg['arch'] = {
'embed_dims': 128,
'depths': 15,
'num_heads': 16,
'channel_ratio': 3,
}
cfg['with_cls_token'] = False
cfg['base_channels'] = 32
model = Conformer(**cfg)
conv_feature, transformer_feature = model(imgs)[-1]
assert conv_feature.shape == (3, 32 * 3 * 4)
assert transformer_feature.shape == (3, 128)
# Test ViT with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [4, 8, 12]
model = Conformer(**cfg)
outs = model(imgs)
assert len(outs) == 3
# stage 1
conv_feature, transformer_feature = outs[0]
assert conv_feature.shape == (3, 64 * 1)
assert transformer_feature.shape == (3, 384)
# stage 2
conv_feature, transformer_feature = outs[1]
assert conv_feature.shape == (3, 64 * 1 * 2)
assert transformer_feature.shape == (3, 384)
# stage 3
conv_feature, transformer_feature = outs[2]
assert conv_feature.shape == (3, 64 * 1 * 4)
assert transformer_feature.shape == (3, 384)