[Refactor]Refactor tracer and channel mutator
parent
332f49ac6f
commit
42063ae4d3
|
@ -1,2 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .delivers import * # noqa: F401,F403
|
||||
from .tracer import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backward_tracer import BackwardTracer
|
||||
from .loss_calculator import * # noqa: F401,F403
|
||||
from .parsers import * # noqa: F401,F403
|
||||
from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, Node,
|
||||
NormNode, Path, PathList)
|
||||
|
||||
__all__ = [
|
||||
'BackwardTracer', 'ConvNode', 'LinearNode', 'NormNode', 'ConcatNode',
|
||||
'Path', 'PathList', 'Node', 'DepthWiseConvNode'
|
||||
]
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from mmcv import ConfigDict
|
||||
from torch.nn import Conv2d, Linear
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _NormBase
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .parsers import DEFAULT_BACKWARD_TRACER
|
||||
from .path import Path, PathList
|
||||
|
||||
SUPPORT_MODULES = (Conv2d, Linear, _NormBase, GroupNorm)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BackwardTracer:
|
||||
"""A topology tracer via backward.
|
||||
|
||||
Args:
|
||||
loss_calculator (dict or Callable): Calculate the pseudo loss to trace
|
||||
the topology of a model.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_calculator):
|
||||
if isinstance(loss_calculator, (dict, ConfigDict)):
|
||||
loss_calculator = TASK_UTILS.build(loss_calculator)
|
||||
|
||||
assert callable(
|
||||
loss_calculator
|
||||
), 'loss_calculator should be a dict, ConfigDict or ' \
|
||||
'callable object'
|
||||
self.loss_calculator = loss_calculator
|
||||
|
||||
@property
|
||||
def backward_parser(self):
|
||||
"""The mapping from the type of a backward op to the corresponding
|
||||
parser."""
|
||||
return DEFAULT_BACKWARD_TRACER
|
||||
|
||||
def backward_trace(self, grad_fn, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module):
|
||||
"""Trace the topology of all the ``NON_PASS_MODULE``."""
|
||||
grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
|
||||
|
||||
if grad_fn is not None:
|
||||
name = type(grad_fn).__name__
|
||||
# In pytorch graph, there may be an additional '0' or '1'
|
||||
# (e.g. ThnnConv2DBackward0) after a backward op. Delete the
|
||||
# digit numbers to build the corresponding parser.
|
||||
name = re.sub(r'[0-1]+', '', name)
|
||||
parse_module = self.backward_parser.get(name)
|
||||
|
||||
if parse_module is not None:
|
||||
parse_module(self, grad_fn, module2name, param2module,
|
||||
cur_path, result_paths, visited, shared_module)
|
||||
else:
|
||||
# If the op is AccumulateGrad, parents is (),
|
||||
parents = grad_fn.next_functions
|
||||
if parents is not None:
|
||||
for parent in parents:
|
||||
self.backward_trace(parent, module2name, param2module,
|
||||
cur_path, result_paths, visited,
|
||||
shared_module)
|
||||
else:
|
||||
result_paths.append(copy.deepcopy(cur_path))
|
||||
|
||||
def _trace_shared_module_hook(self, module, inputs, outputs):
|
||||
"""Trace shared modules. Modules such as the detection head in
|
||||
RetinaNet which are visited more than once during :func:`forward` are
|
||||
shared modules.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to register hook.
|
||||
inputs (tuple): The input of the module.
|
||||
outputs (tuple): The output of the module.
|
||||
"""
|
||||
module._cnt += 1
|
||||
|
||||
def _build_mappings(self, model):
|
||||
"""Build the mappings which are used during tracing."""
|
||||
|
||||
module2name = OrderedDict()
|
||||
# build a mapping from the identity of a module's params
|
||||
# to this module
|
||||
param2module = OrderedDict()
|
||||
# record the visited module name during trace path
|
||||
visited = dict()
|
||||
|
||||
def traverse(module, prefix=''):
|
||||
for name, child in module.named_children():
|
||||
full_name = f'{prefix}.{name}' if prefix else name
|
||||
|
||||
if isinstance(child, SUPPORT_MODULES):
|
||||
module2name[child] = full_name
|
||||
for param in child.parameters():
|
||||
param2module[id(param)] = child
|
||||
visited[full_name] = False
|
||||
else:
|
||||
traverse(child, full_name)
|
||||
|
||||
traverse(model)
|
||||
|
||||
return module2name, param2module, visited
|
||||
|
||||
def _register_share_module_hook(self, model):
|
||||
"""Record shared modules which will be visited more than once during
|
||||
forward such as shared detection head in RetinaNet.
|
||||
|
||||
If a module is not a shared module and it has been visited during
|
||||
forward, its parent modules must have been traced already. However, a
|
||||
shared module will be visited more than once during forward, so it is
|
||||
still need to be traced even if it has been visited.
|
||||
"""
|
||||
self._shared_module_hook_handles = list()
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'weight'):
|
||||
# trace shared modules
|
||||
module._cnt = 0
|
||||
# the handle is only to remove the corresponding hook later
|
||||
handle = module.register_forward_hook(
|
||||
self._trace_shared_module_hook)
|
||||
self._shared_module_hook_handles.append(handle)
|
||||
|
||||
def _remove_share_module_hook(self, model):
|
||||
"""`_trace_shared_module_hook` and `_cnt` are only used to trace the
|
||||
shared modules in a model and need to be remove later."""
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'weight'):
|
||||
del module._cnt
|
||||
|
||||
for handle in self._shared_module_hook_handles:
|
||||
handle.remove()
|
||||
|
||||
del self._shared_module_hook_handles
|
||||
|
||||
def _set_all_requires_grad(self, model):
|
||||
"""Set `requires_grad` of a parameter to True to trace the whole
|
||||
architecture topology."""
|
||||
self._param_requires_grad = dict()
|
||||
for param in model.parameters():
|
||||
self._param_requires_grad[id(param)] = param.requires_grad
|
||||
param.requires_grad = True
|
||||
|
||||
def _restore_requires_grad(self, model):
|
||||
"""We set requires_grad to True to trace the whole architecture
|
||||
topology.
|
||||
|
||||
So it should be reset after that.
|
||||
"""
|
||||
for param in model.parameters():
|
||||
param.requires_grad = self._param_requires_grad[id(param)]
|
||||
del self._param_requires_grad
|
||||
|
||||
@staticmethod
|
||||
def _find_share_modules(model):
|
||||
"""Find shared modules which will be visited more than once during
|
||||
forward such as shared detection head in RetinaNet."""
|
||||
share_modules = list()
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'weight'):
|
||||
if module._cnt > 1:
|
||||
share_modules.append(name)
|
||||
|
||||
return share_modules
|
||||
|
||||
@staticmethod
|
||||
def _reset_norm_running_stats(model):
|
||||
"""As we calculate the pseudo loss during tracing, we need to reset
|
||||
states of parameters."""
|
||||
for module in model.modules():
|
||||
if isinstance(module, _NormBase):
|
||||
module.reset_parameters()
|
||||
|
||||
def trace(self, model):
|
||||
"""Trace trace the architecture topology of the input model."""
|
||||
module2name, param2module, visited = self._build_mappings(model)
|
||||
|
||||
# Set requires_grad to True. If the `requires_grad` of a module's
|
||||
# weight is False, we can not trace this module by parsing backward.
|
||||
self._set_all_requires_grad(model)
|
||||
|
||||
self._register_share_module_hook(model)
|
||||
|
||||
pseudo_loss = self.loss_calculator(model)
|
||||
|
||||
share_modules = self._find_share_modules(model)
|
||||
|
||||
self._remove_share_module_hook(model)
|
||||
self._restore_requires_grad(model)
|
||||
|
||||
module_path_list = PathList()
|
||||
|
||||
self.backward_trace(pseudo_loss.grad_fn, module2name, param2module,
|
||||
Path(), module_path_list, visited, share_modules)
|
||||
|
||||
self._reset_norm_running_stats(model)
|
||||
|
||||
return module_path_list
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .image_classifier_loss_calculator import ImageClassifierPseudoLoss
|
||||
from .single_stage_detector_loss_calculator import \
|
||||
SingleStageDetectorPseudoLoss
|
||||
|
||||
__all__ = ['ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss']
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcls.models import ImageClassifier
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ImageClassifierPseudoLoss:
|
||||
"""Calculate the pseudo loss to trace the topology of a `ImageClassifier`
|
||||
in MMClassification with `BackwardTracer`."""
|
||||
|
||||
def __call__(self, model: ImageClassifier) -> torch.Tensor:
|
||||
pseudo_img = torch.rand(1, 3, 224, 224)
|
||||
pseudo_output = model(pseudo_img)
|
||||
return sum(pseudo_output)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet.models import BaseDetector
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
|
||||
# todo: adapt to mmdet 2.0
|
||||
@TASK_UTILS.register_module()
|
||||
class SingleStageDetectorPseudoLoss:
|
||||
|
||||
def __call__(self, model: BaseDetector) -> torch.Tensor:
|
||||
pseudo_img = torch.rand(1, 3, 224, 224)
|
||||
pseudo_output = model.forward_dummy(pseudo_img)
|
||||
out = torch.tensor(0.)
|
||||
for levels in pseudo_output:
|
||||
out += sum([level.sum() for level in levels])
|
||||
|
||||
return out
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Callable, Dict
|
||||
|
||||
from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode,
|
||||
NormNode, Path, PathList)
|
||||
|
||||
|
||||
def _is_leaf_grad_fn(grad_fn):
|
||||
"""Determine whether the current node is a leaf node."""
|
||||
if type(grad_fn).__name__ == 'AccumulateGrad':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_conv(tracer, grad_fn, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module):
|
||||
"""Parse the backward of a conv layer.
|
||||
|
||||
Example:
|
||||
>>> conv = nn.Conv2d(3, 3, 3)
|
||||
>>> pseudo_img = torch.rand(1, 3, 224, 224)
|
||||
>>> out = conv(pseudo_img)
|
||||
>>> out.grad_fn.next_functions
|
||||
((None, 0), (<AccumulateGrad object at 0x0000020E405CBD88>, 0),
|
||||
(<AccumulateGrad object at 0x0000020E405CB588>, 0))
|
||||
>>> # op.next_functions[0][0] is None means this ThnnConv2DBackward
|
||||
>>> # op has no parents
|
||||
>>> # op.next_functions[1][0].variable is the weight of this Conv2d
|
||||
>>> # module
|
||||
>>> # op.next_functions[2][0].variable is the bias of this Conv2d
|
||||
>>> # module
|
||||
"""
|
||||
leaf_grad_fn = grad_fn.next_functions[1][0]
|
||||
while not _is_leaf_grad_fn(leaf_grad_fn):
|
||||
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
|
||||
variable = leaf_grad_fn.variable
|
||||
param_id = id(variable)
|
||||
module = param2module[param_id]
|
||||
name = module2name[module]
|
||||
parent = grad_fn.next_functions[0][0]
|
||||
if module.in_channels == module.groups:
|
||||
cur_path.append(DepthWiseConvNode(name))
|
||||
else:
|
||||
cur_path.append(ConvNode(name))
|
||||
# If a module is not a shared module and it has been visited during
|
||||
# forward, its parent modules must have been traced already.
|
||||
# However, a shared module will be visited more than once during
|
||||
# forward, so it is still need to be traced even if it has been
|
||||
# visited.
|
||||
if visited[name] and name not in shared_module:
|
||||
result_paths.append(copy.deepcopy(cur_path))
|
||||
else:
|
||||
visited[name] = True
|
||||
tracer.backward_trace(parent, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module)
|
||||
cur_path.pop(-1)
|
||||
|
||||
|
||||
# todo: support parsing `MultiheadAttention` and user-defined matrix
|
||||
# multiplication
|
||||
def parse_linear(tracer, grad_fn, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module):
|
||||
"""Parse the backward of a conv layer.
|
||||
|
||||
Example:
|
||||
>>> fc = nn.Linear(3, 3, bias=True)
|
||||
>>> input = torch.rand(3, 3)
|
||||
>>> out = fc(input)
|
||||
>>> out.grad_fn.next_functions
|
||||
((<AccumulateGrad object at 0x0000020E405F75C8>, 0), (None, 0),
|
||||
(<TBackward object at 0x0000020E405F7D48>, 0))
|
||||
>>> # op.next_functions[0][0].variable is the bias of this Linear
|
||||
>>> # module
|
||||
>>> # op.next_functions[1][0] is None means this AddmmBackward op
|
||||
>>> # has no parents
|
||||
>>> # op.next_functions[2][0] is the TBackward op, and
|
||||
>>> # op.next_functions[2][0].next_functions[0][0].variable is
|
||||
>>> # the transpose of the weight of this Linear module
|
||||
"""
|
||||
leaf_grad_fn = grad_fn.next_functions[-1][0].next_functions[0][0]
|
||||
while not _is_leaf_grad_fn(leaf_grad_fn):
|
||||
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
|
||||
variable = leaf_grad_fn.variable
|
||||
param_id = id(variable)
|
||||
module = param2module[param_id]
|
||||
name = module2name[module]
|
||||
parent = grad_fn.next_functions[-2][0]
|
||||
|
||||
cur_path.append(LinearNode(name))
|
||||
# If a module is not a shared module and it has been visited during
|
||||
# forward, its parent modules must have been traced already.
|
||||
# However, a shared module will be visited more than once during
|
||||
# forward, so it is still need to be traced even if it has been
|
||||
# visited.
|
||||
if visited[name] and name not in shared_module:
|
||||
result_paths.append(copy.deepcopy(cur_path))
|
||||
else:
|
||||
visited[name] = True
|
||||
tracer.backward_trace(parent, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module)
|
||||
|
||||
|
||||
def parse_cat(tracer, grad_fn, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module):
|
||||
"""Parse the backward of a concat operation.
|
||||
|
||||
Example:
|
||||
>>> conv = nn.Conv2d(3, 3, 3)
|
||||
>>> pseudo_img = torch.rand(1, 3, 224, 224)
|
||||
>>> out1 = conv(pseudo_img)
|
||||
>>> out2 = conv(pseudo_img)
|
||||
>>> out = torch.cat([out1, out2], dim=1)
|
||||
>>> out.grad_fn.next_functions
|
||||
((<ThnnConv2DBackward object at 0x0000020E405F24C8>, 0),
|
||||
(<ThnnConv2DBackward object at 0x0000020E405F2648>, 0))
|
||||
>>> # the length of ``out.grad_fn.next_functions`` is two means
|
||||
>>> # ``out`` is obtained by concatenating two tensors
|
||||
"""
|
||||
parents = grad_fn.next_functions
|
||||
concat_id = '_'.join([str(id(p)) for p in parents])
|
||||
name = f'concat_{concat_id}'
|
||||
# If a module is not a shared module and it has been visited during
|
||||
# forward, its parent modules must have been traced already.
|
||||
# However, a shared module will be visited more than once during
|
||||
# forward, so it is still need to be traced even if it has been
|
||||
# visited.
|
||||
if (name in visited and visited[name] and name not in shared_module):
|
||||
pass
|
||||
else:
|
||||
visited[name] = True
|
||||
sub_path_lists = list()
|
||||
for i, parent in enumerate(parents):
|
||||
sub_path_list = PathList()
|
||||
tracer.backward_trace(parent, module2name, param2module, Path(),
|
||||
sub_path_list, visited, shared_module)
|
||||
sub_path_lists.append(sub_path_list)
|
||||
cur_path.append(ConcatNode(name, sub_path_lists))
|
||||
|
||||
result_paths.append(copy.deepcopy(cur_path))
|
||||
cur_path.pop(-1)
|
||||
|
||||
|
||||
def parse_norm(tracer, grad_fn, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module):
|
||||
"""Parse the backward of a concat operation.
|
||||
|
||||
Example:
|
||||
>>> conv = nn.Conv2d(3, 3, 3)
|
||||
>>> pseudo_img = torch.rand(1, 3, 224, 224)
|
||||
>>> out1 = conv(pseudo_img)
|
||||
>>> out2 = conv(pseudo_img)
|
||||
>>> out = torch.cat([out1, out2], dim=1)
|
||||
>>> out.grad_fn.next_functions
|
||||
((<ThnnConv2DBackward object at 0x0000020E405F24C8>, 0),
|
||||
(<ThnnConv2DBackward object at 0x0000020E405F2648>, 0))
|
||||
>>> # the length of ``out.grad_fn.next_functions`` is two means
|
||||
>>> # ``out`` is obtained by concatenating two tensors
|
||||
"""
|
||||
leaf_grad_fn = grad_fn.next_functions[1][0]
|
||||
while not _is_leaf_grad_fn(leaf_grad_fn):
|
||||
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
|
||||
variable = leaf_grad_fn.variable
|
||||
param_id = id(variable)
|
||||
module = param2module[param_id]
|
||||
name = module2name[module]
|
||||
parent = grad_fn.next_functions[0][0]
|
||||
cur_path.append(NormNode(name))
|
||||
|
||||
visited[name] = True
|
||||
tracer.backward_trace(parent, module2name, param2module, cur_path,
|
||||
result_paths, visited, shared_module)
|
||||
cur_path.pop(-1)
|
||||
|
||||
|
||||
DEFAULT_BACKWARD_TRACER: Dict[str, Callable] = {
|
||||
'ThnnConv2DBackward': parse_conv,
|
||||
'CudnnConvolutionBackward': parse_conv,
|
||||
'MkldnnConvolutionBackward': parse_conv,
|
||||
'SlowConvDilated2DBackward': parse_conv,
|
||||
'ThAddmmBackward': parse_linear,
|
||||
'AddmmBackward': parse_linear,
|
||||
'MmBackward': parse_linear,
|
||||
'CatBackward': parse_cat,
|
||||
'ThnnBatchNormBackward': parse_norm,
|
||||
'CudnnBatchNormBackward': parse_norm,
|
||||
'NativeBatchNormBackward': parse_norm,
|
||||
'NativeGroupNormBackward': parse_norm
|
||||
}
|
|
@ -0,0 +1,353 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def _addindent(s_, numSpaces):
|
||||
s = s_.split('\n')
|
||||
# don't do anything for single-line stuff
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(numSpaces * ' ') + line for line in s]
|
||||
s = '\n'.join(s)
|
||||
s = first + '\n' + s
|
||||
return s
|
||||
|
||||
|
||||
def _merge_node_parents(node2parents, _node2parents):
|
||||
for node, parents in _node2parents.items():
|
||||
if node in node2parents:
|
||||
cur_parents = node2parents[node]
|
||||
new_parents_set = set(cur_parents + parents)
|
||||
new_parents = list(new_parents_set)
|
||||
node2parents[node] = new_parents
|
||||
else:
|
||||
node2parents[node] = parents
|
||||
|
||||
|
||||
class Node:
|
||||
"""``Node`` is the data structure that represents individual instances
|
||||
within a ``Path``. It corresponds to a module or an operation such as
|
||||
concatenation in the model.
|
||||
|
||||
Args:
|
||||
name (str): Unique identifier of a node.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self._name = name
|
||||
|
||||
def get_module_names(self) -> List:
|
||||
return [self.name]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the name of current node."""
|
||||
return self._name
|
||||
|
||||
def _get_class_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.name == other.name
|
||||
else:
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self._get_class_name()}(\'{self.name}\')'
|
||||
|
||||
|
||||
class ConvNode(Node):
|
||||
"""A `ConvNode` corresponds to a Conv module in the original model."""
|
||||
pass
|
||||
|
||||
|
||||
class DepthWiseConvNode(Node):
|
||||
"""A `DepthWiseConvNode` corresponds to a depth-wise conv module in the
|
||||
original model."""
|
||||
pass
|
||||
|
||||
|
||||
class NormNode(Node):
|
||||
"""A `NormNode` corresponds to a normalization module in the original
|
||||
model."""
|
||||
pass
|
||||
|
||||
|
||||
class LinearNode(Node):
|
||||
"""A `LinearNode` corresponds to a linear module in the original model."""
|
||||
pass
|
||||
|
||||
|
||||
class Path:
|
||||
"""``Path`` is the data structure that represents a list of ``Node`` traced
|
||||
by a tracer.
|
||||
|
||||
Args:
|
||||
nodes(:obj:`Node` or List[:obj:`Node`], optional): Nodes in a path.
|
||||
Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self, nodes: Optional[Union[Node, List[Node]]] = None):
|
||||
self._nodes: List[Node] = list()
|
||||
if nodes is not None:
|
||||
if isinstance(nodes, Node):
|
||||
nodes = [nodes]
|
||||
assert isinstance(nodes, (list, tuple))
|
||||
for node in nodes:
|
||||
assert isinstance(node, Node)
|
||||
self._nodes.append(node)
|
||||
|
||||
def get_root_names(self) -> List[str]:
|
||||
"""Get the name of the first node in a path."""
|
||||
return self._nodes[0].get_module_names()
|
||||
|
||||
def find_nodes_parents(self,
|
||||
target_nodes: Tuple,
|
||||
non_pass: Optional[Tuple] = None) -> Dict:
|
||||
"""Find the parents of a specific node.
|
||||
|
||||
Args:
|
||||
target_nodes (Tuple): Find the parents of nodes whose types
|
||||
are one of `target_nodes`.
|
||||
non_pass (Tuple): Ancestor nodes whose types are one of
|
||||
`non_pass` are the parents of a specific node. Default to None.
|
||||
"""
|
||||
node2parents: Dict[str, List[Node]] = dict()
|
||||
for i, node in enumerate(self._nodes):
|
||||
if isinstance(node, ConcatNode):
|
||||
_node2parents: Dict[str, List[Node]] = node.find_nodes_parents(
|
||||
target_nodes, non_pass)
|
||||
_merge_node_parents(node2parents, _node2parents)
|
||||
continue
|
||||
|
||||
if isinstance(node, target_nodes):
|
||||
parents = list()
|
||||
for behind_node in self._nodes[i + 1:]:
|
||||
if non_pass is None or isinstance(behind_node, non_pass):
|
||||
parents.append(behind_node)
|
||||
break
|
||||
_node2parents = {node.name: parents}
|
||||
_merge_node_parents(node2parents, _node2parents)
|
||||
return node2parents
|
||||
|
||||
@property
|
||||
def nodes(self) -> List:
|
||||
"""Return a list of nodes in the current path."""
|
||||
return self._nodes
|
||||
|
||||
def append(self, x: Node) -> None:
|
||||
"""Add a node to the end of the current path."""
|
||||
assert isinstance(x, Node)
|
||||
self._nodes.append(x)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
"""Temoves the node at the given index from the path and returns the
|
||||
removed node."""
|
||||
return self._nodes.pop(*args, **kwargs)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.nodes == other.nodes
|
||||
else:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
return len(self._nodes)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._nodes[item]
|
||||
|
||||
def __iter__(self):
|
||||
for node in self._nodes:
|
||||
yield node
|
||||
|
||||
def _get_class_name(self) -> str:
|
||||
"""Get the name of the current class."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr__(self):
|
||||
child_lines = []
|
||||
for node in self._nodes:
|
||||
node_str = repr(node)
|
||||
node_str = _addindent(node_str, 2)
|
||||
child_lines.append(node_str)
|
||||
lines = child_lines
|
||||
|
||||
main_str = self._get_class_name() + '('
|
||||
if lines:
|
||||
main_str += '\n ' + ',\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
return main_str
|
||||
|
||||
|
||||
class PathList:
|
||||
"""``PathList`` is the data structure that represents a list of ``Path``
|
||||
traced by a tracer.
|
||||
|
||||
Args:
|
||||
paths(:obj:`Path` or List[:obj:`Path`], optional): A list of `Path`.
|
||||
Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: Optional[Union[Path, List[Path]]] = None):
|
||||
self._paths = list()
|
||||
if paths is not None:
|
||||
if isinstance(paths, Path):
|
||||
paths = [paths]
|
||||
assert isinstance(paths, (list, tuple))
|
||||
for path in paths:
|
||||
assert isinstance(path, Path)
|
||||
self._paths.append(path)
|
||||
|
||||
def get_root_names(self) -> List[str]:
|
||||
"""Get the root node of all the paths in `PathList`.
|
||||
|
||||
Notes:
|
||||
Different paths in a PathList share the same root node.
|
||||
"""
|
||||
return self._paths[0].get_root_names()
|
||||
|
||||
def find_nodes_parents(self,
|
||||
target_nodes: Tuple,
|
||||
non_pass: Optional[Tuple] = None):
|
||||
"""Find the parents of a specific node.
|
||||
|
||||
Args:
|
||||
target_nodes (Tuple): Find the parents of nodes whose types
|
||||
are one of `target_nodes`.
|
||||
non_pass (Tuple): Ancestor nodes whose types are one of
|
||||
`non_pass` are the parents of a specific node. Default to None.
|
||||
"""
|
||||
node2parents: Dict[str, List[Node]] = dict()
|
||||
for p in self._paths:
|
||||
_node2parents = p.find_nodes_parents(target_nodes, non_pass)
|
||||
_merge_node_parents(node2parents, _node2parents)
|
||||
return node2parents
|
||||
|
||||
def append(self, x: Path) -> None:
|
||||
"""Add a path to the end of the current PathList."""
|
||||
assert isinstance(x, Path)
|
||||
self._paths.append(x)
|
||||
|
||||
@property
|
||||
def paths(self):
|
||||
"""Return all paths in the current PathList."""
|
||||
return self._paths
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.paths == other.paths
|
||||
else:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
return len(self._paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._paths[item]
|
||||
|
||||
def __iter__(self):
|
||||
for path in self._paths:
|
||||
yield path
|
||||
|
||||
def _get_class_name(self) -> str:
|
||||
"""Get the name of the current class."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr__(self):
|
||||
child_lines = []
|
||||
for node in self._paths:
|
||||
node_str = repr(node)
|
||||
node_str = _addindent(node_str, 2)
|
||||
child_lines.append(node_str)
|
||||
lines = child_lines
|
||||
|
||||
main_str = self._get_class_name() + '('
|
||||
if lines:
|
||||
main_str += '\n ' + ',\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
return main_str
|
||||
|
||||
|
||||
class ConcatNode(Node):
|
||||
"""``ConcatNode`` is the data structure that represents the concatenation
|
||||
operation in a model.
|
||||
|
||||
Args:
|
||||
name (str): Unique identifier of a `ConcatNode`.
|
||||
path_lists (List[PathList]): Several nodes are concatenated and each
|
||||
node is the root node of all the paths in a `PathList`
|
||||
(one of `path_lists`).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, path_lists: List[PathList]):
|
||||
super().__init__(name)
|
||||
self._path_lists = list()
|
||||
for path_list in path_lists:
|
||||
assert isinstance(path_list, PathList)
|
||||
self._path_lists.append(path_list)
|
||||
|
||||
def get_module_names(self) -> List[str]:
|
||||
"""Several nodes are concatenated.
|
||||
|
||||
Get the names of these nodes.
|
||||
"""
|
||||
module_names = list()
|
||||
for path_list in self._path_lists:
|
||||
module_names.extend(path_list.get_root_names())
|
||||
return module_names
|
||||
|
||||
def find_nodes_parents(self,
|
||||
target_nodes: Tuple,
|
||||
non_pass: Optional[Tuple] = None):
|
||||
"""Find the parents of a specific node.
|
||||
|
||||
Args:
|
||||
target_nodes (Tuple): Find the parents of nodes whose types
|
||||
are one of `target_nodes`.
|
||||
non_pass (Tuple): Ancestor nodes whose types are one of
|
||||
`non_pass` are the parents of a specific node. Default to None.
|
||||
"""
|
||||
node2parents: Dict[str, List[Node]] = dict()
|
||||
for p in self._path_lists:
|
||||
_node2parents = p.find_nodes_parents(target_nodes, non_pass)
|
||||
_merge_node_parents(node2parents, _node2parents)
|
||||
return node2parents
|
||||
|
||||
@property
|
||||
def path_lists(self) -> List[PathList]:
|
||||
"""Return all the path_list."""
|
||||
return self._path_lists
|
||||
|
||||
def __len__(self):
|
||||
return len(self._path_lists)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._path_lists[item]
|
||||
|
||||
def __iter__(self):
|
||||
for path_list in self._path_lists:
|
||||
yield path_list
|
||||
|
||||
def _get_class_name(self) -> str:
|
||||
"""Get the name of the current class."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr__(self):
|
||||
child_lines = []
|
||||
for node in self._path_lists:
|
||||
node_str = repr(node)
|
||||
node_str = _addindent(node_str, 2)
|
||||
child_lines.append(node_str)
|
||||
lines = child_lines
|
||||
|
||||
main_str = self._get_class_name() + '('
|
||||
if lines:
|
||||
main_str += '\n ' + ',\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
return main_str
|
|
@ -6,5 +6,4 @@ from .losses import * # noqa: F401,F403
|
|||
from .mutables import * # noqa: F401,F403
|
||||
from .mutators import * # noqa: F401,F403
|
||||
from .ops import * # noqa: F401,F403
|
||||
from .pruners import * # noqa: F401,F403
|
||||
from .subnet import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .default_dynamic_ops import * # noqa: F401,F403
|
||||
from .slimmable_dynamic_ops import * # noqa: F401,F403
|
|
@ -0,0 +1,408 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.instancenorm import _InstanceNorm
|
||||
|
||||
from mmrazor.models.mutables import MutableManagerMixIn
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
class DynamicConv2d(nn.Conv2d, MutableManagerMixIn):
|
||||
"""Applies a 2D convolution over an input signal composed of several input
|
||||
planes according to the `mutable_in_channels` and `mutable_out_channels`
|
||||
dynamically.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `DynamicConv2d`.
|
||||
in_channels_cfg (Dict): Config related to `in_channels`.
|
||||
out_channels_cfg (Dict): Config related to `out_channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name, in_channels_cfg, out_channels_cfg, *args,
|
||||
**kwargs):
|
||||
super(DynamicConv2d, self).__init__(*args, **kwargs)
|
||||
|
||||
in_channels_cfg = copy.deepcopy(in_channels_cfg)
|
||||
in_channels_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.in_channels,
|
||||
mask_type='in_mask'))
|
||||
self.mutable_in_channels = MODELS.build(in_channels_cfg)
|
||||
|
||||
out_channels_cfg = copy.deepcopy(out_channels_cfg)
|
||||
out_channels_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.out_channels,
|
||||
mask_type='out_mask'))
|
||||
self.mutable_out_channels = MODELS.build(out_channels_cfg)
|
||||
|
||||
@property
|
||||
def mutable_in(self):
|
||||
"""Mutable `in_channels`."""
|
||||
return self.mutable_in_channels
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `out_channels`."""
|
||||
return self.mutable_out_channels
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_in_channels` and
|
||||
`mutable_out_channels`, and forward."""
|
||||
in_mask = self.mutable_in_channels.mask
|
||||
out_mask = self.mutable_out_channels.mask
|
||||
|
||||
if self.groups == 1:
|
||||
weight = self.weight[out_mask][:, in_mask]
|
||||
groups = 1
|
||||
elif self.groups == self.in_channels == self.out_channels:
|
||||
# depth-wise conv
|
||||
weight = self.weight[out_mask]
|
||||
groups = input.size(1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Current `ChannelMutator` only support pruning the depth-wise '
|
||||
'`nn.Conv2d` or `nn.Conv2d` module whose group number equals '
|
||||
'to one, but got {self.groups}.')
|
||||
|
||||
bias = self.bias[out_mask] if self.bias is not None else None
|
||||
|
||||
return F.conv2d(input, weight, bias, self.stride, self.padding,
|
||||
self.dilation, groups)
|
||||
|
||||
|
||||
class DynamicLinear(nn.Linear, MutableManagerMixIn):
|
||||
"""Applies a linear transformation to the incoming data according to the
|
||||
`mutable_in_features` and `mutable_out_features` dynamically.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `DynamicLinear`.
|
||||
in_features_cfg (Dict): Config related to `in_features`.
|
||||
out_features_cfg (Dict): Config related to `out_features`.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name, in_features_cfg, out_features_cfg, *args,
|
||||
**kwargs):
|
||||
super(DynamicLinear, self).__init__(*args, **kwargs)
|
||||
|
||||
in_features_cfg = copy.deepcopy(in_features_cfg)
|
||||
in_features_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.in_features,
|
||||
mask_type='in_mask'))
|
||||
self.mutable_in_features = MODELS.build(in_features_cfg)
|
||||
|
||||
out_features_cfg = copy.deepcopy(out_features_cfg)
|
||||
out_features_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.out_features,
|
||||
mask_type='out_mask'))
|
||||
self.mutable_out_features = MODELS.build(out_features_cfg)
|
||||
|
||||
@property
|
||||
def mutable_in(self):
|
||||
"""Mutable `in_features`."""
|
||||
return self.mutable_in_features
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `out_features`."""
|
||||
return self.mutable_out_features
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_in_features` and
|
||||
`mutable_out_features`, and forward."""
|
||||
in_mask = self.mutable_in_features.mask
|
||||
out_mask = self.mutable_out_features.mask
|
||||
|
||||
weight = self.weight[out_mask][:, in_mask]
|
||||
bias = self.bias[out_mask] if self.bias is not None else None
|
||||
|
||||
return F.linear(input, weight, bias)
|
||||
|
||||
|
||||
class DynamicBatchNorm(_BatchNorm, MutableManagerMixIn):
|
||||
"""Applies Batch Normalization over an input according to the
|
||||
`mutable_num_features` dynamically.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `DynamicBatchNorm`.
|
||||
num_features_cfg (Dict): Config related to `num_features`.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name, num_features_cfg, *args, **kwargs):
|
||||
super(DynamicBatchNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
num_features_cfg = copy.deepcopy(num_features_cfg)
|
||||
num_features_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.num_features,
|
||||
mask_type='out_mask'))
|
||||
self.mutable_num_features = MODELS.build(num_features_cfg)
|
||||
|
||||
@property
|
||||
def mutable_in(self):
|
||||
"""Mutable `num_features`."""
|
||||
return self.mutable_num_features
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `num_features`."""
|
||||
return self.mutable_num_features
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_num_features`, and
|
||||
forward."""
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
if self.num_batches_tracked is not None: # type: ignore
|
||||
self.num_batches_tracked = \
|
||||
self.num_batches_tracked + 1 # type: ignore
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(
|
||||
self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training:
|
||||
bn_training = True
|
||||
else:
|
||||
bn_training = (self.running_mean is None) and (self.running_var is
|
||||
None)
|
||||
|
||||
if self.affine:
|
||||
out_mask = self.mutable_num_features.mask
|
||||
weight = self.weight[out_mask]
|
||||
bias = self.bias[out_mask]
|
||||
else:
|
||||
weight, bias = self.weight, self.bias
|
||||
|
||||
if self.track_running_stats:
|
||||
out_mask = self.mutable_num_features.mask
|
||||
running_mean = self.running_mean[out_mask] \
|
||||
if not self.training or self.track_running_stats else None
|
||||
running_var = self.running_var[out_mask] \
|
||||
if not self.training or self.track_running_stats else None
|
||||
else:
|
||||
running_mean, running_var = self.running_mean, self.running_var
|
||||
|
||||
return F.batch_norm(input, running_mean, running_var, weight, bias,
|
||||
bn_training, exponential_average_factor, self.eps)
|
||||
|
||||
|
||||
class DynamicInstanceNorm(_InstanceNorm, MutableManagerMixIn):
|
||||
"""Applies Instance Normalization over an input according to the
|
||||
`mutable_num_features` dynamically.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `DynamicInstanceNorm`.
|
||||
num_features_cfg (Dict): Config related to `num_features`.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name, num_features_cfg, *args, **kwargs):
|
||||
super(DynamicInstanceNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
num_features_cfg = copy.deepcopy(num_features_cfg)
|
||||
num_features_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.num_features,
|
||||
mask_type='out_mask'))
|
||||
self.mutable_num_features = MODELS.build(num_features_cfg)
|
||||
|
||||
@property
|
||||
def mutable_in(self):
|
||||
"""Mutable `num_features`."""
|
||||
return self.mutable_num_features
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `num_features`."""
|
||||
return self.mutable_num_features
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_num_features`, and
|
||||
forward."""
|
||||
if self.affine:
|
||||
out_mask = self.mutable_num_features.mask
|
||||
weight = self.weight[out_mask]
|
||||
bias = self.bias[out_mask]
|
||||
else:
|
||||
weight, bias = self.weight, self.bias
|
||||
|
||||
if self.track_running_stats:
|
||||
out_mask = self.mutable_num_features.mask
|
||||
running_mean = self.running_mean[out_mask]
|
||||
running_var = self.running_var[out_mask]
|
||||
else:
|
||||
running_mean, running_var = self.running_mean, self.running_var
|
||||
|
||||
return F.instance_norm(input, running_mean, running_var, weight, bias,
|
||||
self.training or not self.track_running_stats,
|
||||
self.momentum, self.eps)
|
||||
|
||||
|
||||
class DynamicGroupNorm(GroupNorm, MutableManagerMixIn):
|
||||
"""Applies Group Normalization over a mini-batch of inputs according to the
|
||||
`mutable_num_channels` dynamically.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `DynamicGroupNorm`.
|
||||
num_channels_cfg (Dict): Config related to `num_channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name, num_channels_cfg, *args, **kwargs):
|
||||
super(DynamicGroupNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
num_channels_cfg = copy.deepcopy(num_channels_cfg)
|
||||
num_channels_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=self.num_channels,
|
||||
mask_type='out_mask'))
|
||||
self.mutable_num_channels = MODELS.build(num_channels_cfg)
|
||||
|
||||
@property
|
||||
def mutable_in(self):
|
||||
"""Mutable `num_channels`."""
|
||||
return self.mutable_num_channels
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `num_channels`."""
|
||||
return self.mutable_num_channels
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_num_channels`, and
|
||||
forward."""
|
||||
if self.affine:
|
||||
out_mask = self.mutable_num_channels.mask
|
||||
weight = self.weight[out_mask]
|
||||
bias = self.bias[out_mask]
|
||||
else:
|
||||
weight, bias = self.weight, self.bias
|
||||
|
||||
return F.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
|
||||
def build_dynamic_conv2d(module: nn.Conv2d, module_name: str,
|
||||
in_channels_cfg: Dict,
|
||||
out_channels_cfg: Dict) -> DynamicConv2d:
|
||||
"""Build DynamicConv2d.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Conv2d`): The original Conv2d module.
|
||||
module_name (str): Name of this `DynamicConv2d`.
|
||||
in_channels_cfg (Dict): Config related to `in_channels`.
|
||||
out_channels_cfg (Dict): Config related to `out_channels`.
|
||||
"""
|
||||
dynamic_conv = DynamicConv2d(
|
||||
module_name=module_name,
|
||||
in_channels_cfg=in_channels_cfg,
|
||||
out_channels_cfg=out_channels_cfg,
|
||||
in_channels=module.in_channels,
|
||||
out_channels=module.out_channels,
|
||||
kernel_size=module.kernel_size,
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
bias=True if module.bias is not None else False,
|
||||
padding_mode=module.padding_mode)
|
||||
return dynamic_conv
|
||||
|
||||
|
||||
def build_dynamic_linear(module: nn.Linear, module_name: str,
|
||||
in_features_cfg: Dict,
|
||||
out_features_cfg: Dict) -> DynamicLinear:
|
||||
"""Build DynamicLinear.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Linear`): The original Linear module.
|
||||
module_name (str): Name of this `DynamicLinear`.
|
||||
in_features_cfg (Dict): Config related to `in_features`.
|
||||
out_features_cfg (Dict): Config related to `out_features`.
|
||||
"""
|
||||
dynamic_linear = DynamicLinear(
|
||||
module_name=module_name,
|
||||
in_features_cfg=in_features_cfg,
|
||||
out_features_cfg=out_features_cfg,
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=True if module.bias is not None else False)
|
||||
return dynamic_linear
|
||||
|
||||
|
||||
def build_dynamic_bn(module: _BatchNorm, module_name: str,
|
||||
num_features_cfg: Dict) -> DynamicBatchNorm:
|
||||
"""Build DynamicBatchNorm.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module.
|
||||
module_name (str): Name of this `DynamicBatchNorm`.
|
||||
num_features_cfg (Dict): Config related to `num_features`.
|
||||
"""
|
||||
dynamic_bn = DynamicBatchNorm(
|
||||
module_name=module_name,
|
||||
num_features_cfg=num_features_cfg,
|
||||
num_features=module.num_features,
|
||||
eps=module.eps,
|
||||
momentum=module.momentum,
|
||||
affine=module.affine,
|
||||
track_running_stats=module.track_running_stats)
|
||||
return dynamic_bn
|
||||
|
||||
|
||||
def build_dynamic_in(module: _InstanceNorm, module_name: str,
|
||||
num_features_cfg: Dict) -> DynamicInstanceNorm:
|
||||
"""Build DynamicInstanceNorm.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn._InstanceNorm`): The original InstanceNorm
|
||||
module.
|
||||
module_name (str): Name of this `DynamicInstanceNorm`.
|
||||
num_features_cfg (Dict): Config related to `num_features`.
|
||||
"""
|
||||
dynamic_in = DynamicInstanceNorm(
|
||||
module_name=module_name,
|
||||
num_features_cfg=num_features_cfg,
|
||||
num_features=module.num_features,
|
||||
eps=module.eps,
|
||||
momentum=module.momentum,
|
||||
affine=module.affine,
|
||||
track_running_stats=module.track_running_stats)
|
||||
return dynamic_in
|
||||
|
||||
|
||||
def build_dynamic_gn(module: GroupNorm, module_name: str,
|
||||
num_channels_cfg: Dict) -> DynamicGroupNorm:
|
||||
"""Build DynamicGroupNorm.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.GroupNorm`): The original GroupNorm module.
|
||||
module_name (str): Name of this `DynamicGroupNorm`.
|
||||
num_channels_cfg (Dict): Config related to `num_channels`.
|
||||
"""
|
||||
dynamic_gn = DynamicGroupNorm(
|
||||
module_name=module_name,
|
||||
num_channels_cfg=num_channels_cfg,
|
||||
num_channels=module.num_channels,
|
||||
num_groups=module.num_groups,
|
||||
eps=module.eps,
|
||||
affine=module.affine)
|
||||
return dynamic_gn
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models.mutables import MutableManagerMixIn
|
||||
from mmrazor.registry import MODELS
|
||||
from .default_dynamic_ops import build_dynamic_conv2d, build_dynamic_linear
|
||||
|
||||
|
||||
class SwitchableBatchNorm2d(nn.Module, MutableManagerMixIn):
|
||||
"""Employs independent batch normalization for different switches in a
|
||||
slimmable network.
|
||||
|
||||
To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all
|
||||
batch normalization layers for each switch in a slimmable network.
|
||||
Compared with the naive training approach, it solves the problem of feature
|
||||
aggregation inconsistency between different switches by independently
|
||||
normalizing the feature mean and variance during testing.
|
||||
|
||||
Args:
|
||||
module_name (str): Name of this `SwitchableBatchNorm2d`.
|
||||
num_features_cfg (Dict): Config related to `num_features`.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Same as that in :obj:`torch.nn._BatchNorm`. Default: 1e-5
|
||||
momentum (float): The value used for the running_mean and running_var
|
||||
computation. Can be set to None for cumulative moving average
|
||||
(i.e. simple average). Same as that in :obj:`torch.nn._BatchNorm`.
|
||||
Default: 0.1
|
||||
affine (bool): A boolean value that when set to True, this module has
|
||||
learnable affine parameters. Same as that in
|
||||
:obj:`torch.nn._BatchNorm`. Default: True
|
||||
track_running_stats (bool): A boolean value that when set to True, this
|
||||
module tracks the running mean and variance, and when set to False,
|
||||
this module does not track such statistics, and initializes
|
||||
statistics buffers running_mean and running_var as None. When these
|
||||
buffers are None, this module always uses batch statistics. in both
|
||||
training and eval modes. Same as that in
|
||||
:obj:`torch.nn._BatchNorm`. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module_name: str,
|
||||
num_features_cfg: Dict,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.1,
|
||||
affine: bool = True,
|
||||
track_running_stats: bool = True):
|
||||
super(SwitchableBatchNorm2d, self).__init__()
|
||||
|
||||
num_features_cfg = copy.deepcopy(num_features_cfg)
|
||||
candidate_choices = num_features_cfg.pop('candidate_choices')
|
||||
num_features_cfg.update(
|
||||
dict(
|
||||
name=module_name,
|
||||
num_channels=max(candidate_choices),
|
||||
mask_type='out_mask'))
|
||||
|
||||
bns = [
|
||||
nn.BatchNorm2d(num_features, eps, momentum, affine,
|
||||
track_running_stats)
|
||||
for num_features in candidate_choices
|
||||
]
|
||||
self.bns = nn.ModuleList(bns)
|
||||
|
||||
self.mutable_num_features = MODELS.build(num_features_cfg)
|
||||
|
||||
@property
|
||||
def mutable_out(self):
|
||||
"""Mutable `num_features`."""
|
||||
return self.mutable_num_features
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward computation according to the current switch of the slimmable
|
||||
networks."""
|
||||
idx = self.mutable_num_features.current_choice
|
||||
return self.bns[idx](input)
|
||||
|
||||
|
||||
def build_switchable_bn(module: _BatchNorm, module_name: str,
|
||||
num_features_cfg: Dict) -> SwitchableBatchNorm2d:
|
||||
"""Build SwitchableBatchNorm2d.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.GroupNorm`): The original BatchNorm module.
|
||||
module_name (str): Name of this `SwitchableBatchNorm2d`.
|
||||
num_channels_cfg (Dict): Config related to `num_features`.
|
||||
"""
|
||||
switchable_bn = SwitchableBatchNorm2d(
|
||||
module_name=module_name,
|
||||
num_features_cfg=num_features_cfg,
|
||||
eps=module.eps,
|
||||
momentum=module.momentum,
|
||||
affine=module.affine,
|
||||
track_running_stats=module.track_running_stats)
|
||||
return switchable_bn
|
||||
|
||||
|
||||
SLIMMABLE_DYNAMIC_LAYER: Dict[Callable, Callable] = {
|
||||
nn.Conv2d: build_dynamic_conv2d,
|
||||
nn.Linear: build_dynamic_linear,
|
||||
nn.BatchNorm2d: build_switchable_bn
|
||||
}
|
|
@ -1,9 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .diff_mutable import (DiffChoiceRoute, DiffMutable, DiffOP,
|
||||
GumbelChoiceRoute)
|
||||
from .mutable_channel import (OneShotChannelMutable, OrderChannelMutable,
|
||||
RatioChannelMutable, SlimmableChannelMutable)
|
||||
from .mutable_manager_mixin import MutableManagerMixIn
|
||||
from .oneshot_mutable import OneShotMutable, OneShotOP
|
||||
|
||||
__all__ = [
|
||||
'OneShotOP', 'OneShotMutable', 'DiffOP', 'DiffChoiceRoute',
|
||||
'GumbelChoiceRoute', 'DiffMutable'
|
||||
'OneShotOP', 'OneShotMutable', 'OneShotChannelMutable',
|
||||
'RatioChannelMutable', 'OrderChannelMutable', 'DiffOP', 'DiffChoiceRoute',
|
||||
'GumbelChoiceRoute', 'DiffMutable', 'MutableManagerMixIn',
|
||||
'SlimmableChannelMutable'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .one_shot_channel_mutable import OneShotChannelMutable
|
||||
from .order_channel_mutable import OrderChannelMutable
|
||||
from .ratio_channel_mutable import RatioChannelMutable
|
||||
from .slimmable_channel_mutable import SlimmableChannelMutable
|
||||
|
||||
__all__ = [
|
||||
'OneShotChannelMutable', 'OrderChannelMutable', 'RatioChannelMutable',
|
||||
'SlimmableChannelMutable'
|
||||
]
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
class OneShotChannelMutable(BaseModule, ABC):
|
||||
"""A type of ``MUTABLES`` for single path supernet such as AutoSlim. In
|
||||
single path supernet, each module only has one choice invoked at the same
|
||||
time. A path is obtained by sampling all the available choices. It is the
|
||||
base class for one shot channel mutables.
|
||||
|
||||
Args:
|
||||
name (str): Mutable name.
|
||||
mask_type (str): One of 'in_mask' or 'out_mask'.
|
||||
num_channels (int): The raw number of channels.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
and `Pretrained`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
mask_type: str,
|
||||
num_channels: int,
|
||||
init_cfg: Optional[Dict] = None):
|
||||
super(OneShotChannelMutable, self).__init__(init_cfg=init_cfg)
|
||||
# If the input of a module is a concatenation of several modules'
|
||||
# outputs, we add the out_mutable (mask_type == 'out_mask') of
|
||||
# these modules to the `concat_mutables` of this module.
|
||||
self.concat_mutables: List[OneShotChannelMutable] = list()
|
||||
self.name = name
|
||||
assert mask_type in ('in_mask', 'out_mask')
|
||||
self.mask_type = mask_type
|
||||
self.num_channels = num_channels
|
||||
self.register_buffer('_mask', torch.ones((num_channels, )).bool())
|
||||
self._current_choice = num_channels
|
||||
|
||||
self._same_mutables: List[OneShotChannelMutable] = list()
|
||||
|
||||
@property
|
||||
def same_mutables(self):
|
||||
"""Mutables in `same_mutables` and the current mutable should change
|
||||
Synchronously."""
|
||||
return self._same_mutables
|
||||
|
||||
def register_same_mutable(self, mutable):
|
||||
"""Register the input mutable in `same_mutables`."""
|
||||
if isinstance(mutable, list):
|
||||
# Add a concatenation of mutables to `concat_mutables`.
|
||||
assert self.mask_type == 'in_mask'
|
||||
assert all([
|
||||
cur_mutable.mask_type == 'out_mask' for cur_mutable in mutable
|
||||
])
|
||||
self.concat_mutables = mutable
|
||||
return
|
||||
|
||||
if self == mutable:
|
||||
return
|
||||
if mutable in self._same_mutables:
|
||||
return
|
||||
|
||||
self._same_mutables.append(mutable)
|
||||
for s_mutable in self._same_mutables:
|
||||
s_mutable.register_same_mutable(mutable)
|
||||
mutable.register_same_mutable(s_mutable)
|
||||
|
||||
def sample_choice(self) -> int:
|
||||
"""Sample an arbitrary selection from candidate choices.
|
||||
|
||||
Returns:
|
||||
int: The chosen number of channels.
|
||||
"""
|
||||
assert len(self.concat_mutables) == 0
|
||||
num_channels = np.random.choice(self.choices)
|
||||
assert num_channels > 0, \
|
||||
f'Sampled number of channels in `Mutable` {self.name}' \
|
||||
f' should be a positive integer.'
|
||||
return num_channels
|
||||
|
||||
@property
|
||||
def min_choice(self) -> int:
|
||||
"""Minimum number of channels."""
|
||||
assert len(self.concat_mutables) == 0
|
||||
min_channels = min(self.choices)
|
||||
assert min_channels > 0, \
|
||||
f'Minimum number of channels in `Mutable` {self.name}' \
|
||||
f' should be a positive integer.'
|
||||
return min_channels
|
||||
|
||||
@property
|
||||
def max_choice(self) -> int:
|
||||
"""Maximum number of channels."""
|
||||
return max(self.choices)
|
||||
|
||||
def get_choice(self, idx: int) -> int:
|
||||
"""Get the `idx`-th choice from candidate choices."""
|
||||
assert len(self.concat_mutables) == 0
|
||||
num_channels = self.choices[idx]
|
||||
assert num_channels > 0, \
|
||||
f'Number of channels in `Mutable` {self.name}' \
|
||||
f' should be a positive integer.'
|
||||
return num_channels
|
||||
|
||||
@property
|
||||
def current_choice(self):
|
||||
"""The current choice of the mutable."""
|
||||
if len(self.concat_mutables) > 0:
|
||||
return sum(
|
||||
[mutable.current_choice for mutable in self.concat_mutables])
|
||||
else:
|
||||
return self._current_choice
|
||||
|
||||
@current_choice.setter
|
||||
def current_choice(self, choice: int):
|
||||
"""Set the current choice of the mutable."""
|
||||
assert choice in self.choices
|
||||
self._current_choice = choice
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def choices(self) -> List[int]:
|
||||
"""list: all choices. """
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
"""The current mask.
|
||||
|
||||
We slice the registered parameters and buffers of a ``nn.Module``
|
||||
according to the mask of the corresponding channel mutable.
|
||||
"""
|
||||
if len(self.concat_mutables) > 0:
|
||||
# If the input of a module is a concatenation of several modules'
|
||||
# outputs, the in_mask of this module is the concatenation of
|
||||
# these modules' out_mask.
|
||||
return torch.cat(
|
||||
[mutable.mask for mutable in self.concat_mutables])
|
||||
else:
|
||||
num_channels = self.current_choice
|
||||
mask = torch.zeros_like(self._mask).bool()
|
||||
mask[:num_channels] = True
|
||||
return mask
|
||||
|
||||
def __repr__(self):
|
||||
concat_mutable_name = [
|
||||
mutable.name for mutable in self.concat_mutables
|
||||
]
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(name={self.name}, '
|
||||
repr_str += f'mask_type={self.mask_type}, '
|
||||
repr_str += f'num_channels={self.num_channels}, '
|
||||
repr_str += f'concat_mutable_name={concat_mutable_name})'
|
||||
return repr_str
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .one_shot_channel_mutable import OneShotChannelMutable
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OrderChannelMutable(OneShotChannelMutable):
|
||||
"""A type of ``OneShotChannelMutable``. The input candidate choices are
|
||||
candidate channel numbers.
|
||||
|
||||
Args:
|
||||
name (str): Mutable name.
|
||||
mask_type (str): One of 'in_mask' or 'out_mask'.
|
||||
num_channels (int): The raw number of channels.
|
||||
candidate_choices (list | tuple): A list or tuple of candidate
|
||||
channel numbers.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
and `Pretrained`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
mask_type: str,
|
||||
num_channels: int,
|
||||
candidate_choices: Union[List, Tuple],
|
||||
init_cfg: Optional[Dict] = None):
|
||||
super(OrderChannelMutable, self).__init__(
|
||||
name, mask_type, num_channels, init_cfg=init_cfg)
|
||||
|
||||
assert len(candidate_choices) > 0, \
|
||||
f'Number of candidate choices must be greater than 0, ' \
|
||||
f'but got: {len(candidate_choices)}'
|
||||
self._candidate_choices = list(candidate_choices)
|
||||
|
||||
assert all([num > 0 and num <= self.num_channels
|
||||
for num in self._candidate_choices]), \
|
||||
f'The candidate channel numbers should be in ' \
|
||||
f'range(0, {self.num_channels}].'
|
||||
assert all([isinstance(num, int)
|
||||
for num in self._candidate_choices]),\
|
||||
'Type of `candidate_choices` should be int.'
|
||||
|
||||
@property
|
||||
def choices(self) -> List[int]:
|
||||
"""list: all choices. """
|
||||
return self._candidate_choices
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .one_shot_channel_mutable import OneShotChannelMutable
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RatioChannelMutable(OneShotChannelMutable):
|
||||
"""A type of ``OneShotChannelMutable``. The input candidate choices are
|
||||
candidate width ratios.
|
||||
|
||||
Notes:
|
||||
We first calculate the candidate channel numbers according to
|
||||
the input candidate ratios (`candidate_choices`) and regard them as
|
||||
available choices.
|
||||
|
||||
Args:
|
||||
name (str): Mutable name.
|
||||
mask_type (str): One of 'in_mask' or 'out_mask'.
|
||||
num_channels (int): The raw number of channels.
|
||||
candidate_choices (list | tuple): A list or tuple of candidate width
|
||||
ratios. The width ratio is the ratio between the number of reserved
|
||||
channels and that of all channels in a layer.
|
||||
For example, if `ratios` is [0.25, 0.5], there are 2 cases
|
||||
for us to choose from when we sample from a layer with 12 channels.
|
||||
One is sampling the very first 3 channels in this layer, another is
|
||||
sampling the very first 6 channels in this layer.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
and `Pretrained`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
mask_type: str,
|
||||
num_channels: int,
|
||||
candidate_choices: Union[List, Tuple],
|
||||
init_cfg: Optional[Dict] = None):
|
||||
super(RatioChannelMutable, self).__init__(
|
||||
name, mask_type, num_channels, init_cfg=init_cfg)
|
||||
|
||||
assert len(candidate_choices) > 0, \
|
||||
f'Number of candidate choices must be greater than 0, ' \
|
||||
f'but got: {len(candidate_choices)}'
|
||||
self._candidate_choices = candidate_choices
|
||||
|
||||
assert all([
|
||||
ratio > 0 and ratio <= 1 for ratio in self._candidate_choices
|
||||
]), 'The candidate ratio should be in range(0, 1].'
|
||||
|
||||
@property
|
||||
def choices(self) -> List[int]:
|
||||
"""list: all choices. """
|
||||
return [
|
||||
round(ratio * self.num_channels)
|
||||
for ratio in self._candidate_choices
|
||||
]
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SlimmableChannelMutable(BaseModule):
|
||||
"""A type of ``MUTABLES`` to train several subnet together, such as the
|
||||
retraining stage in AutoSlim.
|
||||
|
||||
Notes:
|
||||
We need to set `candidate_choices` after the instantiation of a
|
||||
`SlimmableChannelMutable` by ourselves.
|
||||
|
||||
Args:
|
||||
name (str): Mutable name.
|
||||
mask_type (str): One of 'in_mask' or 'out_mask'.
|
||||
num_channels (int): The raw number of channels.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
and `Pretrained`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
mask_type: str,
|
||||
num_channels: int,
|
||||
init_cfg: Optional[Dict] = None):
|
||||
super(SlimmableChannelMutable, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.name = name
|
||||
assert mask_type in ('in_mask', 'out_mask')
|
||||
self.mask_type = mask_type
|
||||
self.num_channels = num_channels
|
||||
self.register_buffer('_mask', torch.ones((num_channels, )).bool())
|
||||
self._current_choice = 0
|
||||
|
||||
@property
|
||||
def candidate_choices(self) -> List:
|
||||
"""A list of candidate channel numbers."""
|
||||
return self._candidate_choices
|
||||
|
||||
@candidate_choices.setter
|
||||
def candidate_choices(self, choices):
|
||||
"""Set the candidate channel numbers."""
|
||||
assert getattr(self, '_candidate_choices', None) is None, \
|
||||
f'candidate_choices can be set only when candidate_choices is ' \
|
||||
f'None, got: candidate_choices = {self._candidate_choices}'
|
||||
|
||||
assert all([num > 0 and num <= self.num_channels
|
||||
for num in choices]), \
|
||||
f'The candidate channel numbers should be in ' \
|
||||
f'range(0, {self.num_channels}].'
|
||||
assert all([isinstance(num, int) for num in choices]), \
|
||||
'Type of `candidate_choices` should be int.'
|
||||
|
||||
self._candidate_choices = list(choices)
|
||||
|
||||
@property
|
||||
def choices(self) -> List:
|
||||
"""Return all subnet indexes."""
|
||||
assert self._candidate_choices is not None
|
||||
return list(range(len(self.candidate_choices)))
|
||||
|
||||
@property
|
||||
def current_choice(self) -> int:
|
||||
"""The current choice of the mutable."""
|
||||
return self._current_choice
|
||||
|
||||
@current_choice.setter
|
||||
def current_choice(self, choice: int):
|
||||
"""Set the current choice of the mutable."""
|
||||
assert choice in self.choices
|
||||
self._current_choice = choice
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
"""The current mask.
|
||||
|
||||
We slice the registered parameters and buffers of a ``nn.Module``
|
||||
according to the mask of the corresponding channel mutable.
|
||||
"""
|
||||
idx = self.current_choice
|
||||
num_channels = self.candidate_choices[idx]
|
||||
mask = torch.zeros_like(self._mask).bool()
|
||||
mask[:num_channels] = True
|
||||
return mask
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
|
||||
class MutableManagerMixIn:
|
||||
"""Mixin class for determining whether an object is a dynamic layer.
|
||||
|
||||
Note that a dynamic layer manage one or several mutables.
|
||||
"""
|
||||
pass
|
|
@ -1,5 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .channel_mutator.one_shot_channel_mutator import OneShotChannelMutator
|
||||
from .channel_mutator.slimmable_channel_mutator import SlimmableChannelMutator
|
||||
from .diff_mutator import DiffMutator
|
||||
from .one_shot_mutator import OneShotMutator
|
||||
|
||||
__all__ = ['OneShotMutator', 'DiffMutator']
|
||||
__all__ = [
|
||||
'OneShotMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator',
|
||||
'DiffMutator'
|
||||
]
|
||||
|
|
|
@ -38,7 +38,7 @@ class BaseMutator(ABC, BaseModule):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def search_group(self) -> Dict:
|
||||
def search_groups(self) -> Dict:
|
||||
"""Search group of the supernet.
|
||||
|
||||
Note:
|
||||
|
@ -76,7 +76,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
if custom_group is None:
|
||||
custom_group = []
|
||||
self._custom_group = custom_group
|
||||
self._search_group: Optional[Dict[int, List[MUTABLE_TYPE]]] = None
|
||||
self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None
|
||||
|
||||
# TODO
|
||||
# should be a class property
|
||||
|
@ -99,10 +99,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
self._build_search_group(supernet)
|
||||
self._build_search_groups(supernet)
|
||||
|
||||
@property
|
||||
def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]:
|
||||
def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]:
|
||||
"""Search group of supernet.
|
||||
|
||||
Note:
|
||||
|
@ -115,10 +115,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
Returns:
|
||||
Dict[int, List[MUTABLE_TYPE]]: Search group.
|
||||
"""
|
||||
if self._search_group is None:
|
||||
if self._search_groups is None:
|
||||
raise RuntimeError(
|
||||
'Call `prepare_from_supernet` before access search group!')
|
||||
return self._search_group
|
||||
return self._search_groups
|
||||
|
||||
def _build_name_mutable_mapping(
|
||||
self, supernet: Module) -> Dict[str, MUTABLE_TYPE]:
|
||||
|
@ -143,7 +143,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
|
||||
return alias2mutable_names
|
||||
|
||||
def _build_search_group(self, supernet: Module) -> None:
|
||||
def _build_search_groups(self, supernet: Module) -> None:
|
||||
"""Build search group with ``custom_group`` and ``alias``(see more
|
||||
information in :class:`BaseMutable`). Grouping by alias and module name
|
||||
are both supported.
|
||||
|
@ -176,20 +176,20 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
>>> # Using alias for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['a1'], ['a2']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> mutator.search_group
|
||||
>>> mutator.search_groups
|
||||
{0: [op1, op2], 1: [op3]}
|
||||
|
||||
>>> # Using module name for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> mutator.search_group
|
||||
>>> mutator.search_groups
|
||||
{0: [op1, op2], 1: [op3]}
|
||||
|
||||
>>> # Using both alias and module name for grouping
|
||||
>>> mutator = DiffOP(custom_group=[['a2'], ['op2']])
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> # The last operation would be grouped
|
||||
>>> mutator.search_group
|
||||
>>> mutator.search_groups
|
||||
{0: [op3], 1: [op2], 2: [op1]}
|
||||
|
||||
|
||||
|
@ -271,7 +271,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]):
|
|||
f'The duplicate keys are {duplicate_keys}. ' \
|
||||
'Please check if there are duplicate keys in the `custom_group`.'
|
||||
|
||||
self._search_group = search_groups
|
||||
self._search_groups = search_groups
|
||||
|
||||
def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
|
||||
name2mutable: Dict[str, MUTABLE_TYPE],
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .channel_mutator import ChannelMutator
|
||||
from .one_shot_channel_mutator import OneShotChannelMutator
|
||||
from .slimmable_channel_mutator import SlimmableChannelMutable
|
||||
|
||||
__all__ = [
|
||||
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutable'
|
||||
]
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import (BatchNorm1d, BatchNorm2d, BatchNorm3d,
|
||||
_BatchNorm)
|
||||
from torch.nn.modules.instancenorm import (InstanceNorm1d, InstanceNorm2d,
|
||||
InstanceNorm3d, _InstanceNorm)
|
||||
|
||||
from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode,
|
||||
LinearNode, NormNode, PathList)
|
||||
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
|
||||
build_dynamic_conv2d,
|
||||
build_dynamic_gn,
|
||||
build_dynamic_in,
|
||||
build_dynamic_linear)
|
||||
from mmrazor.models.mutables import OneShotChannelMutable
|
||||
from mmrazor.registry import MODELS, TASK_UTILS
|
||||
from ..base_mutator import BaseMutator
|
||||
|
||||
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
|
||||
PASS_NODES = (NormNode, DepthWiseConvNode)
|
||||
|
||||
NONPASS_MODULES = (nn.Conv2d, nn.Linear)
|
||||
PASS_MODULES = (_BatchNorm, _InstanceNorm, GroupNorm)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ChannelMutator(BaseMutator):
|
||||
"""Base class for channel-based mutators.
|
||||
|
||||
Args:
|
||||
mutable_cfg (dict): The config for the channel mutable.
|
||||
tracer_cfg (dict | Optional): The config for the model tracer.
|
||||
We Trace the topology of a given model with the tracer.
|
||||
skip_prefixes (List[str] | Optional): The module whose name start with
|
||||
a string in skip_prefixes will not be pruned.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
|
||||
Attributes:
|
||||
search_groups (Dict[int, List]): Search group of supernet. Note that
|
||||
the search group of a mutable based channel mutator is composed of
|
||||
corresponding mutables. Mutables in the same search group should
|
||||
be pruned together.
|
||||
name2module (Dict[str, :obj:`torch.nn.Module`]): The mapping from
|
||||
a module name to the module.
|
||||
|
||||
Notes:
|
||||
# To avoid ambiguity, we only allow the following two cases:
|
||||
# 1. None of the parent nodes of a node is a `ConcatNode`
|
||||
# 2. A node has only one parent node which is a `ConcatNode`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mutable_cfg: Dict,
|
||||
tracer_cfg: Optional[Dict] = None,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.mutable_cfg = mutable_cfg
|
||||
if tracer_cfg:
|
||||
self.tracer = TASK_UTILS.build(tracer_cfg)
|
||||
else:
|
||||
self.tracer = None
|
||||
self.skip_prefixes = skip_prefixes
|
||||
self._search_groups: Optional[Dict[int, List[Module]]] = None
|
||||
|
||||
def add_link(self, path_list: PathList) -> None:
|
||||
"""Establish the relationship between the current nodes and their
|
||||
parents."""
|
||||
for path in path_list:
|
||||
pre_node = None
|
||||
for node in path:
|
||||
if isinstance(node, DepthWiseConvNode):
|
||||
module = self.name2module[node.name]
|
||||
# The in_channels and out_channels of a depth-wise conv
|
||||
# should be the same
|
||||
module.mutable_out.register_same_mutable(module.mutable_in)
|
||||
module.mutable_in.register_same_mutable(module.mutable_out)
|
||||
|
||||
if isinstance(node, ConcatNode):
|
||||
if pre_node is not None:
|
||||
module_names = node.get_module_names()
|
||||
concat_modules = [
|
||||
self.name2module[name] for name in module_names
|
||||
]
|
||||
concat_mutables = [
|
||||
module.mutable_out for module in concat_modules
|
||||
]
|
||||
pre_module = self.name2module[pre_node.name]
|
||||
pre_module.mutable_in.register_same_mutable(
|
||||
concat_mutables)
|
||||
|
||||
for cur_path_list in node:
|
||||
self.add_link(cur_path_list)
|
||||
|
||||
# ConcatNode is the last node in a path
|
||||
break
|
||||
|
||||
if pre_node is None:
|
||||
pre_node = node
|
||||
continue
|
||||
|
||||
pre_module = self.name2module[pre_node.name]
|
||||
cur_module = self.name2module[node.name]
|
||||
pre_module.mutable_in.register_same_mutable(
|
||||
cur_module.mutable_out)
|
||||
cur_module.mutable_out.register_same_mutable(
|
||||
pre_module.mutable_in)
|
||||
|
||||
pre_node = node
|
||||
|
||||
def prepare_from_supernet(self, supernet: Module) -> None:
|
||||
"""Do some necessary preparations with supernet.
|
||||
|
||||
We support the following two cases:
|
||||
|
||||
Case 1: The input is the original nn.Module. We first replace the
|
||||
conv/linear/norm modules in the input supernet with dynamic ops.
|
||||
And trace the topology of the supernet. Finally, `search_groups` can be
|
||||
built based on the topology.
|
||||
|
||||
Case 2: The input supernet is made up of dynamic ops. In this case,
|
||||
relationship between nodes and their parents must have been
|
||||
established and topology of the supernet is available for us. Then
|
||||
`search_groups` can be built based on the topology.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
if self.tracer is not None:
|
||||
self.convert_dynamic_module(supernet, self.dynamic_layer)
|
||||
# The mapping from a module name to the module
|
||||
self._name2module = dict(supernet.named_modules())
|
||||
|
||||
assert self.tracer is not None
|
||||
module_path_list: PathList = self.tracer.trace(supernet)
|
||||
|
||||
self.add_link(module_path_list)
|
||||
else:
|
||||
self._name2module = dict(supernet.named_modules())
|
||||
|
||||
self._search_groups = self.build_search_groups(supernet)
|
||||
|
||||
@staticmethod
|
||||
def find_same_mutables(supernet) -> Dict:
|
||||
"""The mutables in the same group should be pruned together."""
|
||||
visited = []
|
||||
groups = {}
|
||||
group_idx = 0
|
||||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, OneShotChannelMutable):
|
||||
same_mutables = module.same_mutables
|
||||
if module not in visited and len(same_mutables) > 0:
|
||||
groups[group_idx] = [module] + same_mutables
|
||||
visited.extend(groups[group_idx])
|
||||
group_idx += 1
|
||||
return groups
|
||||
|
||||
def convert_dynamic_module(self, supernet: Module, dynamic_layer: Dict):
|
||||
"""Replace the conv/linear/norm modules in the input supernet with
|
||||
dynamic ops.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be converted
|
||||
in your algorithm.
|
||||
dynamic_layer (Dict): The mapping from the module type to the
|
||||
corresponding dynamic layer.
|
||||
"""
|
||||
|
||||
def traverse(module, prefix):
|
||||
for name, child in module.named_children():
|
||||
module_name = prefix + name
|
||||
if isinstance(child, NONPASS_MODULES):
|
||||
mutable_cfg = copy.deepcopy(self.mutable_cfg)
|
||||
# mutable_cfg.update(dict(name=module_name))
|
||||
layer = dynamic_layer[type(child)](child, module_name,
|
||||
mutable_cfg,
|
||||
mutable_cfg)
|
||||
setattr(module, name, layer)
|
||||
elif isinstance(child, PASS_MODULES):
|
||||
mutable_cfg = copy.deepcopy(self.mutable_cfg)
|
||||
# mutable_cfg.update(dict(name=module_name))
|
||||
layer = dynamic_layer[type(child)](child, module_name,
|
||||
mutable_cfg)
|
||||
setattr(module, name, layer)
|
||||
else:
|
||||
traverse(child, module_name + '.')
|
||||
|
||||
traverse(supernet, '')
|
||||
|
||||
@abstractmethod
|
||||
def build_search_groups(self, supernet: Module):
|
||||
"""Build `search_groups`.
|
||||
|
||||
The mutables in the same group should be pruned together.
|
||||
"""
|
||||
|
||||
@property
|
||||
def search_groups(self) -> Dict[int, List]:
|
||||
"""Search group of supernet.
|
||||
|
||||
Note:
|
||||
For mutable based mutator, the search group is composed of
|
||||
corresponding mutables.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Called before search group has been built.
|
||||
|
||||
Returns:
|
||||
Dict[int, List[MUTABLE_TYPE]]: Search group.
|
||||
"""
|
||||
if self._search_groups is None:
|
||||
raise RuntimeError(
|
||||
'Call `search_groups` before access `build_search_groups`!')
|
||||
return self._search_groups
|
||||
|
||||
@property
|
||||
def name2module(self):
|
||||
"""The mapping from a module name to the module.
|
||||
|
||||
Returns:
|
||||
dict: The name to module mapping.
|
||||
"""
|
||||
if hasattr(self, '_name2module'):
|
||||
return self._name2module
|
||||
else:
|
||||
raise RuntimeError('Called before access `prepare_from_supernet`!')
|
||||
|
||||
@property
|
||||
def dynamic_layer(self) -> Dict:
|
||||
"""The mapping from a type to the corresponding dynamic layer. It is
|
||||
called in `prepare_from_supernet`.
|
||||
|
||||
Returns:
|
||||
dict: The mapping dict.
|
||||
"""
|
||||
|
||||
dynamic_layer: Dict[Callable, Callable] = {
|
||||
nn.Conv2d: build_dynamic_conv2d,
|
||||
nn.Linear: build_dynamic_linear,
|
||||
BatchNorm1d: build_dynamic_bn,
|
||||
BatchNorm2d: build_dynamic_bn,
|
||||
BatchNorm3d: build_dynamic_bn,
|
||||
InstanceNorm1d: build_dynamic_in,
|
||||
InstanceNorm2d: build_dynamic_in,
|
||||
InstanceNorm3d: build_dynamic_in,
|
||||
GroupNorm: build_dynamic_gn
|
||||
}
|
||||
|
||||
return dynamic_layer
|
||||
|
||||
def is_skip_pruning(self, module_name: str,
|
||||
skip_prefixes: Optional[List[str]]) -> bool:
|
||||
"""Judge if the module with the input `module_name` should not be
|
||||
pruned.
|
||||
|
||||
Args:
|
||||
module_name (str): Module name.
|
||||
skip_prefixes (list or None): The module whose name start with
|
||||
a string in skip_prefixes will not be prune.
|
||||
"""
|
||||
skip_pruning = False
|
||||
if skip_prefixes:
|
||||
for prefix in skip_prefixes:
|
||||
if module_name.startswith(prefix):
|
||||
skip_pruning = True
|
||||
break
|
||||
return skip_pruning
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode,
|
||||
LinearNode, NormNode)
|
||||
from mmrazor.registry import MODELS
|
||||
from .channel_mutator import ChannelMutator
|
||||
|
||||
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
|
||||
PASS_NODES = (NormNode, DepthWiseConvNode)
|
||||
PRUNING_NODES = (ConvNode, LinearNode)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OneShotChannelMutator(ChannelMutator):
|
||||
"""One-shot channel mutable based channel mutator.
|
||||
|
||||
Args:
|
||||
mutable_cfg (dict): The config for the channel mutable.
|
||||
tracer_cfg (dict): The config for the model tracer. We Trace the
|
||||
topology of a given model with the tracer.
|
||||
skip_prefixes (List[str] | Optional): The module whose name start with
|
||||
a string in skip_prefixes will not be pruned.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mutable_cfg: Dict,
|
||||
tracer_cfg: Optional[Dict] = None,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(mutable_cfg, tracer_cfg, skip_prefixes, init_cfg)
|
||||
|
||||
def sample_choices(self):
|
||||
"""Sample a choice that records a selection from the search space.
|
||||
|
||||
Returns:
|
||||
dict: Record the information to build the subnet from the supernet.
|
||||
Its keys are the properties ``group_idx`` in the channel
|
||||
mutator's ``search_groups``, and its values are the sampled
|
||||
choice.
|
||||
"""
|
||||
choice_dict = dict()
|
||||
for group_idx, mutables in self.search_groups.items():
|
||||
choice_dict[group_idx] = mutables[0].sample_choice()
|
||||
return choice_dict
|
||||
|
||||
def set_choices(self, choice_dict: Dict[int, Any]) -> None:
|
||||
"""Set current subnet according to ``choice_dict``.
|
||||
|
||||
Args:
|
||||
choice_dict (Dict[int, Any]): Choice dict.
|
||||
"""
|
||||
for group_idx, choice in choice_dict.items():
|
||||
mutables = self.search_groups[group_idx]
|
||||
for mutable in mutables:
|
||||
mutable.current_choice = choice
|
||||
|
||||
def set_max_choices(self) -> None:
|
||||
"""Set the channel numbers of each layer to maximum."""
|
||||
for mutables in self.search_groups.values():
|
||||
for mutable in mutables:
|
||||
mutable.current_choice = mutable.max_choice
|
||||
|
||||
def set_min_choices(self) -> None:
|
||||
"""Set the channel numbers of each layer to minimum."""
|
||||
for mutables in self.search_groups.values():
|
||||
for mutable in mutables:
|
||||
mutable.current_choice = mutable.min_choice
|
||||
|
||||
# todo: check search gorups
|
||||
def build_search_groups(self, supernet: Module):
|
||||
"""Build `search_groups`. The mutables in the same group should be
|
||||
pruned together.
|
||||
|
||||
Examples:
|
||||
>>> class ResBlock(nn.Module):
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
...
|
||||
... self.op1 = nn.Conv2d(3, 8, 1)
|
||||
... self.bn1 = nn.BatchNorm2d(8)
|
||||
... self.op2 = nn.Conv2d(8, 8, 1)
|
||||
... self.bn2 = nn.BatchNorm2d(8)
|
||||
... self.op3 = nn.Conv2d(8, 8, 1)
|
||||
...
|
||||
... def forward(self, x):
|
||||
... x1 = self.bn1(self.op1(x))
|
||||
... x2 = self.bn2(self.op2(x1))
|
||||
... x3 = self.op3(x2 + x1)
|
||||
... return x3
|
||||
|
||||
>>> class ToyPseudoLoss:
|
||||
...
|
||||
... def __call__(self, model):
|
||||
... pseudo_img = torch.rand(2, 3, 16, 16)
|
||||
... pseudo_output = model(pseudo_img)
|
||||
... return pseudo_output.sum()
|
||||
|
||||
>>> mutator = OneShotChannelMutator(
|
||||
... tracer_cfg=dict(type='BackwardTracer',
|
||||
... loss_calculator=ToyPseudoLoss()),
|
||||
... mutable_cfg=dict(type='RatioChannelMutable',
|
||||
... candidate_choices=[4 / 8, 1.0])
|
||||
|
||||
>>> model = ResBlock()
|
||||
>>> mutator.prepare_from_supernet(model)
|
||||
>>> mutator.search_groups
|
||||
{0: [RatioChannelMutable(name=op2, mask_type=out_mask, ...),
|
||||
RatioChannelMutable(name=op1, mask_type=out_mask, ...),
|
||||
RatioChannelMutable(name=op3, mask_type=in_mask, ...),
|
||||
RatioChannelMutable(name=op2, mask_type=in_mask, ...),
|
||||
RatioChannelMutable(name=bn2, mask_type=out_mask, ...),
|
||||
RatioChannelMutable(name=bn1, mask_type=out_mask, ...)]}
|
||||
"""
|
||||
groups = self.find_same_mutables(supernet)
|
||||
|
||||
search_groups = dict()
|
||||
group_idx = 0
|
||||
for group in groups.values():
|
||||
is_skip = False
|
||||
for mutable in group:
|
||||
if self.is_skip_pruning(mutable.name, self.skip_prefixes):
|
||||
warnings.warn(f'Group {group} is not searchable due to'
|
||||
f' skip_prefixes: {self.skip_prefixes}')
|
||||
is_skip = True
|
||||
break
|
||||
if not is_skip:
|
||||
search_groups[group_idx] = group
|
||||
group_idx += 1
|
||||
|
||||
return search_groups
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models.architectures.dynamic_op import (DynamicBatchNorm,
|
||||
build_switchable_bn)
|
||||
from mmrazor.models.mutables import SlimmableChannelMutable
|
||||
from mmrazor.registry import MODELS
|
||||
from .channel_mutator import ChannelMutator
|
||||
|
||||
NONPASS_MODULES = (nn.Conv2d, nn.Linear)
|
||||
PASS_MODULES = (_BatchNorm, )
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SlimmableChannelMutator(ChannelMutator):
|
||||
"""Slimmable channel mutable based channel mutator.
|
||||
|
||||
Args:
|
||||
channel_cfgs (list[Dict]): A list of candidate channel configs.
|
||||
mutable_cfg (dict): The config for the channel mutable.
|
||||
skip_prefixes (List[str] | Optional): The module whose name start with
|
||||
a string in skip_prefixes will not be pruned.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channel_cfgs: List[Dict],
|
||||
mutable_cfg: Dict,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
init_cfg: Optional[Dict] = None):
|
||||
super(SlimmableChannelMutator, self).__init__(
|
||||
mutable_cfg=mutable_cfg,
|
||||
skip_prefixes=skip_prefixes,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.channel_cfgs = self._merge_channel_cfgs(channel_cfgs)
|
||||
|
||||
def _merge_channel_cfgs(self, channel_cfgs: List[Dict]):
|
||||
"""Merge several channel configs.
|
||||
|
||||
Args:
|
||||
channel_cfgs (List[Dict])
|
||||
"""
|
||||
merged_channel_cfg = dict()
|
||||
num_subnet = len(channel_cfgs)
|
||||
|
||||
for module_name in channel_cfgs[0].keys():
|
||||
channels_per_layer = [
|
||||
channel_cfgs[idx][module_name] for idx in range(num_subnet)
|
||||
]
|
||||
merged_channels_per_layer = dict()
|
||||
for key in channels_per_layer[0].keys():
|
||||
merged_channels = [
|
||||
channels_per_layer[idx][key] for idx in range(num_subnet)
|
||||
]
|
||||
merged_channels_per_layer[key] = merged_channels
|
||||
merged_channel_cfg[module_name] = merged_channels_per_layer
|
||||
|
||||
return merged_channel_cfg
|
||||
|
||||
def prepare_from_supernet(self, supernet: Module) -> None:
|
||||
"""Do some necessary preparations with supernet.
|
||||
|
||||
Note:
|
||||
Different from `ChannelMutator`, we only support Case 1 in
|
||||
`ChannelMutator`. The input supernet should be made up of original
|
||||
nn.Module. And we replace the conv/linear/bn modules in the input
|
||||
supernet with dynamic ops first. Then we convert the
|
||||
``DynamicBatchNorm`` in supernet with ``SwitchableBatchNorm2d``.
|
||||
Finally, we set the candidate channel numbers to the corresponding
|
||||
`SlimmableChannelMutable`.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
self.convert_dynamic_module(supernet, self.dynamic_layer)
|
||||
print(supernet)
|
||||
self.convert_switchable_bn(supernet)
|
||||
self.set_candidate_choices(supernet)
|
||||
# The mapping from a module name to the module
|
||||
self._name2module = dict(supernet.named_modules())
|
||||
|
||||
def set_candidate_choices(self, supernet):
|
||||
"""Set the ``candidate_choices`` of each ``SlimmableChannelMutable``.
|
||||
|
||||
Notes:
|
||||
Different from other ``OneShotChannelMutable``,
|
||||
``candidate_choices`` is optional when instantiating a
|
||||
``SlimmableChannelMutable``
|
||||
"""
|
||||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, SlimmableChannelMutable):
|
||||
candidate_choices = self.channel_cfgs[name]['current_choice']
|
||||
module.candidate_choices = candidate_choices
|
||||
|
||||
def convert_switchable_bn(self, supernet):
|
||||
"""Replace ``DynamicBatchNorm`` in supernet with
|
||||
``SwitchableBatchNorm2d``.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be converted
|
||||
in your algorithm.
|
||||
"""
|
||||
|
||||
def traverse(module, prefix):
|
||||
for name, child in module.named_children():
|
||||
module_name = prefix + name
|
||||
if isinstance(child, DynamicBatchNorm):
|
||||
mutable_cfg = copy.deepcopy(self.mutable_cfg)
|
||||
key = module_name + '.mutable_num_features'
|
||||
candidate_choices = self.channel_cfgs[key][
|
||||
'current_choice']
|
||||
mutable_cfg.update(
|
||||
dict(candidate_choices=candidate_choices))
|
||||
sbn = build_switchable_bn(child, module_name, mutable_cfg)
|
||||
setattr(module, name, sbn)
|
||||
else:
|
||||
traverse(child, module_name + '.')
|
||||
|
||||
traverse(supernet, '')
|
||||
|
||||
def switch_choices(self, idx):
|
||||
"""Switch the channel config of the supernet according to input `idx`.
|
||||
|
||||
If we train more than one subnet together, we need to switch the
|
||||
channel_cfg from one to another during one training iteration.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the current subnet.
|
||||
"""
|
||||
for name, module in self.name2module.items():
|
||||
if hasattr(module, 'mutable_out'):
|
||||
module.mutable_out.current_choice = idx
|
||||
if hasattr(module, 'mutable_in'):
|
||||
module.mutable_in.current_choice = idx
|
||||
|
||||
def build_search_groups(self, supernet: Module):
|
||||
"""Build `search_groups`.
|
||||
|
||||
The mutables in the same group should be pruned together.
|
||||
"""
|
||||
pass
|
|
@ -45,12 +45,12 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
|
|||
group_id share the same arch param.
|
||||
|
||||
Returns:
|
||||
torch.nn.ParameterDict: the arch params are got by `search_group`.
|
||||
torch.nn.ParameterDict: the arch params are got by `search_groups`.
|
||||
"""
|
||||
|
||||
arch_params: Dict[int, nn.Parameter] = dict()
|
||||
|
||||
for group_id, modules in self.search_group.items():
|
||||
for group_id, modules in self.search_groups.items():
|
||||
group_arch_param = modules[0].build_arch_param()
|
||||
arch_params[group_id] = group_arch_param
|
||||
|
||||
|
@ -65,7 +65,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]):
|
|||
`DiffMutable`.
|
||||
"""
|
||||
|
||||
for group_id, modules in self.search_group.items():
|
||||
for group_id, modules in self.search_groups.items():
|
||||
for module in modules:
|
||||
module.set_forward_args(arch_param=self.arch_params[group_id])
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
|
|||
['op1', 'op2', 'op3']
|
||||
|
||||
>>> mutator.prepare_from_supernet(supernet)
|
||||
>>> mutator.search_group
|
||||
>>> mutator.search_groups
|
||||
{0: [op1], 1: [op2], 2: [op3]}
|
||||
|
||||
>>> random_choices = mutator.sample_choices()
|
||||
|
@ -61,7 +61,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
|
|||
Dict[int, Any]: Random choices dict.
|
||||
"""
|
||||
random_choices = dict()
|
||||
for group_id, modules in self.search_group.items():
|
||||
for group_id, modules in self.search_groups.items():
|
||||
random_choices[group_id] = modules[0].sample_choice()
|
||||
|
||||
return random_choices
|
||||
|
@ -75,7 +75,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]):
|
|||
search groups, and the value is the sampling results
|
||||
corresponding to this group.
|
||||
"""
|
||||
for group_id, modules in self.search_group.items():
|
||||
for group_id, modules in self.search_groups.items():
|
||||
choice = choices[group_id]
|
||||
for module in modules:
|
||||
module.current_choice = choice
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ratio_pruning import RatioPruner
|
||||
from .structure_pruning import StructurePruner
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
__all__ = ['RatioPruner', 'StructurePruner']
|
|
@ -1,150 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules import GroupNorm
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .structure_pruning import StructurePruner
|
||||
from .utils import SwitchableBatchNorm2d
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RatioPruner(StructurePruner):
|
||||
"""A random ratio pruner.
|
||||
|
||||
Each layer can adjust its own width ratio randomly and independently.
|
||||
|
||||
Args:
|
||||
ratios (list | tuple): Width ratio of each layer can be
|
||||
chosen from `ratios` randomly. The width ratio is the ratio between
|
||||
the number of reserved channels and that of all channels in a
|
||||
layer. For example, if `ratios` is [0.25, 0.5], there are 2 cases
|
||||
for us to choose from when we sample from a layer with 12 channels.
|
||||
One is sampling the very first 3 channels in this layer, another is
|
||||
sampling the very first 6 channels in this layer. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self, ratios, **kwargs):
|
||||
super(RatioPruner, self).__init__(**kwargs)
|
||||
ratios = list(ratios)
|
||||
ratios.sort()
|
||||
self.ratios = ratios
|
||||
self.min_ratio = ratios[0]
|
||||
|
||||
def _check_pruner(self, supernet):
|
||||
for module in supernet.model.modules():
|
||||
if isinstance(module, GroupNorm):
|
||||
num_channels = module.num_channels
|
||||
num_groups = module.num_groups
|
||||
for ratio in self.ratios:
|
||||
new_channels = int(round(num_channels * ratio))
|
||||
assert (num_channels * ratio) % num_groups == 0, \
|
||||
f'Expected number of channels in input of GroupNorm ' \
|
||||
f'to be divisible by num_groups, but number of ' \
|
||||
f'channels may be {new_channels} according to ' \
|
||||
f'ratio {ratio} and num_groups={num_groups}'
|
||||
|
||||
def prepare_from_supernet(self, supernet):
|
||||
super(RatioPruner, self).prepare_from_supernet(supernet)
|
||||
|
||||
def get_channel_mask(self, out_mask):
|
||||
"""Randomly choose a width ratio of a layer from ``ratios``"""
|
||||
out_channels = out_mask.size(1)
|
||||
random_ratio = np.random.choice(self.ratios)
|
||||
new_channels = int(round(out_channels * random_ratio))
|
||||
assert new_channels > 0, \
|
||||
'Output channels should be a positive integer.'
|
||||
new_out_mask = torch.zeros_like(out_mask)
|
||||
new_out_mask[:, :new_channels] = 1
|
||||
|
||||
return new_out_mask
|
||||
|
||||
def sample_subnet(self):
|
||||
"""Random sample subnet by random mask.
|
||||
|
||||
Returns:
|
||||
dict: Record the information to build the subnet from the supernet,
|
||||
its keys are the properties ``space_id`` in the pruner's search
|
||||
spaces, and its values are corresponding sampled out_mask.
|
||||
"""
|
||||
subnet_dict = dict()
|
||||
for space_id, out_mask in self.channel_spaces.items():
|
||||
subnet_dict[space_id] = self.get_channel_mask(out_mask)
|
||||
return subnet_dict
|
||||
|
||||
def set_min_channel(self):
|
||||
"""Set the number of channels each layer to minimum."""
|
||||
subnet_dict = dict()
|
||||
for space_id, out_mask in self.channel_spaces.items():
|
||||
out_channels = out_mask.size(1)
|
||||
random_ratio = self.min_ratio
|
||||
new_channels = int(round(out_channels * random_ratio))
|
||||
assert new_channels > 0, \
|
||||
'Output channels should be a positive integer.'
|
||||
new_out_mask = torch.zeros_like(out_mask)
|
||||
new_out_mask[:, :new_channels] = 1
|
||||
|
||||
subnet_dict[space_id] = new_out_mask
|
||||
|
||||
self.set_subnet(subnet_dict)
|
||||
|
||||
def switch_subnet(self, channel_cfg, subnet_ind=None):
|
||||
"""Switch the channel config of the supernet according to channel_cfg.
|
||||
|
||||
If we train more than one subnet together, we need to switch the
|
||||
channel_cfg from one to another during one training iteration.
|
||||
|
||||
Args:
|
||||
channel_cfg (dict): The channel config of a subnet. Key is space_id
|
||||
and value is a dict which includes out_channels (and
|
||||
in_channels if exists).
|
||||
subnet_ind (int, optional): The index of the current subnet. If
|
||||
we replace normal BatchNorm2d with ``SwitchableBatchNorm2d``,
|
||||
we should switch the index of ``SwitchableBatchNorm2d`` when
|
||||
switch subnet. Defaults to None.
|
||||
"""
|
||||
subnet_dict = dict()
|
||||
for name, channels_per_layer in channel_cfg.items():
|
||||
module = self.name2module[name]
|
||||
if (isinstance(module, SwitchableBatchNorm2d)
|
||||
and subnet_ind is not None):
|
||||
# When switching bn we should switch index simultaneously
|
||||
module.index = subnet_ind
|
||||
continue
|
||||
|
||||
out_channels = channels_per_layer['out_channels']
|
||||
out_mask = torch.zeros_like(module.out_mask)
|
||||
out_mask[:, :out_channels] = 1
|
||||
|
||||
space_id = self.get_space_id(name)
|
||||
if space_id in subnet_dict:
|
||||
assert torch.equal(subnet_dict[space_id], out_mask)
|
||||
elif space_id is not None:
|
||||
subnet_dict[space_id] = out_mask
|
||||
|
||||
self.set_subnet(subnet_dict)
|
||||
|
||||
def convert_switchable_bn(self, module, num_bns):
|
||||
"""Convert normal ``nn.BatchNorm2d`` to ``SwitchableBatchNorm2d``.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to be converted.
|
||||
num_bns (int): The number of ``nn.BatchNorm2d`` in a
|
||||
``SwitchableBatchNorm2d``.
|
||||
|
||||
Return:
|
||||
:obj:`torch.nn.Module`: The converted module. Each
|
||||
``nn.BatchNorm2d`` in this module has been converted to a
|
||||
``SwitchableBatchNorm2d``.
|
||||
"""
|
||||
module_output = module
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
module_output = SwitchableBatchNorm2d(module.num_features, num_bns)
|
||||
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(
|
||||
name, self.convert_switchable_bn(child, num_bns))
|
||||
|
||||
del module
|
||||
return module_output
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .switchable_bn import SwitchableBatchNorm2d
|
||||
|
||||
__all__ = ['SwitchableBatchNorm2d']
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SwitchableBatchNorm2d(nn.Module):
|
||||
"""Employs independent batch normalization for different switches in a
|
||||
slimmable network.
|
||||
|
||||
To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all
|
||||
batch normalization layers for each switch in a slimmable network.
|
||||
Compared with the naive training approach, it solves the problem of feature
|
||||
aggregation inconsistency between different switches by independently
|
||||
normalizing the feature mean and variance during testing.
|
||||
|
||||
Args:
|
||||
max_num_features (int): The maximum ``num_features`` among BatchNorm2d
|
||||
in all the switches.
|
||||
num_bns (int): The number of different switches in the slimmable
|
||||
networks.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_features, num_bns):
|
||||
super(SwitchableBatchNorm2d, self).__init__()
|
||||
|
||||
self.max_num_features = max_num_features
|
||||
# number of BatchNorm2d in a SwitchableBatchNorm2d
|
||||
self.num_bns = num_bns
|
||||
bns = []
|
||||
for _ in range(num_bns):
|
||||
bns.append(nn.BatchNorm2d(max_num_features))
|
||||
self.bns = nn.ModuleList(bns)
|
||||
# When switching bn we should switch index simultaneously
|
||||
self.index = 0
|
||||
|
||||
def forward(self, input):
|
||||
"""Forward computation according to the current switch of the slimmable
|
||||
networks."""
|
||||
return self.bns[self.index](input)
|
|
@ -0,0 +1,474 @@
|
|||
backbone.conv1.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.conv1.conv.mutable_in_channels:
|
||||
current_choice: 3
|
||||
origin_channels: 3
|
||||
backbone.conv1.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.conv2.bn.mutable_num_features:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
backbone.conv2.conv.mutable_in_channels:
|
||||
current_choice: 280
|
||||
origin_channels: 480
|
||||
backbone.conv2.conv.mutable_out_channels:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
backbone.layer1.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer1.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer2.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer2.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer3.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer3.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer4.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer4.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.2.bn.mutable_num_features:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer5.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 48
|
||||
origin_channels: 96
|
||||
backbone.layer5.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer6.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 64
|
||||
origin_channels: 144
|
||||
backbone.layer6.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 648
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 720
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer7.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 176
|
||||
origin_channels: 240
|
||||
backbone.layer7.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 280
|
||||
origin_channels: 480
|
||||
backbone.layer7.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 280
|
||||
origin_channels: 480
|
||||
head.fc.mutable_in_features:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
head.fc.mutable_out_features:
|
||||
current_choice: 1000
|
||||
origin_channels: 1000
|
|
@ -0,0 +1,474 @@
|
|||
backbone.conv1.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.conv1.conv.mutable_in_channels:
|
||||
current_choice: 3
|
||||
origin_channels: 3
|
||||
backbone.conv1.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.conv2.bn.mutable_num_features:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
backbone.conv2.conv.mutable_in_channels:
|
||||
current_choice: 480
|
||||
origin_channels: 480
|
||||
backbone.conv2.conv.mutable_out_channels:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
backbone.layer1.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer1.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 48
|
||||
backbone.layer1.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer2.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 24
|
||||
backbone.layer2.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer2.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer2.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer2.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer3.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 16
|
||||
origin_channels: 40
|
||||
backbone.layer3.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 240
|
||||
backbone.layer3.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer3.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer3.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer4.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 24
|
||||
origin_channels: 48
|
||||
backbone.layer4.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 144
|
||||
origin_channels: 288
|
||||
backbone.layer4.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.2.bn.mutable_num_features:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer4.3.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer4.3.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer5.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 56
|
||||
origin_channels: 96
|
||||
backbone.layer5.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 288
|
||||
origin_channels: 576
|
||||
backbone.layer5.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer5.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 432
|
||||
origin_channels: 864
|
||||
backbone.layer5.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer6.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 96
|
||||
origin_channels: 144
|
||||
backbone.layer6.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 864
|
||||
origin_channels: 864
|
||||
backbone.layer6.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.0.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.2.bn.mutable_num_features:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.1.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer6.1.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.0.bn.mutable_num_features:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.bn.mutable_num_features:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.2.bn.mutable_num_features:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer6.2.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 960
|
||||
origin_channels: 1440
|
||||
backbone.layer6.2.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer7.0.conv.0.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.0.conv.mutable_in_channels:
|
||||
current_choice: 240
|
||||
origin_channels: 240
|
||||
backbone.layer7.0.conv.0.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.bn.mutable_num_features:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.1.conv.mutable_out_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.2.bn.mutable_num_features:
|
||||
current_choice: 480
|
||||
origin_channels: 480
|
||||
backbone.layer7.0.conv.2.conv.mutable_in_channels:
|
||||
current_choice: 1440
|
||||
origin_channels: 1440
|
||||
backbone.layer7.0.conv.2.conv.mutable_out_channels:
|
||||
current_choice: 480
|
||||
origin_channels: 480
|
||||
head.fc.mutable_in_features:
|
||||
current_choice: 1920
|
||||
origin_channels: 1920
|
||||
head.fc.mutable_out_features:
|
||||
current_choice: 1000
|
||||
origin_channels: 1000
|
|
@ -0,0 +1,24 @@
|
|||
op1.mutable_in_channels:
|
||||
current_choice: 3
|
||||
origin_channels: 3
|
||||
op1.mutable_out_channels:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
bn1.mutable_num_features:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
op2.mutable_in_channels:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
op2.mutable_out_channels:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
bn2.mutable_num_features:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
op3.mutable_in_channels:
|
||||
current_choice: 4
|
||||
origin_channels: 8
|
||||
op3.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
|
@ -0,0 +1,24 @@
|
|||
op1.mutable_in_channels:
|
||||
current_choice: 3
|
||||
origin_channels: 3
|
||||
op1.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
bn1.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
op2.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
op2.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
bn2.mutable_num_features:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
op3.mutable_in_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
||||
op3.mutable_out_channels:
|
||||
current_choice: 8
|
||||
origin_channels: 8
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor.core import (BackwardTracer, ConcatNode, ConvNode,
|
||||
DepthWiseConvNode, LinearNode, NormNode, Path,
|
||||
PathList)
|
||||
|
||||
NONPASS_NODES = (ConvNode, LinearNode, ConcatNode)
|
||||
PASS_NODES = (NormNode, DepthWiseConvNode)
|
||||
|
||||
|
||||
class MultiConcatModel(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.op2 = nn.Conv2d(3, 8, 1)
|
||||
self.op3 = nn.Conv2d(16, 8, 1)
|
||||
self.op4 = nn.Conv2d(3, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.op1(x)
|
||||
x2 = self.op2(x)
|
||||
cat1 = torch.cat([x1, x2], dim=1)
|
||||
x3 = self.op3(cat1)
|
||||
x4 = self.op4(x)
|
||||
output = torch.cat([x3, x4], dim=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MultiConcatModel2(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.op2 = nn.Conv2d(3, 8, 1)
|
||||
self.op3 = nn.Conv2d(3, 8, 1)
|
||||
self.op4 = nn.Conv2d(24, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.op1(x)
|
||||
x2 = self.op2(x)
|
||||
x3 = self.op3(x)
|
||||
cat1 = torch.cat([x1, x2], dim=1)
|
||||
cat2 = torch.cat([cat1, x3], dim=1)
|
||||
output = self.op4(cat2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResBlock(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.bn1 = nn.BatchNorm2d(8)
|
||||
self.op2 = nn.Conv2d(8, 8, 1)
|
||||
self.bn2 = nn.BatchNorm2d(8)
|
||||
self.op3 = nn.Conv2d(8, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.bn1(self.op1(x))
|
||||
x2 = self.bn2(self.op2(x1))
|
||||
x3 = self.op3(x2 + x1)
|
||||
return x3
|
||||
|
||||
|
||||
class ToyCNNPseudoLoss:
|
||||
|
||||
def __call__(self, model):
|
||||
pseudo_img = torch.rand(2, 3, 16, 16)
|
||||
pseudo_output = model(pseudo_img)
|
||||
return pseudo_output.sum()
|
||||
|
||||
|
||||
class TestBackwardTracer(TestCase):
|
||||
|
||||
def test_trace_resblock(self) -> None:
|
||||
model = ResBlock()
|
||||
loss_calculator = ToyCNNPseudoLoss()
|
||||
tracer = BackwardTracer(loss_calculator=loss_calculator)
|
||||
path_list = tracer.trace(model)
|
||||
|
||||
# test tracer and parser
|
||||
assert len(path_list) == 2
|
||||
assert len(path_list[0]) == 5
|
||||
|
||||
# test path_list
|
||||
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
|
||||
assert len(nonpass2parents) == 3
|
||||
assert nonpass2parents['op1'] == list()
|
||||
assert nonpass2parents['op2'] == list({NormNode('bn1')})
|
||||
assert nonpass2parents['op3'] == list(
|
||||
{NormNode('bn2'), NormNode('bn1')})
|
||||
|
||||
nonpass2nonpassparents = path_list.find_nodes_parents(
|
||||
NONPASS_NODES, non_pass=NONPASS_NODES)
|
||||
assert len(nonpass2parents) == 3
|
||||
assert nonpass2nonpassparents['op1'] == list()
|
||||
assert nonpass2nonpassparents['op2'] == list({ConvNode('op1')})
|
||||
assert nonpass2nonpassparents['op3'] == list(
|
||||
{ConvNode('op2'), ConvNode('op1')})
|
||||
|
||||
pass2nonpassparents = path_list.find_nodes_parents(
|
||||
PASS_NODES, non_pass=NONPASS_NODES)
|
||||
assert len(pass2nonpassparents) == 2
|
||||
assert pass2nonpassparents['bn1'] == list({ConvNode('op1')})
|
||||
assert pass2nonpassparents['bn2'] == list({ConvNode('op2')})
|
||||
|
||||
def test_trace_multi_cat(self) -> None:
|
||||
loss_calculator = ToyCNNPseudoLoss()
|
||||
|
||||
model = MultiConcatModel()
|
||||
tracer = BackwardTracer(loss_calculator=loss_calculator)
|
||||
path_list = tracer.trace(model)
|
||||
|
||||
assert len(path_list) == 1
|
||||
|
||||
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
|
||||
assert len(nonpass2parents) == 4
|
||||
assert nonpass2parents['op1'] == list()
|
||||
assert nonpass2parents['op2'] == list()
|
||||
path_list1 = PathList(Path(ConvNode('op1')))
|
||||
path_list2 = PathList(Path(ConvNode('op2')))
|
||||
# only one parent
|
||||
assert len(nonpass2parents['op3']) == 1
|
||||
assert isinstance(nonpass2parents['op3'][0], ConcatNode)
|
||||
assert len(nonpass2parents['op3'][0]) == 2
|
||||
assert nonpass2parents['op3'][0].get_module_names() == ['op1', 'op2']
|
||||
assert nonpass2parents['op3'][0].path_lists == [path_list1, path_list2]
|
||||
assert nonpass2parents['op3'][0][0] == path_list1
|
||||
assert nonpass2parents['op4'] == list()
|
||||
|
||||
model = MultiConcatModel2()
|
||||
tracer = BackwardTracer(loss_calculator=loss_calculator)
|
||||
path_list = tracer.trace(model)
|
||||
assert len(path_list) == 1
|
||||
|
||||
nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES)
|
||||
assert len(nonpass2parents) == 4
|
||||
assert nonpass2parents['op1'] == list()
|
||||
assert nonpass2parents['op2'] == list()
|
||||
assert nonpass2parents['op3'] == list()
|
||||
# only one parent
|
||||
assert len(nonpass2parents['op4']) == 1
|
||||
assert isinstance(nonpass2parents['op4'][0], ConcatNode)
|
||||
assert nonpass2parents['op4'][0].get_module_names() == [
|
||||
'op1', 'op2', 'op3'
|
||||
]
|
||||
|
||||
def test_repr(self):
|
||||
toy_node = ConvNode('op1')
|
||||
assert repr(toy_node) == 'ConvNode(\'op1\')'
|
||||
|
||||
toy_path = Path([ConvNode('op1'), ConvNode('op2')])
|
||||
assert repr(
|
||||
toy_path) == 'Path(\n ConvNode(\'op1\'),\n ConvNode(\'op2\')\n)'
|
||||
|
||||
toy_path_list = PathList(Path(ConvNode('op1')))
|
||||
assert repr(toy_path_list
|
||||
) == 'PathList(\n Path(\n ConvNode(\'op1\')\n )\n)'
|
||||
|
||||
path_list1 = PathList(Path(ConvNode('op1')))
|
||||
path_list2 = PathList(Path(ConvNode('op2')))
|
||||
toy_concat_node = ConcatNode('op3', [path_list1, path_list2])
|
||||
assert repr(
|
||||
toy_concat_node
|
||||
) == 'ConcatNode(\n PathList(\n Path(\n ConvNode(\'op1\')\n )\n ),\n PathList(\n Path(\n ConvNode(\'op2\')\n )\n )\n)' # noqa: E501
|
||||
|
||||
def test_reset_bn_running_stats(self):
|
||||
_test_reset_bn_running_stats(False)
|
||||
with pytest.raises(AssertionError):
|
||||
_test_reset_bn_running_stats(True)
|
||||
|
||||
def test_node(self):
|
||||
node1 = ConvNode('conv1')
|
||||
node2 = ConvNode('conv2')
|
||||
assert node1 != node2
|
||||
|
||||
node1 = ConvNode('conv1')
|
||||
node2 = ConvNode('conv1')
|
||||
assert node1 == node2
|
||||
|
||||
def test_path(self):
|
||||
node1 = ConvNode('conv1')
|
||||
node2 = ConvNode('conv2')
|
||||
|
||||
path1 = Path([node1])
|
||||
path2 = Path([node2])
|
||||
assert path1 != path2
|
||||
|
||||
path1 = Path([node1])
|
||||
path2 = Path([node1])
|
||||
assert path1 == path2
|
||||
|
||||
assert path1[0] == node1
|
||||
|
||||
def test_path_list(self):
|
||||
node1 = ConvNode('conv1')
|
||||
node2 = ConvNode('conv2')
|
||||
|
||||
path1 = Path([node1])
|
||||
path2 = Path([node2])
|
||||
assert PathList(path1) == PathList([path1])
|
||||
assert PathList(path1) != PathList(path2)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, ''):
|
||||
_ = PathList({})
|
||||
|
||||
|
||||
def _test_reset_bn_running_stats(should_fail):
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
set_seed(1024)
|
||||
imgs = torch.randn(2, 3, 4, 4)
|
||||
loss_calculator = ToyCNNPseudoLoss()
|
||||
tracer = BackwardTracer(loss_calculator=loss_calculator)
|
||||
if should_fail:
|
||||
tracer._reset_norm_running_stats = lambda *_: None
|
||||
|
||||
torch_rng_state = torch.get_rng_state()
|
||||
np_rng_state = np.random.get_state()
|
||||
random_rng_state = random.getstate()
|
||||
|
||||
model1 = ResBlock()
|
||||
set_seed(1)
|
||||
tracer.trace(model1)
|
||||
model1.eval()
|
||||
output1 = model1(imgs)
|
||||
|
||||
set_seed(1024)
|
||||
torch.set_rng_state(torch_rng_state)
|
||||
np.random.set_state(np_rng_state)
|
||||
random.setstate(random_rng_state)
|
||||
|
||||
model2 = ResBlock()
|
||||
set_seed(2)
|
||||
tracer.trace(model2)
|
||||
model2.eval()
|
||||
output2 = model2(imgs)
|
||||
|
||||
assert torch.equal(output1.norm(p='fro'), output2.norm(p='fro'))
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmrazor.models import OrderChannelMutable, RatioChannelMutable
|
||||
|
||||
|
||||
class TestChannelMutables(TestCase):
|
||||
|
||||
def test_ratio_channel_mutable(self):
|
||||
with pytest.raises(AssertionError):
|
||||
# Test invalid `mask_type`
|
||||
RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='xxx',
|
||||
num_channels=8,
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of candidate choices must be greater than 0
|
||||
RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=list())
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The candidate ratio should be in range(0, 1].
|
||||
RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[0., 1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Minimum number of channels should be a positive integer.
|
||||
out_mutable = RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
_ = out_mutable.min_choice
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Minimum number of channels should be a positive integer.
|
||||
out_mutable = RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
out_mutable.get_choice(0)
|
||||
|
||||
# Test out_mutable (mask_type == 'out_mask')
|
||||
out_mutable = RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
random_choice = out_mutable.sample_choice()
|
||||
assert random_choice in [2, 4, 6, 8]
|
||||
|
||||
choice = out_mutable.get_choice(0)
|
||||
assert choice == 2
|
||||
|
||||
max_choice = out_mutable.max_choice
|
||||
assert max_choice == 8
|
||||
out_mutable.current_choice = max_choice
|
||||
assert torch.equal(out_mutable.mask,
|
||||
torch.ones_like(out_mutable.mask).bool())
|
||||
|
||||
min_choice = out_mutable.min_choice
|
||||
assert min_choice == 2
|
||||
out_mutable.current_choice = min_choice
|
||||
min_mask = torch.zeros_like(out_mutable.mask).bool()
|
||||
min_mask[:2] = True
|
||||
assert torch.equal(out_mutable.mask, min_mask)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Only mutables with mask_type == 'in_mask' (named in_mutable) can
|
||||
# add `concat_mutables`
|
||||
concat_mutables = [copy.deepcopy(out_mutable)] * 2
|
||||
out_mutable.register_same_mutable(concat_mutables)
|
||||
|
||||
# Test in_mutable (mask_type == 'in_mask') with concat_mutable
|
||||
in_mutable = RatioChannelMutable(
|
||||
name='op',
|
||||
mask_type='in_mask',
|
||||
num_channels=16,
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
out_mutable1 = copy.deepcopy(out_mutable)
|
||||
out_mutable2 = copy.deepcopy(out_mutable)
|
||||
in_mutable.register_same_mutable([out_mutable1, out_mutable2])
|
||||
choice1 = out_mutable1.sample_choice()
|
||||
out_mutable1.current_choice = choice1
|
||||
choice2 = out_mutable2.sample_choice()
|
||||
out_mutable2.current_choice = choice2
|
||||
assert in_mutable.current_choice == choice1 + choice2
|
||||
assert torch.equal(in_mutable.mask,
|
||||
torch.cat([out_mutable1.mask, out_mutable2.mask]))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The mask of this in_mutable depends on the out mask of its
|
||||
# `concat_mutables`, so the `sample_choice` method should not
|
||||
# be called
|
||||
in_mutable.sample_choice()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The mask of this in_mutable depends on the out mask of its
|
||||
# `concat_mutables`, so the `min_choice` property should not
|
||||
# be called
|
||||
_ = in_mutable.min_choice
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The mask of this in_mutable depends on the out mask of its
|
||||
# `concat_mutables`, so the `get_choice` method should not
|
||||
# be called
|
||||
in_mutable.get_choice(0)
|
||||
|
||||
def test_order_channel_mutable(self):
|
||||
with pytest.raises(AssertionError):
|
||||
# The candidate ratio should be in range(0, `num_channels`].
|
||||
OrderChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[0, 2, 4, 6, 8])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Type of `candidate_choices` should be int.
|
||||
OrderChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[0., 2, 4, 6, 8])
|
||||
|
||||
# Test out_mutable (mask_type == 'out_mask')
|
||||
out_mutable = OrderChannelMutable(
|
||||
name='op',
|
||||
mask_type='out_mask',
|
||||
num_channels=8,
|
||||
candidate_choices=[2, 4, 6, 8])
|
||||
|
||||
random_choice = out_mutable.sample_choice()
|
||||
assert random_choice in [2, 4, 6, 8]
|
||||
|
||||
choice = out_mutable.get_choice(0)
|
||||
assert choice == 2
|
||||
|
||||
max_choice = out_mutable.max_choice
|
||||
assert max_choice == 8
|
||||
out_mutable.current_choice = max_choice
|
||||
assert torch.equal(out_mutable.mask,
|
||||
torch.ones_like(out_mutable.mask).bool())
|
||||
|
||||
min_choice = out_mutable.min_choice
|
||||
assert min_choice == 2
|
||||
out_mutable.current_choice = min_choice
|
||||
min_mask = torch.zeros_like(out_mutable.mask).bool()
|
||||
min_mask[:2] = True
|
||||
assert torch.equal(out_mutable.mask, min_mask)
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
|
||||
build_dynamic_conv2d,
|
||||
build_dynamic_gn,
|
||||
build_dynamic_in,
|
||||
build_dynamic_linear)
|
||||
|
||||
|
||||
class TestDynamicLayer(TestCase):
|
||||
|
||||
def test_dynamic_conv(self):
|
||||
imgs = torch.rand(2, 8, 16, 16)
|
||||
|
||||
in_channels_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
out_channels_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
conv = nn.Conv2d(8, 8, 1)
|
||||
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
|
||||
out_channels_cfg)
|
||||
# test forward
|
||||
dynamic_conv(imgs)
|
||||
|
||||
conv = nn.Conv2d(8, 8, 1, groups=8)
|
||||
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
|
||||
out_channels_cfg)
|
||||
# test forward
|
||||
dynamic_conv(imgs)
|
||||
|
||||
conv = nn.Conv2d(8, 8, 1, groups=4)
|
||||
dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg,
|
||||
out_channels_cfg)
|
||||
# test forward
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'only support pruning the depth-wise'):
|
||||
dynamic_conv(imgs)
|
||||
|
||||
def test_dynamic_linear(self):
|
||||
imgs = torch.rand(2, 8)
|
||||
|
||||
in_features_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
out_features_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
linear = nn.Linear(8, 8)
|
||||
dynamic_linear = build_dynamic_linear(linear, 'op', in_features_cfg,
|
||||
out_features_cfg)
|
||||
# test forward
|
||||
dynamic_linear(imgs)
|
||||
|
||||
def test_dynamic_batchnorm(self):
|
||||
imgs = torch.rand(2, 8, 16, 16)
|
||||
|
||||
num_features_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
bn = nn.BatchNorm2d(8)
|
||||
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_bn(imgs)
|
||||
|
||||
bn = nn.BatchNorm2d(8, momentum=0)
|
||||
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_bn(imgs)
|
||||
|
||||
bn = nn.BatchNorm2d(8)
|
||||
bn.train()
|
||||
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_bn(imgs)
|
||||
# test num_batches_tracked is not None
|
||||
dynamic_bn(imgs)
|
||||
|
||||
bn = nn.BatchNorm2d(8, affine=False)
|
||||
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_bn(imgs)
|
||||
|
||||
bn = nn.BatchNorm2d(8, track_running_stats=False)
|
||||
dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_bn(imgs)
|
||||
|
||||
def test_dynamic_instancenorm(self):
|
||||
imgs = torch.rand(2, 8, 16, 16)
|
||||
|
||||
num_features_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
instance_norm = nn.InstanceNorm2d(8)
|
||||
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_in(imgs)
|
||||
|
||||
instance_norm = nn.InstanceNorm2d(8, affine=False)
|
||||
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_in(imgs)
|
||||
|
||||
instance_norm = nn.InstanceNorm2d(8, track_running_stats=False)
|
||||
dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg)
|
||||
# test forward
|
||||
dynamic_in(imgs)
|
||||
|
||||
def test_dynamic_groupnorm(self):
|
||||
imgs = torch.rand(2, 8, 16, 16)
|
||||
|
||||
num_channels_cfg = dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0])
|
||||
|
||||
gn = nn.GroupNorm(num_groups=4, num_channels=8)
|
||||
dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg)
|
||||
# test forward
|
||||
dynamic_gn(imgs)
|
||||
|
||||
gn = nn.GroupNorm(num_groups=4, num_channels=8, affine=False)
|
||||
dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg)
|
||||
# test forward
|
||||
dynamic_gn(imgs)
|
|
@ -0,0 +1,229 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import unittest
|
||||
from os.path import dirname
|
||||
|
||||
import mmcv.fileio
|
||||
import pytest
|
||||
import torch
|
||||
from mmcls.models import * # noqa: F401,F403
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor import digit_version
|
||||
from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn,
|
||||
build_dynamic_conv2d)
|
||||
from mmrazor.models.mutables import SlimmableChannelMutable
|
||||
from mmrazor.models.mutators import (OneShotChannelMutator,
|
||||
SlimmableChannelMutator)
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
ONESHOT_MUTATOR_CFG = dict(
|
||||
type='OneShotChannelMutator',
|
||||
tracer_cfg=dict(
|
||||
type='BackwardTracer',
|
||||
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
|
||||
mutable_cfg=dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[
|
||||
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
|
||||
]))
|
||||
|
||||
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER = dict(
|
||||
type='OneShotChannelMutator',
|
||||
mutable_cfg=dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[
|
||||
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
|
||||
]))
|
||||
|
||||
|
||||
class MultiConcatModel(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.op2 = nn.Conv2d(3, 8, 1)
|
||||
self.op3 = nn.Conv2d(16, 8, 1)
|
||||
self.op4 = nn.Conv2d(3, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.op1(x)
|
||||
x2 = self.op2(x)
|
||||
cat1 = torch.cat([x1, x2], dim=1)
|
||||
x3 = self.op3(cat1)
|
||||
x4 = self.op4(x)
|
||||
output = torch.cat([x3, x4], dim=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MultiConcatModel2(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.op2 = nn.Conv2d(3, 8, 1)
|
||||
self.op3 = nn.Conv2d(3, 8, 1)
|
||||
self.op4 = nn.Conv2d(24, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.op1(x)
|
||||
x2 = self.op2(x)
|
||||
x3 = self.op3(x)
|
||||
cat1 = torch.cat([x1, x2], dim=1)
|
||||
cat2 = torch.cat([cat1, x3], dim=1)
|
||||
output = self.op4(cat2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResBlock(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.op1 = nn.Conv2d(3, 8, 1)
|
||||
self.bn1 = nn.BatchNorm2d(8)
|
||||
self.op2 = nn.Conv2d(8, 8, 1)
|
||||
self.bn2 = nn.BatchNorm2d(8)
|
||||
self.op3 = nn.Conv2d(8, 8, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.bn1(self.op1(x))
|
||||
x2 = self.bn2(self.op2(x1))
|
||||
x3 = self.op3(x2 + x1)
|
||||
return x3
|
||||
|
||||
|
||||
class DynamicResBlock(Module):
|
||||
|
||||
def __init__(self, mutable_cfg) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dynamic_op1 = build_dynamic_conv2d(
|
||||
nn.Conv2d(3, 8, 1), 'dynamic_op1', mutable_cfg, mutable_cfg)
|
||||
self.dynamic_bn1 = build_dynamic_bn(
|
||||
nn.BatchNorm2d(8), 'dynamic_bn1', mutable_cfg)
|
||||
self.dynamic_op2 = build_dynamic_conv2d(
|
||||
nn.Conv2d(8, 8, 1), 'dynamic_op2', mutable_cfg, mutable_cfg)
|
||||
self.dynamic_bn2 = build_dynamic_bn(
|
||||
nn.BatchNorm2d(8), 'dynamic_bn2', mutable_cfg)
|
||||
self.dynamic_op3 = build_dynamic_conv2d(
|
||||
nn.Conv2d(8, 8, 1), 'dynamic_op3', mutable_cfg, mutable_cfg)
|
||||
self._add_link()
|
||||
|
||||
def _add_link(self):
|
||||
op1_mutable_out = self.dynamic_op1.mutable_out
|
||||
bn1_mutable_out = self.dynamic_bn1.mutable_out
|
||||
|
||||
op2_mutable_in = self.dynamic_op2.mutable_in
|
||||
op2_mutable_out = self.dynamic_op2.mutable_out
|
||||
bn2_mutable_out = self.dynamic_bn2.mutable_out
|
||||
|
||||
op3_mutable_in = self.dynamic_op3.mutable_in
|
||||
|
||||
bn1_mutable_out.register_same_mutable(op1_mutable_out)
|
||||
op1_mutable_out.register_same_mutable(bn1_mutable_out)
|
||||
|
||||
op2_mutable_in.register_same_mutable(bn1_mutable_out)
|
||||
bn1_mutable_out.register_same_mutable(op2_mutable_in)
|
||||
|
||||
bn2_mutable_out.register_same_mutable(op2_mutable_out)
|
||||
op2_mutable_out.register_same_mutable(bn2_mutable_out)
|
||||
|
||||
op3_mutable_in.register_same_mutable(bn1_mutable_out)
|
||||
bn1_mutable_out.register_same_mutable(op3_mutable_in)
|
||||
|
||||
op3_mutable_in.register_same_mutable(bn2_mutable_out)
|
||||
bn2_mutable_out.register_same_mutable(op3_mutable_in)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.dynamic_bn1(self.dynamic_op1(x))
|
||||
x2 = self.dynamic_bn2(self.dynamic_op2(x1))
|
||||
x3 = self.dynamic_op3(x2 + x1)
|
||||
return x3
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
digit_version(torch.__version__) == digit_version('1.8.1'),
|
||||
'PyTorch version 1.8.1 is not supported by the Backward Tracer.')
|
||||
def test_oneshot_channel_mutator() -> None:
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
|
||||
def _test(model):
|
||||
mutator.prepare_from_supernet(model)
|
||||
for key, val in mutator.search_groups.items():
|
||||
print(key, val)
|
||||
# print(mutator.search_groups)
|
||||
assert hasattr(mutator, 'name2module')
|
||||
|
||||
# test set_min_choices
|
||||
mutator.set_min_choices()
|
||||
for mutables in mutator.search_groups.values():
|
||||
for mutable in mutables:
|
||||
# 1 / 8 is the minimum candidate ratio
|
||||
assert mutable.current_choice == round(1 / 8 *
|
||||
mutable.num_channels)
|
||||
|
||||
# test set_max_channel
|
||||
mutator.set_max_choices()
|
||||
for mutables in mutator.search_groups.values():
|
||||
for mutable in mutables:
|
||||
# 1.0 is the maximum candidate ratio
|
||||
assert mutable.current_choice == round(1. *
|
||||
mutable.num_channels)
|
||||
|
||||
# test making groups logic
|
||||
choice_dict = mutator.sample_choices()
|
||||
assert isinstance(choice_dict, dict)
|
||||
mutator.set_choices(choice_dict)
|
||||
model(imgs)
|
||||
|
||||
mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG)
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = mutator.search_groups
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = mutator.name2module
|
||||
|
||||
_test(ResBlock())
|
||||
_test(MultiConcatModel())
|
||||
_test(MultiConcatModel2())
|
||||
_test(nn.Sequential(nn.BatchNorm2d(3)))
|
||||
|
||||
mutator: OneShotChannelMutator = MODELS.build(
|
||||
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER)
|
||||
dynamic_model = DynamicResBlock(
|
||||
ONESHOT_MUTATOR_CFG_WITHOUT_TRACER['mutable_cfg'])
|
||||
_test(dynamic_model)
|
||||
|
||||
|
||||
def test_slimmable_channel_mutator() -> None:
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
|
||||
root_path = dirname(dirname(dirname(__file__)))
|
||||
channel_cfgs = [
|
||||
os.path.join(root_path, 'data/subnet1.yaml'),
|
||||
os.path.join(root_path, 'data/subnet2.yaml')
|
||||
]
|
||||
channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs]
|
||||
|
||||
mutator = SlimmableChannelMutator(
|
||||
mutable_cfg=dict(type='SlimmableChannelMutable'),
|
||||
channel_cfgs=channel_cfgs)
|
||||
|
||||
model = ResBlock()
|
||||
mutator.prepare_from_supernet(model)
|
||||
mutator.switch_choices(0)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SlimmableChannelMutable):
|
||||
assert module.current_choice == 0
|
||||
_ = model(imgs)
|
||||
|
||||
mutator.switch_choices(1)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SlimmableChannelMutable):
|
||||
assert module.current_choice == 1
|
||||
_ = model(imgs)
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import unittest
|
||||
from os.path import dirname
|
||||
|
||||
import mmcv.fileio
|
||||
import torch
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.models import * # noqa: F401,F403
|
||||
|
||||
from mmrazor import digit_version
|
||||
from mmrazor.models.mutables import SlimmableChannelMutable
|
||||
from mmrazor.models.mutators import (OneShotChannelMutator,
|
||||
SlimmableChannelMutator)
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
MODEL_CFG = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(type='MobileNetV2', widen_factor=1.5),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1920,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
||||
|
||||
ONESHOT_MUTATOR_CFG = dict(
|
||||
type='OneShotChannelMutator',
|
||||
skip_prefixes=['head.fc'],
|
||||
tracer_cfg=dict(
|
||||
type='BackwardTracer',
|
||||
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
|
||||
mutable_cfg=dict(
|
||||
type='RatioChannelMutable',
|
||||
candidate_choices=[
|
||||
1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0
|
||||
]))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
digit_version(torch.__version__) == digit_version('1.8.1'),
|
||||
'PyTorch version 1.8.1 is not supported by the Backward Tracer.')
|
||||
def test_oneshot_channel_mutator() -> None:
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, )))
|
||||
]
|
||||
|
||||
model = MODELS.build(MODEL_CFG)
|
||||
mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert hasattr(mutator, 'name2module')
|
||||
|
||||
# test set_min_choices
|
||||
mutator.set_min_choices()
|
||||
for mutables in mutator.search_groups.values():
|
||||
for mutable in mutables:
|
||||
# 1 / 8 is the minimum candidate ratio
|
||||
assert mutable.current_choice == round(1 / 8 *
|
||||
mutable.num_channels)
|
||||
|
||||
# test set_max_channel
|
||||
mutator.set_max_choices()
|
||||
for mutables in mutator.search_groups.values():
|
||||
for mutable in mutables:
|
||||
# 1.0 is the maximum candidate ratio
|
||||
assert mutable.current_choice == round(1. * mutable.num_channels)
|
||||
|
||||
# test making groups logic
|
||||
choice_dict = mutator.sample_choices()
|
||||
assert isinstance(choice_dict, dict)
|
||||
mutator.set_choices(choice_dict)
|
||||
model(imgs, data_samples=data_samples, mode='loss')
|
||||
|
||||
|
||||
def test_slimmable_channel_mutator() -> None:
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, )))
|
||||
]
|
||||
|
||||
root_path = dirname(dirname(dirname(dirname(__file__))))
|
||||
channel_cfgs = [
|
||||
os.path.join(root_path, 'data/MBV2_320M.yaml'),
|
||||
os.path.join(root_path, 'data/MBV2_220M.yaml')
|
||||
]
|
||||
channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs]
|
||||
|
||||
mutator = SlimmableChannelMutator(
|
||||
mutable_cfg=dict(type='SlimmableChannelMutable'),
|
||||
channel_cfgs=channel_cfgs)
|
||||
|
||||
model = MODELS.build(MODEL_CFG)
|
||||
mutator.prepare_from_supernet(model)
|
||||
mutator.switch_choices(0)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SlimmableChannelMutable):
|
||||
assert module.current_choice == 0
|
||||
model(imgs, data_samples=data_samples, mode='loss')
|
||||
|
||||
mutator.switch_choices(1)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SlimmableChannelMutable):
|
||||
assert module.current_choice == 1
|
||||
model(imgs, data_samples=data_samples, mode='loss')
|
|
@ -104,7 +104,7 @@ class TestDiffMutator(TestCase):
|
|||
mutator: DiffOP = MODELS.build(self.MUTATOR_CFG)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1, 2]
|
||||
|
||||
def test_diff_mutator_diffop_model(self) -> None:
|
||||
model = SearchableModel(self.MUTABLE_CFG)
|
||||
|
@ -118,7 +118,7 @@ class TestDiffMutator(TestCase):
|
|||
mutator: DiffOP = MODELS.build(mutator_cfg)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1, 2]
|
||||
|
||||
mutator.modify_supernet_forward()
|
||||
assert mutator.mutable_class_type == DiffMutable
|
||||
|
@ -146,7 +146,7 @@ class TestDiffMutator(TestCase):
|
|||
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1, 2]
|
||||
|
||||
mutator.modify_supernet_forward()
|
||||
assert mutator.mutable_class_type == DiffMutable
|
||||
|
@ -165,7 +165,7 @@ class TestDiffMutator(TestCase):
|
|||
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2, 3]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1, 2, 3]
|
||||
|
||||
mutator.modify_supernet_forward()
|
||||
assert mutator.mutable_class_type == DiffMutable
|
||||
|
|
|
@ -59,10 +59,10 @@ def test_one_shot_mutator_normal_model() -> None:
|
|||
assert mutator.mutable_class_type == OneShotMutable
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = mutator.search_group
|
||||
_ = mutator.search_groups
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert len(mutator.search_group) == 0
|
||||
assert len(mutator.search_groups) == 0
|
||||
assert len(mutator.sample_choices()) == 0
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ def test_one_shot_mutator_mutable_model() -> None:
|
|||
|
||||
# import pdb; pdb.set_trace()
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert list(mutator.search_group.keys()) == [0, 1, 2]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1, 2]
|
||||
|
||||
random_choices = mutator.sample_choices()
|
||||
assert list(random_choices.keys()) == [0, 1, 2]
|
||||
|
@ -102,7 +102,7 @@ def test_one_shot_mutator_mutable_model() -> None:
|
|||
mutator = MODELS.build(mutator_cfg)
|
||||
|
||||
mutator.prepare_from_supernet(model)
|
||||
assert list(mutator.search_group.keys()) == [0, 1]
|
||||
assert list(mutator.search_groups.keys()) == [0, 1]
|
||||
|
||||
random_choices = mutator.sample_choices()
|
||||
assert list(random_choices.keys()) == [0, 1]
|
||||
|
|
Loading…
Reference in New Issue