[Refactor]Refactor tracer and channel mutator

pull/198/head
whcao 2022-07-01 08:23:15 +00:00 committed by pppppM
parent 332f49ac6f
commit 42063ae4d3
42 changed files with 4187 additions and 227 deletions

View File

@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delivers import * # noqa: F401,F403
from .tracer import * # noqa: F401,F403

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backward_tracer import BackwardTracer
from .loss_calculator import * # noqa: F401,F403
from .parsers import * # noqa: F401,F403
from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, Node,
NormNode, Path, PathList)
__all__ = [
'BackwardTracer', 'ConvNode', 'LinearNode', 'NormNode', 'ConcatNode',
'Path', 'PathList', 'Node', 'DepthWiseConvNode'
]

View File

@ -0,0 +1,201 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import re
from collections import OrderedDict
from mmcv import ConfigDict
from torch.nn import Conv2d, Linear
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _NormBase
from mmrazor.registry import TASK_UTILS
from .parsers import DEFAULT_BACKWARD_TRACER
from .path import Path, PathList
SUPPORT_MODULES = (Conv2d, Linear, _NormBase, GroupNorm)
@TASK_UTILS.register_module()
class BackwardTracer:
"""A topology tracer via backward.
Args:
loss_calculator (dict or Callable): Calculate the pseudo loss to trace
the topology of a model.
"""
def __init__(self, loss_calculator):
if isinstance(loss_calculator, (dict, ConfigDict)):
loss_calculator = TASK_UTILS.build(loss_calculator)
assert callable(
loss_calculator
), 'loss_calculator should be a dict, ConfigDict or ' \
'callable object'
self.loss_calculator = loss_calculator
@property
def backward_parser(self):
"""The mapping from the type of a backward op to the corresponding
parser."""
return DEFAULT_BACKWARD_TRACER
def backward_trace(self, grad_fn, module2name, param2module, cur_path,
result_paths, visited, shared_module):
"""Trace the topology of all the ``NON_PASS_MODULE``."""
grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
if grad_fn is not None:
name = type(grad_fn).__name__
# In pytorch graph, there may be an additional '0' or '1'
# (e.g. ThnnConv2DBackward0) after a backward op. Delete the
# digit numbers to build the corresponding parser.
name = re.sub(r'[0-1]+', '', name)
parse_module = self.backward_parser.get(name)
if parse_module is not None:
parse_module(self, grad_fn, module2name, param2module,
cur_path, result_paths, visited, shared_module)
else:
# If the op is AccumulateGrad, parents is (),
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
self.backward_trace(parent, module2name, param2module,
cur_path, result_paths, visited,
shared_module)
else:
result_paths.append(copy.deepcopy(cur_path))
def _trace_shared_module_hook(self, module, inputs, outputs):
"""Trace shared modules. Modules such as the detection head in
RetinaNet which are visited more than once during :func:`forward` are
shared modules.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs (tuple): The output of the module.
"""
module._cnt += 1
def _build_mappings(self, model):
"""Build the mappings which are used during tracing."""
module2name = OrderedDict()
# build a mapping from the identity of a module's params
# to this module
param2module = OrderedDict()
# record the visited module name during trace path
visited = dict()
def traverse(module, prefix=''):
for name, child in module.named_children():
full_name = f'{prefix}.{name}' if prefix else name
if isinstance(child, SUPPORT_MODULES):
module2name[child] = full_name
for param in child.parameters():
param2module[id(param)] = child
visited[full_name] = False
else:
traverse(child, full_name)
traverse(model)
return module2name, param2module, visited
def _register_share_module_hook(self, model):
"""Record shared modules which will be visited more than once during
forward such as shared detection head in RetinaNet.
If a module is not a shared module and it has been visited during
forward, its parent modules must have been traced already. However, a
shared module will be visited more than once during forward, so it is
still need to be traced even if it has been visited.
"""
self._shared_module_hook_handles = list()
for module in model.modules():
if hasattr(module, 'weight'):
# trace shared modules
module._cnt = 0
# the handle is only to remove the corresponding hook later
handle = module.register_forward_hook(
self._trace_shared_module_hook)
self._shared_module_hook_handles.append(handle)
def _remove_share_module_hook(self, model):
"""`_trace_shared_module_hook` and `_cnt` are only used to trace the
shared modules in a model and need to be remove later."""
for module in model.modules():
if hasattr(module, 'weight'):
del module._cnt
for handle in self._shared_module_hook_handles:
handle.remove()
del self._shared_module_hook_handles
def _set_all_requires_grad(self, model):
"""Set `requires_grad` of a parameter to True to trace the whole
architecture topology."""
self._param_requires_grad = dict()
for param in model.parameters():
self._param_requires_grad[id(param)] = param.requires_grad
param.requires_grad = True
def _restore_requires_grad(self, model):
"""We set requires_grad to True to trace the whole architecture
topology.
So it should be reset after that.
"""
for param in model.parameters():
param.requires_grad = self._param_requires_grad[id(param)]
del self._param_requires_grad
@staticmethod
def _find_share_modules(model):
"""Find shared modules which will be visited more than once during
forward such as shared detection head in RetinaNet."""
share_modules = list()
for name, module in model.named_modules():
if hasattr(module, 'weight'):
if module._cnt > 1:
share_modules.append(name)
return share_modules
@staticmethod
def _reset_norm_running_stats(model):
"""As we calculate the pseudo loss during tracing, we need to reset
states of parameters."""
for module in model.modules():
if isinstance(module, _NormBase):
module.reset_parameters()
def trace(self, model):
"""Trace trace the architecture topology of the input model."""
module2name, param2module, visited = self._build_mappings(model)
# Set requires_grad to True. If the `requires_grad` of a module's
# weight is False, we can not trace this module by parsing backward.
self._set_all_requires_grad(model)
self._register_share_module_hook(model)
pseudo_loss = self.loss_calculator(model)
share_modules = self._find_share_modules(model)
self._remove_share_module_hook(model)
self._restore_requires_grad(model)
module_path_list = PathList()
self.backward_trace(pseudo_loss.grad_fn, module2name, param2module,
Path(), module_path_list, visited, share_modules)
self._reset_norm_running_stats(model)
return module_path_list

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .image_classifier_loss_calculator import ImageClassifierPseudoLoss
from .single_stage_detector_loss_calculator import \
SingleStageDetectorPseudoLoss
__all__ = ['ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss']

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models import ImageClassifier
from mmrazor.registry import TASK_UTILS
@TASK_UTILS.register_module()
class ImageClassifierPseudoLoss:
"""Calculate the pseudo loss to trace the topology of a `ImageClassifier`
in MMClassification with `BackwardTracer`."""
def __call__(self, model: ImageClassifier) -> torch.Tensor:
pseudo_img = torch.rand(1, 3, 224, 224)
pseudo_output = model(pseudo_img)
return sum(pseudo_output)

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models import BaseDetector
from mmrazor.registry import TASK_UTILS
# todo: adapt to mmdet 2.0
@TASK_UTILS.register_module()
class SingleStageDetectorPseudoLoss:
def __call__(self, model: BaseDetector) -> torch.Tensor:
pseudo_img = torch.rand(1, 3, 224, 224)
pseudo_output = model.forward_dummy(pseudo_img)
out = torch.tensor(0.)
for levels in pseudo_output:
out += sum([level.sum() for level in levels])
return out

View File

@ -0,0 +1,189 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Callable, Dict
from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode,
NormNode, Path, PathList)
def _is_leaf_grad_fn(grad_fn):
"""Determine whether the current node is a leaf node."""
if type(grad_fn).__name__ == 'AccumulateGrad':
return True
return False
def parse_conv(tracer, grad_fn, module2name, param2module, cur_path,
result_paths, visited, shared_module):
"""Parse the backward of a conv layer.
Example:
>>> conv = nn.Conv2d(3, 3, 3)
>>> pseudo_img = torch.rand(1, 3, 224, 224)
>>> out = conv(pseudo_img)
>>> out.grad_fn.next_functions
((None, 0), (<AccumulateGrad object at 0x0000020E405CBD88>, 0),
(<AccumulateGrad object at 0x0000020E405CB588>, 0))
>>> # op.next_functions[0][0] is None means this ThnnConv2DBackward
>>> # op has no parents
>>> # op.next_functions[1][0].variable is the weight of this Conv2d
>>> # module
>>> # op.next_functions[2][0].variable is the bias of this Conv2d
>>> # module
"""
leaf_grad_fn = grad_fn.next_functions[1][0]
while not _is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
variable = leaf_grad_fn.variable
param_id = id(variable)
module = param2module[param_id]
name = module2name[module]
parent = grad_fn.next_functions[0][0]
if module.in_channels == module.groups:
cur_path.append(DepthWiseConvNode(name))
else:
cur_path.append(ConvNode(name))
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if visited[name] and name not in shared_module:
result_paths.append(copy.deepcopy(cur_path))
else:
visited[name] = True
tracer.backward_trace(parent, module2name, param2module, cur_path,
result_paths, visited, shared_module)
cur_path.pop(-1)
# todo: support parsing `MultiheadAttention` and user-defined matrix
# multiplication
def parse_linear(tracer, grad_fn, module2name, param2module, cur_path,
result_paths, visited, shared_module):
"""Parse the backward of a conv layer.
Example:
>>> fc = nn.Linear(3, 3, bias=True)
>>> input = torch.rand(3, 3)
>>> out = fc(input)
>>> out.grad_fn.next_functions
((<AccumulateGrad object at 0x0000020E405F75C8>, 0), (None, 0),
(<TBackward object at 0x0000020E405F7D48>, 0))
>>> # op.next_functions[0][0].variable is the bias of this Linear
>>> # module
>>> # op.next_functions[1][0] is None means this AddmmBackward op
>>> # has no parents
>>> # op.next_functions[2][0] is the TBackward op, and
>>> # op.next_functions[2][0].next_functions[0][0].variable is
>>> # the transpose of the weight of this Linear module
"""
leaf_grad_fn = grad_fn.next_functions[-1][0].next_functions[0][0]
while not _is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
variable = leaf_grad_fn.variable
param_id = id(variable)
module = param2module[param_id]
name = module2name[module]
parent = grad_fn.next_functions[-2][0]
cur_path.append(LinearNode(name))
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if visited[name] and name not in shared_module:
result_paths.append(copy.deepcopy(cur_path))
else:
visited[name] = True
tracer.backward_trace(parent, module2name, param2module, cur_path,
result_paths, visited, shared_module)
def parse_cat(tracer, grad_fn, module2name, param2module, cur_path,
result_paths, visited, shared_module):
"""Parse the backward of a concat operation.
Example:
>>> conv = nn.Conv2d(3, 3, 3)
>>> pseudo_img = torch.rand(1, 3, 224, 224)
>>> out1 = conv(pseudo_img)
>>> out2 = conv(pseudo_img)
>>> out = torch.cat([out1, out2], dim=1)
>>> out.grad_fn.next_functions
((<ThnnConv2DBackward object at 0x0000020E405F24C8>, 0),
(<ThnnConv2DBackward object at 0x0000020E405F2648>, 0))
>>> # the length of ``out.grad_fn.next_functions`` is two means
>>> # ``out`` is obtained by concatenating two tensors
"""
parents = grad_fn.next_functions
concat_id = '_'.join([str(id(p)) for p in parents])
name = f'concat_{concat_id}'
# If a module is not a shared module and it has been visited during
# forward, its parent modules must have been traced already.
# However, a shared module will be visited more than once during
# forward, so it is still need to be traced even if it has been
# visited.
if (name in visited and visited[name] and name not in shared_module):
pass
else:
visited[name] = True
sub_path_lists = list()
for i, parent in enumerate(parents):
sub_path_list = PathList()
tracer.backward_trace(parent, module2name, param2module, Path(),
sub_path_list, visited, shared_module)
sub_path_lists.append(sub_path_list)
cur_path.append(ConcatNode(name, sub_path_lists))
result_paths.append(copy.deepcopy(cur_path))
cur_path.pop(-1)
def parse_norm(tracer, grad_fn, module2name, param2module, cur_path,
result_paths, visited, shared_module):
"""Parse the backward of a concat operation.
Example:
>>> conv = nn.Conv2d(3, 3, 3)
>>> pseudo_img = torch.rand(1, 3, 224, 224)
>>> out1 = conv(pseudo_img)
>>> out2 = conv(pseudo_img)
>>> out = torch.cat([out1, out2], dim=1)
>>> out.grad_fn.next_functions
((<ThnnConv2DBackward object at 0x0000020E405F24C8>, 0),
(<ThnnConv2DBackward object at 0x0000020E405F2648>, 0))
>>> # the length of ``out.grad_fn.next_functions`` is two means
>>> # ``out`` is obtained by concatenating two tensors
"""
leaf_grad_fn = grad_fn.next_functions[1][0]
while not _is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
variable = leaf_grad_fn.variable
param_id = id(variable)
module = param2module[param_id]
name = module2name[module]
parent = grad_fn.next_functions[0][0]
cur_path.append(NormNode(name))
visited[name] = True
tracer.backward_trace(parent, module2name, param2module, cur_path,
result_paths, visited, shared_module)
cur_path.pop(-1)
DEFAULT_BACKWARD_TRACER: Dict[str, Callable] = {
'ThnnConv2DBackward': parse_conv,
'CudnnConvolutionBackward': parse_conv,
'MkldnnConvolutionBackward': parse_conv,
'SlowConvDilated2DBackward': parse_conv,
'ThAddmmBackward': parse_linear,
'AddmmBackward': parse_linear,
'MmBackward': parse_linear,
'CatBackward': parse_cat,
'ThnnBatchNormBackward': parse_norm,
'CudnnBatchNormBackward': parse_norm,
'NativeBatchNormBackward': parse_norm,
'NativeGroupNormBackward': parse_norm
}

View File

