[FEATURE] Mobilenet Series Search Space (#82)

* add mbv2 block and identity

* add mbv2 block and identity unittests

* expand_factor -> expand_ratio

* add searchable mobilenet v2

* fix lints

* add spos subnet retraining config

* fix mmcls slurm search

* add proxyless_gpu setting for mbv2

* use bn default

* add angelnas spos config

* update spos readme

* fix SELayer's useage

* add docstring

* rename mbv2 to mb

* add some unittest of mb

* rename mb to mobilenet

* add some rename-mb in configs

* update README of spos

* add rename-mb in unittest

* update test_mmcls

Co-authored-by: wutongshenqiu <690364065@qq.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>
pull/106/head
pppppM 2022-03-07 23:15:18 +08:00 committed by GitHub
parent f59e059cf7
commit ef1637e866
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 799 additions and 10 deletions

View File

@ -1,4 +1,5 @@
# SPOS
> [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420)
<!-- [ALGORITHM] -->
@ -10,16 +11,18 @@ Comprehensive experiments verify that our approach is flexible and effective. It
![pipeline](/docs/en/imgs/model_zoo/spos/pipeline.jpg)
## Introduction
### Supernet pre-training on ImageNet
```bash
python ./tools/mmcls/train_mmcls.py \
configs/nas/spos/spos_supernet_shufflenetv2_8xb128_in1k.py \
--work-dir $WORK_DIR
```
### Search for subnet on the trained supernet
```bash
python ./tools/mmcls/search_mmcls.py \
configs/nas/spos/spos_evolution_search_shufflenetv2_8xb2048_in1k.py \
@ -28,6 +31,7 @@ python ./tools/mmcls/search_mmcls.py \
```
### Subnet retraining on ImageNet
```bash
python ./tools/mmcls/train_mmcls.py \
configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py \
@ -36,14 +40,18 @@ python ./tools/mmcls/train_mmcls.py \
```
## Results and models
|Dataset| Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
|:---------------------:|:---------------------:|:------:|:---------:|:--------:|:---------:|:---------:|:------:|:---------|:---------:|
|ImageNet| ShuffleNetV2 |[mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx)| 3.35 | 0.33 | 73.87 | 91.6 |[config](./spos_subnet_shufflenetv2_8xb128_in1k.py)|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth?versionId=CAEQHxiBgIDK5b6I7xciIDM1YjIwZjQxN2UyMDRjYjA5YTM5NTBlMGNhMTdkNjI2) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json?versionId=CAEQHxiBgIDr9cuL7xciIDBmOTZiZGUyYjRiMDQ5NzhhZjY0NWUxYmUzNDlmNTg5)| MMRazor searched
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
| :------: |:----------------------:| :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------:|
| ImageNet | ShuffleNetV2 | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 3.35 | 0.33 | 73.87 | 91.6 | [config](./spos_subnet_shufflenetv2_8xb128_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.pth?versionId=CAEQHxiBgIDK5b6I7xciIDM1YjIwZjQxN2UyMDRjYjA5YTM5NTBlMGNhMTdkNjI2) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d.log.json?versionId=CAEQHxiBgIDr9cuL7xciIDBmOTZiZGUyYjRiMDQ5NzhhZjY0NWUxYmUzNDlmNTg5) | MMRazor searched |
| ImageNet | MobileNet-ProxylessGPU | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/spos/spos_mobilenet_subnet/spos_angelnas_flops_0.49G_acc_75.98_20220307-54f4698f_mutable_cfg.yaml) | 5.94 | 0.49* | 75.98 | 92.77 | [config](./spos_mobilenet_for_check_ckpt_from_anglenas.py) | | [AngleNAS](https://github.com/megvii-model/AngleNAS) searched |
**Note**:
1. There are some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network.
2. We also retrain the subnet reported in paper with their official code, Top-1 is 73.6 and Top-5 is 91.6
1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
2. For *ShuffleNetV2*, we retrain the subnet reported in paper with their official code, Top-1 is 73.6 and Top-5 is 91.6.
2. For *AngleNAS searched MobileNet-ProxylessGPU*, we obtain params and FLOPs using [this script](/tools/misc/get_flops.py), which may be different from [AngleNAS](https://github.com/megvii-model/AngleNAS#searched-models-with-abs).
## Citation

View File

@ -0,0 +1,66 @@
stage_0_block_0:
chosen:
- mb_k3e1
stage_1_block_0:
chosen:
- mb_k5e3
stage_1_block_1:
chosen:
- mb_k5e3
stage_1_block_2:
chosen:
- identity
stage_1_block_3:
chosen:
- mb_k3e3
stage_2_block_0:
chosen:
- mb_k3e3
stage_2_block_1:
chosen:
- identity
stage_2_block_2:
chosen:
- identity
stage_2_block_3:
chosen:
- mb_k3e3
stage_3_block_0:
chosen:
- mb_k7e6
stage_3_block_1:
chosen:
- identity
stage_3_block_2:
chosen:
- mb_k7e3
stage_3_block_3:
chosen:
- mb_k7e3
stage_4_block_0:
chosen:
- mb_k3e3
stage_4_block_1:
chosen:
- mb_k3e3
stage_4_block_2:
chosen:
- mb_k7e3
stage_4_block_3:
chosen:
- mb_k5e3
stage_5_block_0:
chosen:
- mb_k5e6
stage_5_block_1:
chosen:
- mb_k7e3
stage_5_block_2:
chosen:
- mb_k7e3
stage_5_block_3:
chosen:
- mb_k7e3
stage_6_block_0:
chosen:
- mb_k5e6

View File

@ -0,0 +1,20 @@
_base_ = ['./spos_supernet_mobilenet_proxyless_gpu_8xb128_in1k.py']
data = dict(
samples_per_gpu=512,
workers_per_gpu=16,
)
algorithm = dict(bn_training_mode=True)
searcher = dict(
type='EvolutionSearcher',
candidate_pool_size=50,
candidate_top_k=10,
constraints=dict(flops=465 * 1e6),
metrics='accuracy',
score_key='accuracy_top-1',
max_epoch=20,
num_mutation=25,
num_crossover=25,
mutate_prob=0.1)

View File

@ -0,0 +1,27 @@
_base_ = [
'./spos_subnet_mobilenet_proxyless_gpu_8xb128_in1k.py',
]
img_norm_cfg = dict(mean=[0., 0., 0.], std=[1., 1., 1.], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,10 @@
_base_ = [
'./spos_supernet_mobilenet_proxyless_gpu_8xb128_in1k.py',
]
algorithm = dict(retraining=True)
evaluation = dict(interval=10000, metric='accuracy')
checkpoint_config = dict(interval=30000)
runner = dict(max_iters=300000)
find_unused_parameters = False

View File

@ -0,0 +1,101 @@
_base_ = [
'../../_base_/datasets/mmcls/imagenet_bs128_colorjittor.py',
'../../_base_/schedules/mmcls/imagenet_bs1024_spos.py',
'../../_base_/mmcls_runtime.py'
]
norm_cfg = dict(type='BN')
model = dict(
type='mmcls.ImageClassifier',
backbone=dict(
type='SearchableMobileNet',
first_channels=40,
last_channels=1728,
widen_factor=1.0,
norm_cfg=norm_cfg,
arch_setting_type='proxyless_gpu'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1728,
loss=dict(
type='LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5),
),
)
mutator = dict(
type='OneShotMutator',
placeholder_mapping=dict(
searchable_blocks=dict(
type='OneShotOP',
choices=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
mb_k7e3=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=3,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
mb_k3e6=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=6,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
mb_k5e6=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=6,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
mb_k7e6=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=6,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity'))),
first_blocks=dict(
type='OneShotOP',
choices=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU6')), ))))
algorithm = dict(
type='SPOS',
architecture=dict(
type='MMClsArchitecture',
model=model,
),
mutator=mutator,
distiller=None,
retraining=False,
)
runner = dict(max_iters=150000)
evaluation = dict(interval=10000, metric='accuracy')
# checkpoint saving
checkpoint_config = dict(interval=30000)
find_unused_parameters = True

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_backbone import DartsBackbone
from .searchable_mobilenet import SearchableMobileNet
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2']
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2', 'SearchableMobileNet']

View File

@ -0,0 +1,214 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.builder import BACKBONES
from mmcls.models.utils import make_divisible
from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm
from ...utils import Placeholder
@BACKBONES.register_module()
class SearchableMobileNet(BaseBackbone):
"""Searchable MobileNet backbone.
Args:
first_channels (int): Channel width of first ConvModule. Default: 32.
last_channels (int): Channel width of last ConvModule. Default: 1200.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
out_indices (None or Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
arch_setting_type (str): Specify architecture setting.
Default: 'original'.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
"""
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride.
arch_settings_dict = {
'original': [[16, 1, 1], [24, 2, 2], [32, 3, 2], [64, 4, 2],
[96, 3, 1], [160, 3, 2], [320, 1, 1]],
'proxyless_gpu': [[24, 1, 1], [32, 4, 2], [56, 4, 2], [112, 4, 2],
[128, 4, 1], [256, 4, 2], [432, 1, 1]],
}
def __init__(self,
first_channels=32,
last_channels=1280,
widen_factor=1.,
out_indices=(7, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False,
arch_setting_type='original',
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(SearchableMobileNet, self).__init__(init_cfg)
arch_settings = self.arch_settings_dict.get(arch_setting_type)
if arch_settings is None:
raise ValueError(f'Expect `arch_setting_type`: '
f'{list(self.arch_settings_dict.keys())}, '
f'but got: {arch_setting_type}')
self.arch_settings = arch_settings
self.widen_factor = widen_factor
self.out_indices = out_indices
for index in out_indices:
if index not in range(0, 8):
raise ValueError('the item in out_indices must in '
f'range(0, 8). But received {index}')
if frozen_stages not in range(-1, 8):
raise ValueError('frozen_stages must be in range(-1, 8). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(first_channels * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
channel, num_blocks, stride = layer_cfg
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
stage_idx=i)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
if widen_factor > 1.0:
self.out_channel = int(last_channels * widen_factor)
else:
self.out_channel = last_channels
layer = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channel,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.add_module('conv2', layer)
self.layers.append('conv2')
def make_layer(self, out_channels, num_blocks, stride, stage_idx):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio. Default: 6.
"""
layers = []
for i in range(num_blocks):
if i >= 1:
stride = 1
# HACK
# do not search first block
if stage_idx == 0:
group = 'first_blocks'
else:
group = 'searchable_blocks'
layers.append(
Placeholder(
group=group,
space_id=f'stage_{stage_idx}_block_{i}',
choice_args=dict(
in_channels=self.in_channels,
out_channels=out_channels,
stride=stride,
)))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
"""Forward computation.
Args:
x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
Torch.tensor, containing input data for forward computation.
"""
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
"""Freeze params not to update in the specified stages."""
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Set module status before forward computation.
Args:
mode (bool): Whether it is train_mode or test_mode
"""
super(SearchableMobileNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .common import Identity
from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv,
DartsSkipConnect, DartsZero)
from .mobilenet_series import MBBlock
from .shufflenet_series import ShuffleBlock, ShuffleXception
__all__ = [
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
'DartsSepConv', 'DartsSkipConnect', 'DartsZero'
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity'
]

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from ..builder import OPS
from .base import BaseOP
@OPS.register_module()
class Identity(BaseOP):
"""Base class for searchable operations.
Args:
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: None.
"""
def __init__(self,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=None,
**kwargs):
super(Identity, self).__init__(**kwargs)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
if self.stride != 1 or self.in_channels != self.out_channels:
self.downsample = ConvModule(
self.in_channels,
self.out_channels,
kernel_size=1,
stride=self.stride,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
self.downsample = None
def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
return x

View File

@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcls.models.utils import SELayer
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from ..builder import OPS
from .base import BaseOP
@OPS.register_module()
class MBBlock(BaseOP):
"""Mobilenet block for Searchable backbone.
Args:
kernel_size (int): Size of the convolving kernel.
expand_ratio (int): The input channels' expand factor of the depthwise
convolution.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
kernel_size,
expand_ratio,
se_cfg=None,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
**kwargs):
super(MBBlock, self).__init__(**kwargs)
self.with_res_shortcut = (
self.stride == 1 and self.in_channels == self.out_channels)
assert self.stride in [1, 2]
self.kernel_size = kernel_size
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.mid_channels = self.in_channels * expand_ratio
self.with_expand_conv = (self.mid_channels != self.in_channels)
if self.with_se:
assert isinstance(se_cfg, dict)
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=self.in_channels,
out_channels=self.mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.mid_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=kernel_size // 2,
groups=self.mid_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(self.mid_channels, **se_cfg)
self.linear_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + self.drop_path(out)
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out

View File

@ -236,6 +236,90 @@ def test_spos():
assert flops_supernet > flops_subnet_spos > 0
def test_spos_mobilenet():
model_cfg = dict(
type='mmcls.ImageClassifier',
backbone=dict(type='SearchableMobileNet', widen_factor=1.0),
neck=dict(type='mmcls.GlobalAveragePooling'),
head=dict(
type='mmcls.LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
)
architecture_cfg = dict(
type='MMClsArchitecture',
model=model_cfg,
)
mutator_cfg = dict(
type='OneShotMutator',
placeholder_mapping=dict(
searchable_blocks=dict(
type='OneShotOP',
choices=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
act_cfg=dict(type='ReLU6')),
mb_k7e3=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=3,
act_cfg=dict(type='ReLU6')),
mb_k3e6=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=6,
act_cfg=dict(type='ReLU6')),
mb_k5e6=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=6,
act_cfg=dict(type='ReLU6')),
mb_k7e6=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=6,
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity'))),
first_blocks=dict(
type='OneShotOP',
choices=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
act_cfg=dict(type='ReLU6')), ))))
algorithm_cfg = dict(
type='SPOS',
architecture=architecture_cfg,
mutator=mutator_cfg,
retraining=False,
)
imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))
algorithm_cfg_ = deepcopy(algorithm_cfg)
algorithm = ALGORITHMS.build(algorithm_cfg_)
# test forward
losses = algorithm(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0
def test_detnas():
config_path = os.path.join(
dirname(dirname(dirname(__file__))),

View File

@ -4,6 +4,40 @@ import torch
from mmrazor.models.builder import OPS
def test_common_ops():
tensor = torch.randn(16, 16, 32, 32)
# test stride != 1
identity_cfg = dict(
type='Identity', in_channels=16, out_channels=16, stride=2)
op = OPS.build(identity_cfg)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 16
# test stride == 1
identity_cfg = dict(
type='Identity', in_channels=16, out_channels=16, stride=1)
op = OPS.build(identity_cfg)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test in_channels != out_channels
identity_cfg = dict(
type='Identity', in_channels=8, out_channels=16, stride=1)
op = OPS.build(identity_cfg)
# test forward
outputs = op(tensor[:, :8])
assert outputs.size(1) == 16 and outputs.size(2) == 32
def test_shuffle_series():
tensor = torch.randn(16, 16, 32, 32)
@ -61,6 +95,51 @@ def test_shuffle_series():
assert outputs.size(1) == 16 and outputs.size(2) == 32
def test_mobilenet_series():
tensor = torch.randn(16, 16, 32, 32)
kernel_sizes = (3, 5, 7)
expand_ratios = (3, 6)
strides = (1, 2)
se_cfg_1 = dict(
ratio=4,
act_cfg=(dict(type='HSwish'),
dict(
type='HSigmoid',
bias=3,
divisor=6,
min_value=0,
max_value=1)))
se_cfgs = (None, se_cfg_1)
drop_path_rates = (0, 0.2)
with_cps = (True, False)
for kernel_size in kernel_sizes:
for expand_ratio in expand_ratios:
for stride in strides:
for se_cfg in se_cfgs:
for drop_path_rate in drop_path_rates:
for with_cp in with_cps:
op_cfg = dict(
type='MBBlock',
in_channels=16,
out_channels=16,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
se_cfg=se_cfg,
drop_path_rate=drop_path_rate,
with_cp=with_cp,
stride=stride)
op = OPS.build(op_cfg)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(
2) == 32 // stride
def test_darts_series():
tensor = torch.randn(16, 16, 32, 32)