[Feature] Add Dsnas Algorithm (#226)
* [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOPpull/303/head
parent
d07dee9887
commit
8d603d917e
|
@ -0,0 +1,28 @@
|
|||
norm_cfg = dict(type='BN', eps=0.01)
|
||||
|
||||
_STAGE_MUTABLE = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
fix_threshold=0.3,
|
||||
candidates=dict(
|
||||
shuffle_3x3=dict(
|
||||
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
|
||||
shuffle_5x5=dict(
|
||||
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
|
||||
shuffle_7x7=dict(
|
||||
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
|
||||
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg)))
|
||||
|
||||
arch_setting = [
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, mutable_cfg.
|
||||
[64, 4, _STAGE_MUTABLE],
|
||||
[160, 4, _STAGE_MUTABLE],
|
||||
[320, 8, _STAGE_MUTABLE],
|
||||
[640, 4, _STAGE_MUTABLE]
|
||||
]
|
||||
|
||||
nas_backbone = dict(
|
||||
type='mmrazor.SearchableShuffleNetV2',
|
||||
widen_factor=1.0,
|
||||
arch_setting=arch_setting,
|
||||
norm_cfg=norm_cfg)
|
|
@ -0,0 +1,102 @@
|
|||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_preprocessor = dict(
|
||||
type='mmcls.ClsDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(type='mmcls.RandomResizedCrop', scale=224),
|
||||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='mmcls.CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# optimizer
|
||||
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0)
|
||||
|
||||
optim_wrapper = dict(
|
||||
constructor='mmrazor.SeparateOptimWrapperConstructor',
|
||||
architecture=dict(
|
||||
optimizer=dict(
|
||||
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
|
||||
paramwise_cfg=paramwise_cfg),
|
||||
mutator=dict(
|
||||
optimizer=dict(
|
||||
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5,
|
||||
0.999))))
|
||||
|
||||
search_epochs = 85
|
||||
# leanring policy
|
||||
param_scheduler = dict(
|
||||
architecture=[
|
||||
dict(
|
||||
type='mmcls.LinearLR',
|
||||
end=5,
|
||||
start_factor=0.2,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='mmcls.CosineAnnealingLR',
|
||||
T_max=240,
|
||||
begin=5,
|
||||
end=search_epochs,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='mmcls.CosineAnnealingLR',
|
||||
T_max=160,
|
||||
begin=search_epochs,
|
||||
end=240,
|
||||
eta_min=0.0,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True)
|
||||
],
|
||||
mutator=[])
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=240)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
|
@ -0,0 +1,20 @@
|
|||
backbone.layers.0.0: shuffle_3x3
|
||||
backbone.layers.0.1: shuffle_3x3
|
||||
backbone.layers.0.2: shuffle_xception
|
||||
backbone.layers.0.3: shuffle_3x3
|
||||
backbone.layers.1.0: shuffle_xception
|
||||
backbone.layers.1.1: shuffle_7x7
|
||||
backbone.layers.1.2: shuffle_3x3
|
||||
backbone.layers.1.3: shuffle_3x3
|
||||
backbone.layers.2.0: shuffle_xception
|
||||
backbone.layers.2.1: shuffle_xception
|
||||
backbone.layers.2.2: shuffle_7x7
|
||||
backbone.layers.2.3: shuffle_xception
|
||||
backbone.layers.2.4: shuffle_xception
|
||||
backbone.layers.2.5: shuffle_xception
|
||||
backbone.layers.2.6: shuffle_7x7
|
||||
backbone.layers.2.7: shuffle_3x3
|
||||
backbone.layers.3.0: shuffle_3x3
|
||||
backbone.layers.3.1: shuffle_xception
|
||||
backbone.layers.3.2: shuffle_xception
|
||||
backbone.layers.3.3: shuffle_3x3
|
|
@ -0,0 +1,43 @@
|
|||
# DSNAS
|
||||
|
||||
> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Most existing NAS methods require two-stage parameter optimization.
|
||||
However, performance of the same architecture in the two stages correlates poorly.
|
||||
Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining.
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
|
||||
### Supernet
|
||||
|
||||
| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
|
||||
| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: |
|
||||
| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
|
||||
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
|
||||
2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51.
|
||||
3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the
|
||||
origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59.
|
||||
4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS.
|
||||
5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo.
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@inproceedings{hu2020dsnas,
|
||||
title={Dsnas: Direct neural architecture search without parameter retraining},
|
||||
author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={12084--12092},
|
||||
year={2020}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,29 @@
|
|||
_base_ = ['./dsnas_supernet_8xb128_in1k.py']
|
||||
|
||||
# NOTE: Replace this with the mutable_cfg searched by yourself.
|
||||
fix_subnet = {
|
||||
'backbone.layers.0.0': 'shuffle_3x3',
|
||||
'backbone.layers.0.1': 'shuffle_7x7',
|
||||
'backbone.layers.0.2': 'shuffle_3x3',
|
||||
'backbone.layers.0.3': 'shuffle_5x5',
|
||||
'backbone.layers.1.0': 'shuffle_3x3',
|
||||
'backbone.layers.1.1': 'shuffle_3x3',
|
||||
'backbone.layers.1.2': 'shuffle_3x3',
|
||||
'backbone.layers.1.3': 'shuffle_7x7',
|
||||
'backbone.layers.2.0': 'shuffle_xception',
|
||||
'backbone.layers.2.1': 'shuffle_3x3',
|
||||
'backbone.layers.2.2': 'shuffle_3x3',
|
||||
'backbone.layers.2.3': 'shuffle_5x5',
|
||||
'backbone.layers.2.4': 'shuffle_3x3',
|
||||
'backbone.layers.2.5': 'shuffle_5x5',
|
||||
'backbone.layers.2.6': 'shuffle_7x7',
|
||||
'backbone.layers.2.7': 'shuffle_7x7',
|
||||
'backbone.layers.3.0': 'shuffle_xception',
|
||||
'backbone.layers.3.1': 'shuffle_3x3',
|
||||
'backbone.layers.3.2': 'shuffle_7x7',
|
||||
'backbone.layers.3.3': 'shuffle_3x3',
|
||||
}
|
||||
|
||||
model = dict(fix_subnet=fix_subnet)
|
||||
|
||||
find_unused_parameters = False
|
|
@ -0,0 +1,36 @@
|
|||
_base_ = [
|
||||
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
|
||||
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
|
||||
'mmcls::_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model
|
||||
model = dict(
|
||||
type='mmrazor.Dsnas',
|
||||
architecture=dict(
|
||||
type='ImageClassifier',
|
||||
data_preprocessor=_base_.data_preprocessor,
|
||||
backbone=_base_.nas_backbone,
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5))),
|
||||
mutator=dict(type='mmrazor.DiffModuleMutator'),
|
||||
pretrain_epochs=15,
|
||||
finetune_epochs=_base_.search_epochs,
|
||||
)
|
||||
|
||||
model_wrapper_cfg = dict(
|
||||
type='mmrazor.DsnasDDP',
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=True)
|
||||
|
||||
randomness = dict(seed=48, diff_rank_seed=True)
|
|
@ -10,6 +10,6 @@ __all__ = [
|
|||
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
|
||||
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
|
||||
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
|
||||
'EstimateResourcesHook'
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
|
||||
'SelfDistillValLoop'
|
||||
]
|
||||
|
|
|
@ -3,12 +3,13 @@ from .base import BaseAlgorithm
|
|||
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
|
||||
FpnTeacherDistill, OverhaulFeatureDistillation,
|
||||
SelfDistill, SingleTeacherDistill)
|
||||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
|
||||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
|
||||
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
|
||||
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
|
||||
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
|
||||
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
|
||||
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation'
|
||||
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
|
||||
'DsnasDDP'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .autoslim import AutoSlim, AutoSlimDDP
|
||||
from .darts import Darts, DartsDDP
|
||||
from .dsnas import Dsnas, DsnasDDP
|
||||
from .spos import SPOS
|
||||
|
||||
__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP']
|
||||
__all__ = [
|
||||
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,347 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.model import BaseModel, MMDistributedDataParallel
|
||||
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||
from torch import nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from mmrazor.models.mutators import DiffModuleMutator
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import MODEL_WRAPPERS, MODELS, TASK_UTILS
|
||||
from mmrazor.structures import load_fix_subnet
|
||||
from mmrazor.utils import FixMutable
|
||||
from ..base import BaseAlgorithm
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Dsnas(BaseAlgorithm):
|
||||
"""Implementation of `DSNAS <https://arxiv.org/abs/2002.09128>`_
|
||||
|
||||
Args:
|
||||
architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel`
|
||||
or built model. Corresponding to supernet in NAS algorithm.
|
||||
mutator (dict|:obj:`DiffModuleMutator`): The config of
|
||||
:class:`DiffModuleMutator` or built mutator.
|
||||
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
|
||||
loaded dict or built :obj:`FixSubnet`.
|
||||
pretrain_epochs (int): Num of epochs for supernet pretraining.
|
||||
finetune_epochs (int): Num of epochs for subnet finetuning.
|
||||
flops_constraints (float): Flops constraints for judging whether to
|
||||
backward flops loss or not. Default to 300.0(M).
|
||||
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
|
||||
Default to None.
|
||||
norm_training (bool): Whether to set norm layers to training mode,
|
||||
namely, not freeze running stats (mean and var). Note: Effect on
|
||||
Batch Norm and its variants only. Defaults to False.
|
||||
data_preprocessor (dict, optional): The pre-process config of
|
||||
:class:`BaseDataPreprocessor`. Defaults to None.
|
||||
init_cfg (dict): Init config for ``BaseModule``.
|
||||
|
||||
Note:
|
||||
Dsnas doesn't require retraining. It has 3 stages in searching:
|
||||
1. `cur_epoch` < `pretrain_epochs` refers to supernet pretraining.
|
||||
2. `pretrain_epochs` <= `cur_epoch` < `finetune_epochs` refers to
|
||||
normal supernet training while mutator is updated.
|
||||
3. `cur_epoch` >= `finetune_epochs` refers to subnet finetuning.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
mutator: Optional[Union[DiffModuleMutator, Dict]] = None,
|
||||
fix_subnet: Optional[FixMutable] = None,
|
||||
pretrain_epochs: int = 0,
|
||||
finetune_epochs: int = 80,
|
||||
flops_constraints: float = 300.0,
|
||||
estimator_cfg: Dict[str, Any] = None,
|
||||
norm_training: bool = False,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None,
|
||||
**kwargs):
|
||||
super().__init__(architecture, data_preprocessor, **kwargs)
|
||||
|
||||
if estimator_cfg is None:
|
||||
estimator_cfg = dict(type='mmrazor.ResourceEstimator')
|
||||
self.estimator = TASK_UTILS.build(estimator_cfg)
|
||||
if fix_subnet:
|
||||
# Avoid circular import
|
||||
from mmrazor.structures import load_fix_subnet
|
||||
|
||||
# According to fix_subnet, delete the unchosen part of supernet
|
||||
load_fix_subnet(self.architecture, fix_subnet)
|
||||
self.is_supernet = False
|
||||
else:
|
||||
assert mutator is not None, \
|
||||
'mutator cannot be None when fix_subnet is None.'
|
||||
if isinstance(mutator, DiffModuleMutator):
|
||||
self.mutator = mutator
|
||||
elif isinstance(mutator, dict):
|
||||
self.mutator = MODELS.build(mutator)
|
||||
else:
|
||||
raise TypeError('mutator should be a `dict` or '
|
||||
f'`DiffModuleMutator` instance, but got '
|
||||
f'{type(mutator)}')
|
||||
|
||||
self.mutable_module_resources = self._get_module_resources()
|
||||
# Mutator is an essential component of the NAS algorithm. It
|
||||
# provides some APIs commonly used by NAS.
|
||||
# Before using it, you must do some preparations according to
|
||||
# the supernet.
|
||||
self.mutator.prepare_from_supernet(self.architecture)
|
||||
self.is_supernet = True
|
||||
self.search_space_name_list = list(
|
||||
self.mutator.name2mutable.keys())
|
||||
|
||||
self.norm_training = norm_training
|
||||
self.pretrain_epochs = pretrain_epochs
|
||||
self.finetune_epochs = finetune_epochs
|
||||
if pretrain_epochs >= finetune_epochs:
|
||||
raise ValueError(f'Pretrain stage (optional) must be done before '
|
||||
f'finetuning stage. Got `{pretrain_epochs}` >= '
|
||||
f'`{finetune_epochs}`.')
|
||||
|
||||
self.flops_loss_coef = 1e-2
|
||||
self.flops_constraints = flops_constraints
|
||||
_, self.world_size = get_dist_info()
|
||||
|
||||
def search_subnet(self):
|
||||
"""Search subnet by mutator."""
|
||||
|
||||
# Avoid circular import
|
||||
from mmrazor.structures import export_fix_subnet
|
||||
|
||||
subnet = self.mutator.sample_choices()
|
||||
self.mutator.set_choices(subnet)
|
||||
return export_fix_subnet(self)
|
||||
|
||||
def fix_subnet(self):
|
||||
"""Fix subnet when finetuning."""
|
||||
subnet = self.mutator.sample_choices()
|
||||
self.mutator.set_choices(subnet)
|
||||
for module in self.architecture.modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
if not module.is_fixed:
|
||||
module.fix_chosen(module.current_choice)
|
||||
self.is_supernet = False
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into eval mode while keep normalization layer
|
||||
unfreezed."""
|
||||
|
||||
super().train(mode)
|
||||
if self.norm_training and not mode:
|
||||
for module in self.architecture.modules():
|
||||
if isinstance(module, _BatchNorm):
|
||||
module.training = True
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""The iteration step during training.
|
||||
|
||||
Args:
|
||||
data (dict): The output of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||
runner is passed to ``train_step()``.
|
||||
"""
|
||||
if isinstance(optim_wrapper, OptimWrapperDict):
|
||||
log_vars = dict()
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
cur_epoch = self.message_hub.get_info('epoch')
|
||||
need_update_mutator = self.need_update_mutator(cur_epoch)
|
||||
|
||||
# TODO process the input
|
||||
if cur_epoch == self.finetune_epochs and self.is_supernet:
|
||||
# synchronize arch params to start the finetune stage.
|
||||
for k, v in self.mutator.arch_params.items():
|
||||
dist.broadcast(v, src=0)
|
||||
self.fix_subnet()
|
||||
|
||||
# 1. update architecture
|
||||
with optim_wrapper['architecture'].optim_context(self):
|
||||
pseudo_data = self.data_preprocessor(data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
|
||||
supernet_losses, supernet_log_vars = self.parse_losses(
|
||||
supernet_loss)
|
||||
optim_wrapper['architecture'].backward(
|
||||
supernet_losses, retain_graph=need_update_mutator)
|
||||
optim_wrapper['architecture'].step()
|
||||
optim_wrapper['architecture'].zero_grad()
|
||||
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
|
||||
|
||||
# 2. update mutator
|
||||
if need_update_mutator:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
mutator_loss = self.compute_mutator_loss()
|
||||
mutator_losses, mutator_log_vars = \
|
||||
self.parse_losses(mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
# handle the grad of arch params & weights
|
||||
self.handle_grads()
|
||||
|
||||
else:
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
pseudo_data = self.data_preprocessor(data, True)
|
||||
batch_inputs = pseudo_data['inputs']
|
||||
data_samples = pseudo_data['data_samples']
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_losses)
|
||||
|
||||
return log_vars
|
||||
|
||||
def _get_module_resources(self):
|
||||
"""Get resources of spec modules."""
|
||||
|
||||
spec_modules = []
|
||||
for name, module in self.architecture.named_modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
for choice in module.choices:
|
||||
spec_modules.append(name + '._candidates.' + choice)
|
||||
|
||||
mutable_module_resources = self.estimator.estimate_separation_modules(
|
||||
self.architecture, dict(spec_modules=spec_modules))
|
||||
|
||||
return mutable_module_resources
|
||||
|
||||
def need_update_mutator(self, cur_epoch: int) -> bool:
|
||||
"""Whether to update mutator."""
|
||||
if cur_epoch >= self.pretrain_epochs and \
|
||||
cur_epoch < self.finetune_epochs:
|
||||
return True
|
||||
return False
|
||||
|
||||
def compute_mutator_loss(self) -> Dict[str, torch.Tensor]:
|
||||
"""Compute mutator loss.
|
||||
|
||||
In this method, arch_loss & flops_loss[optional] are computed
|
||||
by traversing arch_weights & probs in search groups.
|
||||
|
||||
Returns:
|
||||
Dict: Loss of the mutator.
|
||||
"""
|
||||
arch_loss = 0.0
|
||||
flops_loss = 0.0
|
||||
for name, module in self.architecture.named_modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
k = str(self.search_space_name_list.index(name))
|
||||
probs = F.softmax(self.mutator.arch_params[k], -1)
|
||||
arch_loss += torch.log(
|
||||
(module.arch_weights * probs).sum(-1)).sum()
|
||||
|
||||
# get the index of op with max arch weights.
|
||||
index = (module.arch_weights == 1).nonzero().item()
|
||||
_module_key = name + '._candidates.' + module.choices[index]
|
||||
flops_loss += probs[index] * \
|
||||
self.mutable_module_resources[_module_key]['flops']
|
||||
|
||||
mutator_loss = dict(arch_loss=arch_loss / self.world_size)
|
||||
|
||||
copied_model = copy.deepcopy(self)
|
||||
fix_mutable = copied_model.search_subnet()
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
subnet_flops = self.estimator.estimate(copied_model)['flops']
|
||||
if subnet_flops >= self.flops_constraints:
|
||||
mutator_loss['flops_loss'] = \
|
||||
(flops_loss * self.flops_loss_coef) / self.world_size
|
||||
|
||||
return mutator_loss
|
||||
|
||||
def handle_grads(self):
|
||||
"""Handle grads of arch params & arch weights."""
|
||||
for name, module in self.architecture.named_modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
k = str(self.search_space_name_list.index(name))
|
||||
self.mutator.arch_params[k].grad.data.mul_(
|
||||
module.arch_weights.grad.data.sum())
|
||||
module.arch_weights.grad.zero_()
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class DsnasDDP(MMDistributedDataParallel):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
device_ids: Optional[Union[List, int, torch.device]] = None,
|
||||
**kwargs) -> None:
|
||||
if device_ids is None:
|
||||
if os.environ.get('LOCAL_RANK') is not None:
|
||||
device_ids = [int(os.environ['LOCAL_RANK'])]
|
||||
super().__init__(device_ids=device_ids, **kwargs)
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""The iteration step during training.
|
||||
|
||||
This method defines an iteration step during training, except for the
|
||||
back propagation and optimizer updating, which are done in an optimizer
|
||||
hook. Note that in some complicated cases or models, the whole process
|
||||
including back propagation and optimizer updating are also defined in
|
||||
this method, such as GAN.
|
||||
"""
|
||||
if isinstance(optim_wrapper, OptimWrapperDict):
|
||||
log_vars = dict()
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
cur_epoch = self.message_hub.get_info('epoch')
|
||||
need_update_mutator = self.module.need_update_mutator(cur_epoch)
|
||||
|
||||
# TODO process the input
|
||||
if cur_epoch == self.module.finetune_epochs and \
|
||||
self.module.is_supernet:
|
||||
# synchronize arch params to start the finetune stage.
|
||||
for k, v in self.module.mutator.arch_params.items():
|
||||
dist.broadcast(v, src=0)
|
||||
self.module.fix_subnet()
|
||||
|
||||
# 1. update architecture
|
||||
with optim_wrapper['architecture'].optim_context(self):
|
||||
pseudo_data = self.module.data_preprocessor(data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
|
||||
supernet_losses, supernet_log_vars = self.module.parse_losses(
|
||||
supernet_loss)
|
||||
optim_wrapper['architecture'].backward(
|
||||
supernet_losses, retain_graph=need_update_mutator)
|
||||
optim_wrapper['architecture'].step()
|
||||
optim_wrapper['architecture'].zero_grad()
|
||||
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
|
||||
|
||||
# 2. update mutator
|
||||
if need_update_mutator:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
mutator_loss = self.module.compute_mutator_loss()
|
||||
mutator_losses, mutator_log_vars = \
|
||||
self.module.parse_losses(mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
# handle the grad of arch params & weights
|
||||
self.module.handle_grads()
|
||||
|
||||
else:
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
pseudo_data = self.module.data_preprocessor(data, True)
|
||||
batch_inputs = pseudo_data['inputs']
|
||||
data_samples = pseudo_data['data_samples']
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.module.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_losses)
|
||||
|
||||
return log_vars
|
|
@ -3,12 +3,13 @@ from .derived_mutable import DerivedMutable
|
|||
from .mutable_channel import (MutableChannel, OneShotMutableChannel,
|
||||
SlimmableMutableChannel)
|
||||
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP,
|
||||
OneShotMutableModule, OneShotMutableOP)
|
||||
OneHotMutableOP, OneShotMutableModule,
|
||||
OneShotMutableOP)
|
||||
from .mutable_value import MutableValue, OneShotMutableValue
|
||||
|
||||
__all__ = [
|
||||
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP',
|
||||
'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel',
|
||||
'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable',
|
||||
'MutableValue', 'OneShotMutableValue'
|
||||
'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP'
|
||||
]
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule,
|
||||
DiffMutableOP)
|
||||
DiffMutableOP, OneHotMutableOP)
|
||||
from .mutable_module import MutableModule
|
||||
from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP
|
||||
|
||||
__all__ = [
|
||||
'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute',
|
||||
'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule'
|
||||
'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule',
|
||||
'OneHotMutableOP'
|
||||
]
|
||||
|
|
|
@ -37,8 +37,9 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
def forward(self,
|
||||
x: Any,
|
||||
arch_param: Optional[nn.Parameter] = None) -> Any:
|
||||
"""Calls either :func:`forward_fixed` or :func:`forward_choice`
|
||||
depending on whether :func:`is_fixed` is ``True``.
|
||||
"""Calls either :func:`forward_fixed` or :func:`forward_arch_param`
|
||||
depending on whether :func:`is_fixed` is ``True`` and whether
|
||||
:func:`arch_param` is None.
|
||||
|
||||
To reduce the coupling between `Mutable` and `Mutator`, the
|
||||
`arch_param` is generated by the `Mutator` and is passed to the
|
||||
|
@ -52,6 +53,9 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
|
|||
x (Any): input data for forward computation.
|
||||
arch_param (nn.Parameter, optional): the architecture parameters
|
||||
for ``DiffMutableModule``.
|
||||
|
||||
Returns:
|
||||
Any: the result of forward
|
||||
"""
|
||||
if self.is_fixed:
|
||||
return self.forward_fixed(x)
|
||||
|
@ -97,6 +101,10 @@ class DiffMutableOP(DiffMutableModule[str, str]):
|
|||
Args:
|
||||
candidates (dict[str, dict]): the configs for the candidate
|
||||
operations.
|
||||
fix_threshold (float): The threshold that determines whether to fix
|
||||
the choice of current module as the op with the maximum `probs`.
|
||||
It happens when the maximum prob is `fix_threshold` or more higher
|
||||
then all the other probs. Default to 1.0.
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
|
@ -109,6 +117,7 @@ class DiffMutableOP(DiffMutableModule[str, str]):
|
|||
def __init__(
|
||||
self,
|
||||
candidates: Dict[str, Dict],
|
||||
fix_threshold: float = 1.0,
|
||||
module_kwargs: Optional[Dict[str, Dict]] = None,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
|
@ -120,6 +129,10 @@ class DiffMutableOP(DiffMutableModule[str, str]):
|
|||
f'but got: {len(candidates)}'
|
||||
|
||||
self._is_fixed = False
|
||||
if fix_threshold < 0 or fix_threshold > 1.0:
|
||||
raise ValueError(
|
||||
f'The fix_threshold should be in [0, 1]. Got {fix_threshold}.')
|
||||
self.fix_threshold = fix_threshold
|
||||
self._candidates = self._build_ops(candidates, self.module_kwargs)
|
||||
|
||||
@staticmethod
|
||||
|
@ -242,6 +255,94 @@ class DiffMutableOP(DiffMutableModule[str, str]):
|
|||
return list(self._candidates.keys())
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OneHotMutableOP(DiffMutableOP):
|
||||
"""A type of ``MUTABLES`` for one-hot sample based architecture search,
|
||||
such as DSNAS. Search the best module by learnable parameters `arch_param`.
|
||||
|
||||
Args:
|
||||
candidates (dict[str, dict]): the configs for the candidate
|
||||
operations.
|
||||
module_kwargs (dict[str, dict], optional): Module initialization named
|
||||
arguments. Defaults to None.
|
||||
alias (str, optional): alias of the `MUTABLE`.
|
||||
init_cfg (dict, optional): initialization configuration dict for
|
||||
``BaseModule``. OpenMMLab has implement 5 initializer including
|
||||
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
|
||||
and `Pretrained`.
|
||||
"""
|
||||
|
||||
def sample_weights(self,
|
||||
arch_param: nn.Parameter,
|
||||
probs: torch.Tensor,
|
||||
random_sample: bool = False) -> Tensor:
|
||||
"""Use one-hot distributions to sample the arch weights based on the
|
||||
arch params.
|
||||
|
||||
Args:
|
||||
arch_param (nn.Parameter): architecture parameters for
|
||||
`DiffMutableModule`.
|
||||
probs (Tensor): the probs of choice.
|
||||
random_sample (bool): Whether to random sample arch weights or not
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: Sampled one-hot arch weights.
|
||||
"""
|
||||
import torch.distributions as D
|
||||
if random_sample:
|
||||
uni = torch.ones_like(arch_param)
|
||||
m = D.one_hot_categorical.OneHotCategorical(uni)
|
||||
else:
|
||||
m = D.one_hot_categorical.OneHotCategorical(probs=probs)
|
||||
return m.sample()
|
||||
|
||||
def forward_arch_param(self,
|
||||
x: Any,
|
||||
arch_param: Optional[nn.Parameter] = None
|
||||
) -> Tensor:
|
||||
"""Forward with architecture parameters.
|
||||
|
||||
Args:
|
||||
x (Any): x could be a Torch.tensor or a tuple of
|
||||
Torch.tensor, containing input data for forward computation.
|
||||
arch_param (str, optional): architecture parameters for
|
||||
`DiffMutableModule`.
|
||||
|
||||
Returns:
|
||||
Tensor: the result of forward with ``arch_param``.
|
||||
"""
|
||||
if arch_param is None:
|
||||
return self.forward_all(x)
|
||||
else:
|
||||
# compute the probs of choice
|
||||
probs = self.compute_arch_probs(arch_param=arch_param)
|
||||
|
||||
if not self.is_fixed:
|
||||
self.arch_weights = self.sample_weights(arch_param, probs)
|
||||
sorted_param = torch.topk(probs, 2)
|
||||
index = (
|
||||
sorted_param[0][0] - sorted_param[0][1] >=
|
||||
self.fix_threshold)
|
||||
if index:
|
||||
self.fix_chosen(self.choices[index])
|
||||
|
||||
if self.is_fixed:
|
||||
index = self.choices.index(self._chosen[0])
|
||||
self.arch_weights.data.zero_()
|
||||
self.arch_weights.data[index].fill_(1.0)
|
||||
self.arch_weights.requires_grad_()
|
||||
|
||||
# forward based on self.arch_weights
|
||||
outputs = list()
|
||||
for prob, module in zip(self.arch_weights,
|
||||
self._candidates.values()):
|
||||
if prob > 0.:
|
||||
outputs.append(prob * module(x))
|
||||
|
||||
return sum(outputs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
|
||||
"""A type of ``MUTABLES`` for Neural Architecture Search, which can select
|
||||
|
|
|
@ -88,8 +88,8 @@ class DiffModuleMutator(ModuleMutator):
|
|||
|
||||
choices = dict()
|
||||
for group_id, mutables in self.search_groups.items():
|
||||
arch_parm = self.arch_params[str(group_id)]
|
||||
choice = mutables[0].sample_choice(arch_parm)
|
||||
arch_param = self.arch_params[str(group_id)]
|
||||
choice = mutables[0].sample_choice(arch_param)
|
||||
choices[group_id] = choice
|
||||
return choices
|
||||
|
||||
|
|
|
@ -54,6 +54,24 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
|
|||
"""
|
||||
self._build_search_groups(supernet)
|
||||
|
||||
@property
|
||||
def name2mutable(self) -> Dict[str, MUTABLE_TYPE]:
|
||||
"""Search space of supernet.
|
||||
|
||||
Note:
|
||||
To get the mapping: module name to mutable.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Called before search space has been parsed.
|
||||
|
||||
Returns:
|
||||
Dict[str, MUTABLE_TYPE]: The name2mutable dict.
|
||||
"""
|
||||
if self._name2mutable is None:
|
||||
raise RuntimeError(
|
||||
'Call `prepare_from_supernet` before access name2mutable!')
|
||||
return self._name2mutable
|
||||
|
||||
@property
|
||||
def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]:
|
||||
"""Search group of supernet.
|
||||
|
@ -80,6 +98,8 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
|
|||
for name, module in supernet.named_modules():
|
||||
if isinstance(module, self.mutable_class_type):
|
||||
name2mutable[name] = module
|
||||
self._name2mutable = name2mutable
|
||||
|
||||
return name2mutable
|
||||
|
||||
def _build_alias_names_mapping(self,
|
||||
|
@ -121,7 +141,7 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
|
|||
>>> import torch
|
||||
>>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP
|
||||
|
||||
>>> # Assume that a toy model consists of three mutabels
|
||||
>>> # Assume that a toy model consists of three mutables
|
||||
>>> # whose name are op1,op2,op3. The corresponding
|
||||
>>> # alias names of the three mutables are a1, a1, a2.
|
||||
>>> model = ToyModel()
|
||||
|
|
|
@ -50,21 +50,22 @@ def load_fix_subnet(model: nn.Module,
|
|||
# In the corresponding mutable, it will check whether the `chosen`
|
||||
# format is correct.
|
||||
if isinstance(module, BaseMutable):
|
||||
if getattr(module, 'alias', None):
|
||||
alias = module.alias
|
||||
assert alias in fix_mutable, \
|
||||
f'The alias {alias} is not in fix_modules, ' \
|
||||
'please check your `fix_mutable`.'
|
||||
chosen = fix_mutable.get(alias, None)
|
||||
else:
|
||||
mutable_name = name.lstrip(prefix)
|
||||
if mutable_name not in fix_mutable and \
|
||||
not isinstance(module, DerivedMutable):
|
||||
raise RuntimeError(
|
||||
f'The module name {mutable_name} is not in '
|
||||
'fix_mutable, please check your `fix_mutable`.')
|
||||
chosen = fix_mutable.get(mutable_name, None)
|
||||
module.fix_chosen(chosen)
|
||||
if not module.is_fixed:
|
||||
if getattr(module, 'alias', None):
|
||||
alias = module.alias
|
||||
assert alias in fix_mutable, \
|
||||
f'The alias {alias} is not in fix_modules, ' \
|
||||
'please check your `fix_mutable`.'
|
||||
chosen = fix_mutable.get(alias, None)
|
||||
else:
|
||||
mutable_name = name.lstrip(prefix)
|
||||
if mutable_name not in fix_mutable and \
|
||||
not isinstance(module, DerivedMutable):
|
||||
raise RuntimeError(
|
||||
f'The module name {mutable_name} is not in '
|
||||
'fix_mutable, please check your `fix_mutable`.')
|
||||
chosen = fix_mutable.get(mutable_name, None)
|
||||
module.fix_chosen(chosen)
|
||||
|
||||
# convert dynamic op to static op
|
||||
_dynamic_to_static(model)
|
||||
|
@ -89,7 +90,6 @@ def export_fix_subnet(model: nn.Module,
|
|||
if isinstance(module, DerivedMutable) and not dump_derived_mutable:
|
||||
continue
|
||||
|
||||
assert not module.is_fixed
|
||||
if module.alias:
|
||||
fix_subnet[module.alias] = module.dump_chosen()
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import build_optim_wrapper
|
||||
from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP
|
||||
from mmrazor.models.algorithms.nas.dsnas import DsnasDDP
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
||||
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
||||
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyDiffModule(BaseModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
|
||||
self.candidates = dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
)
|
||||
module_kwargs = dict(in_channels=3, out_channels=8, stride=1)
|
||||
|
||||
self.mutable = OneHotMutableOP(
|
||||
candidates=self.candidates, module_kwargs=module_kwargs)
|
||||
self.bn = nn.BatchNorm2d(8)
|
||||
|
||||
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
|
||||
if mode == 'loss':
|
||||
out = self.bn(self.mutable(batch_inputs))
|
||||
return dict(loss=out)
|
||||
elif mode == 'predict':
|
||||
out = self.bn(self.mutable(batch_inputs)) + 1
|
||||
return out
|
||||
elif mode == 'tensor':
|
||||
out = self.bn(self.mutable(batch_inputs)) + 2
|
||||
return out
|
||||
|
||||
|
||||
class TestDsnas(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.device: str = 'cpu'
|
||||
|
||||
OPTIMIZER_CFG = dict(
|
||||
type='SGD',
|
||||
lr=0.5,
|
||||
momentum=0.9,
|
||||
nesterov=True,
|
||||
weight_decay=0.0001)
|
||||
|
||||
self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG)
|
||||
|
||||
def test_init(self) -> None:
|
||||
# initiate dsnas when `norm_training` is True.
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
algo = Dsnas(architecture=model, mutator=mutator, norm_training=True)
|
||||
algo.eval()
|
||||
self.assertTrue(model.bn.training)
|
||||
|
||||
# initiate Dsnas with built mutator
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
algo = Dsnas(model, mutator)
|
||||
self.assertIs(algo.mutator, mutator)
|
||||
|
||||
# initiate Dsnas with unbuilt mutator
|
||||
mutator = dict(type='DiffModuleMutator')
|
||||
algo = Dsnas(model, mutator)
|
||||
self.assertIsInstance(algo.mutator, DiffModuleMutator)
|
||||
|
||||
# initiate Dsnas when `fix_subnet` is not None
|
||||
fix_subnet = {'mutable': 'torch_conv2d_5x5'}
|
||||
algo = Dsnas(model, mutator, fix_subnet=fix_subnet)
|
||||
self.assertEqual(algo.architecture.mutable.num_choices, 1)
|
||||
|
||||
# initiate Dsnas with error type `mutator`
|
||||
with self.assertRaisesRegex(TypeError, 'mutator should be'):
|
||||
Dsnas(model, model)
|
||||
|
||||
def test_forward_loss(self) -> None:
|
||||
inputs = torch.randn(1, 3, 8, 8)
|
||||
model = ToyDiffModule()
|
||||
|
||||
# supernet
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
algo = Dsnas(model, mutator)
|
||||
loss = algo(inputs, mode='loss')
|
||||
self.assertIsInstance(loss, dict)
|
||||
|
||||
# subnet
|
||||
fix_subnet = {'mutable': 'torch_conv2d_5x5'}
|
||||
algo = Dsnas(model, fix_subnet=fix_subnet)
|
||||
loss = algo(inputs, mode='loss')
|
||||
self.assertIsInstance(loss, dict)
|
||||
|
||||
def _prepare_fake_data(self):
|
||||
imgs = torch.randn(16, 3, 224, 224).to(self.device)
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_label(torch.randint(0, 1000,
|
||||
(16, ))).to(self.device)
|
||||
]
|
||||
return {'inputs': imgs, 'data_samples': data_samples}
|
||||
|
||||
def test_search_subnet(self) -> None:
|
||||
model = ToyDiffModule()
|
||||
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
algo = Dsnas(model, mutator)
|
||||
subnet = algo.search_subnet()
|
||||
self.assertIsInstance(subnet, dict)
|
||||
|
||||
@patch('mmengine.logging.message_hub.MessageHub.get_info')
|
||||
def test_dsnas_train_step(self, mock_get_info) -> None:
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
mock_get_info.return_value = 2
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
data = self._prepare_fake_data()
|
||||
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
|
||||
loss = algo.train_step(data, optim_wrapper)
|
||||
|
||||
self.assertTrue(isinstance(loss['loss'], Tensor))
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
|
||||
loss = algo.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
|
||||
class TestDsnasDDP(TestDsnas):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12345'
|
||||
|
||||
# initialize the process group
|
||||
if torch.cuda.is_available():
|
||||
backend = 'nccl'
|
||||
cls.device = 'cuda'
|
||||
else:
|
||||
backend = 'gloo'
|
||||
dist.init_process_group(backend, rank=0, world_size=1)
|
||||
|
||||
def prepare_model(self, device_ids=None) -> Dsnas:
|
||||
model = ToyDiffModule().to(self.device)
|
||||
mutator = DiffModuleMutator().to(self.device)
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
|
||||
return DsnasDDP(
|
||||
module=algo, find_unused_parameters=True, device_ids=device_ids)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='cuda device is not avaliable')
|
||||
def test_init(self) -> None:
|
||||
ddp_model = self.prepare_model()
|
||||
self.assertIsInstance(ddp_model, DsnasDDP)
|
||||
|
||||
@patch('mmengine.logging.message_hub.MessageHub.get_info')
|
||||
def test_dsnasddp_train_step(self, mock_get_info) -> None:
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
mock_get_info.return_value = 2
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
|
||||
data = self._prepare_fake_data()
|
||||
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
|
||||
loss = ddp_model.train_step(data, optim_wrapper)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
|
||||
loss = ddp_model.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.models import * # noqa:F403,F401
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
||||
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
||||
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
||||
|
||||
|
||||
class TestOneHotOP(TestCase):
|
||||
|
||||
def test_forward_arch_param(self):
|
||||
op_cfg = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
candidates=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
),
|
||||
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||||
|
||||
op = MODELS.build(op_cfg)
|
||||
input = torch.randn(4, 32, 64, 64)
|
||||
|
||||
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
|
||||
output = op.forward_arch_param(input, arch_param=arch_param)
|
||||
assert output is not None
|
||||
|
||||
output = op.forward_arch_param(input, arch_param=None)
|
||||
assert output is not None
|
||||
|
||||
# test when some element of arch_param is 0
|
||||
arch_param = nn.Parameter(torch.ones(op.num_choices))
|
||||
output = op.forward_arch_param(input, arch_param=arch_param)
|
||||
assert output is not None
|
||||
|
||||
def test_forward_fixed(self):
|
||||
op_cfg = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
candidates=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
),
|
||||
),
|
||||
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||||
|
||||
op = MODELS.build(op_cfg)
|
||||
input = torch.randn(4, 32, 64, 64)
|
||||
|
||||
op.fix_chosen('torch_conv2d_7x7')
|
||||
output = op.forward_fixed(input)
|
||||
|
||||
assert output is not None
|
||||
assert op.is_fixed is True
|
||||
|
||||
def test_forward(self):
|
||||
op_cfg = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
candidates=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
),
|
||||
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||||
|
||||
op = MODELS.build(op_cfg)
|
||||
input = torch.randn(4, 32, 64, 64)
|
||||
|
||||
# test set_forward_args
|
||||
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
|
||||
op.set_forward_args(arch_param=arch_param)
|
||||
output = op.forward(input)
|
||||
assert output is not None
|
||||
|
||||
# test dump_chosen
|
||||
with pytest.raises(AssertionError):
|
||||
op.dump_chosen()
|
||||
|
||||
# test forward when is_fixed is True
|
||||
op.fix_chosen('torch_conv2d_7x7')
|
||||
output = op.forward(input)
|
||||
|
||||
def test_property(self):
|
||||
op_cfg = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
candidates=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
),
|
||||
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||||
|
||||
op = MODELS.build(op_cfg)
|
||||
|
||||
assert len(op.choices) == 3
|
||||
|
||||
# test is_fixed propty
|
||||
assert op.is_fixed is False
|
||||
|
||||
# test is_fixed setting
|
||||
op.fix_chosen('torch_conv2d_5x5')
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
op.is_fixed = True
|
||||
|
||||
# test fix choice when is_fixed is True
|
||||
with pytest.raises(AttributeError):
|
||||
op.fix_chosen('torch_conv2d_3x3')
|
||||
|
||||
def test_module_kwargs(self):
|
||||
op_cfg = dict(
|
||||
type='mmrazor.OneHotMutableOP',
|
||||
candidates=dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
stride=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
stride=1,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
stride=1,
|
||||
),
|
||||
torch_maxpool_3x3=dict(
|
||||
type='torchMaxPool2d',
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
),
|
||||
torch_avgpool_3x3=dict(
|
||||
type='torchAvgPool2d',
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
op = MODELS.build(op_cfg)
|
||||
input = torch.randn(4, 32, 64, 64)
|
||||
|
||||
op.fix_chosen('torch_avgpool_3x3')
|
||||
output = op.forward(input)
|
||||
assert output is not None
|
Loading…
Reference in New Issue