@ -0,0 +1,353 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _merge_node_parents(node2parents, _node2parents):
for node, parents in _node2parents.items():
if node in node2parents:
cur_parents = node2parents[node]
new_parents_set = set(cur_parents + parents)
new_parents = list(new_parents_set)
node2parents[node] = new_parents
else:
node2parents[node] = parents
class Node:
"""``Node`` is the data structure that represents individual instances
within a ``Path``. It corresponds to a module or an operation such as
concatenation in the model.
Args:
name (str): Unique identifier of a node.
"""
def __init__(self, name: str) -> None:
self._name = name
def get_module_names(self) -> List:
return [self.name]
@property
def name(self) -> str:
"""Get the name of current node."""
return self._name
def _get_class_name(self):
return self.__class__.__name__
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.name == other.name
else:
return False
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f'{self._get_class_name()}(\'{self.name}\')'
class ConvNode(Node):
"""A `ConvNode` corresponds to a Conv module in the original model."""
pass
class DepthWiseConvNode(Node):
"""A `DepthWiseConvNode` corresponds to a depth-wise conv module in the
original model."""
pass
class NormNode(Node):
"""A `NormNode` corresponds to a normalization module in the original
model."""
pass
class LinearNode(Node):
"""A `LinearNode` corresponds to a linear module in the original model."""
pass
class Path:
"""``Path`` is the data structure that represents a list of ``Node`` traced
by a tracer.
Args:
nodes(:obj:`Node` or List[:obj:`Node`], optional): Nodes in a path.
Default to None.
"""
def __init__(self, nodes: Optional[Union[Node, List[Node]]] = None):
self._nodes: List[Node] = list()
if nodes is not None:
if isinstance(nodes, Node):
nodes = [nodes]
assert isinstance(nodes, (list, tuple))
for node in nodes:
assert isinstance(node, Node)
self._nodes.append(node)
def get_root_names(self) -> List[str]:
"""Get the name of the first node in a path."""
return self._nodes[0].get_module_names()
def find_nodes_parents(self,
target_nodes: Tuple,
non_pass: Optional[Tuple] = None) -> Dict:
"""Find the parents of a specific node.
Args:
target_nodes (Tuple): Find the parents of nodes whose types
are one of `target_nodes`.
non_pass (Tuple): Ancestor nodes whose types are one of
`non_pass` are the parents of a specific node. Default to None.
"""
node2parents: Dict[str, List[Node]] = dict()
for i, node in enumerate(self._nodes):
if isinstance(node, ConcatNode):
_node2parents: Dict[str, List[Node]] = node.find_nodes_parents(
target_nodes, non_pass)
_merge_node_parents(node2parents, _node2parents)
continue
if isinstance(node, target_nodes):
parents = list()
for behind_node in self._nodes[i + 1:]:
if non_pass is None or isinstance(behind_node, non_pass):
parents.append(behind_node)
break
_node2parents = {node.name: parents}
_merge_node_parents(node2parents, _node2parents)
return node2parents
@property
def nodes(self) -> List:
"""Return a list of nodes in the current path."""
return self._nodes
def append(self, x: Node) -> None:
"""Add a node to the end of the current path."""
assert isinstance(x, Node)
self._nodes.append(x)
def pop(self, *args, **kwargs):
"""Temoves the node at the given index from the path and returns the
removed node."""
return self._nodes.pop(*args, **kwargs)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.nodes == other.nodes
else:
return False
def __len__(self):
return len(self._nodes)
def __getitem__(self, item):
return self._nodes[item]
def __iter__(self):
for node in self._nodes:
yield node
def _get_class_name(self) -> str:
"""Get the name of the current class."""
return self.__class__.__name__
def __repr__(self):
child_lines = []
for node in self._nodes:
node_str = repr(node)
node_str = _addindent(node_str, 2)
child_lines.append(node_str)
lines = child_lines
main_str = self._get_class_name() + '('
if lines:
main_str += '\n ' + ',\n '.join(lines) + '\n'
main_str += ')'
return main_str
class PathList:
"""``PathList`` is the data structure that represents a list of ``Path``
traced by a tracer.
Args:
paths(:obj:`Path` or List[:obj:`Path`], optional): A list of `Path`.
Default to None.
"""
def __init__(self, paths: Optional[Union[Path, List[Path]]] = None):
self._paths = list()
if paths is not None:
if isinstance(paths, Path):
paths = [paths]
assert isinstance(paths, (list, tuple))
for path in paths:
assert isinstance(path, Path)
self._paths.append(path)
def get_root_names(self) -> List[str]:
"""Get the root node of all the paths in `PathList`.
Notes:
Different paths in a PathList share the same root node.
"""
return self._paths[0].get_root_names()
def find_nodes_parents(self,
target_nodes: Tuple,
non_pass: Optional[Tuple] = None):
"""Find the parents of a specific node.
Args:
target_nodes (Tuple): Find the parents of nodes whose types
are one of `target_nodes`.
non_pass (Tuple): Ancestor nodes whose types are one of
`non_pass` are the parents of a specific node. Default to None.
"""
node2parents: Dict[str, List[Node]] = dict()
for p in self._paths:
_node2parents = p.find_nodes_parents(target_nodes, non_pass)
_merge_node_parents(node2parents, _node2parents)
return node2parents
def append(self, x: Path) -> None:
"""Add a path to the end of the current PathList."""
assert isinstance(x, Path)
self._paths.append(x)
@property
def paths(self):
"""Return all paths in the current PathList."""
return self._paths
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.paths == other.paths
else:
return False
def __len__(self):
return len(self._paths)
def __getitem__(self, item):
return self._paths[item]
def __iter__(self):
for path in self._paths:
yield path
def _get_class_name(self) -> str:
"""Get the name of the current class."""
return self.__class__.__name__
def __repr__(self):
child_lines = []
for node in self._paths:
node_str = repr(node)
node_str = _addindent(node_str, 2)
child_lines.append(node_str)
lines = child_lines
main_str = self._get_class_name() + '('
if lines:
main_str += '\n ' + ',\n '.join(lines) + '\n'
main_str += ')'
return main_str
class ConcatNode(Node):
"""``ConcatNode`` is the data structure that represents the concatenation
operation in a model.
Args:
name (str): Unique identifier of a `ConcatNode`.
path_lists (List[PathList]): Several nodes are concatenated and each
node is the root node of all the paths in a `PathList`
(one of `path_lists`).
"""
def __init__(self, name: str, path_lists: List[PathList]):
super().__init__(name)
self._path_lists = list()
for path_list in path_lists:
assert isinstance(path_list, PathList)
self._path_lists.append(path_list)
def get_module_names(self) -> List[str]:
"""Several nodes are concatenated.
Get the names of these nodes.
"""
module_names = list()
for path_list in self._path_lists:
module_names.extend(path_list.get_root_names())
return module_names
def find_nodes_parents(self,
target_nodes: Tuple,
non_pass: Optional[Tuple] = None):
"""Find the parents of a specific node.
Args:
target_nodes (Tuple): Find the parents of nodes whose types
are one of `target_nodes`.
non_pass (Tuple): Ancestor nodes whose types are one of
`non_pass` are the parents of a specific node. Default to None.
"""
node2parents: Dict[str, List[Node]] = dict()
for p in self._path_lists:
_node2parents = p.find_nodes_parents(target_nodes, non_pass)
_merge_node_parents(node2parents, _node2parents)
return node2parents
@property
def path_lists(self) -> List[PathList]:
"""Return all the path_list."""
return self._path_lists
def __len__(self):
return len(self._path_lists)
def __getitem__(self, item):
return self._path_lists[item]
def __iter__(self):
for path_list in self._path_lists:
yield path_list
def _get_class_name(self) -> str:
"""Get the name of the current class."""
return self.__class__.__name__
def __repr__(self):
child_lines = []
for node in self._path_lists:
node_str = repr(node)
node_str = _addindent(node_str, 2)
child_lines.append(node_str)
lines = child_lines
main_str = self._get_class_name() + '('
if lines:
main_str += '\n ' + ',\n '.join(lines) + '\n'
main_str += ')'
return main_str

View File

@ -6,5 +6,4 @@ from .losses import * # noqa: F401,F403
from .mutables import * # noqa: F401,F403
from .mutators import * # noqa: F401,F403
from .ops import * # noqa: F401,F403
from .pruners import * # noqa: F401,F403
from .subnet import * # noqa: F401,F403

View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default_dynamic_ops import * # noqa: F401,F403
from .slimmable_dynamic_ops import * # noqa: F401,F403

View File

