[Feature] Support MobileViT backbone. (#1068)

* init

* fix

* add config

* add meta

* add unittest

* fix for comments

* Imporvee docstring and support custom arch.

* Update README

* Update windows CI

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1125/head
Hubert 2022-10-18 17:05:59 +08:00 committed by GitHub
parent 29f066f7fb
commit bcca619066
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 760 additions and 3 deletions

View File

@ -71,7 +71,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==1.8.2+${{matrix.platform}} torchvision==0.9.2+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install mmcls dependencies

View File

@ -148,6 +148,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
</details>

View File

@ -147,6 +147,7 @@ mim install -e .
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
</details>

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileViT', arch='small'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=640,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileViT', arch='x_small'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileViT', arch='xx_small'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=320,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,36 @@
# MobileVit
> [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178)
<!-- [ALGORITHM] -->
## Abstract
Light-weight convolutional neural networks (CNNs) are the de-facto for mobile vision tasks. Their spatial inductive biases allow them to learn representations with fewer parameters across different vision tasks. However, these networks are spatially local. To learn global representations, self-attention-based vision trans-formers (ViTs) have been adopted. Unlike CNNs, ViTs are heavy-weight. In this paper, we ask the following question: is it possible to combine the strengths of CNNs and ViTs to build a light-weight and low latency network for mobile vision tasks? Towards this end, we introduce MobileViT, a light-weight and general-purpose vision transformer for mobile devices. MobileViT presents a different perspective for the global processing of information with transformers, i.e., transformers as convolutions. Our results show that MobileViT significantly outperforms CNN- and ViT-based networks across different tasks and datasets. On the ImageNet-1k dataset, MobileViT achieves top-1 accuracy of 78.4% with about 6 million parameters, which is 3.2% and 6.2% more accurate than MobileNetv3 (CNN-based) and DeIT (ViT-based) for a similar number of parameters. On the MS-COCO object detection task, MobileViT is 5.7% more accurate than MobileNetv3 for a similar number of parameters.
<div align=center>
<img src="https://user-images.githubusercontent.com/42952108/193229983-822bf025-89a6-4d95-b6be-76b7f1a62f2c.png" width="70%"/>
</div>
## Results and models
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-----------------: | :-------: | :------: | :-------: | :-------: | :------------------------------------------: | :----------------------------------------------------------------------------------------------------: |
| MobileViT-XXSmall\* | 1.27 | 0.42 | 69.02 | 88.91 | [config](./mobilevit-xxsmall_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-xxsmall_3rdparty_in1k_20221018-77835605.pth) |
| MobileViT-XSmall\* | 2.32 | 1.05 | 74.75 | 92.32 | [config](./mobilevit-xsmall_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-xsmall_3rdparty_in1k_20221018-be39a6e7.pth) |
| MobileViT-Small\* | 5.58 | 2.03 | 78.25 | 94.09 | [config](./mobilevit-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-small_3rdparty_in1k_20221018-cb4f741c.pth) |
*Models with * are converted from [ml-cvnets](https://github.com/apple/ml-cvnets). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
```
@article{mehta2021mobilevit,
title={MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author={Mehta, Sachin and Rastegari, Mohammad},
journal={arXiv preprint arXiv:2110.02178},
year={2021}
}
```

View File

@ -0,0 +1,60 @@
Collections:
- Name: MobileViT
Metadata:
Training Data: ImageNet-1k
Architecture:
- MobileViT Block
Paper:
URL: https://arxiv.org/abs/2110.02178
Title: MobileViT Light-weight, General-purpose, and Mobile-friendly Vision Transformer
README: configs/mobilevit/README.md
Models:
- Name: mobilevit-small_3rdparty_in1k
Metadata:
FLOPs: 2030000000
Parameters: 5580000
In Collection: MobileViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.25
Top 5 Accuracy: 94.09
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-small_3rdparty_in1k_20221018-cb4f741c.pth
Config: configs/mobilevit/mobilevit-small_8xb128_in1k.py
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
Code: https://github.com/apple/ml-cvnets
- Name: mobilevit-xsmall_3rdparty_in1k
Metadata:
FLOPs: 1050000000
Parameters: 2320000
In Collection: MobileViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 74.75
Top 5 Accuracy: 92.32
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-xsmall_3rdparty_in1k_20221018-be39a6e7.pth
Config: configs/mobilevit/mobilevit-xsmall_8xb128_in1k.py
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
Code: https://github.com/apple/ml-cvnets
- Name: mobilevit-xxsmall_3rdparty_in1k
Metadata:
FLOPs: 420000000
Parameters: 1270000
In Collection: MobileViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 69.02
Top 5 Accuracy: 88.91
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mobilevit/mobilevit-xxsmall_3rdparty_in1k_20221018-77835605.pth
Config: configs/mobilevit/mobilevit-xxsmall_8xb128_in1k.py
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
Code: https://github.com/apple/ml-cvnets

View File

@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/mobilevit/mobilevit_s.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/default_runtime.py',
'../_base_/schedules/imagenet_bs256.py',
]
# no normalize for original implements
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0, 0, 0],
std=[255, 255, 255],
# use bgr directly
to_rgb=False,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=288, edge='short'),
dict(type='CenterCrop', crop_size=256),
dict(type='PackClsInputs'),
]
train_dataloader = dict(batch_size=128)
val_dataloader = dict(
batch_size=128,
dataset=dict(pipeline=test_pipeline),
)
test_dataloader = val_dataloader

