mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
Fix Mobilenetv3 structure and add pretrained model (#291)
* Refactor Mobilenetv3 structure and add ConvClsHead. * Change model's name from 'MobileNetv3' to 'MobileNetV3' * Modify configs for MobileNetV3 on CIFAR10. And add MobileNetV3 configs for imagenet * Fix activate setting bugs in MobileNetV3. And remove bias in SELayer. * Modify unittest * Remove useless config and file. * Fix mobilenetv3-large arch setting * Add dropout option in ConvClsHead * Fix MobilenetV3 structure according to torchvision version. 1. Remove with_expand_conv option in InvertedResidual, it should be decided by channels. 2. Revert activation function, should before SE layer. * Format code. * Rename MobilenetV3 arch "big" to "large". * Add mobilenetv3_small torchvision training recipe * Modify default `out_indices` of MobilenetV3, now it will change according to `arch` if not specified. * Add MobilenetV3 large config. * Add mobilenetv3 README * Modify InvertedResidual unit test. * Refactor ConvClsHead to StackedLinearClsHead, and add unit tests. * Add unit test for `simple_test` of `StackedLinearClsHead`. * Fix typo Co-authored-by: Yidi Shao <ydshao@smail.nju.edu.cn>
This commit is contained in:
parent
53c0df271f
commit
65410b05ad
14
configs/_base_/models/mobilenet_v3_large_imagenet.py
Normal file
14
configs/_base_/models/mobilenet_v3_large_imagenet.py
Normal file
@ -0,0 +1,14 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='MobileNetV3', arch='large'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='StackedLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=960,
|
||||
mid_channels=[1280],
|
||||
dropout_rate=0.2,
|
||||
act_cfg=dict(type='HSwish'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
13
configs/_base_/models/mobilenet_v3_small_cifar.py
Normal file
13
configs/_base_/models/mobilenet_v3_small_cifar.py
Normal file
@ -0,0 +1,13 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='MobileNetV3', arch='small'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='StackedLinearClsHead',
|
||||
num_classes=10,
|
||||
in_channels=576,
|
||||
mid_channels=[1280],
|
||||
act_cfg=dict(type='HSwish'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
14
configs/_base_/models/mobilenet_v3_small_imagenet.py
Normal file
14
configs/_base_/models/mobilenet_v3_small_imagenet.py
Normal file
@ -0,0 +1,14 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='MobileNetV3', arch='small'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='StackedLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=576,
|
||||
mid_channels=[1024],
|
||||
dropout_rate=0.2,
|
||||
act_cfg=dict(type='HSwish'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
30
configs/mobilenet_v3/README.md
Normal file
30
configs/mobilenet_v3/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Searching for MobileNetV3
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```latex
|
||||
@inproceedings{Howard_2019_ICCV,
|
||||
author = {Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and Le, Quoc V. and Adam, Hartwig},
|
||||
title = {Searching for MobileNetV3},
|
||||
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||
month = {October},
|
||||
year = {2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrain model
|
||||
|
||||
The pre-trained modles are converted from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html).
|
||||
|
||||
### ImageNet
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
|
||||
| MobileNetV3-Large | 5.48 | 0.23 | 74.04 | 91.34 | [model](https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth)|
|
||||
| MobileNetV3-Small | 2.54 | 0.06 | 67.66 | 87.41 | [model](https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth)|
|
||||
|
||||
## Results and models
|
||||
|
||||
Waiting for adding.
|
158
configs/mobilenet_v3/mobilenet_v3_large_imagenet.py
Normal file
158
configs/mobilenet_v3/mobilenet_v3_large_imagenet.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Refer to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
# ----------------------------
|
||||
# -[x] auto_augment='imagenet'
|
||||
# -[x] batch_size=128 (per gpu)
|
||||
# -[x] epochs=600
|
||||
# -[x] opt='rmsprop'
|
||||
# -[x] lr=0.064
|
||||
# -[x] eps=0.0316
|
||||
# -[x] alpha=0.9
|
||||
# -[x] weight_decay=1e-05
|
||||
# -[x] momentum=0.9
|
||||
# -[x] lr_gamma=0.973
|
||||
# -[x] lr_step_size=2
|
||||
# -[x] nproc_per_node=8
|
||||
# -[x] random_erase=0.2
|
||||
# -[x] workers=16 (workers_per_gpu)
|
||||
# - modify: RandomErasing use RE-M instead of RE-0
|
||||
|
||||
_base_ = [
|
||||
'../_base_/models/mobilenet_v3_large_imagenet.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
policies = [
|
||||
[
|
||||
dict(type='Posterize', bits=4, prob=0.4),
|
||||
dict(type='Rotate', angle=30., prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
[
|
||||
dict(type='Posterize', bits=5, prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Posterize', bits=6, prob=0.8),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='Rotate', angle=10., prob=0.2),
|
||||
dict(type='Solarize', thr=256 / 9, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.0),
|
||||
dict(type='Equalize', prob=0.8)],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0.2, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
|
||||
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Sharpness', magnitude=0.7, prob=0.4),
|
||||
dict(type='Invert', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude=0.3 / 9 * 5,
|
||||
prob=0.6,
|
||||
direction='horizontal'),
|
||||
dict(type='Equalize', prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='AutoAugment', policies=policies),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.2,
|
||||
mode='const',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean']),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
train=dict(pipeline=train_pipeline))
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='RMSprop',
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=2, gamma=0.973, by_epoch=True)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=600)
|
8
configs/mobilenet_v3/mobilenet_v3_small_cifar.py
Normal file
8
configs/mobilenet_v3/mobilenet_v3_small_cifar.py
Normal file
@ -0,0 +1,8 @@
|
||||
_base_ = [
|
||||
'../_base_/models/mobilenet_v3_small_cifar.py',
|
||||
'../_base_/datasets/cifar10_bs16.py',
|
||||
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
lr_config = dict(policy='step', step=[120, 170])
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
158
configs/mobilenet_v3/mobilenet_v3_small_imagenet.py
Normal file
158
configs/mobilenet_v3/mobilenet_v3_small_imagenet.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Refer to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
# ----------------------------
|
||||
# -[x] auto_augment='imagenet'
|
||||
# -[x] batch_size=128 (per gpu)
|
||||
# -[x] epochs=600
|
||||
# -[x] opt='rmsprop'
|
||||
# -[x] lr=0.064
|
||||
# -[x] eps=0.0316
|
||||
# -[x] alpha=0.9
|
||||
# -[x] weight_decay=1e-05
|
||||
# -[x] momentum=0.9
|
||||
# -[x] lr_gamma=0.973
|
||||
# -[x] lr_step_size=2
|
||||
# -[x] nproc_per_node=8
|
||||
# -[x] random_erase=0.2
|
||||
# -[x] workers=16 (workers_per_gpu)
|
||||
# - modify: RandomErasing use RE-M instead of RE-0
|
||||
|
||||
_base_ = [
|
||||
'../_base_/models/mobilenet_v3_small_imagenet.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
policies = [
|
||||
[
|
||||
dict(type='Posterize', bits=4, prob=0.4),
|
||||
dict(type='Rotate', angle=30., prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
[
|
||||
dict(type='Posterize', bits=5, prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Posterize', bits=6, prob=0.8),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='Rotate', angle=10., prob=0.2),
|
||||
dict(type='Solarize', thr=256 / 9, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.0),
|
||||
dict(type='Equalize', prob=0.8)],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0.2, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
|
||||
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Sharpness', magnitude=0.7, prob=0.4),
|
||||
dict(type='Invert', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude=0.3 / 9 * 5,
|
||||
prob=0.6,
|
||||
direction='horizontal'),
|
||||
dict(type='Equalize', prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='AutoAugment', policies=policies),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.2,
|
||||
mode='const',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean']),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
train=dict(pipeline=train_pipeline))
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='RMSprop',
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=2, gamma=0.973, by_epoch=True)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=600)
|
@ -1,7 +1,7 @@
|
||||
from .alexnet import AlexNet
|
||||
from .lenet import LeNet5
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetv3
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .regnet import RegNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
@ -17,5 +17,5 @@ from .vision_transformer import VisionTransformer
|
||||
__all__ = [
|
||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3', 'VisionTransformer'
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer'
|
||||
]
|
||||
|
@ -7,18 +7,18 @@ from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MobileNetv3(BaseBackbone):
|
||||
"""MobileNetv3 backbone.
|
||||
class MobileNetV3(BaseBackbone):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
Args:
|
||||
arch (str): Architechture of mobilnetv3, from {small, big}.
|
||||
arch (str): Architechture of mobilnetv3, from {small, large}.
|
||||
Default: small.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (10, ), which means output tensors from final stage.
|
||||
Default: None, which means output tensors from final stage.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Defualt: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
@ -42,49 +42,54 @@ class MobileNetv3(BaseBackbone):
|
||||
[5, 288, 96, True, 'HSwish', 2],
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'big': [[3, 16, 16, False, 'ReLU', 1],
|
||||
[3, 64, 24, False, 'ReLU', 2],
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2],
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1],
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1],
|
||||
[3, 64, 24, False, 'ReLU', 2],
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2],
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1],
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2],
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(10, ),
|
||||
norm_cfg=dict(type='BN', eps=0.001, momentum=0.01),
|
||||
out_indices=None,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(type='Constant', val=1, layer=['BatchNorm2d'])
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer=['Conv2d'],
|
||||
nonlinearity='leaky_relu'),
|
||||
dict(type='Normal', layer=['Linear'], std=0.01),
|
||||
dict(type='Constant', layer=['BatchNorm2d'], val=1)
|
||||
]):
|
||||
super(MobileNetv3, self).__init__(init_cfg)
|
||||
super(MobileNetV3, self).__init__(init_cfg)
|
||||
assert arch in self.arch_settings
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch])):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])}). '
|
||||
f'But received {index}')
|
||||
if out_indices is None:
|
||||
out_indices = (12, ) if arch == 'small' else (16, )
|
||||
for order, index in enumerate(out_indices):
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch]) + 2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch])):
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])}). '
|
||||
f'{len(self.arch_settings[arch]) + 2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
@ -93,23 +98,26 @@ class MobileNetv3(BaseBackbone):
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = 16
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
|
||||
self.layers = self._make_layer()
|
||||
self.feat_dim = self.arch_settings[arch][-1][2]
|
||||
self.feat_dim = self.arch_settings[arch][-1][1]
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
in_channels = 16
|
||||
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
@ -117,31 +125,50 @@ class MobileNetv3(BaseBackbone):
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(
|
||||
type='HSigmoid',
|
||||
bias=3,
|
||||
divisor=6,
|
||||
min_value=0,
|
||||
max_value=1)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=self.in_channels,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
self.in_channels = out_channels
|
||||
in_channels = out_channels
|
||||
layer_name = 'layer{}'.format(i + 1)
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# Build the last layer before pooling
|
||||
# TODO: No dilation
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = 'layer{}'.format(len(layer_setting) + 1)
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
@ -155,17 +182,14 @@ class MobileNetv3(BaseBackbone):
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(MobileNetv3, self).train(mode)
|
||||
super(MobileNetV3, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
|
@ -2,9 +2,10 @@ from .cls_head import ClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
||||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead',
|
||||
'VisionTransformerClsHead'
|
||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead'
|
||||
]
|
||||
|
135
mmcls/models/heads/stacked_head.py
Normal file
135
mmcls/models/heads/stacked_head.py
Normal file
@ -0,0 +1,135 @@
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from ..builder import HEADS
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
class LinearBlock(BaseModule):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout_rate=0.,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.fc = nn.Linear(in_channels, out_channels)
|
||||
|
||||
self.norm = None
|
||||
self.act = None
|
||||
self.dropout = None
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
|
||||
if act_cfg is not None:
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
if dropout_rate > 0:
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class StackedLinearClsHead(ClsHead):
|
||||
"""Classifier head with several hidden fc layer and a output fc layer.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
mid_channels (Sequence): Number of channels in the hidden fc layers.
|
||||
dropout_rate (float): Dropout rate after each hidden fc layer,
|
||||
except the last layer. Defaults to 0.
|
||||
norm_cfg (dict, optional): Config dict of normalization layer after
|
||||
each hidden fc layer, except the last layer. Defaults to None.
|
||||
act_cfg (dict, optional): Config dict of activation function after each
|
||||
hidden layer, except the last layer. Defaults to use "ReLU".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
mid_channels: Sequence,
|
||||
dropout_rate: float = 0.,
|
||||
norm_cfg: Dict = None,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
**kwargs):
|
||||
super(StackedLinearClsHead, self).__init__(**kwargs)
|
||||
assert num_classes > 0, \
|
||||
f'`num_classes` of StackedLinearClsHead must be a positive ' \
|
||||
f'integer, got {num_classes} instead.'
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
assert isinstance(mid_channels, Sequence), \
|
||||
f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \
|
||||
f'instead of {type(mid_channels)}'
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
self.dropout_rate = dropout_rate
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = ModuleList(
|
||||
init_cfg=dict(
|
||||
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.))
|
||||
in_channels = self.in_channels
|
||||
for hidden_channels in self.mid_channels:
|
||||
self.layers.append(
|
||||
LinearBlock(
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
dropout_rate=self.dropout_rate,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
in_channels = hidden_channels
|
||||
|
||||
self.layers.append(
|
||||
LinearBlock(
|
||||
self.mid_channels[-1],
|
||||
self.num_classes,
|
||||
dropout_rate=0.,
|
||||
norm_cfg=None,
|
||||
act_cfg=None))
|
||||
|
||||
def init_weights(self):
|
||||
self.layers.init_weights()
|
||||
|
||||
def simple_test(self, img):
|
||||
"""Test without augmentation."""
|
||||
cls_score = img
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = x
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
@ -18,9 +18,6 @@ class InvertedResidual(BaseModule):
|
||||
stride (int): The stride of the depthwise convolution. Default: 1.
|
||||
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
|
||||
se layer.
|
||||
with_expand_conv (bool): Use expand conv or not. If set False,
|
||||
mid_channels must be the same with in_channels.
|
||||
Default: True.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
@ -41,7 +38,6 @@ class InvertedResidual(BaseModule):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
se_cfg=None,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
@ -52,12 +48,10 @@ class InvertedResidual(BaseModule):
|
||||
assert stride in [1, 2]
|
||||
self.with_cp = with_cp
|
||||
self.with_se = se_cfg is not None
|
||||
self.with_expand_conv = with_expand_conv
|
||||
self.with_expand_conv = (mid_channels != in_channels)
|
||||
|
||||
if self.with_se:
|
||||
assert isinstance(se_cfg, dict)
|
||||
if not self.with_expand_conv:
|
||||
assert mid_channels == in_channels
|
||||
|
||||
if self.with_expand_conv:
|
||||
self.expand_conv = ConvModule(
|
||||
@ -89,7 +83,7 @@ class InvertedResidual(BaseModule):
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from .make_divisible import make_divisible
|
||||
|
||||
|
||||
# class SELayer(nn.Module):
|
||||
class SELayer(BaseModule):
|
||||
@ -25,6 +27,7 @@ class SELayer(BaseModule):
|
||||
def __init__(self,
|
||||
channels,
|
||||
ratio=16,
|
||||
bias='auto',
|
||||
conv_cfg=None,
|
||||
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
|
||||
init_cfg=None):
|
||||
@ -34,18 +37,21 @@ class SELayer(BaseModule):
|
||||
assert len(act_cfg) == 2
|
||||
assert mmcv.is_tuple_of(act_cfg, dict)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
squeeze_channels = make_divisible(channels // ratio, 8)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=channels,
|
||||
out_channels=int(channels / ratio),
|
||||
out_channels=squeeze_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=bias,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[0])
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=int(channels / ratio),
|
||||
in_channels=squeeze_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=bias,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[1])
|
||||
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import MobileNetv3
|
||||
from mmcls.models.backbones import MobileNetV3
|
||||
from mmcls.models.utils import InvertedResidual
|
||||
|
||||
|
||||
@ -26,42 +26,40 @@ def check_norm_state(modules, train_state):
|
||||
def test_mobilenetv3_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = MobileNetv3()
|
||||
model = MobileNetV3()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# arch must in [small, big]
|
||||
MobileNetv3(arch='others')
|
||||
# arch must in [small, large]
|
||||
MobileNetV3(arch='others')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must less than 12 when arch is small
|
||||
MobileNetv3(arch='small', frozen_stages=12)
|
||||
# frozen_stages must less than 13 when arch is small
|
||||
MobileNetV3(arch='small', frozen_stages=13)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must less than 16 when arch is big
|
||||
MobileNetv3(arch='big', frozen_stages=16)
|
||||
# frozen_stages must less than 17 when arch is large
|
||||
MobileNetV3(arch='large', frozen_stages=17)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# max out_indices must less than 11 when arch is small
|
||||
MobileNetv3(arch='small', out_indices=(11, ))
|
||||
# max out_indices must less than 13 when arch is small
|
||||
MobileNetV3(arch='small', out_indices=(13, ))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# max out_indices must less than 15 when arch is big
|
||||
MobileNetv3(arch='big', out_indices=(15, ))
|
||||
# max out_indices must less than 17 when arch is large
|
||||
MobileNetV3(arch='large', out_indices=(17, ))
|
||||
|
||||
# Test MobileNetv3
|
||||
model = MobileNetv3()
|
||||
# Test MobileNetV3
|
||||
model = MobileNetV3()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test MobileNetv3 with first stage frozen
|
||||
# Test MobileNetV3 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = MobileNetv3(frozen_stages=frozen_stages)
|
||||
model = MobileNetV3(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
for param in model.conv1.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
for i in range(0, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
@ -69,35 +67,37 @@ def test_mobilenetv3_backbone():
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test MobileNetv3 with norm eval
|
||||
model = MobileNetv3(norm_eval=True, out_indices=range(0, 11))
|
||||
# Test MobileNetV3 with norm eval
|
||||
model = MobileNetV3(norm_eval=True, out_indices=range(0, 12))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test MobileNetv3 forward with small arch
|
||||
model = MobileNetv3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
|
||||
# Test MobileNetV3 forward with small arch
|
||||
model = MobileNetV3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 11
|
||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
||||
assert len(feat) == 13
|
||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[8].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[11].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[12].shape == torch.Size([1, 576, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with small arch and GroupNorm
|
||||
model = MobileNetv3(
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
|
||||
# Test MobileNetV3 forward with small arch and GroupNorm
|
||||
model = MobileNetV3(
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
|
||||
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
@ -107,47 +107,51 @@ def test_mobilenetv3_backbone():
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 11
|
||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
||||
assert len(feat) == 13
|
||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[8].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[11].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[12].shape == torch.Size([1, 576, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with big arch
|
||||
model = MobileNetv3(
|
||||
arch='big',
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
|
||||
# Test MobileNetV3 forward with large arch
|
||||
model = MobileNetV3(
|
||||
arch='large',
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 15
|
||||
assert len(feat) == 17
|
||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 16, 112, 112])
|
||||
assert feat[2].shape == torch.Size([1, 24, 56, 56])
|
||||
assert feat[3].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 24, 56, 56])
|
||||
assert feat[4].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[5].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[6].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[7].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[9].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[10].shape == torch.Size([1, 112, 14, 14])
|
||||
assert feat[10].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[11].shape == torch.Size([1, 112, 14, 14])
|
||||
assert feat[12].shape == torch.Size([1, 160, 14, 14])
|
||||
assert feat[12].shape == torch.Size([1, 112, 14, 14])
|
||||
assert feat[13].shape == torch.Size([1, 160, 7, 7])
|
||||
assert feat[14].shape == torch.Size([1, 160, 7, 7])
|
||||
assert feat[15].shape == torch.Size([1, 160, 7, 7])
|
||||
assert feat[16].shape == torch.Size([1, 960, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with big arch
|
||||
model = MobileNetv3(arch='big', out_indices=(0, ))
|
||||
# Test MobileNetV3 forward with large arch
|
||||
model = MobileNetV3(arch='large', out_indices=(0, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
@ -155,8 +159,8 @@ def test_mobilenetv3_backbone():
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 16, 112, 112])
|
||||
|
||||
# Test MobileNetv3 with checkpoint forward
|
||||
model = MobileNetv3(with_cp=True)
|
||||
# Test MobileNetV3 with checkpoint forward
|
||||
model = MobileNetV3(with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, InvertedResidual):
|
||||
assert m.with_cp
|
||||
@ -165,4 +169,4 @@ def test_mobilenetv3_backbone():
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat.shape == torch.Size([1, 576, 7, 7])
|
||||
|
@ -57,10 +57,9 @@ def test_inverted_residual():
|
||||
# se_cfg must be None or dict
|
||||
InvertedResidual(16, 16, 32, se_cfg=list())
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_channeld and out_channels must be the same if
|
||||
# with_expand_conv is False
|
||||
InvertedResidual(16, 16, 32, with_expand_conv=False)
|
||||
# Add expand conv if in_channels and mid_channels is not the same
|
||||
assert InvertedResidual(32, 16, 32).with_expand_conv is False
|
||||
assert InvertedResidual(16, 16, 32).with_expand_conv is True
|
||||
|
||||
# Test InvertedResidual forward, stride=1
|
||||
block = InvertedResidual(16, 16, 32, stride=1)
|
||||
@ -85,8 +84,8 @@ def test_inverted_residual():
|
||||
assert isinstance(block.se, SELayer)
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward, with_expand_conv=False
|
||||
block = InvertedResidual(32, 16, 32, with_expand_conv=False)
|
||||
# Test InvertedResidual forward without expand conv
|
||||
block = InvertedResidual(32, 16, 32)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = block(x)
|
||||
assert getattr(block, 'expand_conv', None) is None
|
||||
|
@ -1,7 +1,10 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
||||
MultiLabelLinearClsHead)
|
||||
MultiLabelLinearClsHead, StackedLinearClsHead)
|
||||
|
||||
|
||||
def test_cls_head():
|
||||
@ -48,3 +51,47 @@ def test_multilabel_linear_head():
|
||||
head.init_weights()
|
||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_stacked_linear_cls_head():
|
||||
# test assertion
|
||||
with pytest.raises(AssertionError):
|
||||
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
|
||||
|
||||
fake_img = torch.rand(4, 5) # B, channel
|
||||
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
|
||||
|
||||
# test forward with default setting
|
||||
head = StackedLinearClsHead(
|
||||
num_classes=3, in_channels=5, mid_channels=[10])
|
||||
head.init_weights()
|
||||
|
||||
losses = head.forward_train(fake_img, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple test
|
||||
pred = head.simple_test(fake_img)
|
||||
assert len(pred) == 4
|
||||
|
||||
# test simple test in tracing
|
||||
p = patch('torch.onnx.is_in_onnx_export', lambda: True)
|
||||
p.start()
|
||||
pred = head.simple_test(fake_img)
|
||||
assert pred.shape == torch.Size((4, 3))
|
||||
p.stop()
|
||||
|
||||
# test forward with full function
|
||||
head = StackedLinearClsHead(
|
||||
num_classes=3,
|
||||
in_channels=5,
|
||||
mid_channels=[8, 10],
|
||||
dropout_rate=0.2,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
act_cfg=dict(type='HSwish'))
|
||||
head.init_weights()
|
||||
|
||||
losses = head.forward_train(fake_img, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user