[Feature] Add efficientformer Backbone for MMCls 1.x. (#1031)

* rebase

* update filename

* update URL

* update UT

* fix lint

* update head

* add efficientformer

* update filename

* update UT

* fix lint

* update configs

* rebase

* fix unit tests

* Fix comments and docs.

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1029/head
Ezra-Yu 2022-09-20 14:56:45 +08:00 committed by GitHub
parent f1d2f50c21
commit aaf127c5e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1180 additions and 7 deletions

View File

@ -133,6 +133,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)
</details>

View File

@ -132,6 +132,7 @@ mim install -e .
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)
</details>

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='EfficientFormer',
arch='l1',
drop_path_rate=0,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-5)
]),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='EfficientFormerClsHead', in_channels=448, num_classes=1000))

View File

@ -0,0 +1,47 @@
# EfficientFormer
> [EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)
<!-- [ALGORITHM] -->
## Abstract
Vision Transformers (ViT) have shown rapid progress in computer vision tasks, achieving promising results on various benchmarks. However, due to the massive number of parameters and model design, e.g., attention mechanism, ViT-based models are generally times slower than lightweight convolutional networks. Therefore, the deployment of ViT for real-time applications is particularly challenging, especially on resource-constrained hardware such as mobile devices. Recent efforts try to reduce the computation complexity of ViT through network architecture search or hybrid design with MobileNet block, yet the inference speed is still unsatisfactory. This leads to an important question: can transformers run as fast as MobileNet while obtaining high performance? To answer this, we first revisit the network architecture and operators used in ViT-based models and identify inefficient designs. Then we introduce a dimension-consistent pure transformer (without MobileNet blocks) as a design paradigm. Finally, we perform latency-driven slimming to get a series of final models dubbed EfficientFormer. Extensive experiments show the superiority of EfficientFormer in performance and speed on mobile devices. Our fastest model, EfficientFormer-L1, achieves 79.2% top-1 accuracy on ImageNet-1K with only 1.6 ms inference latency on iPhone 12 (compiled with CoreML), which runs as fast as MobileNetV2×1.4 (1.6 ms, 74.7% top-1), and our largest model, EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work proves that properly designed transformers can reach extremely low latency on mobile devices while maintaining high performance.
<div align=center>
<img src="https://user-images.githubusercontent.com/18586273/180713426-9d3d77e3-3584-42d8-9098-625b4170d796.png" width="100%"/>
</div>
## Results and models
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------: |
| EfficientFormer-l1\* | 12.19 | 1.30 | 80.46 | 94.99 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l1_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220915-cc3e1ac6.pth) |
| EfficientFormer-l3\* | 31.41 | 3.93 | 82.45 | 96.18 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l3_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220915-466793d6.pth) |
| EfficientFormer-l7\* | 82.23 | 10.16 | 83.40 | 96.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l7_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220915-185e30af.pth) |
*Models with * are converted from the [official repo](https://github.com/snap-research/EfficientFormer). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
```bibtex
@misc{https://doi.org/10.48550/arxiv.2206.01191,
doi = {10.48550/ARXIV.2206.01191},
url = {https://arxiv.org/abs/2206.01191},
author = {Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov, Sergey and Wang, Yanzhi and Ren, Jian},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {EfficientFormer: Vision Transformers at MobileNet Speed},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
```

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/efficientformer-l1.py',
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]

View File

@ -0,0 +1,3 @@
_base_ = './efficientformer-l1_8xb128_in1k.py'
model = dict(backbone=dict(arch='l3'), head=dict(in_channels=512))

View File

@ -0,0 +1,3 @@
_base_ = './efficientformer-l1_8xb128_in1k.py'
model = dict(backbone=dict(arch='l7'), head=dict(in_channels=768))

View File

