diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4e208de..fb1f746a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,4 +67,5 @@ repos: (?x)( ^test | ^docs + | ^configs ) diff --git a/configs/_base_/models/arch_settings/mobilenet/original.py b/configs/_base_/models/arch_settings/mobilenet/original.py new file mode 100644 index 00000000..bb2a97a5 --- /dev/null +++ b/configs/_base_/models/arch_settings/mobilenet/original.py @@ -0,0 +1,62 @@ +_FIRST_STAGE_MUTABLE = dict( + type='OneShotOP', + candidate_ops=dict( + mb_k3e1=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')))) + +_OTHER_STAGE_MUTABLE = dict( + type='OneShotOP', + candidate_ops=dict( + mb_k3e3=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k5e3=dict( + type='MBBlock', + kernel_size=5, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k7e3=dict( + type='MBBlock', + kernel_size=7, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k3e6=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k5e6=dict( + type='MBBlock', + kernel_size=5, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k7e6=dict( + type='MBBlock', + kernel_size=7, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + identity=dict(type='Identity'))) + +arch_setting = [ + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, stride, mutable cfg. + [16, 1, 1, _FIRST_STAGE_MUTABLE], + [24, 2, 2, _OTHER_STAGE_MUTABLE], + [32, 3, 2, _OTHER_STAGE_MUTABLE], + [64, 4, 2, _OTHER_STAGE_MUTABLE], + [96, 3, 1, _OTHER_STAGE_MUTABLE], + [160, 3, 2, _OTHER_STAGE_MUTABLE], + [320, 1, 1, _OTHER_STAGE_MUTABLE] +] diff --git a/configs/_base_/models/arch_settings/mobilenet/proxyless_gpu.py b/configs/_base_/models/arch_settings/mobilenet/proxyless_gpu.py new file mode 100644 index 00000000..25b5b704 --- /dev/null +++ b/configs/_base_/models/arch_settings/mobilenet/proxyless_gpu.py @@ -0,0 +1,62 @@ +_FIRST_STAGE_MUTABLE = dict( + type='OneShotOP', + candidate_ops=dict( + mb_k3e1=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')))) + +_OTHER_STAGE_MUTABLE = dict( + type='OneShotOP', + candidate_ops=dict( + mb_k3e3=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k5e3=dict( + type='MBBlock', + kernel_size=5, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k7e3=dict( + type='MBBlock', + kernel_size=7, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k3e6=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k5e6=dict( + type='MBBlock', + kernel_size=5, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k7e6=dict( + type='MBBlock', + kernel_size=7, + expand_ratio=6, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + identity=dict(type='Identity'))) + +arch_setting = [ + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, stride, mutable cfg. + [24, 1, 1, _FIRST_STAGE_MUTABLE], + [32, 4, 2, _OTHER_STAGE_MUTABLE], + [56, 4, 2, _OTHER_STAGE_MUTABLE], + [112, 4, 2, _OTHER_STAGE_MUTABLE], + [128, 4, 1, _OTHER_STAGE_MUTABLE], + [256, 4, 2, _OTHER_STAGE_MUTABLE], + [432, 1, 1, _OTHER_STAGE_MUTABLE] +] diff --git a/configs/_base_/models/arch_settings/shufflenet_v2/original.py b/configs/_base_/models/arch_settings/shufflenet_v2/original.py new file mode 100644 index 00000000..2f34525a --- /dev/null +++ b/configs/_base_/models/arch_settings/shufflenet_v2/original.py @@ -0,0 +1,21 @@ +_STAGE_MUTABLE = dict( + type='OneShotOP', + candidate_ops=dict( + shuffle_3x3=dict( + type='ShuffleBlock', kernel_size=3, norm_cfg=dict(type='BN')), + shuffle_5x5=dict( + type='ShuffleBlock', kernel_size=5, norm_cfg=dict(type='BN')), + shuffle_7x7=dict( + type='ShuffleBlock', kernel_size=7, norm_cfg=dict(type='BN')), + shuffle_xception=dict( + type='ShuffleXception', norm_cfg=dict(type='BN')), + )) + +arch_setting = [ + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, mutable_cfg. + [64, 4, _STAGE_MUTABLE], + [160, 4, _STAGE_MUTABLE], + [320, 8, _STAGE_MUTABLE], + [640, 4, _STAGE_MUTABLE], +] diff --git a/mmrazor/models/architectures/__init__.py b/mmrazor/models/architectures/__init__.py index f3630633..89fc51f5 100644 --- a/mmrazor/models/architectures/__init__.py +++ b/mmrazor/models/architectures/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 from .components import * # noqa: F401,F403 from .mmcls import MMClsArchitecture from .mmdet import MMDetArchitecture diff --git a/mmrazor/models/architectures/backbones/__init__.py b/mmrazor/models/architectures/backbones/__init__.py new file mode 100644 index 00000000..9ceb3265 --- /dev/null +++ b/mmrazor/models/architectures/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .searchable_mobilenet import SearchableMobileNet +from .searchable_shufflenet_v2 import SearchableShuffleNetV2 + +__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2'] diff --git a/mmrazor/models/architectures/backbones/searchable_mobilenet.py b/mmrazor/models/architectures/backbones/searchable_mobilenet.py new file mode 100644 index 00000000..40faedb2 --- /dev/null +++ b/mmrazor/models/architectures/backbones/searchable_mobilenet.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from mmcls.models.backbones.base_backbone import BaseBackbone +from mmcls.models.utils import make_divisible +from mmcv.cnn import ConvModule +from mmcv.runner import Sequential +from torch import Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class SearchableMobileNet(BaseBackbone): + """Searchable MobileNet backbone. + + Args: + arch_setting (list[list]): Architecture settings. + first_channels (int): Channel width of first ConvModule. Default: 32. + last_channels (int): Channel width of last ConvModule. Default: 1200. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict | list[dict], optional): initialization configuration + dict to define initializer. OpenMMLab has implemented + 6 initializers, including ``Constant``, ``Xavier``, ``Normal``, + ``Uniform``, ``Kaiming``, and ``Pretrained``. + + Excamples: + >>> mutable_cfg = dict( + ... type='OneShotOP', + ... candidate_ops=dict( + ... mb_k3e1=dict( + ... type='MBBlock', + ... kernel_size=3, + ... expand_ratio=1, + ... norm_cfg=dict(type='BN'), + ... act_cfg=dict(type='ReLU6')))) + >>> arch_setting = [ + ... # Parameters to build layers. 4 parameters are needed to + ... # construct a layer, from left to right: + ... # channel, num_blocks, stride, mutable cfg. + ... [16, 1, 1, mutable_cfg], + ... [24, 2, 2, mutable_cfg], + ... [32, 3, 2, mutable_cfg], + ... [64, 4, 2, mutable_cfg], + ... [96, 3, 1, mutable_cfg], + ... [160, 3, 2, mutable_cfg], + ... [320, 1, 1, mutable_cfg] + ... ] + >>> model = SearchableMobileNet(arch_setting=arch_setting) + """ + + def __init__( + self, + arch_setting: List[List], + first_channels: int = 32, + last_channels: int = 1280, + widen_factor: float = 1., + out_indices: Sequence[int] = (7, ), + frozen_stages: int = -1, + conv_cfg: Optional[Dict] = None, + norm_cfg: Dict = dict(type='BN'), + act_cfg: Dict = dict(type='ReLU6'), + norm_eval: bool = False, + with_cp: bool = False, + init_cfg: Optional[Union[Dict, List[Dict]]] = [ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + + super().__init__(init_cfg) + + self.arch_setting = arch_setting + self.widen_factor = widen_factor + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(first_channels * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(arch_setting): + channel, num_blocks, stride, mutable_cfg = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self._make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + mutable_cfg=copy.deepcopy(mutable_cfg)) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(last_channels * widen_factor) + else: + self.out_channel = last_channels + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def _make_layer(self, out_channels: int, num_blocks: int, stride: int, + mutable_cfg: Dict) -> Sequential: + """Stack mutable blocks to build a layer for SearchableMobileNet. + + Note: + Here we use ``module_kwargs`` to pass dynamic parameters such as + ``in_channels``, ``out_channels`` and ``stride`` + to build the mutable. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. + mutable_cfg (dict): Config of mutable. + + Returns: + mmcv.runner.Sequential: The layer made. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + + mutable_cfg.update( + module_kwargs=dict( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride)) + layers.append(MODELS.build(mutable_cfg)) + + self.in_channels = out_channels + + return Sequential(*layers) + + def forward(self, x: Tensor) -> Tuple[Tensor, ...]: + """Forward computation. + + Args: + x (tensor): x contains input data for forward computation. + """ + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self) -> None: + """Freeze params not to update in the specified stages.""" + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> None: + """Set module status before forward computation.""" + super().train(mode) + + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmrazor/models/architectures/backbones/searchable_shufflenet_v2.py b/mmrazor/models/architectures/backbones/searchable_shufflenet_v2.py new file mode 100644 index 00000000..6515ddcb --- /dev/null +++ b/mmrazor/models/architectures/backbones/searchable_shufflenet_v2.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch.nn as nn +from mmcls.models.backbones.base_backbone import BaseBackbone +from mmcv.cnn import ConvModule, constant_init, normal_init +from mmcv.runner import ModuleList, Sequential +from torch import Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class SearchableShuffleNetV2(BaseBackbone): + """Based on ShuffleNetV2 backbone. + + Args: + arch_setting (list[list]): Architecture settings. + stem_multiplier (int): Stem multiplier - adjusts the number of + channels in the first layer. Default: 1. + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + with_last_layer (bool): Whether is last layer. + Default: True, which means not need to add `Placeholder``. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict | list[dict], optional): initialization configuration + dict to define initializer. OpenMMLab has implemented + 6 initializers, including ``Constant``, ``Xavier``, ``Normal``, + ``Uniform``, ``Kaiming``, and ``Pretrained``. + + Excamples: + >>> mutable_cfg = dict( + ... type='OneShotOP', + ... candidate_ops=dict( + ... shuffle_3x3=dict( + ... type='ShuffleBlock', + ... kernel_size=3, + ... norm_cfg=dict(type='BN')))) + >>> arch_setting = [ + ... # Parameters to build layers. 3 parameters are needed to + ... # construct a layer, from left to right: + ... # channel, num_blocks, mutable cfg. + ... [64, 4, mutable_cfg], + ... [160, 4, mutable_cfg], + ... [320, 8, mutable_cfg], + ... [640, 4, mutable_cfg] + ... ] + >>> model = SearchableShuffleNetV2(arch_setting=arch_setting) + """ + + def __init__(self, + arch_setting: List[List], + stem_multiplier: int = 1, + widen_factor: float = 1.0, + out_indices: Sequence[int] = (4, ), + frozen_stages: int = -1, + with_last_layer: bool = True, + conv_cfg: Optional[Dict] = None, + norm_cfg: Dict = dict(type='BN'), + act_cfg: Dict = dict(type='ReLU'), + norm_eval: bool = False, + with_cp: bool = False, + init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: + layers_nums = 5 if with_last_layer else 4 + for index in out_indices: + if index not in range(0, layers_nums): + raise ValueError('the item in out_indices must in ' + f'range(0, 5). But received {index}') + + self.frozen_stages = frozen_stages + if frozen_stages not in range(-1, layers_nums): + raise ValueError('frozen_stages must be in range(-1, 5). ' + f'But received {frozen_stages}') + + super().__init__(init_cfg) + + self.arch_setting = arch_setting + self.widen_factor = widen_factor + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + last_channels = 1024 + self.in_channels = 16 * stem_multiplier + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.layers = ModuleList() + for channel, num_blocks, mutable_cfg in arch_setting: + out_channels = round(channel * widen_factor) + layer = self._make_layer(out_channels, num_blocks, + copy.deepcopy(mutable_cfg)) + self.layers.append(layer) + + if with_last_layer: + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=last_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels: int, num_blocks: int, + mutable_cfg: Dict) -> Sequential: + """Stack mutable blocks to build a layer for ShuffleNet V2. + + Note: + Here we use ``module_kwargs`` to pass dynamic parameters such as + ``in_channels``, ``out_channels`` and ``stride`` + to build the mutable. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + mutable_cfg (dict): Config of mutable. + + Returns: + mmcv.runner.Sequential: The layer made. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + + mutable_cfg.update( + module_kwargs=dict( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride)) + layers.append(MODELS.build(mutable_cfg)) + self.in_channels = out_channels + + return Sequential(*layers) + + def _freeze_stages(self) -> None: + """Freeze params not to update in the specified stages.""" + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self) -> None: + """Init weights of ``SearchableShuffleNetV2``.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def forward(self, x: Tensor) -> Tuple[Tensor, ...]: + """Forward computation. + + Args: + x (tensor): x contains input data for forward computation. + """ + x = self.conv1(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode: bool = True) -> None: + """Set module status before forward computation.""" + super().train(mode) + + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmrazor/models/architectures/components/__init__.py b/mmrazor/models/architectures/components/__init__.py index cfa1c5fc..4a2ee4a4 100644 --- a/mmrazor/models/architectures/components/__init__.py +++ b/mmrazor/models/architectures/components/__init__.py @@ -1,4 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backbones import * # noqa: F401,F403 from .heads import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 diff --git a/tests/test_models/test_architectures/test_backbones/__init__.py b/tests/test_models/test_architectures/test_backbones/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_architectures/test_backbones/test_searchable_mobilenet.py b/tests/test_models/test_architectures/test_backbones/test_searchable_mobilenet.py new file mode 100644 index 00000000..dc36c3a7 --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/test_searchable_mobilenet.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import pytest +import torch +from mmcls.models import * # noqa: F401,F403 +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models import * # noqa: F401,F403 +from mmrazor.models.mutables import * # noqa: F401,F403 +from mmrazor.registry import MODELS +from .utils import MockMutable + +_FIRST_STAGE_MUTABLE = dict(type='MockMutable', choices=['c1']) +_OTHER_STAGE_MUTABLE = dict( + type='MockMutable', choices=['c1', 'c2', 'c3', 'c4']) +ARCHSETTING_CFG = [ + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, stride, mutable cfg. + [16, 1, 1, _FIRST_STAGE_MUTABLE], + [24, 2, 2, _OTHER_STAGE_MUTABLE], + [32, 3, 2, _OTHER_STAGE_MUTABLE], + [64, 4, 2, _OTHER_STAGE_MUTABLE], + [96, 3, 1, _OTHER_STAGE_MUTABLE], + [160, 3, 2, _OTHER_STAGE_MUTABLE], + [320, 1, 1, _OTHER_STAGE_MUTABLE] +] +NORM_CFG = dict(type='BN') +BACKBONE_CFG = dict( + type='mmrazor.SearchableMobileNet', + first_channels=32, + last_channels=1280, + widen_factor=1.0, + norm_cfg=NORM_CFG, + arch_setting=ARCHSETTING_CFG) + + +def test_searchable_mobilenet_mutable() -> None: + backbone = MODELS.build(BACKBONE_CFG) + + choices = ['c1', 'c2', 'c3', 'c4'] + mutable_nums = 0 + + for name, module in backbone.named_modules(): + if isinstance(module, MockMutable): + if 'layer1' in name: + assert module.choices == ['c1'] + else: + assert module.choices == choices + mutable_nums += 1 + + arch_setting = backbone.arch_setting + target_mutable_nums = 0 + for layer_cfg in arch_setting: + target_mutable_nums += layer_cfg[1] + assert mutable_nums == target_mutable_nums + + +def test_searchable_mobilenet_train() -> None: + backbone = MODELS.build(BACKBONE_CFG) + backbone.train(mode=True) + for m in backbone.modules(): + assert m.training + + backbone.norm_eval = True + backbone.train(mode=True) + for m in backbone.modules(): + if isinstance(m, _BatchNorm): + assert not m.training + else: + assert m.training + + x = torch.rand(10, 3, 224, 224) + assert len(backbone(x)) == 1 + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['frozen_stages'] = 5 + backbone = MODELS.build(backbone_cfg) + backbone.train() + + for param in backbone.conv1.parameters(): + assert not param.requires_grad + for i in range(1, 8): + layer = getattr(backbone, f'layer{i}') + for m in layer.modules(): + if i <= 5: + assert not m.training + else: + assert m.training + for param in layer.parameters(): + if i <= 5: + assert not param.requires_grad + else: + assert param.requires_grad + + +def test_searchable_mobilenet_init() -> None: + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['out_indices'] = (10, ) + + with pytest.raises(ValueError): + MODELS.build(backbone_cfg) + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['frozen_stages'] = 8 + + with pytest.raises(ValueError): + MODELS.build(backbone_cfg) + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['widen_factor'] = 1.5 + backbone = MODELS.build(backbone_cfg) + assert backbone.out_channel == 1920 diff --git a/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py b/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py new file mode 100644 index 00000000..0c600d9b --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import tempfile + +import pytest +import torch +from mmcls.models import * # noqa: F401,F403 +from torch.nn import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models import * # noqa: F401,F403 +from mmrazor.models.mutables import * # noqa: F401,F403 +from mmrazor.registry import MODELS +from .utils import MockMutable + +STAGE_MUTABLE = dict(type='MockMutable', choices=['c1', 'c2', 'c3', 'c4']) +ARCHSETTING_CFG = [ + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, mutable_cfg. + [64, 4, STAGE_MUTABLE], + [160, 4, STAGE_MUTABLE], + [320, 8, STAGE_MUTABLE], + [640, 4, STAGE_MUTABLE], +] + +NORM_CFG = dict(type='BN') +BACKBONE_CFG = dict( + type='mmrazor.SearchableShuffleNetV2', + widen_factor=1.0, + norm_cfg=NORM_CFG, + arch_setting=ARCHSETTING_CFG) + + +def test_searchable_shufflenet_v2_mutable() -> None: + backbone = MODELS.build(BACKBONE_CFG) + + choices = ['c1', 'c2', 'c3', 'c4'] + mutable_nums = 0 + + for module in backbone.modules(): + if isinstance(module, MockMutable): + assert module.choices == choices + mutable_nums += 1 + + arch_setting = backbone.arch_setting + target_mutable_nums = 0 + for layer_cfg in arch_setting: + target_mutable_nums += layer_cfg[1] + assert mutable_nums == target_mutable_nums + + +def test_searchable_shufflenet_v2_train() -> None: + backbone = MODELS.build(BACKBONE_CFG) + backbone.train(mode=True) + for m in backbone.modules(): + assert m.training + + backbone.norm_eval = True + backbone.train(mode=True) + for m in backbone.modules(): + if isinstance(m, _BatchNorm): + assert not m.training + else: + assert m.training + + x = torch.rand(10, 3, 224, 224) + assert len(backbone(x)) == 1 + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['frozen_stages'] = 2 + backbone = MODELS.build(backbone_cfg) + backbone.train() + + for param in backbone.conv1.parameters(): + assert not param.requires_grad + for i in range(2): + layer = backbone.layers[i] + for m in layer.modules(): + if i < 2: + assert not m.training + else: + assert m.training + for param in layer.parameters(): + if i < 2: + assert not param.requires_grad + else: + assert param.requires_grad + + +def test_searchable_shufflenet_v2_init() -> None: + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['out_indices'] = (5, ) + + with pytest.raises(ValueError): + MODELS.build(backbone_cfg) + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['frozen_stages'] = 5 + + with pytest.raises(ValueError): + MODELS.build(backbone_cfg) + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['with_last_layer'] = False + with pytest.raises(ValueError): + MODELS.build(backbone_cfg) + + backbone_cfg['out_indices'] = (3, ) + backbone = MODELS.build(backbone_cfg) + assert len(backbone.layers) == 4 + + +def test_searchable_shufflenet_v2_init_weights() -> None: + backbone = MODELS.build(BACKBONE_CFG) + backbone.init_weights() + + for m in backbone.modules(): + if isinstance(m, (_BatchNorm, GroupNorm)): + if hasattr(m, 'weight') and m.weight is not None: + assert torch.equal(m.weight, torch.ones_like(m.weight)) + if hasattr(m, 'bias') and m.bias is not None: + bias_tensor = torch.ones_like(m.bias) + bias_tensor *= 0.0001 + assert torch.equal(bias_tensor, m.bias) + + checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['init_cfg'] = dict( + type='Pretrained', checkpoint=checkpoint_path) + backbone = MODELS.build(backbone_cfg) + torch.save(backbone.state_dict(), checkpoint_path) + + name2weight = dict() + for name, m in backbone.named_modules(): + if isinstance(m, (_BatchNorm, GroupNorm)): + if hasattr(m, 'weight') and m.weight is not None: + name2weight[name] = m.weight.clone() + + backbone.init_weights() + for name, m in backbone.named_modules(): + if isinstance(m, (_BatchNorm, GroupNorm)): + if hasattr(m, 'weight') and m.weight is not None: + if name in name2weight: + assert torch.equal(name2weight[name], m.weight) + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['norm_cfg'] = dict(type='BN', track_running_stats=False) + backbone = MODELS.build(backbone_cfg) + backbone.init_weights() + + backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone_cfg['norm_cfg'] = dict(type='GN', num_groups=1) + backbone = MODELS.build(backbone_cfg) + backbone.init_weights() diff --git a/tests/test_models/test_architectures/test_backbones/utils.py b/tests/test_models/test_architectures/test_backbones/utils.py new file mode 100644 index 00000000..593faa4a --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +from torch import Tensor +from torch.nn import Conv2d, Module + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class MockMutable(Module): + + def __init__(self, choices: List[str], module_kwargs: Dict) -> None: + super().__init__() + + self.choices = choices + self.module_kwargs = module_kwargs + self.conv = Conv2d(**module_kwargs, kernel_size=3, padding=3 // 2) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(x)