View File

@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/mobilevit/mobilevit_xs.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/default_runtime.py',
'../_base_/schedules/imagenet_bs256.py',
]
# no normalize for original implements
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0, 0, 0],
std=[255, 255, 255],
# use bgr directly
to_rgb=False,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=288, edge='short'),
dict(type='CenterCrop', crop_size=256),
dict(type='PackClsInputs'),
]
train_dataloader = dict(batch_size=128)
val_dataloader = dict(
batch_size=128,
dataset=dict(pipeline=test_pipeline),
)
test_dataloader = val_dataloader

View File

@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/mobilevit/mobilevit_xxs.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/default_runtime.py',
'../_base_/schedules/imagenet_bs256.py',
]
# no normalize for original implements
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0, 0, 0],
std=[255, 255, 255],
# use bgr directly
to_rgb=False,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=288, edge='short'),
dict(type='CenterCrop', crop_size=256),
dict(type='PackClsInputs'),
]
train_dataloader = dict(batch_size=128)
val_dataloader = dict(
batch_size=128,
dataset=dict(pipeline=test_pipeline),
)
test_dataloader = val_dataloader

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileViT', arch='small'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=640,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -63,12 +63,12 @@ Backbones
Conformer
ConvMixer
ConvNeXt
DenseNet
DeiT3
DenseNet
DistilledVisionTransformer
EdgeNeXt
EfficientFormer
EfficientNet
EdgeNeXt
HRNet
InceptionV3
LeNet5
@ -77,6 +77,7 @@ Backbones
MobileNetV2
MobileNetV3
MobileOne
MobileViT
PCPVT
PoolFormer
RegNet

View File

@ -17,6 +17,7 @@ from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mobileone import MobileOne
from .mobilevit import MobileViT
from .mvit import MViT
from .poolformer import PoolFormer
from .regnet import RegNet
@ -89,4 +90,5 @@ __all__ = [
'SwinTransformerV2',
'MViT',
'DeiT3',
'MobileViT',
]

View File