@ -0,0 +1,67 @@
Collections:
- Name: EfficientFormer
Metadata:
Training Data: ImageNet-1k
Architecture:
- Pooling
- 1x1 Convolution
- LayerScale
- MetaFormer
Paper:
URL: https://arxiv.org/pdf/2206.01191.pdf
Title: "EfficientFormer: Vision Transformers at MobileNet Speed"
README: configs/efficientformer/README.md
Code:
Version: v1.0.0rc1
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc1/configs/efficientformer/metafile.yml
Models:
- Name: efficientformer-l1_3rdparty_8xb128_in1k
Metadata:
FLOPs: 1304601088 # 1.3G
Parameters: 12278696 # 12M
In Collection: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.46
Top 5 Accuracy: 94.99
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220915-cc3e1ac6.pth
Config: configs/efficientformer/efficientformer-l1_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/11SbX-3cfqTOc247xKYubrAjBiUmr818y/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer
- Name: efficientformer-l3_3rdparty_8xb128_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 3737045760 # 3.7G
Parameters: 31406000 # 31M
In Collection: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.45
Top 5 Accuracy: 96.18
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220915-466793d6.pth
Config: configs/efficientformer/efficientformer-l3_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/1OyyjKKxDyMj-BcfInp4GlDdwLu3hc30m/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer
- Name: efficientformer-l7_3rdparty_8xb128_in1k
Metadata:
FLOPs: 10163951616 # 10.2G
Parameters: 82229328 # 82M
In Collection: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.40
Top 5 Accuracy: 96.60
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220915-185e30af.pth
Config: configs/efficientformer/efficientformer-l7_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/1cVw-pctJwgvGafeouynqWWCwgkcoFMM5/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer

View File

@ -65,12 +65,15 @@ Backbones
ConvNeXt
DenseNet
DistilledVisionTransformer
EfficientFormer
EfficientNet
HRNet
InceptionV3
LeNet5
MlpMixer
MobileNetV2
MobileNetV3
MobileOne
PCPVT
PoolFormer
RegNet
@ -95,8 +98,6 @@ Backbones
VAN
VGG
VisionTransformer
MobileOne
InceptionV3
.. module:: mmcls.models.necks
@ -126,6 +127,7 @@ Heads
LinearClsHead
StackedLinearClsHead
VisionTransformerClsHead
EfficientFormerClsHead
DeiTClsHead
ConformerHead
MultiLabelClsHead
@ -172,6 +174,7 @@ Common Components
PatchEmbed
PatchMerging
HybridEmbed
LayerScale
.. _helpers:

View File

@ -6,6 +6,7 @@ from .convnext import ConvNeXt
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .deit import DistilledVisionTransformer
from .densenet import DenseNet
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .hrnet import HRNet
from .inception_v3 import InceptionV3
@ -44,5 +45,6 @@ __all__ = [
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer',
'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet',
'PoolFormer', 'DenseNet', 'VAN', 'InceptionV3', 'MobileOne'
'PoolFormer', 'DenseNet', 'VAN', 'InceptionV3', 'MobileOne',
'EfficientFormer'
]

View File

