[FEATURE] Mobilenet Series Search Space (#82)
* add mbv2 block and identity * add mbv2 block and identity unittests * expand_factor -> expand_ratio * add searchable mobilenet v2 * fix lints * add spos subnet retraining config * fix mmcls slurm search * add proxyless_gpu setting for mbv2 * use bn default * add angelnas spos config * update spos readme * fix SELayer's useage * add docstring * rename mbv2 to mb * add some unittest of mb * rename mb to mobilenet * add some rename-mb in configs * update README of spos * add rename-mb in unittest * update test_mmcls Co-authored-by: wutongshenqiu <690364065@qq.com> Co-authored-by: humu789 <humu@pjlab.org.cn>pull/106/head
parent
f59e059cf7
commit
ef1637e866
|
@ -1,4 +1,5 @@
|
|||
# SPOS
|
||||
|
||||
> [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
@ -10,16 +11,18 @@ Comprehensive experiments verify that our approach is flexible and effective. It
|
|||
|
||||

|
||||
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
### Supernet pre-training on ImageNet
|
||||
|
||||
```bash
|
||||
python ./tools/mmcls/train_mmcls.py \
|
||||
configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k.py \
|
||||
--work-dir $WORK_DIR
|
||||
```
|
||||
|
||||
### Search for subnet on the trained supernet
|
||||
|
||||
```bash
|
||||
python ./tools/mmcls/search_mmcls.py \
|
||||
configs/nas/spos/spos_evolution_search_shufflenetv2_8xb2048_in1k.py \
|
||||
|
@ -28,6 +31,7 @@ python ./tools/mmcls/search_mmcls.py \
|
|||
```
|
||||
|
||||
### Subnet retraining on ImageNet
|
||||
|
||||
```bash
|
||||
python ./tools/mmcls/train_mmcls.py \
|
||||
configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py \
|
||||
|
@ -36,14 +40,18 @@ python ./tools/mmcls/train_mmcls.py \
|
|||
```
|
||||
|
||||
## Results and models
|
||||
|Dataset| Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
|
||||
|:---------------------:|:---------------------:|:------:|:---------:|:--------:|:---------:|:---------:|:------:|:---------|:---------:|
|
||||
|ImageNet| ShuffleNetV2 |[mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx)| 3.35 | 0.33 | 73.87 | 91.6 |[config](./spos_subnet_shufflenetv2_8xb128_in1k.py)|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth?versionId=CAEQHxiBgIDK5b6I7xciIDM1YjIwZjQxN2UyMDRjYjA5YTM5NTBlMGNhMTdkNjI2) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json?versionId=CAEQHxiBgIDr9cuL7xciIDBmOTZiZGUyYjRiMDQ5NzhhZjY0NWUxYmUzNDlmNTg5)| MMRazor searched
|
||||
|
||||
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
|
||||
| :------: |:----------------------:| :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------:|
|
||||
| ImageNet | ShuffleNetV2 | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 3.35 | 0.33 | 73.87 | 91.6 | [config](./spos_subnet_shufflenetv2_8xb128_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth?versionId=CAEQHxiBgIDK5b6I7xciIDM1YjIwZjQxN2UyMDRjYjA5YTM5NTBlMGNhMTdkNjI2) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json?versionId=CAEQHxiBgIDr9cuL7xciIDBmOTZiZGUyYjRiMDQ5NzhhZjY0NWUxYmUzNDlmNTg5) | MMRazor searched |
|
||||
| ImageNet | MobileNet-ProxylessGPU | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_mobilenet_subnet/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_mutable_cfg.yaml) | 5.94 | 0.49* | 75.98 | 92.77 | [config](./spos_mobilenet_for_check_ckpt_from_anglenas.py) | | [AngleNAS](https://github.com/megvii-model/AngleNAS) searched |
|
||||
|
||||
**Note**:
|
||||
1. There are some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
|
||||
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network.
|
||||
2. We also retrain the subnet reported in paper with their official code, Top-1 is 73.6 and Top-5 is 91.6
|
||||
|
||||
1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
|
||||
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
|
||||
2. For *ShuffleNetV2*, we retrain the subnet reported in paper with their official code, Top-1 is 73.6 and Top-5 is 91.6.
|
||||
2. For *AngleNAS searched MobileNet-ProxylessGPU*, we obtain params and FLOPs using [this script](/tools/misc/get_flops.py), which may be different from [AngleNAS](https://github.com/megvii-model/AngleNAS#searched-models-with-abs).
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
stage_0_block_0:
|
||||
chosen:
|
||||
- mb_k3e1
|
||||
stage_1_block_0:
|
||||
chosen:
|
||||
- mb_k5e3
|
||||
stage_1_block_1:
|
||||
chosen:
|
||||
- mb_k5e3
|
||||
stage_1_block_2:
|
||||
chosen:
|
||||
- identity
|
||||
stage_1_block_3:
|
||||
chosen:
|
||||
- mb_k3e3
|
||||
stage_2_block_0:
|
||||
chosen:
|
||||
- mb_k3e3
|
||||
stage_2_block_1:
|
||||
chosen:
|
||||
- identity
|
||||
stage_2_block_2:
|
||||
chosen:
|
||||
- identity
|
||||
stage_2_block_3:
|
||||
chosen:
|
||||
- mb_k3e3
|
||||
stage_3_block_0:
|
||||
chosen:
|
||||
- mb_k7e6
|
||||
stage_3_block_1:
|
||||
chosen:
|
||||
- identity
|
||||
stage_3_block_2:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_3_block_3:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_4_block_0:
|
||||
chosen:
|
||||
- mb_k3e3
|
||||
stage_4_block_1:
|
||||
chosen:
|
||||
- mb_k3e3
|
||||
stage_4_block_2:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_4_block_3:
|
||||
chosen:
|
||||
- mb_k5e3
|
||||
stage_5_block_0:
|
||||
chosen:
|
||||
- mb_k5e6
|
||||
stage_5_block_1:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_5_block_2:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_5_block_3:
|
||||
chosen:
|
||||
- mb_k7e3
|
||||
stage_6_block_0:
|
||||
chosen:
|
||||
- mb_k5e6
|
|
@ -0,0 +1,20 @@
|
|||
_base_ = ['./spos_supernet_mobilenet_proxyless_gpu_8xb128_in1k.py']
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=512,
|
||||
workers_per_gpu=16,
|
||||
)
|
||||
|
||||
algorithm = dict(bn_training_mode=True)
|
||||
|
||||
searcher = dict(
|
||||
type='EvolutionSearcher',
|
||||
candidate_pool_size=50,
|
||||
candidate_top_k=10,
|
||||
constraints=dict(flops=465 * 1e6),
|
||||
metrics='accuracy',
|
||||
score_key='accuracy_top-1',
|
||||
max_epoch=20,
|
||||
num_mutation=25,
|
||||
num_crossover=25,
|
||||
mutate_prob=0.1)
|
|
@ -0,0 +1,27 @@
|
|||
_base_ = [
|
||||
'./spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py',
|
||||
]
|
||||
|
||||
img_norm_cfg = dict(mean=[0., 0., 0.], std=[1., 1., 1.], to_rgb=False)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(256, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'./spos_supernet_mobilenet_proxyless_gpu_8xb128_in1k.py',
|
||||
]
|
||||
|
||||
algorithm = dict(retraining=True)
|
||||
evaluation = dict(interval=10000, metric='accuracy')
|
||||
checkpoint_config = dict(interval=30000)
|
||||
|
||||
runner = dict(max_iters=300000)
|
||||
find_unused_parameters = False
|
|
@ -0,0 +1,101 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/mmcls/imagenet_bs128_colorjittor.py',
|
||||
'../../_base_/schedules/mmcls/imagenet_bs1024_spos.py',
|
||||
'../../_base_/mmcls_runtime.py'
|
||||
]
|
||||
norm_cfg = dict(type='BN')
|
||||
model = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SearchableMobileNet',
|
||||
first_channels=40,
|
||||
last_channels=1728,
|
||||
widen_factor=1.0,
|
||||
norm_cfg=norm_cfg,
|
||||
arch_setting_type='proxyless_gpu'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1728,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
),
|
||||
)
|
||||
|
||||
mutator = dict(
|
||||
type='OneShotMutator',
|
||||
placeholder_mapping=dict(
|
||||
searchable_blocks=dict(
|
||||
type='OneShotOP',
|
||||
choices=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=3,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k3e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=6,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=6,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=6,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity'))),
|
||||
first_blocks=dict(
|
||||
type='OneShotOP',
|
||||
choices=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU6')), ))))
|
||||
|
||||
algorithm = dict(
|
||||
type='SPOS',
|
||||
architecture=dict(
|
||||
type='MMClsArchitecture',
|
||||
model=model,
|
||||
),
|
||||
mutator=mutator,
|
||||
distiller=None,
|
||||
retraining=False,
|
||||
)
|
||||
|
||||
runner = dict(max_iters=150000)
|
||||
evaluation = dict(interval=10000, metric='accuracy')
|
||||
|
||||
# checkpoint saving
|
||||
checkpoint_config = dict(interval=30000)
|
||||
|
||||
find_unused_parameters = True
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .darts_backbone import DartsBackbone
|
||||
from .searchable_mobilenet import SearchableMobileNet
|
||||
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
|
||||
|
||||
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2']
|
||||
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2', 'SearchableMobileNet']
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.builder import BACKBONES
|
||||
from mmcls.models.utils import make_divisible
|
||||
from mmcv.cnn import ConvModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ...utils import Placeholder
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class SearchableMobileNet(BaseBackbone):
|
||||
"""Searchable MobileNet backbone.
|
||||
|
||||
Args:
|
||||
first_channels (int): Channel width of first ConvModule. Default: 32.
|
||||
last_channels (int): Channel width of last ConvModule. Default: 1200.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
arch_setting_type (str): Specify architecture setting.
|
||||
Default: 'original'.
|
||||
init_cfg (dict | list[dict]): initialization configuration dict to
|
||||
define initializer. OpenMMLab has implemented 6 initializers
|
||||
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
|
||||
``Kaiming``, and ``Pretrained``.
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride.
|
||||
arch_settings_dict = {
|
||||
'original': [[16, 1, 1], [24, 2, 2], [32, 3, 2], [64, 4, 2],
|
||||
[96, 3, 1], [160, 3, 2], [320, 1, 1]],
|
||||
'proxyless_gpu': [[24, 1, 1], [32, 4, 2], [56, 4, 2], [112, 4, 2],
|
||||
[128, 4, 1], [256, 4, 2], [432, 1, 1]],
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
first_channels=32,
|
||||
last_channels=1280,
|
||||
widen_factor=1.,
|
||||
out_indices=(7, ),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
arch_setting_type='original',
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]):
|
||||
super(SearchableMobileNet, self).__init__(init_cfg)
|
||||
|
||||
arch_settings = self.arch_settings_dict.get(arch_setting_type)
|
||||
if arch_settings is None:
|
||||
raise ValueError(f'Expect `arch_setting_type`: '
|
||||
f'{list(self.arch_settings_dict.keys())}, '
|
||||
f'but got: {arch_setting_type}')
|
||||
self.arch_settings = arch_settings
|
||||
self.widen_factor = widen_factor
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
if index not in range(0, 8):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 8). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 8):
|
||||
raise ValueError('frozen_stages must be in range(-1, 8). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(first_channels * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(self.arch_settings):
|
||||
channel, num_blocks, stride = layer_cfg
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self.make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
stage_idx=i)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
if widen_factor > 1.0:
|
||||
self.out_channel = int(last_channels * widen_factor)
|
||||
else:
|
||||
self.out_channel = last_channels
|
||||
|
||||
layer = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.add_module('conv2', layer)
|
||||
self.layers.append('conv2')
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, stride, stage_idx):
|
||||
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): number of blocks.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
expand_ratio (int): Expand the number of channels of the
|
||||
hidden layer in InvertedResidual by this ratio. Default: 6.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
if i >= 1:
|
||||
stride = 1
|
||||
# HACK
|
||||
# do not search first block
|
||||
if stage_idx == 0:
|
||||
group = 'first_blocks'
|
||||
else:
|
||||
group = 'searchable_blocks'
|
||||
layers.append(
|
||||
Placeholder(
|
||||
group=group,
|
||||
space_id=f'stage_{stage_idx}_block_{i}',
|
||||
choice_args=dict(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride,
|
||||
)))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
|
||||
Torch.tensor, containing input data for forward computation.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze params not to update in the specified stages."""
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Set module status before forward computation.
|
||||
|
||||
Args:
|
||||
mode (bool): Whether it is train_mode or test_mode
|
||||
"""
|
||||
super(SearchableMobileNet, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .common import Identity
|
||||
from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv,
|
||||
DartsSkipConnect, DartsZero)
|
||||
from .mobilenet_series import MBBlock
|
||||
from .shufflenet_series import ShuffleBlock, ShuffleXception
|
||||
|
||||
__all__ = [
|
||||
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
|
||||
'DartsSepConv', 'DartsSkipConnect', 'DartsZero'
|
||||
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import OPS
|
||||
from .base import BaseOP
|
||||
|
||||
|
||||
@OPS.register_module()
|
||||
class Identity(BaseOP):
|
||||
"""Base class for searchable operations.
|
||||
|
||||
Args:
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=None,
|
||||
**kwargs):
|
||||
super(Identity, self).__init__(**kwargs)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
if self.stride != 1 or self.in_channels != self.out_channels:
|
||||
self.downsample = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=self.stride,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcls.models.utils import SELayer
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
|
||||
from ..builder import OPS
|
||||
from .base import BaseOP
|
||||
|
||||
|
||||
@OPS.register_module()
|
||||
class MBBlock(BaseOP):
|
||||
"""Mobilenet block for Searchable backbone.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
expand_ratio (int): The input channels' expand factor of the depthwise
|
||||
convolution.
|
||||
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
|
||||
which means no se layer.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
expand_ratio,
|
||||
se_cfg=None,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
drop_path_rate=0.,
|
||||
with_cp=False,
|
||||
**kwargs):
|
||||
|
||||
super(MBBlock, self).__init__(**kwargs)
|
||||
self.with_res_shortcut = (
|
||||
self.stride == 1 and self.in_channels == self.out_channels)
|
||||
assert self.stride in [1, 2]
|
||||
self.kernel_size = kernel_size
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.with_cp = with_cp
|
||||
self.drop_path = DropPath(
|
||||
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.with_se = se_cfg is not None
|
||||
self.mid_channels = self.in_channels * expand_ratio
|
||||
self.with_expand_conv = (self.mid_channels != self.in_channels)
|
||||
|
||||
if self.with_se:
|
||||
assert isinstance(se_cfg, dict)
|
||||
|
||||
if self.with_expand_conv:
|
||||
self.expand_conv = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels=self.mid_channels,
|
||||
out_channels=self.mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=kernel_size // 2,
|
||||
groups=self.mid_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if self.with_se:
|
||||
self.se = SELayer(self.mid_channels, **se_cfg)
|
||||
self.linear_conv = ConvModule(
|
||||
in_channels=self.mid_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def _inner_forward(x):
|
||||
out = x
|
||||
|
||||
if self.with_expand_conv:
|
||||
out = self.expand_conv(out)
|
||||
|
||||
out = self.depthwise_conv(out)
|
||||
|
||||
if self.with_se:
|
||||
out = self.se(out)
|
||||
|
||||
out = self.linear_conv(out)
|
||||
|
||||
if self.with_res_shortcut:
|
||||
return x + self.drop_path(out)
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
|
@ -236,6 +236,90 @@ def test_spos():
|
|||
assert flops_supernet > flops_subnet_spos > 0
|
||||
|
||||
|
||||
def test_spos_mobilenet():
|
||||
|
||||
model_cfg = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(type='SearchableMobileNet', widen_factor=1.0),
|
||||
neck=dict(type='mmcls.GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='mmcls.LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1280,
|
||||
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
),
|
||||
)
|
||||
|
||||
architecture_cfg = dict(
|
||||
type='MMClsArchitecture',
|
||||
model=model_cfg,
|
||||
)
|
||||
|
||||
mutator_cfg = dict(
|
||||
type='OneShotMutator',
|
||||
placeholder_mapping=dict(
|
||||
searchable_blocks=dict(
|
||||
type='OneShotOP',
|
||||
choices=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=3,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k3e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=6,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=6,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=6,
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity'))),
|
||||
first_blocks=dict(
|
||||
type='OneShotOP',
|
||||
choices=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
act_cfg=dict(type='ReLU6')), ))))
|
||||
|
||||
algorithm_cfg = dict(
|
||||
type='SPOS',
|
||||
architecture=architecture_cfg,
|
||||
mutator=mutator_cfg,
|
||||
retraining=False,
|
||||
)
|
||||
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
label = torch.randint(0, 1000, (16, ))
|
||||
|
||||
algorithm_cfg_ = deepcopy(algorithm_cfg)
|
||||
algorithm = ALGORITHMS.build(algorithm_cfg_)
|
||||
|
||||
# test forward
|
||||
losses = algorithm(imgs, return_loss=True, gt_label=label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_detnas():
|
||||
config_path = os.path.join(
|
||||
dirname(dirname(dirname(__file__))),
|
||||
|
|
|
@ -4,6 +4,40 @@ import torch
|
|||
from mmrazor.models.builder import OPS
|
||||
|
||||
|
||||
def test_common_ops():
|
||||
tensor = torch.randn(16, 16, 32, 32)
|
||||
|
||||
# test stride != 1
|
||||
identity_cfg = dict(
|
||||
type='Identity', in_channels=16, out_channels=16, stride=2)
|
||||
|
||||
op = OPS.build(identity_cfg)
|
||||
|
||||
# test forward
|
||||
outputs = op(tensor)
|
||||
assert outputs.size(1) == 16 and outputs.size(2) == 16
|
||||
|
||||
# test stride == 1
|
||||
identity_cfg = dict(
|
||||
type='Identity', in_channels=16, out_channels=16, stride=1)
|
||||
|
||||
op = OPS.build(identity_cfg)
|
||||
|
||||
# test forward
|
||||
outputs = op(tensor)
|
||||
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||||
|
||||
# test in_channels != out_channels
|
||||
identity_cfg = dict(
|
||||
type='Identity', in_channels=8, out_channels=16, stride=1)
|
||||
|
||||
op = OPS.build(identity_cfg)
|
||||
|
||||
# test forward
|
||||
outputs = op(tensor[:, :8])
|
||||
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||||
|
||||
|
||||
def test_shuffle_series():
|
||||
|
||||
tensor = torch.randn(16, 16, 32, 32)
|
||||
|
@ -61,6 +95,51 @@ def test_shuffle_series():
|
|||
assert outputs.size(1) == 16 and outputs.size(2) == 32
|
||||
|
||||
|
||||
def test_mobilenet_series():
|
||||
|
||||
tensor = torch.randn(16, 16, 32, 32)
|
||||
|
||||
kernel_sizes = (3, 5, 7)
|
||||
expand_ratios = (3, 6)
|
||||
strides = (1, 2)
|
||||
se_cfg_1 = dict(
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='HSwish'),
|
||||
dict(
|
||||
type='HSigmoid',
|
||||
bias=3,
|
||||
divisor=6,
|
||||
min_value=0,
|
||||
max_value=1)))
|
||||
se_cfgs = (None, se_cfg_1)
|
||||
drop_path_rates = (0, 0.2)
|
||||
with_cps = (True, False)
|
||||
|
||||
for kernel_size in kernel_sizes:
|
||||
for expand_ratio in expand_ratios:
|
||||
for stride in strides:
|
||||
for se_cfg in se_cfgs:
|
||||
for drop_path_rate in drop_path_rates:
|
||||
for with_cp in with_cps:
|
||||
op_cfg = dict(
|
||||
type='MBBlock',
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
kernel_size=kernel_size,
|
||||
expand_ratio=expand_ratio,
|
||||
se_cfg=se_cfg,
|
||||
drop_path_rate=drop_path_rate,
|
||||
with_cp=with_cp,
|
||||
stride=stride)
|
||||
|
||||
op = OPS.build(op_cfg)
|
||||
|
||||
# test forward
|
||||
outputs = op(tensor)
|
||||
assert outputs.size(1) == 16 and outputs.size(
|
||||
2) == 32 // stride
|
||||
|
||||
|
||||
def test_darts_series():
|
||||
|
||||
tensor = torch.randn(16, 16, 32, 32)
|
||||
|
|
Loading…
Reference in New Issue