Refactor backbone
parent
910b131183
commit
99e7993376
|
@ -67,4 +67,5 @@ repos:
|
|||
(?x)(
|
||||
^test
|
||||
| ^docs
|
||||
| ^configs
|
||||
)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
_FIRST_STAGE_MUTABLE = dict(
|
||||
type='OneShotOP',
|
||||
candidate_ops=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'))))
|
||||
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
type='OneShotOP',
|
||||
candidate_ops=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k3e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity')))
|
||||
|
||||
arch_setting = [
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
|
||||
[16, 1, 1, _FIRST_STAGE_MUTABLE],
|
||||
[24, 2, 2, _OTHER_STAGE_MUTABLE],
|
||||
[32, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[64, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[96, 3, 1, _OTHER_STAGE_MUTABLE],
|
||||
[160, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[320, 1, 1, _OTHER_STAGE_MUTABLE]
|
||||
]
|
|
@ -0,0 +1,62 @@
|
|||
_FIRST_STAGE_MUTABLE = dict(
|
||||
type='OneShotOP',
|
||||
candidate_ops=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'))))
|
||||
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
type='OneShotOP',
|
||||
candidate_ops=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k3e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k7e6=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=7,
|
||||
expand_ratio=6,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity')))
|
||||
|
||||
arch_setting = [
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
|
||||
[24, 1, 1, _FIRST_STAGE_MUTABLE],
|
||||
[32, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[56, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[112, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[128, 4, 1, _OTHER_STAGE_MUTABLE],
|
||||
[256, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[432, 1, 1, _OTHER_STAGE_MUTABLE]
|
||||
]
|
|
@ -0,0 +1,21 @@
|
|||
_STAGE_MUTABLE = dict(
|
||||
type='OneShotOP',
|
||||
candidate_ops=dict(
|
||||
shuffle_3x3=dict(
|
||||
type='ShuffleBlock', kernel_size=3, norm_cfg=dict(type='BN')),
|
||||
shuffle_5x5=dict(
|
||||
type='ShuffleBlock', kernel_size=5, norm_cfg=dict(type='BN')),
|
||||
shuffle_7x7=dict(
|
||||
type='ShuffleBlock', kernel_size=7, norm_cfg=dict(type='BN')),
|
||||
shuffle_xception=dict(
|
||||
type='ShuffleXception', norm_cfg=dict(type='BN')),
|
||||
))
|
||||
|
||||
arch_setting = [
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, mutable_cfg.
|
||||
[64, 4, _STAGE_MUTABLE],
|
||||
[160, 4, _STAGE_MUTABLE],
|
||||
[320, 8, _STAGE_MUTABLE],
|
||||
[640, 4, _STAGE_MUTABLE],
|
||||
]
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .components import * # noqa: F401,F403
|
||||
from .mmcls import MMClsArchitecture
|
||||
from .mmdet import MMDetArchitecture
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .searchable_mobilenet import SearchableMobileNet
|
||||
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
|
||||
|
||||
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2']
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.utils import make_divisible
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import Sequential
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SearchableMobileNet(BaseBackbone):
|
||||
"""Searchable MobileNet backbone.
|
||||
|
||||
Args:
|
||||
arch_setting (list[list]): Architecture settings.
|
||||
first_channels (int): Channel width of first ConvModule. Default: 32.
|
||||
last_channels (int): Channel width of last ConvModule. Default: 1200.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
init_cfg (dict | list[dict], optional): initialization configuration
|
||||
dict to define initializer. OpenMMLab has implemented
|
||||
6 initializers, including ``Constant``, ``Xavier``, ``Normal``,
|
||||
``Uniform``, ``Kaiming``, and ``Pretrained``.
|
||||
|
||||
Excamples:
|
||||
>>> mutable_cfg = dict(
|
||||
... type='OneShotOP',
|
||||
... candidate_ops=dict(
|
||||
... mb_k3e1=dict(
|
||||
... type='MBBlock',
|
||||
... kernel_size=3,
|
||||
... expand_ratio=1,
|
||||
... norm_cfg=dict(type='BN'),
|
||||
... act_cfg=dict(type='ReLU6'))))
|
||||
>>> arch_setting = [
|
||||
... # Parameters to build layers. 4 parameters are needed to
|
||||
... # construct a layer, from left to right:
|
||||
... # channel, num_blocks, stride, mutable cfg.
|
||||
... [16, 1, 1, mutable_cfg],
|
||||
... [24, 2, 2, mutable_cfg],
|
||||
... [32, 3, 2, mutable_cfg],
|
||||
... [64, 4, 2, mutable_cfg],
|
||||
... [96, 3, 1, mutable_cfg],
|
||||
... [160, 3, 2, mutable_cfg],
|
||||
... [320, 1, 1, mutable_cfg]
|
||||
... ]
|
||||
>>> model = SearchableMobileNet(arch_setting=arch_setting)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch_setting: List[List],
|
||||
first_channels: int = 32,
|
||||
last_channels: int = 1280,
|
||||
widen_factor: float = 1.,
|
||||
out_indices: Sequence[int] = (7, ),
|
||||
frozen_stages: int = -1,
|
||||
conv_cfg: Optional[Dict] = None,
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
act_cfg: Dict = dict(type='ReLU6'),
|
||||
norm_eval: bool = False,
|
||||
with_cp: bool = False,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
) -> None:
|
||||
for index in out_indices:
|
||||
if index not in range(0, 8):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 8). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 8):
|
||||
raise ValueError('frozen_stages must be in range(-1, 8). '
|
||||
f'But received {frozen_stages}')
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.arch_setting = arch_setting
|
||||
self.widen_factor = widen_factor
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(first_channels * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(arch_setting):
|
||||
channel, num_blocks, stride, mutable_cfg = layer_cfg
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self._make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
mutable_cfg=copy.deepcopy(mutable_cfg))
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
if widen_factor > 1.0:
|
||||
self.out_channel = int(last_channels * widen_factor)
|
||||
else:
|
||||
self.out_channel = last_channels
|
||||
|
||||
layer = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.add_module('conv2', layer)
|
||||
self.layers.append('conv2')
|
||||
|
||||
def _make_layer(self, out_channels: int, num_blocks: int, stride: int,
|
||||
mutable_cfg: Dict) -> Sequential:
|
||||
"""Stack mutable blocks to build a layer for SearchableMobileNet.
|
||||
|
||||
Note:
|
||||
Here we use ``module_kwargs`` to pass dynamic parameters such as
|
||||
``in_channels``, ``out_channels`` and ``stride``
|
||||
to build the mutable.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): number of blocks.
|
||||
stride (int): stride of the first block.
|
||||
mutable_cfg (dict): Config of mutable.
|
||||
|
||||
Returns:
|
||||
mmcv.runner.Sequential: The layer made.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
if i >= 1:
|
||||
stride = 1
|
||||
|
||||
mutable_cfg.update(
|
||||
module_kwargs=dict(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride))
|
||||
layers.append(MODELS.build(mutable_cfg))
|
||||
|
||||
self.in_channels = out_channels
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
x (tensor): x contains input data for forward computation.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self) -> None:
|
||||
"""Freeze params not to update in the specified stages."""
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True) -> None:
|
||||
"""Set module status before forward computation."""
|
||||
super().train(mode)
|
||||
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcv.cnn import ConvModule, constant_init, normal_init
|
||||
from mmcv.runner import ModuleList, Sequential
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SearchableShuffleNetV2(BaseBackbone):
|
||||
"""Based on ShuffleNetV2 backbone.
|
||||
|
||||
Args:
|
||||
arch_setting (list[list]): Architecture settings.
|
||||
stem_multiplier (int): Stem multiplier - adjusts the number of
|
||||
channels in the first layer. Default: 1.
|
||||
widen_factor (float): Width multiplier - adjusts the number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (4, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
with_last_layer (bool): Whether is last layer.
|
||||
Default: True, which means not need to add `Placeholder``.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
init_cfg (dict | list[dict], optional): initialization configuration
|
||||
dict to define initializer. OpenMMLab has implemented
|
||||
6 initializers, including ``Constant``, ``Xavier``, ``Normal``,
|
||||
``Uniform``, ``Kaiming``, and ``Pretrained``.
|
||||
|
||||
Excamples:
|
||||
>>> mutable_cfg = dict(
|
||||
... type='OneShotOP',
|
||||
... candidate_ops=dict(
|
||||
... shuffle_3x3=dict(
|
||||
... type='ShuffleBlock',
|
||||
... kernel_size=3,
|
||||
... norm_cfg=dict(type='BN'))))
|
||||
>>> arch_setting = [
|
||||
... # Parameters to build layers. 3 parameters are needed to
|
||||
... # construct a layer, from left to right:
|
||||
... # channel, num_blocks, mutable cfg.
|
||||
... [64, 4, mutable_cfg],
|
||||
... [160, 4, mutable_cfg],
|
||||
... [320, 8, mutable_cfg],
|
||||
... [640, 4, mutable_cfg]
|
||||
... ]
|
||||
>>> model = SearchableShuffleNetV2(arch_setting=arch_setting)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch_setting: List[List],
|
||||
stem_multiplier: int = 1,
|
||||
widen_factor: float = 1.0,
|
||||
out_indices: Sequence[int] = (4, ),
|
||||
frozen_stages: int = -1,
|
||||
with_last_layer: bool = True,
|
||||
conv_cfg: Optional[Dict] = None,
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_eval: bool = False,
|
||||
with_cp: bool = False,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
layers_nums = 5 if with_last_layer else 4
|
||||
for index in out_indices:
|
||||
if index not in range(0, layers_nums):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 5). But received {index}')
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
if frozen_stages not in range(-1, layers_nums):
|
||||
raise ValueError('frozen_stages must be in range(-1, 5). '
|
||||
f'But received {frozen_stages}')
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.arch_setting = arch_setting
|
||||
self.widen_factor = widen_factor
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
last_channels = 1024
|
||||
self.in_channels = 16 * stem_multiplier
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.layers = ModuleList()
|
||||
for channel, num_blocks, mutable_cfg in arch_setting:
|
||||
out_channels = round(channel * widen_factor)
|
||||
layer = self._make_layer(out_channels, num_blocks,
|
||||
copy.deepcopy(mutable_cfg))
|
||||
self.layers.append(layer)
|
||||
|
||||
if with_last_layer:
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=last_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def _make_layer(self, out_channels: int, num_blocks: int,
|
||||
mutable_cfg: Dict) -> Sequential:
|
||||
"""Stack mutable blocks to build a layer for ShuffleNet V2.
|
||||
|
||||
Note:
|
||||
Here we use ``module_kwargs`` to pass dynamic parameters such as
|
||||
``in_channels``, ``out_channels`` and ``stride``
|
||||
to build the mutable.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of the block.
|
||||
num_blocks (int): number of blocks.
|
||||
mutable_cfg (dict): Config of mutable.
|
||||
|
||||
Returns:
|
||||
mmcv.runner.Sequential: The layer made.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
stride = 2 if i == 0 else 1
|
||||
|
||||
mutable_cfg.update(
|
||||
module_kwargs=dict(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride))
|
||||
layers.append(MODELS.build(mutable_cfg))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _freeze_stages(self) -> None:
|
||||
"""Freeze params not to update in the specified stages."""
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(self.frozen_stages):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Init weights of ``SearchableShuffleNetV2``."""
|
||||
super().init_weights()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
return
|
||||
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'conv1' in name:
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
else:
|
||||
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, val=1, bias=0.0001)
|
||||
if isinstance(m, _BatchNorm):
|
||||
if m.running_mean is not None:
|
||||
nn.init.constant_(m.running_mean, 0)
|
||||
|
||||
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
x (tensor): x contains input data for forward computation.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode: bool = True) -> None:
|
||||
"""Set module status before forward computation."""
|
||||
super().train(mode)
|
||||
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -1,4 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmcls.models import * # noqa: F401,F403
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.models.mutables import * # noqa: F401,F403
|
||||
from mmrazor.registry import MODELS
|
||||
from .utils import MockMutable
|
||||
|
||||
_FIRST_STAGE_MUTABLE = dict(type='MockMutable', choices=['c1'])
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
|
||||
ARCHSETTING_CFG = [
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
|
||||
[16, 1, 1, _FIRST_STAGE_MUTABLE],
|
||||
[24, 2, 2, _OTHER_STAGE_MUTABLE],
|
||||
[32, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[64, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[96, 3, 1, _OTHER_STAGE_MUTABLE],
|
||||
[160, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[320, 1, 1, _OTHER_STAGE_MUTABLE]
|
||||
]
|
||||
NORM_CFG = dict(type='BN')
|
||||
BACKBONE_CFG = dict(
|
||||
type='mmrazor.SearchableMobileNet',
|
||||
first_channels=32,
|
||||
last_channels=1280,
|
||||
widen_factor=1.0,
|
||||
norm_cfg=NORM_CFG,
|
||||
arch_setting=ARCHSETTING_CFG)
|
||||
|
||||
|
||||
def test_searchable_mobilenet_mutable() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
|
||||
choices = ['c1', 'c2', 'c3', 'c4']
|
||||
mutable_nums = 0
|
||||
|
||||
for name, module in backbone.named_modules():
|
||||
if isinstance(module, MockMutable):
|
||||
if 'layer1' in name:
|
||||
assert module.choices == ['c1']
|
||||
else:
|
||||
assert module.choices == choices
|
||||
mutable_nums += 1
|
||||
|
||||
arch_setting = backbone.arch_setting
|
||||
target_mutable_nums = 0
|
||||
for layer_cfg in arch_setting:
|
||||
target_mutable_nums += layer_cfg[1]
|
||||
assert mutable_nums == target_mutable_nums
|
||||
|
||||
|
||||
def test_searchable_mobilenet_train() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
backbone.train(mode=True)
|
||||
for m in backbone.modules():
|
||||
assert m.training
|
||||
|
||||
backbone.norm_eval = True
|
||||
backbone.train(mode=True)
|
||||
for m in backbone.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
assert not m.training
|
||||
else:
|
||||
assert m.training
|
||||
|
||||
x = torch.rand(10, 3, 224, 224)
|
||||
assert len(backbone(x)) == 1
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['frozen_stages'] = 5
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
backbone.train()
|
||||
|
||||
for param in backbone.conv1.parameters():
|
||||
assert not param.requires_grad
|
||||
for i in range(1, 8):
|
||||
layer = getattr(backbone, f'layer{i}')
|
||||
for m in layer.modules():
|
||||
if i <= 5:
|
||||
assert not m.training
|
||||
else:
|
||||
assert m.training
|
||||
for param in layer.parameters():
|
||||
if i <= 5:
|
||||
assert not param.requires_grad
|
||||
else:
|
||||
assert param.requires_grad
|
||||
|
||||
|
||||
def test_searchable_mobilenet_init() -> None:
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['out_indices'] = (10, )
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MODELS.build(backbone_cfg)
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['frozen_stages'] = 8
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MODELS.build(backbone_cfg)
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['widen_factor'] = 1.5
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
assert backbone.out_channel == 1920
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmcls.models import * # noqa: F401,F403
|
||||
from torch.nn import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.models.mutables import * # noqa: F401,F403
|
||||
from mmrazor.registry import MODELS
|
||||
from .utils import MockMutable
|
||||
|
||||
STAGE_MUTABLE = dict(type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
|
||||
ARCHSETTING_CFG = [
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, mutable_cfg.
|
||||
[64, 4, STAGE_MUTABLE],
|
||||
[160, 4, STAGE_MUTABLE],
|
||||
[320, 8, STAGE_MUTABLE],
|
||||
[640, 4, STAGE_MUTABLE],
|
||||
]
|
||||
|
||||
NORM_CFG = dict(type='BN')
|
||||
BACKBONE_CFG = dict(
|
||||
type='mmrazor.SearchableShuffleNetV2',
|
||||
widen_factor=1.0,
|
||||
norm_cfg=NORM_CFG,
|
||||
arch_setting=ARCHSETTING_CFG)
|
||||
|
||||
|
||||
def test_searchable_shufflenet_v2_mutable() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
|
||||
choices = ['c1', 'c2', 'c3', 'c4']
|
||||
mutable_nums = 0
|
||||
|
||||
for module in backbone.modules():
|
||||
if isinstance(module, MockMutable):
|
||||
assert module.choices == choices
|
||||
mutable_nums += 1
|
||||
|
||||
arch_setting = backbone.arch_setting
|
||||
target_mutable_nums = 0
|
||||
for layer_cfg in arch_setting:
|
||||
target_mutable_nums += layer_cfg[1]
|
||||
assert mutable_nums == target_mutable_nums
|
||||
|
||||
|
||||
def test_searchable_shufflenet_v2_train() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
backbone.train(mode=True)
|
||||
for m in backbone.modules():
|
||||
assert m.training
|
||||
|
||||
backbone.norm_eval = True
|
||||
backbone.train(mode=True)
|
||||
for m in backbone.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
assert not m.training
|
||||
else:
|
||||
assert m.training
|
||||
|
||||
x = torch.rand(10, 3, 224, 224)
|
||||
assert len(backbone(x)) == 1
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['frozen_stages'] = 2
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
backbone.train()
|
||||
|
||||
for param in backbone.conv1.parameters():
|
||||
assert not param.requires_grad
|
||||
for i in range(2):
|
||||
layer = backbone.layers[i]
|
||||
for m in layer.modules():
|
||||
if i < 2:
|
||||
assert not m.training
|
||||
else:
|
||||
assert m.training
|
||||
for param in layer.parameters():
|
||||
if i < 2:
|
||||
assert not param.requires_grad
|
||||
else:
|
||||
assert param.requires_grad
|
||||
|
||||
|
||||
def test_searchable_shufflenet_v2_init() -> None:
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['out_indices'] = (5, )
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MODELS.build(backbone_cfg)
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['frozen_stages'] = 5
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MODELS.build(backbone_cfg)
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['with_last_layer'] = False
|
||||
with pytest.raises(ValueError):
|
||||
MODELS.build(backbone_cfg)
|
||||
|
||||
backbone_cfg['out_indices'] = (3, )
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
assert len(backbone.layers) == 4
|
||||
|
||||
|
||||
def test_searchable_shufflenet_v2_init_weights() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
backbone.init_weights()
|
||||
|
||||
for m in backbone.modules():
|
||||
if isinstance(m, (_BatchNorm, GroupNorm)):
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
assert torch.equal(m.weight, torch.ones_like(m.weight))
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
bias_tensor = torch.ones_like(m.bias)
|
||||
bias_tensor *= 0.0001
|
||||
assert torch.equal(bias_tensor, m.bias)
|
||||
|
||||
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['init_cfg'] = dict(
|
||||
type='Pretrained', checkpoint=checkpoint_path)
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
torch.save(backbone.state_dict(), checkpoint_path)
|
||||
|
||||
name2weight = dict()
|
||||
for name, m in backbone.named_modules():
|
||||
if isinstance(m, (_BatchNorm, GroupNorm)):
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
name2weight[name] = m.weight.clone()
|
||||
|
||||
backbone.init_weights()
|
||||
for name, m in backbone.named_modules():
|
||||
if isinstance(m, (_BatchNorm, GroupNorm)):
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
if name in name2weight:
|
||||
assert torch.equal(name2weight[name], m.weight)
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['norm_cfg'] = dict(type='BN', track_running_stats=False)
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
backbone.init_weights()
|
||||
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone_cfg['norm_cfg'] = dict(type='GN', num_groups=1)
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
backbone.init_weights()
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Conv2d, Module
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MockMutable(Module):
|
||||
|
||||
def __init__(self, choices: List[str], module_kwargs: Dict) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.choices = choices
|
||||
self.module_kwargs = module_kwargs
|
||||
self.conv = Conv2d(**module_kwargs, kernel_size=3, padding=3 // 2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv(x)
|
Loading…
Reference in New Issue