@ -0,0 +1,408 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from mmrazor.models.mutables import MutableManagerMixIn
from mmrazor.registry import MODELS
class DynamicConv2d(nn.Conv2d, MutableManagerMixIn):
"""Applies a 2D convolution over an input signal composed of several input
planes according to the `mutable_in_channels` and `mutable_out_channels`
dynamically.
Args:
module_name (str): Name of this `DynamicConv2d`.
in_channels_cfg (Dict): Config related to `in_channels`.
out_channels_cfg (Dict): Config related to `out_channels`.
"""
def __init__(self, module_name, in_channels_cfg, out_channels_cfg, *args,
**kwargs):
super(DynamicConv2d, self).__init__(*args, **kwargs)
in_channels_cfg = copy.deepcopy(in_channels_cfg)
in_channels_cfg.update(
dict(
name=module_name,
num_channels=self.in_channels,
mask_type='in_mask'))
self.mutable_in_channels = MODELS.build(in_channels_cfg)
out_channels_cfg = copy.deepcopy(out_channels_cfg)
out_channels_cfg.update(
dict(
name=module_name,
num_channels=self.out_channels,
mask_type='out_mask'))
self.mutable_out_channels = MODELS.build(out_channels_cfg)
@property
def mutable_in(self):
"""Mutable `in_channels`."""
return self.mutable_in_channels
@property
def mutable_out(self):
"""Mutable `out_channels`."""
return self.mutable_out_channels
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_in_channels` and
`mutable_out_channels`, and forward."""
in_mask = self.mutable_in_channels.mask
out_mask = self.mutable_out_channels.mask
if self.groups == 1:
weight = self.weight[out_mask][:, in_mask]
groups = 1
elif self.groups == self.in_channels == self.out_channels:
# depth-wise conv
weight = self.weight[out_mask]
groups = input.size(1)
else:
raise NotImplementedError(
'Current `ChannelMutator` only support pruning the depth-wise '
'`nn.Conv2d` or `nn.Conv2d` module whose group number equals '
'to one, but got {self.groups}.')
bias = self.bias[out_mask] if self.bias is not None else None
return F.conv2d(input, weight, bias, self.stride, self.padding,
self.dilation, groups)
class DynamicLinear(nn.Linear, MutableManagerMixIn):
"""Applies a linear transformation to the incoming data according to the
`mutable_in_features` and `mutable_out_features` dynamically.
Args:
module_name (str): Name of this `DynamicLinear`.
in_features_cfg (Dict): Config related to `in_features`.
out_features_cfg (Dict): Config related to `out_features`.
"""
def __init__(self, module_name, in_features_cfg, out_features_cfg, *args,
**kwargs):
super(DynamicLinear, self).__init__(*args, **kwargs)
in_features_cfg = copy.deepcopy(in_features_cfg)
in_features_cfg.update(
dict(
name=module_name,
num_channels=self.in_features,
mask_type='in_mask'))
self.mutable_in_features = MODELS.build(in_features_cfg)
out_features_cfg = copy.deepcopy(out_features_cfg)
out_features_cfg.update(
dict(
name=module_name,
num_channels=self.out_features,
mask_type='out_mask'))
self.mutable_out_features = MODELS.build(out_features_cfg)
@property
def mutable_in(self):
"""Mutable `in_features`."""
return self.mutable_in_features
@property
def mutable_out(self):
"""Mutable `out_features`."""
return self.mutable_out_features
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_in_features` and
`mutable_out_features`, and forward."""
in_mask = self.mutable_in_features.mask
out_mask = self.mutable_out_features.mask
weight = self.weight[out_mask][:, in_mask]
bias = self.bias[out_mask] if self.bias is not None else None
return F.linear(input, weight, bias)
class DynamicBatchNorm(_BatchNorm, MutableManagerMixIn):
"""Applies Batch Normalization over an input according to the
`mutable_num_features` dynamically.
Args:
module_name (str): Name of this `DynamicBatchNorm`.
num_features_cfg (Dict): Config related to `num_features`.
"""
def __init__(self, module_name, num_features_cfg, *args, **kwargs):
super(DynamicBatchNorm, self).__init__(*args, **kwargs)
num_features_cfg = copy.deepcopy(num_features_cfg)
num_features_cfg.update(
dict(
name=module_name,
num_channels=self.num_features,
mask_type='out_mask'))
self.mutable_num_features = MODELS.build(num_features_cfg)
@property
def mutable_in(self):
"""Mutable `num_features`."""
return self.mutable_num_features
@property
def mutable_out(self):
"""Mutable `num_features`."""
return self.mutable_num_features
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_num_features`, and
forward."""
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = \
self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(
self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is
None)
if self.affine:
out_mask = self.mutable_num_features.mask
weight = self.weight[out_mask]
bias = self.bias[out_mask]
else:
weight, bias = self.weight, self.bias
if self.track_running_stats:
out_mask = self.mutable_num_features.mask
running_mean = self.running_mean[out_mask] \
if not self.training or self.track_running_stats else None
running_var = self.running_var[out_mask] \
if not self.training or self.track_running_stats else None
else:
running_mean, running_var = self.running_mean, self.running_var
return F.batch_norm(input, running_mean, running_var, weight, bias,
bn_training, exponential_average_factor, self.eps)
class DynamicInstanceNorm(_InstanceNorm, MutableManagerMixIn):
"""Applies Instance Normalization over an input according to the
`mutable_num_features` dynamically.
Args:
module_name (str): Name of this `DynamicInstanceNorm`.
num_features_cfg (Dict): Config related to `num_features`.
"""
def __init__(self, module_name, num_features_cfg, *args, **kwargs):
super(DynamicInstanceNorm, self).__init__(*args, **kwargs)
num_features_cfg = copy.deepcopy(num_features_cfg)
num_features_cfg.update(
dict(
name=module_name,
num_channels=self.num_features,
mask_type='out_mask'))
self.mutable_num_features = MODELS.build(num_features_cfg)
@property
def mutable_in(self):
"""Mutable `num_features`."""
return self.mutable_num_features
@property
def mutable_out(self):
"""Mutable `num_features`."""
return self.mutable_num_features
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_num_features`, and
forward."""
if self.affine:
out_mask = self.mutable_num_features.mask
weight = self.weight[out_mask]
bias = self.bias[out_mask]
else:
weight, bias = self.weight, self.bias
if self.track_running_stats:
out_mask = self.mutable_num_features.mask
running_mean = self.running_mean[out_mask]
running_var = self.running_var[out_mask]
else:
running_mean, running_var = self.running_mean, self.running_var
return F.instance_norm(input, running_mean, running_var, weight, bias,
self.training or not self.track_running_stats,
self.momentum, self.eps)
class DynamicGroupNorm(GroupNorm, MutableManagerMixIn):
"""Applies Group Normalization over a mini-batch of inputs according to the
`mutable_num_channels` dynamically.
Args:
module_name (str): Name of this `DynamicGroupNorm`.
num_channels_cfg (Dict): Config related to `num_channels`.
"""
def __init__(self, module_name, num_channels_cfg, *args, **kwargs):
super(DynamicGroupNorm, self).__init__(*args, **kwargs)
num_channels_cfg = copy.deepcopy(num_channels_cfg)
num_channels_cfg.update(
dict(
name=module_name,
num_channels=self.num_channels,
mask_type='out_mask'))
self.mutable_num_channels = MODELS.build(num_channels_cfg)
@property
def mutable_in(self):
"""Mutable `num_channels`."""
return self.mutable_num_channels
@property
def mutable_out(self):
"""Mutable `num_channels`."""
return self.mutable_num_channels
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_num_channels`, and
forward."""
if self.affine:
out_mask = self.mutable_num_channels.mask
weight = self.weight[out_mask]
bias = self.bias[out_mask]
else:
weight, bias = self.weight, self.bias
return F.group_norm(input, self.num_groups, weight, bias, self.eps)
def build_dynamic_conv2d(module: nn.Conv2d, module_name: str,
in_channels_cfg: Dict,
out_channels_cfg: Dict) -> DynamicConv2d:
"""Build DynamicConv2d.
Args:
module (:obj:`torch.nn.Conv2d`): The original Conv2d module.
module_name (str): Name of this `DynamicConv2d`.
in_channels_cfg (Dict): Config related to `in_channels`.
out_channels_cfg (Dict): Config related to `out_channels`.
"""
dynamic_conv = DynamicConv2d(
module_name=module_name,
in_channels_cfg=in_channels_cfg,
out_channels_cfg=out_channels_cfg,
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=True if module.bias is not None else False,
padding_mode=module.padding_mode)
return dynamic_conv
def build_dynamic_linear(module: nn.Linear, module_name: str,
in_features_cfg: Dict,
out_features_cfg: Dict) -> DynamicLinear:
"""Build DynamicLinear.
Args:
module (:obj:`torch.nn.Linear`): The original Linear module.
module_name (str): Name of this `DynamicLinear`.
in_features_cfg (Dict): Config related to `in_features`.
out_features_cfg (Dict): Config related to `out_features`.
"""
dynamic_linear = DynamicLinear(
module_name=module_name,
in_features_cfg=in_features_cfg,
out_features_cfg=out_features_cfg,
in_features=module.in_features,
out_features=module.out_features,
bias=True if module.bias is not None else False)
return dynamic_linear
def build_dynamic_bn(module: _BatchNorm, module_name: str,
num_features_cfg: Dict) -> DynamicBatchNorm:
"""Build DynamicBatchNorm.
Args:
module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module.
module_name (str): Name of this `DynamicBatchNorm`.
num_features_cfg (Dict): Config related to `num_features`.
"""
dynamic_bn = DynamicBatchNorm(
module_name=module_name,
num_features_cfg=num_features_cfg,
num_features=module.num_features,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats)
return dynamic_bn
def build_dynamic_in(module: _InstanceNorm, module_name: str,
num_features_cfg: Dict) -> DynamicInstanceNorm:
"""Build DynamicInstanceNorm.
Args:
module (:obj:`torch.nn._InstanceNorm`): The original InstanceNorm
module.
module_name (str): Name of this `DynamicInstanceNorm`.
num_features_cfg (Dict): Config related to `num_features`.
"""
dynamic_in = DynamicInstanceNorm(
module_name=module_name,
num_features_cfg=num_features_cfg,
num_features=module.num_features,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats)
return dynamic_in
def build_dynamic_gn(module: GroupNorm, module_name: str,
num_channels_cfg: Dict) -> DynamicGroupNorm:
"""Build DynamicGroupNorm.
Args:
module (:obj:`torch.nn.GroupNorm`): The original GroupNorm module.
module_name (str): Name of this `DynamicGroupNorm`.
num_channels_cfg (Dict): Config related to `num_channels`.
"""
dynamic_gn = DynamicGroupNorm(
module_name=module_name,
num_channels_cfg=num_channels_cfg,
num_channels=module.num_channels,
num_groups=module.num_groups,
eps=module.eps,
affine=module.affine)
return dynamic_gn

View File

@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Callable, Dict
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.mutables import MutableManagerMixIn
from mmrazor.registry import MODELS
from .default_dynamic_ops import build_dynamic_conv2d, build_dynamic_linear
class SwitchableBatchNorm2d(nn.Module, MutableManagerMixIn):
"""Employs independent batch normalization for different switches in a
slimmable network.
To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all
batch normalization layers for each switch in a slimmable network.
Compared with the naive training approach, it solves the problem of feature
aggregation inconsistency between different switches by independently
normalizing the feature mean and variance during testing.
Args:
module_name (str): Name of this `SwitchableBatchNorm2d`.
num_features_cfg (Dict): Config related to `num_features`.
eps (float): A value added to the denominator for numerical stability.
Same as that in :obj:`torch.nn._BatchNorm`. Default: 1e-5
momentum (float): The value used for the running_mean and running_var
computation. Can be set to None for cumulative moving average
(i.e. simple average). Same as that in :obj:`torch.nn._BatchNorm`.
Default: 0.1
affine (bool): A boolean value that when set to True, this module has
learnable affine parameters. Same as that in
:obj:`torch.nn._BatchNorm`. Default: True
track_running_stats (bool): A boolean value that when set to True, this
module tracks the running mean and variance, and when set to False,
this module does not track such statistics, and initializes
statistics buffers running_mean and running_var as None. When these
buffers are None, this module always uses batch statistics. in both
training and eval modes. Same as that in
:obj:`torch.nn._BatchNorm`. Default: True
"""
def __init__(self,
module_name: str,
num_features_cfg: Dict,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True):
super(SwitchableBatchNorm2d, self).__init__()
num_features_cfg = copy.deepcopy(num_features_cfg)
candidate_choices = num_features_cfg.pop('candidate_choices')
num_features_cfg.update(
dict(
name=module_name,
num_channels=max(candidate_choices),
mask_type='out_mask'))
bns = [
nn.BatchNorm2d(num_features, eps, momentum, affine,
track_running_stats)
for num_features in candidate_choices
]
self.bns = nn.ModuleList(bns)
self.mutable_num_features = MODELS.build(num_features_cfg)
@property
def mutable_out(self):
"""Mutable `num_features`."""
return self.mutable_num_features
def forward(self, input):
"""Forward computation according to the current switch of the slimmable
networks."""
idx = self.mutable_num_features.current_choice
return self.bns[idx](input)
def build_switchable_bn(module: _BatchNorm, module_name: str,
num_features_cfg: Dict) -> SwitchableBatchNorm2d:
"""Build SwitchableBatchNorm2d.
Args:
module (:obj:`torch.nn.GroupNorm`): The original BatchNorm module.
module_name (str): Name of this `SwitchableBatchNorm2d`.
num_channels_cfg (Dict): Config related to `num_features`.
"""
switchable_bn = SwitchableBatchNorm2d(
module_name=module_name,
num_features_cfg=num_features_cfg,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats)
return switchable_bn
SLIMMABLE_DYNAMIC_LAYER: Dict[Callable, Callable] = {
nn.Conv2d: build_dynamic_conv2d,
nn.Linear: build_dynamic_linear,
nn.BatchNorm2d: build_switchable_bn
}

View File

@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .diff_mutable import (DiffChoiceRoute, DiffMutable, DiffOP,
GumbelChoiceRoute)
from .mutable_channel import (OneShotChannelMutable, OrderChannelMutable,
RatioChannelMutable, SlimmableChannelMutable)
from .mutable_manager_mixin import MutableManagerMixIn
from .oneshot_mutable import OneShotMutable, OneShotOP
__all__ = [
'OneShotOP', 'OneShotMutable', 'DiffOP', 'DiffChoiceRoute',
'GumbelChoiceRoute', 'DiffMutable'
'OneShotOP', 'OneShotMutable', 'OneShotChannelMutable',
'RatioChannelMutable', 'OrderChannelMutable', 'DiffOP', 'DiffChoiceRoute',
'GumbelChoiceRoute', 'DiffMutable', 'MutableManagerMixIn',
'SlimmableChannelMutable'
]

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .one_shot_channel_mutable import OneShotChannelMutable
from .order_channel_mutable import OrderChannelMutable
from .ratio_channel_mutable import RatioChannelMutable
from .slimmable_channel_mutable import SlimmableChannelMutable
__all__ = [
'OneShotChannelMutable', 'OrderChannelMutable', 'RatioChannelMutable',
'SlimmableChannelMutable'
]

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import numpy as np
import torch
from mmcv.runner import BaseModule
class OneShotChannelMutable(BaseModule, ABC):
"""A type of ``MUTABLES`` for single path supernet such as AutoSlim. In
single path supernet, each module only has one choice invoked at the same
time. A path is obtained by sampling all the available choices. It is the
base class for one shot channel mutables.
Args:
name (str): Mutable name.
mask_type (str): One of 'in_mask' or 'out_mask'.
num_channels (int): The raw number of channels.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
def __init__(self,
name: str,
mask_type: str,
num_channels: int,
init_cfg: Optional[Dict] = None):
super(OneShotChannelMutable, self).__init__(init_cfg=init_cfg)
# If the input of a module is a concatenation of several modules'
# outputs, we add the out_mutable (mask_type == 'out_mask') of
# these modules to the `concat_mutables` of this module.
self.concat_mutables: List[OneShotChannelMutable] = list()
self.name = name
assert mask_type in ('in_mask', 'out_mask')
self.mask_type = mask_type
self.num_channels = num_channels
self.register_buffer('_mask', torch.ones((num_channels, )).bool())
self._current_choice = num_channels
self._same_mutables: List[OneShotChannelMutable] = list()
@property
def same_mutables(self):
"""Mutables in `same_mutables` and the current mutable should change
Synchronously."""
return self._same_mutables
def register_same_mutable(self, mutable):
"""Register the input mutable in `same_mutables`."""
if isinstance(mutable, list):
# Add a concatenation of mutables to `concat_mutables`.
assert self.mask_type == 'in_mask'
assert all([
cur_mutable.mask_type == 'out_mask' for cur_mutable in mutable
])
self.concat_mutables = mutable
return
if self == mutable:
return
if mutable in self._same_mutables:
return
self._same_mutables.append(mutable)
for s_mutable in self._same_mutables:
s_mutable.register_same_mutable(mutable)
mutable.register_same_mutable(s_mutable)
def sample_choice(self) -> int:
"""Sample an arbitrary selection from candidate choices.
Returns:
int: The chosen number of channels.
"""
assert len(self.concat_mutables) == 0
num_channels = np.random.choice(self.choices)
assert num_channels > 0, \
f'Sampled number of channels in `Mutable` {self.name}' \
f' should be a positive integer.'
return num_channels
@property
def min_choice(self) -> int:
"""Minimum number of channels."""
assert len(self.concat_mutables) == 0
min_channels = min(self.choices)
assert min_channels > 0, \
f'Minimum number of channels in `Mutable` {self.name}' \
f' should be a positive integer.'
return min_channels
@property
def max_choice(self) -> int:
"""Maximum number of channels."""
return max(self.choices)
def get_choice(self, idx: int) -> int:
"""Get the `idx`-th choice from candidate choices."""
assert len(self.concat_mutables) == 0
num_channels = self.choices[idx]
assert num_channels > 0, \
f'Number of channels in `Mutable` {self.name}' \
f' should be a positive integer.'
return num_channels
@property
def current_choice(self):
"""The current choice of the mutable."""
if len(self.concat_mutables) > 0:
return sum(
[mutable.current_choice for mutable in self.concat_mutables])
else:
return self._current_choice
@current_choice.setter
def current_choice(self, choice: int):
"""Set the current choice of the mutable."""
assert choice in self.choices
self._current_choice = choice
@property
@abstractmethod
def choices(self) -> List[int]:
"""list: all choices. """
@property
def mask(self):
"""The current mask.
We slice the registered parameters and buffers of a ``nn.Module``
according to the mask of the corresponding channel mutable.
"""
if len(self.concat_mutables) > 0:
# If the input of a module is a concatenation of several modules'
# outputs, the in_mask of this module is the concatenation of
# these modules' out_mask.
return torch.cat(
[mutable.mask for mutable in self.concat_mutables])
else:
num_channels = self.current_choice
mask = torch.zeros_like(self._mask).bool()
mask[:num_channels] = True
return mask
def __repr__(self):
concat_mutable_name = [
mutable.name for mutable in self.concat_mutables
]
repr_str = self.__class__.__name__
repr_str += f'(name={self.name}, '
repr_str += f'mask_type={self.mask_type}, '
repr_str += f'num_channels={self.num_channels}, '
repr_str += f'concat_mutable_name={concat_mutable_name})'
return repr_str

View File

