mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
Fix Mobilenetv3 structure and add pretrained model (#291)
* Refactor Mobilenetv3 structure and add ConvClsHead. * Change model's name from 'MobileNetv3' to 'MobileNetV3' * Modify configs for MobileNetV3 on CIFAR10. And add MobileNetV3 configs for imagenet * Fix activate setting bugs in MobileNetV3. And remove bias in SELayer. * Modify unittest * Remove useless config and file. * Fix mobilenetv3-large arch setting * Add dropout option in ConvClsHead * Fix MobilenetV3 structure according to torchvision version. 1. Remove with_expand_conv option in InvertedResidual, it should be decided by channels. 2. Revert activation function, should before SE layer. * Format code. * Rename MobilenetV3 arch "big" to "large". * Add mobilenetv3_small torchvision training recipe * Modify default `out_indices` of MobilenetV3, now it will change according to `arch` if not specified. * Add MobilenetV3 large config. * Add mobilenetv3 README * Modify InvertedResidual unit test. * Refactor ConvClsHead to StackedLinearClsHead, and add unit tests. * Add unit test for `simple_test` of `StackedLinearClsHead`. * Fix typo Co-authored-by: Yidi Shao <ydshao@smail.nju.edu.cn>
This commit is contained in:
parent
53c0df271f
commit
65410b05ad
14
configs/_base_/models/mobilenet_v3_large_imagenet.py
Normal file
14
configs/_base_/models/mobilenet_v3_large_imagenet.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='MobileNetV3', arch='large'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='StackedLinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=960,
|
||||||
|
mid_channels=[1280],
|
||||||
|
dropout_rate=0.2,
|
||||||
|
act_cfg=dict(type='HSwish'),
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5)))
|
13
configs/_base_/models/mobilenet_v3_small_cifar.py
Normal file
13
configs/_base_/models/mobilenet_v3_small_cifar.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='MobileNetV3', arch='small'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='StackedLinearClsHead',
|
||||||
|
num_classes=10,
|
||||||
|
in_channels=576,
|
||||||
|
mid_channels=[1280],
|
||||||
|
act_cfg=dict(type='HSwish'),
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5)))
|
14
configs/_base_/models/mobilenet_v3_small_imagenet.py
Normal file
14
configs/_base_/models/mobilenet_v3_small_imagenet.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='MobileNetV3', arch='small'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='StackedLinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=576,
|
||||||
|
mid_channels=[1024],
|
||||||
|
dropout_rate=0.2,
|
||||||
|
act_cfg=dict(type='HSwish'),
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5)))
|
30
configs/mobilenet_v3/README.md
Normal file
30
configs/mobilenet_v3/README.md
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Searching for MobileNetV3
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@inproceedings{Howard_2019_ICCV,
|
||||||
|
author = {Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and Le, Quoc V. and Adam, Hartwig},
|
||||||
|
title = {Searching for MobileNetV3},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||||
|
month = {October},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pretrain model
|
||||||
|
|
||||||
|
The pre-trained modles are converted from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/mobilenetv3.html).
|
||||||
|
|
||||||
|
### ImageNet
|
||||||
|
|
||||||
|
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||||
|
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
|
||||||
|
| MobileNetV3-Large | 5.48 | 0.23 | 74.04 | 91.34 | [model](https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth)|
|
||||||
|
| MobileNetV3-Small | 2.54 | 0.06 | 67.66 | 87.41 | [model](https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth)|
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
Waiting for adding.
|
158
configs/mobilenet_v3/mobilenet_v3_large_imagenet.py
Normal file
158
configs/mobilenet_v3/mobilenet_v3_large_imagenet.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
# Refer to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||||
|
# ----------------------------
|
||||||
|
# -[x] auto_augment='imagenet'
|
||||||
|
# -[x] batch_size=128 (per gpu)
|
||||||
|
# -[x] epochs=600
|
||||||
|
# -[x] opt='rmsprop'
|
||||||
|
# -[x] lr=0.064
|
||||||
|
# -[x] eps=0.0316
|
||||||
|
# -[x] alpha=0.9
|
||||||
|
# -[x] weight_decay=1e-05
|
||||||
|
# -[x] momentum=0.9
|
||||||
|
# -[x] lr_gamma=0.973
|
||||||
|
# -[x] lr_step_size=2
|
||||||
|
# -[x] nproc_per_node=8
|
||||||
|
# -[x] random_erase=0.2
|
||||||
|
# -[x] workers=16 (workers_per_gpu)
|
||||||
|
# - modify: RandomErasing use RE-M instead of RE-0
|
||||||
|
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/mobilenet_v3_large_imagenet.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
policies = [
|
||||||
|
[
|
||||||
|
dict(type='Posterize', bits=4, prob=0.4),
|
||||||
|
dict(type='Rotate', angle=30., prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||||
|
dict(type='AutoContrast', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.8),
|
||||||
|
dict(type='Equalize', prob=0.6)],
|
||||||
|
[
|
||||||
|
dict(type='Posterize', bits=5, prob=0.6),
|
||||||
|
dict(type='Posterize', bits=5, prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Posterize', bits=6, prob=0.8),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=10., prob=0.2),
|
||||||
|
dict(type='Solarize', thr=256 / 9, prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.6),
|
||||||
|
dict(type='Posterize', bits=5, prob=0.4)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||||
|
dict(type='ColorTransform', magnitude=0., prob=0.4)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30., prob=0.4),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.0),
|
||||||
|
dict(type='Equalize', prob=0.8)],
|
||||||
|
[dict(type='Invert', prob=0.6),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||||
|
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||||
|
dict(type='ColorTransform', magnitude=0.2, prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Sharpness', magnitude=0.7, prob=0.4),
|
||||||
|
dict(type='Invert', prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
type='Shear',
|
||||||
|
magnitude=0.3 / 9 * 5,
|
||||||
|
prob=0.6,
|
||||||
|
direction='horizontal'),
|
||||||
|
dict(type='Equalize', prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0., prob=0.4),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||||
|
dict(type='AutoContrast', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Invert', prob=0.6),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||||
|
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.8),
|
||||||
|
dict(type='Equalize', prob=0.6)],
|
||||||
|
]
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||||
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='AutoAugment', policies=policies),
|
||||||
|
dict(
|
||||||
|
type='RandomErasing',
|
||||||
|
erase_prob=0.2,
|
||||||
|
mode='const',
|
||||||
|
min_area_ratio=0.02,
|
||||||
|
max_area_ratio=1 / 3,
|
||||||
|
fill_color=img_norm_cfg['mean']),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_label'])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=128,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(pipeline=train_pipeline))
|
||||||
|
evaluation = dict(interval=10, metric='accuracy')
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(
|
||||||
|
type='RMSprop',
|
||||||
|
lr=0.064,
|
||||||
|
alpha=0.9,
|
||||||
|
momentum=0.9,
|
||||||
|
eps=0.0316,
|
||||||
|
weight_decay=1e-5)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='step', step=2, gamma=0.973, by_epoch=True)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=600)
|
8
configs/mobilenet_v3/mobilenet_v3_small_cifar.py
Normal file
8
configs/mobilenet_v3/mobilenet_v3_small_cifar.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/mobilenet_v3_small_cifar.py',
|
||||||
|
'../_base_/datasets/cifar10_bs16.py',
|
||||||
|
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
lr_config = dict(policy='step', step=[120, 170])
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
158
configs/mobilenet_v3/mobilenet_v3_small_imagenet.py
Normal file
158
configs/mobilenet_v3/mobilenet_v3_small_imagenet.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
# Refer to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||||
|
# ----------------------------
|
||||||
|
# -[x] auto_augment='imagenet'
|
||||||
|
# -[x] batch_size=128 (per gpu)
|
||||||
|
# -[x] epochs=600
|
||||||
|
# -[x] opt='rmsprop'
|
||||||
|
# -[x] lr=0.064
|
||||||
|
# -[x] eps=0.0316
|
||||||
|
# -[x] alpha=0.9
|
||||||
|
# -[x] weight_decay=1e-05
|
||||||
|
# -[x] momentum=0.9
|
||||||
|
# -[x] lr_gamma=0.973
|
||||||
|
# -[x] lr_step_size=2
|
||||||
|
# -[x] nproc_per_node=8
|
||||||
|
# -[x] random_erase=0.2
|
||||||
|
# -[x] workers=16 (workers_per_gpu)
|
||||||
|
# - modify: RandomErasing use RE-M instead of RE-0
|
||||||
|
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/mobilenet_v3_small_imagenet.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
policies = [
|
||||||
|
[
|
||||||
|
dict(type='Posterize', bits=4, prob=0.4),
|
||||||
|
dict(type='Rotate', angle=30., prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||||
|
dict(type='AutoContrast', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.8),
|
||||||
|
dict(type='Equalize', prob=0.6)],
|
||||||
|
[
|
||||||
|
dict(type='Posterize', bits=5, prob=0.6),
|
||||||
|
dict(type='Posterize', bits=5, prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Posterize', bits=6, prob=0.8),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=10., prob=0.2),
|
||||||
|
dict(type='Solarize', thr=256 / 9, prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.6),
|
||||||
|
dict(type='Posterize', bits=5, prob=0.4)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||||
|
dict(type='ColorTransform', magnitude=0., prob=0.4)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30., prob=0.4),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.0),
|
||||||
|
dict(type='Equalize', prob=0.8)],
|
||||||
|
[dict(type='Invert', prob=0.6),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||||
|
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||||
|
dict(type='ColorTransform', magnitude=0.2, prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Sharpness', magnitude=0.7, prob=0.4),
|
||||||
|
dict(type='Invert', prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
type='Shear',
|
||||||
|
magnitude=0.3 / 9 * 5,
|
||||||
|
prob=0.6,
|
||||||
|
direction='horizontal'),
|
||||||
|
dict(type='Equalize', prob=1.)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0., prob=0.4),
|
||||||
|
dict(type='Equalize', prob=0.6)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Equalize', prob=0.4),
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||||
|
dict(type='AutoContrast', prob=0.6)
|
||||||
|
],
|
||||||
|
[dict(type='Invert', prob=0.6),
|
||||||
|
dict(type='Equalize', prob=1.)],
|
||||||
|
[
|
||||||
|
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||||
|
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||||
|
],
|
||||||
|
[dict(type='Equalize', prob=0.8),
|
||||||
|
dict(type='Equalize', prob=0.6)],
|
||||||
|
]
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||||
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='AutoAugment', policies=policies),
|
||||||
|
dict(
|
||||||
|
type='RandomErasing',
|
||||||
|
erase_prob=0.2,
|
||||||
|
mode='const',
|
||||||
|
min_area_ratio=0.02,
|
||||||
|
max_area_ratio=1 / 3,
|
||||||
|
fill_color=img_norm_cfg['mean']),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_label'])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=128,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(pipeline=train_pipeline))
|
||||||
|
evaluation = dict(interval=10, metric='accuracy')
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(
|
||||||
|
type='RMSprop',
|
||||||
|
lr=0.064,
|
||||||
|
alpha=0.9,
|
||||||
|
momentum=0.9,
|
||||||
|
eps=0.0316,
|
||||||
|
weight_decay=1e-5)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='step', step=2, gamma=0.973, by_epoch=True)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=600)
|
@ -1,7 +1,7 @@
|
|||||||
from .alexnet import AlexNet
|
from .alexnet import AlexNet
|
||||||
from .lenet import LeNet5
|
from .lenet import LeNet5
|
||||||
from .mobilenet_v2 import MobileNetV2
|
from .mobilenet_v2 import MobileNetV2
|
||||||
from .mobilenet_v3 import MobileNetv3
|
from .mobilenet_v3 import MobileNetV3
|
||||||
from .regnet import RegNet
|
from .regnet import RegNet
|
||||||
from .resnest import ResNeSt
|
from .resnest import ResNeSt
|
||||||
from .resnet import ResNet, ResNetV1d
|
from .resnet import ResNet, ResNetV1d
|
||||||
@ -17,5 +17,5 @@ from .vision_transformer import VisionTransformer
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3', 'VisionTransformer'
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer'
|
||||||
]
|
]
|
||||||
|
@ -7,18 +7,18 @@ from .base_backbone import BaseBackbone
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class MobileNetv3(BaseBackbone):
|
class MobileNetV3(BaseBackbone):
|
||||||
"""MobileNetv3 backbone.
|
"""MobileNetV3 backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arch (str): Architechture of mobilnetv3, from {small, big}.
|
arch (str): Architechture of mobilnetv3, from {small, large}.
|
||||||
Default: small.
|
Default: small.
|
||||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||||
Default: None, which means using conv2d.
|
Default: None, which means using conv2d.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Default: dict(type='BN').
|
Default: dict(type='BN').
|
||||||
out_indices (None or Sequence[int]): Output from which stages.
|
out_indices (None or Sequence[int]): Output from which stages.
|
||||||
Default: (10, ), which means output tensors from final stage.
|
Default: None, which means output tensors from final stage.
|
||||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||||
Defualt: -1, which means not freezing any parameters.
|
Defualt: -1, which means not freezing any parameters.
|
||||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
@ -42,49 +42,54 @@ class MobileNetv3(BaseBackbone):
|
|||||||
[5, 288, 96, True, 'HSwish', 2],
|
[5, 288, 96, True, 'HSwish', 2],
|
||||||
[5, 576, 96, True, 'HSwish', 1],
|
[5, 576, 96, True, 'HSwish', 1],
|
||||||
[5, 576, 96, True, 'HSwish', 1]],
|
[5, 576, 96, True, 'HSwish', 1]],
|
||||||
'big': [[3, 16, 16, False, 'ReLU', 1],
|
'large': [[3, 16, 16, False, 'ReLU', 1],
|
||||||
[3, 64, 24, False, 'ReLU', 2],
|
[3, 64, 24, False, 'ReLU', 2],
|
||||||
[3, 72, 24, False, 'ReLU', 1],
|
[3, 72, 24, False, 'ReLU', 1],
|
||||||
[5, 72, 40, True, 'ReLU', 2],
|
[5, 72, 40, True, 'ReLU', 2],
|
||||||
[5, 120, 40, True, 'ReLU', 1],
|
[5, 120, 40, True, 'ReLU', 1],
|
||||||
[5, 120, 40, True, 'ReLU', 1],
|
[5, 120, 40, True, 'ReLU', 1],
|
||||||
[3, 240, 80, False, 'HSwish', 2],
|
[3, 240, 80, False, 'HSwish', 2],
|
||||||
[3, 200, 80, False, 'HSwish', 1],
|
[3, 200, 80, False, 'HSwish', 1],
|
||||||
[3, 184, 80, False, 'HSwish', 1],
|
[3, 184, 80, False, 'HSwish', 1],
|
||||||
[3, 184, 80, False, 'HSwish', 1],
|
[3, 184, 80, False, 'HSwish', 1],
|
||||||
[3, 480, 112, True, 'HSwish', 1],
|
[3, 480, 112, True, 'HSwish', 1],
|
||||||
[3, 672, 112, True, 'HSwish', 1],
|
[3, 672, 112, True, 'HSwish', 1],
|
||||||
[5, 672, 160, True, 'HSwish', 1],
|
[5, 672, 160, True, 'HSwish', 2],
|
||||||
[5, 672, 160, True, 'HSwish', 2],
|
[5, 960, 160, True, 'HSwish', 1],
|
||||||
[5, 960, 160, True, 'HSwish', 1]]
|
[5, 960, 160, True, 'HSwish', 1]]
|
||||||
} # yapf: disable
|
} # yapf: disable
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
arch='small',
|
arch='small',
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN', eps=0.001, momentum=0.01),
|
||||||
out_indices=(10, ),
|
out_indices=None,
|
||||||
frozen_stages=-1,
|
frozen_stages=-1,
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
init_cfg=[
|
init_cfg=[
|
||||||
dict(type='Kaiming', layer=['Conv2d']),
|
dict(
|
||||||
dict(type='Constant', val=1, layer=['BatchNorm2d'])
|
type='Kaiming',
|
||||||
|
layer=['Conv2d'],
|
||||||
|
nonlinearity='leaky_relu'),
|
||||||
|
dict(type='Normal', layer=['Linear'], std=0.01),
|
||||||
|
dict(type='Constant', layer=['BatchNorm2d'], val=1)
|
||||||
]):
|
]):
|
||||||
super(MobileNetv3, self).__init__(init_cfg)
|
super(MobileNetV3, self).__init__(init_cfg)
|
||||||
assert arch in self.arch_settings
|
assert arch in self.arch_settings
|
||||||
for index in out_indices:
|
if out_indices is None:
|
||||||
if index not in range(0, len(self.arch_settings[arch])):
|
out_indices = (12, ) if arch == 'small' else (16, )
|
||||||
raise ValueError('the item in out_indices must in '
|
for order, index in enumerate(out_indices):
|
||||||
f'range(0, {len(self.arch_settings[arch])}). '
|
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||||
f'But received {index}')
|
raise ValueError(
|
||||||
|
'the item in out_indices must in '
|
||||||
|
f'range(0, {len(self.arch_settings[arch]) + 2}). '
|
||||||
|
f'But received {index}')
|
||||||
|
|
||||||
if frozen_stages not in range(-1, len(self.arch_settings[arch])):
|
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||||
raise ValueError('frozen_stages must be in range(-1, '
|
raise ValueError('frozen_stages must be in range(-1, '
|
||||||
f'{len(self.arch_settings[arch])}). '
|
f'{len(self.arch_settings[arch]) + 2}). '
|
||||||
f'But received {frozen_stages}')
|
f'But received {frozen_stages}')
|
||||||
self.out_indices = out_indices
|
|
||||||
self.frozen_stages = frozen_stages
|
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
@ -93,23 +98,26 @@ class MobileNetv3(BaseBackbone):
|
|||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
|
|
||||||
self.in_channels = 16
|
|
||||||
self.conv1 = ConvModule(
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=self.in_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
conv_cfg=conv_cfg,
|
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
act_cfg=dict(type='HSwish'))
|
|
||||||
|
|
||||||
self.layers = self._make_layer()
|
self.layers = self._make_layer()
|
||||||
self.feat_dim = self.arch_settings[arch][-1][2]
|
self.feat_dim = self.arch_settings[arch][-1][1]
|
||||||
|
|
||||||
def _make_layer(self):
|
def _make_layer(self):
|
||||||
layers = []
|
layers = []
|
||||||
layer_setting = self.arch_settings[self.arch]
|
layer_setting = self.arch_settings[self.arch]
|
||||||
|
in_channels = 16
|
||||||
|
|
||||||
|
layer = ConvModule(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=dict(type='HSwish'))
|
||||||
|
self.add_module('layer0', layer)
|
||||||
|
layers.append('layer0')
|
||||||
|
|
||||||
for i, params in enumerate(layer_setting):
|
for i, params in enumerate(layer_setting):
|
||||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||||
stride) = params
|
stride) = params
|
||||||
@ -117,31 +125,50 @@ class MobileNetv3(BaseBackbone):
|
|||||||
se_cfg = dict(
|
se_cfg = dict(
|
||||||
channels=mid_channels,
|
channels=mid_channels,
|
||||||
ratio=4,
|
ratio=4,
|
||||||
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
|
act_cfg=(dict(type='ReLU'),
|
||||||
|
dict(
|
||||||
|
type='HSigmoid',
|
||||||
|
bias=3,
|
||||||
|
divisor=6,
|
||||||
|
min_value=0,
|
||||||
|
max_value=1)))
|
||||||
else:
|
else:
|
||||||
se_cfg = None
|
se_cfg = None
|
||||||
|
|
||||||
layer = InvertedResidual(
|
layer = InvertedResidual(
|
||||||
in_channels=self.in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
mid_channels=mid_channels,
|
mid_channels=mid_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
se_cfg=se_cfg,
|
se_cfg=se_cfg,
|
||||||
with_expand_conv=True,
|
|
||||||
conv_cfg=self.conv_cfg,
|
conv_cfg=self.conv_cfg,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
act_cfg=dict(type=act),
|
act_cfg=dict(type=act),
|
||||||
with_cp=self.with_cp)
|
with_cp=self.with_cp)
|
||||||
self.in_channels = out_channels
|
in_channels = out_channels
|
||||||
layer_name = 'layer{}'.format(i + 1)
|
layer_name = 'layer{}'.format(i + 1)
|
||||||
self.add_module(layer_name, layer)
|
self.add_module(layer_name, layer)
|
||||||
layers.append(layer_name)
|
layers.append(layer_name)
|
||||||
|
|
||||||
|
# Build the last layer before pooling
|
||||||
|
# TODO: No dilation
|
||||||
|
layer = ConvModule(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=576 if self.arch == 'small' else 960,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=dict(type='HSwish'))
|
||||||
|
layer_name = 'layer{}'.format(len(layer_setting) + 1)
|
||||||
|
self.add_module(layer_name, layer)
|
||||||
|
layers.append(layer_name)
|
||||||
|
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
|
||||||
|
|
||||||
outs = []
|
outs = []
|
||||||
for i, layer_name in enumerate(self.layers):
|
for i, layer_name in enumerate(self.layers):
|
||||||
layer = getattr(self, layer_name)
|
layer = getattr(self, layer_name)
|
||||||
@ -155,17 +182,14 @@ class MobileNetv3(BaseBackbone):
|
|||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
||||||
def _freeze_stages(self):
|
def _freeze_stages(self):
|
||||||
if self.frozen_stages >= 0:
|
for i in range(0, self.frozen_stages + 1):
|
||||||
for param in self.conv1.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
for i in range(1, self.frozen_stages + 1):
|
|
||||||
layer = getattr(self, f'layer{i}')
|
layer = getattr(self, f'layer{i}')
|
||||||
layer.eval()
|
layer.eval()
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(MobileNetv3, self).train(mode)
|
super(MobileNetV3, self).train(mode)
|
||||||
self._freeze_stages()
|
self._freeze_stages()
|
||||||
if mode and self.norm_eval:
|
if mode and self.norm_eval:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
@ -2,9 +2,10 @@ from .cls_head import ClsHead
|
|||||||
from .linear_head import LinearClsHead
|
from .linear_head import LinearClsHead
|
||||||
from .multi_label_head import MultiLabelClsHead
|
from .multi_label_head import MultiLabelClsHead
|
||||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||||
|
from .stacked_head import StackedLinearClsHead
|
||||||
from .vision_transformer_head import VisionTransformerClsHead
|
from .vision_transformer_head import VisionTransformerClsHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead',
|
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||||
'VisionTransformerClsHead'
|
'MultiLabelLinearClsHead', 'VisionTransformerClsHead'
|
||||||
]
|
]
|
||||||
|
135
mmcls/models/heads/stacked_head.py
Normal file
135
mmcls/models/heads/stacked_head.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||||
|
from mmcv.runner import BaseModule, ModuleList
|
||||||
|
|
||||||
|
from ..builder import HEADS
|
||||||
|
from .cls_head import ClsHead
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBlock(BaseModule):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
dropout_rate=0.,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
self.fc = nn.Linear(in_channels, out_channels)
|
||||||
|
|
||||||
|
self.norm = None
|
||||||
|
self.act = None
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
if norm_cfg is not None:
|
||||||
|
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
|
||||||
|
if act_cfg is not None:
|
||||||
|
self.act = build_activation_layer(act_cfg)
|
||||||
|
if dropout_rate > 0:
|
||||||
|
self.dropout = nn.Dropout(p=dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
if self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class StackedLinearClsHead(ClsHead):
|
||||||
|
"""Classifier head with several hidden fc layer and a output fc layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of categories excluding the background
|
||||||
|
category.
|
||||||
|
in_channels (int): Number of channels in the input feature map.
|
||||||
|
mid_channels (Sequence): Number of channels in the hidden fc layers.
|
||||||
|
dropout_rate (float): Dropout rate after each hidden fc layer,
|
||||||
|
except the last layer. Defaults to 0.
|
||||||
|
norm_cfg (dict, optional): Config dict of normalization layer after
|
||||||
|
each hidden fc layer, except the last layer. Defaults to None.
|
||||||
|
act_cfg (dict, optional): Config dict of activation function after each
|
||||||
|
hidden layer, except the last layer. Defaults to use "ReLU".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes: int,
|
||||||
|
in_channels: int,
|
||||||
|
mid_channels: Sequence,
|
||||||
|
dropout_rate: float = 0.,
|
||||||
|
norm_cfg: Dict = None,
|
||||||
|
act_cfg: Dict = dict(type='ReLU'),
|
||||||
|
**kwargs):
|
||||||
|
super(StackedLinearClsHead, self).__init__(**kwargs)
|
||||||
|
assert num_classes > 0, \
|
||||||
|
f'`num_classes` of StackedLinearClsHead must be a positive ' \
|
||||||
|
f'integer, got {num_classes} instead.'
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
assert isinstance(mid_channels, Sequence), \
|
||||||
|
f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \
|
||||||
|
f'instead of {type(mid_channels)}'
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
|
||||||
|
self._init_layers()
|
||||||
|
|
||||||
|
def _init_layers(self):
|
||||||
|
self.layers = ModuleList(
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.))
|
||||||
|
in_channels = self.in_channels
|
||||||
|
for hidden_channels in self.mid_channels:
|
||||||
|
self.layers.append(
|
||||||
|
LinearBlock(
|
||||||
|
in_channels,
|
||||||
|
hidden_channels,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg))
|
||||||
|
in_channels = hidden_channels
|
||||||
|
|
||||||
|
self.layers.append(
|
||||||
|
LinearBlock(
|
||||||
|
self.mid_channels[-1],
|
||||||
|
self.num_classes,
|
||||||
|
dropout_rate=0.,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=None))
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.layers.init_weights()
|
||||||
|
|
||||||
|
def simple_test(self, img):
|
||||||
|
"""Test without augmentation."""
|
||||||
|
cls_score = img
|
||||||
|
for layer in self.layers:
|
||||||
|
cls_score = layer(cls_score)
|
||||||
|
if isinstance(cls_score, list):
|
||||||
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
return pred
|
||||||
|
pred = list(pred.detach().cpu().numpy())
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def forward_train(self, x, gt_label):
|
||||||
|
cls_score = x
|
||||||
|
for layer in self.layers:
|
||||||
|
cls_score = layer(cls_score)
|
||||||
|
losses = self.loss(cls_score, gt_label)
|
||||||
|
return losses
|
@ -18,9 +18,6 @@ class InvertedResidual(BaseModule):
|
|||||||
stride (int): The stride of the depthwise convolution. Default: 1.
|
stride (int): The stride of the depthwise convolution. Default: 1.
|
||||||
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
|
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
|
||||||
se layer.
|
se layer.
|
||||||
with_expand_conv (bool): Use expand conv or not. If set False,
|
|
||||||
mid_channels must be the same with in_channels.
|
|
||||||
Default: True.
|
|
||||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||||
which means using conv2d.
|
which means using conv2d.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
@ -41,7 +38,6 @@ class InvertedResidual(BaseModule):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
se_cfg=None,
|
se_cfg=None,
|
||||||
with_expand_conv=True,
|
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
@ -52,12 +48,10 @@ class InvertedResidual(BaseModule):
|
|||||||
assert stride in [1, 2]
|
assert stride in [1, 2]
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
self.with_se = se_cfg is not None
|
self.with_se = se_cfg is not None
|
||||||
self.with_expand_conv = with_expand_conv
|
self.with_expand_conv = (mid_channels != in_channels)
|
||||||
|
|
||||||
if self.with_se:
|
if self.with_se:
|
||||||
assert isinstance(se_cfg, dict)
|
assert isinstance(se_cfg, dict)
|
||||||
if not self.with_expand_conv:
|
|
||||||
assert mid_channels == in_channels
|
|
||||||
|
|
||||||
if self.with_expand_conv:
|
if self.with_expand_conv:
|
||||||
self.expand_conv = ConvModule(
|
self.expand_conv = ConvModule(
|
||||||
@ -89,7 +83,7 @@ class InvertedResidual(BaseModule):
|
|||||||
padding=0,
|
padding=0,
|
||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg)
|
act_cfg=None)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
|||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
|
from .make_divisible import make_divisible
|
||||||
|
|
||||||
|
|
||||||
# class SELayer(nn.Module):
|
# class SELayer(nn.Module):
|
||||||
class SELayer(BaseModule):
|
class SELayer(BaseModule):
|
||||||
@ -25,6 +27,7 @@ class SELayer(BaseModule):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
channels,
|
channels,
|
||||||
ratio=16,
|
ratio=16,
|
||||||
|
bias='auto',
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
|
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
@ -34,18 +37,21 @@ class SELayer(BaseModule):
|
|||||||
assert len(act_cfg) == 2
|
assert len(act_cfg) == 2
|
||||||
assert mmcv.is_tuple_of(act_cfg, dict)
|
assert mmcv.is_tuple_of(act_cfg, dict)
|
||||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
squeeze_channels = make_divisible(channels // ratio, 8)
|
||||||
self.conv1 = ConvModule(
|
self.conv1 = ConvModule(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=int(channels / ratio),
|
out_channels=squeeze_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
bias=bias,
|
||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
act_cfg=act_cfg[0])
|
act_cfg=act_cfg[0])
|
||||||
self.conv2 = ConvModule(
|
self.conv2 = ConvModule(
|
||||||
in_channels=int(channels / ratio),
|
in_channels=squeeze_channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
bias=bias,
|
||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
act_cfg=act_cfg[1])
|
act_cfg=act_cfg[1])
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
from torch.nn.modules import GroupNorm
|
from torch.nn.modules import GroupNorm
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.backbones import MobileNetv3
|
from mmcls.models.backbones import MobileNetV3
|
||||||
from mmcls.models.utils import InvertedResidual
|
from mmcls.models.utils import InvertedResidual
|
||||||
|
|
||||||
|
|
||||||
@ -26,42 +26,40 @@ def check_norm_state(modules, train_state):
|
|||||||
def test_mobilenetv3_backbone():
|
def test_mobilenetv3_backbone():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# pretrained must be a string path
|
# pretrained must be a string path
|
||||||
model = MobileNetv3()
|
model = MobileNetV3()
|
||||||
model.init_weights(pretrained=0)
|
model.init_weights(pretrained=0)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# arch must in [small, big]
|
# arch must in [small, large]
|
||||||
MobileNetv3(arch='others')
|
MobileNetV3(arch='others')
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# frozen_stages must less than 12 when arch is small
|
# frozen_stages must less than 13 when arch is small
|
||||||
MobileNetv3(arch='small', frozen_stages=12)
|
MobileNetV3(arch='small', frozen_stages=13)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# frozen_stages must less than 16 when arch is big
|
# frozen_stages must less than 17 when arch is large
|
||||||
MobileNetv3(arch='big', frozen_stages=16)
|
MobileNetV3(arch='large', frozen_stages=17)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# max out_indices must less than 11 when arch is small
|
# max out_indices must less than 13 when arch is small
|
||||||
MobileNetv3(arch='small', out_indices=(11, ))
|
MobileNetV3(arch='small', out_indices=(13, ))
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# max out_indices must less than 15 when arch is big
|
# max out_indices must less than 17 when arch is large
|
||||||
MobileNetv3(arch='big', out_indices=(15, ))
|
MobileNetV3(arch='large', out_indices=(17, ))
|
||||||
|
|
||||||
# Test MobileNetv3
|
# Test MobileNetV3
|
||||||
model = MobileNetv3()
|
model = MobileNetV3()
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# Test MobileNetv3 with first stage frozen
|
# Test MobileNetV3 with first stage frozen
|
||||||
frozen_stages = 1
|
frozen_stages = 1
|
||||||
model = MobileNetv3(frozen_stages=frozen_stages)
|
model = MobileNetV3(frozen_stages=frozen_stages)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
for param in model.conv1.parameters():
|
for i in range(0, frozen_stages + 1):
|
||||||
assert param.requires_grad is False
|
|
||||||
for i in range(1, frozen_stages + 1):
|
|
||||||
layer = getattr(model, f'layer{i}')
|
layer = getattr(model, f'layer{i}')
|
||||||
for mod in layer.modules():
|
for mod in layer.modules():
|
||||||
if isinstance(mod, _BatchNorm):
|
if isinstance(mod, _BatchNorm):
|
||||||
@ -69,35 +67,37 @@ def test_mobilenetv3_backbone():
|
|||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
|
|
||||||
# Test MobileNetv3 with norm eval
|
# Test MobileNetV3 with norm eval
|
||||||
model = MobileNetv3(norm_eval=True, out_indices=range(0, 11))
|
model = MobileNetV3(norm_eval=True, out_indices=range(0, 12))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert check_norm_state(model.modules(), False)
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
# Test MobileNetv3 forward with small arch
|
# Test MobileNetV3 forward with small arch
|
||||||
model = MobileNetv3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
|
model = MobileNetV3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 11
|
assert len(feat) == 13
|
||||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
assert feat[1].shape == torch.Size([1, 16, 56, 56])
|
||||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
assert feat[3].shape == torch.Size([1, 24, 28, 28])
|
||||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
assert feat[6].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
assert feat[8].shape == torch.Size([1, 48, 14, 14])
|
||||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||||
|
assert feat[11].shape == torch.Size([1, 96, 7, 7])
|
||||||
|
assert feat[12].shape == torch.Size([1, 576, 7, 7])
|
||||||
|
|
||||||
# Test MobileNetv3 forward with small arch and GroupNorm
|
# Test MobileNetV3 forward with small arch and GroupNorm
|
||||||
model = MobileNetv3(
|
model = MobileNetV3(
|
||||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
|
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
|
||||||
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if is_norm(m):
|
if is_norm(m):
|
||||||
@ -107,47 +107,51 @@ def test_mobilenetv3_backbone():
|
|||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 11
|
assert len(feat) == 13
|
||||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
assert feat[1].shape == torch.Size([1, 16, 56, 56])
|
||||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
assert feat[3].shape == torch.Size([1, 24, 28, 28])
|
||||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
assert feat[6].shape == torch.Size([1, 40, 14, 14])
|
||||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
assert feat[8].shape == torch.Size([1, 48, 14, 14])
|
||||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||||
|
assert feat[11].shape == torch.Size([1, 96, 7, 7])
|
||||||
|
assert feat[12].shape == torch.Size([1, 576, 7, 7])
|
||||||
|
|
||||||
# Test MobileNetv3 forward with big arch
|
# Test MobileNetV3 forward with large arch
|
||||||
model = MobileNetv3(
|
model = MobileNetV3(
|
||||||
arch='big',
|
arch='large',
|
||||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
|
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 15
|
assert len(feat) == 17
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
assert feat[1].shape == torch.Size([1, 16, 112, 112])
|
||||||
assert feat[2].shape == torch.Size([1, 24, 56, 56])
|
assert feat[2].shape == torch.Size([1, 24, 56, 56])
|
||||||
assert feat[3].shape == torch.Size([1, 40, 28, 28])
|
assert feat[3].shape == torch.Size([1, 24, 56, 56])
|
||||||
assert feat[4].shape == torch.Size([1, 40, 28, 28])
|
assert feat[4].shape == torch.Size([1, 40, 28, 28])
|
||||||
assert feat[5].shape == torch.Size([1, 40, 28, 28])
|
assert feat[5].shape == torch.Size([1, 40, 28, 28])
|
||||||
assert feat[6].shape == torch.Size([1, 80, 14, 14])
|
assert feat[6].shape == torch.Size([1, 40, 28, 28])
|
||||||
assert feat[7].shape == torch.Size([1, 80, 14, 14])
|
assert feat[7].shape == torch.Size([1, 80, 14, 14])
|
||||||
assert feat[8].shape == torch.Size([1, 80, 14, 14])
|
assert feat[8].shape == torch.Size([1, 80, 14, 14])
|
||||||
assert feat[9].shape == torch.Size([1, 80, 14, 14])
|
assert feat[9].shape == torch.Size([1, 80, 14, 14])
|
||||||
assert feat[10].shape == torch.Size([1, 112, 14, 14])
|
assert feat[10].shape == torch.Size([1, 80, 14, 14])
|
||||||
assert feat[11].shape == torch.Size([1, 112, 14, 14])
|
assert feat[11].shape == torch.Size([1, 112, 14, 14])
|
||||||
assert feat[12].shape == torch.Size([1, 160, 14, 14])
|
assert feat[12].shape == torch.Size([1, 112, 14, 14])
|
||||||
assert feat[13].shape == torch.Size([1, 160, 7, 7])
|
assert feat[13].shape == torch.Size([1, 160, 7, 7])
|
||||||
assert feat[14].shape == torch.Size([1, 160, 7, 7])
|
assert feat[14].shape == torch.Size([1, 160, 7, 7])
|
||||||
|
assert feat[15].shape == torch.Size([1, 160, 7, 7])
|
||||||
|
assert feat[16].shape == torch.Size([1, 960, 7, 7])
|
||||||
|
|
||||||
# Test MobileNetv3 forward with big arch
|
# Test MobileNetV3 forward with large arch
|
||||||
model = MobileNetv3(arch='big', out_indices=(0, ))
|
model = MobileNetV3(arch='large', out_indices=(0, ))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
@ -155,8 +159,8 @@ def test_mobilenetv3_backbone():
|
|||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat.shape == torch.Size([1, 16, 112, 112])
|
assert feat.shape == torch.Size([1, 16, 112, 112])
|
||||||
|
|
||||||
# Test MobileNetv3 with checkpoint forward
|
# Test MobileNetV3 with checkpoint forward
|
||||||
model = MobileNetv3(with_cp=True)
|
model = MobileNetV3(with_cp=True)
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, InvertedResidual):
|
if isinstance(m, InvertedResidual):
|
||||||
assert m.with_cp
|
assert m.with_cp
|
||||||
@ -165,4 +169,4 @@ def test_mobilenetv3_backbone():
|
|||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert feat.shape == torch.Size([1, 96, 7, 7])
|
assert feat.shape == torch.Size([1, 576, 7, 7])
|
||||||
|
@ -57,10 +57,9 @@ def test_inverted_residual():
|
|||||||
# se_cfg must be None or dict
|
# se_cfg must be None or dict
|
||||||
InvertedResidual(16, 16, 32, se_cfg=list())
|
InvertedResidual(16, 16, 32, se_cfg=list())
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
# Add expand conv if in_channels and mid_channels is not the same
|
||||||
# in_channeld and out_channels must be the same if
|
assert InvertedResidual(32, 16, 32).with_expand_conv is False
|
||||||
# with_expand_conv is False
|
assert InvertedResidual(16, 16, 32).with_expand_conv is True
|
||||||
InvertedResidual(16, 16, 32, with_expand_conv=False)
|
|
||||||
|
|
||||||
# Test InvertedResidual forward, stride=1
|
# Test InvertedResidual forward, stride=1
|
||||||
block = InvertedResidual(16, 16, 32, stride=1)
|
block = InvertedResidual(16, 16, 32, stride=1)
|
||||||
@ -85,8 +84,8 @@ def test_inverted_residual():
|
|||||||
assert isinstance(block.se, SELayer)
|
assert isinstance(block.se, SELayer)
|
||||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||||
|
|
||||||
# Test InvertedResidual forward, with_expand_conv=False
|
# Test InvertedResidual forward without expand conv
|
||||||
block = InvertedResidual(32, 16, 32, with_expand_conv=False)
|
block = InvertedResidual(32, 16, 32)
|
||||||
x = torch.randn(1, 32, 56, 56)
|
x = torch.randn(1, 32, 56, 56)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert getattr(block, 'expand_conv', None) is None
|
assert getattr(block, 'expand_conv', None) is None
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
||||||
MultiLabelLinearClsHead)
|
MultiLabelLinearClsHead, StackedLinearClsHead)
|
||||||
|
|
||||||
|
|
||||||
def test_cls_head():
|
def test_cls_head():
|
||||||
@ -48,3 +51,47 @@ def test_multilabel_linear_head():
|
|||||||
head.init_weights()
|
head.init_weights()
|
||||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||||
assert losses['loss'].item() > 0
|
assert losses['loss'].item() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stacked_linear_cls_head():
|
||||||
|
# test assertion
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
|
||||||
|
|
||||||
|
fake_img = torch.rand(4, 5) # B, channel
|
||||||
|
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
|
||||||
|
|
||||||
|
# test forward with default setting
|
||||||
|
head = StackedLinearClsHead(
|
||||||
|
num_classes=3, in_channels=5, mid_channels=[10])
|
||||||
|
head.init_weights()
|
||||||
|
|
||||||
|
losses = head.forward_train(fake_img, fake_gt_label)
|
||||||
|
assert losses['loss'].item() > 0
|
||||||
|
|
||||||
|
# test simple test
|
||||||
|
pred = head.simple_test(fake_img)
|
||||||
|
assert len(pred) == 4
|
||||||
|
|
||||||
|
# test simple test in tracing
|
||||||
|
p = patch('torch.onnx.is_in_onnx_export', lambda: True)
|
||||||
|
p.start()
|
||||||
|
pred = head.simple_test(fake_img)
|
||||||
|
assert pred.shape == torch.Size((4, 3))
|
||||||
|
p.stop()
|
||||||
|
|
||||||
|
# test forward with full function
|
||||||
|
head = StackedLinearClsHead(
|
||||||
|
num_classes=3,
|
||||||
|
in_channels=5,
|
||||||
|
mid_channels=[8, 10],
|
||||||
|
dropout_rate=0.2,
|
||||||
|
norm_cfg=dict(type='BN1d'),
|
||||||
|
act_cfg=dict(type='HSwish'))
|
||||||
|
head.init_weights()
|
||||||
|
|
||||||
|
losses = head.forward_train(fake_img, fake_gt_label)
|
||||||
|
assert losses['loss'].item() > 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user