@ -0,0 +1,606 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Optional, Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer)
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS
from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling
class AttentionWithBias(BaseModule):
"""Multi-head Attention Module with attention_bias.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
key_dim (int): The dimension of q, k. Defaults to 32.
attn_ratio (float): The dimension of v equals to
``key_dim * attn_ratio``. Defaults to 4.
resolution (int): The height and width of attention_bias.
Defaults to 7.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads=8,
key_dim=32,
attn_ratio=4.,
resolution=7,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.attn_ratio = attn_ratio
self.key_dim = key_dim
self.nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
h = self.dh + self.nh_kd * 2
self.qkv = nn.Linear(embed_dims, h)
self.proj = nn.Linear(self.dh, embed_dims)
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs',
torch.LongTensor(idxs).view(N, N))
@torch.no_grad()
def train(self, mode=True):
"""change the mode of model."""
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x):
"""forward function.
Args:
x (tensor): input features with shape of (B, N, C)
"""
B, N, _ = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs]
if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class Flat(nn.Module):
"""Flat the input from (B, C, H, W) to (B, H*W, C)."""
def __init__(self, ):
super().__init__()
def forward(self, x: torch.Tensor):
x = x.flatten(2).transpose(1, 2)
return x
class LinearMlp(BaseModule):
"""Mlp implemented with linear.
The shape of input and output tensor are (B, N, C).
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_activation_layer(act_cfg)
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x (torch.Tensor): input tensor with shape (B, N, C).
Returns:
torch.Tensor: output tensor with shape (B, N, C).
"""
x = self.drop1(self.act(self.fc1(x)))
x = self.drop2(self.fc2(x))
return x
class ConvMlp(BaseModule):
"""Mlp implemented with 1*1 convolutions.
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
self.norm2 = build_norm_layer(norm_cfg, out_features)[1]
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x (torch.Tensor): input tensor with shape (B, C, H, W).
Returns:
torch.Tensor: output tensor with shape (B, C, H, W).
"""
x = self.act(self.norm1(self.fc1(x)))
x = self.drop(x)
x = self.norm2(self.fc2(x))
x = self.drop(x)
return x
class Meta3D(BaseModule):
"""Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
(B, N, C)."""
def __init__(self,
dim,
mlp_ratio=4.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.token_mixer = AttentionWithBias(dim)
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LinearMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
if use_layer_scale:
self.ls1 = LayerScale(dim)
self.ls2 = LayerScale(dim)
else:
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
return x
class Meta4D(BaseModule):
"""Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
(B, C, H, W)."""
def __init__(self,
dim,
pool_size=3,
mlp_ratio=4.,
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.token_mixer = Pooling(pool_size=pool_size)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ConvMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
if use_layer_scale:
self.ls1 = LayerScale(dim, data_format='channels_first')
self.ls2 = LayerScale(dim, data_format='channels_first')
else:
self.ls1, self.ls2 = nn.Identity(), nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.ls1(self.token_mixer(x)))
x = x + self.drop_path(self.ls2(self.mlp(x)))
return x
def basic_blocks(in_channels,
out_channels,
index,
layers,
pool_size=3,
mlp_ratio=4.,
act_cfg=dict(type='GELU'),
drop_rate=.0,
drop_path_rate=0.,
use_layer_scale=True,
vit_num=1,
has_downsamper=False):
"""generate EfficientFormer blocks for a stage."""
blocks = []
if has_downsamper:
blocks.append(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
norm_cfg=dict(type='BN'),
act_cfg=None))
if index == 3 and vit_num == layers[index]:
blocks.append(Flat())
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
sum(layers) - 1)
if index == 3 and layers[index] - block_idx <= vit_num:
blocks.append(
Meta3D(
out_channels,
mlp_ratio=mlp_ratio,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
))
else:
blocks.append(
Meta4D(
out_channels,
pool_size=pool_size,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale))
if index == 3 and layers[index] - block_idx - 1 == vit_num:
blocks.append(Flat())
blocks = nn.Sequential(*blocks)
return blocks
@MODELS.register_module()
class EfficientFormer(BaseBackbone):
"""EfficientFormer.
A PyTorch implementation of EfficientFormer introduced by:
`EfficientFormer: Vision Transformers at MobileNet Speed <https://arxiv.org/abs/2206.01191>`_
Modified from the `official repo
<https://github.com/snap-research/EfficientFormer>`.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``EfficientFormer.arch_settings``. And if dict,
it should include the following 4 keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- downsamples (list[int]): Has downsample or not in the four stages.
- vit_num (int): The num of vit blocks in the last stage.
Defaults to 'l1'.
in_channels (int): The num of input channels. Defaults to 3.
pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3.
mlp_ratios (int): The dimension ratio of multi-head attention mechanism
in ``Meta4D`` blocks. Defaults to 3.
reshape_last_feat (bool): Whether to reshape the feature map from
(B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num``
in ``arch`` is not 0. Defaults to False. Usually set to True
in downstream tasks.
out_indices (Sequence[int]): Output from which stages.
Defaults to -1.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer
block. Defaults to True.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Example:
>>> from mmcls.models import EfficientFormer
>>> import torch
>>> inputs = torch.rand((1, 3, 224, 224))
>>> # build EfficientFormer backbone for classification task
>>> model = EfficientFormer(arch="l1")
>>> model.eval()
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 448, 49)
>>> # build EfficientFormer backbone for downstream task
>>> model = EfficientFormer(
>>> arch="l3",
>>> out_indices=(0, 1, 2, 3),
>>> reshape_last_feat=True)
>>> model.eval()
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 320, 14, 14)
(1, 512, 7, 7)
""" # noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims: [x,x,x,x], embedding dims for the four stages
# --downsamples: [x,x,x,x], has downsample or not in the four stages
# --vit_num(int), the num of vit blocks in the last stage
arch_settings = {
'l1': {
'layers': [3, 2, 6, 4],
'embed_dims': [48, 96, 224, 448],
'downsamples': [False, True, True, True],
'vit_num': 1,
},
'l3': {
'layers': [4, 4, 12, 6],
'embed_dims': [64, 128, 320, 512],
'downsamples': [False, True, True, True],
'vit_num': 4,
},
'l7': {
'layers': [6, 6, 18, 8],
'embed_dims': [96, 192, 384, 768],
'downsamples': [False, True, True, True],
'vit_num': 8,
},
}
def __init__(self,
arch='l1',
in_channels=3,
pool_size=3,
mlp_ratios=4,
reshape_last_feat=False,
out_indices=-1,
frozen_stages=-1,
act_cfg=dict(type='GELU'),
drop_rate=0.,
drop_path_rate=0.,
use_layer_scale=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_extra_tokens = 0 # no cls_token, no dist_token
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
default_keys = set(self.arch_settings['l1'].keys())
assert set(arch.keys()) == default_keys, \
f'The arch dict must have {default_keys}, ' \
f'but got {list(arch.keys())}.'
self.layers = arch['layers']
self.embed_dims = arch['embed_dims']
self.downsamples = arch['downsamples']
assert isinstance(self.layers, list) and isinstance(
self.embed_dims, list) and isinstance(self.downsamples, list)
assert len(self.layers) == len(self.embed_dims) == len(
self.downsamples)
self.vit_num = arch['vit_num']
self.reshape_last_feat = reshape_last_feat
assert self.vit_num >= 0, "'vit_num' must be an integer " \
'greater than or equal to 0.'
assert self.vit_num <= self.layers[-1], (
"'vit_num' must be an integer smaller than layer number")
self._make_stem(in_channels, self.embed_dims[0])
# set the main block in network
network = []
for i in range(len(self.layers)):
if i != 0:
in_channels = self.embed_dims[i - 1]
else:
in_channels = self.embed_dims[i]
out_channels = self.embed_dims[i]
stage = basic_blocks(
in_channels,
out_channels,
i,
self.layers,
pool_size=pool_size,
mlp_ratio=mlp_ratios,
act_cfg=act_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
vit_num=self.vit_num,
use_layer_scale=use_layer_scale,
has_downsamper=self.downsamples[i])
network.append(stage)
self.network = ModuleList(network)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
for i_layer in self.out_indices:
if not self.reshape_last_feat and \
i_layer == 3 and self.vit_num > 0:
layer = build_norm_layer(
dict(type='LN'), self.embed_dims[i_layer])[1]
else:
# use GN with 1 group as channel-first LN2D
layer = build_norm_layer(
dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
def _make_stem(self, in_channels: int, stem_channels: int):
"""make 2-ConvBNReLu stem layer."""
self.patch_embed = Sequential(
ConvModule(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
inplace=True))
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
if idx == len(self.network) - 1:
N, _, H, W = x.shape
if self.downsamples[idx]:
H, W = H // 2, W // 2
x = block(x)
if idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
if idx == len(self.network) - 1 and x.dim() == 3:
# when ``vit-num`` > 0 and in the last stage,
# if `self.reshape_last_feat`` is True, reshape the
# features to `BCHW` format before the final normalization.
# if `self.reshape_last_feat`` is False, do
# normalization directly and permute the features to `BCN`.
if self.reshape_last_feat:
x = x.permute((0, 2, 1)).reshape(N, -1, H, W)
x_out = norm_layer(x)
else:
x_out = norm_layer(x).permute((0, 2, 1))
else:
x_out = norm_layer(x)
outs.append(x_out.contiguous())
return tuple(outs)
def forward(self, x):
# input embedding
x = self.patch_embed(x)
# through stages
x = self.forward_tokens(x)
return x
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
# Include both block and downsample layer.
module = self.network[i]
module.eval()
for param in module.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(EfficientFormer, self).train(mode)
self._freeze_stages()

View File

@ -2,6 +2,7 @@
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .efficientformer_head import EfficientFormerClsHead
from .linear_head import LinearClsHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
@ -11,5 +12,5 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead'
'ConformerHead', 'EfficientFormerClsHead'
]

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from .cls_head import ClsHead
@MODELS.register_module()
class EfficientFormerClsHead(ClsHead):
"""EfficientFormer classifier head.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
distillation (bool): Whether use a additional distilled head.
Defaults to True.
init_cfg (dict): The extra initialization configs. Defaults to
``dict(type='Normal', layer='Linear', std=0.01)``.
"""
def __init__(self,
num_classes,
in_channels,
distillation=True,
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
**kwargs):
super(EfficientFormerClsHead, self).__init__(
init_cfg=init_cfg, *args, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.dist = distillation
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self.head = nn.Linear(self.in_channels, self.num_classes)
if self.dist:
self.dist_head = nn.Linear(self.in_channels, self.num_classes)
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.head(pre_logits)
if self.dist:
cls_score = (cls_score + self.dist_head(pre_logits)) / 2
return cls_score
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In :obj`EfficientFormerClsHead`, we just
obtain the feature of the last stage.
"""
# The EfficientFormerClsHead doesn't have other module, just return
# after unpacking.
return feats[-1]
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
if self.dist:
raise NotImplementedError(
"MMClassification doesn't support to train"
' the distilled version EfficientFormer.')
else:
return super().loss(feats, data_samples, **kwargs)

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyrigforward_trainht (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch

View File

@ -7,6 +7,7 @@ from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
resize_relative_position_bias_table)
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .layer_scale import LayerScale
from .make_divisible import make_divisible
from .position_encoding import ConditionalPositionEncoding
from .se_layer import SELayer
@ -17,5 +18,6 @@ __all__ = [
'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA',
'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding',
'resize_pos_embed', 'resize_relative_position_bias_table',
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention'
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention',
'LayerScale'
]