@ -0,0 +1,431 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Callable, Optional, Sequence
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from mmengine.registry import MODELS
from torch import nn
from .base_backbone import BaseBackbone
from .mobilenet_v2 import InvertedResidual
from .vision_transformer import TransformerEncoderLayer
class MobileVitBlock(nn.Module):
"""MobileViT block.
According to the paper, the MobileViT block has a local representation.
a transformer-as-convolution layer which consists of a global
representation with unfolding and folding, and a final fusion layer.
Args:
in_channels (int): Number of input image channels.
transformer_dim (int): Number of transformer channels.
ffn_dim (int): Number of ffn channels in transformer block.
out_channels (int): Number of channels in output.
conv_ksize (int): Conv kernel size in local representation
and fusion. Defaults to 3.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Defaults to dict(type='Swish').
num_transformer_blocks (int): Number of transformer blocks in
a MobileViT block. Defaults to 2.
patch_size (int): Patch size for unfolding and folding.
Defaults to 2.
num_heads (int): Number of heads in global representation.
Defaults to 4.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
no_fusion (bool): Whether to remove the fusion layer.
Defaults to False.
transformer_norm_cfg (dict, optional): Config dict for normalization
layer in transformer. Defaults to dict(type='LN').
"""
def __init__(
self,
in_channels: int,
transformer_dim: int,
ffn_dim: int,
out_channels: int,
conv_ksize: int = 3,
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='Swish'),
num_transformer_blocks: int = 2,
patch_size: int = 2,
num_heads: int = 4,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
no_fusion: bool = False,
transformer_norm_cfg: Callable = dict(type='LN'),
):
super(MobileVitBlock, self).__init__()
self.local_rep = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
padding=int((conv_ksize - 1) / 2),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=in_channels,
out_channels=transformer_dim,
kernel_size=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=None),
)
global_rep = [
TransformerEncoderLayer(
embed_dims=transformer_dim,
num_heads=num_heads,
feedforward_channels=ffn_dim,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
qkv_bias=True,
act_cfg=dict(type='Swish'),
norm_cfg=transformer_norm_cfg)
for _ in range(num_transformer_blocks)
]
global_rep.append(
build_norm_layer(transformer_norm_cfg, transformer_dim)[1])
self.global_rep = nn.Sequential(*global_rep)
self.conv_proj = ConvModule(
in_channels=transformer_dim,
out_channels=out_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if no_fusion:
self.conv_fusion = None
else:
self.conv_fusion = ConvModule(
in_channels=in_channels + out_channels,
out_channels=out_channels,
kernel_size=conv_ksize,
padding=int((conv_ksize - 1) / 2),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.patch_size = (patch_size, patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
# Local representation
x = self.local_rep(x)
# Unfold (feature map -> patches)
patch_h, patch_w = self.patch_size
B, C, H, W = x.shape
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(
W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa
num_patches = num_patch_h * num_patch_w # N
interpolate = False
if new_h != H or new_w != W:
# Note: Padding can be done, but then it needs to be handled in attention function. # noqa
x = F.interpolate(
x, size=(new_h, new_w), mode='bilinear', align_corners=False)
interpolate = True
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w,
patch_w).transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa
x = x.reshape(B, C, num_patches,
self.patch_area).transpose(1, 3).reshape(
B * self.patch_area, num_patches, -1)
# Global representations
x = self.global_rep(x)
# Fold (patch -> feature map)
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
x = x.contiguous().view(B, self.patch_area, num_patches, -1)
x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w,
patch_h, patch_w)
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa
x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h,
num_patch_w * patch_w)
if interpolate:
x = F.interpolate(
x, size=(H, W), mode='bilinear', align_corners=False)
x = self.conv_proj(x)
if self.conv_fusion is not None:
x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
return x
@MODELS.register_module()
class MobileViT(BaseBackbone):
"""MobileViT backbone.
A PyTorch implementation of : `MobileViT: Light-weight, General-purpose,
and Mobile-friendly Vision Transformer <https://arxiv.org/pdf/2110.02178.pdf>`_
Modified from the `official repo
<https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mobilevit.py>`_.
Args:
arch (str | List[list]): Architecture of MobileViT.
- If a string, choose from "small", "x_small" and "xx_small".
- If a list, every item should be also a list, and the first item
of the sub-list can be chosen from "moblienetv2" and "mobilevit",
which indicates the type of this layer sequence. If "mobilenetv2",
the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer`
(except ``in_channels``) and if "mobilevit", the other items are
the arguments of :attr:`~MobileViT.make_mobilevit_layer`
(except ``in_channels``).
Defaults to "small".
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Channels of stem layer. Defaults to 16.
last_exp_factor (int): Channels expand factor of last layer.
Defaults to 4.
out_indices (Sequence[int]): Output from which stages.
Defaults to (4, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Defaults to dict(type='Swish').
init_cfg (dict, optional): Initialization config dict.
""" # noqa
# Parameters to build layers. The first param is the type of layer.
# For `mobilenetv2` layer, the rest params from left to right are:
# out channels, stride, num of blocks, expand_ratio.
# For `mobilevit` layer, the rest params from left to right are:
# out channels, stride, transformer_channels, ffn channels,
# num of transformer blocks, expand_ratio.
arch_settings = {
'small': [
['mobilenetv2', 32, 1, 1, 4],
['mobilenetv2', 64, 2, 3, 4],
['mobilevit', 96, 2, 144, 288, 2, 4],
['mobilevit', 128, 2, 192, 384, 4, 4],
['mobilevit', 160, 2, 240, 480, 3, 4],
],
'x_small': [
['mobilenetv2', 32, 1, 1, 4],
['mobilenetv2', 48, 2, 3, 4],
['mobilevit', 64, 2, 96, 192, 2, 4],
['mobilevit', 80, 2, 120, 240, 4, 4],
['mobilevit', 96, 2, 144, 288, 3, 4],
],
'xx_small': [
['mobilenetv2', 16, 1, 1, 2],
['mobilenetv2', 24, 2, 3, 2],
['mobilevit', 48, 2, 64, 128, 2, 2],
['mobilevit', 64, 2, 80, 160, 4, 2],
['mobilevit', 80, 2, 96, 192, 3, 2],
]
}
def __init__(self,
arch='small',
in_channels=3,
stem_channels=16,
last_exp_factor=4,
out_indices=(4, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='Swish'),
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(MobileViT, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a list.'
arch = self.arch_settings[arch]
self.arch = arch
self.num_stages = len(arch)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
if frozen_stages not in range(-1, self.num_stages):
raise ValueError('frozen_stages must be in range(-1, '
f'{self.num_stages}). '
f'But received {frozen_stages}')
self.frozen_stages = frozen_stages
_make_layer_func = {
'mobilenetv2': self.make_mobilenetv2_layer,
'mobilevit': self.make_mobilevit_layer,
}
self.stem = ConvModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
in_channels = stem_channels
layers = []
for i, layer_settings in enumerate(arch):
layer_type, settings = layer_settings[0], layer_settings[1:]
layer, out_channels = _make_layer_func[layer_type](in_channels,
*settings)
layers.append(layer)
in_channels = out_channels
self.layers = nn.Sequential(*layers)
self.conv_1x1_exp = ConvModule(
in_channels=in_channels,
out_channels=last_exp_factor * in_channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
@staticmethod
def make_mobilevit_layer(in_channels,
out_channels,
stride,
transformer_dim,
ffn_dim,
num_transformer_blocks,
expand_ratio=4):
"""Build mobilevit layer, which consists of one InvertedResidual and
one MobileVitBlock.
Args:
in_channels (int): The input channels.
out_channels (int): The output channels.
stride (int): The stride of the first 3x3 convolution in the
``InvertedResidual`` layers.
transformer_dim (int): The channels of the transformer layers.
ffn_dim (int): The mid-channels of the feedforward network in
transformer layers.
num_transformer_blocks (int): The number of transformer blocks.
expand_ratio (int): adjusts number of channels of the hidden layer
in ``InvertedResidual`` by this amount. Defaults to 4.
"""
layer = []
layer.append(
InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
act_cfg=dict(type='Swish'),
))
layer.append(
MobileVitBlock(
in_channels=out_channels,
transformer_dim=transformer_dim,
ffn_dim=ffn_dim,
out_channels=out_channels,
num_transformer_blocks=num_transformer_blocks,
))
return nn.Sequential(*layer), out_channels
@staticmethod
def make_mobilenetv2_layer(in_channels,
out_channels,
stride,
num_blocks,
expand_ratio=4):
"""Build mobilenetv2 layer, which consists of several InvertedResidual
layers.
Args:
in_channels (int): The input channels.
out_channels (int): The output channels.
stride (int): The stride of the first 3x3 convolution in the
``InvertedResidual`` layers.
num_blocks (int): The number of ``InvertedResidual`` blocks.
expand_ratio (int): adjusts number of channels of the hidden layer
in ``InvertedResidual`` by this amount. Defaults to 4.
"""
layer = []
for i in range(num_blocks):
stride = stride if i == 0 else 1
layer.append(
InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
act_cfg=dict(type='Swish'),
))
in_channels = out_channels
return nn.Sequential(*layer), out_channels
def _freeze_stages(self):
for i in range(0, self.frozen_stages):
layer = self.layers[i]
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileViT, self).train(mode)
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
x = self.conv_1x1_exp(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)

View File

@ -34,3 +34,4 @@ Import:
- configs/efficientformer/metafile.yml
- configs/swin_transformer_v2/metafile.yml
- configs/deit3/metafile.yml
- configs/mobilevit/metafile.yml

View File

@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcls.models.backbones import MobileViT
def test_assertion():
with pytest.raises(AssertionError):
MobileViT(arch='unknown')
with pytest.raises(AssertionError):
# MobileViT out_indices should be valid depth.
MobileViT(out_indices=-100)
def test_mobilevit():
# Test forward
model = MobileViT(arch='small')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 640, 8, 8])
# Test custom arch
model = MobileViT(arch=[
['mobilenetv2', 16, 1, 1, 2],
['mobilenetv2', 24, 2, 3, 2],
['mobilevit', 48, 2, 64, 128, 2, 2],
['mobilevit', 64, 2, 80, 160, 4, 2],
['mobilevit', 80, 2, 96, 192, 3, 2],
])
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 320, 8, 8])
# Test last_exp_factor
model = MobileViT(arch='small', last_exp_factor=8)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 1280, 8, 8])
# Test stem_channels
model = MobileViT(arch='small', stem_channels=32)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 640, 8, 8])
# Test forward with multiple outputs
model = MobileViT(arch='small', out_indices=range(5))
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size([1, 32, 128, 128])
assert feat[1].shape == torch.Size([1, 64, 64, 64])
assert feat[2].shape == torch.Size([1, 96, 32, 32])
assert feat[3].shape == torch.Size([1, 128, 16, 16])
assert feat[4].shape == torch.Size([1, 640, 8, 8])
# Test frozen_stages
model = MobileViT(arch='small', frozen_stages=2)
model.init_weights()
model.train()
for i in range(2):
assert not model.layers[i].training
for i in range(2, 5):
assert model.layers[i].training