@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from mmrazor.registry import MODELS
from .one_shot_channel_mutable import OneShotChannelMutable
@MODELS.register_module()
class OrderChannelMutable(OneShotChannelMutable):
"""A type of ``OneShotChannelMutable``. The input candidate choices are
candidate channel numbers.
Args:
name (str): Mutable name.
mask_type (str): One of 'in_mask' or 'out_mask'.
num_channels (int): The raw number of channels.
candidate_choices (list | tuple): A list or tuple of candidate
channel numbers.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
def __init__(self,
name: str,
mask_type: str,
num_channels: int,
candidate_choices: Union[List, Tuple],
init_cfg: Optional[Dict] = None):
super(OrderChannelMutable, self).__init__(
name, mask_type, num_channels, init_cfg=init_cfg)
assert len(candidate_choices) > 0, \
f'Number of candidate choices must be greater than 0, ' \
f'but got: {len(candidate_choices)}'
self._candidate_choices = list(candidate_choices)
assert all([num > 0 and num <= self.num_channels
for num in self._candidate_choices]), \
f'The candidate channel numbers should be in ' \
f'range(0, {self.num_channels}].'
assert all([isinstance(num, int)
for num in self._candidate_choices]),\
'Type of `candidate_choices` should be int.'
@property
def choices(self) -> List[int]:
"""list: all choices. """
return self._candidate_choices

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from mmrazor.registry import MODELS
from .one_shot_channel_mutable import OneShotChannelMutable
@MODELS.register_module()
class RatioChannelMutable(OneShotChannelMutable):
"""A type of ``OneShotChannelMutable``. The input candidate choices are
candidate width ratios.
Notes:
We first calculate the candidate channel numbers according to
the input candidate ratios (`candidate_choices`) and regard them as
available choices.
Args:
name (str): Mutable name.
mask_type (str): One of 'in_mask' or 'out_mask'.
num_channels (int): The raw number of channels.
candidate_choices (list | tuple): A list or tuple of candidate width
ratios. The width ratio is the ratio between the number of reserved
channels and that of all channels in a layer.
For example, if `ratios` is [0.25, 0.5], there are 2 cases
for us to choose from when we sample from a layer with 12 channels.
One is sampling the very first 3 channels in this layer, another is
sampling the very first 6 channels in this layer.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
def __init__(self,
name: str,
mask_type: str,
num_channels: int,
candidate_choices: Union[List, Tuple],
init_cfg: Optional[Dict] = None):
super(RatioChannelMutable, self).__init__(
name, mask_type, num_channels, init_cfg=init_cfg)
assert len(candidate_choices) > 0, \
f'Number of candidate choices must be greater than 0, ' \
f'but got: {len(candidate_choices)}'
self._candidate_choices = candidate_choices
assert all([
ratio > 0 and ratio <= 1 for ratio in self._candidate_choices
]), 'The candidate ratio should be in range(0, 1].'
@property
def choices(self) -> List[int]:
"""list: all choices. """
return [
round(ratio * self.num_channels)
for ratio in self._candidate_choices
]

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
from mmcv.runner import BaseModule
from mmrazor.registry import MODELS
@MODELS.register_module()
class SlimmableChannelMutable(BaseModule):
"""A type of ``MUTABLES`` to train several subnet together, such as the
retraining stage in AutoSlim.
Notes:
We need to set `candidate_choices` after the instantiation of a
`SlimmableChannelMutable` by ourselves.
Args:
name (str): Mutable name.
mask_type (str): One of 'in_mask' or 'out_mask'.
num_channels (int): The raw number of channels.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
def __init__(self,
name: str,
mask_type: str,
num_channels: int,
init_cfg: Optional[Dict] = None):
super(SlimmableChannelMutable, self).__init__(init_cfg=init_cfg)
self.name = name
assert mask_type in ('in_mask', 'out_mask')
self.mask_type = mask_type
self.num_channels = num_channels
self.register_buffer('_mask', torch.ones((num_channels, )).bool())
self._current_choice = 0
@property
def candidate_choices(self) -> List:
"""A list of candidate channel numbers."""
return self._candidate_choices
@candidate_choices.setter
def candidate_choices(self, choices):
"""Set the candidate channel numbers."""
assert getattr(self, '_candidate_choices', None) is None, \
f'candidate_choices can be set only when candidate_choices is ' \
f'None, got: candidate_choices = {self._candidate_choices}'
assert all([num > 0 and num <= self.num_channels
for num in choices]), \
f'The candidate channel numbers should be in ' \
f'range(0, {self.num_channels}].'
assert all([isinstance(num, int) for num in choices]), \
'Type of `candidate_choices` should be int.'
self._candidate_choices = list(choices)
@property
def choices(self) -> List:
"""Return all subnet indexes."""
assert self._candidate_choices is not None
return list(range(len(self.candidate_choices)))
@property
def current_choice(self) -> int:
"""The current choice of the mutable."""
return self._current_choice
@current_choice.setter
def current_choice(self, choice: int):
"""Set the current choice of the mutable."""
assert choice in self.choices
self._current_choice = choice
@property
def mask(self):
"""The current mask.
We slice the registered parameters and buffers of a ``nn.Module``
according to the mask of the corresponding channel mutable.
"""
idx = self.current_choice
num_channels = self.candidate_choices[idx]
mask = torch.zeros_like(self._mask).bool()
mask[:num_channels] = True
return mask

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
class MutableManagerMixIn:
"""Mixin class for determining whether an object is a dynamic layer.
Note that a dynamic layer manage one or several mutables.
"""
pass

View File

@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .channel_mutator.one_shot_channel_mutator import OneShotChannelMutator
from .channel_mutator.slimmable_channel_mutator import SlimmableChannelMutator
from .diff_mutator import DiffMutator
from .one_shot_mutator import OneShotMutator
__all__ = ['OneShotMutator', 'DiffMutator']
__all__ = [
'OneShotMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator',
'DiffMutator'
]

View File

@ -38,7 +38,7 @@ class BaseMutator(ABC, BaseModule):
@property
@abstractmethod
def search_group(self) -> Dict:
def search_groups(self) -> Dict:
"""Search group of the supernet.
Note:
@ -76,7 +76,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
if custom_group is None:
custom_group = []
self._custom_group = custom_group
self._search_group: Optional[Dict[int, List[MUTABLE_TYPE]]] = None
self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None
# TODO
# should be a class property
@ -99,10 +99,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
self._build_search_group(supernet)
self._build_search_groups(supernet)
@property
def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]:
def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]:
"""Search group of supernet.
Note:
@ -115,10 +115,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
Returns:
Dict[int, List[MUTABLE_TYPE]]: Search group.
"""
if self._search_group is None:
if self._search_groups is None:
raise RuntimeError(
'Call `prepare_from_supernet` before access search group!')
return self._search_group
return self._search_groups
def _build_name_mutable_mapping(
self, supernet: Module) -> Dict[str, MUTABLE_TYPE]:
@ -143,7 +143,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
return alias2mutable_names
def _build_search_group(self, supernet: Module) -> None:
def _build_search_groups(self, supernet: Module) -> None:
"""Build search group with ``custom_group`` and ``alias``(see more
information in :class:`BaseMutable`). Grouping by alias and module name
are both supported.
@ -176,20 +176,20 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
>>> # Using alias for grouping
>>> mutator = DiffOP(custom_group=[['a1'], ['a2']])
>>> mutator.prepare_from_supernet(model)
>>> mutator.search_group
>>> mutator.search_groups
{0: [op1, op2], 1: [op3]}
>>> # Using module name for grouping
>>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']])
>>> mutator.prepare_from_supernet(model)
>>> mutator.search_group
>>> mutator.search_groups
{0: [op1, op2], 1: [op3]}
>>> # Using both alias and module name for grouping
>>> mutator = DiffOP(custom_group=[['a2'], ['op2']])
>>> mutator.prepare_from_supernet(model)
>>> # The last operation would be grouped
>>> mutator.search_group
>>> mutator.search_groups
{0: [op3], 1: [op2], 2: [op1]}
@ -271,7 +271,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
f'The duplicate keys are {duplicate_keys}. ' \
'Please check if there are duplicate keys in the `custom_group`.'
self._search_group = search_groups
self._search_groups = search_groups
def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
name2mutable: Dict[str, MUTABLE_TYPE],

View File

@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .channel_mutator import ChannelMutator
from .one_shot_channel_mutator import OneShotChannelMutator
from .slimmable_channel_mutator import SlimmableChannelMutable
__all__ = [
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutable'
]

View File

@ -0,0 +1,277 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from abc import abstractmethod
from typing import Callable, Dict, List, Optional
import torch.nn as nn
from torch.nn import Module
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import (BatchNorm1d, BatchNorm2d, BatchNorm3d,
_BatchNorm)
from torch.nn.modules.instancenorm import (InstanceNorm1d, InstanceNorm2d,
InstanceNorm3d, _InstanceNorm)
from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode,
LinearNode, NormNode, PathList)
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
build_dynamic_conv2d,
build_dynamic_gn,
build_dynamic_in,
build_dynamic_linear)
from mmrazor.models.mutables import OneShotChannelMutable
from mmrazor.registry import MODELS, TASK_UTILS
from ..base_mutator import BaseMutator
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
PASS_NODES = (NormNode, DepthWiseConvNode)
NONPASS_MODULES = (nn.Conv2d, nn.Linear)
PASS_MODULES = (_BatchNorm, _InstanceNorm, GroupNorm)
@MODELS.register_module()
class ChannelMutator(BaseMutator):
"""Base class for channel-based mutators.
Args:
mutable_cfg (dict): The config for the channel mutable.
tracer_cfg (dict | Optional): The config for the model tracer.
We Trace the topology of a given model with the tracer.
skip_prefixes (List[str] | Optional): The module whose name start with
a string in skip_prefixes will not be pruned.
init_cfg (dict, optional): The config to control the initialization.
Attributes:
search_groups (Dict[int, List]): Search group of supernet. Note that
the search group of a mutable based channel mutator is composed of
corresponding mutables. Mutables in the same search group should
be pruned together.
name2module (Dict[str, :obj:`torch.nn.Module`]): The mapping from
a module name to the module.
Notes:
# To avoid ambiguity, we only allow the following two cases:
# 1. None of the parent nodes of a node is a `ConcatNode`
# 2. A node has only one parent node which is a `ConcatNode`
"""
def __init__(
self,
mutable_cfg: Dict,
tracer_cfg: Optional[Dict] = None,
skip_prefixes: Optional[List[str]] = None,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.mutable_cfg = mutable_cfg
if tracer_cfg:
self.tracer = TASK_UTILS.build(tracer_cfg)
else:
self.tracer = None
self.skip_prefixes = skip_prefixes
self._search_groups: Optional[Dict[int, List[Module]]] = None
def add_link(self, path_list: PathList) -> None:
"""Establish the relationship between the current nodes and their
parents."""
for path in path_list:
pre_node = None
for node in path:
if isinstance(node, DepthWiseConvNode):
module = self.name2module[node.name]
# The in_channels and out_channels of a depth-wise conv
# should be the same
module.mutable_out.register_same_mutable(module.mutable_in)
module.mutable_in.register_same_mutable(module.mutable_out)
if isinstance(node, ConcatNode):
if pre_node is not None:
module_names = node.get_module_names()
concat_modules = [
self.name2module[name] for name in module_names
]
concat_mutables = [
module.mutable_out for module in concat_modules
]
pre_module = self.name2module[pre_node.name]
pre_module.mutable_in.register_same_mutable(
concat_mutables)
for cur_path_list in node:
self.add_link(cur_path_list)
# ConcatNode is the last node in a path
break
if pre_node is None:
pre_node = node
continue
pre_module = self.name2module[pre_node.name]
cur_module = self.name2module[node.name]
pre_module.mutable_in.register_same_mutable(
cur_module.mutable_out)
cur_module.mutable_out.register_same_mutable(
pre_module.mutable_in)
pre_node = node
def prepare_from_supernet(self, supernet: Module) -> None:
"""Do some necessary preparations with supernet.
We support the following two cases:
Case 1: The input is the original nn.Module. We first replace the
conv/linear/norm modules in the input supernet with dynamic ops.
And trace the topology of the supernet. Finally, `search_groups` can be
built based on the topology.
Case 2: The input supernet is made up of dynamic ops. In this case,
relationship between nodes and their parents must have been
established and topology of the supernet is available for us. Then
`search_groups` can be built based on the topology.
Args:
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
if self.tracer is not None:
self.convert_dynamic_module(supernet, self.dynamic_layer)
# The mapping from a module name to the module
self._name2module = dict(supernet.named_modules())
assert self.tracer is not None
module_path_list: PathList = self.tracer.trace(supernet)
self.add_link(module_path_list)
else:
self._name2module = dict(supernet.named_modules())
self._search_groups = self.build_search_groups(supernet)
@staticmethod
def find_same_mutables(supernet) -> Dict:
"""The mutables in the same group should be pruned together."""
visited = []
groups = {}
group_idx = 0
for name, module in supernet.named_modules():
if isinstance(module, OneShotChannelMutable):
same_mutables = module.same_mutables
if module not in visited and len(same_mutables) > 0:
groups[group_idx] = [module] + same_mutables
visited.extend(groups[group_idx])
group_idx += 1
return groups
def convert_dynamic_module(self, supernet: Module, dynamic_layer: Dict):
"""Replace the conv/linear/norm modules in the input supernet with
dynamic ops.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be converted
in your algorithm.
dynamic_layer (Dict): The mapping from the module type to the
corresponding dynamic layer.
"""
def traverse(module, prefix):
for name, child in module.named_children():
module_name = prefix + name
if isinstance(child, NONPASS_MODULES):
mutable_cfg = copy.deepcopy(self.mutable_cfg)
# mutable_cfg.update(dict(name=module_name))
layer = dynamic_layer[type(child)](child, module_name,
mutable_cfg,
mutable_cfg)
setattr(module, name, layer)
elif isinstance(child, PASS_MODULES):
mutable_cfg = copy.deepcopy(self.mutable_cfg)
# mutable_cfg.update(dict(name=module_name))
layer = dynamic_layer[type(child)](child, module_name,
mutable_cfg)
setattr(module, name, layer)
else:
traverse(child, module_name + '.')
traverse(supernet, '')
@abstractmethod
def build_search_groups(self, supernet: Module):
"""Build `search_groups`.
The mutables in the same group should be pruned together.
"""
@property
def search_groups(self) -> Dict[int, List]:
"""Search group of supernet.
Note:
For mutable based mutator, the search group is composed of
corresponding mutables.
Raises:
RuntimeError: Called before search group has been built.
Returns:
Dict[int, List[MUTABLE_TYPE]]: Search group.
"""
if self._search_groups is None:
raise RuntimeError(
'Call `search_groups` before access `build_search_groups`!')
return self._search_groups
@property
def name2module(self):
"""The mapping from a module name to the module.
Returns:
dict: The name to module mapping.
"""
if hasattr(self, '_name2module'):
return self._name2module
else:
raise RuntimeError('Called before access `prepare_from_supernet`!')
@property
def dynamic_layer(self) -> Dict:
"""The mapping from a type to the corresponding dynamic layer. It is
called in `prepare_from_supernet`.
Returns:
dict: The mapping dict.
"""
dynamic_layer: Dict[Callable, Callable] = {
nn.Conv2d: build_dynamic_conv2d,
nn.Linear: build_dynamic_linear,
BatchNorm1d: build_dynamic_bn,
BatchNorm2d: build_dynamic_bn,
BatchNorm3d: build_dynamic_bn,
InstanceNorm1d: build_dynamic_in,
InstanceNorm2d: build_dynamic_in,
InstanceNorm3d: build_dynamic_in,
GroupNorm: build_dynamic_gn
}
return dynamic_layer
def is_skip_pruning(self, module_name: str,
skip_prefixes: Optional[List[str]]) -> bool:
"""Judge if the module with the input `module_name` should not be
pruned.
Args:
module_name (str): Module name.
skip_prefixes (list or None): The module whose name start with
a string in skip_prefixes will not be prune.
"""
skip_pruning = False
if skip_prefixes:
for prefix in skip_prefixes:
if module_name.startswith(prefix):
skip_pruning = True
break
return skip_pruning

