[Feature] Add mutator (#3)
parent
acc8a9c913
commit
3da9b4c8a1
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .darts_mutator import DartsMutator
|
||||
from .differentiable_mutator import DifferentiableMutator
|
||||
from .one_shot_mutator import OneShotMutator
|
||||
|
||||
__all__ = ['DifferentiableMutator', 'DartsMutator', 'OneShotMutator']
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta
|
||||
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmrazor.models.architectures import Placeholder
|
||||
from mmrazor.models.builder import MUTABLES, MUTATORS
|
||||
from mmrazor.models.mutables import MutableModule
|
||||
|
||||
|
||||
@MUTATORS.register_module()
|
||||
class BaseMutator(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for mutators."""
|
||||
|
||||
def __init__(self, placeholder_mapping=None, init_cfg=None):
|
||||
super(BaseMutator, self).__init__(init_cfg=init_cfg)
|
||||
self.placeholder_mapping = placeholder_mapping
|
||||
|
||||
def prepare_from_supernet(self, supernet):
|
||||
"""Implement some preparatory work based on supernet, including
|
||||
``convert_placeholder`` and ``build_search_spaces``.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
"""
|
||||
if self.placeholder_mapping is not None:
|
||||
self.convert_placeholder(supernet, self.placeholder_mapping)
|
||||
|
||||
self.search_spaces = self.build_search_spaces(supernet)
|
||||
|
||||
def build_search_spaces(self, supernet):
|
||||
"""Build a search space from the supernet.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
|
||||
Returns:
|
||||
dict: To collect some information about ``MutableModule`` in the
|
||||
supernet.
|
||||
"""
|
||||
search_spaces = dict()
|
||||
|
||||
def traverse(module):
|
||||
for child in module.children():
|
||||
if isinstance(child, MutableModule):
|
||||
if child.space_id not in search_spaces.keys():
|
||||
search_spaces[child.space_id] = dict(
|
||||
modules=[child],
|
||||
choice_names=child.choice_names,
|
||||
num_chosen=child.num_chosen,
|
||||
space_mask=child.build_space_mask())
|
||||
else:
|
||||
search_spaces[child.space_id]['modules'].append(child)
|
||||
|
||||
traverse(child)
|
||||
|
||||
traverse(supernet)
|
||||
return search_spaces
|
||||
|
||||
def convert_placeholder(self, supernet, placeholder_mapping):
|
||||
"""Replace all placeholders in the model.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used in
|
||||
your algorithm.
|
||||
placeholder_mapping (dict): Record which placeholders need to be
|
||||
replaced by which ops,
|
||||
its keys are the properties ``placeholder_group`` of
|
||||
placeholders used in the searchable architecture,
|
||||
its values are the registered ``OPS``.
|
||||
"""
|
||||
|
||||
def traverse(module):
|
||||
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, Placeholder):
|
||||
|
||||
mutable_cfg = placeholder_mapping[
|
||||
child.placeholder_group].copy()
|
||||
assert 'type' in mutable_cfg, f'{mutable_cfg}'
|
||||
mutable_type = mutable_cfg.pop('type')
|
||||
assert mutable_type in MUTABLES, \
|
||||
f'{mutable_type} not in MUTABLES.'
|
||||
mutable_constructor = MUTABLES.get(mutable_type)
|
||||
mutable_kwargs = child.placeholder_kwargs
|
||||
mutable_kwargs.update(mutable_cfg)
|
||||
mutable_module = mutable_constructor(**mutable_kwargs)
|
||||
setattr(module, name, mutable_module)
|
||||
|
||||
# setattr(module, name, choice_module)
|
||||
# If the new MUTABLE is MutableEdge, it may have MutableOP,
|
||||
# so here we need to traverse the new MUTABLES.
|
||||
traverse(mutable_module)
|
||||
else:
|
||||
traverse(child)
|
||||
|
||||
traverse(supernet)
|
||||
|
||||
def deploy_subnet(self, supernet, subnet_dict):
|
||||
"""Export the subnet from the supernet based on the specified
|
||||
subnet_dict.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used in
|
||||
your algorithm.
|
||||
subnet_dict (dict): Record the information to build the subnet from
|
||||
the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are dicts: {'chosen': ['chosen name1',
|
||||
'chosen name2', ...]}
|
||||
"""
|
||||
|
||||
def traverse(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, MutableModule):
|
||||
space_id = child.space_id
|
||||
chosen = subnet_dict[space_id]['chosen']
|
||||
child.export(chosen)
|
||||
|
||||
traverse(child)
|
||||
|
||||
traverse(supernet)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmrazor.models.builder import MUTATORS
|
||||
from .differentiable_mutator import DifferentiableMutator
|
||||
|
||||
|
||||
@MUTATORS.register_module()
|
||||
class DartsMutator(DifferentiableMutator):
|
||||
|
||||
def __init__(self, ignore_choices=('zero', ), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ignore_choices = ignore_choices
|
||||
|
||||
def search_subnet(self):
|
||||
|
||||
subnet_dict = dict()
|
||||
for space_id, sub_space in self.search_spaces.items():
|
||||
if space_id in self.arch_params:
|
||||
space_arch_param = self.arch_params[space_id]
|
||||
arch_probs = F.softmax(space_arch_param, dim=-1)
|
||||
choice_names = sub_space['choice_names']
|
||||
keep_idx = [
|
||||
i for i, name in enumerate(choice_names)
|
||||
if name not in self.ignore_choices
|
||||
]
|
||||
best_choice_prob, best_choice_idx = torch.max(
|
||||
arch_probs[keep_idx], 0)
|
||||
best_choice_idx = keep_idx[best_choice_idx.item()]
|
||||
best_choice_name = choice_names[best_choice_idx]
|
||||
|
||||
subnet_dict[space_id] = dict(
|
||||
chosen=[best_choice_name],
|
||||
chosen_probs=[best_choice_prob.item()])
|
||||
|
||||
def sort_key(x):
|
||||
return subnet_dict[x]['chosen_probs'][0]
|
||||
|
||||
for space_id, sub_space in self.search_spaces.items():
|
||||
if space_id not in self.arch_params:
|
||||
num_chosen = sub_space['num_chosen']
|
||||
choice_names = sub_space['choice_names']
|
||||
sorted_edges = list(
|
||||
sorted(choice_names, key=sort_key, reverse=True))
|
||||
chosen = sorted_edges[:num_chosen]
|
||||
subnet_dict[space_id] = dict(chosen=chosen)
|
||||
|
||||
for not_chosen in sorted_edges[num_chosen:]:
|
||||
subnet_dict.pop(not_chosen)
|
||||
|
||||
return subnet_dict
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.models.builder import MUTATORS
|
||||
from mmrazor.models.mutables import MutableModule
|
||||
from .base import BaseMutator
|
||||
|
||||
|
||||
@MUTATORS.register_module()
|
||||
class DifferentiableMutator(BaseMutator):
|
||||
"""A mutator for the differentiable NAS, which mainly provide some core
|
||||
functions of changing the structure of ``ARCHITECTURES``."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def prepare_from_supernet(self, supernet):
|
||||
"""Inherit from ``BaseMutator``'s, execute some customized functions
|
||||
exclude implementing origin ``prepare_from_supernet``.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
"""
|
||||
super().prepare_from_supernet(supernet)
|
||||
self.arch_params = self.build_arch_params(supernet)
|
||||
self.modify_supernet_forward(supernet)
|
||||
|
||||
def build_arch_params(self, supernet):
|
||||
"""This function will build many arch params, which are generally used
|
||||
in diffirentiale search algorithms, such as Darts' series. Each
|
||||
space_id corresponds to an arch param, so the Mutable with the same
|
||||
space_id share the same arch param.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
|
||||
Returns:
|
||||
torch.nn.ParameterDict: the arch params are got after traversing
|
||||
the supernet.
|
||||
"""
|
||||
|
||||
arch_params = nn.ParameterDict()
|
||||
|
||||
# Traverse all the child modules of the model. If a child module is an
|
||||
# Space instance and its space_id is not recorded, call its
|
||||
# :func:'build_space_architecture' and record the return value. If not,
|
||||
# pass.
|
||||
def traverse(module):
|
||||
for name, child in module.named_children():
|
||||
|
||||
if isinstance(child, MutableModule):
|
||||
space_id = child.space_id
|
||||
if space_id not in arch_params:
|
||||
space_arch_param = child.build_arch_param()
|
||||
if space_arch_param is not None:
|
||||
arch_params[space_id] = space_arch_param
|
||||
|
||||
traverse(child)
|
||||
|
||||
traverse(supernet)
|
||||
|
||||
return arch_params
|
||||
|
||||
def modify_supernet_forward(self, supernet):
|
||||
"""Modify the supernet's default value in forward. Traverse all child
|
||||
modules of the model, modify the supernet's default value in
|
||||
:func:'forward' of each Space.
|
||||
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The architecture to be used
|
||||
in your algorithm.
|
||||
"""
|
||||
|
||||
def traverse(module):
|
||||
for name, child in module.named_children():
|
||||
|
||||
if isinstance(child, MutableModule):
|
||||
if child.space_id in self.arch_params.keys():
|
||||
space_id = child.space_id
|
||||
space_arch_param = self.arch_params[space_id]
|
||||
child.forward = partial(
|
||||
child.forward, arch_param=space_arch_param)
|
||||
|
||||
traverse(child)
|
||||
|
||||
traverse(supernet)
|
||||
|
||||
@abstractmethod
|
||||
def search_subnet(self):
|
||||
pass
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mmrazor.models.builder import MUTATORS
|
||||
from .base import BaseMutator
|
||||
|
||||
|
||||
@MUTATORS.register_module()
|
||||
class OneShotMutator(BaseMutator):
|
||||
"""A mutator for the one-shot NAS, which mainly provide some core functions
|
||||
of changing the structure of ``ARCHITECTURES``."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_random_mask(space_info, searching):
|
||||
"""Generate random mask for randomly sampling.
|
||||
|
||||
Args:
|
||||
space_info (dict): Record the information of the space need
|
||||
to sample.
|
||||
searching (bool): Whether is in search stage.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Random mask generated.
|
||||
"""
|
||||
space_mask = space_info['space_mask']
|
||||
num_chosen = space_info['num_chosen']
|
||||
assert num_chosen <= space_mask.size()[0]
|
||||
choice_idx = torch.multinomial(space_mask, num_chosen)
|
||||
choice_mask = torch.zeros_like(space_mask)
|
||||
choice_mask[choice_idx] = 1
|
||||
if dist.is_available() and dist.is_initialized() and not searching:
|
||||
dist.broadcast(choice_mask, src=0)
|
||||
return choice_mask
|
||||
|
||||
def sample_subnet(self, searching=False):
|
||||
"""Random sample subnet by random mask.
|
||||
|
||||
Args:
|
||||
searching (bool): Whether is in search stage.
|
||||
|
||||
Returns:
|
||||
dict: Record the information to build the subnet from the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are random mask generated.
|
||||
"""
|
||||
subnet_dict = dict()
|
||||
for space_id, space_info in self.search_spaces.items():
|
||||
subnet_dict[space_id] = self.get_random_mask(space_info, searching)
|
||||
return subnet_dict
|
||||
|
||||
def set_subnet(self, subnet_dict):
|
||||
"""Setting subnet in the supernet based on the result of
|
||||
``sample_subnet`` by changing the flag: ``in_subnet``, which is easy to
|
||||
implement some operations for subnet, such as ``forward``, calculate
|
||||
flops and so on.
|
||||
|
||||
Args:
|
||||
subnet_dict (dict): Record the information to build the subnet
|
||||
from the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are masks.
|
||||
"""
|
||||
for space_id, space_info in self.search_spaces.items():
|
||||
choice_mask = subnet_dict[space_id]
|
||||
for module in space_info['modules']:
|
||||
module.choice_mask = choice_mask
|
||||
for i, choice in enumerate(module.choices.values()):
|
||||
if choice_mask[i]:
|
||||
choice.apply(
|
||||
partial(self.reset_in_subnet, in_subnet=True))
|
||||
else:
|
||||
choice.apply(
|
||||
partial(self.reset_in_subnet, in_subnet=False))
|
||||
|
||||
@staticmethod
|
||||
def reset_in_subnet(m, in_subnet=True):
|
||||
"""Reset the module's attribution.
|
||||
|
||||
Args:
|
||||
m (:obj:`torch.nn.Module`): The module in the supernet.
|
||||
in_subnet (bool): If the module in subnet, set ``in_subnet`` to
|
||||
True, otherwise set to False.
|
||||
"""
|
||||
m.__in_subnet__ = in_subnet
|
||||
|
||||
def set_chosen_subnet(self, subnet_dict):
|
||||
"""Set chosen subnet in the search_spaces after searching stage.
|
||||
|
||||
Args:
|
||||
subnet_dict (dict): Record the information to build the subnet from
|
||||
the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are masks.
|
||||
"""
|
||||
for space_id, mask in subnet_dict.items():
|
||||
idxs = [i for i, x in enumerate(mask.tolist()) if x == 1.0]
|
||||
self.search_spaces[space_id]['chosen'] = [
|
||||
self.search_spaces[space_id]['choice_names'][i] for i in idxs
|
||||
]
|
||||
|
||||
def mutation(self, subnet_dict, prob=0.1):
|
||||
"""Mutation used in evolution search.
|
||||
|
||||
Args:
|
||||
subnet_dict (dict): Record the information to build the subnet
|
||||
from the supernet, its keys are the properties ``space_id``
|
||||
of placeholders in the mutator's search spaces, its values
|
||||
are masks.
|
||||
prob (float): The probability of mutation.
|
||||
|
||||
Returns:
|
||||
dict: A new subnet_dict after mutation.
|
||||
"""
|
||||
mutation_subnet_dict = copy.deepcopy(subnet_dict)
|
||||
for name, mask in subnet_dict.items():
|
||||
if np.random.random_sample() < prob:
|
||||
mutation_subnet_dict[name] = self.get_random_mask(
|
||||
self.search_spaces[name], searching=True)
|
||||
return mutation_subnet_dict
|
||||
|
||||
@staticmethod
|
||||
def crossover(subnet_dict1, subnet_dict2):
|
||||
"""Crossover used in evolution search.
|
||||
|
||||
Args:
|
||||
subnet_dict1 (dict): Record the information to build the subnet
|
||||
from the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are masks.
|
||||
subnet_dict2 (dict): Record the information to build the subnet
|
||||
from the supernet,
|
||||
its keys are the properties ``space_id`` of placeholders in the
|
||||
mutator's search spaces,
|
||||
its values are masks.
|
||||
|
||||
Returns:
|
||||
dict: A new subnet_dict after crossover.
|
||||
"""
|
||||
crossover_subnet_dict = copy.deepcopy(subnet_dict1)
|
||||
for name, mask in subnet_dict2.items():
|
||||
if np.random.random_sample() < 0.5:
|
||||
crossover_subnet_dict[name] = mask
|
||||
return crossover_subnet_dict
|
Loading…
Reference in New Issue