View File

@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): inplace: can optionally do the
operation in-place. Defaults to False.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Defaults to 'channels_last'.
"""
def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last'):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
def forward(self, x):
if self.data_format == 'channels_first':
if self.inplace:
return x.mul_(self.weight.view(-1, 1, 1))
else:
return x * self.weight.view(-1, 1, 1)
return x.mul_(self.weight) if self.inplace else x * self.weight

View File

@ -29,3 +29,4 @@ Import:
- configs/poolformer/metafile.yml
- configs/inception_v3/metafile.yml
- configs/mobileone/metafile.yml
- configs/efficientformer/metafile.yml

View File

@ -0,0 +1,199 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmcv.cnn import ConvModule
from torch import nn
from mmcls.models.backbones import EfficientFormer
from mmcls.models.backbones.efficientformer import (AttentionWithBias, Flat,
Meta3D, Meta4D)
from mmcls.models.backbones.poolformer import Pooling
class TestEfficientFormer(TestCase):
def setUp(self):
self.cfg = dict(arch='l1', drop_path_rate=0.1)
self.arch = EfficientFormer.arch_settings['l1']
self.custom_arch = {
'layers': [1, 1, 1, 4],
'embed_dims': [48, 96, 224, 448],
'downsamples': [False, True, True, True],
'vit_num': 2,
}
self.custom_cfg = dict(arch=self.custom_arch)
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'Unavailable arch'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
EfficientFormer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'must have'):
cfg = deepcopy(self.custom_cfg)
cfg['arch'].pop('layers')
EfficientFormer(**cfg)
# Test vit_num < 0
with self.assertRaisesRegex(AssertionError, "'vit_num' must"):
cfg = deepcopy(self.custom_cfg)
cfg['arch']['vit_num'] = -1
EfficientFormer(**cfg)
# Test vit_num > last stage layers
with self.assertRaisesRegex(AssertionError, "'vit_num' must"):
cfg = deepcopy(self.custom_cfg)
cfg['arch']['vit_num'] = 10
EfficientFormer(**cfg)
# Test out_ind
with self.assertRaisesRegex(AssertionError, '"out_indices" must'):
cfg = deepcopy(self.custom_cfg)
cfg['out_indices'] = dict
EfficientFormer(**cfg)
# Test custom arch
cfg = deepcopy(self.custom_cfg)
model = EfficientFormer(**cfg)
self.assertEqual(len(model.patch_embed), 2)
layers = self.custom_arch['layers']
downsamples = self.custom_arch['downsamples']
vit_num = self.custom_arch['vit_num']
for i, stage in enumerate(model.network):
if downsamples[i]:
self.assertIsInstance(stage[0], ConvModule)
self.assertEqual(stage[0].conv.stride, (2, 2))
self.assertTrue(hasattr(stage[0].conv, 'bias'))
self.assertTrue(isinstance(stage[0].bn, nn.BatchNorm2d))
if i < len(model.network) - 1:
self.assertIsInstance(stage[-1], Meta4D)
self.assertIsInstance(stage[-1].token_mixer, Pooling)
self.assertEqual(len(stage) - downsamples[i], layers[i])
elif vit_num > 0:
self.assertIsInstance(stage[-1], Meta3D)
self.assertIsInstance(stage[-1].token_mixer, AttentionWithBias)
self.assertEqual(len(stage) - downsamples[i] - 1, layers[i])
flat_layer_idx = len(stage) - vit_num - downsamples[i]
self.assertIsInstance(stage[flat_layer_idx], Flat)
count = 0
for layer in stage:
if isinstance(layer, Meta3D):
count += 1
self.assertEqual(count, vit_num)
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear'),
dict(type='Constant', layer=['LayerScale'], val=1e-4)
]
model = EfficientFormer(**cfg)
ori_weight = model.patch_embed[0].conv.weight.clone().detach()
ori_ls_weight = model.network[0][-1].ls1.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed[0].conv.weight
initialized_ls_weight = model.network[0][-1].ls1.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
self.assertFalse(torch.allclose(ori_ls_weight, initialized_ls_weight))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
# test last stage output
cfg = deepcopy(self.cfg)
model = EfficientFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (1, 448, 49))
assert hasattr(model, 'norm3')
assert isinstance(getattr(model, 'norm3'), nn.LayerNorm)
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 1, 2, 3)
cfg['reshape_last_feat'] = True
model = EfficientFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
# Test out features shape
for dim, stride, out in zip(self.arch['embed_dims'], [1, 2, 4, 8],
outs):
self.assertEqual(out.shape, (1, dim, 56 // stride, 56 // stride))
# Test norm layer
for i in range(4):
assert hasattr(model, f'norm{i}')
stage_norm = getattr(model, f'norm{i}')
assert isinstance(stage_norm, nn.GroupNorm)
assert stage_norm.num_groups == 1
# Test vit_num == 0
cfg = deepcopy(self.custom_cfg)
cfg['arch']['vit_num'] = 0
cfg['out_indices'] = (0, 1, 2, 3)
model = EfficientFormer(**cfg)
for i in range(4):
assert hasattr(model, f'norm{i}')
stage_norm = getattr(model, f'norm{i}')
assert isinstance(stage_norm, nn.GroupNorm)
assert stage_norm.num_groups == 1
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = EfficientFormer(**cfg)
layers = self.arch['layers']
for i, block in enumerate(model.network):
expect_prob = 0.2 / (sum(layers) - 1) * i
if hasattr(block, 'drop_path'):
if expect_prob == 0:
self.assertIsInstance(block.drop_path, torch.nn.Identity)
else:
self.assertAlmostEqual(block.drop_path.drop_prob,
expect_prob)
# test with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 1
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 1, 2, 3)
model = EfficientFormer(**cfg)
model.init_weights()
model.train()
# the patch_embed and first stage should not require grad.
self.assertFalse(model.patch_embed.training)
for param in model.patch_embed.parameters():
self.assertFalse(param.requires_grad)
for i in range(frozen_stages):
module = model.network[i]
for param in module.parameters():
self.assertFalse(param.requires_grad)
for param in model.norm0.parameters():
self.assertFalse(param.requires_grad)
# the second stage should require grad.
for i in range(frozen_stages + 1, 4):
module = model.network[i]
for param in module.parameters():
self.assertTrue(param.requires_grad)
if hasattr(model, f'norm{i}'):
norm = getattr(model, f'norm{i}')
for param in norm.parameters():
self.assertTrue(param.requires_grad)

View File

@ -24,6 +24,7 @@ def setup_seed(seed):
class TestClsHead(TestCase):
DEFAULT_ARGS = dict(type='ClsHead')
FAKE_FEATS = (torch.rand(4, 10), )
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
@ -42,7 +43,7 @@ class TestClsHead(TestCase):
self.assertIs(outs, feats[-1])
def test_loss(self):
feats = (torch.rand(4, 10), )
feats = self.FAKE_FEATS
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
@ -96,6 +97,7 @@ class TestClsHead(TestCase):
class TestLinearClsHead(TestCase):
DEFAULT_ARGS = dict(type='LinearClsHead', in_channels=10, num_classes=5)
FAKE_FEATS = (torch.rand(4, 10), )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
@ -425,6 +427,47 @@ class TestMultiLabelClsHead(TestCase):
self.assertIn('score', pred.pred_label)
class EfficientFormerClsHead(TestClsHead):
DEFAULT_ARGS = dict(
type='EfficientFormerClsHead',
in_channels=10,
num_classes=10,
distillation=False)
FAKE_FEATS = (torch.rand(4, 10), )
def test_forward(self):
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
self.assertTrue(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
# test without distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
head = MODELS.build(cfg)
self.assertFalse(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
with self.assertRaisesRegex(NotImplementedError, 'MMClassification '):
head.loss(feats, data_samples)
# test without distillation head
super().test_loss()
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
DEFAULT_ARGS = dict(
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)

View File

@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcls.models.utils import LayerScale
class TestLayerScale(TestCase):
def test_init(self):
with self.assertRaisesRegex(AssertionError, "'data_format' could"):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight,
torch.ones(10, requires_grad=True) * 1e-5)
def forward(self):
# Test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
self.assertEqual(tuple(out.size()), (4, 49, 256))
assert torch.equal(x * 1e-5, out)
# Test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
assert torch.equal(x * 1e-5, out)
# Test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
self.assertIs(x, out)