From 42063ae4d38673cc89ea6fbc3e224c6f8bbeb5bb Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 1 Jul 2022 08:23:15 +0000 Subject: [PATCH] [Refactor]Refactor tracer and channel mutator --- mmrazor/core/__init__.py | 1 + mmrazor/core/tracer/__init__.py | 11 + mmrazor/core/tracer/backward_tracer.py | 201 ++++++++ .../core/tracer/loss_calculator/__init__.py | 6 + .../image_classifier_loss_calculator.py | 16 + .../single_stage_detector_loss_calculator.py | 19 + mmrazor/core/tracer/parsers.py | 189 +++++++ mmrazor/core/tracer/path.py | 353 +++++++++++++ mmrazor/models/__init__.py | 1 - .../architectures/dynamic_op/__init__.py | 3 + .../dynamic_op/default_dynamic_ops.py | 408 +++++++++++++++ .../dynamic_op/slimmable_dynamic_ops.py | 105 ++++ mmrazor/models/mutables/__init__.py | 9 +- .../mutables/mutable_channel/__init__.py | 10 + .../one_shot_channel_mutable.py | 157 ++++++ .../mutable_channel/order_channel_mutable.py | 50 ++ .../mutable_channel/ratio_channel_mutable.py | 59 +++ .../slimmable_channel_mutable.py | 92 ++++ .../models/mutables/mutable_manager_mixin.py | 9 + mmrazor/models/mutators/__init__.py | 7 +- mmrazor/models/mutators/base_mutator.py | 22 +- .../mutators/channel_mutator/__init__.py | 8 + .../channel_mutator/channel_mutator.py | 277 ++++++++++ .../one_shot_channel_mutator.py | 135 +++++ .../slimmable_channel_mutator.py | 148 ++++++ mmrazor/models/mutators/diff_mutator.py | 6 +- mmrazor/models/mutators/one_shot_mutator.py | 6 +- mmrazor/models/pruners/__init__.py | 6 - mmrazor/models/pruners/ratio_pruning.py | 150 ------ mmrazor/models/pruners/utils/__init__.py | 4 - mmrazor/models/pruners/utils/switchable_bn.py | 38 -- tests/data/MBV2_220M.yaml | 474 ++++++++++++++++++ tests/data/MBV2_320M.yaml | 474 ++++++++++++++++++ tests/data/subnet1.yaml | 24 + tests/data/subnet2.yaml | 24 + .../test_tracer/test_backward_tracer.py | 260 ++++++++++ .../test_mutables/test_channel_mutable.py | 164 ++++++ .../test_mutables/test_dynamic_layer.py | 136 +++++ .../test_mutators/test_channel_mutator.py | 229 +++++++++ .../test_mbv2_channel_mutator.py | 107 ++++ .../test_mutators/test_diff_mutator.py | 8 +- .../test_mutators/test_one_shot_mutator.py | 8 +- 42 files changed, 4187 insertions(+), 227 deletions(-) create mode 100644 mmrazor/core/tracer/__init__.py create mode 100644 mmrazor/core/tracer/backward_tracer.py create mode 100644 mmrazor/core/tracer/loss_calculator/__init__.py create mode 100644 mmrazor/core/tracer/loss_calculator/image_classifier_loss_calculator.py create mode 100644 mmrazor/core/tracer/loss_calculator/single_stage_detector_loss_calculator.py create mode 100644 mmrazor/core/tracer/parsers.py create mode 100644 mmrazor/core/tracer/path.py create mode 100644 mmrazor/models/architectures/dynamic_op/__init__.py create mode 100644 mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py create mode 100644 mmrazor/models/architectures/dynamic_op/slimmable_dynamic_ops.py create mode 100644 mmrazor/models/mutables/mutable_channel/__init__.py create mode 100644 mmrazor/models/mutables/mutable_channel/one_shot_channel_mutable.py create mode 100644 mmrazor/models/mutables/mutable_channel/order_channel_mutable.py create mode 100644 mmrazor/models/mutables/mutable_channel/ratio_channel_mutable.py create mode 100644 mmrazor/models/mutables/mutable_channel/slimmable_channel_mutable.py create mode 100644 mmrazor/models/mutables/mutable_manager_mixin.py create mode 100644 mmrazor/models/mutators/channel_mutator/__init__.py create mode 100644 mmrazor/models/mutators/channel_mutator/channel_mutator.py create mode 100644 mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py create mode 100644 mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py delete mode 100644 mmrazor/models/pruners/__init__.py delete mode 100644 mmrazor/models/pruners/ratio_pruning.py delete mode 100644 mmrazor/models/pruners/utils/__init__.py delete mode 100644 mmrazor/models/pruners/utils/switchable_bn.py create mode 100644 tests/data/MBV2_220M.yaml create mode 100644 tests/data/MBV2_320M.yaml create mode 100644 tests/data/subnet1.yaml create mode 100644 tests/data/subnet2.yaml create mode 100644 tests/test_core/test_tracer/test_backward_tracer.py create mode 100644 tests/test_models/test_mutables/test_channel_mutable.py create mode 100644 tests/test_models/test_mutables/test_dynamic_layer.py create mode 100644 tests/test_models/test_mutators/test_channel_mutator.py create mode 100644 tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py diff --git a/mmrazor/core/__init__.py b/mmrazor/core/__init__.py index a5469d53..ad88a0c2 100644 --- a/mmrazor/core/__init__.py +++ b/mmrazor/core/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from .delivers import * # noqa: F401,F403 +from .tracer import * # noqa: F401,F403 diff --git a/mmrazor/core/tracer/__init__.py b/mmrazor/core/tracer/__init__.py new file mode 100644 index 00000000..4b2868cc --- /dev/null +++ b/mmrazor/core/tracer/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backward_tracer import BackwardTracer +from .loss_calculator import * # noqa: F401,F403 +from .parsers import * # noqa: F401,F403 +from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, Node, + NormNode, Path, PathList) + +__all__ = [ + 'BackwardTracer', 'ConvNode', 'LinearNode', 'NormNode', 'ConcatNode', + 'Path', 'PathList', 'Node', 'DepthWiseConvNode' +] diff --git a/mmrazor/core/tracer/backward_tracer.py b/mmrazor/core/tracer/backward_tracer.py new file mode 100644 index 00000000..e26080df --- /dev/null +++ b/mmrazor/core/tracer/backward_tracer.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import re +from collections import OrderedDict + +from mmcv import ConfigDict +from torch.nn import Conv2d, Linear +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _NormBase + +from mmrazor.registry import TASK_UTILS +from .parsers import DEFAULT_BACKWARD_TRACER +from .path import Path, PathList + +SUPPORT_MODULES = (Conv2d, Linear, _NormBase, GroupNorm) + + +@TASK_UTILS.register_module() +class BackwardTracer: + """A topology tracer via backward. + + Args: + loss_calculator (dict or Callable): Calculate the pseudo loss to trace + the topology of a model. + """ + + def __init__(self, loss_calculator): + if isinstance(loss_calculator, (dict, ConfigDict)): + loss_calculator = TASK_UTILS.build(loss_calculator) + + assert callable( + loss_calculator + ), 'loss_calculator should be a dict, ConfigDict or ' \ + 'callable object' + self.loss_calculator = loss_calculator + + @property + def backward_parser(self): + """The mapping from the type of a backward op to the corresponding + parser.""" + return DEFAULT_BACKWARD_TRACER + + def backward_trace(self, grad_fn, module2name, param2module, cur_path, + result_paths, visited, shared_module): + """Trace the topology of all the ``NON_PASS_MODULE``.""" + grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn + + if grad_fn is not None: + name = type(grad_fn).__name__ + # In pytorch graph, there may be an additional '0' or '1' + # (e.g. ThnnConv2DBackward0) after a backward op. Delete the + # digit numbers to build the corresponding parser. + name = re.sub(r'[0-1]+', '', name) + parse_module = self.backward_parser.get(name) + + if parse_module is not None: + parse_module(self, grad_fn, module2name, param2module, + cur_path, result_paths, visited, shared_module) + else: + # If the op is AccumulateGrad, parents is (), + parents = grad_fn.next_functions + if parents is not None: + for parent in parents: + self.backward_trace(parent, module2name, param2module, + cur_path, result_paths, visited, + shared_module) + else: + result_paths.append(copy.deepcopy(cur_path)) + + def _trace_shared_module_hook(self, module, inputs, outputs): + """Trace shared modules. Modules such as the detection head in + RetinaNet which are visited more than once during :func:`forward` are + shared modules. + + Args: + module (:obj:`torch.nn.Module`): The module to register hook. + inputs (tuple): The input of the module. + outputs (tuple): The output of the module. + """ + module._cnt += 1 + + def _build_mappings(self, model): + """Build the mappings which are used during tracing.""" + + module2name = OrderedDict() + # build a mapping from the identity of a module's params + # to this module + param2module = OrderedDict() + # record the visited module name during trace path + visited = dict() + + def traverse(module, prefix=''): + for name, child in module.named_children(): + full_name = f'{prefix}.{name}' if prefix else name + + if isinstance(child, SUPPORT_MODULES): + module2name[child] = full_name + for param in child.parameters(): + param2module[id(param)] = child + visited[full_name] = False + else: + traverse(child, full_name) + + traverse(model) + + return module2name, param2module, visited + + def _register_share_module_hook(self, model): + """Record shared modules which will be visited more than once during + forward such as shared detection head in RetinaNet. + + If a module is not a shared module and it has been visited during + forward, its parent modules must have been traced already. However, a + shared module will be visited more than once during forward, so it is + still need to be traced even if it has been visited. + """ + self._shared_module_hook_handles = list() + for module in model.modules(): + if hasattr(module, 'weight'): + # trace shared modules + module._cnt = 0 + # the handle is only to remove the corresponding hook later + handle = module.register_forward_hook( + self._trace_shared_module_hook) + self._shared_module_hook_handles.append(handle) + + def _remove_share_module_hook(self, model): + """`_trace_shared_module_hook` and `_cnt` are only used to trace the + shared modules in a model and need to be remove later.""" + for module in model.modules(): + if hasattr(module, 'weight'): + del module._cnt + + for handle in self._shared_module_hook_handles: + handle.remove() + + del self._shared_module_hook_handles + + def _set_all_requires_grad(self, model): + """Set `requires_grad` of a parameter to True to trace the whole + architecture topology.""" + self._param_requires_grad = dict() + for param in model.parameters(): + self._param_requires_grad[id(param)] = param.requires_grad + param.requires_grad = True + + def _restore_requires_grad(self, model): + """We set requires_grad to True to trace the whole architecture + topology. + + So it should be reset after that. + """ + for param in model.parameters(): + param.requires_grad = self._param_requires_grad[id(param)] + del self._param_requires_grad + + @staticmethod + def _find_share_modules(model): + """Find shared modules which will be visited more than once during + forward such as shared detection head in RetinaNet.""" + share_modules = list() + for name, module in model.named_modules(): + if hasattr(module, 'weight'): + if module._cnt > 1: + share_modules.append(name) + + return share_modules + + @staticmethod + def _reset_norm_running_stats(model): + """As we calculate the pseudo loss during tracing, we need to reset + states of parameters.""" + for module in model.modules(): + if isinstance(module, _NormBase): + module.reset_parameters() + + def trace(self, model): + """Trace trace the architecture topology of the input model.""" + module2name, param2module, visited = self._build_mappings(model) + + # Set requires_grad to True. If the `requires_grad` of a module's + # weight is False, we can not trace this module by parsing backward. + self._set_all_requires_grad(model) + + self._register_share_module_hook(model) + + pseudo_loss = self.loss_calculator(model) + + share_modules = self._find_share_modules(model) + + self._remove_share_module_hook(model) + self._restore_requires_grad(model) + + module_path_list = PathList() + + self.backward_trace(pseudo_loss.grad_fn, module2name, param2module, + Path(), module_path_list, visited, share_modules) + + self._reset_norm_running_stats(model) + + return module_path_list diff --git a/mmrazor/core/tracer/loss_calculator/__init__.py b/mmrazor/core/tracer/loss_calculator/__init__.py new file mode 100644 index 00000000..0371a713 --- /dev/null +++ b/mmrazor/core/tracer/loss_calculator/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .image_classifier_loss_calculator import ImageClassifierPseudoLoss +from .single_stage_detector_loss_calculator import \ + SingleStageDetectorPseudoLoss + +__all__ = ['ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss'] diff --git a/mmrazor/core/tracer/loss_calculator/image_classifier_loss_calculator.py b/mmrazor/core/tracer/loss_calculator/image_classifier_loss_calculator.py new file mode 100644 index 00000000..6579c317 --- /dev/null +++ b/mmrazor/core/tracer/loss_calculator/image_classifier_loss_calculator.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcls.models import ImageClassifier + +from mmrazor.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class ImageClassifierPseudoLoss: + """Calculate the pseudo loss to trace the topology of a `ImageClassifier` + in MMClassification with `BackwardTracer`.""" + + def __call__(self, model: ImageClassifier) -> torch.Tensor: + pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_output = model(pseudo_img) + return sum(pseudo_output) diff --git a/mmrazor/core/tracer/loss_calculator/single_stage_detector_loss_calculator.py b/mmrazor/core/tracer/loss_calculator/single_stage_detector_loss_calculator.py new file mode 100644 index 00000000..85f25eaf --- /dev/null +++ b/mmrazor/core/tracer/loss_calculator/single_stage_detector_loss_calculator.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet.models import BaseDetector + +from mmrazor.registry import TASK_UTILS + + +# todo: adapt to mmdet 2.0 +@TASK_UTILS.register_module() +class SingleStageDetectorPseudoLoss: + + def __call__(self, model: BaseDetector) -> torch.Tensor: + pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_output = model.forward_dummy(pseudo_img) + out = torch.tensor(0.) + for levels in pseudo_output: + out += sum([level.sum() for level in levels]) + + return out diff --git a/mmrazor/core/tracer/parsers.py b/mmrazor/core/tracer/parsers.py new file mode 100644 index 00000000..55d0ec7f --- /dev/null +++ b/mmrazor/core/tracer/parsers.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, Dict + +from .path import (ConcatNode, ConvNode, DepthWiseConvNode, LinearNode, + NormNode, Path, PathList) + + +def _is_leaf_grad_fn(grad_fn): + """Determine whether the current node is a leaf node.""" + if type(grad_fn).__name__ == 'AccumulateGrad': + return True + return False + + +def parse_conv(tracer, grad_fn, module2name, param2module, cur_path, + result_paths, visited, shared_module): + """Parse the backward of a conv layer. + + Example: + >>> conv = nn.Conv2d(3, 3, 3) + >>> pseudo_img = torch.rand(1, 3, 224, 224) + >>> out = conv(pseudo_img) + >>> out.grad_fn.next_functions + ((None, 0), (, 0), + (, 0)) + >>> # op.next_functions[0][0] is None means this ThnnConv2DBackward + >>> # op has no parents + >>> # op.next_functions[1][0].variable is the weight of this Conv2d + >>> # module + >>> # op.next_functions[2][0].variable is the bias of this Conv2d + >>> # module + """ + leaf_grad_fn = grad_fn.next_functions[1][0] + while not _is_leaf_grad_fn(leaf_grad_fn): + leaf_grad_fn = leaf_grad_fn.next_functions[0][0] + variable = leaf_grad_fn.variable + param_id = id(variable) + module = param2module[param_id] + name = module2name[module] + parent = grad_fn.next_functions[0][0] + if module.in_channels == module.groups: + cur_path.append(DepthWiseConvNode(name)) + else: + cur_path.append(ConvNode(name)) + # If a module is not a shared module and it has been visited during + # forward, its parent modules must have been traced already. + # However, a shared module will be visited more than once during + # forward, so it is still need to be traced even if it has been + # visited. + if visited[name] and name not in shared_module: + result_paths.append(copy.deepcopy(cur_path)) + else: + visited[name] = True + tracer.backward_trace(parent, module2name, param2module, cur_path, + result_paths, visited, shared_module) + cur_path.pop(-1) + + +# todo: support parsing `MultiheadAttention` and user-defined matrix +# multiplication +def parse_linear(tracer, grad_fn, module2name, param2module, cur_path, + result_paths, visited, shared_module): + """Parse the backward of a conv layer. + + Example: + >>> fc = nn.Linear(3, 3, bias=True) + >>> input = torch.rand(3, 3) + >>> out = fc(input) + >>> out.grad_fn.next_functions + ((, 0), (None, 0), + (, 0)) + >>> # op.next_functions[0][0].variable is the bias of this Linear + >>> # module + >>> # op.next_functions[1][0] is None means this AddmmBackward op + >>> # has no parents + >>> # op.next_functions[2][0] is the TBackward op, and + >>> # op.next_functions[2][0].next_functions[0][0].variable is + >>> # the transpose of the weight of this Linear module + """ + leaf_grad_fn = grad_fn.next_functions[-1][0].next_functions[0][0] + while not _is_leaf_grad_fn(leaf_grad_fn): + leaf_grad_fn = leaf_grad_fn.next_functions[0][0] + variable = leaf_grad_fn.variable + param_id = id(variable) + module = param2module[param_id] + name = module2name[module] + parent = grad_fn.next_functions[-2][0] + + cur_path.append(LinearNode(name)) + # If a module is not a shared module and it has been visited during + # forward, its parent modules must have been traced already. + # However, a shared module will be visited more than once during + # forward, so it is still need to be traced even if it has been + # visited. + if visited[name] and name not in shared_module: + result_paths.append(copy.deepcopy(cur_path)) + else: + visited[name] = True + tracer.backward_trace(parent, module2name, param2module, cur_path, + result_paths, visited, shared_module) + + +def parse_cat(tracer, grad_fn, module2name, param2module, cur_path, + result_paths, visited, shared_module): + """Parse the backward of a concat operation. + + Example: + >>> conv = nn.Conv2d(3, 3, 3) + >>> pseudo_img = torch.rand(1, 3, 224, 224) + >>> out1 = conv(pseudo_img) + >>> out2 = conv(pseudo_img) + >>> out = torch.cat([out1, out2], dim=1) + >>> out.grad_fn.next_functions + ((, 0), + (, 0)) + >>> # the length of ``out.grad_fn.next_functions`` is two means + >>> # ``out`` is obtained by concatenating two tensors + """ + parents = grad_fn.next_functions + concat_id = '_'.join([str(id(p)) for p in parents]) + name = f'concat_{concat_id}' + # If a module is not a shared module and it has been visited during + # forward, its parent modules must have been traced already. + # However, a shared module will be visited more than once during + # forward, so it is still need to be traced even if it has been + # visited. + if (name in visited and visited[name] and name not in shared_module): + pass + else: + visited[name] = True + sub_path_lists = list() + for i, parent in enumerate(parents): + sub_path_list = PathList() + tracer.backward_trace(parent, module2name, param2module, Path(), + sub_path_list, visited, shared_module) + sub_path_lists.append(sub_path_list) + cur_path.append(ConcatNode(name, sub_path_lists)) + + result_paths.append(copy.deepcopy(cur_path)) + cur_path.pop(-1) + + +def parse_norm(tracer, grad_fn, module2name, param2module, cur_path, + result_paths, visited, shared_module): + """Parse the backward of a concat operation. + + Example: + >>> conv = nn.Conv2d(3, 3, 3) + >>> pseudo_img = torch.rand(1, 3, 224, 224) + >>> out1 = conv(pseudo_img) + >>> out2 = conv(pseudo_img) + >>> out = torch.cat([out1, out2], dim=1) + >>> out.grad_fn.next_functions + ((, 0), + (, 0)) + >>> # the length of ``out.grad_fn.next_functions`` is two means + >>> # ``out`` is obtained by concatenating two tensors + """ + leaf_grad_fn = grad_fn.next_functions[1][0] + while not _is_leaf_grad_fn(leaf_grad_fn): + leaf_grad_fn = leaf_grad_fn.next_functions[0][0] + variable = leaf_grad_fn.variable + param_id = id(variable) + module = param2module[param_id] + name = module2name[module] + parent = grad_fn.next_functions[0][0] + cur_path.append(NormNode(name)) + + visited[name] = True + tracer.backward_trace(parent, module2name, param2module, cur_path, + result_paths, visited, shared_module) + cur_path.pop(-1) + + +DEFAULT_BACKWARD_TRACER: Dict[str, Callable] = { + 'ThnnConv2DBackward': parse_conv, + 'CudnnConvolutionBackward': parse_conv, + 'MkldnnConvolutionBackward': parse_conv, + 'SlowConvDilated2DBackward': parse_conv, + 'ThAddmmBackward': parse_linear, + 'AddmmBackward': parse_linear, + 'MmBackward': parse_linear, + 'CatBackward': parse_cat, + 'ThnnBatchNormBackward': parse_norm, + 'CudnnBatchNormBackward': parse_norm, + 'NativeBatchNormBackward': parse_norm, + 'NativeGroupNormBackward': parse_norm +} diff --git a/mmrazor/core/tracer/path.py b/mmrazor/core/tracer/path.py new file mode 100644 index 00000000..35731926 --- /dev/null +++ b/mmrazor/core/tracer/path.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + + +def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + +def _merge_node_parents(node2parents, _node2parents): + for node, parents in _node2parents.items(): + if node in node2parents: + cur_parents = node2parents[node] + new_parents_set = set(cur_parents + parents) + new_parents = list(new_parents_set) + node2parents[node] = new_parents + else: + node2parents[node] = parents + + +class Node: + """``Node`` is the data structure that represents individual instances + within a ``Path``. It corresponds to a module or an operation such as + concatenation in the model. + + Args: + name (str): Unique identifier of a node. + """ + + def __init__(self, name: str) -> None: + self._name = name + + def get_module_names(self) -> List: + return [self.name] + + @property + def name(self) -> str: + """Get the name of current node.""" + return self._name + + def _get_class_name(self): + return self.__class__.__name__ + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.name == other.name + else: + return False + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return f'{self._get_class_name()}(\'{self.name}\')' + + +class ConvNode(Node): + """A `ConvNode` corresponds to a Conv module in the original model.""" + pass + + +class DepthWiseConvNode(Node): + """A `DepthWiseConvNode` corresponds to a depth-wise conv module in the + original model.""" + pass + + +class NormNode(Node): + """A `NormNode` corresponds to a normalization module in the original + model.""" + pass + + +class LinearNode(Node): + """A `LinearNode` corresponds to a linear module in the original model.""" + pass + + +class Path: + """``Path`` is the data structure that represents a list of ``Node`` traced + by a tracer. + + Args: + nodes(:obj:`Node` or List[:obj:`Node`], optional): Nodes in a path. + Default to None. + """ + + def __init__(self, nodes: Optional[Union[Node, List[Node]]] = None): + self._nodes: List[Node] = list() + if nodes is not None: + if isinstance(nodes, Node): + nodes = [nodes] + assert isinstance(nodes, (list, tuple)) + for node in nodes: + assert isinstance(node, Node) + self._nodes.append(node) + + def get_root_names(self) -> List[str]: + """Get the name of the first node in a path.""" + return self._nodes[0].get_module_names() + + def find_nodes_parents(self, + target_nodes: Tuple, + non_pass: Optional[Tuple] = None) -> Dict: + """Find the parents of a specific node. + + Args: + target_nodes (Tuple): Find the parents of nodes whose types + are one of `target_nodes`. + non_pass (Tuple): Ancestor nodes whose types are one of + `non_pass` are the parents of a specific node. Default to None. + """ + node2parents: Dict[str, List[Node]] = dict() + for i, node in enumerate(self._nodes): + if isinstance(node, ConcatNode): + _node2parents: Dict[str, List[Node]] = node.find_nodes_parents( + target_nodes, non_pass) + _merge_node_parents(node2parents, _node2parents) + continue + + if isinstance(node, target_nodes): + parents = list() + for behind_node in self._nodes[i + 1:]: + if non_pass is None or isinstance(behind_node, non_pass): + parents.append(behind_node) + break + _node2parents = {node.name: parents} + _merge_node_parents(node2parents, _node2parents) + return node2parents + + @property + def nodes(self) -> List: + """Return a list of nodes in the current path.""" + return self._nodes + + def append(self, x: Node) -> None: + """Add a node to the end of the current path.""" + assert isinstance(x, Node) + self._nodes.append(x) + + def pop(self, *args, **kwargs): + """Temoves the node at the given index from the path and returns the + removed node.""" + return self._nodes.pop(*args, **kwargs) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.nodes == other.nodes + else: + return False + + def __len__(self): + return len(self._nodes) + + def __getitem__(self, item): + return self._nodes[item] + + def __iter__(self): + for node in self._nodes: + yield node + + def _get_class_name(self) -> str: + """Get the name of the current class.""" + return self.__class__.__name__ + + def __repr__(self): + child_lines = [] + for node in self._nodes: + node_str = repr(node) + node_str = _addindent(node_str, 2) + child_lines.append(node_str) + lines = child_lines + + main_str = self._get_class_name() + '(' + if lines: + main_str += '\n ' + ',\n '.join(lines) + '\n' + main_str += ')' + return main_str + + +class PathList: + """``PathList`` is the data structure that represents a list of ``Path`` + traced by a tracer. + + Args: + paths(:obj:`Path` or List[:obj:`Path`], optional): A list of `Path`. + Default to None. + """ + + def __init__(self, paths: Optional[Union[Path, List[Path]]] = None): + self._paths = list() + if paths is not None: + if isinstance(paths, Path): + paths = [paths] + assert isinstance(paths, (list, tuple)) + for path in paths: + assert isinstance(path, Path) + self._paths.append(path) + + def get_root_names(self) -> List[str]: + """Get the root node of all the paths in `PathList`. + + Notes: + Different paths in a PathList share the same root node. + """ + return self._paths[0].get_root_names() + + def find_nodes_parents(self, + target_nodes: Tuple, + non_pass: Optional[Tuple] = None): + """Find the parents of a specific node. + + Args: + target_nodes (Tuple): Find the parents of nodes whose types + are one of `target_nodes`. + non_pass (Tuple): Ancestor nodes whose types are one of + `non_pass` are the parents of a specific node. Default to None. + """ + node2parents: Dict[str, List[Node]] = dict() + for p in self._paths: + _node2parents = p.find_nodes_parents(target_nodes, non_pass) + _merge_node_parents(node2parents, _node2parents) + return node2parents + + def append(self, x: Path) -> None: + """Add a path to the end of the current PathList.""" + assert isinstance(x, Path) + self._paths.append(x) + + @property + def paths(self): + """Return all paths in the current PathList.""" + return self._paths + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.paths == other.paths + else: + return False + + def __len__(self): + return len(self._paths) + + def __getitem__(self, item): + return self._paths[item] + + def __iter__(self): + for path in self._paths: + yield path + + def _get_class_name(self) -> str: + """Get the name of the current class.""" + return self.__class__.__name__ + + def __repr__(self): + child_lines = [] + for node in self._paths: + node_str = repr(node) + node_str = _addindent(node_str, 2) + child_lines.append(node_str) + lines = child_lines + + main_str = self._get_class_name() + '(' + if lines: + main_str += '\n ' + ',\n '.join(lines) + '\n' + main_str += ')' + return main_str + + +class ConcatNode(Node): + """``ConcatNode`` is the data structure that represents the concatenation + operation in a model. + + Args: + name (str): Unique identifier of a `ConcatNode`. + path_lists (List[PathList]): Several nodes are concatenated and each + node is the root node of all the paths in a `PathList` + (one of `path_lists`). + """ + + def __init__(self, name: str, path_lists: List[PathList]): + super().__init__(name) + self._path_lists = list() + for path_list in path_lists: + assert isinstance(path_list, PathList) + self._path_lists.append(path_list) + + def get_module_names(self) -> List[str]: + """Several nodes are concatenated. + + Get the names of these nodes. + """ + module_names = list() + for path_list in self._path_lists: + module_names.extend(path_list.get_root_names()) + return module_names + + def find_nodes_parents(self, + target_nodes: Tuple, + non_pass: Optional[Tuple] = None): + """Find the parents of a specific node. + + Args: + target_nodes (Tuple): Find the parents of nodes whose types + are one of `target_nodes`. + non_pass (Tuple): Ancestor nodes whose types are one of + `non_pass` are the parents of a specific node. Default to None. + """ + node2parents: Dict[str, List[Node]] = dict() + for p in self._path_lists: + _node2parents = p.find_nodes_parents(target_nodes, non_pass) + _merge_node_parents(node2parents, _node2parents) + return node2parents + + @property + def path_lists(self) -> List[PathList]: + """Return all the path_list.""" + return self._path_lists + + def __len__(self): + return len(self._path_lists) + + def __getitem__(self, item): + return self._path_lists[item] + + def __iter__(self): + for path_list in self._path_lists: + yield path_list + + def _get_class_name(self) -> str: + """Get the name of the current class.""" + return self.__class__.__name__ + + def __repr__(self): + child_lines = [] + for node in self._path_lists: + node_str = repr(node) + node_str = _addindent(node_str, 2) + child_lines.append(node_str) + lines = child_lines + + main_str = self._get_class_name() + '(' + if lines: + main_str += '\n ' + ',\n '.join(lines) + '\n' + main_str += ')' + return main_str diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index 38e66e89..a77f02a9 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -6,5 +6,4 @@ from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 from .mutators import * # noqa: F401,F403 from .ops import * # noqa: F401,F403 -from .pruners import * # noqa: F401,F403 from .subnet import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/dynamic_op/__init__.py b/mmrazor/models/architectures/dynamic_op/__init__.py new file mode 100644 index 00000000..b5b1f855 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .default_dynamic_ops import * # noqa: F401,F403 +from .slimmable_dynamic_ops import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py b/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py new file mode 100644 index 00000000..285a9b00 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm + +from mmrazor.models.mutables import MutableManagerMixIn +from mmrazor.registry import MODELS + + +class DynamicConv2d(nn.Conv2d, MutableManagerMixIn): + """Applies a 2D convolution over an input signal composed of several input + planes according to the `mutable_in_channels` and `mutable_out_channels` + dynamically. + + Args: + module_name (str): Name of this `DynamicConv2d`. + in_channels_cfg (Dict): Config related to `in_channels`. + out_channels_cfg (Dict): Config related to `out_channels`. + """ + + def __init__(self, module_name, in_channels_cfg, out_channels_cfg, *args, + **kwargs): + super(DynamicConv2d, self).__init__(*args, **kwargs) + + in_channels_cfg = copy.deepcopy(in_channels_cfg) + in_channels_cfg.update( + dict( + name=module_name, + num_channels=self.in_channels, + mask_type='in_mask')) + self.mutable_in_channels = MODELS.build(in_channels_cfg) + + out_channels_cfg = copy.deepcopy(out_channels_cfg) + out_channels_cfg.update( + dict( + name=module_name, + num_channels=self.out_channels, + mask_type='out_mask')) + self.mutable_out_channels = MODELS.build(out_channels_cfg) + + @property + def mutable_in(self): + """Mutable `in_channels`.""" + return self.mutable_in_channels + + @property + def mutable_out(self): + """Mutable `out_channels`.""" + return self.mutable_out_channels + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_in_channels` and + `mutable_out_channels`, and forward.""" + in_mask = self.mutable_in_channels.mask + out_mask = self.mutable_out_channels.mask + + if self.groups == 1: + weight = self.weight[out_mask][:, in_mask] + groups = 1 + elif self.groups == self.in_channels == self.out_channels: + # depth-wise conv + weight = self.weight[out_mask] + groups = input.size(1) + else: + raise NotImplementedError( + 'Current `ChannelMutator` only support pruning the depth-wise ' + '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' + 'to one, but got {self.groups}.') + + bias = self.bias[out_mask] if self.bias is not None else None + + return F.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, groups) + + +class DynamicLinear(nn.Linear, MutableManagerMixIn): + """Applies a linear transformation to the incoming data according to the + `mutable_in_features` and `mutable_out_features` dynamically. + + Args: + module_name (str): Name of this `DynamicLinear`. + in_features_cfg (Dict): Config related to `in_features`. + out_features_cfg (Dict): Config related to `out_features`. + """ + + def __init__(self, module_name, in_features_cfg, out_features_cfg, *args, + **kwargs): + super(DynamicLinear, self).__init__(*args, **kwargs) + + in_features_cfg = copy.deepcopy(in_features_cfg) + in_features_cfg.update( + dict( + name=module_name, + num_channels=self.in_features, + mask_type='in_mask')) + self.mutable_in_features = MODELS.build(in_features_cfg) + + out_features_cfg = copy.deepcopy(out_features_cfg) + out_features_cfg.update( + dict( + name=module_name, + num_channels=self.out_features, + mask_type='out_mask')) + self.mutable_out_features = MODELS.build(out_features_cfg) + + @property + def mutable_in(self): + """Mutable `in_features`.""" + return self.mutable_in_features + + @property + def mutable_out(self): + """Mutable `out_features`.""" + return self.mutable_out_features + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_in_features` and + `mutable_out_features`, and forward.""" + in_mask = self.mutable_in_features.mask + out_mask = self.mutable_out_features.mask + + weight = self.weight[out_mask][:, in_mask] + bias = self.bias[out_mask] if self.bias is not None else None + + return F.linear(input, weight, bias) + + +class DynamicBatchNorm(_BatchNorm, MutableManagerMixIn): + """Applies Batch Normalization over an input according to the + `mutable_num_features` dynamically. + + Args: + module_name (str): Name of this `DynamicBatchNorm`. + num_features_cfg (Dict): Config related to `num_features`. + """ + + def __init__(self, module_name, num_features_cfg, *args, **kwargs): + super(DynamicBatchNorm, self).__init__(*args, **kwargs) + + num_features_cfg = copy.deepcopy(num_features_cfg) + num_features_cfg.update( + dict( + name=module_name, + num_channels=self.num_features, + mask_type='out_mask')) + self.mutable_num_features = MODELS.build(num_features_cfg) + + @property + def mutable_in(self): + """Mutable `num_features`.""" + return self.mutable_num_features + + @property + def mutable_out(self): + """Mutable `num_features`.""" + return self.mutable_num_features + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_num_features`, and + forward.""" + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: # type: ignore + self.num_batches_tracked = \ + self.num_batches_tracked + 1 # type: ignore + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is + None) + + if self.affine: + out_mask = self.mutable_num_features.mask + weight = self.weight[out_mask] + bias = self.bias[out_mask] + else: + weight, bias = self.weight, self.bias + + if self.track_running_stats: + out_mask = self.mutable_num_features.mask + running_mean = self.running_mean[out_mask] \ + if not self.training or self.track_running_stats else None + running_var = self.running_var[out_mask] \ + if not self.training or self.track_running_stats else None + else: + running_mean, running_var = self.running_mean, self.running_var + + return F.batch_norm(input, running_mean, running_var, weight, bias, + bn_training, exponential_average_factor, self.eps) + + +class DynamicInstanceNorm(_InstanceNorm, MutableManagerMixIn): + """Applies Instance Normalization over an input according to the + `mutable_num_features` dynamically. + + Args: + module_name (str): Name of this `DynamicInstanceNorm`. + num_features_cfg (Dict): Config related to `num_features`. + """ + + def __init__(self, module_name, num_features_cfg, *args, **kwargs): + super(DynamicInstanceNorm, self).__init__(*args, **kwargs) + + num_features_cfg = copy.deepcopy(num_features_cfg) + num_features_cfg.update( + dict( + name=module_name, + num_channels=self.num_features, + mask_type='out_mask')) + self.mutable_num_features = MODELS.build(num_features_cfg) + + @property + def mutable_in(self): + """Mutable `num_features`.""" + return self.mutable_num_features + + @property + def mutable_out(self): + """Mutable `num_features`.""" + return self.mutable_num_features + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_num_features`, and + forward.""" + if self.affine: + out_mask = self.mutable_num_features.mask + weight = self.weight[out_mask] + bias = self.bias[out_mask] + else: + weight, bias = self.weight, self.bias + + if self.track_running_stats: + out_mask = self.mutable_num_features.mask + running_mean = self.running_mean[out_mask] + running_var = self.running_var[out_mask] + else: + running_mean, running_var = self.running_mean, self.running_var + + return F.instance_norm(input, running_mean, running_var, weight, bias, + self.training or not self.track_running_stats, + self.momentum, self.eps) + + +class DynamicGroupNorm(GroupNorm, MutableManagerMixIn): + """Applies Group Normalization over a mini-batch of inputs according to the + `mutable_num_channels` dynamically. + + Args: + module_name (str): Name of this `DynamicGroupNorm`. + num_channels_cfg (Dict): Config related to `num_channels`. + """ + + def __init__(self, module_name, num_channels_cfg, *args, **kwargs): + super(DynamicGroupNorm, self).__init__(*args, **kwargs) + + num_channels_cfg = copy.deepcopy(num_channels_cfg) + num_channels_cfg.update( + dict( + name=module_name, + num_channels=self.num_channels, + mask_type='out_mask')) + self.mutable_num_channels = MODELS.build(num_channels_cfg) + + @property + def mutable_in(self): + """Mutable `num_channels`.""" + return self.mutable_num_channels + + @property + def mutable_out(self): + """Mutable `num_channels`.""" + return self.mutable_num_channels + + def forward(self, input: Tensor) -> Tensor: + """Slice the parameters according to `mutable_num_channels`, and + forward.""" + if self.affine: + out_mask = self.mutable_num_channels.mask + weight = self.weight[out_mask] + bias = self.bias[out_mask] + else: + weight, bias = self.weight, self.bias + + return F.group_norm(input, self.num_groups, weight, bias, self.eps) + + +def build_dynamic_conv2d(module: nn.Conv2d, module_name: str, + in_channels_cfg: Dict, + out_channels_cfg: Dict) -> DynamicConv2d: + """Build DynamicConv2d. + + Args: + module (:obj:`torch.nn.Conv2d`): The original Conv2d module. + module_name (str): Name of this `DynamicConv2d`. + in_channels_cfg (Dict): Config related to `in_channels`. + out_channels_cfg (Dict): Config related to `out_channels`. + """ + dynamic_conv = DynamicConv2d( + module_name=module_name, + in_channels_cfg=in_channels_cfg, + out_channels_cfg=out_channels_cfg, + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) + return dynamic_conv + + +def build_dynamic_linear(module: nn.Linear, module_name: str, + in_features_cfg: Dict, + out_features_cfg: Dict) -> DynamicLinear: + """Build DynamicLinear. + + Args: + module (:obj:`torch.nn.Linear`): The original Linear module. + module_name (str): Name of this `DynamicLinear`. + in_features_cfg (Dict): Config related to `in_features`. + out_features_cfg (Dict): Config related to `out_features`. + """ + dynamic_linear = DynamicLinear( + module_name=module_name, + in_features_cfg=in_features_cfg, + out_features_cfg=out_features_cfg, + in_features=module.in_features, + out_features=module.out_features, + bias=True if module.bias is not None else False) + return dynamic_linear + + +def build_dynamic_bn(module: _BatchNorm, module_name: str, + num_features_cfg: Dict) -> DynamicBatchNorm: + """Build DynamicBatchNorm. + + Args: + module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module. + module_name (str): Name of this `DynamicBatchNorm`. + num_features_cfg (Dict): Config related to `num_features`. + """ + dynamic_bn = DynamicBatchNorm( + module_name=module_name, + num_features_cfg=num_features_cfg, + num_features=module.num_features, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats) + return dynamic_bn + + +def build_dynamic_in(module: _InstanceNorm, module_name: str, + num_features_cfg: Dict) -> DynamicInstanceNorm: + """Build DynamicInstanceNorm. + + Args: + module (:obj:`torch.nn._InstanceNorm`): The original InstanceNorm + module. + module_name (str): Name of this `DynamicInstanceNorm`. + num_features_cfg (Dict): Config related to `num_features`. + """ + dynamic_in = DynamicInstanceNorm( + module_name=module_name, + num_features_cfg=num_features_cfg, + num_features=module.num_features, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats) + return dynamic_in + + +def build_dynamic_gn(module: GroupNorm, module_name: str, + num_channels_cfg: Dict) -> DynamicGroupNorm: + """Build DynamicGroupNorm. + + Args: + module (:obj:`torch.nn.GroupNorm`): The original GroupNorm module. + module_name (str): Name of this `DynamicGroupNorm`. + num_channels_cfg (Dict): Config related to `num_channels`. + """ + dynamic_gn = DynamicGroupNorm( + module_name=module_name, + num_channels_cfg=num_channels_cfg, + num_channels=module.num_channels, + num_groups=module.num_groups, + eps=module.eps, + affine=module.affine) + return dynamic_gn diff --git a/mmrazor/models/architectures/dynamic_op/slimmable_dynamic_ops.py b/mmrazor/models/architectures/dynamic_op/slimmable_dynamic_ops.py new file mode 100644 index 00000000..4445f503 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/slimmable_dynamic_ops.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, Dict + +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutables import MutableManagerMixIn +from mmrazor.registry import MODELS +from .default_dynamic_ops import build_dynamic_conv2d, build_dynamic_linear + + +class SwitchableBatchNorm2d(nn.Module, MutableManagerMixIn): + """Employs independent batch normalization for different switches in a + slimmable network. + + To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all + batch normalization layers for each switch in a slimmable network. + Compared with the naive training approach, it solves the problem of feature + aggregation inconsistency between different switches by independently + normalizing the feature mean and variance during testing. + + Args: + module_name (str): Name of this `SwitchableBatchNorm2d`. + num_features_cfg (Dict): Config related to `num_features`. + eps (float): A value added to the denominator for numerical stability. + Same as that in :obj:`torch.nn._BatchNorm`. Default: 1e-5 + momentum (float): The value used for the running_mean and running_var + computation. Can be set to None for cumulative moving average + (i.e. simple average). Same as that in :obj:`torch.nn._BatchNorm`. + Default: 0.1 + affine (bool): A boolean value that when set to True, this module has + learnable affine parameters. Same as that in + :obj:`torch.nn._BatchNorm`. Default: True + track_running_stats (bool): A boolean value that when set to True, this + module tracks the running mean and variance, and when set to False, + this module does not track such statistics, and initializes + statistics buffers running_mean and running_var as None. When these + buffers are None, this module always uses batch statistics. in both + training and eval modes. Same as that in + :obj:`torch.nn._BatchNorm`. Default: True + """ + + def __init__(self, + module_name: str, + num_features_cfg: Dict, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True): + super(SwitchableBatchNorm2d, self).__init__() + + num_features_cfg = copy.deepcopy(num_features_cfg) + candidate_choices = num_features_cfg.pop('candidate_choices') + num_features_cfg.update( + dict( + name=module_name, + num_channels=max(candidate_choices), + mask_type='out_mask')) + + bns = [ + nn.BatchNorm2d(num_features, eps, momentum, affine, + track_running_stats) + for num_features in candidate_choices + ] + self.bns = nn.ModuleList(bns) + + self.mutable_num_features = MODELS.build(num_features_cfg) + + @property + def mutable_out(self): + """Mutable `num_features`.""" + return self.mutable_num_features + + def forward(self, input): + """Forward computation according to the current switch of the slimmable + networks.""" + idx = self.mutable_num_features.current_choice + return self.bns[idx](input) + + +def build_switchable_bn(module: _BatchNorm, module_name: str, + num_features_cfg: Dict) -> SwitchableBatchNorm2d: + """Build SwitchableBatchNorm2d. + + Args: + module (:obj:`torch.nn.GroupNorm`): The original BatchNorm module. + module_name (str): Name of this `SwitchableBatchNorm2d`. + num_channels_cfg (Dict): Config related to `num_features`. + """ + switchable_bn = SwitchableBatchNorm2d( + module_name=module_name, + num_features_cfg=num_features_cfg, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats) + return switchable_bn + + +SLIMMABLE_DYNAMIC_LAYER: Dict[Callable, Callable] = { + nn.Conv2d: build_dynamic_conv2d, + nn.Linear: build_dynamic_linear, + nn.BatchNorm2d: build_switchable_bn +} diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 35875f4b..3c71d272 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .diff_mutable import (DiffChoiceRoute, DiffMutable, DiffOP, GumbelChoiceRoute) +from .mutable_channel import (OneShotChannelMutable, OrderChannelMutable, + RatioChannelMutable, SlimmableChannelMutable) +from .mutable_manager_mixin import MutableManagerMixIn from .oneshot_mutable import OneShotMutable, OneShotOP __all__ = [ - 'OneShotOP', 'OneShotMutable', 'DiffOP', 'DiffChoiceRoute', - 'GumbelChoiceRoute', 'DiffMutable' + 'OneShotOP', 'OneShotMutable', 'OneShotChannelMutable', + 'RatioChannelMutable', 'OrderChannelMutable', 'DiffOP', 'DiffChoiceRoute', + 'GumbelChoiceRoute', 'DiffMutable', 'MutableManagerMixIn', + 'SlimmableChannelMutable' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py new file mode 100644 index 00000000..9f5af764 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .one_shot_channel_mutable import OneShotChannelMutable +from .order_channel_mutable import OrderChannelMutable +from .ratio_channel_mutable import RatioChannelMutable +from .slimmable_channel_mutable import SlimmableChannelMutable + +__all__ = [ + 'OneShotChannelMutable', 'OrderChannelMutable', 'RatioChannelMutable', + 'SlimmableChannelMutable' +] diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_channel_mutable.py b/mmrazor/models/mutables/mutable_channel/one_shot_channel_mutable.py new file mode 100644 index 00000000..7fb313f5 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/one_shot_channel_mutable.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +import numpy as np +import torch +from mmcv.runner import BaseModule + + +class OneShotChannelMutable(BaseModule, ABC): + """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In + single path supernet, each module only has one choice invoked at the same + time. A path is obtained by sampling all the available choices. It is the + base class for one shot channel mutables. + + Args: + name (str): Mutable name. + mask_type (str): One of 'in_mask' or 'out_mask'. + num_channels (int): The raw number of channels. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + name: str, + mask_type: str, + num_channels: int, + init_cfg: Optional[Dict] = None): + super(OneShotChannelMutable, self).__init__(init_cfg=init_cfg) + # If the input of a module is a concatenation of several modules' + # outputs, we add the out_mutable (mask_type == 'out_mask') of + # these modules to the `concat_mutables` of this module. + self.concat_mutables: List[OneShotChannelMutable] = list() + self.name = name + assert mask_type in ('in_mask', 'out_mask') + self.mask_type = mask_type + self.num_channels = num_channels + self.register_buffer('_mask', torch.ones((num_channels, )).bool()) + self._current_choice = num_channels + + self._same_mutables: List[OneShotChannelMutable] = list() + + @property + def same_mutables(self): + """Mutables in `same_mutables` and the current mutable should change + Synchronously.""" + return self._same_mutables + + def register_same_mutable(self, mutable): + """Register the input mutable in `same_mutables`.""" + if isinstance(mutable, list): + # Add a concatenation of mutables to `concat_mutables`. + assert self.mask_type == 'in_mask' + assert all([ + cur_mutable.mask_type == 'out_mask' for cur_mutable in mutable + ]) + self.concat_mutables = mutable + return + + if self == mutable: + return + if mutable in self._same_mutables: + return + + self._same_mutables.append(mutable) + for s_mutable in self._same_mutables: + s_mutable.register_same_mutable(mutable) + mutable.register_same_mutable(s_mutable) + + def sample_choice(self) -> int: + """Sample an arbitrary selection from candidate choices. + + Returns: + int: The chosen number of channels. + """ + assert len(self.concat_mutables) == 0 + num_channels = np.random.choice(self.choices) + assert num_channels > 0, \ + f'Sampled number of channels in `Mutable` {self.name}' \ + f' should be a positive integer.' + return num_channels + + @property + def min_choice(self) -> int: + """Minimum number of channels.""" + assert len(self.concat_mutables) == 0 + min_channels = min(self.choices) + assert min_channels > 0, \ + f'Minimum number of channels in `Mutable` {self.name}' \ + f' should be a positive integer.' + return min_channels + + @property + def max_choice(self) -> int: + """Maximum number of channels.""" + return max(self.choices) + + def get_choice(self, idx: int) -> int: + """Get the `idx`-th choice from candidate choices.""" + assert len(self.concat_mutables) == 0 + num_channels = self.choices[idx] + assert num_channels > 0, \ + f'Number of channels in `Mutable` {self.name}' \ + f' should be a positive integer.' + return num_channels + + @property + def current_choice(self): + """The current choice of the mutable.""" + if len(self.concat_mutables) > 0: + return sum( + [mutable.current_choice for mutable in self.concat_mutables]) + else: + return self._current_choice + + @current_choice.setter + def current_choice(self, choice: int): + """Set the current choice of the mutable.""" + assert choice in self.choices + self._current_choice = choice + + @property + @abstractmethod + def choices(self) -> List[int]: + """list: all choices. """ + + @property + def mask(self): + """The current mask. + + We slice the registered parameters and buffers of a ``nn.Module`` + according to the mask of the corresponding channel mutable. + """ + if len(self.concat_mutables) > 0: + # If the input of a module is a concatenation of several modules' + # outputs, the in_mask of this module is the concatenation of + # these modules' out_mask. + return torch.cat( + [mutable.mask for mutable in self.concat_mutables]) + else: + num_channels = self.current_choice + mask = torch.zeros_like(self._mask).bool() + mask[:num_channels] = True + return mask + + def __repr__(self): + concat_mutable_name = [ + mutable.name for mutable in self.concat_mutables + ] + repr_str = self.__class__.__name__ + repr_str += f'(name={self.name}, ' + repr_str += f'mask_type={self.mask_type}, ' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'concat_mutable_name={concat_mutable_name})' + return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/order_channel_mutable.py b/mmrazor/models/mutables/mutable_channel/order_channel_mutable.py new file mode 100644 index 00000000..1c1775d0 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/order_channel_mutable.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +from mmrazor.registry import MODELS +from .one_shot_channel_mutable import OneShotChannelMutable + + +@MODELS.register_module() +class OrderChannelMutable(OneShotChannelMutable): + """A type of ``OneShotChannelMutable``. The input candidate choices are + candidate channel numbers. + + Args: + name (str): Mutable name. + mask_type (str): One of 'in_mask' or 'out_mask'. + num_channels (int): The raw number of channels. + candidate_choices (list | tuple): A list or tuple of candidate + channel numbers. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + name: str, + mask_type: str, + num_channels: int, + candidate_choices: Union[List, Tuple], + init_cfg: Optional[Dict] = None): + super(OrderChannelMutable, self).__init__( + name, mask_type, num_channels, init_cfg=init_cfg) + + assert len(candidate_choices) > 0, \ + f'Number of candidate choices must be greater than 0, ' \ + f'but got: {len(candidate_choices)}' + self._candidate_choices = list(candidate_choices) + + assert all([num > 0 and num <= self.num_channels + for num in self._candidate_choices]), \ + f'The candidate channel numbers should be in ' \ + f'range(0, {self.num_channels}].' + assert all([isinstance(num, int) + for num in self._candidate_choices]),\ + 'Type of `candidate_choices` should be int.' + + @property + def choices(self) -> List[int]: + """list: all choices. """ + return self._candidate_choices diff --git a/mmrazor/models/mutables/mutable_channel/ratio_channel_mutable.py b/mmrazor/models/mutables/mutable_channel/ratio_channel_mutable.py new file mode 100644 index 00000000..d80a8785 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/ratio_channel_mutable.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +from mmrazor.registry import MODELS +from .one_shot_channel_mutable import OneShotChannelMutable + + +@MODELS.register_module() +class RatioChannelMutable(OneShotChannelMutable): + """A type of ``OneShotChannelMutable``. The input candidate choices are + candidate width ratios. + + Notes: + We first calculate the candidate channel numbers according to + the input candidate ratios (`candidate_choices`) and regard them as + available choices. + + Args: + name (str): Mutable name. + mask_type (str): One of 'in_mask' or 'out_mask'. + num_channels (int): The raw number of channels. + candidate_choices (list | tuple): A list or tuple of candidate width + ratios. The width ratio is the ratio between the number of reserved + channels and that of all channels in a layer. + For example, if `ratios` is [0.25, 0.5], there are 2 cases + for us to choose from when we sample from a layer with 12 channels. + One is sampling the very first 3 channels in this layer, another is + sampling the very first 6 channels in this layer. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + name: str, + mask_type: str, + num_channels: int, + candidate_choices: Union[List, Tuple], + init_cfg: Optional[Dict] = None): + super(RatioChannelMutable, self).__init__( + name, mask_type, num_channels, init_cfg=init_cfg) + + assert len(candidate_choices) > 0, \ + f'Number of candidate choices must be greater than 0, ' \ + f'but got: {len(candidate_choices)}' + self._candidate_choices = candidate_choices + + assert all([ + ratio > 0 and ratio <= 1 for ratio in self._candidate_choices + ]), 'The candidate ratio should be in range(0, 1].' + + @property + def choices(self) -> List[int]: + """list: all choices. """ + return [ + round(ratio * self.num_channels) + for ratio in self._candidate_choices + ] diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_channel_mutable.py b/mmrazor/models/mutables/mutable_channel/slimmable_channel_mutable.py new file mode 100644 index 00000000..f5547ac0 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/slimmable_channel_mutable.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch +from mmcv.runner import BaseModule + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class SlimmableChannelMutable(BaseModule): + """A type of ``MUTABLES`` to train several subnet together, such as the + retraining stage in AutoSlim. + + Notes: + We need to set `candidate_choices` after the instantiation of a + `SlimmableChannelMutable` by ourselves. + + Args: + name (str): Mutable name. + mask_type (str): One of 'in_mask' or 'out_mask'. + num_channels (int): The raw number of channels. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def __init__(self, + name: str, + mask_type: str, + num_channels: int, + init_cfg: Optional[Dict] = None): + super(SlimmableChannelMutable, self).__init__(init_cfg=init_cfg) + + self.name = name + assert mask_type in ('in_mask', 'out_mask') + self.mask_type = mask_type + self.num_channels = num_channels + self.register_buffer('_mask', torch.ones((num_channels, )).bool()) + self._current_choice = 0 + + @property + def candidate_choices(self) -> List: + """A list of candidate channel numbers.""" + return self._candidate_choices + + @candidate_choices.setter + def candidate_choices(self, choices): + """Set the candidate channel numbers.""" + assert getattr(self, '_candidate_choices', None) is None, \ + f'candidate_choices can be set only when candidate_choices is ' \ + f'None, got: candidate_choices = {self._candidate_choices}' + + assert all([num > 0 and num <= self.num_channels + for num in choices]), \ + f'The candidate channel numbers should be in ' \ + f'range(0, {self.num_channels}].' + assert all([isinstance(num, int) for num in choices]), \ + 'Type of `candidate_choices` should be int.' + + self._candidate_choices = list(choices) + + @property + def choices(self) -> List: + """Return all subnet indexes.""" + assert self._candidate_choices is not None + return list(range(len(self.candidate_choices))) + + @property + def current_choice(self) -> int: + """The current choice of the mutable.""" + return self._current_choice + + @current_choice.setter + def current_choice(self, choice: int): + """Set the current choice of the mutable.""" + assert choice in self.choices + self._current_choice = choice + + @property + def mask(self): + """The current mask. + + We slice the registered parameters and buffers of a ``nn.Module`` + according to the mask of the corresponding channel mutable. + """ + idx = self.current_choice + num_channels = self.candidate_choices[idx] + mask = torch.zeros_like(self._mask).bool() + mask[:num_channels] = True + return mask diff --git a/mmrazor/models/mutables/mutable_manager_mixin.py b/mmrazor/models/mutables/mutable_manager_mixin.py new file mode 100644 index 00000000..19830ec3 --- /dev/null +++ b/mmrazor/models/mutables/mutable_manager_mixin.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +class MutableManagerMixIn: + """Mixin class for determining whether an object is a dynamic layer. + + Note that a dynamic layer manage one or several mutables. + """ + pass diff --git a/mmrazor/models/mutators/__init__.py b/mmrazor/models/mutators/__init__.py index c7855ef2..26b4d0e0 100644 --- a/mmrazor/models/mutators/__init__.py +++ b/mmrazor/models/mutators/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .channel_mutator.one_shot_channel_mutator import OneShotChannelMutator +from .channel_mutator.slimmable_channel_mutator import SlimmableChannelMutator from .diff_mutator import DiffMutator from .one_shot_mutator import OneShotMutator -__all__ = ['OneShotMutator', 'DiffMutator'] +__all__ = [ + 'OneShotMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator', + 'DiffMutator' +] diff --git a/mmrazor/models/mutators/base_mutator.py b/mmrazor/models/mutators/base_mutator.py index b87ec691..adb53e37 100644 --- a/mmrazor/models/mutators/base_mutator.py +++ b/mmrazor/models/mutators/base_mutator.py @@ -38,7 +38,7 @@ class BaseMutator(ABC, BaseModule): @property @abstractmethod - def search_group(self) -> Dict: + def search_groups(self) -> Dict: """Search group of the supernet. Note: @@ -76,7 +76,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): if custom_group is None: custom_group = [] self._custom_group = custom_group - self._search_group: Optional[Dict[int, List[MUTABLE_TYPE]]] = None + self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None # TODO # should be a class property @@ -99,10 +99,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): supernet (:obj:`torch.nn.Module`): The supernet to be searched in your algorithm. """ - self._build_search_group(supernet) + self._build_search_groups(supernet) @property - def search_group(self) -> Dict[int, List[MUTABLE_TYPE]]: + def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: """Search group of supernet. Note: @@ -115,10 +115,10 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): Returns: Dict[int, List[MUTABLE_TYPE]]: Search group. """ - if self._search_group is None: + if self._search_groups is None: raise RuntimeError( 'Call `prepare_from_supernet` before access search group!') - return self._search_group + return self._search_groups def _build_name_mutable_mapping( self, supernet: Module) -> Dict[str, MUTABLE_TYPE]: @@ -143,7 +143,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): return alias2mutable_names - def _build_search_group(self, supernet: Module) -> None: + def _build_search_groups(self, supernet: Module) -> None: """Build search group with ``custom_group`` and ``alias``(see more information in :class:`BaseMutable`). Grouping by alias and module name are both supported. @@ -176,20 +176,20 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): >>> # Using alias for grouping >>> mutator = DiffOP(custom_group=[['a1'], ['a2']]) >>> mutator.prepare_from_supernet(model) - >>> mutator.search_group + >>> mutator.search_groups {0: [op1, op2], 1: [op3]} >>> # Using module name for grouping >>> mutator = DiffOP(custom_group=[['op1', 'op2'], ['op3']]) >>> mutator.prepare_from_supernet(model) - >>> mutator.search_group + >>> mutator.search_groups {0: [op1, op2], 1: [op3]} >>> # Using both alias and module name for grouping >>> mutator = DiffOP(custom_group=[['a2'], ['op2']]) >>> mutator.prepare_from_supernet(model) >>> # The last operation would be grouped - >>> mutator.search_group + >>> mutator.search_groups {0: [op3], 1: [op2], 2: [op1]} @@ -271,7 +271,7 @@ class ArchitectureMutator(BaseMutator, Generic[MUTABLE_TYPE]): f'The duplicate keys are {duplicate_keys}. ' \ 'Please check if there are duplicate keys in the `custom_group`.' - self._search_group = search_groups + self._search_groups = search_groups def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], name2mutable: Dict[str, MUTABLE_TYPE], diff --git a/mmrazor/models/mutators/channel_mutator/__init__.py b/mmrazor/models/mutators/channel_mutator/__init__.py new file mode 100644 index 00000000..4eb08806 --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .channel_mutator import ChannelMutator +from .one_shot_channel_mutator import OneShotChannelMutator +from .slimmable_channel_mutator import SlimmableChannelMutable + +__all__ = [ + 'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutable' +] diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py new file mode 100644 index 00000000..2ef7e5e3 --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from abc import abstractmethod +from typing import Callable, Dict, List, Optional + +import torch.nn as nn +from torch.nn import Module +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import (BatchNorm1d, BatchNorm2d, BatchNorm3d, + _BatchNorm) +from torch.nn.modules.instancenorm import (InstanceNorm1d, InstanceNorm2d, + InstanceNorm3d, _InstanceNorm) + +from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode, + LinearNode, NormNode, PathList) +from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn, + build_dynamic_conv2d, + build_dynamic_gn, + build_dynamic_in, + build_dynamic_linear) +from mmrazor.models.mutables import OneShotChannelMutable +from mmrazor.registry import MODELS, TASK_UTILS +from ..base_mutator import BaseMutator + +NONPASS_NODES = (ConvNode, LinearNode, ConcatNode) +PASS_NODES = (NormNode, DepthWiseConvNode) + +NONPASS_MODULES = (nn.Conv2d, nn.Linear) +PASS_MODULES = (_BatchNorm, _InstanceNorm, GroupNorm) + + +@MODELS.register_module() +class ChannelMutator(BaseMutator): + """Base class for channel-based mutators. + + Args: + mutable_cfg (dict): The config for the channel mutable. + tracer_cfg (dict | Optional): The config for the model tracer. + We Trace the topology of a given model with the tracer. + skip_prefixes (List[str] | Optional): The module whose name start with + a string in skip_prefixes will not be pruned. + init_cfg (dict, optional): The config to control the initialization. + + Attributes: + search_groups (Dict[int, List]): Search group of supernet. Note that + the search group of a mutable based channel mutator is composed of + corresponding mutables. Mutables in the same search group should + be pruned together. + name2module (Dict[str, :obj:`torch.nn.Module`]): The mapping from + a module name to the module. + + Notes: + # To avoid ambiguity, we only allow the following two cases: + # 1. None of the parent nodes of a node is a `ConcatNode` + # 2. A node has only one parent node which is a `ConcatNode` + """ + + def __init__( + self, + mutable_cfg: Dict, + tracer_cfg: Optional[Dict] = None, + skip_prefixes: Optional[List[str]] = None, + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__(init_cfg) + + self.mutable_cfg = mutable_cfg + if tracer_cfg: + self.tracer = TASK_UTILS.build(tracer_cfg) + else: + self.tracer = None + self.skip_prefixes = skip_prefixes + self._search_groups: Optional[Dict[int, List[Module]]] = None + + def add_link(self, path_list: PathList) -> None: + """Establish the relationship between the current nodes and their + parents.""" + for path in path_list: + pre_node = None + for node in path: + if isinstance(node, DepthWiseConvNode): + module = self.name2module[node.name] + # The in_channels and out_channels of a depth-wise conv + # should be the same + module.mutable_out.register_same_mutable(module.mutable_in) + module.mutable_in.register_same_mutable(module.mutable_out) + + if isinstance(node, ConcatNode): + if pre_node is not None: + module_names = node.get_module_names() + concat_modules = [ + self.name2module[name] for name in module_names + ] + concat_mutables = [ + module.mutable_out for module in concat_modules + ] + pre_module = self.name2module[pre_node.name] + pre_module.mutable_in.register_same_mutable( + concat_mutables) + + for cur_path_list in node: + self.add_link(cur_path_list) + + # ConcatNode is the last node in a path + break + + if pre_node is None: + pre_node = node + continue + + pre_module = self.name2module[pre_node.name] + cur_module = self.name2module[node.name] + pre_module.mutable_in.register_same_mutable( + cur_module.mutable_out) + cur_module.mutable_out.register_same_mutable( + pre_module.mutable_in) + + pre_node = node + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + We support the following two cases: + + Case 1: The input is the original nn.Module. We first replace the + conv/linear/norm modules in the input supernet with dynamic ops. + And trace the topology of the supernet. Finally, `search_groups` can be + built based on the topology. + + Case 2: The input supernet is made up of dynamic ops. In this case, + relationship between nodes and their parents must have been + established and topology of the supernet is available for us. Then + `search_groups` can be built based on the topology. + + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + if self.tracer is not None: + self.convert_dynamic_module(supernet, self.dynamic_layer) + # The mapping from a module name to the module + self._name2module = dict(supernet.named_modules()) + + assert self.tracer is not None + module_path_list: PathList = self.tracer.trace(supernet) + + self.add_link(module_path_list) + else: + self._name2module = dict(supernet.named_modules()) + + self._search_groups = self.build_search_groups(supernet) + + @staticmethod + def find_same_mutables(supernet) -> Dict: + """The mutables in the same group should be pruned together.""" + visited = [] + groups = {} + group_idx = 0 + for name, module in supernet.named_modules(): + if isinstance(module, OneShotChannelMutable): + same_mutables = module.same_mutables + if module not in visited and len(same_mutables) > 0: + groups[group_idx] = [module] + same_mutables + visited.extend(groups[group_idx]) + group_idx += 1 + return groups + + def convert_dynamic_module(self, supernet: Module, dynamic_layer: Dict): + """Replace the conv/linear/norm modules in the input supernet with + dynamic ops. + + Args: + supernet (:obj:`torch.nn.Module`): The architecture to be converted + in your algorithm. + dynamic_layer (Dict): The mapping from the module type to the + corresponding dynamic layer. + """ + + def traverse(module, prefix): + for name, child in module.named_children(): + module_name = prefix + name + if isinstance(child, NONPASS_MODULES): + mutable_cfg = copy.deepcopy(self.mutable_cfg) + # mutable_cfg.update(dict(name=module_name)) + layer = dynamic_layer[type(child)](child, module_name, + mutable_cfg, + mutable_cfg) + setattr(module, name, layer) + elif isinstance(child, PASS_MODULES): + mutable_cfg = copy.deepcopy(self.mutable_cfg) + # mutable_cfg.update(dict(name=module_name)) + layer = dynamic_layer[type(child)](child, module_name, + mutable_cfg) + setattr(module, name, layer) + else: + traverse(child, module_name + '.') + + traverse(supernet, '') + + @abstractmethod + def build_search_groups(self, supernet: Module): + """Build `search_groups`. + + The mutables in the same group should be pruned together. + """ + + @property + def search_groups(self) -> Dict[int, List]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + + Raises: + RuntimeError: Called before search group has been built. + + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `search_groups` before access `build_search_groups`!') + return self._search_groups + + @property + def name2module(self): + """The mapping from a module name to the module. + + Returns: + dict: The name to module mapping. + """ + if hasattr(self, '_name2module'): + return self._name2module + else: + raise RuntimeError('Called before access `prepare_from_supernet`!') + + @property + def dynamic_layer(self) -> Dict: + """The mapping from a type to the corresponding dynamic layer. It is + called in `prepare_from_supernet`. + + Returns: + dict: The mapping dict. + """ + + dynamic_layer: Dict[Callable, Callable] = { + nn.Conv2d: build_dynamic_conv2d, + nn.Linear: build_dynamic_linear, + BatchNorm1d: build_dynamic_bn, + BatchNorm2d: build_dynamic_bn, + BatchNorm3d: build_dynamic_bn, + InstanceNorm1d: build_dynamic_in, + InstanceNorm2d: build_dynamic_in, + InstanceNorm3d: build_dynamic_in, + GroupNorm: build_dynamic_gn + } + + return dynamic_layer + + def is_skip_pruning(self, module_name: str, + skip_prefixes: Optional[List[str]]) -> bool: + """Judge if the module with the input `module_name` should not be + pruned. + + Args: + module_name (str): Module name. + skip_prefixes (list or None): The module whose name start with + a string in skip_prefixes will not be prune. + """ + skip_pruning = False + if skip_prefixes: + for prefix in skip_prefixes: + if module_name.startswith(prefix): + skip_pruning = True + break + return skip_pruning diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py new file mode 100644 index 00000000..4c08ff1e --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Any, Dict, List, Optional + +from torch.nn import Module + +from mmrazor.core.tracer import (ConcatNode, ConvNode, DepthWiseConvNode, + LinearNode, NormNode) +from mmrazor.registry import MODELS +from .channel_mutator import ChannelMutator + +NONPASS_NODES = (ConvNode, LinearNode, ConcatNode) +PASS_NODES = (NormNode, DepthWiseConvNode) +PRUNING_NODES = (ConvNode, LinearNode) + + +@MODELS.register_module() +class OneShotChannelMutator(ChannelMutator): + """One-shot channel mutable based channel mutator. + + Args: + mutable_cfg (dict): The config for the channel mutable. + tracer_cfg (dict): The config for the model tracer. We Trace the + topology of a given model with the tracer. + skip_prefixes (List[str] | Optional): The module whose name start with + a string in skip_prefixes will not be pruned. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__(self, + mutable_cfg: Dict, + tracer_cfg: Optional[Dict] = None, + skip_prefixes: Optional[List[str]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(mutable_cfg, tracer_cfg, skip_prefixes, init_cfg) + + def sample_choices(self): + """Sample a choice that records a selection from the search space. + + Returns: + dict: Record the information to build the subnet from the supernet. + Its keys are the properties ``group_idx`` in the channel + mutator's ``search_groups``, and its values are the sampled + choice. + """ + choice_dict = dict() + for group_idx, mutables in self.search_groups.items(): + choice_dict[group_idx] = mutables[0].sample_choice() + return choice_dict + + def set_choices(self, choice_dict: Dict[int, Any]) -> None: + """Set current subnet according to ``choice_dict``. + + Args: + choice_dict (Dict[int, Any]): Choice dict. + """ + for group_idx, choice in choice_dict.items(): + mutables = self.search_groups[group_idx] + for mutable in mutables: + mutable.current_choice = choice + + def set_max_choices(self) -> None: + """Set the channel numbers of each layer to maximum.""" + for mutables in self.search_groups.values(): + for mutable in mutables: + mutable.current_choice = mutable.max_choice + + def set_min_choices(self) -> None: + """Set the channel numbers of each layer to minimum.""" + for mutables in self.search_groups.values(): + for mutable in mutables: + mutable.current_choice = mutable.min_choice + + # todo: check search gorups + def build_search_groups(self, supernet: Module): + """Build `search_groups`. The mutables in the same group should be + pruned together. + + Examples: + >>> class ResBlock(nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... + ... self.op1 = nn.Conv2d(3, 8, 1) + ... self.bn1 = nn.BatchNorm2d(8) + ... self.op2 = nn.Conv2d(8, 8, 1) + ... self.bn2 = nn.BatchNorm2d(8) + ... self.op3 = nn.Conv2d(8, 8, 1) + ... + ... def forward(self, x): + ... x1 = self.bn1(self.op1(x)) + ... x2 = self.bn2(self.op2(x1)) + ... x3 = self.op3(x2 + x1) + ... return x3 + + >>> class ToyPseudoLoss: + ... + ... def __call__(self, model): + ... pseudo_img = torch.rand(2, 3, 16, 16) + ... pseudo_output = model(pseudo_img) + ... return pseudo_output.sum() + + >>> mutator = OneShotChannelMutator( + ... tracer_cfg=dict(type='BackwardTracer', + ... loss_calculator=ToyPseudoLoss()), + ... mutable_cfg=dict(type='RatioChannelMutable', + ... candidate_choices=[4 / 8, 1.0]) + + >>> model = ResBlock() + >>> mutator.prepare_from_supernet(model) + >>> mutator.search_groups + {0: [RatioChannelMutable(name=op2, mask_type=out_mask, ...), + RatioChannelMutable(name=op1, mask_type=out_mask, ...), + RatioChannelMutable(name=op3, mask_type=in_mask, ...), + RatioChannelMutable(name=op2, mask_type=in_mask, ...), + RatioChannelMutable(name=bn2, mask_type=out_mask, ...), + RatioChannelMutable(name=bn1, mask_type=out_mask, ...)]} + """ + groups = self.find_same_mutables(supernet) + + search_groups = dict() + group_idx = 0 + for group in groups.values(): + is_skip = False + for mutable in group: + if self.is_skip_pruning(mutable.name, self.skip_prefixes): + warnings.warn(f'Group {group} is not searchable due to' + f' skip_prefixes: {self.skip_prefixes}') + is_skip = True + break + if not is_skip: + search_groups[group_idx] = group + group_idx += 1 + + return search_groups diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py new file mode 100644 index 00000000..ec96d97c --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional + +import torch.nn as nn +from torch.nn import Module +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.architectures.dynamic_op import (DynamicBatchNorm, + build_switchable_bn) +from mmrazor.models.mutables import SlimmableChannelMutable +from mmrazor.registry import MODELS +from .channel_mutator import ChannelMutator + +NONPASS_MODULES = (nn.Conv2d, nn.Linear) +PASS_MODULES = (_BatchNorm, ) + + +@MODELS.register_module() +class SlimmableChannelMutator(ChannelMutator): + """Slimmable channel mutable based channel mutator. + + Args: + channel_cfgs (list[Dict]): A list of candidate channel configs. + mutable_cfg (dict): The config for the channel mutable. + skip_prefixes (List[str] | Optional): The module whose name start with + a string in skip_prefixes will not be pruned. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__(self, + channel_cfgs: List[Dict], + mutable_cfg: Dict, + skip_prefixes: Optional[List[str]] = None, + init_cfg: Optional[Dict] = None): + super(SlimmableChannelMutator, self).__init__( + mutable_cfg=mutable_cfg, + skip_prefixes=skip_prefixes, + init_cfg=init_cfg) + + self.channel_cfgs = self._merge_channel_cfgs(channel_cfgs) + + def _merge_channel_cfgs(self, channel_cfgs: List[Dict]): + """Merge several channel configs. + + Args: + channel_cfgs (List[Dict]) + """ + merged_channel_cfg = dict() + num_subnet = len(channel_cfgs) + + for module_name in channel_cfgs[0].keys(): + channels_per_layer = [ + channel_cfgs[idx][module_name] for idx in range(num_subnet) + ] + merged_channels_per_layer = dict() + for key in channels_per_layer[0].keys(): + merged_channels = [ + channels_per_layer[idx][key] for idx in range(num_subnet) + ] + merged_channels_per_layer[key] = merged_channels + merged_channel_cfg[module_name] = merged_channels_per_layer + + return merged_channel_cfg + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + Different from `ChannelMutator`, we only support Case 1 in + `ChannelMutator`. The input supernet should be made up of original + nn.Module. And we replace the conv/linear/bn modules in the input + supernet with dynamic ops first. Then we convert the + ``DynamicBatchNorm`` in supernet with ``SwitchableBatchNorm2d``. + Finally, we set the candidate channel numbers to the corresponding + `SlimmableChannelMutable`. + + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self.convert_dynamic_module(supernet, self.dynamic_layer) + print(supernet) + self.convert_switchable_bn(supernet) + self.set_candidate_choices(supernet) + # The mapping from a module name to the module + self._name2module = dict(supernet.named_modules()) + + def set_candidate_choices(self, supernet): + """Set the ``candidate_choices`` of each ``SlimmableChannelMutable``. + + Notes: + Different from other ``OneShotChannelMutable``, + ``candidate_choices`` is optional when instantiating a + ``SlimmableChannelMutable`` + """ + for name, module in supernet.named_modules(): + if isinstance(module, SlimmableChannelMutable): + candidate_choices = self.channel_cfgs[name]['current_choice'] + module.candidate_choices = candidate_choices + + def convert_switchable_bn(self, supernet): + """Replace ``DynamicBatchNorm`` in supernet with + ``SwitchableBatchNorm2d``. + + Args: + supernet (:obj:`torch.nn.Module`): The architecture to be converted + in your algorithm. + """ + + def traverse(module, prefix): + for name, child in module.named_children(): + module_name = prefix + name + if isinstance(child, DynamicBatchNorm): + mutable_cfg = copy.deepcopy(self.mutable_cfg) + key = module_name + '.mutable_num_features' + candidate_choices = self.channel_cfgs[key][ + 'current_choice'] + mutable_cfg.update( + dict(candidate_choices=candidate_choices)) + sbn = build_switchable_bn(child, module_name, mutable_cfg) + setattr(module, name, sbn) + else: + traverse(child, module_name + '.') + + traverse(supernet, '') + + def switch_choices(self, idx): + """Switch the channel config of the supernet according to input `idx`. + + If we train more than one subnet together, we need to switch the + channel_cfg from one to another during one training iteration. + + Args: + idx (int): The index of the current subnet. + """ + for name, module in self.name2module.items(): + if hasattr(module, 'mutable_out'): + module.mutable_out.current_choice = idx + if hasattr(module, 'mutable_in'): + module.mutable_in.current_choice = idx + + def build_search_groups(self, supernet: Module): + """Build `search_groups`. + + The mutables in the same group should be pruned together. + """ + pass diff --git a/mmrazor/models/mutators/diff_mutator.py b/mmrazor/models/mutators/diff_mutator.py index 40542870..e645098b 100644 --- a/mmrazor/models/mutators/diff_mutator.py +++ b/mmrazor/models/mutators/diff_mutator.py @@ -45,12 +45,12 @@ class DiffMutator(ArchitectureMutator[DiffMutable]): group_id share the same arch param. Returns: - torch.nn.ParameterDict: the arch params are got by `search_group`. + torch.nn.ParameterDict: the arch params are got by `search_groups`. """ arch_params: Dict[int, nn.Parameter] = dict() - for group_id, modules in self.search_group.items(): + for group_id, modules in self.search_groups.items(): group_arch_param = modules[0].build_arch_param() arch_params[group_id] = group_arch_param @@ -65,7 +65,7 @@ class DiffMutator(ArchitectureMutator[DiffMutable]): `DiffMutable`. """ - for group_id, modules in self.search_group.items(): + for group_id, modules in self.search_groups.items(): for module in modules: module.set_forward_args(arch_param=self.arch_params[group_id]) diff --git a/mmrazor/models/mutators/one_shot_mutator.py b/mmrazor/models/mutators/one_shot_mutator.py index 6aabf527..8650d82b 100644 --- a/mmrazor/models/mutators/one_shot_mutator.py +++ b/mmrazor/models/mutators/one_shot_mutator.py @@ -25,7 +25,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]): ['op1', 'op2', 'op3'] >>> mutator.prepare_from_supernet(supernet) - >>> mutator.search_group + >>> mutator.search_groups {0: [op1], 1: [op2], 2: [op3]} >>> random_choices = mutator.sample_choices() @@ -61,7 +61,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]): Dict[int, Any]: Random choices dict. """ random_choices = dict() - for group_id, modules in self.search_group.items(): + for group_id, modules in self.search_groups.items(): random_choices[group_id] = modules[0].sample_choice() return random_choices @@ -75,7 +75,7 @@ class OneShotMutator(ArchitectureMutator[OneShotMutable]): search groups, and the value is the sampling results corresponding to this group. """ - for group_id, modules in self.search_group.items(): + for group_id, modules in self.search_groups.items(): choice = choices[group_id] for module in modules: module.current_choice = choice diff --git a/mmrazor/models/pruners/__init__.py b/mmrazor/models/pruners/__init__.py deleted file mode 100644 index cb0480d4..00000000 --- a/mmrazor/models/pruners/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .ratio_pruning import RatioPruner -from .structure_pruning import StructurePruner -from .utils import * # noqa: F401,F403 - -__all__ = ['RatioPruner', 'StructurePruner'] diff --git a/mmrazor/models/pruners/ratio_pruning.py b/mmrazor/models/pruners/ratio_pruning.py deleted file mode 100644 index 4486a889..00000000 --- a/mmrazor/models/pruners/ratio_pruning.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import torch -import torch.nn as nn -from torch.nn.modules import GroupNorm - -from mmrazor.registry import MODELS -from .structure_pruning import StructurePruner -from .utils import SwitchableBatchNorm2d - - -@MODELS.register_module() -class RatioPruner(StructurePruner): - """A random ratio pruner. - - Each layer can adjust its own width ratio randomly and independently. - - Args: - ratios (list | tuple): Width ratio of each layer can be - chosen from `ratios` randomly. The width ratio is the ratio between - the number of reserved channels and that of all channels in a - layer. For example, if `ratios` is [0.25, 0.5], there are 2 cases - for us to choose from when we sample from a layer with 12 channels. - One is sampling the very first 3 channels in this layer, another is - sampling the very first 6 channels in this layer. Default to None. - """ - - def __init__(self, ratios, **kwargs): - super(RatioPruner, self).__init__(**kwargs) - ratios = list(ratios) - ratios.sort() - self.ratios = ratios - self.min_ratio = ratios[0] - - def _check_pruner(self, supernet): - for module in supernet.model.modules(): - if isinstance(module, GroupNorm): - num_channels = module.num_channels - num_groups = module.num_groups - for ratio in self.ratios: - new_channels = int(round(num_channels * ratio)) - assert (num_channels * ratio) % num_groups == 0, \ - f'Expected number of channels in input of GroupNorm ' \ - f'to be divisible by num_groups, but number of ' \ - f'channels may be {new_channels} according to ' \ - f'ratio {ratio} and num_groups={num_groups}' - - def prepare_from_supernet(self, supernet): - super(RatioPruner, self).prepare_from_supernet(supernet) - - def get_channel_mask(self, out_mask): - """Randomly choose a width ratio of a layer from ``ratios``""" - out_channels = out_mask.size(1) - random_ratio = np.random.choice(self.ratios) - new_channels = int(round(out_channels * random_ratio)) - assert new_channels > 0, \ - 'Output channels should be a positive integer.' - new_out_mask = torch.zeros_like(out_mask) - new_out_mask[:, :new_channels] = 1 - - return new_out_mask - - def sample_subnet(self): - """Random sample subnet by random mask. - - Returns: - dict: Record the information to build the subnet from the supernet, - its keys are the properties ``space_id`` in the pruner's search - spaces, and its values are corresponding sampled out_mask. - """ - subnet_dict = dict() - for space_id, out_mask in self.channel_spaces.items(): - subnet_dict[space_id] = self.get_channel_mask(out_mask) - return subnet_dict - - def set_min_channel(self): - """Set the number of channels each layer to minimum.""" - subnet_dict = dict() - for space_id, out_mask in self.channel_spaces.items(): - out_channels = out_mask.size(1) - random_ratio = self.min_ratio - new_channels = int(round(out_channels * random_ratio)) - assert new_channels > 0, \ - 'Output channels should be a positive integer.' - new_out_mask = torch.zeros_like(out_mask) - new_out_mask[:, :new_channels] = 1 - - subnet_dict[space_id] = new_out_mask - - self.set_subnet(subnet_dict) - - def switch_subnet(self, channel_cfg, subnet_ind=None): - """Switch the channel config of the supernet according to channel_cfg. - - If we train more than one subnet together, we need to switch the - channel_cfg from one to another during one training iteration. - - Args: - channel_cfg (dict): The channel config of a subnet. Key is space_id - and value is a dict which includes out_channels (and - in_channels if exists). - subnet_ind (int, optional): The index of the current subnet. If - we replace normal BatchNorm2d with ``SwitchableBatchNorm2d``, - we should switch the index of ``SwitchableBatchNorm2d`` when - switch subnet. Defaults to None. - """ - subnet_dict = dict() - for name, channels_per_layer in channel_cfg.items(): - module = self.name2module[name] - if (isinstance(module, SwitchableBatchNorm2d) - and subnet_ind is not None): - # When switching bn we should switch index simultaneously - module.index = subnet_ind - continue - - out_channels = channels_per_layer['out_channels'] - out_mask = torch.zeros_like(module.out_mask) - out_mask[:, :out_channels] = 1 - - space_id = self.get_space_id(name) - if space_id in subnet_dict: - assert torch.equal(subnet_dict[space_id], out_mask) - elif space_id is not None: - subnet_dict[space_id] = out_mask - - self.set_subnet(subnet_dict) - - def convert_switchable_bn(self, module, num_bns): - """Convert normal ``nn.BatchNorm2d`` to ``SwitchableBatchNorm2d``. - - Args: - module (:obj:`torch.nn.Module`): The module to be converted. - num_bns (int): The number of ``nn.BatchNorm2d`` in a - ``SwitchableBatchNorm2d``. - - Return: - :obj:`torch.nn.Module`: The converted module. Each - ``nn.BatchNorm2d`` in this module has been converted to a - ``SwitchableBatchNorm2d``. - """ - module_output = module - if isinstance(module, nn.modules.batchnorm._BatchNorm): - module_output = SwitchableBatchNorm2d(module.num_features, num_bns) - - for name, child in module.named_children(): - module_output.add_module( - name, self.convert_switchable_bn(child, num_bns)) - - del module - return module_output diff --git a/mmrazor/models/pruners/utils/__init__.py b/mmrazor/models/pruners/utils/__init__.py deleted file mode 100644 index 6da7c245..00000000 --- a/mmrazor/models/pruners/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .switchable_bn import SwitchableBatchNorm2d - -__all__ = ['SwitchableBatchNorm2d'] diff --git a/mmrazor/models/pruners/utils/switchable_bn.py b/mmrazor/models/pruners/utils/switchable_bn.py deleted file mode 100644 index 686a04eb..00000000 --- a/mmrazor/models/pruners/utils/switchable_bn.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - - -class SwitchableBatchNorm2d(nn.Module): - """Employs independent batch normalization for different switches in a - slimmable network. - - To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all - batch normalization layers for each switch in a slimmable network. - Compared with the naive training approach, it solves the problem of feature - aggregation inconsistency between different switches by independently - normalizing the feature mean and variance during testing. - - Args: - max_num_features (int): The maximum ``num_features`` among BatchNorm2d - in all the switches. - num_bns (int): The number of different switches in the slimmable - networks. - """ - - def __init__(self, max_num_features, num_bns): - super(SwitchableBatchNorm2d, self).__init__() - - self.max_num_features = max_num_features - # number of BatchNorm2d in a SwitchableBatchNorm2d - self.num_bns = num_bns - bns = [] - for _ in range(num_bns): - bns.append(nn.BatchNorm2d(max_num_features)) - self.bns = nn.ModuleList(bns) - # When switching bn we should switch index simultaneously - self.index = 0 - - def forward(self, input): - """Forward computation according to the current switch of the slimmable - networks.""" - return self.bns[self.index](input) diff --git a/tests/data/MBV2_220M.yaml b/tests/data/MBV2_220M.yaml new file mode 100644 index 00000000..2aa967c8 --- /dev/null +++ b/tests/data/MBV2_220M.yaml @@ -0,0 +1,474 @@ +backbone.conv1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.conv1.conv.mutable_in_channels: + current_choice: 3 + origin_channels: 3 +backbone.conv1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.conv2.bn.mutable_num_features: + current_choice: 1920 + origin_channels: 1920 +backbone.conv2.conv.mutable_in_channels: + current_choice: 280 + origin_channels: 480 +backbone.conv2.conv.mutable_out_channels: + current_choice: 1920 + origin_channels: 1920 +backbone.layer1.0.conv.0.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 24 +backbone.layer1.0.conv.1.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.0.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.1.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.2.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.bn.mutable_num_features: + current_choice: 48 + origin_channels: 96 +backbone.layer4.3.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.conv.mutable_out_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer5.0.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.0.conv.mutable_in_channels: + current_choice: 48 + origin_channels: 96 +backbone.layer5.0.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.0.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.1.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.bn.mutable_num_features: + current_choice: 64 + origin_channels: 144 +backbone.layer5.2.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.conv.mutable_out_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer6.0.conv.0.bn.mutable_num_features: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.0.conv.mutable_in_channels: + current_choice: 64 + origin_channels: 144 +backbone.layer6.0.conv.0.conv.mutable_out_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.bn.mutable_num_features: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_in_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_out_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.0.conv.2.conv.mutable_in_channels: + current_choice: 648 + origin_channels: 864 +backbone.layer6.0.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.0.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.0.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.1.conv.2.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.1.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.0.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.0.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.bn.mutable_num_features: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_out_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.2.bn.mutable_num_features: + current_choice: 176 + origin_channels: 240 +backbone.layer6.2.conv.2.conv.mutable_in_channels: + current_choice: 720 + origin_channels: 1440 +backbone.layer6.2.conv.2.conv.mutable_out_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer7.0.conv.0.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.0.conv.mutable_in_channels: + current_choice: 176 + origin_channels: 240 +backbone.layer7.0.conv.0.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.bn.mutable_num_features: + current_choice: 280 + origin_channels: 480 +backbone.layer7.0.conv.2.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.conv.mutable_out_channels: + current_choice: 280 + origin_channels: 480 +head.fc.mutable_in_features: + current_choice: 1920 + origin_channels: 1920 +head.fc.mutable_out_features: + current_choice: 1000 + origin_channels: 1000 diff --git a/tests/data/MBV2_320M.yaml b/tests/data/MBV2_320M.yaml new file mode 100644 index 00000000..2c63bcf7 --- /dev/null +++ b/tests/data/MBV2_320M.yaml @@ -0,0 +1,474 @@ +backbone.conv1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.conv1.conv.mutable_in_channels: + current_choice: 3 + origin_channels: 3 +backbone.conv1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.conv2.bn.mutable_num_features: + current_choice: 1920 + origin_channels: 1920 +backbone.conv2.conv.mutable_in_channels: + current_choice: 480 + origin_channels: 480 +backbone.conv2.conv.mutable_out_channels: + current_choice: 1920 + origin_channels: 1920 +backbone.layer1.0.conv.0.bn.mutable_num_features: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.0.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.bn.mutable_num_features: + current_choice: 8 + origin_channels: 24 +backbone.layer1.0.conv.1.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 48 +backbone.layer1.0.conv.1.conv.mutable_out_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.0.conv.mutable_in_channels: + current_choice: 8 + origin_channels: 24 +backbone.layer2.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer2.0.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.bn.mutable_num_features: + current_choice: 16 + origin_channels: 40 +backbone.layer2.1.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer2.1.conv.2.conv.mutable_out_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.0.conv.mutable_in_channels: + current_choice: 16 + origin_channels: 40 +backbone.layer3.0.conv.0.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.bn.mutable_num_features: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.1.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.0.conv.2.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 240 +backbone.layer3.0.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.1.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.1.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.bn.mutable_num_features: + current_choice: 24 + origin_channels: 48 +backbone.layer3.2.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer3.2.conv.2.conv.mutable_out_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.0.conv.mutable_in_channels: + current_choice: 24 + origin_channels: 48 +backbone.layer4.0.conv.0.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.bn.mutable_num_features: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.1.conv.mutable_out_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.bn.mutable_num_features: + current_choice: 56 + origin_channels: 96 +backbone.layer4.0.conv.2.conv.mutable_in_channels: + current_choice: 144 + origin_channels: 288 +backbone.layer4.0.conv.2.conv.mutable_out_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.1.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.0.conv.mutable_in_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.1.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.bn.mutable_num_features: + current_choice: 56 + origin_channels: 96 +backbone.layer4.1.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.1.conv.2.conv.mutable_out_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.2.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.0.conv.mutable_in_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.2.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.bn.mutable_num_features: + current_choice: 56 + origin_channels: 96 +backbone.layer4.2.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.2.conv.2.conv.mutable_out_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.3.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.0.conv.mutable_in_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer4.3.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.bn.mutable_num_features: + current_choice: 56 + origin_channels: 96 +backbone.layer4.3.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer4.3.conv.2.conv.mutable_out_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer5.0.conv.0.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.0.conv.mutable_in_channels: + current_choice: 56 + origin_channels: 96 +backbone.layer5.0.conv.0.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.bn.mutable_num_features: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.1.conv.mutable_out_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer5.0.conv.2.conv.mutable_in_channels: + current_choice: 288 + origin_channels: 576 +backbone.layer5.0.conv.2.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer5.1.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.0.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer5.1.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer5.1.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.1.conv.2.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer5.2.conv.0.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.0.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer5.2.conv.0.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.bn.mutable_num_features: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.1.conv.mutable_out_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.bn.mutable_num_features: + current_choice: 96 + origin_channels: 144 +backbone.layer5.2.conv.2.conv.mutable_in_channels: + current_choice: 432 + origin_channels: 864 +backbone.layer5.2.conv.2.conv.mutable_out_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer6.0.conv.0.bn.mutable_num_features: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.0.conv.mutable_in_channels: + current_choice: 96 + origin_channels: 144 +backbone.layer6.0.conv.0.conv.mutable_out_channels: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.1.bn.mutable_num_features: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_in_channels: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.1.conv.mutable_out_channels: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.2.bn.mutable_num_features: + current_choice: 240 + origin_channels: 240 +backbone.layer6.0.conv.2.conv.mutable_in_channels: + current_choice: 864 + origin_channels: 864 +backbone.layer6.0.conv.2.conv.mutable_out_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer6.1.conv.0.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.0.conv.mutable_in_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer6.1.conv.0.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.1.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.1.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.2.bn.mutable_num_features: + current_choice: 240 + origin_channels: 240 +backbone.layer6.1.conv.2.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer6.1.conv.2.conv.mutable_out_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer6.2.conv.0.bn.mutable_num_features: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.0.conv.mutable_in_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer6.2.conv.0.conv.mutable_out_channels: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.1.bn.mutable_num_features: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_in_channels: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.1.conv.mutable_out_channels: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.2.bn.mutable_num_features: + current_choice: 240 + origin_channels: 240 +backbone.layer6.2.conv.2.conv.mutable_in_channels: + current_choice: 960 + origin_channels: 1440 +backbone.layer6.2.conv.2.conv.mutable_out_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer7.0.conv.0.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.0.conv.mutable_in_channels: + current_choice: 240 + origin_channels: 240 +backbone.layer7.0.conv.0.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.bn.mutable_num_features: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.1.conv.mutable_out_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.bn.mutable_num_features: + current_choice: 480 + origin_channels: 480 +backbone.layer7.0.conv.2.conv.mutable_in_channels: + current_choice: 1440 + origin_channels: 1440 +backbone.layer7.0.conv.2.conv.mutable_out_channels: + current_choice: 480 + origin_channels: 480 +head.fc.mutable_in_features: + current_choice: 1920 + origin_channels: 1920 +head.fc.mutable_out_features: + current_choice: 1000 + origin_channels: 1000 diff --git a/tests/data/subnet1.yaml b/tests/data/subnet1.yaml new file mode 100644 index 00000000..f7886351 --- /dev/null +++ b/tests/data/subnet1.yaml @@ -0,0 +1,24 @@ +op1.mutable_in_channels: + current_choice: 3 + origin_channels: 3 +op1.mutable_out_channels: + current_choice: 4 + origin_channels: 8 +bn1.mutable_num_features: + current_choice: 4 + origin_channels: 8 +op2.mutable_in_channels: + current_choice: 4 + origin_channels: 8 +op2.mutable_out_channels: + current_choice: 4 + origin_channels: 8 +bn2.mutable_num_features: + current_choice: 4 + origin_channels: 8 +op3.mutable_in_channels: + current_choice: 4 + origin_channels: 8 +op3.mutable_out_channels: + current_choice: 8 + origin_channels: 8 \ No newline at end of file diff --git a/tests/data/subnet2.yaml b/tests/data/subnet2.yaml new file mode 100644 index 00000000..bd49b2c7 --- /dev/null +++ b/tests/data/subnet2.yaml @@ -0,0 +1,24 @@ +op1.mutable_in_channels: + current_choice: 3 + origin_channels: 3 +op1.mutable_out_channels: + current_choice: 8 + origin_channels: 8 +bn1.mutable_num_features: + current_choice: 8 + origin_channels: 8 +op2.mutable_in_channels: + current_choice: 8 + origin_channels: 8 +op2.mutable_out_channels: + current_choice: 8 + origin_channels: 8 +bn2.mutable_num_features: + current_choice: 8 + origin_channels: 8 +op3.mutable_in_channels: + current_choice: 8 + origin_channels: 8 +op3.mutable_out_channels: + current_choice: 8 + origin_channels: 8 \ No newline at end of file diff --git a/tests/test_core/test_tracer/test_backward_tracer.py b/tests/test_core/test_tracer/test_backward_tracer.py new file mode 100644 index 00000000..445dd3e7 --- /dev/null +++ b/tests/test_core/test_tracer/test_backward_tracer.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +from torch import Tensor, nn +from torch.nn import Module + +from mmrazor.core import (BackwardTracer, ConcatNode, ConvNode, + DepthWiseConvNode, LinearNode, NormNode, Path, + PathList) + +NONPASS_NODES = (ConvNode, LinearNode, ConcatNode) +PASS_NODES = (NormNode, DepthWiseConvNode) + + +class MultiConcatModel(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(16, 8, 1) + self.op4 = nn.Conv2d(3, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + x4 = self.op4(x) + output = torch.cat([x3, x4], dim=1) + + return output + + +class MultiConcatModel2(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(3, 8, 1) + self.op4 = nn.Conv2d(24, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + cat1 = torch.cat([x1, x2], dim=1) + cat2 = torch.cat([cat1, x3], dim=1) + output = self.op4(cat2) + + return output + + +class ResBlock(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(8, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(8, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x1)) + x3 = self.op3(x2 + x1) + return x3 + + +class ToyCNNPseudoLoss: + + def __call__(self, model): + pseudo_img = torch.rand(2, 3, 16, 16) + pseudo_output = model(pseudo_img) + return pseudo_output.sum() + + +class TestBackwardTracer(TestCase): + + def test_trace_resblock(self) -> None: + model = ResBlock() + loss_calculator = ToyCNNPseudoLoss() + tracer = BackwardTracer(loss_calculator=loss_calculator) + path_list = tracer.trace(model) + + # test tracer and parser + assert len(path_list) == 2 + assert len(path_list[0]) == 5 + + # test path_list + nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES) + assert len(nonpass2parents) == 3 + assert nonpass2parents['op1'] == list() + assert nonpass2parents['op2'] == list({NormNode('bn1')}) + assert nonpass2parents['op3'] == list( + {NormNode('bn2'), NormNode('bn1')}) + + nonpass2nonpassparents = path_list.find_nodes_parents( + NONPASS_NODES, non_pass=NONPASS_NODES) + assert len(nonpass2parents) == 3 + assert nonpass2nonpassparents['op1'] == list() + assert nonpass2nonpassparents['op2'] == list({ConvNode('op1')}) + assert nonpass2nonpassparents['op3'] == list( + {ConvNode('op2'), ConvNode('op1')}) + + pass2nonpassparents = path_list.find_nodes_parents( + PASS_NODES, non_pass=NONPASS_NODES) + assert len(pass2nonpassparents) == 2 + assert pass2nonpassparents['bn1'] == list({ConvNode('op1')}) + assert pass2nonpassparents['bn2'] == list({ConvNode('op2')}) + + def test_trace_multi_cat(self) -> None: + loss_calculator = ToyCNNPseudoLoss() + + model = MultiConcatModel() + tracer = BackwardTracer(loss_calculator=loss_calculator) + path_list = tracer.trace(model) + + assert len(path_list) == 1 + + nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES) + assert len(nonpass2parents) == 4 + assert nonpass2parents['op1'] == list() + assert nonpass2parents['op2'] == list() + path_list1 = PathList(Path(ConvNode('op1'))) + path_list2 = PathList(Path(ConvNode('op2'))) + # only one parent + assert len(nonpass2parents['op3']) == 1 + assert isinstance(nonpass2parents['op3'][0], ConcatNode) + assert len(nonpass2parents['op3'][0]) == 2 + assert nonpass2parents['op3'][0].get_module_names() == ['op1', 'op2'] + assert nonpass2parents['op3'][0].path_lists == [path_list1, path_list2] + assert nonpass2parents['op3'][0][0] == path_list1 + assert nonpass2parents['op4'] == list() + + model = MultiConcatModel2() + tracer = BackwardTracer(loss_calculator=loss_calculator) + path_list = tracer.trace(model) + assert len(path_list) == 1 + + nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES) + assert len(nonpass2parents) == 4 + assert nonpass2parents['op1'] == list() + assert nonpass2parents['op2'] == list() + assert nonpass2parents['op3'] == list() + # only one parent + assert len(nonpass2parents['op4']) == 1 + assert isinstance(nonpass2parents['op4'][0], ConcatNode) + assert nonpass2parents['op4'][0].get_module_names() == [ + 'op1', 'op2', 'op3' + ] + + def test_repr(self): + toy_node = ConvNode('op1') + assert repr(toy_node) == 'ConvNode(\'op1\')' + + toy_path = Path([ConvNode('op1'), ConvNode('op2')]) + assert repr( + toy_path) == 'Path(\n ConvNode(\'op1\'),\n ConvNode(\'op2\')\n)' + + toy_path_list = PathList(Path(ConvNode('op1'))) + assert repr(toy_path_list + ) == 'PathList(\n Path(\n ConvNode(\'op1\')\n )\n)' + + path_list1 = PathList(Path(ConvNode('op1'))) + path_list2 = PathList(Path(ConvNode('op2'))) + toy_concat_node = ConcatNode('op3', [path_list1, path_list2]) + assert repr( + toy_concat_node + ) == 'ConcatNode(\n PathList(\n Path(\n ConvNode(\'op1\')\n )\n ),\n PathList(\n Path(\n ConvNode(\'op2\')\n )\n )\n)' # noqa: E501 + + def test_reset_bn_running_stats(self): + _test_reset_bn_running_stats(False) + with pytest.raises(AssertionError): + _test_reset_bn_running_stats(True) + + def test_node(self): + node1 = ConvNode('conv1') + node2 = ConvNode('conv2') + assert node1 != node2 + + node1 = ConvNode('conv1') + node2 = ConvNode('conv1') + assert node1 == node2 + + def test_path(self): + node1 = ConvNode('conv1') + node2 = ConvNode('conv2') + + path1 = Path([node1]) + path2 = Path([node2]) + assert path1 != path2 + + path1 = Path([node1]) + path2 = Path([node1]) + assert path1 == path2 + + assert path1[0] == node1 + + def test_path_list(self): + node1 = ConvNode('conv1') + node2 = ConvNode('conv2') + + path1 = Path([node1]) + path2 = Path([node2]) + assert PathList(path1) == PathList([path1]) + assert PathList(path1) != PathList(path2) + + with self.assertRaisesRegex(AssertionError, ''): + _ = PathList({}) + + +def _test_reset_bn_running_stats(should_fail): + import os + import random + + import numpy as np + + def set_seed(seed: int) -> None: + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + set_seed(1024) + imgs = torch.randn(2, 3, 4, 4) + loss_calculator = ToyCNNPseudoLoss() + tracer = BackwardTracer(loss_calculator=loss_calculator) + if should_fail: + tracer._reset_norm_running_stats = lambda *_: None + + torch_rng_state = torch.get_rng_state() + np_rng_state = np.random.get_state() + random_rng_state = random.getstate() + + model1 = ResBlock() + set_seed(1) + tracer.trace(model1) + model1.eval() + output1 = model1(imgs) + + set_seed(1024) + torch.set_rng_state(torch_rng_state) + np.random.set_state(np_rng_state) + random.setstate(random_rng_state) + + model2 = ResBlock() + set_seed(2) + tracer.trace(model2) + model2.eval() + output2 = model2(imgs) + + assert torch.equal(output1.norm(p='fro'), output2.norm(p='fro')) diff --git a/tests/test_models/test_mutables/test_channel_mutable.py b/tests/test_models/test_mutables/test_channel_mutable.py new file mode 100644 index 00000000..7f5aaadb --- /dev/null +++ b/tests/test_models/test_mutables/test_channel_mutable.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import pytest +import torch + +from mmrazor.models import OrderChannelMutable, RatioChannelMutable + + +class TestChannelMutables(TestCase): + + def test_ratio_channel_mutable(self): + with pytest.raises(AssertionError): + # Test invalid `mask_type` + RatioChannelMutable( + name='op', + mask_type='xxx', + num_channels=8, + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + with pytest.raises(AssertionError): + # Number of candidate choices must be greater than 0 + RatioChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=list()) + + with pytest.raises(AssertionError): + # The candidate ratio should be in range(0, 1]. + RatioChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[0., 1 / 4, 2 / 4, 3 / 4, 1.0]) + + with pytest.raises(AssertionError): + # Minimum number of channels should be a positive integer. + out_mutable = RatioChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0]) + _ = out_mutable.min_choice + + with pytest.raises(AssertionError): + # Minimum number of channels should be a positive integer. + out_mutable = RatioChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0]) + out_mutable.get_choice(0) + + # Test out_mutable (mask_type == 'out_mask') + out_mutable = RatioChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + random_choice = out_mutable.sample_choice() + assert random_choice in [2, 4, 6, 8] + + choice = out_mutable.get_choice(0) + assert choice == 2 + + max_choice = out_mutable.max_choice + assert max_choice == 8 + out_mutable.current_choice = max_choice + assert torch.equal(out_mutable.mask, + torch.ones_like(out_mutable.mask).bool()) + + min_choice = out_mutable.min_choice + assert min_choice == 2 + out_mutable.current_choice = min_choice + min_mask = torch.zeros_like(out_mutable.mask).bool() + min_mask[:2] = True + assert torch.equal(out_mutable.mask, min_mask) + + with pytest.raises(AssertionError): + # Only mutables with mask_type == 'in_mask' (named in_mutable) can + # add `concat_mutables` + concat_mutables = [copy.deepcopy(out_mutable)] * 2 + out_mutable.register_same_mutable(concat_mutables) + + # Test in_mutable (mask_type == 'in_mask') with concat_mutable + in_mutable = RatioChannelMutable( + name='op', + mask_type='in_mask', + num_channels=16, + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + out_mutable1 = copy.deepcopy(out_mutable) + out_mutable2 = copy.deepcopy(out_mutable) + in_mutable.register_same_mutable([out_mutable1, out_mutable2]) + choice1 = out_mutable1.sample_choice() + out_mutable1.current_choice = choice1 + choice2 = out_mutable2.sample_choice() + out_mutable2.current_choice = choice2 + assert in_mutable.current_choice == choice1 + choice2 + assert torch.equal(in_mutable.mask, + torch.cat([out_mutable1.mask, out_mutable2.mask])) + + with pytest.raises(AssertionError): + # The mask of this in_mutable depends on the out mask of its + # `concat_mutables`, so the `sample_choice` method should not + # be called + in_mutable.sample_choice() + + with pytest.raises(AssertionError): + # The mask of this in_mutable depends on the out mask of its + # `concat_mutables`, so the `min_choice` property should not + # be called + _ = in_mutable.min_choice + + with pytest.raises(AssertionError): + # The mask of this in_mutable depends on the out mask of its + # `concat_mutables`, so the `get_choice` method should not + # be called + in_mutable.get_choice(0) + + def test_order_channel_mutable(self): + with pytest.raises(AssertionError): + # The candidate ratio should be in range(0, `num_channels`]. + OrderChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[0, 2, 4, 6, 8]) + + with pytest.raises(AssertionError): + # Type of `candidate_choices` should be int. + OrderChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[0., 2, 4, 6, 8]) + + # Test out_mutable (mask_type == 'out_mask') + out_mutable = OrderChannelMutable( + name='op', + mask_type='out_mask', + num_channels=8, + candidate_choices=[2, 4, 6, 8]) + + random_choice = out_mutable.sample_choice() + assert random_choice in [2, 4, 6, 8] + + choice = out_mutable.get_choice(0) + assert choice == 2 + + max_choice = out_mutable.max_choice + assert max_choice == 8 + out_mutable.current_choice = max_choice + assert torch.equal(out_mutable.mask, + torch.ones_like(out_mutable.mask).bool()) + + min_choice = out_mutable.min_choice + assert min_choice == 2 + out_mutable.current_choice = min_choice + min_mask = torch.zeros_like(out_mutable.mask).bool() + min_mask[:2] = True + assert torch.equal(out_mutable.mask, min_mask) diff --git a/tests/test_models/test_mutables/test_dynamic_layer.py b/tests/test_models/test_mutables/test_dynamic_layer.py new file mode 100644 index 00000000..1454d03d --- /dev/null +++ b/tests/test_models/test_mutables/test_dynamic_layer.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from torch import nn + +from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn, + build_dynamic_conv2d, + build_dynamic_gn, + build_dynamic_in, + build_dynamic_linear) + + +class TestDynamicLayer(TestCase): + + def test_dynamic_conv(self): + imgs = torch.rand(2, 8, 16, 16) + + in_channels_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + out_channels_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + conv = nn.Conv2d(8, 8, 1) + dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg, + out_channels_cfg) + # test forward + dynamic_conv(imgs) + + conv = nn.Conv2d(8, 8, 1, groups=8) + dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg, + out_channels_cfg) + # test forward + dynamic_conv(imgs) + + conv = nn.Conv2d(8, 8, 1, groups=4) + dynamic_conv = build_dynamic_conv2d(conv, 'op', in_channels_cfg, + out_channels_cfg) + # test forward + with self.assertRaisesRegex(NotImplementedError, + 'only support pruning the depth-wise'): + dynamic_conv(imgs) + + def test_dynamic_linear(self): + imgs = torch.rand(2, 8) + + in_features_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + out_features_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + linear = nn.Linear(8, 8) + dynamic_linear = build_dynamic_linear(linear, 'op', in_features_cfg, + out_features_cfg) + # test forward + dynamic_linear(imgs) + + def test_dynamic_batchnorm(self): + imgs = torch.rand(2, 8, 16, 16) + + num_features_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + bn = nn.BatchNorm2d(8) + dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg) + # test forward + dynamic_bn(imgs) + + bn = nn.BatchNorm2d(8, momentum=0) + dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg) + # test forward + dynamic_bn(imgs) + + bn = nn.BatchNorm2d(8) + bn.train() + dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg) + # test forward + dynamic_bn(imgs) + # test num_batches_tracked is not None + dynamic_bn(imgs) + + bn = nn.BatchNorm2d(8, affine=False) + dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg) + # test forward + dynamic_bn(imgs) + + bn = nn.BatchNorm2d(8, track_running_stats=False) + dynamic_bn = build_dynamic_bn(bn, 'bn', num_features_cfg) + # test forward + dynamic_bn(imgs) + + def test_dynamic_instancenorm(self): + imgs = torch.rand(2, 8, 16, 16) + + num_features_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + instance_norm = nn.InstanceNorm2d(8) + dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg) + # test forward + dynamic_in(imgs) + + instance_norm = nn.InstanceNorm2d(8, affine=False) + dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg) + # test forward + dynamic_in(imgs) + + instance_norm = nn.InstanceNorm2d(8, track_running_stats=False) + dynamic_in = build_dynamic_in(instance_norm, 'in', num_features_cfg) + # test forward + dynamic_in(imgs) + + def test_dynamic_groupnorm(self): + imgs = torch.rand(2, 8, 16, 16) + + num_channels_cfg = dict( + type='RatioChannelMutable', + candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0]) + + gn = nn.GroupNorm(num_groups=4, num_channels=8) + dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg) + # test forward + dynamic_gn(imgs) + + gn = nn.GroupNorm(num_groups=4, num_channels=8, affine=False) + dynamic_gn = build_dynamic_gn(gn, 'gn', num_channels_cfg) + # test forward + dynamic_gn(imgs) diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py new file mode 100644 index 00000000..1d5f4526 --- /dev/null +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import unittest +from os.path import dirname + +import mmcv.fileio +import pytest +import torch +from mmcls.models import * # noqa: F401,F403 +from torch import Tensor, nn +from torch.nn import Module + +from mmrazor import digit_version +from mmrazor.models.architectures.dynamic_op import (build_dynamic_bn, + build_dynamic_conv2d) +from mmrazor.models.mutables import SlimmableChannelMutable +from mmrazor.models.mutators import (OneShotChannelMutator, + SlimmableChannelMutator) +from mmrazor.registry import MODELS + +ONESHOT_MUTATOR_CFG = dict( + type='OneShotChannelMutator', + tracer_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + mutable_cfg=dict( + type='RatioChannelMutable', + candidate_choices=[ + 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 + ])) + +ONESHOT_MUTATOR_CFG_WITHOUT_TRACER = dict( + type='OneShotChannelMutator', + mutable_cfg=dict( + type='RatioChannelMutable', + candidate_choices=[ + 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 + ])) + + +class MultiConcatModel(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(16, 8, 1) + self.op4 = nn.Conv2d(3, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + cat1 = torch.cat([x1, x2], dim=1) + x3 = self.op3(cat1) + x4 = self.op4(x) + output = torch.cat([x3, x4], dim=1) + + return output + + +class MultiConcatModel2(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(3, 8, 1) + self.op4 = nn.Conv2d(24, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + cat1 = torch.cat([x1, x2], dim=1) + cat2 = torch.cat([cat1, x3], dim=1) + output = self.op4(cat2) + + return output + + +class ResBlock(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.bn1 = nn.BatchNorm2d(8) + self.op2 = nn.Conv2d(8, 8, 1) + self.bn2 = nn.BatchNorm2d(8) + self.op3 = nn.Conv2d(8, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.bn1(self.op1(x)) + x2 = self.bn2(self.op2(x1)) + x3 = self.op3(x2 + x1) + return x3 + + +class DynamicResBlock(Module): + + def __init__(self, mutable_cfg) -> None: + super().__init__() + + self.dynamic_op1 = build_dynamic_conv2d( + nn.Conv2d(3, 8, 1), 'dynamic_op1', mutable_cfg, mutable_cfg) + self.dynamic_bn1 = build_dynamic_bn( + nn.BatchNorm2d(8), 'dynamic_bn1', mutable_cfg) + self.dynamic_op2 = build_dynamic_conv2d( + nn.Conv2d(8, 8, 1), 'dynamic_op2', mutable_cfg, mutable_cfg) + self.dynamic_bn2 = build_dynamic_bn( + nn.BatchNorm2d(8), 'dynamic_bn2', mutable_cfg) + self.dynamic_op3 = build_dynamic_conv2d( + nn.Conv2d(8, 8, 1), 'dynamic_op3', mutable_cfg, mutable_cfg) + self._add_link() + + def _add_link(self): + op1_mutable_out = self.dynamic_op1.mutable_out + bn1_mutable_out = self.dynamic_bn1.mutable_out + + op2_mutable_in = self.dynamic_op2.mutable_in + op2_mutable_out = self.dynamic_op2.mutable_out + bn2_mutable_out = self.dynamic_bn2.mutable_out + + op3_mutable_in = self.dynamic_op3.mutable_in + + bn1_mutable_out.register_same_mutable(op1_mutable_out) + op1_mutable_out.register_same_mutable(bn1_mutable_out) + + op2_mutable_in.register_same_mutable(bn1_mutable_out) + bn1_mutable_out.register_same_mutable(op2_mutable_in) + + bn2_mutable_out.register_same_mutable(op2_mutable_out) + op2_mutable_out.register_same_mutable(bn2_mutable_out) + + op3_mutable_in.register_same_mutable(bn1_mutable_out) + bn1_mutable_out.register_same_mutable(op3_mutable_in) + + op3_mutable_in.register_same_mutable(bn2_mutable_out) + bn2_mutable_out.register_same_mutable(op3_mutable_in) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.dynamic_bn1(self.dynamic_op1(x)) + x2 = self.dynamic_bn2(self.dynamic_op2(x1)) + x3 = self.dynamic_op3(x2 + x1) + return x3 + + +@unittest.skipIf( + digit_version(torch.__version__) == digit_version('1.8.1'), + 'PyTorch version 1.8.1 is not supported by the Backward Tracer.') +def test_oneshot_channel_mutator() -> None: + imgs = torch.randn(16, 3, 224, 224) + + def _test(model): + mutator.prepare_from_supernet(model) + for key, val in mutator.search_groups.items(): + print(key, val) + # print(mutator.search_groups) + assert hasattr(mutator, 'name2module') + + # test set_min_choices + mutator.set_min_choices() + for mutables in mutator.search_groups.values(): + for mutable in mutables: + # 1 / 8 is the minimum candidate ratio + assert mutable.current_choice == round(1 / 8 * + mutable.num_channels) + + # test set_max_channel + mutator.set_max_choices() + for mutables in mutator.search_groups.values(): + for mutable in mutables: + # 1.0 is the maximum candidate ratio + assert mutable.current_choice == round(1. * + mutable.num_channels) + + # test making groups logic + choice_dict = mutator.sample_choices() + assert isinstance(choice_dict, dict) + mutator.set_choices(choice_dict) + model(imgs) + + mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG) + with pytest.raises(RuntimeError): + _ = mutator.search_groups + with pytest.raises(RuntimeError): + _ = mutator.name2module + + _test(ResBlock()) + _test(MultiConcatModel()) + _test(MultiConcatModel2()) + _test(nn.Sequential(nn.BatchNorm2d(3))) + + mutator: OneShotChannelMutator = MODELS.build( + ONESHOT_MUTATOR_CFG_WITHOUT_TRACER) + dynamic_model = DynamicResBlock( + ONESHOT_MUTATOR_CFG_WITHOUT_TRACER['mutable_cfg']) + _test(dynamic_model) + + +def test_slimmable_channel_mutator() -> None: + imgs = torch.randn(16, 3, 224, 224) + + root_path = dirname(dirname(dirname(__file__))) + channel_cfgs = [ + os.path.join(root_path, 'data/subnet1.yaml'), + os.path.join(root_path, 'data/subnet2.yaml') + ] + channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs] + + mutator = SlimmableChannelMutator( + mutable_cfg=dict(type='SlimmableChannelMutable'), + channel_cfgs=channel_cfgs) + + model = ResBlock() + mutator.prepare_from_supernet(model) + mutator.switch_choices(0) + for name, module in model.named_modules(): + if isinstance(module, SlimmableChannelMutable): + assert module.current_choice == 0 + _ = model(imgs) + + mutator.switch_choices(1) + for name, module in model.named_modules(): + if isinstance(module, SlimmableChannelMutable): + assert module.current_choice == 1 + _ = model(imgs) diff --git a/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py b/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py new file mode 100644 index 00000000..f015e307 --- /dev/null +++ b/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import unittest +from os.path import dirname + +import mmcv.fileio +import torch +from mmcls.core import ClsDataSample +from mmcls.models import * # noqa: F401,F403 + +from mmrazor import digit_version +from mmrazor.models.mutables import SlimmableChannelMutable +from mmrazor.models.mutators import (OneShotChannelMutator, + SlimmableChannelMutator) +from mmrazor.registry import MODELS + +MODEL_CFG = dict( + type='mmcls.ImageClassifier', + backbone=dict(type='MobileNetV2', widen_factor=1.5), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1920, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) + +ONESHOT_MUTATOR_CFG = dict( + type='OneShotChannelMutator', + skip_prefixes=['head.fc'], + tracer_cfg=dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + mutable_cfg=dict( + type='RatioChannelMutable', + candidate_choices=[ + 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 + ])) + + +@unittest.skipIf( + digit_version(torch.__version__) == digit_version('1.8.1'), + 'PyTorch version 1.8.1 is not supported by the Backward Tracer.') +def test_oneshot_channel_mutator() -> None: + imgs = torch.randn(16, 3, 224, 224) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) + ] + + model = MODELS.build(MODEL_CFG) + mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG) + + mutator.prepare_from_supernet(model) + assert hasattr(mutator, 'name2module') + + # test set_min_choices + mutator.set_min_choices() + for mutables in mutator.search_groups.values(): + for mutable in mutables: + # 1 / 8 is the minimum candidate ratio + assert mutable.current_choice == round(1 / 8 * + mutable.num_channels) + + # test set_max_channel + mutator.set_max_choices() + for mutables in mutator.search_groups.values(): + for mutable in mutables: + # 1.0 is the maximum candidate ratio + assert mutable.current_choice == round(1. * mutable.num_channels) + + # test making groups logic + choice_dict = mutator.sample_choices() + assert isinstance(choice_dict, dict) + mutator.set_choices(choice_dict) + model(imgs, data_samples=data_samples, mode='loss') + + +def test_slimmable_channel_mutator() -> None: + imgs = torch.randn(16, 3, 224, 224) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) + ] + + root_path = dirname(dirname(dirname(dirname(__file__)))) + channel_cfgs = [ + os.path.join(root_path, 'data/MBV2_320M.yaml'), + os.path.join(root_path, 'data/MBV2_220M.yaml') + ] + channel_cfgs = [mmcv.fileio.load(path) for path in channel_cfgs] + + mutator = SlimmableChannelMutator( + mutable_cfg=dict(type='SlimmableChannelMutable'), + channel_cfgs=channel_cfgs) + + model = MODELS.build(MODEL_CFG) + mutator.prepare_from_supernet(model) + mutator.switch_choices(0) + for name, module in model.named_modules(): + if isinstance(module, SlimmableChannelMutable): + assert module.current_choice == 0 + model(imgs, data_samples=data_samples, mode='loss') + + mutator.switch_choices(1) + for name, module in model.named_modules(): + if isinstance(module, SlimmableChannelMutable): + assert module.current_choice == 1 + model(imgs, data_samples=data_samples, mode='loss') diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py index b940b8a4..546aee2e 100644 --- a/tests/test_models/test_mutators/test_diff_mutator.py +++ b/tests/test_models/test_mutators/test_diff_mutator.py @@ -104,7 +104,7 @@ class TestDiffMutator(TestCase): mutator: DiffOP = MODELS.build(self.MUTATOR_CFG) mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1, 2] + assert list(mutator.search_groups.keys()) == [0, 1, 2] def test_diff_mutator_diffop_model(self) -> None: model = SearchableModel(self.MUTABLE_CFG) @@ -118,7 +118,7 @@ class TestDiffMutator(TestCase): mutator: DiffOP = MODELS.build(mutator_cfg) mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1, 2] + assert list(mutator.search_groups.keys()) == [0, 1, 2] mutator.modify_supernet_forward() assert mutator.mutable_class_type == DiffMutable @@ -146,7 +146,7 @@ class TestDiffMutator(TestCase): mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1, 2] + assert list(mutator.search_groups.keys()) == [0, 1, 2] mutator.modify_supernet_forward() assert mutator.mutable_class_type == DiffMutable @@ -165,7 +165,7 @@ class TestDiffMutator(TestCase): mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1, 2, 3] + assert list(mutator.search_groups.keys()) == [0, 1, 2, 3] mutator.modify_supernet_forward() assert mutator.mutable_class_type == DiffMutable diff --git a/tests/test_models/test_mutators/test_one_shot_mutator.py b/tests/test_models/test_mutators/test_one_shot_mutator.py index b3f0db8a..b99062b7 100644 --- a/tests/test_models/test_mutators/test_one_shot_mutator.py +++ b/tests/test_models/test_mutators/test_one_shot_mutator.py @@ -59,10 +59,10 @@ def test_one_shot_mutator_normal_model() -> None: assert mutator.mutable_class_type == OneShotMutable with pytest.raises(RuntimeError): - _ = mutator.search_group + _ = mutator.search_groups mutator.prepare_from_supernet(model) - assert len(mutator.search_group) == 0 + assert len(mutator.search_groups) == 0 assert len(mutator.sample_choices()) == 0 @@ -89,7 +89,7 @@ def test_one_shot_mutator_mutable_model() -> None: # import pdb; pdb.set_trace() mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1, 2] + assert list(mutator.search_groups.keys()) == [0, 1, 2] random_choices = mutator.sample_choices() assert list(random_choices.keys()) == [0, 1, 2] @@ -102,7 +102,7 @@ def test_one_shot_mutator_mutable_model() -> None: mutator = MODELS.build(mutator_cfg) mutator.prepare_from_supernet(model) - assert list(mutator.search_group.keys()) == [0, 1] + assert list(mutator.search_groups.keys()) == [0, 1] random_choices = mutator.sample_choices() assert list(random_choices.keys()) == [0, 1]