diff --git a/mmrazor/models/architectures/__init__.py b/mmrazor/models/architectures/__init__.py index 89fc51f5..162be884 100644 --- a/mmrazor/models/architectures/__init__.py +++ b/mmrazor/models/architectures/__init__.py @@ -6,4 +6,8 @@ from .mmdet import MMDetArchitecture from .mmseg import MMSegArchitecture from .utils import * # noqa: F401,F403 -__all__ = ['MMClsArchitecture', 'MMDetArchitecture', 'MMSegArchitecture'] +__all__ = [ + 'MMClsArchitecture', + 'MMDetArchitecture', + 'MMSegArchitecture', +] diff --git a/mmrazor/models/architectures/backbones/__init__.py b/mmrazor/models/architectures/backbones/__init__.py index 9ceb3265..e0cfb671 100644 --- a/mmrazor/models/architectures/backbones/__init__.py +++ b/mmrazor/models/architectures/backbones/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .darts_backbone import DartsBackbone from .searchable_mobilenet import SearchableMobileNet from .searchable_shufflenet_v2 import SearchableShuffleNetV2 -__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2'] +__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone'] diff --git a/mmrazor/models/architectures/components/backbones/darts_backbone.py b/mmrazor/models/architectures/backbones/darts_backbone.py similarity index 51% rename from mmrazor/models/architectures/components/backbones/darts_backbone.py rename to mmrazor/models/architectures/backbones/darts_backbone.py index 0ecb2a58..7c1ffd50 100644 --- a/mmrazor/models/architectures/components/backbones/darts_backbone.py +++ b/mmrazor/models/architectures/backbones/darts_backbone.py @@ -1,22 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from mmcv.cnn import build_activation_layer, build_norm_layer +from torch import Tensor from mmrazor.registry import MODELS -from ...utils import Placeholder class FactorizedReduce(nn.Module): - """Reduce feature map size by factorized pointwise (stride=2).""" + """Reduce feature map size by factorized pointwise (stride=2). - def __init__(self, - in_channels, - out_channels, - act_cfg=dict(type='ReLU'), - norm_cfg=dict(type='BN')): + Args: + in_channels (int): number of channels of input tensor. + out_channels (int): number of channels of output tensor. + act_cfg (Dict): config to build activation layer. + norm_cfg (Dict): config to build normalization layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + act_cfg: Dict = dict(type='ReLU'), + norm_cfg: Dict = dict(type='BN') + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -39,7 +49,8 @@ class FactorizedReduce(nn.Module): bias=False) self.bn = build_norm_layer(self.norm_cfg, self.out_channels)[1] - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward with factorized reduce.""" x = self.relu(x) out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) out = self.bn(out) @@ -47,18 +58,31 @@ class FactorizedReduce(nn.Module): class StandardConv(nn.Module): - """ - Standard conv: ReLU - Conv - BN + """Standard Convolution in Darts. Basic structure is ReLU-Conv-BN. + + Args: + in_channels (int): number of channels of input tensor. + out_channels (int): number of channels of output tensor. + kernel_size (Union[int, Tuple]): size of the convolving kernel. + stride (Union[int, Tuple]): controls the stride for the + cross-correlation, a single number or a one-element tuple. + Default to 1. + padding (Union[str, int, Tuple]): Padding added to both sides + of the input. Default to 0. + act_cfg (Dict): config to build activation layer. + norm_cfg (Dict): config to build normalization layer. """ - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - act_cfg=dict(type='ReLU'), - norm_cfg=dict(type='BN')): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple], + stride: Union[int, Tuple] = 1, + padding: Union[str, int, Tuple] = 0, + act_cfg: Dict = dict(type='ReLU'), + norm_cfg: Dict = dict(type='BN') + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -78,14 +102,26 @@ class StandardConv(nn.Module): bias=False), build_norm_layer(self.norm_cfg, self.out_channels)[1]) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward the standard convolution.""" return self.net(x) class Node(nn.Module): + """Node structure of DARTS. - def __init__(self, node_id, num_prev_nodes, channels, - num_downsample_nodes): + Args: + node_id (str): key of the node. + num_prev_nodes (int): number of previous nodes. + channels (int): number of channels of current node. + num_downsample_nodes (int): index of downsample node. + mutable_cfg (Dict): config of `DiffMutable`. + route_cfg (Dict): config of `DiffChoiceRoute`. + """ + + def __init__(self, node_id: str, num_prev_nodes: int, channels: int, + num_downsample_nodes: int, mutable_cfg: Dict, + route_cfg: Dict) -> None: super().__init__() edges = nn.ModuleDict() for i in range(num_prev_nodes): @@ -95,35 +131,58 @@ class Node(nn.Module): stride = 1 edge_id = '{}_p{}'.format(node_id, i) - edges.add_module( - edge_id, - nn.Sequential( - Placeholder( - group='node', - space_id=edge_id, - choice_args=dict( - stride=stride, - in_channels=channels, - out_channels=channels)), )) - self.edges = Placeholder( - group='node_edge', space_id=node_id, choices=edges) + module_kwargs = dict( + in_channels=channels, + out_channels=channels, + stride=stride, + ) - def forward(self, prev_nodes): + mutable_cfg.update(module_kwargs=module_kwargs) + mutable_cfg.update(alias=edge_id) + edges.add_module(edge_id, MODELS.build(mutable_cfg)) + + route_cfg.update(edges=edges) + self.edges = MODELS.build(route_cfg) + + def forward(self, prev_nodes: Union[List[Tensor], + Tuple[Tensor]]) -> Tensor: + """Forward with the previous nodes list.""" return self.edges(prev_nodes) class Cell(nn.Module): + """Darts cell structure. + + Args: + num_nodes (int): number of nodes. + channels (int): number of channels of current cell. + prev_channels (int): number of channel of previous input. + prev_prev_channels (int): number of channel of previous previous input. + reduction (bool): whether to reduce the feature map size. + prev_reduction (bool): whether to reduce the previous feature map size. + mutable_cfg (Optional[Dict]): config of `DiffMutable`. + route_cfg (Optional[Dict]): config of `DiffChoiceRoute`. + act_cfg (Dict): config to build activation layer. + Defaults to dict(type='ReLU'). + norm_cfg (Dict): config to build normalization layer. + Defaults to dict(type='BN'). + """ + + def __init__( + self, + num_nodes: int, + channels: int, + prev_channels: int, + prev_prev_channels: int, + reduction: bool, + prev_reduction: bool, + mutable_cfg: Dict, + route_cfg: Dict, + act_cfg: Dict = dict(type='ReLU'), + norm_cfg: Dict = dict(type='BN'), + ) -> None: - def __init__(self, - num_nodes, - channels, - prev_channels, - prev_prev_channels, - reduction, - prev_reduction, - act_cfg=dict(type='ReLU'), - norm_cfg=dict(type='BN')): super().__init__() self.act_cfg = act_cfg self.norm_cfg = norm_cfg @@ -152,11 +211,12 @@ class Cell(nn.Module): node_id = f'normal_n{depth}' num_downsample_nodes = 0 self.nodes.append( - Node(node_id, depth, channels, num_downsample_nodes)) + Node(node_id, depth, channels, num_downsample_nodes, + mutable_cfg, route_cfg)) - def forward(self, s0, s1): - # s0, s1 are the outputs of previous previous cell and previous cell, - # respectively. + def forward(self, s0: Tensor, s1: Tensor) -> Tensor: + """Forward with the outputs of previous previous cell and previous + cell.""" tensors = [self.preproc0(s0), self.preproc1(s1)] for node in self.nodes: cur_tensor = node(tensors) @@ -167,14 +227,21 @@ class Cell(nn.Module): class AuxiliaryModule(nn.Module): - """Auxiliary head in 2/3 place of network to let the gradient flow well.""" + """Auxiliary head in 2/3 place of network to let the gradient flow well. + + Args: + in_channels (int): number of channels of inputs. + base_channels (int): number of middle channels of the auxiliary module. + out_channels (int): number of channels of outputs. + norm_cfg (Dict): config to build normalization layer. + Defaults to dict(type='BN'). + """ def __init__(self, - in_channels, - base_channels, - out_channels, - norm_cfg=dict(type='BN')): - + in_channels: int, + base_channels: int, + out_channels: int, + norm_cfg: Dict = dict(type='BN')) -> None: super().__init__() self.norm_cfg = norm_cfg self.net = nn.Sequential( @@ -189,25 +256,56 @@ class AuxiliaryModule(nn.Module): build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU(inplace=True)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward the auxiliary module.""" return self.net(x) @MODELS.register_module() class DartsBackbone(nn.Module): + """Backbone of Differentiable Architecture Search (DARTS). - def __init__(self, - in_channels, - base_channels, - num_layers=8, - num_nodes=4, - stem_multiplier=3, - out_indices=(7, ), - auxliary=False, - aux_channels=None, - aux_out_channels=None, - act_cfg=dict(type='ReLU'), - norm_cfg=dict(type='BN')): + Args: + in_channels (int): number of channels of input tensor. + base_channels (int): number of middle channels. + mutable_cfg (Optional[Dict]): config of `DiffMutable`. + route_cfg (Optional[Dict]): config of `DiffChoiceRoute`. + num_layers (Optional[int]): number of layers. + Defaults to 8. + num_nodes (Optional[int]): number of nodes. + Defaults to 4. + stem_multiplier (Optional[int]): multiplier for stem. + Defaults to 3. + out_indices (tuple, optional): output indices for auxliary module. + Defaults to (7, ). + auxliary (bool, optional): whether use auxliary module. + Defaults to False. + aux_channels (Optional[int]): number of middle channels of + auxliary module. Defaults to None. + aux_out_channels (Optional[int]): number of output channels of + auxliary module. Defaults to None. + act_cfg (Dict): config to build activation layer. + Defaults to dict(type='ReLU'). + norm_cfg (Dict): config to build normalization layer. + Defaults to dict(type='BN'). + """ + + def __init__( + self, + in_channels: int, + base_channels: int, + mutable_cfg: Dict, + route_cfg: Dict, + num_layers: int = 8, + num_nodes: int = 4, + stem_multiplier: int = 3, + out_indices: Union[Tuple, List] = (7, ), + auxliary: bool = False, + aux_channels: Optional[int] = None, + aux_out_channels: Optional[int] = None, + act_cfg: Dict = dict(type='ReLU'), + norm_cfg: Dict = dict(type='BN'), + ) -> None: super().__init__() self.in_channels = in_channels @@ -237,7 +335,7 @@ class DartsBackbone(nn.Module): build_norm_layer(self.norm_cfg, self.out_channels)[1]) # for the first cell, stem is used for both s0 and s1 - # [!] prev_prev_channels and prev_channels is output channel size, + # prev_prev_channels and prev_channels is output channel size, # but c_cur is input channel size. prev_prev_channels = self.out_channels prev_channels = self.out_channels @@ -255,7 +353,7 @@ class DartsBackbone(nn.Module): cell = Cell(self.num_nodes, self.out_channels, prev_channels, prev_prev_channels, reduction, prev_reduction, - self.act_cfg, self.norm_cfg) + mutable_cfg, route_cfg, self.act_cfg, self.norm_cfg) self.cells.append(cell) prev_prev_channels = prev_channels @@ -267,7 +365,8 @@ class DartsBackbone(nn.Module): self.aux_out_channels, self.norm_cfg) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward the darts backbone.""" outs = [] s0 = s1 = self.stem(x) for i, cell in enumerate(self.cells): @@ -276,7 +375,6 @@ class DartsBackbone(nn.Module): outs.append(s1) if i == self.auxliary_indice and self.training: aux_feature = self.auxliary_module(s1) - outs.insert(0, aux_feature) return tuple(outs) diff --git a/mmrazor/models/architectures/components/backbones/__init__.py b/mmrazor/models/architectures/components/backbones/__init__.py index 6fd18ba8..4e3bdb15 100644 --- a/mmrazor/models/architectures/components/backbones/__init__.py +++ b/mmrazor/models/architectures/components/backbones/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .darts_backbone import DartsBackbone from .searchable_mobilenet import SearchableMobileNet from .searchable_shufflenet_v2 import SearchableShuffleNetV2 -__all__ = ['DartsBackbone', 'SearchableShuffleNetV2', 'SearchableMobileNet'] +__all__ = ['SearchableShuffleNetV2', 'SearchableMobileNet'] diff --git a/mmrazor/models/mutables/base_mutable.py b/mmrazor/models/mutables/base_mutable.py index bf3bf438..32b4393c 100644 --- a/mmrazor/models/mutables/base_mutable.py +++ b/mmrazor/models/mutables/base_mutable.py @@ -25,6 +25,7 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]): Args: module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -33,10 +34,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg=init_cfg) self.module_kwargs = module_kwargs + self.alias = alias self._is_fixed = False self._current_choice: Optional[CHOICE_TYPE] = None diff --git a/mmrazor/models/mutables/diff_mutable.py b/mmrazor/models/mutables/diff_mutable.py index 69cb6118..74e8add8 100644 --- a/mmrazor/models/mutables/diff_mutable.py +++ b/mmrazor/models/mutables/diff_mutable.py @@ -20,6 +20,7 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): Args: module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -31,8 +32,10 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: - super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg) + super().__init__( + module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg) def forward(self, x: Any, @@ -103,6 +106,7 @@ class DiffOP(DiffMutable[str, str]): operations. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -113,9 +117,11 @@ class DiffOP(DiffMutable[str, str]): self, candidate_ops: Dict[str, Dict], module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None, ) -> None: - super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg) + super().__init__( + module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg) assert len(candidate_ops) >= 1, \ f'Number of candidate op must greater than or equal to 1, ' \ f'but got: {len(candidate_ops)}' diff --git a/mmrazor/models/mutables/oneshot_mutable.py b/mmrazor/models/mutables/oneshot_mutable.py index 0057c267..903a1d9c 100644 --- a/mmrazor/models/mutables/oneshot_mutable.py +++ b/mmrazor/models/mutables/oneshot_mutable.py @@ -24,6 +24,7 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): Args: module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -35,8 +36,10 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: - super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg) + super().__init__( + module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg) def forward(self, x: Any) -> Any: """Calls either :func:`forward_fixed` or :func:`forward_choice` @@ -107,6 +110,7 @@ class OneShotOP(OneShotMutable[str, str]): operations. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -153,9 +157,14 @@ class OneShotOP(OneShotMutable[str, str]): self, candidate_ops: Union[Dict[str, Dict], nn.ModuleDict], module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None, ) -> None: - super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg) + super().__init__( + module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg) + assert len(candidate_ops) >= 1, \ + f'Number of candidate op must greater than 1, ' \ + f'but got: {len(candidate_ops)}' self._is_fixed = False self._chosen: Optional[str] = None @@ -284,6 +293,7 @@ class OneShotProbOP(OneShotOP): candidate operation. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, @@ -294,10 +304,12 @@ class OneShotProbOP(OneShotOP): candidate_ops: Dict[str, Dict], choice_probs: list = None, module_kwargs: Optional[Dict[str, Dict]] = None, + alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__( candidate_ops=candidate_ops, module_kwargs=module_kwargs, + alias=alias, init_cfg=init_cfg) assert choice_probs is not None assert sum(choice_probs) - 1 < np.finfo(np.float64).eps, \ diff --git a/mmrazor/models/mutators/base_mutator.py b/mmrazor/models/mutators/base_mutator.py index 31beaae7..b87ec691 100644 --- a/mmrazor/models/mutators/base_mutator.py +++ b/mmrazor/models/mutators/base_mutator.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from collections import Counter from typing import Dict, Generic, List, Optional, Type, TypeVar from mmcv.runner import BaseModule @@ -98,7 +99,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): supernet (:obj:`torch.nn.Module`): The supernet to be searched in your algorithm. """ - return self._build_search_group(supernet) + self._build_search_group(supernet) @property def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]: @@ -119,43 +120,193 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): 'Call `prepare_from_supernet` before access search group!') return self._search_group + def _build_name_mutable_mapping( + self, supernet: Module) -> Dict[str, MUTABLE_TYPE]: + """Mapping module name to mutable.""" + name2mutable: Dict[str, MUTABLE_TYPE] = dict() + for name, module in supernet.named_modules(): + if isinstance(module, self.mutable_class_type): + name2mutable[name] = module + return name2mutable + + def _build_alias_names_mapping(self, + supernet: Module) -> Dict[str, List[str]]: + """Mapping alias to module names.""" + alias2mutable_names: Dict[str, List[str]] = dict() + for name, module in supernet.named_modules(): + if isinstance(module, self.mutable_class_type): + if module.alias is not None: + if module.alias not in alias2mutable_names: + alias2mutable_names[module.alias] = [name] + else: + alias2mutable_names[module.alias].append(name) + + return alias2mutable_names + def _build_search_group(self, supernet: Module) -> None: - """Build search group with ``supernet`` and ``custom_group``. + """Build search group with ``custom_group`` and ``alias``(see more + information in :class:`BaseMutable`). Grouping by alias and module name + are both supported. Note: Apart from user-defined search group, all other searchable modules(mutable) will be grouped separately. + The main difference between using alias and module name for + grouping is that the alias is One-to-Many while the module + name is One-to-One. + + When using both alias and module name in `custom_group`, the + priority of alias is higher than that of module name. + + If alias is set in `custom_group`, then its corresponding module + name should not be in the `custom_group`. + + Moreover, there should be no duplicate keys in the `custom_group`. + + Example: + >>> import torch + >>> from mmrazor.models.mutables.diff_mutable import DiffOP + + >>> # Assume that a toy model consists of three mutabels + >>> # whose name are op1,op2,op3. The corresponding + >>> # alias names of the three mutables are a1, a1, a2. + >>> model = ToyModel() + + >>> # Using alias for grouping + >>> mutator = DiffOP(custom_group=[['a1'], ['a2']]) + >>> mutator.prepare_from_supernet(model) + >>> mutator.search_group + {0: [op1, op2], 1: [op3]} + + >>> # Using module name for grouping + >>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']]) + >>> mutator.prepare_from_supernet(model) + >>> mutator.search_group + {0: [op1, op2], 1: [op3]} + + >>> # Using both alias and module name for grouping + >>> mutator = DiffOP(custom_group=[['a2'], ['op2']]) + >>> mutator.prepare_from_supernet(model) + >>> # The last operation would be grouped + >>> mutator.search_group + {0: [op3], 1: [op2], 2: [op1]} + + Args: supernet (:obj:`torch.nn.Module`): The supernet to be searched in your algorithm. """ - module_name2module: Dict[str, Module] = dict() - for name, module in supernet.named_modules(): - module_name2module[name] = module + name2mutable = self._build_name_mutable_mapping(supernet) + alias2mutable_names = self._build_alias_names_mapping(supernet) - # Map module to group id for user-defined group - self.module_name2group_id: Dict[str, int] = dict() - for idx, group in enumerate(self._custom_group): - for module_name in group: - assert module_name in module_name2module, \ - f'`{module_name}` is not a module name of supernet, ' \ - f'expected module names: {module_name2module.keys()}' - self.module_name2group_id[module_name] = idx + # Check whether the custom group is valid + if len(self._custom_group) > 0: + self._check_valid_groups(alias2mutable_names, name2mutable, + self._custom_group) - search_group: Dict[int, List[MUTABLE_TYPE]] = dict() - current_group_nums = len(self._custom_group) + # Construct search_groups based on user-defined group + search_groups: Dict[int, List[MUTABLE_TYPE]] = dict() - for name, module in module_name2module.items(): - if isinstance(module, self.mutable_class_type): - group_id = self.module_name2group_id.get(name) - if group_id is None: - group_id = current_group_nums + current_group_nums = 0 + grouped_mutable_names: List[str] = list() + grouped_alias: List[str] = list() + for group in self._custom_group: + group_mutables = list() + for item in group: + if item in alias2mutable_names: + # if the item is from alias name + mutable_names: List[str] = alias2mutable_names[item] + grouped_alias.append(item) + group_mutables.extend( + [name2mutable[n] for n in mutable_names]) + grouped_mutable_names.extend(mutable_names) + else: + # if the item is in name2mutable + group_mutables.append(name2mutable[item]) + grouped_mutable_names.append(item) + + search_groups[current_group_nums] = group_mutables + current_group_nums += 1 + + # Construct search_groups based on alias + for alias, mutable_names in alias2mutable_names.items(): + if alias not in grouped_alias: + # Check whether all current names are already grouped + flag_all_grouped = True + for mutable_name in mutable_names: + if mutable_name not in grouped_mutable_names: + flag_all_grouped = False + + # If not all mutables are already grouped + if not flag_all_grouped: + search_groups[current_group_nums] = [] + for mutable_name in mutable_names: + if mutable_name not in grouped_mutable_names: + search_groups[current_group_nums].append( + name2mutable[mutable_name]) + grouped_mutable_names.append(mutable_name) current_group_nums += 1 - try: - search_group[group_id].append(module) - except KeyError: - search_group[group_id] = [module] - self.module_name2group_id[name] = group_id - self._search_group = search_group + # check whether all the mutable objects are in the search_groups + for name, module in supernet.named_modules(): + if isinstance(module, self.mutable_class_type): + if name in grouped_mutable_names: + continue + else: + search_groups[current_group_nums] = [module] + current_group_nums += 1 + + grouped_counter = Counter(grouped_mutable_names) + + # find duplicate keys + duplicate_keys = list() + for key, count in grouped_counter.items(): + if count > 1: + duplicate_keys.append(key) + + assert len(grouped_mutable_names) == len( + list(set(grouped_mutable_names))), \ + 'There are duplicate keys in grouped mutable names. ' \ + f'The duplicate keys are {duplicate_keys}. ' \ + 'Please check if there are duplicate keys in the `custom_group`.' + + self._search_group = search_groups + + def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], + name2mutable: Dict[str, MUTABLE_TYPE], + custom_group: List[List[str]]) -> None: + + aliases = [*alias2mutable_names.keys()] + module_names = [*name2mutable.keys()] + + # check if all keys are legal + expanded_custom_group: List[str] = [ + _ for group in custom_group for _ in group + ] + legal_keys: List[str] = [*aliases, *module_names] + + for key in expanded_custom_group: + if key not in legal_keys: + raise AssertionError( + f'The key: {key} in `custom_group` is not legal. ' + f'Legal keys are: {legal_keys}. ' + 'Make sure that the keys are either alias or mutable name') + + # when the mutable has alias attribute, the corresponding module + # name should not be used in `custom_group`. + used_aliases = list() + for group in custom_group: + for key in group: + if key in aliases: + used_aliases.append(key) + + for alias_key in used_aliases: + mutable_names: List = alias2mutable_names[alias_key] + # check whether module name is in custom group + for mutable_name in mutable_names: + if mutable_name in expanded_custom_group: + raise AssertionError( + f'When a mutable is set alias attribute :{alias_key},' + f'the corresponding module name {mutable_name} should ' + f'not be used in `custom_group` {custom_group}.') diff --git a/mmrazor/models/mutators/diff_mutator.py b/mmrazor/models/mutators/diff_mutator.py index 802fcf58..40542870 100644 --- a/mmrazor/models/mutators/diff_mutator.py +++ b/mmrazor/models/mutators/diff_mutator.py @@ -12,6 +12,10 @@ from .base_mutator import ArchitectureMutator class DiffMutator(ArchitectureMutator[DiffMutable]): """Differentiable mutable based mutator. + `DiffMutator` is responsible for mutating `DiffMutable`, `DiffOP`, + `DiffChoiceRoute` and `GumbelChoiceRoute`. The architecture + parameters of the mutables are retained in `DiffMutator`. + Args: custom_group (list[list[str]], optional): User-defined search groups. All searchable modules that are not in ``custom_group`` will be @@ -21,7 +25,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]): def __init__(self, custom_group: Optional[List[List[str]]] = None, init_cfg: Optional[Dict] = None) -> None: - super().__init__(custom_group, init_cfg) + super().__init__(custom_group=custom_group, init_cfg=init_cfg) def prepare_from_supernet(self, supernet: nn.Module) -> None: """Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS. @@ -32,31 +36,23 @@ class DiffMutator(ArchitectureMutator[DiffMutable]): """ super().prepare_from_supernet(supernet) - self.arch_params = self.build_arch_params(supernet) + self.arch_params = self.build_arch_params() - def build_arch_params(self, supernet): + def build_arch_params(self): """This function will build many arch params, which are generally used in differentiable search algorithms, such as Darts' series. Each group_id corresponds to an arch param, so the Mutables with the same group_id share the same arch param. - Args: - supernet (:obj:`torch.nn.Module`): The architecture to be used - in your algorithm. Returns: - torch.nn.ParameterDict: the arch params are got after traversing - the supernet. + torch.nn.ParameterDict: the arch params are got by `search_group`. """ arch_params: Dict[int, nn.Parameter] = dict() - for module_name, module in supernet.named_modules(): - if isinstance(module, self.mutable_class_type): - group_id = self.module_name2group_id[module_name] - if group_id not in arch_params: - group_arch_param = module.build_arch_param() - if group_arch_param is not None: - arch_params[group_id] = group_arch_param + for group_id, modules in self.search_group.items(): + group_arch_param = modules[0].build_arch_param() + arch_params[group_id] = group_arch_param return arch_params @@ -70,10 +66,8 @@ class DiffMutator(ArchitectureMutator[DiffMutable]): """ for group_id, modules in self.search_group.items(): - if group_id in self.arch_params.keys(): - for module in modules: - module.set_forward_args( - arch_param=self.arch_params[group_id]) + for module in modules: + module.set_forward_args(arch_param=self.arch_params[group_id]) @property def mutable_class_type(self) -> Type[DiffMutable]: diff --git a/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py new file mode 100644 index 00000000..29b49af1 --- /dev/null +++ b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest +from unittest import TestCase + +import torch +import torch.nn as nn +from mmcls.models import * # noqa:F403,F401 + +from mmrazor.models import * # noqa:F403,F401 +from mmrazor.models.architectures.components import * # noqa:F403,F401 +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +class TestDartsBackbone(TestCase): + + def setUp(self) -> None: + self.mutable_cfg = dict( + type='DiffOP', + candidate_ops=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + )) + + self.route_cfg = dict( + type='DiffChoiceRoute', + with_arch_param=True, + ) + + self.backbone_cfg = dict( + type='mmrazor.DartsBackbone', + in_channels=3, + base_channels=16, + num_layers=8, + num_nodes=4, + stem_multiplier=3, + out_indices=(7, ), + mutable_cfg=self.mutable_cfg, + route_cfg=self.route_cfg) + + self.mutator_cfg = dict( + type='DiffMutator', + custom_group=None, + ) + + def test_darts_backbone(self): + model = MODELS.build(self.backbone_cfg) + custom_group = self.generate_key(model) + + assert model is not None + self.mutable_cfg.update(custom_group=custom_group) + mutator = MODELS.build(self.mutator_cfg) + assert mutator is not None + + mutator.prepare_from_supernet(model) + mutator.modify_supernet_forward() + + inputs = torch.randn(4, 3, 224, 224) + outputs = model(inputs) + assert outputs is not None + + def test_darts_backbone_with_auxliary(self): + self.backbone_cfg.update( + auxliary=True, aux_channels=256, aux_out_channels=512) + model = MODELS.build(self.backbone_cfg) + custom_group = self.generate_key(model) + + assert model is not None + self.mutable_cfg.update(custom_group=custom_group) + mutator = MODELS.build(self.mutator_cfg) + assert mutator is not None + mutator.prepare_from_supernet(model) + mutator.modify_supernet_forward() + + inputs = torch.randn(4, 3, 224, 224) + outputs = model(inputs) + assert outputs is not None + + def generate_key(self, model): + """auto generate custom group for darts.""" + tmp_dict = dict() + + for key, _ in model.named_modules(): + node_type = key.split('._candidate_ops')[0].split('.')[-1].split( + '_')[0] + if node_type not in ['normal', 'reduce']: + # not supported type + continue + + node_name = key.split('._candidate_ops')[0].split('.')[-1] + if node_name not in tmp_dict.keys(): + tmp_dict[node_name] = [key.split('._candidate_ops')[0]] + else: + current_key = key.split('._candidate_ops')[0] + if current_key not in tmp_dict[node_name]: + tmp_dict[node_name].append(current_key) + + return list(tmp_dict.values()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py b/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py index 0c600d9b..7f7d4e12 100644 --- a/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py +++ b/tests/test_models/test_architectures/test_backbones/test_searchable_shufflenet_v2.py @@ -124,12 +124,14 @@ def test_searchable_shufflenet_v2_init_weights() -> None: bias_tensor *= 0.0001 assert torch.equal(bias_tensor, m.bias) - checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') + temp_dir = tempfile.mkdtemp() + checkpoint_path = os.path.join(temp_dir, 'checkpoint.pth') backbone_cfg = copy.deepcopy(BACKBONE_CFG) + backbone = MODELS.build(backbone_cfg) + torch.save(backbone.state_dict(), checkpoint_path) backbone_cfg['init_cfg'] = dict( type='Pretrained', checkpoint=checkpoint_path) backbone = MODELS.build(backbone_cfg) - torch.save(backbone.state_dict(), checkpoint_path) name2weight = dict() for name, m in backbone.named_modules(): diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py index be85547f..b940b8a4 100644 --- a/tests/test_models/test_mutators/test_diff_mutator.py +++ b/tests/test_models/test_mutators/test_diff_mutator.py @@ -3,7 +3,6 @@ from unittest import TestCase import pytest import torch.nn as nn -from mmcls.models import * # noqa: F401,F403 from mmrazor.models import * # noqa: F401,F403 from mmrazor.models.mutables import DiffMutable, DiffOP @@ -42,6 +41,37 @@ class SearchableModel(nn.Module): return self.slayer3(x) +class SearchableLayerAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + mutable_cfg.update(alias='op1') + self.op1 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op2') + self.op2 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op3') + self.op3 = MODELS.build(mutable_cfg) + + def forward(self, x): + x = self.op1(x) + x = self.op2(x) + return self.op3(x) + + +class SearchableModelAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.slayer1 = SearchableLayerAlias(mutable_cfg) + self.slayer2 = SearchableLayerAlias(mutable_cfg) + self.slayer3 = SearchableLayerAlias(mutable_cfg) + + def forward(self, x): + x = self.slayer1(x) + x = self.slayer2(x) + return self.slayer3(x) + + class TestDiffMutator(TestCase): def setUp(self): @@ -107,6 +137,81 @@ class TestDiffMutator(TestCase): with pytest.raises(AssertionError): mutator.prepare_from_supernet(model) + def test_diff_mutator_diffop_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']] + mutator: DiffOP = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_group.keys()) == [0, 1, 2] + + mutator.modify_supernet_forward() + assert mutator.mutable_class_type == DiffMutable + + def test_diff_mutator_alias_module_name(self) -> None: + """Using both alias and module name for grouping.""" + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_group'] = [['op1'], + [ + 'slayer1.op2', 'slayer2.op2', + 'slayer3.op2' + ], ['slayer1.op3', 'slayer2.op3']] + mutator: DiffOP = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_group.keys()) == [0, 1, 2, 3] + + mutator.modify_supernet_forward() + assert mutator.mutable_class_type == DiffMutable + + def test_diff_mutator_duplicate_keys(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_group'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer2.op3'], + ] + mutator: DiffOP = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_duplicate_key_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_group'] = [ + ['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffOP = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_illegal_key(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_group'] = [ + ['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffOP = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + if __name__ == '__main__': import unittest diff --git a/tests/test_models/test_mutators/test_one_shot_mutator.py b/tests/test_models/test_mutators/test_one_shot_mutator.py index be565577..b3f0db8a 100644 --- a/tests/test_models/test_mutators/test_one_shot_mutator.py +++ b/tests/test_models/test_mutators/test_one_shot_mutator.py @@ -13,17 +13,17 @@ from mmrazor.registry import MODELS MODEL_CFG = dict( type='mmcls.ImageClassifier', backbone=dict( - type='ResNet', + type='mmcls.ResNet', depth=50, num_stages=4, out_indices=(3, ), style='pytorch'), - neck=dict(type='GlobalAveragePooling'), + neck=dict(type='mmcls.GlobalAveragePooling'), head=dict( - type='LinearClsHead', + type='mmcls.LinearClsHead', num_classes=1000, in_channels=2048, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), topk=(1, 5), )) @@ -87,6 +87,7 @@ def test_one_shot_mutator_mutable_model() -> None: model = _SearchableModel() mutator: OneShotMutator = MODELS.build(MUTATOR_CFG) + # import pdb; pdb.set_trace() mutator.prepare_from_supernet(model) assert list(mutator.search_group.keys()) == [0, 1, 2] @@ -116,3 +117,7 @@ def test_one_shot_mutator_mutable_model() -> None: mutator = MODELS.build(mutator_cfg) with pytest.raises(AssertionError): mutator.prepare_from_supernet(model) + + +if __name__ == '__main__': + pytest.main()