Refactor DartsBackbone

pull/198/head
PJDong 2022-06-17 04:05:41 +00:00 committed by pppppM
parent 577a3a2a94
commit 56afc69d85
13 changed files with 628 additions and 130 deletions

View File

@ -6,4 +6,8 @@ from .mmdet import MMDetArchitecture
from .mmseg import MMSegArchitecture
from .utils import * # noqa: F401,F403
__all__ = ['MMClsArchitecture', 'MMDetArchitecture', 'MMSegArchitecture']
__all__ = [
'MMClsArchitecture',
'MMDetArchitecture',
'MMSegArchitecture',
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_backbone import DartsBackbone
from .searchable_mobilenet import SearchableMobileNet
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2']
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone']

View File

@ -1,22 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from torch import Tensor
from mmrazor.registry import MODELS
from ...utils import Placeholder
class FactorizedReduce(nn.Module):
"""Reduce feature map size by factorized pointwise (stride=2)."""
"""Reduce feature map size by factorized pointwise (stride=2).
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')):
Args:
in_channels (int): number of channels of input tensor.
out_channels (int): number of channels of output tensor.
act_cfg (Dict): config to build activation layer.
norm_cfg (Dict): config to build normalization layer.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
act_cfg: Dict = dict(type='ReLU'),
norm_cfg: Dict = dict(type='BN')
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -39,7 +49,8 @@ class FactorizedReduce(nn.Module):
bias=False)
self.bn = build_norm_layer(self.norm_cfg, self.out_channels)[1]
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward with factorized reduce."""
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
@ -47,18 +58,31 @@ class FactorizedReduce(nn.Module):
class StandardConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""Standard Convolution in Darts. Basic structure is ReLU-Conv-BN.
Args:
in_channels (int): number of channels of input tensor.
out_channels (int): number of channels of output tensor.
kernel_size (Union[int, Tuple]): size of the convolving kernel.
stride (Union[int, Tuple]): controls the stride for the
cross-correlation, a single number or a one-element tuple.
Default to 1.
padding (Union[str, int, Tuple]): Padding added to both sides
of the input. Default to 0.
act_cfg (Dict): config to build activation layer.
norm_cfg (Dict): config to build normalization layer.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple],
stride: Union[int, Tuple] = 1,
padding: Union[str, int, Tuple] = 0,
act_cfg: Dict = dict(type='ReLU'),
norm_cfg: Dict = dict(type='BN')
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -78,14 +102,26 @@ class StandardConv(nn.Module):
bias=False),
build_norm_layer(self.norm_cfg, self.out_channels)[1])
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward the standard convolution."""
return self.net(x)
class Node(nn.Module):
"""Node structure of DARTS.
def __init__(self, node_id, num_prev_nodes, channels,
num_downsample_nodes):
Args:
node_id (str): key of the node.
num_prev_nodes (int): number of previous nodes.
channels (int): number of channels of current node.
num_downsample_nodes (int): index of downsample node.
mutable_cfg (Dict): config of `DiffMutable`.
route_cfg (Dict): config of `DiffChoiceRoute`.
"""
def __init__(self, node_id: str, num_prev_nodes: int, channels: int,
num_downsample_nodes: int, mutable_cfg: Dict,
route_cfg: Dict) -> None:
super().__init__()
edges = nn.ModuleDict()
for i in range(num_prev_nodes):
@ -95,35 +131,58 @@ class Node(nn.Module):
stride = 1
edge_id = '{}_p{}'.format(node_id, i)
edges.add_module(
edge_id,
nn.Sequential(
Placeholder(
group='node',
space_id=edge_id,
choice_args=dict(
stride=stride,
in_channels=channels,
out_channels=channels)), ))
self.edges = Placeholder(
group='node_edge', space_id=node_id, choices=edges)
module_kwargs = dict(
in_channels=channels,
out_channels=channels,
stride=stride,
)
def forward(self, prev_nodes):
mutable_cfg.update(module_kwargs=module_kwargs)
mutable_cfg.update(alias=edge_id)
edges.add_module(edge_id, MODELS.build(mutable_cfg))
route_cfg.update(edges=edges)
self.edges = MODELS.build(route_cfg)
def forward(self, prev_nodes: Union[List[Tensor],
Tuple[Tensor]]) -> Tensor:
"""Forward with the previous nodes list."""
return self.edges(prev_nodes)
class Cell(nn.Module):
"""Darts cell structure.
Args:
num_nodes (int): number of nodes.
channels (int): number of channels of current cell.
prev_channels (int): number of channel of previous input.
prev_prev_channels (int): number of channel of previous previous input.
reduction (bool): whether to reduce the feature map size.
prev_reduction (bool): whether to reduce the previous feature map size.
mutable_cfg (Optional[Dict]): config of `DiffMutable`.
route_cfg (Optional[Dict]): config of `DiffChoiceRoute`.
act_cfg (Dict): config to build activation layer.
Defaults to dict(type='ReLU').
norm_cfg (Dict): config to build normalization layer.
Defaults to dict(type='BN').
"""
def __init__(
self,
num_nodes: int,
channels: int,
prev_channels: int,
prev_prev_channels: int,
reduction: bool,
prev_reduction: bool,
mutable_cfg: Dict,
route_cfg: Dict,
act_cfg: Dict = dict(type='ReLU'),
norm_cfg: Dict = dict(type='BN'),
) -> None:
def __init__(self,
num_nodes,
channels,
prev_channels,
prev_prev_channels,
reduction,
prev_reduction,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')):
super().__init__()
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
@ -152,11 +211,12 @@ class Cell(nn.Module):
node_id = f'normal_n{depth}'
num_downsample_nodes = 0
self.nodes.append(
Node(node_id, depth, channels, num_downsample_nodes))
Node(node_id, depth, channels, num_downsample_nodes,
mutable_cfg, route_cfg))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell,
# respectively.
def forward(self, s0: Tensor, s1: Tensor) -> Tensor:
"""Forward with the outputs of previous previous cell and previous
cell."""
tensors = [self.preproc0(s0), self.preproc1(s1)]
for node in self.nodes:
cur_tensor = node(tensors)
@ -167,14 +227,21 @@ class Cell(nn.Module):
class AuxiliaryModule(nn.Module):
"""Auxiliary head in 2/3 place of network to let the gradient flow well."""
"""Auxiliary head in 2/3 place of network to let the gradient flow well.
Args:
in_channels (int): number of channels of inputs.
base_channels (int): number of middle channels of the auxiliary module.
out_channels (int): number of channels of outputs.
norm_cfg (Dict): config to build normalization layer.
Defaults to dict(type='BN').
"""
def __init__(self,
in_channels,
base_channels,
out_channels,
norm_cfg=dict(type='BN')):
in_channels: int,
base_channels: int,
out_channels: int,
norm_cfg: Dict = dict(type='BN')) -> None:
super().__init__()
self.norm_cfg = norm_cfg
self.net = nn.Sequential(
@ -189,25 +256,56 @@ class AuxiliaryModule(nn.Module):
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True))
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward the auxiliary module."""
return self.net(x)
@MODELS.register_module()
class DartsBackbone(nn.Module):
"""Backbone of Differentiable Architecture Search (DARTS).
def __init__(self,
in_channels,
base_channels,
num_layers=8,
num_nodes=4,
stem_multiplier=3,
out_indices=(7, ),
auxliary=False,
aux_channels=None,
aux_out_channels=None,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')):
Args:
in_channels (int): number of channels of input tensor.
base_channels (int): number of middle channels.
mutable_cfg (Optional[Dict]): config of `DiffMutable`.
route_cfg (Optional[Dict]): config of `DiffChoiceRoute`.
num_layers (Optional[int]): number of layers.
Defaults to 8.
num_nodes (Optional[int]): number of nodes.
Defaults to 4.
stem_multiplier (Optional[int]): multiplier for stem.
Defaults to 3.
out_indices (tuple, optional): output indices for auxliary module.
Defaults to (7, ).
auxliary (bool, optional): whether use auxliary module.
Defaults to False.
aux_channels (Optional[int]): number of middle channels of
auxliary module. Defaults to None.
aux_out_channels (Optional[int]): number of output channels of
auxliary module. Defaults to None.
act_cfg (Dict): config to build activation layer.
Defaults to dict(type='ReLU').
norm_cfg (Dict): config to build normalization layer.
Defaults to dict(type='BN').
"""
def __init__(
self,
in_channels: int,
base_channels: int,
mutable_cfg: Dict,
route_cfg: Dict,
num_layers: int = 8,
num_nodes: int = 4,
stem_multiplier: int = 3,
out_indices: Union[Tuple, List] = (7, ),
auxliary: bool = False,
aux_channels: Optional[int] = None,
aux_out_channels: Optional[int] = None,
act_cfg: Dict = dict(type='ReLU'),
norm_cfg: Dict = dict(type='BN'),
) -> None:
super().__init__()
self.in_channels = in_channels
@ -237,7 +335,7 @@ class DartsBackbone(nn.Module):
build_norm_layer(self.norm_cfg, self.out_channels)[1])
# for the first cell, stem is used for both s0 and s1
# [!] prev_prev_channels and prev_channels is output channel size,
# prev_prev_channels and prev_channels is output channel size,
# but c_cur is input channel size.
prev_prev_channels = self.out_channels
prev_channels = self.out_channels
@ -255,7 +353,7 @@ class DartsBackbone(nn.Module):
cell = Cell(self.num_nodes, self.out_channels, prev_channels,
prev_prev_channels, reduction, prev_reduction,
self.act_cfg, self.norm_cfg)
mutable_cfg, route_cfg, self.act_cfg, self.norm_cfg)
self.cells.append(cell)
prev_prev_channels = prev_channels
@ -267,7 +365,8 @@ class DartsBackbone(nn.Module):
self.aux_out_channels,
self.norm_cfg)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward the darts backbone."""
outs = []
s0 = s1 = self.stem(x)
for i, cell in enumerate(self.cells):
@ -276,7 +375,6 @@ class DartsBackbone(nn.Module):
outs.append(s1)
if i == self.auxliary_indice and self.training:
aux_feature = self.auxliary_module(s1)
outs.insert(0, aux_feature)
return tuple(outs)

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_backbone import DartsBackbone
from .searchable_mobilenet import SearchableMobileNet
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2', 'SearchableMobileNet']
__all__ = ['SearchableShuffleNetV2', 'SearchableMobileNet']

View File

@ -25,6 +25,7 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
Args:
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -33,10 +34,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
def __init__(self,
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.module_kwargs = module_kwargs
self.alias = alias
self._is_fixed = False
self._current_choice: Optional[CHOICE_TYPE] = None

View File

@ -20,6 +20,7 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
Args:
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -31,8 +32,10 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
def __init__(self,
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
super().__init__(
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
def forward(self,
x: Any,
@ -103,6 +106,7 @@ class DiffOP(DiffMutable[str, str]):
operations.
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -113,9 +117,11 @@ class DiffOP(DiffMutable[str, str]):
self,
candidate_ops: Dict[str, Dict],
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
super().__init__(
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
assert len(candidate_ops) >= 1, \
f'Number of candidate op must greater than or equal to 1, ' \
f'but got: {len(candidate_ops)}'

View File

@ -24,6 +24,7 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
Args:
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -35,8 +36,10 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
def __init__(self,
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
super().__init__(
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
def forward(self, x: Any) -> Any:
"""Calls either :func:`forward_fixed` or :func:`forward_choice`
@ -107,6 +110,7 @@ class OneShotOP(OneShotMutable[str, str]):
operations.
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -153,9 +157,14 @@ class OneShotOP(OneShotMutable[str, str]):
self,
candidate_ops: Union[Dict[str, Dict], nn.ModuleDict],
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
super().__init__(
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
assert len(candidate_ops) >= 1, \
f'Number of candidate op must greater than 1, ' \
f'but got: {len(candidate_ops)}'
self._is_fixed = False
self._chosen: Optional[str] = None
@ -284,6 +293,7 @@ class OneShotProbOP(OneShotOP):
candidate operation.
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
@ -294,10 +304,12 @@ class OneShotProbOP(OneShotOP):
candidate_ops: Dict[str, Dict],
choice_probs: list = None,
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(
candidate_ops=candidate_ops,
module_kwargs=module_kwargs,
alias=alias,
init_cfg=init_cfg)
assert choice_probs is not None
assert sum(choice_probs) - 1 < np.finfo(np.float64).eps, \

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generic, List, Optional, Type, TypeVar
from mmcv.runner import BaseModule
@ -98,7 +99,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
return self._build_search_group(supernet)
self._build_search_group(supernet)
@property
def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]:
@ -119,43 +120,193 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
'Call `prepare_from_supernet` before access search group!')
return self._search_group
def _build_name_mutable_mapping(
self, supernet: Module) -> Dict[str, MUTABLE_TYPE]:
"""Mapping module name to mutable."""
name2mutable: Dict[str, MUTABLE_TYPE] = dict()
for name, module in supernet.named_modules():
if isinstance(module, self.mutable_class_type):
name2mutable[name] = module
return name2mutable
def _build_alias_names_mapping(self,
supernet: Module) -> Dict[str, List[str]]:
"""Mapping alias to module names."""
alias2mutable_names: Dict[str, List[str]] = dict()
for name, module in supernet.named_modules():
if isinstance(module, self.mutable_class_type):
if module.alias is not None:
if module.alias not in alias2mutable_names:
alias2mutable_names[module.alias] = [name]
else:
alias2mutable_names[module.alias].append(name)
return alias2mutable_names
def _build_search_group(self, supernet: Module) -> None:
"""Build search group with ``supernet`` and ``custom_group``.
"""Build search group with ``custom_group`` and ``alias``(see more
information in :class:`BaseMutable`). Grouping by alias and module name
are both supported.
Note:
Apart from user-defined search group, all other searchable
modules(mutable) will be grouped separately.
The main difference between using alias and module name for
grouping is that the alias is One-to-Many while the module
name is One-to-One.
When using both alias and module name in `custom_group`, the
priority of alias is higher than that of module name.
If alias is set in `custom_group`, then its corresponding module
name should not be in the `custom_group`.
Moreover, there should be no duplicate keys in the `custom_group`.
Example:
>>> import torch
>>> from mmrazor.models.mutables.diff_mutable import DiffOP
>>> # Assume that a toy model consists of three mutabels
>>> # whose name are op1,op2,op3. The corresponding
>>> # alias names of the three mutables are a1, a1, a2.
>>> model = ToyModel()
>>> # Using alias for grouping
>>> mutator = DiffOP(custom_group=[['a1'], ['a2']])
>>> mutator.prepare_from_supernet(model)
>>> mutator.search_group
{0: [op1, op2], 1: [op3]}
>>> # Using module name for grouping
>>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']])
>>> mutator.prepare_from_supernet(model)
>>> mutator.search_group
{0: [op1, op2], 1: [op3]}
>>> # Using both alias and module name for grouping
>>> mutator = DiffOP(custom_group=[['a2'], ['op2']])
>>> mutator.prepare_from_supernet(model)
>>> # The last operation would be grouped
>>> mutator.search_group
{0: [op3], 1: [op2], 2: [op1]}
Args:
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
module_name2module: Dict[str, Module] = dict()
for name, module in supernet.named_modules():
module_name2module[name] = module
name2mutable = self._build_name_mutable_mapping(supernet)
alias2mutable_names = self._build_alias_names_mapping(supernet)
# Map module to group id for user-defined group
self.module_name2group_id: Dict[str, int] = dict()
for idx, group in enumerate(self._custom_group):
for module_name in group:
assert module_name in module_name2module, \
f'`{module_name}` is not a module name of supernet, ' \
f'expected module names: {module_name2module.keys()}'
self.module_name2group_id[module_name] = idx
# Check whether the custom group is valid
if len(self._custom_group) > 0:
self._check_valid_groups(alias2mutable_names, name2mutable,
self._custom_group)
search_group: Dict[int, List[MUTABLE_TYPE]] = dict()
current_group_nums = len(self._custom_group)
# Construct search_groups based on user-defined group
search_groups: Dict[int, List[MUTABLE_TYPE]] = dict()
for name, module in module_name2module.items():
if isinstance(module, self.mutable_class_type):
group_id = self.module_name2group_id.get(name)
if group_id is None:
group_id = current_group_nums
current_group_nums = 0
grouped_mutable_names: List[str] = list()
grouped_alias: List[str] = list()
for group in self._custom_group:
group_mutables = list()
for item in group:
if item in alias2mutable_names:
# if the item is from alias name
mutable_names: List[str] = alias2mutable_names[item]
grouped_alias.append(item)
group_mutables.extend(
[name2mutable[n] for n in mutable_names])
grouped_mutable_names.extend(mutable_names)
else:
# if the item is in name2mutable
group_mutables.append(name2mutable[item])
grouped_mutable_names.append(item)
search_groups[current_group_nums] = group_mutables
current_group_nums += 1
# Construct search_groups based on alias
for alias, mutable_names in alias2mutable_names.items():
if alias not in grouped_alias:
# Check whether all current names are already grouped
flag_all_grouped = True
for mutable_name in mutable_names:
if mutable_name not in grouped_mutable_names:
flag_all_grouped = False
# If not all mutables are already grouped
if not flag_all_grouped:
search_groups[current_group_nums] = []
for mutable_name in mutable_names:
if mutable_name not in grouped_mutable_names:
search_groups[current_group_nums].append(
name2mutable[mutable_name])
grouped_mutable_names.append(mutable_name)
current_group_nums += 1
try:
search_group[group_id].append(module)
except KeyError:
search_group[group_id] = [module]
self.module_name2group_id[name] = group_id
self._search_group = search_group
# check whether all the mutable objects are in the search_groups
for name, module in supernet.named_modules():
if isinstance(module, self.mutable_class_type):
if name in grouped_mutable_names:
continue
else:
search_groups[current_group_nums] = [module]
current_group_nums += 1
grouped_counter = Counter(grouped_mutable_names)
# find duplicate keys
duplicate_keys = list()
for key, count in grouped_counter.items():
if count > 1:
duplicate_keys.append(key)
assert len(grouped_mutable_names) == len(
list(set(grouped_mutable_names))), \
'There are duplicate keys in grouped mutable names. ' \
f'The duplicate keys are {duplicate_keys}. ' \
'Please check if there are duplicate keys in the `custom_group`.'
self._search_group = search_groups
def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
name2mutable: Dict[str, MUTABLE_TYPE],
custom_group: List[List[str]]) -> None:
aliases = [*alias2mutable_names.keys()]
module_names = [*name2mutable.keys()]
# check if all keys are legal
expanded_custom_group: List[str] = [
_ for group in custom_group for _ in group
]
legal_keys: List[str] = [*aliases, *module_names]
for key in expanded_custom_group:
if key not in legal_keys:
raise AssertionError(
f'The key: {key} in `custom_group` is not legal. '
f'Legal keys are: {legal_keys}. '
'Make sure that the keys are either alias or mutable name')
# when the mutable has alias attribute, the corresponding module
# name should not be used in `custom_group`.
used_aliases = list()
for group in custom_group:
for key in group:
if key in aliases:
used_aliases.append(key)
for alias_key in used_aliases:
mutable_names: List = alias2mutable_names[alias_key]
# check whether module name is in custom group
for mutable_name in mutable_names:
if mutable_name in expanded_custom_group:
raise AssertionError(
f'When a mutable is set alias attribute :{alias_key},'
f'the corresponding module name {mutable_name} should '
f'not be used in `custom_group` {custom_group}.')

