Refactor DartsBackbone
parent
577a3a2a94
commit
56afc69d85
|
@ -6,4 +6,8 @@ from .mmdet import MMDetArchitecture
|
|||
from .mmseg import MMSegArchitecture
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
__all__ = ['MMClsArchitecture', 'MMDetArchitecture', 'MMSegArchitecture']
|
||||
__all__ = [
|
||||
'MMClsArchitecture',
|
||||
'MMDetArchitecture',
|
||||
'MMSegArchitecture',
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .darts_backbone import DartsBackbone
|
||||
from .searchable_mobilenet import SearchableMobileNet
|
||||
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
|
||||
|
||||
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2']
|
||||
__all__ = ['SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone']
|
||||
|
|
|
@ -1,22 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from ...utils import Placeholder
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
"""Reduce feature map size by factorized pointwise (stride=2)."""
|
||||
"""Reduce feature map size by factorized pointwise (stride=2).
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN')):
|
||||
Args:
|
||||
in_channels (int): number of channels of input tensor.
|
||||
out_channels (int): number of channels of output tensor.
|
||||
act_cfg (Dict): config to build activation layer.
|
||||
norm_cfg (Dict): config to build normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_cfg: Dict = dict(type='BN')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -39,7 +49,8 @@ class FactorizedReduce(nn.Module):
|
|||
bias=False)
|
||||
self.bn = build_norm_layer(self.norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward with factorized reduce."""
|
||||
x = self.relu(x)
|
||||
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
|
@ -47,18 +58,31 @@ class FactorizedReduce(nn.Module):
|
|||
|
||||
|
||||
class StandardConv(nn.Module):
|
||||
"""
|
||||
Standard conv: ReLU - Conv - BN
|
||||
"""Standard Convolution in Darts. Basic structure is ReLU-Conv-BN.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels of input tensor.
|
||||
out_channels (int): number of channels of output tensor.
|
||||
kernel_size (Union[int, Tuple]): size of the convolving kernel.
|
||||
stride (Union[int, Tuple]): controls the stride for the
|
||||
cross-correlation, a single number or a one-element tuple.
|
||||
Default to 1.
|
||||
padding (Union[str, int, Tuple]): Padding added to both sides
|
||||
of the input. Default to 0.
|
||||
act_cfg (Dict): config to build activation layer.
|
||||
norm_cfg (Dict): config to build normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN')):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple],
|
||||
stride: Union[int, Tuple] = 1,
|
||||
padding: Union[str, int, Tuple] = 0,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_cfg: Dict = dict(type='BN')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -78,14 +102,26 @@ class StandardConv(nn.Module):
|
|||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, self.out_channels)[1])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward the standard convolution."""
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Node(nn.Module):
|
||||
"""Node structure of DARTS.
|
||||
|
||||
def __init__(self, node_id, num_prev_nodes, channels,
|
||||
num_downsample_nodes):
|
||||
Args:
|
||||
node_id (str): key of the node.
|
||||
num_prev_nodes (int): number of previous nodes.
|
||||
channels (int): number of channels of current node.
|
||||
num_downsample_nodes (int): index of downsample node.
|
||||
mutable_cfg (Dict): config of `DiffMutable`.
|
||||
route_cfg (Dict): config of `DiffChoiceRoute`.
|
||||
"""
|
||||
|
||||
def __init__(self, node_id: str, num_prev_nodes: int, channels: int,
|
||||
num_downsample_nodes: int, mutable_cfg: Dict,
|
||||
route_cfg: Dict) -> None:
|
||||
super().__init__()
|
||||
edges = nn.ModuleDict()
|
||||
for i in range(num_prev_nodes):
|
||||
|
@ -95,35 +131,58 @@ class Node(nn.Module):
|
|||
stride = 1
|
||||
|
||||
edge_id = '{}_p{}'.format(node_id, i)
|
||||
edges.add_module(
|
||||
edge_id,
|
||||
nn.Sequential(
|
||||
Placeholder(
|
||||
group='node',
|
||||
space_id=edge_id,
|
||||
choice_args=dict(
|
||||
stride=stride,
|
||||
in_channels=channels,
|
||||
out_channels=channels)), ))
|
||||
|
||||
self.edges = Placeholder(
|
||||
group='node_edge', space_id=node_id, choices=edges)
|
||||
module_kwargs = dict(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
def forward(self, prev_nodes):
|
||||
mutable_cfg.update(module_kwargs=module_kwargs)
|
||||
mutable_cfg.update(alias=edge_id)
|
||||
edges.add_module(edge_id, MODELS.build(mutable_cfg))
|
||||
|
||||
route_cfg.update(edges=edges)
|
||||
self.edges = MODELS.build(route_cfg)
|
||||
|
||||
def forward(self, prev_nodes: Union[List[Tensor],
|
||||
Tuple[Tensor]]) -> Tensor:
|
||||
"""Forward with the previous nodes list."""
|
||||
return self.edges(prev_nodes)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""Darts cell structure.
|
||||
|
||||
Args:
|
||||
num_nodes (int): number of nodes.
|
||||
channels (int): number of channels of current cell.
|
||||
prev_channels (int): number of channel of previous input.
|
||||
prev_prev_channels (int): number of channel of previous previous input.
|
||||
reduction (bool): whether to reduce the feature map size.
|
||||
prev_reduction (bool): whether to reduce the previous feature map size.
|
||||
mutable_cfg (Optional[Dict]): config of `DiffMutable`.
|
||||
route_cfg (Optional[Dict]): config of `DiffChoiceRoute`.
|
||||
act_cfg (Dict): config to build activation layer.
|
||||
Defaults to dict(type='ReLU').
|
||||
norm_cfg (Dict): config to build normalization layer.
|
||||
Defaults to dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_nodes: int,
|
||||
channels: int,
|
||||
prev_channels: int,
|
||||
prev_prev_channels: int,
|
||||
reduction: bool,
|
||||
prev_reduction: bool,
|
||||
mutable_cfg: Dict,
|
||||
route_cfg: Dict,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
) -> None:
|
||||
|
||||
def __init__(self,
|
||||
num_nodes,
|
||||
channels,
|
||||
prev_channels,
|
||||
prev_prev_channels,
|
||||
reduction,
|
||||
prev_reduction,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN')):
|
||||
super().__init__()
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
|
@ -152,11 +211,12 @@ class Cell(nn.Module):
|
|||
node_id = f'normal_n{depth}'
|
||||
num_downsample_nodes = 0
|
||||
self.nodes.append(
|
||||
Node(node_id, depth, channels, num_downsample_nodes))
|
||||
Node(node_id, depth, channels, num_downsample_nodes,
|
||||
mutable_cfg, route_cfg))
|
||||
|
||||
def forward(self, s0, s1):
|
||||
# s0, s1 are the outputs of previous previous cell and previous cell,
|
||||
# respectively.
|
||||
def forward(self, s0: Tensor, s1: Tensor) -> Tensor:
|
||||
"""Forward with the outputs of previous previous cell and previous
|
||||
cell."""
|
||||
tensors = [self.preproc0(s0), self.preproc1(s1)]
|
||||
for node in self.nodes:
|
||||
cur_tensor = node(tensors)
|
||||
|
@ -167,14 +227,21 @@ class Cell(nn.Module):
|
|||
|
||||
|
||||
class AuxiliaryModule(nn.Module):
|
||||
"""Auxiliary head in 2/3 place of network to let the gradient flow well."""
|
||||
"""Auxiliary head in 2/3 place of network to let the gradient flow well.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels of inputs.
|
||||
base_channels (int): number of middle channels of the auxiliary module.
|
||||
out_channels (int): number of channels of outputs.
|
||||
norm_cfg (Dict): config to build normalization layer.
|
||||
Defaults to dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
base_channels,
|
||||
out_channels,
|
||||
norm_cfg=dict(type='BN')):
|
||||
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
out_channels: int,
|
||||
norm_cfg: Dict = dict(type='BN')) -> None:
|
||||
super().__init__()
|
||||
self.norm_cfg = norm_cfg
|
||||
self.net = nn.Sequential(
|
||||
|
@ -189,25 +256,56 @@ class AuxiliaryModule(nn.Module):
|
|||
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward the auxiliary module."""
|
||||
return self.net(x)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DartsBackbone(nn.Module):
|
||||
"""Backbone of Differentiable Architecture Search (DARTS).
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
base_channels,
|
||||
num_layers=8,
|
||||
num_nodes=4,
|
||||
stem_multiplier=3,
|
||||
out_indices=(7, ),
|
||||
auxliary=False,
|
||||
aux_channels=None,
|
||||
aux_out_channels=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN')):
|
||||
Args:
|
||||
in_channels (int): number of channels of input tensor.
|
||||
base_channels (int): number of middle channels.
|
||||
mutable_cfg (Optional[Dict]): config of `DiffMutable`.
|
||||
route_cfg (Optional[Dict]): config of `DiffChoiceRoute`.
|
||||
num_layers (Optional[int]): number of layers.
|
||||
Defaults to 8.
|
||||
num_nodes (Optional[int]): number of nodes.
|
||||
Defaults to 4.
|
||||
stem_multiplier (Optional[int]): multiplier for stem.
|
||||
Defaults to 3.
|
||||
out_indices (tuple, optional): output indices for auxliary module.
|
||||
Defaults to (7, ).
|
||||
auxliary (bool, optional): whether use auxliary module.
|
||||
Defaults to False.
|
||||
aux_channels (Optional[int]): number of middle channels of
|
||||
auxliary module. Defaults to None.
|
||||
aux_out_channels (Optional[int]): number of output channels of
|
||||
auxliary module. Defaults to None.
|
||||
act_cfg (Dict): config to build activation layer.
|
||||
Defaults to dict(type='ReLU').
|
||||
norm_cfg (Dict): config to build normalization layer.
|
||||
Defaults to dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
mutable_cfg: Dict,
|
||||
route_cfg: Dict,
|
||||
num_layers: int = 8,
|
||||
num_nodes: int = 4,
|
||||
stem_multiplier: int = 3,
|
||||
out_indices: Union[Tuple, List] = (7, ),
|
||||
auxliary: bool = False,
|
||||
aux_channels: Optional[int] = None,
|
||||
aux_out_channels: Optional[int] = None,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
@ -237,7 +335,7 @@ class DartsBackbone(nn.Module):
|
|||
build_norm_layer(self.norm_cfg, self.out_channels)[1])
|
||||
|
||||
# for the first cell, stem is used for both s0 and s1
|
||||
# [!] prev_prev_channels and prev_channels is output channel size,
|
||||
# prev_prev_channels and prev_channels is output channel size,
|
||||
# but c_cur is input channel size.
|
||||
prev_prev_channels = self.out_channels
|
||||
prev_channels = self.out_channels
|
||||
|
@ -255,7 +353,7 @@ class DartsBackbone(nn.Module):
|
|||
|
||||
cell = Cell(self.num_nodes, self.out_channels, prev_channels,
|
||||
prev_prev_channels, reduction, prev_reduction,
|
||||
self.act_cfg, self.norm_cfg)
|
||||
mutable_cfg, route_cfg, self.act_cfg, self.norm_cfg)
|
||||
self.cells.append(cell)
|
||||
|
||||
prev_prev_channels = prev_channels
|
||||
|
@ -267,7 +365,8 @@ class DartsBackbone(nn.Module):
|
|||
self.aux_out_channels,
|
||||
self.norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward the darts backbone."""
|
||||
outs = []
|
||||
s0 = s1 = self.stem(x)
|
||||
for i, cell in enumerate(self.cells):
|
||||
|
@ -276,7 +375,6 @@ class DartsBackbone(nn.Module):
|
|||
outs.append(s1)
|
||||
if i == self.auxliary_indice and self.training:
|
||||
aux_feature = self.auxliary_module(s1)
|
||||
|
||||
outs.insert(0, aux_feature)
|
||||
|
||||
return tuple(outs)
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .darts_backbone import DartsBackbone
|
||||
from .searchable_mobilenet import SearchableMobileNet
|
||||
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
|
||||
|
||||
__all__ = ['DartsBackbone', 'SearchableShuffleNetV2', 'SearchableMobileNet']
|
||||
__all__ = ['SearchableShuffleNetV2', 'SearchableMobileNet']
|
||||
|
|
|
@ -25,6 +25,7 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
Args:
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -33,10 +34,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
|
||||
def __init__(self,
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.module_kwargs = module_kwargs
|
||||
self.alias = alias
|
||||
self._is_fixed = False
|
||||
self._current_choice: Optional[CHOICE_TYPE] = None
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
Args:
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -31,8 +32,10 @@ class DiffMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
|
||||
def __init__(self,
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
|
||||
super().__init__(
|
||||
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
|
||||
|
||||
def forward(self,
|
||||
x: Any,
|
||||
|
@ -103,6 +106,7 @@ class DiffOP(DiffMutable[str, str]):
|
|||
operations.
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -113,9 +117,11 @@ class DiffOP(DiffMutable[str, str]):
|
|||
self,
|
||||
candidate_ops: Dict[str, Dict],
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
|
||||
super().__init__(
|
||||
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
|
||||
assert len(candidate_ops) >= 1, \
|
||||
f'Number of candidate op must greater than or equal to 1, ' \
|
||||
f'but got: {len(candidate_ops)}'
|
||||
|
|
|
@ -24,6 +24,7 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
Args:
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -35,8 +36,10 @@ class OneShotMutable(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
|
||||
def __init__(self,
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
|
||||
super().__init__(
|
||||
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
|
||||
|
||||
def forward(self, x: Any) -> Any:
|
||||
"""Calls either :func:`forward_fixed` or :func:`forward_choice`
|
||||
|
@ -107,6 +110,7 @@ class OneShotOP(OneShotMutable[str, str]):
|
|||
operations.
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -153,9 +157,14 @@ class OneShotOP(OneShotMutable[str, str]):
|
|||
self,
|
||||
candidate_ops: Union[Dict[str, Dict], nn.ModuleDict],
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(module_kwargs=module_kwargs, init_cfg=init_cfg)
|
||||
super().__init__(
|
||||
module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg)
|
||||
assert len(candidate_ops) >= 1, \
|
||||
f'Number of candidate op must greater than 1, ' \
|
||||
f'but got: {len(candidate_ops)}'
|
||||
|
||||
self._is_fixed = False
|
||||
self._chosen: Optional[str] = None
|
||||
|
@ -284,6 +293,7 @@ class OneShotProbOP(OneShotOP):
|
|||
candidate operation.
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
|
@ -294,10 +304,12 @@ class OneShotProbOP(OneShotOP):
|
|||
candidate_ops: Dict[str, Dict],
|
||||
choice_probs: list = None,
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(
|
||||
candidate_ops=candidate_ops,
|
||||
module_kwargs=module_kwargs,
|
||||
alias=alias,
|
||||
init_cfg=init_cfg)
|
||||
assert choice_probs is not None
|
||||
assert sum(choice_probs) - 1 < np.finfo(np.float64).eps, \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Dict, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from mmcv.runner import BaseModule
|
||||
|
@ -98,7 +99,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
return self._build_search_group(supernet)
|
||||
self._build_search_group(supernet)
|
||||
|
||||
@property
|
||||
def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]:
|
||||
|
@ -119,43 +120,193 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
'Call `prepare_from_supernet` before access search group!')
|
||||
return self._search_group
|
||||
|
||||
def _build_name_mutable_mapping(
|
||||
self, supernet: Module) -> Dict[str, MUTABLE_TYPE]:
|
||||
"""Mapping module name to mutable."""
|
||||
name2mutable: Dict[str, MUTABLE_TYPE] = dict()
|
||||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
name2mutable[name] = module
|
||||
return name2mutable
|
||||
|
||||
def _build_alias_names_mapping(self,
|
||||
supernet: Module) -> Dict[str, List[str]]:
|
||||
"""Mapping alias to module names."""
|
||||
alias2mutable_names: Dict[str, List[str]] = dict()
|
||||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
if module.alias is not None:
|
||||
if module.alias not in alias2mutable_names:
|
||||
alias2mutable_names[module.alias] = [name]
|
||||
else:
|
||||
alias2mutable_names[module.alias].append(name)
|
||||
|
||||
return alias2mutable_names
|
||||
|
||||
def _build_search_group(self, supernet: Module) -> None:
|
||||
"""Build search group with ``supernet`` and ``custom_group``.
|
||||
"""Build search group with ``custom_group`` and ``alias``(see more
|
||||
information in :class:`BaseMutable`). Grouping by alias and module name
|
||||
are both supported.
|
||||
|
||||
Note:
|
||||
Apart from user-defined search group, all other searchable
|
||||
modules(mutable) will be grouped separately.
|
||||
|
||||
The main difference between using alias and module name for
|
||||
grouping is that the alias is One-to-Many while the module
|
||||
name is One-to-One.
|
||||
|
||||
When using both alias and module name in `custom_group`, the
|
||||
priority of alias is higher than that of module name.
|
||||
|
||||
If alias is set in `custom_group`, then its corresponding module
|
||||
name should not be in the `custom_group`.
|
||||
|
||||
Moreover, there should be no duplicate keys in the `custom_group`.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> from mmrazor.models.mutables.diff_mutable import DiffOP
|
||||
|
||||
>>> # Assume that a toy model consists of three mutabels
|
||||
>>> # whose name are op1,op2,op3. The corresponding
|
||||
>>> # alias names of the three mutables are a1, a1, a2.
|
||||
>>> model = ToyModel()
|
||||
|
||||
>>> # Using alias for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['a1'], ['a2']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> mutator.search_group
|
||||
{0: [op1, op2], 1: [op3]}
|
||||
|
||||
>>> # Using module name for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> mutator.search_group
|
||||
{0: [op1, op2], 1: [op3]}
|
||||
|
||||
>>> # Using both alias and module name for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['a2'], ['op2']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> # The last operation would be grouped
|
||||
>>> mutator.search_group
|
||||
{0: [op3], 1: [op2], 2: [op1]}
|
||||
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
module_name2module: Dict[str, Module] = dict()
|
||||
for name, module in supernet.named_modules():
|
||||
module_name2module[name] = module
|
||||
name2mutable = self._build_name_mutable_mapping(supernet)
|
||||
alias2mutable_names = self._build_alias_names_mapping(supernet)
|
||||
|
||||
# Map module to group id for user-defined group
|
||||
self.module_name2group_id: Dict[str, int] = dict()
|
||||
for idx, group in enumerate(self._custom_group):
|
||||
for module_name in group:
|
||||
assert module_name in module_name2module, \
|
||||
f'`{module_name}` is not a module name of supernet, ' \
|
||||
f'expected module names: {module_name2module.keys()}'
|
||||
self.module_name2group_id[module_name] = idx
|
||||
# Check whether the custom group is valid
|
||||
if len(self._custom_group) > 0:
|
||||
self._check_valid_groups(alias2mutable_names, name2mutable,
|
||||
self._custom_group)
|
||||
|
||||
search_group: Dict[int, List[MUTABLE_TYPE]] = dict()
|
||||
current_group_nums = len(self._custom_group)
|
||||
# Construct search_groups based on user-defined group
|
||||
search_groups: Dict[int, List[MUTABLE_TYPE]] = dict()
|
||||
|
||||
for name, module in module_name2module.items():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
group_id = self.module_name2group_id.get(name)
|
||||
if group_id is None:
|
||||
group_id = current_group_nums
|
||||
current_group_nums = 0
|
||||
grouped_mutable_names: List[str] = list()
|
||||
grouped_alias: List[str] = list()
|
||||
for group in self._custom_group:
|
||||
group_mutables = list()
|
||||
for item in group:
|
||||
if item in alias2mutable_names:
|
||||
# if the item is from alias name
|
||||
mutable_names: List[str] = alias2mutable_names[item]
|
||||
grouped_alias.append(item)
|
||||
group_mutables.extend(
|
||||
[name2mutable[n] for n in mutable_names])
|
||||
grouped_mutable_names.extend(mutable_names)
|
||||
else:
|
||||
# if the item is in name2mutable
|
||||
group_mutables.append(name2mutable[item])
|
||||
grouped_mutable_names.append(item)
|
||||
|
||||
search_groups[current_group_nums] = group_mutables
|
||||
current_group_nums += 1
|
||||
|
||||
# Construct search_groups based on alias
|
||||
for alias, mutable_names in alias2mutable_names.items():
|
||||
if alias not in grouped_alias:
|
||||
# Check whether all current names are already grouped
|
||||
flag_all_grouped = True
|
||||
for mutable_name in mutable_names:
|
||||
if mutable_name not in grouped_mutable_names:
|
||||
flag_all_grouped = False
|
||||
|
||||
# If not all mutables are already grouped
|
||||
if not flag_all_grouped:
|
||||
search_groups[current_group_nums] = []
|
||||
for mutable_name in mutable_names:
|
||||
if mutable_name not in grouped_mutable_names:
|
||||
search_groups[current_group_nums].append(
|
||||
name2mutable[mutable_name])
|
||||
grouped_mutable_names.append(mutable_name)
|
||||
current_group_nums += 1
|
||||
try:
|
||||
search_group[group_id].append(module)
|
||||
except KeyError:
|
||||
search_group[group_id] = [module]
|
||||
self.module_name2group_id[name] = group_id
|
||||
|
||||
self._search_group = search_group
|
||||
# check whether all the mutable objects are in the search_groups
|
||||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
if name in grouped_mutable_names:
|
||||
continue
|
||||
else:
|
||||
search_groups[current_group_nums] = [module]
|
||||
current_group_nums += 1
|
||||
|
||||
grouped_counter = Counter(grouped_mutable_names)
|
||||
|
||||
# find duplicate keys
|
||||
duplicate_keys = list()
|
||||
for key, count in grouped_counter.items():
|
||||
if count > 1:
|
||||
duplicate_keys.append(key)
|
||||
|
||||
assert len(grouped_mutable_names) == len(
|
||||
list(set(grouped_mutable_names))), \
|
||||
'There are duplicate keys in grouped mutable names. ' \
|
||||
f'The duplicate keys are {duplicate_keys}. ' \
|
||||
'Please check if there are duplicate keys in the `custom_group`.'
|
||||
|
||||
self._search_group = search_groups
|
||||
|
||||
def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
|
||||
name2mutable: Dict[str, MUTABLE_TYPE],
|
||||
custom_group: List[List[str]]) -> None:
|
||||
|
||||
aliases = [*alias2mutable_names.keys()]
|
||||
module_names = [*name2mutable.keys()]
|
||||
|
||||
# check if all keys are legal
|
||||
expanded_custom_group: List[str] = [
|
||||
_ for group in custom_group for _ in group
|
||||
]
|
||||
legal_keys: List[str] = [*aliases, *module_names]
|
||||
|
||||
for key in expanded_custom_group:
|
||||
if key not in legal_keys:
|
||||
raise AssertionError(
|
||||
f'The key: {key} in `custom_group` is not legal. '
|
||||
f'Legal keys are: {legal_keys}. '
|
||||
'Make sure that the keys are either alias or mutable name')
|
||||
|
||||
# when the mutable has alias attribute, the corresponding module
|
||||
# name should not be used in `custom_group`.
|
||||
used_aliases = list()
|
||||
for group in custom_group:
|
||||
for key in group:
|
||||
if key in aliases:
|
||||
used_aliases.append(key)
|
||||
|
||||
for alias_key in used_aliases:
|
||||
mutable_names: List = alias2mutable_names[alias_key]
|
||||
# check whether module name is in custom group
|
||||
for mutable_name in mutable_names:
|
||||
if mutable_name in expanded_custom_group:
|
||||
raise AssertionError(
|
||||
f'When a mutable is set alias attribute :{alias_key},'
|
||||
f'the corresponding module name {mutable_name} should '
|
||||
f'not be used in `custom_group` {custom_group}.')
|
||||
|
|
|
@ -12,6 +12,10 @@ from .base_mutator import ArchitectureMutator
|
|||
class DiffMutator(ArchitectureMutator[DiffMutable]):
|
||||
"""Differentiable mutable based mutator.
|
||||
|
||||
`DiffMutator` is responsible for mutating `DiffMutable`, `DiffOP`,
|
||||
`DiffChoiceRoute` and `GumbelChoiceRoute`. The architecture
|
||||
parameters of the mutables are retained in `DiffMutator`.
|
||||
|
||||
Args:
|
||||
custom_group (list[list[str]], optional): User-defined search groups.
|
||||
All searchable modules that are not in ``custom_group`` will be
|
||||
|
@ -21,7 +25,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
|
|||
def __init__(self,
|
||||
custom_group: Optional[List[List[str]]] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(custom_group, init_cfg)
|
||||
super().__init__(custom_group=custom_group, init_cfg=init_cfg)
|
||||
|
||||
def prepare_from_supernet(self, supernet: nn.Module) -> None:
|
||||
"""Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS.
|
||||
|
@ -32,31 +36,23 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
|
|||
"""
|
||||
|
||||
super().prepare_from_supernet(supernet)
|
||||
self.arch_params = self.build_arch_params(supernet)
|
||||
self.arch_params = self.build_arch_params()
|
||||
|
||||
def build_arch_params(self, supernet):
|
||||
def build_arch_params(self):
|
||||
"""This function will build many arch params, which are generally used
|
||||
in differentiable search algorithms, such as Darts' series. Each
|
||||
group_id corresponds to an arch param, so the Mutables with the same
|
||||
group_id share the same arch param.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
Returns:
|
||||
torch.nn.ParameterDict: the arch params are got after traversing
|
||||
the supernet.
|
||||
torch.nn.ParameterDict: the arch params are got by `search_group`.
|
||||
"""
|
||||
|
||||
arch_params: Dict[int, nn.Parameter] = dict()
|
||||
|
||||
for module_name, module in supernet.named_modules():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
group_id = self.module_name2group_id[module_name]
|
||||
if group_id not in arch_params:
|
||||
group_arch_param = module.build_arch_param()
|
||||
if group_arch_param is not None:
|
||||
arch_params[group_id] = group_arch_param
|
||||
for group_id, modules in self.search_group.items():
|
||||
group_arch_param = modules[0].build_arch_param()
|
||||
arch_params[group_id] = group_arch_param
|
||||
|
||||
return arch_params
|
||||
|
||||
|
@ -70,10 +66,8 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
|
|||
"""
|
||||
|
||||
for group_id, modules in self.search_group.items():
|
||||
if group_id in self.arch_params.keys():
|
||||
for module in modules:
|
||||
module.set_forward_args(
|
||||
arch_param=self.arch_params[group_id])
|
||||
for module in modules:
|
||||
module.set_forward_args(arch_param=self.arch_params[group_id])
|
||||
|
||||
@property
|
||||
def mutable_class_type(self) -> Type[DiffMutable]:
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcls.models import * # noqa:F403,F401
|
||||
|
||||
from mmrazor.models import * # noqa:F403,F401
|
||||
from mmrazor.models.architectures.components import * # noqa:F403,F401
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
||||
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
||||
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
||||
|
||||
|
||||
class TestDartsBackbone(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.mutable_cfg = dict(
|
||||
type='DiffOP',
|
||||
candidate_ops=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
))
|
||||
|
||||
self.route_cfg = dict(
|
||||
type='DiffChoiceRoute',
|
||||
with_arch_param=True,
|
||||
)
|
||||
|
||||
self.backbone_cfg = dict(
|
||||
type='mmrazor.DartsBackbone',
|
||||
in_channels=3,
|
||||
base_channels=16,
|
||||
num_layers=8,
|
||||
num_nodes=4,
|
||||
stem_multiplier=3,
|
||||
out_indices=(7, ),
|
||||
mutable_cfg=self.mutable_cfg,
|
||||
route_cfg=self.route_cfg)
|
||||
|
||||
self.mutator_cfg = dict(
|
||||
type='DiffMutator',
|
||||
custom_group=None,
|
||||
)
|
||||
|
||||
def test_darts_backbone(self):
|
||||
model = MODELS.build(self.backbone_cfg)
|
||||
custom_group = self.generate_key(model)
|
||||
|
||||
assert model is not None
|
||||
self.mutable_cfg.update(custom_group=custom_group)
|
||||
mutator = MODELS.build(self.mutator_cfg)
|
||||
assert mutator is not None
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
mutator.modify_supernet_forward()
|
||||
|
||||
inputs = torch.randn(4, 3, 224, 224)
|
||||
outputs = model(inputs)
|
||||
assert outputs is not None
|
||||
|
||||
def test_darts_backbone_with_auxliary(self):
|
||||
self.backbone_cfg.update(
|
||||
auxliary=True, aux_channels=256, aux_out_channels=512)
|
||||
model = MODELS.build(self.backbone_cfg)
|
||||
custom_group = self.generate_key(model)
|
||||
|
||||
assert model is not None
|
||||
self.mutable_cfg.update(custom_group=custom_group)
|
||||
mutator = MODELS.build(self.mutator_cfg)
|
||||
assert mutator is not None
|
||||
mutator.prepare_from_supernet(model)
|
||||
mutator.modify_supernet_forward()
|
||||
|
||||
inputs = torch.randn(4, 3, 224, 224)
|
||||
outputs = model(inputs)
|
||||
assert outputs is not None
|
||||
|
||||
def generate_key(self, model):
|
||||
"""auto generate custom group for darts."""
|
||||
tmp_dict = dict()
|
||||
|
||||
for key, _ in model.named_modules():
|
||||
node_type = key.split('._candidate_ops')[0].split('.')[-1].split(
|
||||
'_')[0]
|
||||
if node_type not in ['normal', 'reduce']:
|
||||
# not supported type
|
||||
continue
|
||||
|
||||
node_name = key.split('._candidate_ops')[0].split('.')[-1]
|
||||
if node_name not in tmp_dict.keys():
|
||||
tmp_dict[node_name] = [key.split('._candidate_ops')[0]]
|
||||
else:
|
||||
current_key = key.split('._candidate_ops')[0]
|
||||
if current_key not in tmp_dict[node_name]:
|
||||
tmp_dict[node_name].append(current_key)
|
||||
|
||||
return list(tmp_dict.values())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -124,12 +124,14 @@ def test_searchable_shufflenet_v2_init_weights() -> None:
|
|||
bias_tensor *= 0.0001
|
||||
assert torch.equal(bias_tensor, m.bias)
|
||||
|
||||
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
checkpoint_path = os.path.join(temp_dir, 'checkpoint.pth')
|
||||
backbone_cfg = copy.deepcopy(BACKBONE_CFG)
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
torch.save(backbone.state_dict(), checkpoint_path)
|
||||
backbone_cfg['init_cfg'] = dict(
|
||||
type='Pretrained', checkpoint=checkpoint_path)
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
torch.save(backbone.state_dict(), checkpoint_path)
|
||||
|
||||
name2weight = dict()
|
||||
for name, m in backbone.named_modules():
|
||||
|
|
|
@ -3,7 +3,6 @@ from unittest import TestCase
|
|||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from mmcls.models import * # noqa: F401,F403
|
||||
|
||||
from mmrazor.models import * # noqa: F401,F403
|
||||
from mmrazor.models.mutables import DiffMutable, DiffOP
|
||||
|
@ -42,6 +41,37 @@ class SearchableModel(nn.Module):
|
|||
return self.slayer3(x)
|
||||
|
||||
|
||||
class SearchableLayerAlias(nn.Module):
|
||||
|
||||
def __init__(self, mutable_cfg: dict) -> None:
|
||||
super().__init__()
|
||||
mutable_cfg.update(alias='op1')
|
||||
self.op1 = MODELS.build(mutable_cfg)
|
||||
mutable_cfg.update(alias='op2')
|
||||
self.op2 = MODELS.build(mutable_cfg)
|
||||
mutable_cfg.update(alias='op3')
|
||||
self.op3 = MODELS.build(mutable_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op1(x)
|
||||
x = self.op2(x)
|
||||
return self.op3(x)
|
||||
|
||||
|
||||
class SearchableModelAlias(nn.Module):
|
||||
|
||||
def __init__(self, mutable_cfg: dict) -> None:
|
||||
super().__init__()
|
||||
self.slayer1 = SearchableLayerAlias(mutable_cfg)
|
||||
self.slayer2 = SearchableLayerAlias(mutable_cfg)
|
||||
self.slayer3 = SearchableLayerAlias(mutable_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.slayer1(x)
|
||||
x = self.slayer2(x)
|
||||
return self.slayer3(x)
|
||||
|
||||
|
||||
class TestDiffMutator(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -107,6 +137,81 @@ class TestDiffMutator(TestCase):
|
|||
with pytest.raises(AssertionError):
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
def test_diff_mutator_diffop_alias(self) -> None:
|
||||
model = SearchableModelAlias(self.MUTABLE_CFG)
|
||||
|
||||
mutator_cfg = self.MUTATOR_CFG.copy()
|
||||
mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']]
|
||||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
|
||||
mutator.modify_supernet_forward()
|
||||
assert mutator.mutable_class_type == DiffMutable
|
||||
|
||||
def test_diff_mutator_alias_module_name(self) -> None:
|
||||
"""Using both alias and module name for grouping."""
|
||||
model = SearchableModelAlias(self.MUTABLE_CFG)
|
||||
|
||||
mutator_cfg = self.MUTATOR_CFG.copy()
|
||||
mutator_cfg['custom_group'] = [['op1'],
|
||||
[
|
||||
'slayer1.op2', 'slayer2.op2',
|
||||
'slayer3.op2'
|
||||
], ['slayer1.op3', 'slayer2.op3']]
|
||||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2, 3]
|
||||
|
||||
mutator.modify_supernet_forward()
|
||||
assert mutator.mutable_class_type == DiffMutable
|
||||
|
||||
def test_diff_mutator_duplicate_keys(self) -> None:
|
||||
model = SearchableModel(self.MUTABLE_CFG)
|
||||
|
||||
mutator_cfg = self.MUTATOR_CFG.copy()
|
||||
mutator_cfg['custom_group'] = [
|
||||
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
|
||||
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
|
||||
['slayer1.op3', 'slayer2.op3', 'slayer2.op3'],
|
||||
]
|
||||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
def test_diff_mutator_duplicate_key_alias(self) -> None:
|
||||
model = SearchableModelAlias(self.MUTABLE_CFG)
|
||||
|
||||
mutator_cfg = self.MUTATOR_CFG.copy()
|
||||
mutator_cfg['custom_group'] = [
|
||||
['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
|
||||
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
|
||||
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
|
||||
]
|
||||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
def test_diff_mutator_illegal_key(self) -> None:
|
||||
model = SearchableModel(self.MUTABLE_CFG)
|
||||
|
||||
mutator_cfg = self.MUTATOR_CFG.copy()
|
||||
mutator_cfg['custom_group'] = [
|
||||
['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
|
||||
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
|
||||
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
|
||||
]
|
||||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import unittest
|
||||
|
|
|
@ -13,17 +13,17 @@ from mmrazor.registry import MODELS
|
|||
MODEL_CFG = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
type='mmcls.ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
neck=dict(type='mmcls.GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
type='mmcls.LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
|
@ -87,6 +87,7 @@ def test_one_shot_mutator_mutable_model() -> None:
|
|||
model = _SearchableModel()
|
||||
mutator: OneShotMutator = MODELS.build(MUTATOR_CFG)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
|
||||
|
@ -116,3 +117,7 @@ def test_one_shot_mutator_mutable_model() -> None:
|
|||
mutator = MODELS.build(mutator_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main()
|
||||
|
|
Loading…
Reference in New Issue