[Feature] Add mutator (#3)

pull/255/head
humu789 2021-12-23 04:17:13 +08:00 committed by GitHub
parent acc8a9c913
commit 3da9b4c8a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 434 additions and 0 deletions

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_mutator import DartsMutator
from .differentiable_mutator import DifferentiableMutator
from .one_shot_mutator import OneShotMutator
__all__ = ['DifferentiableMutator', 'DartsMutator', 'OneShotMutator']

View File

@ -0,0 +1,125 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from mmcv.runner import BaseModule
from mmrazor.models.architectures import Placeholder
from mmrazor.models.builder import MUTABLES, MUTATORS
from mmrazor.models.mutables import MutableModule
@MUTATORS.register_module()
class BaseMutator(BaseModule, metaclass=ABCMeta):
"""Base class for mutators."""
def __init__(self, placeholder_mapping=None, init_cfg=None):
super(BaseMutator, self).__init__(init_cfg=init_cfg)
self.placeholder_mapping = placeholder_mapping
def prepare_from_supernet(self, supernet):
"""Implement some preparatory work based on supernet, including
``convert_placeholder`` and ``build_search_spaces``.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
"""
if self.placeholder_mapping is not None:
self.convert_placeholder(supernet, self.placeholder_mapping)
self.search_spaces = self.build_search_spaces(supernet)
def build_search_spaces(self, supernet):
"""Build a search space from the supernet.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
Returns:
dict: To collect some information about ``MutableModule`` in the
supernet.
"""
search_spaces = dict()
def traverse(module):
for child in module.children():
if isinstance(child, MutableModule):
if child.space_id not in search_spaces.keys():
search_spaces[child.space_id] = dict(
modules=[child],
choice_names=child.choice_names,
num_chosen=child.num_chosen,
space_mask=child.build_space_mask())
else:
search_spaces[child.space_id]['modules'].append(child)
traverse(child)
traverse(supernet)
return search_spaces
def convert_placeholder(self, supernet, placeholder_mapping):
"""Replace all placeholders in the model.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used in
your algorithm.
placeholder_mapping (dict): Record which placeholders need to be
replaced by which ops,
its keys are the properties ``placeholder_group`` of
placeholders used in the searchable architecture,
its values are the registered ``OPS``.
"""
def traverse(module):
for name, child in module.named_children():
if isinstance(child, Placeholder):
mutable_cfg = placeholder_mapping[
child.placeholder_group].copy()
assert 'type' in mutable_cfg, f'{mutable_cfg}'
mutable_type = mutable_cfg.pop('type')
assert mutable_type in MUTABLES, \
f'{mutable_type} not in MUTABLES.'
mutable_constructor = MUTABLES.get(mutable_type)
mutable_kwargs = child.placeholder_kwargs
mutable_kwargs.update(mutable_cfg)
mutable_module = mutable_constructor(**mutable_kwargs)
setattr(module, name, mutable_module)
# setattr(module, name, choice_module)
# If the new MUTABLE is MutableEdge, it may have MutableOP,
# so here we need to traverse the new MUTABLES.
traverse(mutable_module)
else:
traverse(child)
traverse(supernet)
def deploy_subnet(self, supernet, subnet_dict):
"""Export the subnet from the supernet based on the specified
subnet_dict.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used in
your algorithm.
subnet_dict (dict): Record the information to build the subnet from
the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are dicts: {'chosen': ['chosen name1',
'chosen name2', ...]}
"""
def traverse(module):
for name, child in module.named_children():
if isinstance(child, MutableModule):
space_id = child.space_id
chosen = subnet_dict[space_id]['chosen']
child.export(chosen)
traverse(child)
traverse(supernet)

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn import functional as F
from mmrazor.models.builder import MUTATORS
from .differentiable_mutator import DifferentiableMutator
@MUTATORS.register_module()
class DartsMutator(DifferentiableMutator):
def __init__(self, ignore_choices=('zero', ), **kwargs):
super().__init__(**kwargs)
self.ignore_choices = ignore_choices
def search_subnet(self):
subnet_dict = dict()
for space_id, sub_space in self.search_spaces.items():
if space_id in self.arch_params:
space_arch_param = self.arch_params[space_id]
arch_probs = F.softmax(space_arch_param, dim=-1)
choice_names = sub_space['choice_names']
keep_idx = [
i for i, name in enumerate(choice_names)
if name not in self.ignore_choices
]
best_choice_prob, best_choice_idx = torch.max(
arch_probs[keep_idx], 0)
best_choice_idx = keep_idx[best_choice_idx.item()]
best_choice_name = choice_names[best_choice_idx]
subnet_dict[space_id] = dict(
chosen=[best_choice_name],
chosen_probs=[best_choice_prob.item()])
def sort_key(x):
return subnet_dict[x]['chosen_probs'][0]
for space_id, sub_space in self.search_spaces.items():
if space_id not in self.arch_params:
num_chosen = sub_space['num_chosen']
choice_names = sub_space['choice_names']
sorted_edges = list(
sorted(choice_names, key=sort_key, reverse=True))
chosen = sorted_edges[:num_chosen]
subnet_dict[space_id] = dict(chosen=chosen)
for not_chosen in sorted_edges[num_chosen:]:
subnet_dict.pop(not_chosen)
return subnet_dict

View File

@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from functools import partial
from torch import nn
from mmrazor.models.builder import MUTATORS
from mmrazor.models.mutables import MutableModule
from .base import BaseMutator
@MUTATORS.register_module()
class DifferentiableMutator(BaseMutator):
"""A mutator for the differentiable NAS, which mainly provide some core
functions of changing the structure of ``ARCHITECTURES``."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def prepare_from_supernet(self, supernet):
"""Inherit from ``BaseMutator``'s, execute some customized functions
exclude implementing origin ``prepare_from_supernet``.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
"""
super().prepare_from_supernet(supernet)
self.arch_params = self.build_arch_params(supernet)
self.modify_supernet_forward(supernet)
def build_arch_params(self, supernet):
"""This function will build many arch params, which are generally used
in diffirentiale search algorithms, such as Darts' series. Each
space_id corresponds to an arch param, so the Mutable with the same
space_id share the same arch param.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
Returns:
torch.nn.ParameterDict: the arch params are got after traversing
the supernet.
"""
arch_params = nn.ParameterDict()
# Traverse all the child modules of the model. If a child module is an
# Space instance and its space_id is not recorded, call its
# :func:'build_space_architecture' and record the return value. If not,
# pass.
def traverse(module):
for name, child in module.named_children():
if isinstance(child, MutableModule):
space_id = child.space_id
if space_id not in arch_params:
space_arch_param = child.build_arch_param()
if space_arch_param is not None:
arch_params[space_id] = space_arch_param
traverse(child)
traverse(supernet)
return arch_params
def modify_supernet_forward(self, supernet):
"""Modify the supernet's default value in forward. Traverse all child
modules of the model, modify the supernet's default value in
:func:'forward' of each Space.
Args:
supernet (:obj:`torch.nn.Module`): The architecture to be used
in your algorithm.
"""
def traverse(module):
for name, child in module.named_children():
if isinstance(child, MutableModule):
if child.space_id in self.arch_params.keys():
space_id = child.space_id
space_arch_param = self.arch_params[space_id]
child.forward = partial(
child.forward, arch_param=space_arch_param)
traverse(child)
traverse(supernet)
@abstractmethod
def search_subnet(self):
pass

View File

@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from functools import partial
import numpy as np
import torch
import torch.distributed as dist
from mmrazor.models.builder import MUTATORS
from .base import BaseMutator
@MUTATORS.register_module()
class OneShotMutator(BaseMutator):
"""A mutator for the one-shot NAS, which mainly provide some core functions
of changing the structure of ``ARCHITECTURES``."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def get_random_mask(space_info, searching):
"""Generate random mask for randomly sampling.
Args:
space_info (dict): Record the information of the space need
to sample.
searching (bool): Whether is in search stage.
Returns:
torch.Tensor: Random mask generated.
"""
space_mask = space_info['space_mask']
num_chosen = space_info['num_chosen']
assert num_chosen <= space_mask.size()[0]
choice_idx = torch.multinomial(space_mask, num_chosen)
choice_mask = torch.zeros_like(space_mask)
choice_mask[choice_idx] = 1
if dist.is_available() and dist.is_initialized() and not searching:
dist.broadcast(choice_mask, src=0)
return choice_mask
def sample_subnet(self, searching=False):
"""Random sample subnet by random mask.
Args:
searching (bool): Whether is in search stage.
Returns:
dict: Record the information to build the subnet from the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are random mask generated.
"""
subnet_dict = dict()
for space_id, space_info in self.search_spaces.items():
subnet_dict[space_id] = self.get_random_mask(space_info, searching)
return subnet_dict
def set_subnet(self, subnet_dict):
"""Setting subnet in the supernet based on the result of
``sample_subnet`` by changing the flag: ``in_subnet``, which is easy to
implement some operations for subnet, such as ``forward``, calculate
flops and so on.
Args:
subnet_dict (dict): Record the information to build the subnet
from the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are masks.
"""
for space_id, space_info in self.search_spaces.items():
choice_mask = subnet_dict[space_id]
for module in space_info['modules']:
module.choice_mask = choice_mask
for i, choice in enumerate(module.choices.values()):
if choice_mask[i]:
choice.apply(
partial(self.reset_in_subnet, in_subnet=True))
else:
choice.apply(
partial(self.reset_in_subnet, in_subnet=False))
@staticmethod
def reset_in_subnet(m, in_subnet=True):
"""Reset the module's attribution.
Args:
m (:obj:`torch.nn.Module`): The module in the supernet.
in_subnet (bool): If the module in subnet, set ``in_subnet`` to
True, otherwise set to False.
"""
m.__in_subnet__ = in_subnet
def set_chosen_subnet(self, subnet_dict):
"""Set chosen subnet in the search_spaces after searching stage.
Args:
subnet_dict (dict): Record the information to build the subnet from
the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are masks.
"""
for space_id, mask in subnet_dict.items():
idxs = [i for i, x in enumerate(mask.tolist()) if x == 1.0]
self.search_spaces[space_id]['chosen'] = [
self.search_spaces[space_id]['choice_names'][i] for i in idxs
]
def mutation(self, subnet_dict, prob=0.1):
"""Mutation used in evolution search.
Args:
subnet_dict (dict): Record the information to build the subnet
from the supernet, its keys are the properties ``space_id``
of placeholders in the mutator's search spaces, its values
are masks.
prob (float): The probability of mutation.
Returns:
dict: A new subnet_dict after mutation.
"""
mutation_subnet_dict = copy.deepcopy(subnet_dict)
for name, mask in subnet_dict.items():
if np.random.random_sample() < prob:
mutation_subnet_dict[name] = self.get_random_mask(
self.search_spaces[name], searching=True)
return mutation_subnet_dict
@staticmethod
def crossover(subnet_dict1, subnet_dict2):
"""Crossover used in evolution search.
Args:
subnet_dict1 (dict): Record the information to build the subnet
from the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are masks.
subnet_dict2 (dict): Record the information to build the subnet
from the supernet,
its keys are the properties ``space_id`` of placeholders in the
mutator's search spaces,
its values are masks.
Returns:
dict: A new subnet_dict after crossover.
"""
crossover_subnet_dict = copy.deepcopy(subnet_dict1)
for name, mask in subnet_dict2.items():
if np.random.random_sample() < 0.5:
crossover_subnet_dict[name] = mask
return crossover_subnet_dict