View File

@ -12,6 +12,10 @@ from .base_mutator import ArchitectureMutator
class DiffMutator(ArchitectureMutator[DiffMutable]):
"""Differentiable mutable based mutator.
`DiffMutator` is responsible for mutating `DiffMutable`, `DiffOP`,
`DiffChoiceRoute` and `GumbelChoiceRoute`. The architecture
parameters of the mutables are retained in `DiffMutator`.
Args:
custom_group (list[list[str]], optional): User-defined search groups.
All searchable modules that are not in ``custom_group`` will be
@ -21,7 +25,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
def __init__(self,
custom_group: Optional[List[List[str]]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(custom_group, init_cfg)
super().__init__(custom_group=custom_group, init_cfg=init_cfg)
def prepare_from_supernet(self, supernet: nn.Module) -> None:
"""Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS.
@ -32,31 +36,23 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
"""
super().prepare_from_supernet(supernet)
self.arch_params = self.build_arch_params(supernet)
self.arch_params = self.build_arch_params()
def build_arch_params(self, supernet):
def build_arch_params(self):
"""This function will build many arch params, which are generally used
in differentiable search algorithms, such as Darts' series. Each
group_id corresponds to an arch param, so the Mutables with the same
group_id share the same arch param.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
Returns:
torch.nn.ParameterDict: the arch params are got after traversing
the supernet.
torch.nn.ParameterDict: the arch params are got by `search_group`.
"""
arch_params: Dict[int, nn.Parameter] = dict()
for module_name, module in supernet.named_modules():
if isinstance(module, self.mutable_class_type):
group_id = self.module_name2group_id[module_name]
if group_id not in arch_params:
group_arch_param = module.build_arch_param()
if group_arch_param is not None:
arch_params[group_id] = group_arch_param
for group_id, modules in self.search_group.items():
group_arch_param = modules[0].build_arch_param()
arch_params[group_id] = group_arch_param
return arch_params
@ -70,10 +66,8 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
"""
for group_id, modules in self.search_group.items():
if group_id in self.arch_params.keys():
for module in modules:
module.set_forward_args(
arch_param=self.arch_params[group_id])
for module in modules:
module.set_forward_args(arch_param=self.arch_params[group_id])
@property
def mutable_class_type(self) -> Type[DiffMutable]:

View File

@ -0,0 +1,118 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase
import torch
import torch.nn as nn
from mmcls.models import * # noqa:F403,F401
from mmrazor.models import * # noqa:F403,F401
from mmrazor.models.architectures.components import * # noqa:F403,F401
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
class TestDartsBackbone(TestCase):
def setUp(self) -> None:
self.mutable_cfg = dict(
type='DiffOP',
candidate_ops=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
))
self.route_cfg = dict(
type='DiffChoiceRoute',
with_arch_param=True,
)
self.backbone_cfg = dict(
type='mmrazor.DartsBackbone',
in_channels=3,
base_channels=16,
num_layers=8,
num_nodes=4,
stem_multiplier=3,
out_indices=(7, ),
mutable_cfg=self.mutable_cfg,
route_cfg=self.route_cfg)
self.mutator_cfg = dict(
type='DiffMutator',
custom_group=None,
)
def test_darts_backbone(self):
model = MODELS.build(self.backbone_cfg)
custom_group = self.generate_key(model)
assert model is not None
self.mutable_cfg.update(custom_group=custom_group)
mutator = MODELS.build(self.mutator_cfg)
assert mutator is not None
mutator.prepare_from_supernet(model)
mutator.modify_supernet_forward()
inputs = torch.randn(4, 3, 224, 224)
outputs = model(inputs)
assert outputs is not None
def test_darts_backbone_with_auxliary(self):
self.backbone_cfg.update(
auxliary=True, aux_channels=256, aux_out_channels=512)
model = MODELS.build(self.backbone_cfg)
custom_group = self.generate_key(model)
assert model is not None
self.mutable_cfg.update(custom_group=custom_group)
mutator = MODELS.build(self.mutator_cfg)
assert mutator is not None
mutator.prepare_from_supernet(model)
mutator.modify_supernet_forward()
inputs = torch.randn(4, 3, 224, 224)
outputs = model(inputs)
assert outputs is not None
def generate_key(self, model):
"""auto generate custom group for darts."""
tmp_dict = dict()
for key, _ in model.named_modules():
node_type = key.split('._candidate_ops')[0].split('.')[-1].split(
'_')[0]
if node_type not in ['normal', 'reduce']:
# not supported type
continue
node_name = key.split('._candidate_ops')[0].split('.')[-1]
if node_name not in tmp_dict.keys():
tmp_dict[node_name] = [key.split('._candidate_ops')[0]]
else:
current_key = key.split('._candidate_ops')[0]
if current_key not in tmp_dict[node_name]:
tmp_dict[node_name].append(current_key)
return list(tmp_dict.values())
if __name__ == '__main__':
unittest.main()

