Refactor backbone

pull/198/head
qiufeng 2022-06-01 07:12:52 +00:00 committed by pppppM
parent 910b131183
commit 99e7993376
13 changed files with 883 additions and 1 deletions

View File

@ -67,4 +67,5 @@ repos:
(?x)(
^test
| ^docs
| ^configs
)

View File

@ -0,0 +1,62 @@
_FIRST_STAGE_MUTABLE = dict(
type='OneShotOP',
candidate_ops=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'))))
_OTHER_STAGE_MUTABLE = dict(
type='OneShotOP',
candidate_ops=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k7e3=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k3e6=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e6=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k7e6=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity')))
arch_setting = [
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
[16, 1, 1, _FIRST_STAGE_MUTABLE],
[24, 2, 2, _OTHER_STAGE_MUTABLE],
[32, 3, 2, _OTHER_STAGE_MUTABLE],
[64, 4, 2, _OTHER_STAGE_MUTABLE],
[96, 3, 1, _OTHER_STAGE_MUTABLE],
[160, 3, 2, _OTHER_STAGE_MUTABLE],
[320, 1, 1, _OTHER_STAGE_MUTABLE]
]

View File

@ -0,0 +1,62 @@
_FIRST_STAGE_MUTABLE = dict(
type='OneShotOP',
candidate_ops=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'))))
_OTHER_STAGE_MUTABLE = dict(
type='OneShotOP',
candidate_ops=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k7e3=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k3e6=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e6=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k7e6=dict(
type='MBBlock',
kernel_size=7,
expand_ratio=6,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity')))
arch_setting = [
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
[24, 1, 1, _FIRST_STAGE_MUTABLE],
[32, 4, 2, _OTHER_STAGE_MUTABLE],
[56, 4, 2, _OTHER_STAGE_MUTABLE],
[112, 4, 2, _OTHER_STAGE_MUTABLE],
[128, 4, 1, _OTHER_STAGE_MUTABLE],
[256, 4, 2, _OTHER_STAGE_MUTABLE],
[432, 1, 1, _OTHER_STAGE_MUTABLE]
]

View File

@ -0,0 +1,21 @@
_STAGE_MUTABLE = dict(
type='OneShotOP',
candidate_ops=dict(
shuffle_3x3=dict(
type='ShuffleBlock', kernel_size=3, norm_cfg=dict(type='BN')),
shuffle_5x5=dict(
type='ShuffleBlock', kernel_size=5, norm_cfg=dict(type='BN')),
shuffle_7x7=dict(
type='ShuffleBlock', kernel_size=7, norm_cfg=dict(type='BN')),
shuffle_xception=dict(
type='ShuffleXception', norm_cfg=dict(type='BN')),
))
arch_setting = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, _STAGE_MUTABLE],
[160, 4, _STAGE_MUTABLE],
[320, 8, _STAGE_MUTABLE],
[640, 4, _STAGE_MUTABLE],
]

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .components import * # noqa: F401,F403
from .mmcls import MMClsArchitecture
from .mmdet import MMDetArchitecture

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .searchable_mobilenet import SearchableMobileNet
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2']

View File

@ -0,0 +1,222 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Sequence, Tuple, Union
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import make_divisible
from mmcv.cnn import ConvModule
from mmcv.runner import Sequential
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.registry import MODELS
@MODELS.register_module()
class SearchableMobileNet(BaseBackbone):
"""Searchable MobileNet backbone.
Args:
arch_setting (list[list]): Architecture settings.
first_channels (int): Channel width of first ConvModule. Default: 32.
last_channels (int): Channel width of last ConvModule. Default: 1200.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
init_cfg (dict | list[dict], optional): initialization configuration
dict to define initializer. OpenMMLab has implemented
6 initializers, including ``Constant``, ``Xavier``, ``Normal``,
``Uniform``, ``Kaiming``, and ``Pretrained``.
Excamples:
>>> mutable_cfg = dict(
... type='OneShotOP',
... candidate_ops=dict(
... mb_k3e1=dict(
... type='MBBlock',
... kernel_size=3,
... expand_ratio=1,
... norm_cfg=dict(type='BN'),
... act_cfg=dict(type='ReLU6'))))
>>> arch_setting = [
... # Parameters to build layers. 4 parameters are needed to
... # construct a layer, from left to right:
... # channel, num_blocks, stride, mutable cfg.
... [16, 1, 1, mutable_cfg],
... [24, 2, 2, mutable_cfg],
... [32, 3, 2, mutable_cfg],
... [64, 4, 2, mutable_cfg],
... [96, 3, 1, mutable_cfg],
... [160, 3, 2, mutable_cfg],
... [320, 1, 1, mutable_cfg]
... ]
>>> model = SearchableMobileNet(arch_setting=arch_setting)
"""
def __init__(
self,
arch_setting: List[List],
first_channels: int = 32,
last_channels: int = 1280,
widen_factor: float = 1.,
out_indices: Sequence[int] = (7, ),
frozen_stages: int = -1,
conv_cfg: Optional[Dict] = None,
norm_cfg: Dict = dict(type='BN'),
act_cfg: Dict = dict(type='ReLU6'),
norm_eval: bool = False,
with_cp: bool = False,
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
) -> None:
for index in out_indices:
if index not in range(0, 8):
raise ValueError('the item in out_indices must in '
f'range(0, 8). But received {index}')
if frozen_stages not in range(-1, 8):
raise ValueError('frozen_stages must be in range(-1, 8). '
f'But received {frozen_stages}')
super().__init__(init_cfg)
self.arch_setting = arch_setting
self.widen_factor = widen_factor
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(first_channels * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(arch_setting):
channel, num_blocks, stride, mutable_cfg = layer_cfg
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self._make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
mutable_cfg=copy.deepcopy(mutable_cfg))
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
if widen_factor > 1.0:
self.out_channel = int(last_channels * widen_factor)
else:
self.out_channel = last_channels
layer = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channel,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.add_module('conv2', layer)
self.layers.append('conv2')
def _make_layer(self, out_channels: int, num_blocks: int, stride: int,
mutable_cfg: Dict) -> Sequential:
"""Stack mutable blocks to build a layer for SearchableMobileNet.
Note:
Here we use ``module_kwargs`` to pass dynamic parameters such as
``in_channels``, ``out_channels`` and ``stride``
to build the mutable.
Args:
out_channels (int): out_channels of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block.
mutable_cfg (dict): Config of mutable.
Returns:
mmcv.runner.Sequential: The layer made.
"""
layers = []
for i in range(num_blocks):
if i >= 1:
stride = 1
mutable_cfg.update(
module_kwargs=dict(
in_channels=self.in_channels,
out_channels=out_channels,
stride=stride))
layers.append(MODELS.build(mutable_cfg))
self.in_channels = out_channels
return Sequential(*layers)
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward computation.
Args:
x (tensor): x contains input data for forward computation.
"""
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self) -> None:
"""Freeze params not to update in the specified stages."""
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode: bool = True) -> None:
"""Set module status before forward computation."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -0,0 +1,219 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch.nn as nn
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import ModuleList, Sequential
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.registry import MODELS
@MODELS.register_module()
class SearchableShuffleNetV2(BaseBackbone):
"""Based on ShuffleNetV2 backbone.
Args:
arch_setting (list[list]): Architecture settings.
stem_multiplier (int): Stem multiplier - adjusts the number of
channels in the first layer. Default: 1.
widen_factor (float): Width multiplier - adjusts the number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (4, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
with_last_layer (bool): Whether is last layer.
Default: True, which means not need to add `Placeholder``.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
init_cfg (dict | list[dict], optional): initialization configuration
dict to define initializer. OpenMMLab has implemented
6 initializers, including ``Constant``, ``Xavier``, ``Normal``,
``Uniform``, ``Kaiming``, and ``Pretrained``.
Excamples:
>>> mutable_cfg = dict(
... type='OneShotOP',
... candidate_ops=dict(
... shuffle_3x3=dict(
... type='ShuffleBlock',
... kernel_size=3,
... norm_cfg=dict(type='BN'))))
>>> arch_setting = [
... # Parameters to build layers. 3 parameters are needed to
... # construct a layer, from left to right:
... # channel, num_blocks, mutable cfg.
... [64, 4, mutable_cfg],
... [160, 4, mutable_cfg],
... [320, 8, mutable_cfg],
... [640, 4, mutable_cfg]
... ]
>>> model = SearchableShuffleNetV2(arch_setting=arch_setting)
"""
def __init__(self,
arch_setting: List[List],
stem_multiplier: int = 1,
widen_factor: float = 1.0,
out_indices: Sequence[int] = (4, ),
frozen_stages: int = -1,
with_last_layer: bool = True,
conv_cfg: Optional[Dict] = None,
norm_cfg: Dict = dict(type='BN'),
act_cfg: Dict = dict(type='ReLU'),
norm_eval: bool = False,
with_cp: bool = False,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
layers_nums = 5 if with_last_layer else 4
for index in out_indices:
if index not in range(0, layers_nums):
raise ValueError('the item in out_indices must in '
f'range(0, 5). But received {index}')
self.frozen_stages = frozen_stages
if frozen_stages not in range(-1, layers_nums):
raise ValueError('frozen_stages must be in range(-1, 5). '
f'But received {frozen_stages}')
super().__init__(init_cfg)
self.arch_setting = arch_setting
self.widen_factor = widen_factor
self.out_indices = out_indices
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
last_channels = 1024
self.in_channels = 16 * stem_multiplier
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.layers = ModuleList()
for channel, num_blocks, mutable_cfg in arch_setting:
out_channels = round(channel * widen_factor)
layer = self._make_layer(out_channels, num_blocks,
copy.deepcopy(mutable_cfg))
self.layers.append(layer)
if with_last_layer:
self.layers.append(
ConvModule(
in_channels=self.in_channels,
out_channels=last_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def _make_layer(self, out_channels: int, num_blocks: int,
mutable_cfg: Dict) -> Sequential:
"""Stack mutable blocks to build a layer for ShuffleNet V2.
Note:
Here we use ``module_kwargs`` to pass dynamic parameters such as
``in_channels``, ``out_channels`` and ``stride``
to build the mutable.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): number of blocks.
mutable_cfg (dict): Config of mutable.
Returns:
mmcv.runner.Sequential: The layer made.
"""
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
mutable_cfg.update(
module_kwargs=dict(
in_channels=self.in_channels,
out_channels=out_channels,
stride=stride))
layers.append(MODELS.build(mutable_cfg))
self.in_channels = out_channels
return Sequential(*layers)
def _freeze_stages(self) -> None:
"""Freeze params not to update in the specified stages."""
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self) -> None:
"""Init weights of ``SearchableShuffleNetV2``."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward computation.
Args:
x (tensor): x contains input data for forward computation.
"""
x = self.conv1(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode: bool = True) -> None:
"""Set module status before forward computation."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -1,4 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .necks import * # noqa: F401,F403

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,113 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import pytest
import torch
from mmcls.models import * # noqa: F401,F403
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import * # noqa: F401,F403
from mmrazor.registry import MODELS
from .utils import MockMutable
_FIRST_STAGE_MUTABLE = dict(type='MockMutable', choices=['c1'])
_OTHER_STAGE_MUTABLE = dict(
type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
ARCHSETTING_CFG = [
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
[16, 1, 1, _FIRST_STAGE_MUTABLE],
[24, 2, 2, _OTHER_STAGE_MUTABLE],
[32, 3, 2, _OTHER_STAGE_MUTABLE],
[64, 4, 2, _OTHER_STAGE_MUTABLE],
[96, 3, 1, _OTHER_STAGE_MUTABLE],
[160, 3, 2, _OTHER_STAGE_MUTABLE],
[320, 1, 1, _OTHER_STAGE_MUTABLE]
]
NORM_CFG = dict(type='BN')
BACKBONE_CFG = dict(
type='mmrazor.SearchableMobileNet',
first_channels=32,
last_channels=1280,
widen_factor=1.0,
norm_cfg=NORM_CFG,
arch_setting=ARCHSETTING_CFG)
def test_searchable_mobilenet_mutable() -> None:
backbone = MODELS.build(BACKBONE_CFG)
choices = ['c1', 'c2', 'c3', 'c4']
mutable_nums = 0
for name, module in backbone.named_modules():
if isinstance(module, MockMutable):
if 'layer1' in name:
assert module.choices == ['c1']
else:
assert module.choices == choices
mutable_nums += 1
arch_setting = backbone.arch_setting
target_mutable_nums = 0
for layer_cfg in arch_setting:
target_mutable_nums += layer_cfg[1]
assert mutable_nums == target_mutable_nums
def test_searchable_mobilenet_train() -> None:
backbone = MODELS.build(BACKBONE_CFG)
backbone.train(mode=True)
for m in backbone.modules():
assert m.training
backbone.norm_eval = True
backbone.train(mode=True)
for m in backbone.modules():
if isinstance(m, _BatchNorm):
assert not m.training
else:
assert m.training
x = torch.rand(10, 3, 224, 224)
assert len(backbone(x)) == 1
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['frozen_stages'] = 5
backbone = MODELS.build(backbone_cfg)
backbone.train()
for param in backbone.conv1.parameters():
assert not param.requires_grad
for i in range(1, 8):
layer = getattr(backbone, f'layer{i}')
for m in layer.modules():
if i <= 5:
assert not m.training
else:
assert m.training
for param in layer.parameters():
if i <= 5:
assert not param.requires_grad
else:
assert param.requires_grad
def test_searchable_mobilenet_init() -> None:
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['out_indices'] = (10, )
with pytest.raises(ValueError):
MODELS.build(backbone_cfg)
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['frozen_stages'] = 8
with pytest.raises(ValueError):
MODELS.build(backbone_cfg)
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['widen_factor'] = 1.5
backbone = MODELS.build(backbone_cfg)
assert backbone.out_channel == 1920

View File

@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import tempfile
import pytest
import torch
from mmcls.models import * # noqa: F401,F403
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import * # noqa: F401,F403
from mmrazor.registry import MODELS
from .utils import MockMutable
STAGE_MUTABLE = dict(type='MockMutable', choices=['c1', 'c2', 'c3', 'c4'])
ARCHSETTING_CFG = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, STAGE_MUTABLE],
[160, 4, STAGE_MUTABLE],
[320, 8, STAGE_MUTABLE],
[640, 4, STAGE_MUTABLE],
]
NORM_CFG = dict(type='BN')
BACKBONE_CFG = dict(
type='mmrazor.SearchableShuffleNetV2',
widen_factor=1.0,
norm_cfg=NORM_CFG,
arch_setting=ARCHSETTING_CFG)
def test_searchable_shufflenet_v2_mutable() -> None:
backbone = MODELS.build(BACKBONE_CFG)
choices = ['c1', 'c2', 'c3', 'c4']
mutable_nums = 0
for module in backbone.modules():
if isinstance(module, MockMutable):
assert module.choices == choices
mutable_nums += 1
arch_setting = backbone.arch_setting
target_mutable_nums = 0
for layer_cfg in arch_setting:
target_mutable_nums += layer_cfg[1]
assert mutable_nums == target_mutable_nums
def test_searchable_shufflenet_v2_train() -> None:
backbone = MODELS.build(BACKBONE_CFG)
backbone.train(mode=True)
for m in backbone.modules():
assert m.training
backbone.norm_eval = True
backbone.train(mode=True)
for m in backbone.modules():
if isinstance(m, _BatchNorm):
assert not m.training
else:
assert m.training
x = torch.rand(10, 3, 224, 224)
assert len(backbone(x)) == 1
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['frozen_stages'] = 2
backbone = MODELS.build(backbone_cfg)
backbone.train()
for param in backbone.conv1.parameters():
assert not param.requires_grad
for i in range(2):
layer = backbone.layers[i]
for m in layer.modules():
if i < 2:
assert not m.training
else:
assert m.training
for param in layer.parameters():
if i < 2:
assert not param.requires_grad
else:
assert param.requires_grad
def test_searchable_shufflenet_v2_init() -> None:
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['out_indices'] = (5, )
with pytest.raises(ValueError):
MODELS.build(backbone_cfg)
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['frozen_stages'] = 5
with pytest.raises(ValueError):
MODELS.build(backbone_cfg)
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['with_last_layer'] = False
with pytest.raises(ValueError):
MODELS.build(backbone_cfg)
backbone_cfg['out_indices'] = (3, )
backbone = MODELS.build(backbone_cfg)
assert len(backbone.layers) == 4
def test_searchable_shufflenet_v2_init_weights() -> None:
backbone = MODELS.build(BACKBONE_CFG)
backbone.init_weights()
for m in backbone.modules():
if isinstance(m, (_BatchNorm, GroupNorm)):
if hasattr(m, 'weight') and m.weight is not None:
assert torch.equal(m.weight, torch.ones_like(m.weight))
if hasattr(m, 'bias') and m.bias is not None:
bias_tensor = torch.ones_like(m.bias)
bias_tensor *= 0.0001
assert torch.equal(bias_tensor, m.bias)
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['init_cfg'] = dict(
type='Pretrained', checkpoint=checkpoint_path)
backbone = MODELS.build(backbone_cfg)
torch.save(backbone.state_dict(), checkpoint_path)
name2weight = dict()
for name, m in backbone.named_modules():
if isinstance(m, (_BatchNorm, GroupNorm)):
if hasattr(m, 'weight') and m.weight is not None:
name2weight[name] = m.weight.clone()
backbone.init_weights()
for name, m in backbone.named_modules():
if isinstance(m, (_BatchNorm, GroupNorm)):
if hasattr(m, 'weight') and m.weight is not None:
if name in name2weight:
assert torch.equal(name2weight[name], m.weight)
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['norm_cfg'] = dict(type='BN', track_running_stats=False)
backbone = MODELS.build(backbone_cfg)
backbone.init_weights()
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone_cfg['norm_cfg'] = dict(type='GN', num_groups=1)
backbone = MODELS.build(backbone_cfg)
backbone.init_weights()

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from torch import Tensor
from torch.nn import Conv2d, Module
from mmrazor.registry import MODELS
@MODELS.register_module()
class MockMutable(Module):
def __init__(self, choices: List[str], module_kwargs: Dict) -> None:
super().__init__()
self.choices = choices
self.module_kwargs = module_kwargs
self.conv = Conv2d(**module_kwargs, kernel_size=3, padding=3 // 2)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)