View File

@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Any, Dict, List, Optional
from torch.nn import Module
from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode,
LinearNode, NormNode)
from mmrazor.registry import MODELS
from .channel_mutator import ChannelMutator
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
PASS_NODES = (NormNode, DepthWiseConvNode)
PRUNING_NODES = (ConvNode, LinearNode)
@MODELS.register_module()
class OneShotChannelMutator(ChannelMutator):
"""One-shot channel mutable based channel mutator.
Args:
mutable_cfg (dict): The config for the channel mutable.
tracer_cfg (dict): The config for the model tracer. We Trace the
topology of a given model with the tracer.
skip_prefixes (List[str] | Optional): The module whose name start with
a string in skip_prefixes will not be pruned.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
mutable_cfg: Dict,
tracer_cfg: Optional[Dict] = None,
skip_prefixes: Optional[List[str]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(mutable_cfg, tracer_cfg, skip_prefixes, init_cfg)
def sample_choices(self):
"""Sample a choice that records a selection from the search space.
Returns:
dict: Record the information to build the subnet from the supernet.
Its keys are the properties ``group_idx`` in the channel
mutator's ``search_groups``, and its values are the sampled
choice.
"""
choice_dict = dict()
for group_idx, mutables in self.search_groups.items():
choice_dict[group_idx] = mutables[0].sample_choice()
return choice_dict
def set_choices(self, choice_dict: Dict[int, Any]) -> None:
"""Set current subnet according to ``choice_dict``.
Args:
choice_dict (Dict[int, Any]): Choice dict.
"""
for group_idx, choice in choice_dict.items():
mutables = self.search_groups[group_idx]
for mutable in mutables:
mutable.current_choice = choice
def set_max_choices(self) -> None:
"""Set the channel numbers of each layer to maximum."""
for mutables in self.search_groups.values():
for mutable in mutables:
mutable.current_choice = mutable.max_choice
def set_min_choices(self) -> None:
"""Set the channel numbers of each layer to minimum."""
for mutables in self.search_groups.values():
for mutable in mutables:
mutable.current_choice = mutable.min_choice
# todo: check search gorups
def build_search_groups(self, supernet: Module):
"""Build `search_groups`. The mutables in the same group should be
pruned together.
Examples:
>>> class ResBlock(nn.Module):
... def __init__(self) -> None:
... super().__init__()
...
... self.op1 = nn.Conv2d(3, 8, 1)
... self.bn1 = nn.BatchNorm2d(8)
... self.op2 = nn.Conv2d(8, 8, 1)
... self.bn2 = nn.BatchNorm2d(8)
... self.op3 = nn.Conv2d(8, 8, 1)
...
... def forward(self, x):
... x1 = self.bn1(self.op1(x))
... x2 = self.bn2(self.op2(x1))
... x3 = self.op3(x2 + x1)
... return x3
>>> class ToyPseudoLoss:
...
... def __call__(self, model):
... pseudo_img = torch.rand(2, 3, 16, 16)
... pseudo_output = model(pseudo_img)
... return pseudo_output.sum()
>>> mutator = OneShotChannelMutator(
... tracer_cfg=dict(type='BackwardTracer',
... loss_calculator=ToyPseudoLoss()),
... mutable_cfg=dict(type='RatioChannelMutable',
... candidate_choices=[4 / 8, 1.0])
>>> model = ResBlock()
>>> mutator.prepare_from_supernet(model)
>>> mutator.search_groups
{0: [RatioChannelMutable(name=op2, mask_type=out_mask, ...),
RatioChannelMutable(name=op1, mask_type=out_mask, ...),
RatioChannelMutable(name=op3, mask_type=in_mask, ...),
RatioChannelMutable(name=op2, mask_type=in_mask, ...),
RatioChannelMutable(name=bn2, mask_type=out_mask, ...),
RatioChannelMutable(name=bn1, mask_type=out_mask, ...)]}
"""
groups = self.find_same_mutables(supernet)
search_groups = dict()
group_idx = 0
for group in groups.values():
is_skip = False
for mutable in group:
if self.is_skip_pruning(mutable.name, self.skip_prefixes):
warnings.warn(f'Group {group} is not searchable due to'
f' skip_prefixes: {self.skip_prefixes}')
is_skip = True
break
if not is_skip:
search_groups[group_idx] = group
group_idx += 1
return search_groups

View File

@ -0,0 +1,148 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional
import torch.nn as nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.architectures.dynamic_op import (DynamicBatchNorm,
build_switchable_bn)
from mmrazor.models.mutables import SlimmableChannelMutable
from mmrazor.registry import MODELS
from .channel_mutator import ChannelMutator
NONPASS_MODULES = (nn.Conv2d, nn.Linear)
PASS_MODULES = (_BatchNorm, )
@MODELS.register_module()
class SlimmableChannelMutator(ChannelMutator):
"""Slimmable channel mutable based channel mutator.
Args:
channel_cfgs (list[Dict]): A list of candidate channel configs.
mutable_cfg (dict): The config for the channel mutable.
skip_prefixes (List[str] | Optional): The module whose name start with
a string in skip_prefixes will not be pruned.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
channel_cfgs: List[Dict],
mutable_cfg: Dict,
skip_prefixes: Optional[List[str]] = None,
init_cfg: Optional[Dict] = None):
super(SlimmableChannelMutator, self).__init__(
mutable_cfg=mutable_cfg,
skip_prefixes=skip_prefixes,
init_cfg=init_cfg)
self.channel_cfgs = self._merge_channel_cfgs(channel_cfgs)
def _merge_channel_cfgs(self, channel_cfgs: List[Dict]):
"""Merge several channel configs.
Args:
channel_cfgs (List[Dict])
"""
merged_channel_cfg = dict()
num_subnet = len(channel_cfgs)
for module_name in channel_cfgs[0].keys():
channels_per_layer = [
channel_cfgs[idx][module_name] for idx in range(num_subnet)
]
merged_channels_per_layer = dict()
for key in channels_per_layer[0].keys():
merged_channels = [
channels_per_layer[idx][key] for idx in range(num_subnet)
]
merged_channels_per_layer[key] = merged_channels
merged_channel_cfg[module_name] = merged_channels_per_layer
return merged_channel_cfg
def prepare_from_supernet(self, supernet: Module) -> None:
"""Do some necessary preparations with supernet.
Note:
Different from `ChannelMutator`, we only support Case 1 in
`ChannelMutator`. The input supernet should be made up of original
nn.Module. And we replace the conv/linear/bn modules in the input
supernet with dynamic ops first. Then we convert the
``DynamicBatchNorm`` in supernet with ``SwitchableBatchNorm2d``.
Finally, we set the candidate channel numbers to the corresponding
`SlimmableChannelMutable`.
Args:
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
self.convert_dynamic_module(supernet, self.dynamic_layer)
print(supernet)
self.convert_switchable_bn(supernet)
self.set_candidate_choices(supernet)
# The mapping from a module name to the module
self._name2module = dict(supernet.named_modules())
def set_candidate_choices(self, supernet):
"""Set the ``candidate_choices`` of each ``SlimmableChannelMutable``.
Notes:
Different from other ``OneShotChannelMutable``,
``candidate_choices`` is optional when instantiating a
``SlimmableChannelMutable``
"""
for name, module in supernet.named_modules():
if isinstance(module, SlimmableChannelMutable):
candidate_choices = self.channel_cfgs[name]['current_choice']
module.candidate_choices = candidate_choices
def convert_switchable_bn(self, supernet):
"""Replace ``DynamicBatchNorm`` in supernet with
``SwitchableBatchNorm2d``.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be converted
in your algorithm.
"""
def traverse(module, prefix):
for name, child in module.named_children():
module_name = prefix + name
if isinstance(child, DynamicBatchNorm):
mutable_cfg = copy.deepcopy(self.mutable_cfg)
key = module_name + '.mutable_num_features'
candidate_choices = self.channel_cfgs[key][
'current_choice']
mutable_cfg.update(
dict(candidate_choices=candidate_choices))
sbn = build_switchable_bn(child, module_name, mutable_cfg)
setattr(module, name, sbn)
else:
traverse(child, module_name + '.')
traverse(supernet, '')
def switch_choices(self, idx):
"""Switch the channel config of the supernet according to input `idx`.
If we train more than one subnet together, we need to switch the
channel_cfg from one to another during one training iteration.
Args:
idx (int): The index of the current subnet.
"""
for name, module in self.name2module.items():
if hasattr(module, 'mutable_out'):
module.mutable_out.current_choice = idx
if hasattr(module, 'mutable_in'):
module.mutable_in.current_choice = idx
def build_search_groups(self, supernet: Module):
"""Build `search_groups`.
The mutables in the same group should be pruned together.
"""
pass

View File

@ -45,12 +45,12 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
group_id share the same arch param.
Returns:
torch.nn.ParameterDict: the arch params are got by `search_group`.
torch.nn.ParameterDict: the arch params are got by `search_groups`.
"""
arch_params: Dict[int, nn.Parameter] = dict()
for group_id, modules in self.search_group.items():
for group_id, modules in self.search_groups.items():
group_arch_param = modules[0].build_arch_param()
arch_params[group_id] = group_arch_param
@ -65,7 +65,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
`DiffMutable`.
"""
for group_id, modules in self.search_group.items():
for group_id, modules in self.search_groups.items():
for module in modules:
module.set_forward_args(arch_param=self.arch_params[group_id])

View File

@ -25,7 +25,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
['op1', 'op2', 'op3']
>>> mutator.prepare_from_supernet(supernet)
>>> mutator.search_group
>>> mutator.search_groups
{0: [op1], 1: [op2], 2: [op3]}
>>> random_choices = mutator.sample_choices()
@ -61,7 +61,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
Dict[int, Any]: Random choices dict.
"""
random_choices = dict()
for group_id, modules in self.search_group.items():
for group_id, modules in self.search_groups.items():
random_choices[group_id] = modules[0].sample_choice()
return random_choices
@ -75,7 +75,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
search groups, and the value is the sampling results
corresponding to this group.
"""
for group_id, modules in self.search_group.items():
for group_id, modules in self.search_groups.items():
choice = choices[group_id]
for module in modules:
module.current_choice = choice

View File

@ -1,6 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ratio_pruning import RatioPruner
from .structure_pruning import StructurePruner
from .utils import * # noqa: F401,F403
__all__ = ['RatioPruner', 'StructurePruner']

View File

@ -1,150 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules import GroupNorm
from mmrazor.registry import MODELS
from .structure_pruning import StructurePruner
from .utils import SwitchableBatchNorm2d
@MODELS.register_module()
class RatioPruner(StructurePruner):
"""A random ratio pruner.
Each layer can adjust its own width ratio randomly and independently.
Args:
ratios (list | tuple): Width ratio of each layer can be
chosen from `ratios` randomly. The width ratio is the ratio between
the number of reserved channels and that of all channels in a
layer. For example, if `ratios` is [0.25, 0.5], there are 2 cases
for us to choose from when we sample from a layer with 12 channels.
One is sampling the very first 3 channels in this layer, another is
sampling the very first 6 channels in this layer. Default to None.
"""
def __init__(self, ratios, **kwargs):
super(RatioPruner, self).__init__(**kwargs)
ratios = list(ratios)
ratios.sort()
self.ratios = ratios
self.min_ratio = ratios[0]
def _check_pruner(self, supernet):
for module in supernet.model.modules():
if isinstance(module, GroupNorm):
num_channels = module.num_channels
num_groups = module.num_groups
for ratio in self.ratios:
new_channels = int(round(num_channels * ratio))
assert (num_channels * ratio) % num_groups == 0, \
f'Expected number of channels in input of GroupNorm ' \
f'to be divisible by num_groups, but number of ' \
f'channels may be {new_channels} according to ' \
f'ratio {ratio} and num_groups={num_groups}'
def prepare_from_supernet(self, supernet):
super(RatioPruner, self).prepare_from_supernet(supernet)
def get_channel_mask(self, out_mask):
"""Randomly choose a width ratio of a layer from ``ratios``"""
out_channels = out_mask.size(1)
random_ratio = np.random.choice(self.ratios)
new_channels = int(round(out_channels * random_ratio))
assert new_channels > 0, \
'Output channels should be a positive integer.'
new_out_mask = torch.zeros_like(out_mask)
new_out_mask[:, :new_channels] = 1
return new_out_mask
def sample_subnet(self):
"""Random sample subnet by random mask.
Returns:
dict: Record the information to build the subnet from the supernet,
its keys are the properties ``space_id`` in the pruner's search
spaces, and its values are corresponding sampled out_mask.
"""
subnet_dict = dict()
for space_id, out_mask in self.channel_spaces.items():
subnet_dict[space_id] = self.get_channel_mask(out_mask)
return subnet_dict
def set_min_channel(self):
"""Set the number of channels each layer to minimum."""
subnet_dict = dict()
for space_id, out_mask in self.channel_spaces.items():
out_channels = out_mask.size(1)
random_ratio = self.min_ratio
new_channels = int(round(out_channels * random_ratio))
assert new_channels > 0, \
'Output channels should be a positive integer.'
new_out_mask = torch.zeros_like(out_mask)
new_out_mask[:, :new_channels] = 1
subnet_dict[space_id] = new_out_mask
self.set_subnet(subnet_dict)
def switch_subnet(self, channel_cfg, subnet_ind=None):
"""Switch the channel config of the supernet according to channel_cfg.
If we train more than one subnet together, we need to switch the
channel_cfg from one to another during one training iteration.
Args:
channel_cfg (dict): The channel config of a subnet. Key is space_id
and value is a dict which includes out_channels (and
in_channels if exists).
subnet_ind (int, optional): The index of the current subnet. If
we replace normal BatchNorm2d with ``SwitchableBatchNorm2d``,
we should switch the index of ``SwitchableBatchNorm2d`` when
switch subnet. Defaults to None.
"""
subnet_dict = dict()
for name, channels_per_layer in channel_cfg.items():
module = self.name2module[name]
if (isinstance(module, SwitchableBatchNorm2d)
and subnet_ind is not None):
# When switching bn we should switch index simultaneously
module.index = subnet_ind
continue
out_channels = channels_per_layer['out_channels']
out_mask = torch.zeros_like(module.out_mask)
out_mask[:, :out_channels] = 1
space_id = self.get_space_id(name)
if space_id in subnet_dict:
assert torch.equal(subnet_dict[space_id], out_mask)
elif space_id is not None:
subnet_dict[space_id] = out_mask
self.set_subnet(subnet_dict)
def convert_switchable_bn(self, module, num_bns):
"""Convert normal ``nn.BatchNorm2d`` to ``SwitchableBatchNorm2d``.
Args:
module (:obj:`torch.nn.Module`): The module to be converted.
num_bns (int): The number of ``nn.BatchNorm2d`` in a
``SwitchableBatchNorm2d``.
Return:
:obj:`torch.nn.Module`: The converted module. Each
``nn.BatchNorm2d`` in this module has been converted to a
``SwitchableBatchNorm2d``.
"""
module_output = module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module_output = SwitchableBatchNorm2d(module.num_features, num_bns)
for name, child in module.named_children():
module_output.add_module(
name, self.convert_switchable_bn(child, num_bns))
del module
return module_output

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .switchable_bn import SwitchableBatchNorm2d
__all__ = ['SwitchableBatchNorm2d']

View File

@ -1,38 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
class SwitchableBatchNorm2d(nn.Module):
"""Employs independent batch normalization for different switches in a
slimmable network.
To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all
batch normalization layers for each switch in a slimmable network.
Compared with the naive training approach, it solves the problem of feature
aggregation inconsistency between different switches by independently
normalizing the feature mean and variance during testing.
Args:
max_num_features (int): The maximum ``num_features`` among BatchNorm2d
in all the switches.
num_bns (int): The number of different switches in the slimmable
networks.
"""
def __init__(self, max_num_features, num_bns):
super(SwitchableBatchNorm2d, self).__init__()
self.max_num_features = max_num_features
# number of BatchNorm2d in a SwitchableBatchNorm2d
self.num_bns = num_bns
bns = []
for _ in range(num_bns):
bns.append(nn.BatchNorm2d(max_num_features))
self.bns = nn.ModuleList(bns)
# When switching bn we should switch index simultaneously
self.index = 0
def forward(self, input):
"""Forward computation according to the current switch of the slimmable
networks."""
return self.bns[self.index](input)

View File

@ -0,0 +1,474 @@
backbone.conv1.bn.mutable_num_features:
current_choice: 8
origin_channels: 48
backbone.conv1.conv.mutable_in_channels:
current_choice: 3
origin_channels: 3
backbone.conv1.conv.mutable_out_channels:
current_choice: 8
origin_channels: 48
backbone.conv2.bn.mutable_num_features:
current_choice: 1920
origin_channels: 1920
backbone.conv2.conv.mutable_in_channels:
current_choice: 280
origin_channels: 480
backbone.conv2.conv.mutable_out_channels:
current_choice: 1920
origin_channels: 1920
backbone.layer1.0.conv.0.bn.mutable_num_features:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.0.conv.mutable_in_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.0.conv.mutable_out_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.1.bn.mutable_num_features:
current_choice: 8
origin_channels: 24
backbone.layer1.0.conv.1.conv.mutable_in_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.1.conv.mutable_out_channels:
current_choice: 8
origin_channels: 24
backbone.layer2.0.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.0.conv.mutable_in_channels:
current_choice: 8
origin_channels: 24
backbone.layer2.0.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.2.bn.mutable_num_features:
current_choice: 16
origin_channels: 40
backbone.layer2.0.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.2.conv.mutable_out_channels:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.0.conv.mutable_in_channels:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.2.bn.mutable_num_features:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.2.conv.mutable_out_channels:
current_choice: 16
origin_channels: 40
backbone.layer3.0.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.0.conv.mutable_in_channels:
current_choice: 16
origin_channels: 40
backbone.layer3.0.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.0.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer4.0.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer4.0.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.2.bn.mutable_num_features:
current_choice: 48
origin_channels: 96
backbone.layer4.0.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.2.conv.mutable_out_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.1.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.0.conv.mutable_in_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.1.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.2.bn.mutable_num_features:
current_choice: 48
origin_channels: 96
backbone.layer4.1.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.2.conv.mutable_out_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.2.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.0.conv.mutable_in_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.2.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.2.bn.mutable_num_features:
current_choice: 48
origin_channels: 96
backbone.layer4.2.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.2.conv.mutable_out_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.3.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.0.conv.mutable_in_channels:
current_choice: 48
origin_channels: 96
backbone.layer4.3.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.2.bn.mutable_num_features:
current_choice: 48
origin_channels: 96
backbone.layer4.3.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.2.conv.mutable_out_channels:
current_choice: 48
origin_channels: 96
backbone.layer5.0.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.0.conv.mutable_in_channels:
current_choice: 48
origin_channels: 96
backbone.layer5.0.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.2.bn.mutable_num_features:
current_choice: 64
origin_channels: 144
backbone.layer5.0.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.2.conv.mutable_out_channels:
current_choice: 64
origin_channels: 144
backbone.layer5.1.conv.0.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.0.conv.mutable_in_channels:
current_choice: 64
origin_channels: 144
backbone.layer5.1.conv.0.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.2.bn.mutable_num_features:
current_choice: 64
origin_channels: 144
backbone.layer5.1.conv.2.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.2.conv.mutable_out_channels:
current_choice: 64
origin_channels: 144
backbone.layer5.2.conv.0.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.0.conv.mutable_in_channels:
current_choice: 64
origin_channels: 144
backbone.layer5.2.conv.0.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.2.bn.mutable_num_features:
current_choice: 64
origin_channels: 144
backbone.layer5.2.conv.2.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.2.conv.mutable_out_channels:
current_choice: 64
origin_channels: 144
backbone.layer6.0.conv.0.bn.mutable_num_features:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.0.conv.mutable_in_channels:
current_choice: 64
origin_channels: 144
backbone.layer6.0.conv.0.conv.mutable_out_channels:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.1.bn.mutable_num_features:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.1.conv.mutable_in_channels:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.1.conv.mutable_out_channels:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.2.bn.mutable_num_features:
current_choice: 176
origin_channels: 240
backbone.layer6.0.conv.2.conv.mutable_in_channels:
current_choice: 648
origin_channels: 864
backbone.layer6.0.conv.2.conv.mutable_out_channels:
current_choice: 176
origin_channels: 240
backbone.layer6.1.conv.0.bn.mutable_num_features:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.0.conv.mutable_in_channels:
current_choice: 176
origin_channels: 240
backbone.layer6.1.conv.0.conv.mutable_out_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.1.bn.mutable_num_features:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.1.conv.mutable_in_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.1.conv.mutable_out_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.2.bn.mutable_num_features:
current_choice: 176
origin_channels: 240
backbone.layer6.1.conv.2.conv.mutable_in_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.1.conv.2.conv.mutable_out_channels:
current_choice: 176
origin_channels: 240
backbone.layer6.2.conv.0.bn.mutable_num_features:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.0.conv.mutable_in_channels:
current_choice: 176
origin_channels: 240
backbone.layer6.2.conv.0.conv.mutable_out_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.1.bn.mutable_num_features:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.1.conv.mutable_in_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.1.conv.mutable_out_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.2.bn.mutable_num_features:
current_choice: 176
origin_channels: 240
backbone.layer6.2.conv.2.conv.mutable_in_channels:
current_choice: 720
origin_channels: 1440
backbone.layer6.2.conv.2.conv.mutable_out_channels:
current_choice: 176
origin_channels: 240
backbone.layer7.0.conv.0.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.0.conv.mutable_in_channels:
current_choice: 176
origin_channels: 240
backbone.layer7.0.conv.0.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.2.bn.mutable_num_features:
current_choice: 280
origin_channels: 480
backbone.layer7.0.conv.2.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.2.conv.mutable_out_channels:
current_choice: 280
origin_channels: 480
head.fc.mutable_in_features:
current_choice: 1920
origin_channels: 1920
head.fc.mutable_out_features:
current_choice: 1000
origin_channels: 1000

View File

@ -0,0 +1,474 @@
backbone.conv1.bn.mutable_num_features:
current_choice: 8
origin_channels: 48
backbone.conv1.conv.mutable_in_channels:
current_choice: 3
origin_channels: 3
backbone.conv1.conv.mutable_out_channels:
current_choice: 8
origin_channels: 48
backbone.conv2.bn.mutable_num_features:
current_choice: 1920
origin_channels: 1920
backbone.conv2.conv.mutable_in_channels:
current_choice: 480
origin_channels: 480
backbone.conv2.conv.mutable_out_channels:
current_choice: 1920
origin_channels: 1920
backbone.layer1.0.conv.0.bn.mutable_num_features:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.0.conv.mutable_in_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.0.conv.mutable_out_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.1.bn.mutable_num_features:
current_choice: 8
origin_channels: 24
backbone.layer1.0.conv.1.conv.mutable_in_channels:
current_choice: 8
origin_channels: 48
backbone.layer1.0.conv.1.conv.mutable_out_channels:
current_choice: 8
origin_channels: 24
backbone.layer2.0.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.0.conv.mutable_in_channels:
current_choice: 8
origin_channels: 24
backbone.layer2.0.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.2.bn.mutable_num_features:
current_choice: 16
origin_channels: 40
backbone.layer2.0.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer2.0.conv.2.conv.mutable_out_channels:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.0.conv.mutable_in_channels:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.2.bn.mutable_num_features:
current_choice: 16
origin_channels: 40
backbone.layer2.1.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer2.1.conv.2.conv.mutable_out_channels:
current_choice: 16
origin_channels: 40
backbone.layer3.0.conv.0.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.0.conv.mutable_in_channels:
current_choice: 16
origin_channels: 40
backbone.layer3.0.conv.0.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.bn.mutable_num_features:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.1.conv.mutable_out_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.0.conv.2.conv.mutable_in_channels:
current_choice: 96
origin_channels: 240
backbone.layer3.0.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.1.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.1.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.2.bn.mutable_num_features:
current_choice: 24
origin_channels: 48
backbone.layer3.2.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer3.2.conv.2.conv.mutable_out_channels:
current_choice: 24
origin_channels: 48
backbone.layer4.0.conv.0.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.0.conv.mutable_in_channels:
current_choice: 24
origin_channels: 48
backbone.layer4.0.conv.0.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.bn.mutable_num_features:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.1.conv.mutable_out_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.2.bn.mutable_num_features:
current_choice: 56
origin_channels: 96
backbone.layer4.0.conv.2.conv.mutable_in_channels:
current_choice: 144
origin_channels: 288
backbone.layer4.0.conv.2.conv.mutable_out_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.1.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.0.conv.mutable_in_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.1.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.2.bn.mutable_num_features:
current_choice: 56
origin_channels: 96
backbone.layer4.1.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.1.conv.2.conv.mutable_out_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.2.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.0.conv.mutable_in_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.2.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.2.bn.mutable_num_features:
current_choice: 56
origin_channels: 96
backbone.layer4.2.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.2.conv.2.conv.mutable_out_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.3.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.0.conv.mutable_in_channels:
current_choice: 56
origin_channels: 96
backbone.layer4.3.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.2.bn.mutable_num_features:
current_choice: 56
origin_channels: 96
backbone.layer4.3.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer4.3.conv.2.conv.mutable_out_channels:
current_choice: 56
origin_channels: 96
backbone.layer5.0.conv.0.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.0.conv.mutable_in_channels:
current_choice: 56
origin_channels: 96
backbone.layer5.0.conv.0.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.bn.mutable_num_features:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.1.conv.mutable_out_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.2.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer5.0.conv.2.conv.mutable_in_channels:
current_choice: 288
origin_channels: 576
backbone.layer5.0.conv.2.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer5.1.conv.0.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.0.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer5.1.conv.0.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.1.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.2.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer5.1.conv.2.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.1.conv.2.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer5.2.conv.0.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.0.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer5.2.conv.0.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.bn.mutable_num_features:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.1.conv.mutable_out_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.2.bn.mutable_num_features:
current_choice: 96
origin_channels: 144
backbone.layer5.2.conv.2.conv.mutable_in_channels:
current_choice: 432
origin_channels: 864
backbone.layer5.2.conv.2.conv.mutable_out_channels:
current_choice: 96
origin_channels: 144
backbone.layer6.0.conv.0.bn.mutable_num_features:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.0.conv.mutable_in_channels:
current_choice: 96
origin_channels: 144
backbone.layer6.0.conv.0.conv.mutable_out_channels:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.1.bn.mutable_num_features:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.1.conv.mutable_in_channels:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.1.conv.mutable_out_channels:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.2.bn.mutable_num_features:
current_choice: 240
origin_channels: 240
backbone.layer6.0.conv.2.conv.mutable_in_channels:
current_choice: 864
origin_channels: 864
backbone.layer6.0.conv.2.conv.mutable_out_channels:
current_choice: 240
origin_channels: 240
backbone.layer6.1.conv.0.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.0.conv.mutable_in_channels:
current_choice: 240
origin_channels: 240
backbone.layer6.1.conv.0.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.1.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.1.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.1.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.2.bn.mutable_num_features:
current_choice: 240
origin_channels: 240
backbone.layer6.1.conv.2.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer6.1.conv.2.conv.mutable_out_channels:
current_choice: 240
origin_channels: 240
backbone.layer6.2.conv.0.bn.mutable_num_features:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.0.conv.mutable_in_channels:
current_choice: 240
origin_channels: 240
backbone.layer6.2.conv.0.conv.mutable_out_channels:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.1.bn.mutable_num_features:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.1.conv.mutable_in_channels:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.1.conv.mutable_out_channels:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.2.bn.mutable_num_features:
current_choice: 240
origin_channels: 240
backbone.layer6.2.conv.2.conv.mutable_in_channels:
current_choice: 960
origin_channels: 1440
backbone.layer6.2.conv.2.conv.mutable_out_channels:
current_choice: 240
origin_channels: 240
backbone.layer7.0.conv.0.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.0.conv.mutable_in_channels:
current_choice: 240
origin_channels: 240
backbone.layer7.0.conv.0.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.bn.mutable_num_features:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.1.conv.mutable_out_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.2.bn.mutable_num_features:
current_choice: 480
origin_channels: 480
backbone.layer7.0.conv.2.conv.mutable_in_channels:
current_choice: 1440
origin_channels: 1440
backbone.layer7.0.conv.2.conv.mutable_out_channels:
current_choice: 480
origin_channels: 480
head.fc.mutable_in_features:
current_choice: 1920
origin_channels: 1920
head.fc.mutable_out_features:
current_choice: 1000
origin_channels: 1000

View File

@ -0,0 +1,24 @@
op1.mutable_in_channels:
current_choice: 3
origin_channels: 3
op1.mutable_out_channels:
current_choice: 4
origin_channels: 8
bn1.mutable_num_features:
current_choice: 4
origin_channels: 8
op2.mutable_in_channels:
current_choice: 4
origin_channels: 8
op2.mutable_out_channels:
current_choice: 4
origin_channels: 8
bn2.mutable_num_features:
current_choice: 4
origin_channels: 8
op3.mutable_in_channels:
current_choice: 4
origin_channels: 8
op3.mutable_out_channels:
current_choice: 8
origin_channels: 8

View File

@ -0,0 +1,24 @@
op1.mutable_in_channels:
current_choice: 3
origin_channels: 3
op1.mutable_out_channels:
current_choice: 8
origin_channels: 8
bn1.mutable_num_features:
current_choice: 8
origin_channels: 8
op2.mutable_in_channels:
current_choice: 8
origin_channels: 8
op2.mutable_out_channels:
current_choice: 8
origin_channels: 8
bn2.mutable_num_features:
current_choice: 8
origin_channels: 8
op3.mutable_in_channels:
current_choice: 8
origin_channels: 8
op3.mutable_out_channels:
current_choice: 8
origin_channels: 8

View File

@ -0,0 +1,260 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from torch import Tensor, nn
from torch.nn import Module
from mmrazor.core import (BackwardTracer, ConcatNode, ConvNode,
DepthWiseConvNode, LinearNode, NormNode, Path,
PathList)
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
PASS_NODES = (NormNode, DepthWiseConvNode)
class MultiConcatModel(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(16, 8, 1)
self.op4 = nn.Conv2d(3, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x4 = self.op4(x)
output = torch.cat([x3, x4], dim=1)
return output
class MultiConcatModel2(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(3, 8, 1)
self.op4 = nn.Conv2d(24, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
cat1 = torch.cat([x1, x2], dim=1)
cat2 = torch.cat([cat1, x3], dim=1)
output = self.op4(cat2)
return output
class ResBlock(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(8, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x1))
x3 = self.op3(x2 + x1)
return x3
class ToyCNNPseudoLoss:
def __call__(self, model):
pseudo_img = torch.rand(2, 3, 16, 16)
pseudo_output = model(pseudo_img)
return pseudo_output.sum()
class TestBackwardTracer(TestCase):
def test_trace_resblock(self) -> None:
model = ResBlock()
loss_calculator = ToyCNNPseudoLoss()
tracer = BackwardTracer(loss_calculator=loss_calculator)
path_list = tracer.trace(model)
# test tracer and parser
assert len(path_list) == 2
assert len(path_list[0]) == 5
# test path_list
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
assert len(nonpass2parents) == 3
assert nonpass2parents['op1'] == list()
assert nonpass2parents['op2'] == list({NormNode('bn1')})
assert nonpass2parents['op3'] == list(
{NormNode('bn2'), NormNode('bn1')})
nonpass2nonpassparents = path_list.find_nodes_parents(
NONPASS_NODES, non_pass=NONPASS_NODES)
assert len(nonpass2parents) == 3
assert nonpass2nonpassparents['op1'] == list()
assert nonpass2nonpassparents['op2'] == list({ConvNode('op1')})
assert nonpass2nonpassparents['op3'] == list(
{ConvNode('op2'), ConvNode('op1')})
pass2nonpassparents = path_list.find_nodes_parents(
PASS_NODES, non_pass=NONPASS_NODES)
assert len(pass2nonpassparents) == 2
assert pass2nonpassparents['bn1'] == list({ConvNode('op1')})
assert pass2nonpassparents['bn2'] == list({ConvNode('op2')})
def test_trace_multi_cat(self) -> None:
loss_calculator = ToyCNNPseudoLoss()
model = MultiConcatModel()
tracer = BackwardTracer(loss_calculator=loss_calculator)
path_list = tracer.trace(model)
assert len(path_list) == 1
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
assert len(nonpass2parents) == 4
assert nonpass2parents['op1'] == list()
assert nonpass2parents['op2'] == list()
path_list1 = PathList(Path(ConvNode('op1')))
path_list2 = PathList(Path(ConvNode('op2')))
# only one parent
assert len(nonpass2parents['op3']) == 1
assert isinstance(nonpass2parents['op3'][0], ConcatNode)
assert len(nonpass2parents['op3'][0]) == 2
assert nonpass2parents['op3'][0].get_module_names() == ['op1', 'op2']
assert nonpass2parents['op3'][0].path_lists == [path_list1, path_list2]
assert nonpass2parents['op3'][0][0] == path_list1
assert nonpass2parents['op4'] == list()
model = MultiConcatModel2()
tracer = BackwardTracer(loss_calculator=loss_calculator)
path_list = tracer.trace(model)
assert len(path_list) == 1
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
assert len(nonpass2parents) == 4
assert nonpass2parents['op1'] == list()
assert nonpass2parents['op2'] == list()
assert nonpass2parents['op3'] == list()
# only one parent
assert len(nonpass2parents['op4']) == 1
assert isinstance(nonpass2parents['op4'][0], ConcatNode)
assert nonpass2parents['op4'][0].get_module_names() == [
'op1', 'op2', 'op3'
]
def test_repr(self):
toy_node = ConvNode('op1')
assert repr(toy_node) == 'ConvNode(\'op1\')'
toy_path = Path([ConvNode('op1'), ConvNode('op2')])
assert repr(
toy_path) == 'Path(\n ConvNode(\'op1\'),\n ConvNode(\'op2\')\n)'
toy_path_list = PathList(Path(ConvNode('op1')))
assert repr(toy_path_list
) == 'PathList(\n Path(\n ConvNode(\'op1\')\n )\n)'
path_list1 = PathList(Path(ConvNode('op1')))
path_list2 = PathList(Path(ConvNode('op2')))
toy_concat_node = ConcatNode('op3', [path_list1, path_list2])
assert repr(
toy_concat_node
) == 'ConcatNode(\n PathList(\n Path(\n ConvNode(\'op1\')\n )\n ),\n PathList(\n Path(\n ConvNode(\'op2\')\n )\n )\n)' # noqa: E501
def test_reset_bn_running_stats(self):
_test_reset_bn_running_stats(False)
with pytest.raises(AssertionError):
_test_reset_bn_running_stats(True)
def test_node(self):
node1 = ConvNode('conv1')
node2 = ConvNode('conv2')
assert node1 != node2
node1 = ConvNode('conv1')
node2 = ConvNode('conv1')
assert node1 == node2
def test_path(self):
node1 = ConvNode('conv1')
node2 = ConvNode('conv2')
path1 = Path([node1])
path2 = Path([node2])
assert path1 != path2
path1 = Path([node1])
path2 = Path([node1])
assert path1 == path2
assert path1[0] == node1
def test_path_list(self):
node1 = ConvNode('conv1')
node2 = ConvNode('conv2')
path1 = Path([node1])
path2 = Path([node2])
assert PathList(path1) == PathList([path1])
assert PathList(path1) != PathList(path2)
with self.assertRaisesRegex(AssertionError, ''):
_ = PathList({})
def _test_reset_bn_running_stats(should_fail):
import os
import random
import numpy as np
def set_seed(seed: int) -> None:
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(1024)
imgs = torch.randn(2, 3, 4, 4)
loss_calculator = ToyCNNPseudoLoss()
tracer = BackwardTracer(loss_calculator=loss_calculator)
if should_fail:
tracer._reset_norm_running_stats = lambda *_: None
torch_rng_state = torch.get_rng_state()
np_rng_state = np.random.get_state()
random_rng_state = random.getstate()
model1 = ResBlock()
set_seed(1)
tracer.trace(model1)
model1.eval()
output1 = model1(imgs)
set_seed(1024)
torch.set_rng_state(torch_rng_state)
np.random.set_state(np_rng_state)
random.setstate(random_rng_state)
model2 = ResBlock()
set_seed(2)
tracer.trace(model2)
model2.eval()
output2 = model2(imgs)
assert torch.equal(output1.norm(p='fro'), output2.norm(p='fro'))

View File

@ -0,0 +1,164 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import pytest
import torch
from mmrazor.models import OrderChannelMutable, RatioChannelMutable
class TestChannelMutables(TestCase):
def test_ratio_channel_mutable(self):
with pytest.raises(AssertionError):
# Test invalid `mask_type`
RatioChannelMutable(
name='op',
mask_type='xxx',
num_channels=8,
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
with pytest.raises(AssertionError):
# Number of candidate choices must be greater than 0
RatioChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=list())
with pytest.raises(AssertionError):
# The candidate ratio should be in range(0, 1].
RatioChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[0., 1 / 4, 2 / 4, 3 / 4, 1.0])
with pytest.raises(AssertionError):
# Minimum number of channels should be a positive integer.
out_mutable = RatioChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0])
_ = out_mutable.min_choice
with pytest.raises(AssertionError):
# Minimum number of channels should be a positive integer.
out_mutable = RatioChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0])
out_mutable.get_choice(0)
# Test out_mutable (mask_type == 'out_mask')
out_mutable = RatioChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
random_choice = out_mutable.sample_choice()
assert random_choice in [2, 4, 6, 8]
choice = out_mutable.get_choice(0)
assert choice == 2
max_choice = out_mutable.max_choice
assert max_choice == 8
out_mutable.current_choice = max_choice
assert torch.equal(out_mutable.mask,
torch.ones_like(out_mutable.mask).bool())
min_choice = out_mutable.min_choice
assert min_choice == 2
out_mutable.current_choice = min_choice
min_mask = torch.zeros_like(out_mutable.mask).bool()
min_mask[:2] = True
assert torch.equal(out_mutable.mask, min_mask)
with pytest.raises(AssertionError):
# Only mutables with mask_type == 'in_mask' (named in_mutable) can
# add `concat_mutables`
concat_mutables = [copy.deepcopy(out_mutable)] * 2
out_mutable.register_same_mutable(concat_mutables)
# Test in_mutable (mask_type == 'in_mask') with concat_mutable
in_mutable = RatioChannelMutable(
name='op',
mask_type='in_mask',
num_channels=16,
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
out_mutable1 = copy.deepcopy(out_mutable)
out_mutable2 = copy.deepcopy(out_mutable)
in_mutable.register_same_mutable([out_mutable1, out_mutable2])
choice1 = out_mutable1.sample_choice()
out_mutable1.current_choice = choice1
choice2 = out_mutable2.sample_choice()
out_mutable2.current_choice = choice2
assert in_mutable.current_choice == choice1 + choice2
assert torch.equal(in_mutable.mask,
torch.cat([out_mutable1.mask, out_mutable2.mask]))
with pytest.raises(AssertionError):
# The mask of this in_mutable depends on the out mask of its
# `concat_mutables`, so the `sample_choice` method should not
# be called
in_mutable.sample_choice()
with pytest.raises(AssertionError):
# The mask of this in_mutable depends on the out mask of its
# `concat_mutables`, so the `min_choice` property should not
# be called
_ = in_mutable.min_choice
with pytest.raises(AssertionError):
# The mask of this in_mutable depends on the out mask of its
# `concat_mutables`, so the `get_choice` method should not
# be called
in_mutable.get_choice(0)
def test_order_channel_mutable(self):
with pytest.raises(AssertionError):
# The candidate ratio should be in range(0, `num_channels`].
OrderChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[0, 2, 4, 6, 8])
with pytest.raises(AssertionError):
# Type of `candidate_choices` should be int.
OrderChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[0., 2, 4, 6, 8])
# Test out_mutable (mask_type == 'out_mask')
out_mutable = OrderChannelMutable(
name='op',
mask_type='out_mask',
num_channels=8,
candidate_choices=[2, 4, 6, 8])
random_choice = out_mutable.sample_choice()
assert random_choice in [2, 4, 6, 8]
choice = out_mutable.get_choice(0)
assert choice == 2
max_choice = out_mutable.max_choice
assert max_choice == 8
out_mutable.current_choice = max_choice
assert torch.equal(out_mutable.mask,
torch.ones_like(out_mutable.mask).bool())
min_choice = out_mutable.min_choice
assert min_choice == 2
out_mutable.current_choice = min_choice
min_mask = torch.zeros_like(out_mutable.mask).bool()
min_mask[:2] = True
assert torch.equal(out_mutable.mask, min_mask)

View File

@ -0,0 +1,136 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch import nn
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
build_dynamic_conv2d,
build_dynamic_gn,
build_dynamic_in,
build_dynamic_linear)
class TestDynamicLayer(TestCase):
def test_dynamic_conv(self):
imgs = torch.rand(2, 8, 16, 16)
in_channels_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
out_channels_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
conv = nn.Conv2d(8, 8, 1)
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
out_channels_cfg)
# test forward
dynamic_conv(imgs)
conv = nn.Conv2d(8, 8, 1, groups=8)
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
out_channels_cfg)
# test forward
dynamic_conv(imgs)
conv = nn.Conv2d(8, 8, 1, groups=4)
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
out_channels_cfg)
# test forward
with self.assertRaisesRegex(NotImplementedError,
'only support pruning the depth-wise'):
dynamic_conv(imgs)
def test_dynamic_linear(self):
imgs = torch.rand(2, 8)
in_features_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
out_features_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
linear = nn.Linear(8, 8)
dynamic_linear = build_dynamic_linear(linear, 'op', in_features_cfg,
out_features_cfg)
# test forward
dynamic_linear(imgs)
def test_dynamic_batchnorm(self):
imgs = torch.rand(2, 8, 16, 16)
num_features_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
bn = nn.BatchNorm2d(8)
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
# test forward
dynamic_bn(imgs)
bn = nn.BatchNorm2d(8, momentum=0)
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
# test forward
dynamic_bn(imgs)
bn = nn.BatchNorm2d(8)
bn.train()
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
# test forward
dynamic_bn(imgs)
# test num_batches_tracked is not None
dynamic_bn(imgs)
bn = nn.BatchNorm2d(8, affine=False)
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
# test forward
dynamic_bn(imgs)
bn = nn.BatchNorm2d(8, track_running_stats=False)
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
# test forward
dynamic_bn(imgs)
def test_dynamic_instancenorm(self):
imgs = torch.rand(2, 8, 16, 16)
num_features_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
instance_norm = nn.InstanceNorm2d(8)
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
# test forward
dynamic_in(imgs)
instance_norm = nn.InstanceNorm2d(8, affine=False)
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
# test forward
dynamic_in(imgs)
instance_norm = nn.InstanceNorm2d(8, track_running_stats=False)
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
# test forward
dynamic_in(imgs)
def test_dynamic_groupnorm(self):
imgs = torch.rand(2, 8, 16, 16)
num_channels_cfg = dict(
type='RatioChannelMutable',
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
gn = nn.GroupNorm(num_groups=4, num_channels=8)
dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg)
# test forward
dynamic_gn(imgs)
gn = nn.GroupNorm(num_groups=4, num_channels=8, affine=False)
dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg)
# test forward
dynamic_gn(imgs)

View File

@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from os.path import dirname
import mmcv.fileio
import pytest
import torch
from mmcls.models import * # noqa: F401,F403
from torch import Tensor, nn
from torch.nn import Module
from mmrazor import digit_version
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
build_dynamic_conv2d)
from mmrazor.models.mutables import SlimmableChannelMutable
from mmrazor.models.mutators import (OneShotChannelMutator,
SlimmableChannelMutator)
from mmrazor.registry import MODELS
ONESHOT_MUTATOR_CFG = dict(
type='OneShotChannelMutator',
tracer_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
mutable_cfg=dict(
type='RatioChannelMutable',
candidate_choices=[
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
]))
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER = dict(
type='OneShotChannelMutator',
mutable_cfg=dict(
type='RatioChannelMutable',
candidate_choices=[
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
]))
class MultiConcatModel(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(16, 8, 1)
self.op4 = nn.Conv2d(3, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x4 = self.op4(x)
output = torch.cat([x3, x4], dim=1)
return output
class MultiConcatModel2(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(3, 8, 1)
self.op4 = nn.Conv2d(24, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
cat1 = torch.cat([x1, x2], dim=1)
cat2 = torch.cat([cat1, x3], dim=1)
output = self.op4(cat2)
return output
class ResBlock(Module):
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(8, 8, 1)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x1))
x3 = self.op3(x2 + x1)
return x3
class DynamicResBlock(Module):
def __init__(self, mutable_cfg) -> None:
super().__init__()
self.dynamic_op1 = build_dynamic_conv2d(
nn.Conv2d(3, 8, 1), 'dynamic_op1', mutable_cfg, mutable_cfg)
self.dynamic_bn1 = build_dynamic_bn(
nn.BatchNorm2d(8), 'dynamic_bn1', mutable_cfg)
self.dynamic_op2 = build_dynamic_conv2d(
nn.Conv2d(8, 8, 1), 'dynamic_op2', mutable_cfg, mutable_cfg)
self.dynamic_bn2 = build_dynamic_bn(
nn.BatchNorm2d(8), 'dynamic_bn2', mutable_cfg)
self.dynamic_op3 = build_dynamic_conv2d(
nn.Conv2d(8, 8, 1), 'dynamic_op3', mutable_cfg, mutable_cfg)
self._add_link()
def _add_link(self):
op1_mutable_out = self.dynamic_op1.mutable_out
bn1_mutable_out = self.dynamic_bn1.mutable_out
op2_mutable_in = self.dynamic_op2.mutable_in
op2_mutable_out = self.dynamic_op2.mutable_out
bn2_mutable_out = self.dynamic_bn2.mutable_out
op3_mutable_in = self.dynamic_op3.mutable_in
bn1_mutable_out.register_same_mutable(op1_mutable_out)
op1_mutable_out.register_same_mutable(bn1_mutable_out)
op2_mutable_in.register_same_mutable(bn1_mutable_out)
bn1_mutable_out.register_same_mutable(op2_mutable_in)
bn2_mutable_out.register_same_mutable(op2_mutable_out)
op2_mutable_out.register_same_mutable(bn2_mutable_out)
op3_mutable_in.register_same_mutable(bn1_mutable_out)
bn1_mutable_out.register_same_mutable(op3_mutable_in)
op3_mutable_in.register_same_mutable(bn2_mutable_out)
bn2_mutable_out.register_same_mutable(op3_mutable_in)
def forward(self, x: Tensor) -> Tensor:
x1 = self.dynamic_bn1(self.dynamic_op1(x))
x2 = self.dynamic_bn2(self.dynamic_op2(x1))
x3 = self.dynamic_op3(x2 + x1)
return x3
@unittest.skipIf(
digit_version(torch.__version__) == digit_version('1.8.1'),
'PyTorch version 1.8.1 is not supported by the Backward Tracer.')
def test_oneshot_channel_mutator() -> None:
imgs = torch.randn(16, 3, 224, 224)
def _test(model):
mutator.prepare_from_supernet(model)
for key, val in mutator.search_groups.items():
print(key, val)
# print(mutator.search_groups)
assert hasattr(mutator, 'name2module')
# test set_min_choices
mutator.set_min_choices()
for mutables in mutator.search_groups.values():
for mutable in mutables:
# 1 / 8 is the minimum candidate ratio
assert mutable.current_choice == round(1 / 8 *
mutable.num_channels)
# test set_max_channel
mutator.set_max_choices()
for mutables in mutator.search_groups.values():
for mutable in mutables:
# 1.0 is the maximum candidate ratio
assert mutable.current_choice == round(1. *
mutable.num_channels)
# test making groups logic
choice_dict = mutator.sample_choices()
assert isinstance(choice_dict, dict)
mutator.set_choices(choice_dict)
model(imgs)
mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG)
with pytest.raises(RuntimeError):
_ = mutator.search_groups
with pytest.raises(RuntimeError):
_ = mutator.name2module
_test(ResBlock())
_test(MultiConcatModel())
_test(MultiConcatModel2())
_test(nn.Sequential(nn.BatchNorm2d(3)))
mutator: OneShotChannelMutator = MODELS.build(
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER)
dynamic_model = DynamicResBlock(
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER['mutable_cfg'])
_test(dynamic_model)
def test_slimmable_channel_mutator() -> None:
imgs = torch.randn(16, 3, 224, 224)
root_path = dirname(dirname(dirname(__file__)))
channel_cfgs = [
os.path.join(root_path, 'data/subnet1.yaml'),
os.path.join(root_path, 'data/subnet2.yaml')
]
channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs]
mutator = SlimmableChannelMutator(
mutable_cfg=dict(type='SlimmableChannelMutable'),
channel_cfgs=channel_cfgs)
model = ResBlock()
mutator.prepare_from_supernet(model)
mutator.switch_choices(0)
for name, module in model.named_modules():
if isinstance(module, SlimmableChannelMutable):
assert module.current_choice == 0
_ = model(imgs)
mutator.switch_choices(1)
for name, module in model.named_modules():
if isinstance(module, SlimmableChannelMutable):
assert module.current_choice == 1
_ = model(imgs)

View File

@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from os.path import dirname
import mmcv.fileio
import torch
from mmcls.core import ClsDataSample
from mmcls.models import * # noqa: F401,F403
from mmrazor import digit_version
from mmrazor.models.mutables import SlimmableChannelMutable
from mmrazor.models.mutators import (OneShotChannelMutator,
SlimmableChannelMutator)
from mmrazor.registry import MODELS
MODEL_CFG = dict(
type='mmcls.ImageClassifier',
backbone=dict(type='MobileNetV2', widen_factor=1.5),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1920,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))
ONESHOT_MUTATOR_CFG = dict(
type='OneShotChannelMutator',
skip_prefixes=['head.fc'],
tracer_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
mutable_cfg=dict(
type='RatioChannelMutable',
candidate_choices=[
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
]))
@unittest.skipIf(
digit_version(torch.__version__) == digit_version('1.8.1'),
'PyTorch version 1.8.1 is not supported by the Backward Tracer.')
def test_oneshot_channel_mutator() -> None:
imgs = torch.randn(16, 3, 224, 224)
data_samples = [
ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, )))
]
model = MODELS.build(MODEL_CFG)
mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG)
mutator.prepare_from_supernet(model)
assert hasattr(mutator, 'name2module')
# test set_min_choices
mutator.set_min_choices()
for mutables in mutator.search_groups.values():
for mutable in mutables:
# 1 / 8 is the minimum candidate ratio
assert mutable.current_choice == round(1 / 8 *
mutable.num_channels)
# test set_max_channel
mutator.set_max_choices()
for mutables in mutator.search_groups.values():
for mutable in mutables:
# 1.0 is the maximum candidate ratio
assert mutable.current_choice == round(1. * mutable.num_channels)
# test making groups logic
choice_dict = mutator.sample_choices()
assert isinstance(choice_dict, dict)
mutator.set_choices(choice_dict)
model(imgs, data_samples=data_samples, mode='loss')
def test_slimmable_channel_mutator() -> None:
imgs = torch.randn(16, 3, 224, 224)
data_samples = [
ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, )))
]
root_path = dirname(dirname(dirname(dirname(__file__))))
channel_cfgs = [
os.path.join(root_path, 'data/MBV2_320M.yaml'),
os.path.join(root_path, 'data/MBV2_220M.yaml')
]
channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs]
mutator = SlimmableChannelMutator(
mutable_cfg=dict(type='SlimmableChannelMutable'),
channel_cfgs=channel_cfgs)
model = MODELS.build(MODEL_CFG)
mutator.prepare_from_supernet(model)
mutator.switch_choices(0)
for name, module in model.named_modules():
if isinstance(module, SlimmableChannelMutable):
assert module.current_choice == 0
model(imgs, data_samples=data_samples, mode='loss')
mutator.switch_choices(1)
for name, module in model.named_modules():
if isinstance(module, SlimmableChannelMutable):
assert module.current_choice == 1
model(imgs, data_samples=data_samples, mode='loss')

View File

@ -104,7 +104,7 @@ class TestDiffMutator(TestCase):
mutator: DiffOP = MODELS.build(self.MUTATOR_CFG)
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
assert list(mutator.search_groups.keys()) == [0, 1, 2]
def test_diff_mutator_diffop_model(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
@ -118,7 +118,7 @@ class TestDiffMutator(TestCase):
mutator: DiffOP = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
assert list(mutator.search_groups.keys()) == [0, 1, 2]
mutator.modify_supernet_forward()
assert mutator.mutable_class_type == DiffMutable
@ -146,7 +146,7 @@ class TestDiffMutator(TestCase):
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
assert list(mutator.search_groups.keys()) == [0, 1, 2]
mutator.modify_supernet_forward()
assert mutator.mutable_class_type == DiffMutable
@ -165,7 +165,7 @@ class TestDiffMutator(TestCase):
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2, 3]
assert list(mutator.search_groups.keys()) == [0, 1, 2, 3]
mutator.modify_supernet_forward()
assert mutator.mutable_class_type == DiffMutable

View File

@ -59,10 +59,10 @@ def test_one_shot_mutator_normal_model() -> None:
assert mutator.mutable_class_type == OneShotMutable
with pytest.raises(RuntimeError):
_ = mutator.search_group
_ = mutator.search_groups
mutator.prepare_from_supernet(model)
assert len(mutator.search_group) == 0
assert len(mutator.search_groups) == 0
assert len(mutator.sample_choices()) == 0
@ -89,7 +89,7 @@ def test_one_shot_mutator_mutable_model() -> None:
# import pdb; pdb.set_trace()
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1, 2]
assert list(mutator.search_groups.keys()) == [0, 1, 2]
random_choices = mutator.sample_choices()
assert list(random_choices.keys()) == [0, 1, 2]
@ -102,7 +102,7 @@ def test_one_shot_mutator_mutable_model() -> None:
mutator = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
assert list(mutator.search_group.keys()) == [0, 1]
assert list(mutator.search_groups.keys()) == [0, 1]
random_choices = mutator.sample_choices()
assert list(random_choices.keys()) == [0, 1]