View File

@ -124,12 +124,14 @@ def test_searchable_shufflenet_v2_init_weights() -> None:
bias_tensor *= 0.0001
assert torch.equal(bias_tensor, m.bias)
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
temp_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(temp_dir, 'checkpoint.pth')
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
backbone = MODELS.build(backbone_cfg)
torch.save(backbone.state_dict(), checkpoint_path)
backbone_cfg['init_cfg'] = dict(
type='Pretrained', checkpoint=checkpoint_path)
backbone = MODELS.build(backbone_cfg)
torch.save(backbone.state_dict(), checkpoint_path)
name2weight = dict()
for name, m in backbone.named_modules():

View File

@ -3,7 +3,6 @@ from unittest import TestCase
import pytest
import torch.nn as nn
from mmcls.models import * # noqa: F401,F403
from mmrazor.models import * # noqa: F401,F403
from mmrazor.models.mutables import DiffMutable, DiffOP
@ -42,6 +41,37 @@ class SearchableModel(nn.Module):
return self.slayer3(x)
class SearchableLayerAlias(nn.Module):
def __init__(self, mutable_cfg: dict) -> None:
super().__init__()
mutable_cfg.update(alias='op1')
self.op1 = MODELS.build(mutable_cfg)
mutable_cfg.update(alias='op2')
self.op2 = MODELS.build(mutable_cfg)
mutable_cfg.update(alias='op3')
self.op3 = MODELS.build(mutable_cfg)
def forward(self, x):
x = self.op1(x)
x = self.op2(x)
return self.op3(x)
class SearchableModelAlias(nn.Module):
def __init__(self, mutable_cfg: dict) -> None:
super().__init__()
self.slayer1 = SearchableLayerAlias(mutable_cfg)
self.slayer2 = SearchableLayerAlias(mutable_cfg)
self.slayer3 = SearchableLayerAlias(mutable_cfg)
def forward(self, x):
x = self.slayer1(x)
x = self.slayer2(x)
return self.slayer3(x)
class TestDiffMutator(TestCase):
def setUp(self):
@ -107,6 +137,81 @@ class TestDiffMutator(TestCase):
with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
def test_diff_mutator_diffop_alias(self) -> None:
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']]
mutator: DiffOP = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
mutator.modify_supernet_forward()
assert mutator.mutable_class_type == DiffMutable
def test_diff_mutator_alias_module_name(self) -> None:
"""Using both alias and module name for grouping."""
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [['op1'],
[
'slayer1.op2', 'slayer2.op2',
'slayer3.op2'
], ['slayer1.op3', 'slayer2.op3']]
mutator: DiffOP = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2, 3]
mutator.modify_supernet_forward()
assert mutator.mutable_class_type == DiffMutable
def test_diff_mutator_duplicate_keys(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer2.op3'],
]
mutator: DiffOP = MODELS.build(mutator_cfg)
with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
def test_diff_mutator_duplicate_key_alias(self) -> None:
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [
['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffOP = MODELS.build(mutator_cfg)
with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
def test_diff_mutator_illegal_key(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
mutator_cfg['custom_group'] = [
['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
]
mutator: DiffOP = MODELS.build(mutator_cfg)
with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
if __name__ == '__main__':
import unittest

View File

@ -13,17 +13,17 @@ from mmrazor.registry import MODELS
MODEL_CFG = dict(
type='mmcls.ImageClassifier',
backbone=dict(
type='ResNet',
type='mmcls.ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
neck=dict(type='mmcls.GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
type='mmcls.LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
@ -87,6 +87,7 @@ def test_one_shot_mutator_mutable_model() -> None:
model = _SearchableModel()
mutator: OneShotMutator = MODELS.build(MUTATOR_CFG)
# import pdb; pdb.set_trace()
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
@ -116,3 +117,7 @@ def test_one_shot_mutator_mutable_model() -> None:
mutator = MODELS.build(mutator_cfg)
with pytest.raises(AssertionError):
mutator.prepare_from_supernet(model)
if __name__ == '__main__':
pytest.main()