[Feature] Add Autoformer algorithm ()

* update candidates

* update subnet_sampler_loop

* update candidate

* add readme

* rename variable

* rename variable

* clean

* update

* add doc string

* Revert "[Improvement] Support for candidate multiple dimensional search constraints."

* [Improvement] Update Candidate with multi-dim search constraints. ()

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* [Feature] Autoformer architecture and dynamicOPs ()

* add DynamicSequential

* dynamiclayernorm

* add dynamic_pathchembed

* add DynamicMultiheadAttention and DynamicRelativePosition2D

* add channel-level dynamicOP

* add autoformer algo

* clean notes

* adapt channel_mutator

* vit fly

* fix import

* mutable init

* remove annotation

* add DynamicInputResizer

* add unittest for mutables

* add OneShotMutableChannelUnit_VIT

* clean code

* reset unit for vit

* remove attr

* add autoformer backbone UT

* add valuemutator UT

* clean code

* add autoformer algo UT

* update classifier UT

* fix test error

* ignore

* make lint

* update

* fix lint

* mutable_attrs

* fix test

* fix error

* remove DynamicInputResizer

* fix test ci

* remove InputResizer

* rename variables

* modify type

* Continued improvements of ChannelUnit

* fix lint

* fix lint

* remove OneShotMutableChannelUnit

* adjust derived type

* combination mixins

* clean code

* fix sample subnet

* search loop fly

* more annotations

* avoid counter warning and modify batch_augment cfg by gy

* restore

* source_value_mutables restriction

* simply arch_setting api

* update

* clean

* fix ut
pull/356/head
Yue Sun 2022-11-14 13:01:04 +08:00 committed by GitHub
parent 9c567e4d40
commit fb42405af8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
68 changed files with 3598 additions and 260 deletions

View File

@ -0,0 +1,180 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
preprocess_cfg = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = preprocess_cfg['mean'][::-1]
bgr_std = preprocess_cfg['std'][::-1]
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
rand_increasing_policies = [
dict(type='mmcls.AutoContrast'),
dict(type='mmcls.Equalize'),
dict(type='mmcls.Invert'),
dict(type='mmcls.Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(type='mmcls.Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='mmcls.Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(
type='mmcls.SolarizeAdd',
magnitude_key='magnitude',
magnitude_range=(0, 110)),
dict(
type='mmcls.ColorTransform',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='mmcls.Contrast',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='mmcls.Brightness',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='mmcls.Sharpness',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='mmcls.Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='mmcls.Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
dict(
type='mmcls.Translate',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
direction='horizontal'),
dict(
type='mmcls.Translate',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
direction='vertical')
]
train_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(
type='mmcls.RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='mmcls.RandAugment',
policies=rand_increasing_policies,
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='mmcls.RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='mmcls.PackClsInputs'),
]
test_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(
type='mmcls.ResizeEdge',
scale=248,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='mmcls.CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs')
]
train_dataloader = dict(
batch_size=64,
num_workers=6,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='mmcls.RepeatAugSampler'),
persistent_workers=True,
)
val_dataloader = dict(
batch_size=256,
num_workers=6,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0)
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=0.002,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
# specific to vit pretrain
paramwise_cfg=dict(custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# leanring policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=1e-3,
by_epoch=True,
begin=0,
# about 10000 iterations for ImageNet-1k
end=20,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=500,
eta_min=1e-5,
by_epoch=True,
begin=20,
end=500,
convert_to_iter_based=True),
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=500)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=2048)

View File

@ -0,0 +1,66 @@
# AutoFormer
> [Searching Transformers for Visual Recognition](https://arxiv.org/abs/2107.00651)
<!-- [ALGORITHM] -->
## Abstract
Recently, pure transformer-based models have shown
great potentials for vision tasks such as image classification and detection. However, the design of transformer networks is challenging. It has been observed that the depth,
embedding dimension, and number of heads can largely affect the performance of vision transformers. Previous models configure these dimensions based upon manual crafting. In this work, we propose a new one-shot architecture
search framework, namely AutoFormer, dedicated to vision
transformer search. AutoFormer entangles the weights of
different blocks in the same layers during supernet training. Benefiting from the strategy, the trained supernet allows thousands of subnets to be very well-trained. Specifically, the performance of these subnets with weights inherited from the supernet is comparable to those retrained
from scratch. Besides, the searched models, which we refer to AutoFormers, surpass the recent state-of-the-arts such
as ViT and DeiT. In particular, AutoFormer-tiny/small/base
achieve 74.7%/81.7%/82.4% top-1 accuracy on ImageNet
with 5.7M/22.9M/53.7M parameters, respectively. Lastly,
we verify the transferability of AutoFormer by providing
the performance on downstream benchmarks and distillation experiments.
![pipeline](/docs/en/imgs/model_zoo/autoformer/pipeline.png)
## Introduction
### Supernet pre-training on ImageNet
```bash
python ./tools/train.py \
configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py \
--work-dir $WORK_DIR
```
### Search for subnet on the trained supernet
```bash
sh tools/train.py \
configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py \
$STEP1_CKPT \
--work-dir $WORK_DIR
```
## Results and models
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
| :------: | :------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------: |
| ImageNet | vit | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 52.472 | 10.2 | 82.48 | 95.99 | [config](./autoformer_supernet_32xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/x.pth) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/x.log.json) | MMRazor searched |
**Note**:
1. There are some small differences in our experiment in order to be consistent with mmrazor repo. For example, we set the max value of embed_channels 624 while the original repo set it 640. However, the original repo only search 528, 576, 624 embed_channels, so set 624 can also get the same result with orifinal paper.
2. The original paper get 82.4 top-1 acc with 53.7M Params while we get 82.48 top-1 acc with 52.47M Params.
## Citation
```latex
@article{xu2021autoformer,
title={Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting},
author={Xu, Jiehui and Wang, Jianmin and Long, Mingsheng and others},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```
Footer

View File

@ -0,0 +1,17 @@
_base_ = ['./autoformer_supernet_32xb256_in1k.py']
custom_hooks = None
train_cfg = dict(
_delete_=True,
type='mmrazor.EvolutionSearchLoop',
dataloader=_base_.val_dataloader,
evaluator=_base_.val_evaluator,
max_epochs=20,
num_candidates=20,
top_k=10,
num_mutation=5,
num_crossover=5,
mutate_prob=0.2,
constraints_range=dict(params=(0, 55)),
score_key='accuracy/top1')

View File

@ -0,0 +1,79 @@
_base_ = [
'mmrazor::_base_/settings/imagenet_bs2048_AdamW.py',
'mmcls::_base_/default_runtime.py',
]
# data preprocessor
data_preprocessor = dict(
_scope_='mmcls',
type='ClsDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
num_classes=1000,
batch_augments=dict(
augments=[
dict(type='Mixup', alpha=0.2),
dict(type='CutMix', alpha=1.0)
],
probs=[0.5, 0.5]))
arch_setting = dict(
mlp_ratios=[3.0, 3.5, 4.0],
num_heads=[8, 9, 10],
depth=[14, 15, 16],
embed_dims=[528, 576, 624])
supernet = dict(
_scope_='mmrazor',
type='SearchableImageClassifier',
data_preprocessor=data_preprocessor,
backbone=dict(
_scope_='mmrazor',
type='AutoformerBackbone',
arch_setting=arch_setting),
neck=None,
head=dict(
type='DynamicLinearClsHead',
num_classes=1000,
in_channels=624,
loss=dict(
type='mmcls.LabelSmoothLoss',
mode='original',
num_classes=1000,
label_smooth_val=0.1,
loss_weight=1.0),
topk=(1, 5)),
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
)
model = dict(
type='mmrazor.Autoformer',
architecture=supernet,
fix_subnet=None,
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
# runtime setting
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
# checkpoint saving
_base_.default_hooks.checkpoint = dict(
type='CheckpointHook',
interval=2,
by_epoch=True,
save_best='accuracy/top1',
max_keep_ckpts=3)
find_unused_parameters = True

View File

@ -13,5 +13,5 @@ train_cfg = dict(
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 465.),
constraints_range=dict(flops=(0., 465.)),
score_key='accuracy/top1')

View File

@ -13,5 +13,5 @@ train_cfg = dict(
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
flops_range=(0., 330.),
constraints_range=dict(flops=(0, 330)),
score_key='accuracy/top1')

View File

@ -13,5 +13,5 @@ train_cfg = dict(
num_mutation=20,
num_crossover=20,
mutate_prob=0.1,
flops_range=(0., 300.),
constraints_range=dict(flops=(0, 330)),
score_key='coco/bbox_mAP')

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from mmengine import fileio
@ -14,10 +15,10 @@ from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import check_subnet_flops, crossover
from .utils import check_subnet_resources, crossover
@LOOPS.register_module()
@ -41,10 +42,11 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
num_crossover (int): The number of candidates got by crossover.
Defaults to 25.
mutate_prob (float): The probability of mutation. Defaults to 0.1.
flops_range (tuple, optional): It is used for screening candidates.
resource_estimator_cfg (dict): The config for building estimator, which
is be used to estimate the flops of sampled subnet. Defaults to
None, which means default config is used.
crossover_prob (float): The probability of crossover. Defaults to 0.5.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
init_candidates (str, optional): The candidates file path, which is
@ -64,8 +66,9 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
num_mutation: int = 25,
num_crossover: int = 25,
mutate_prob: float = 0.1,
flops_range: Optional[Tuple[float, float]] = (0., 330.),
resource_estimator_cfg: Optional[dict] = None,
crossover_prob: float = 0.5,
constraints_range: Dict[str, Any] = dict(flops=(0., 330.)),
resource_estimator_cfg: Optional[Dict] = None,
score_key: str = 'accuracy/top1',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
@ -83,11 +86,12 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
self.num_candidates = num_candidates
self.top_k = top_k
self.flops_range = flops_range
self.constraints_range = constraints_range
self.score_key = score_key
self.num_mutation = num_mutation
self.num_crossover = num_crossover
self.mutate_prob = mutate_prob
self.crossover_prob = crossover_prob
self.max_keep_ckpts = max_keep_ckpts
self.resume_from = resume_from
@ -99,16 +103,58 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
correct init candidates file'
self.top_k_candidates = Candidates()
if resource_estimator_cfg is None:
self.estimator = ResourceEstimator()
else:
self.estimator = ResourceEstimator(**resource_estimator_cfg)
if self.runner.distributed:
self.model = runner.model.module
else:
self.model = runner.model
# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)
def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict): A
resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')
resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore
if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore
return estimator # type: ignore
def run(self) -> None:
"""Launch searching."""
self.runner.call_hook('before_train')
@ -144,31 +190,48 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
f'{scores_before}')
self.candidates.extend(self.top_k_candidates)
self.candidates.sort(key=lambda x: x[1], reverse=True)
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
self.candidates.sort_by(key_indicator='score', reverse=True)
self.top_k_candidates = Candidates(self.candidates.data[:self.top_k])
scores_after = self.top_k_candidates.scores
self.runner.logger.info(f'top k scores after update: '
f'{scores_after}')
mutation_candidates = self.gen_mutation_candidates()
self.candidates_mutator_crossover = Candidates(mutation_candidates)
crossover_candidates = self.gen_crossover_candidates()
candidates = mutation_candidates + crossover_candidates
assert len(candidates) <= self.num_candidates, 'Total of mutation and \
crossover should be no more than the number of candidates.'
self.candidates_mutator_crossover.extend(crossover_candidates)
self.candidates = Candidates(candidates)
assert len(self.candidates_mutator_crossover
) <= self.num_candidates, 'Total of mutation and \
crossover should be less than the number of candidates.'
self.candidates = self.candidates_mutator_crossover
self._epoch += 1
def sample_candidates(self) -> None:
"""Update candidate pool contains specified number of candicates."""
candidates_resources = []
init_candidates = len(self.candidates)
if self.runner.rank == 0:
while len(self.candidates) < self.num_candidates:
candidate = self.model.sample_subnet()
if self._check_constraints(random_subnet=candidate):
is_pass, result = self._check_constraints(
random_subnet=candidate)
if is_pass:
self.candidates.append(candidate)
candidates_resources.append(result)
self.candidates = Candidates(self.candidates.data)
else:
self.candidates = Candidates([None] * self.num_candidates)
self.candidates = Candidates([dict(a=0)] * self.num_candidates)
if len(candidates_resources) > 0:
self.candidates.update_resources(
candidates_resources,
start=len(self.candidates.data) - len(candidates_resources))
assert init_candidates + len(
candidates_resources) == self.num_candidates
# broadcast candidates to val with multi-GPUs.
broadcast_object_list(self.candidates.data)
@ -180,14 +243,18 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
metrics = self._val_candidate()
score = metrics[self.score_key] \
if len(metrics) != 0 else 0.
self.candidates.set_score(i, score)
self.candidates.set_resource(i, score, 'score')
self.runner.logger.info(
f'Epoch:[{self._epoch}/{self._max_epochs}] '
f'Candidate:[{i + 1}/{self.num_candidates}] '
f'Score:{score}')
f'Flops: {self.candidates.resources("flops")[i]} '
f'Params: {self.candidates.resources("params")[i]} '
f'Latency: {self.candidates.resources("latency")[i]} '
f'Score: {self.candidates.scores} ')
def gen_mutation_candidates(self) -> List:
def gen_mutation_candidates(self):
"""Generate specified number of mutation candicates."""
mutation_resources = []
mutation_candidates: List = []
max_mutate_iters = self.num_mutation * 10
mutate_iter = 0
@ -198,12 +265,20 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
mutation_candidate = self._mutation()
if self._check_constraints(random_subnet=mutation_candidate):
is_pass, result = self._check_constraints(
random_subnet=mutation_candidate)
if is_pass:
mutation_candidates.append(mutation_candidate)
mutation_resources.append(result)
mutation_candidates = Candidates(mutation_candidates)
mutation_candidates.update_resources(mutation_resources)
return mutation_candidates
def gen_crossover_candidates(self) -> List:
def gen_crossover_candidates(self):
"""Generate specofied number of crossover candicates."""
crossover_resources = []
crossover_candidates: List = []
crossover_iter = 0
max_crossover_iters = self.num_crossover * 10
@ -214,8 +289,15 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
crossover_candidate = self._crossover()
if self._check_constraints(random_subnet=crossover_candidate):
is_pass, result = self._check_constraints(
random_subnet=crossover_candidate)
if is_pass:
crossover_candidates.append(crossover_candidate)
crossover_resources.append(result)
crossover_candidates = Candidates(crossover_candidates)
crossover_candidates.update_resources(crossover_resources)
return crossover_candidates
def _mutation(self) -> SupportRandomSubnet:
@ -229,7 +311,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
"""Crossover."""
candidate1 = random.choice(self.top_k_candidates.subnets)
candidate2 = random.choice(self.top_k_candidates.subnets)
candidate = crossover(candidate1, candidate2)
candidate = crossover(candidate1, candidate2, prob=self.crossover_prob)
return candidate
def _resume(self):
@ -263,7 +345,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.evaluator.process(outputs, data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
return metrics
@ -295,16 +377,17 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
if osp.isfile(ckpt_path):
os.remove(ckpt_path)
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
def _check_constraints(
self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]:
"""Check whether is beyond constraints.
Returns:
bool: The result of checking.
bool, result: The result of checking.
"""
is_pass = check_subnet_flops(
is_pass, results = check_subnet_resources(
model=self.model,
subnet=random_subnet,
estimator=self.estimator,
flops_range=self.flops_range)
constraints_range=self.constraints_range)
return is_pass
return is_pass, results

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os
import random
from abc import abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from mmengine import fileio
@ -13,10 +14,10 @@ from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates
from mmrazor.utils import SupportRandomSubnet
from .utils import check_subnet_flops
from .utils import check_subnet_resources
class BaseSamplerTrainLoop(IterBasedTrainLoop):
@ -77,18 +78,15 @@ class BaseSamplerTrainLoop(IterBasedTrainLoop):
@LOOPS.register_module()
class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
"""IterBasedTrainLoop for greedy sampler.
In GreedySamplerTrainLoop, `Greedy` means that only use some top
sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly
picks the top candidates based on their val socres, then use them to train
the supernet one by one.
Steps:
1. Sample from the supernet and the candidates.
2. Validate these sampled candidates to get each candidate's score.
3. Get top-k candidates based on their scores, then use them to train
the supernet one by one.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
@ -102,10 +100,10 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
val_interval (int): Validation interval. Defaults to 1000.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
flops_range (dict): Constraints to be used for screening candidates.
resource_estimator_cfg (dict): The config for building estimator, which
is be used to estimate the flops of sampled subnet. Defaults to
None, which means default config is used.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
@ -139,8 +137,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
val_begin: int = 1,
val_interval: int = 1000,
score_key: str = 'accuracy/top1',
flops_range: Optional[Tuple[float, float]] = (0., 330),
resource_estimator_cfg: Optional[dict] = None,
constraints_range: Dict[str, Any] = dict(flops=(0, 330)),
resource_estimator_cfg: Optional[Dict] = None,
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
@ -163,7 +161,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.evaluator = evaluator
self.score_key = score_key
self.flops_range = flops_range
self.constraints_range = constraints_range
self.num_candidates = num_candidates
self.num_samples = num_samples
self.top_k = top_k
@ -177,10 +175,52 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.candidates = Candidates()
self.top_k_candidates = Candidates()
if resource_estimator_cfg is None:
self.estimator = ResourceEstimator()
# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)
def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict):
A resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')
resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore
if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
self.estimator = ResourceEstimator(**resource_estimator_cfg)
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore
return estimator # type: ignore
def run(self) -> None:
"""Launch training."""
@ -230,9 +270,11 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.update_candidates_scores()
self.candidates.sort(key=lambda x: x[1], reverse=True)
self.candidates = Candidates(self.candidates[:self.num_candidates])
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
self.candidates.sort_by(key_indicator='score', reverse=True)
self.candidates = Candidates(
self.candidates.data[:self.num_candidates])
self.top_k_candidates = Candidates(
self.candidates.data[:self.top_k])
top1_score = self.top_k_candidates.scores[0]
if (self._iter % self.val_interval) < self.top_k:
@ -243,7 +285,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
f'{num_sample_from_supernet}/{self.num_samples} '
f'top1_score {top1_score:.3f} '
f'cur_num_candidates: {len(self.candidates)}')
return self.top_k_candidates.pop(0)[0]
return self.top_k_candidates.subnets[0]
def update_cur_prob(self, cur_iter: int) -> None:
"""update current probablity of sampling from the candidates, which is
@ -278,7 +320,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
for _ in range(num_samples):
if random.random() >= self.cur_prob or len(self.candidates) == 0:
subnet = self._sample_from_supernet()
if self._check_constraints(subnet):
is_pass, _ = self._check_constraints(subnet)
if is_pass:
sampled_candidates.append(subnet)
num_sample_from_supernet += 1
else:
@ -292,7 +335,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.model.set_subnet(candidate)
metrics = self._val_candidate()
score = metrics[self.score_key] if len(metrics) != 0 else 0.
self.candidates.set_score(i, score)
self.candidates.set_resource(i, score, 'score')
@torch.no_grad()
def _val_candidate(self) -> Dict:
@ -312,22 +355,22 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
def _sample_from_candidates(self) -> SupportRandomSubnet:
"""Sample from the candidates."""
assert len(self.candidates) > 0
subnet = random.choice(self.candidates)
subnet = random.choice(self.candidates.data)
return subnet
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
def _check_constraints(self, random_subnet: SupportRandomSubnet):
"""Check whether is beyond constraints.
Returns:
bool: The result of checking.
bool, result: The result of checking.
"""
is_pass = check_subnet_flops(
is_pass, results = check_subnet_resources(
model=self.model,
subnet=random_subnet,
estimator=self.estimator,
flops_range=self.flops_range)
constraints_range=self.constraints_range)
return is_pass
return is_pass, results
def _save_candidates(self) -> None:
"""Save the candidates to init the next searching."""

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .check import check_subnet_flops
from .check import check_subnet_resources
from .genetic import crossover
__all__ = ['crossover', 'check_subnet_flops']
__all__ = ['crossover', 'check_subnet_resources']

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional, Tuple
from typing import Any, Dict, Tuple
import torch.nn as nn
import torch
from mmrazor.models import ResourceEstimator
from mmrazor.structures import export_fix_subnet, load_fix_subnet
@ -15,18 +15,20 @@ except ImportError:
BaseDetector = get_placeholder('mmdet')
def check_subnet_flops(
model: nn.Module,
subnet: SupportRandomSubnet,
estimator: ResourceEstimator,
flops_range: Optional[Tuple[float, float]] = None) -> bool:
"""Check whether is beyond flops constraints.
@torch.no_grad()
def check_subnet_resources(
model,
subnet: SupportRandomSubnet,
estimator: ResourceEstimator,
constraints_range: Dict[str, Any] = dict(flops=(0, 330))
) -> Tuple[bool, Dict]:
"""Check whether is beyond resources constraints.
Returns:
bool: The result of checking.
bool, result: The result of checking.
"""
if flops_range is None:
return True
if constraints_range is None:
return True, dict()
assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')
model.set_subnet(subnet)
@ -40,9 +42,10 @@ def check_subnet_flops(
else:
results = estimator.estimate(model=model_to_check)
flops = results['flops']
flops_mix, flops_max = flops_range
if flops_mix <= flops <= flops_max: # type: ignore
return True
else:
return False
for k, v in constraints_range.items():
if not isinstance(v, (list, tuple)):
v = (0, v)
if results[k] < v[0] or results[k] > v[1]:
return False, results
return True, results

View File

@ -3,7 +3,8 @@ from .base import BaseAlgorithm
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP,
Darts, DartsDDP)
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
from .pruning.ite_prune_algorithm import ItePruneAlgorithm
@ -25,4 +26,5 @@ __all__ = [
'ItePruneAlgorithm',
'DSNAS',
'DSNASDDP',
'Autoformer',
]

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoformer import Autoformer
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
from .dsnas import DSNAS, DSNASDDP
from .spos import SPOS
__all__ = [
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', 'DSNASDDP'
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS',
'DSNASDDP', 'Autoformer'
]

View File

@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from mmengine.model import BaseModel
from mmengine.structures import BaseDataElement
from torch import nn
from mmrazor.registry import MODELS
from mmrazor.utils import ValidFixMutable
from ..base import BaseAlgorithm, LossResults
@MODELS.register_module()
class Autoformer(BaseAlgorithm):
"""Implementation of `Autoformer <https://arxiv.org/abs/2107.00651>`_
AutoFormer is dedicated to vision transformer search. AutoFormer
entangles the weights of different blocks in the same layers during
supernet training.
The logic of the search part is implemented in
:class:`mmrazor.engine.EvolutionSearchLoop`
Args:
architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel`
or built model. Corresponding to supernet in NAS algorithm.
mutators (Optional[dict]): The dict of different Mutators config.
Defaults to None.
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
loaded dict or built :obj:`FixSubnet`. Defaults to None.
data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process
config of :class:`BaseDataPreprocessor`. Defaults to None.
init_cfg (Optional[dict]): Init config for ``BaseModule``.
Defaults to None.
Note:
Autoformer uses two mutators which are ``DynamicValueMutator`` and
``ChannelMutator``. `DynamicValueMutator` handle the mutable object
``OneShotMutableValue`` in Autoformer while ChannelMutator handle
the mutable object ``OneShotMutableChannel`` in Autoformer.
"""
def __init__(self,
architecture: Union[BaseModel, Dict],
mutators: Optional[Dict] = None,
fix_subnet: Optional[ValidFixMutable] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)
# Autoformer support supernet training and subnet retraining.
# fix_subnet is not None, means subnet retraining.
if fix_subnet:
# Avoid circular import
from mmrazor.structures import load_fix_subnet
# According to fix_subnet, delete the unchosen part of supernet
load_fix_subnet(self.architecture, fix_subnet)
self.is_supernet = False
else:
assert mutators is not None, \
'mutator cannot be None when fix_subnet is None.'
if isinstance(mutators, dict):
built_mutators: Dict = dict()
for name, mutator_cfg in mutators.items():
if 'parse_cfg' in mutator_cfg and isinstance(
mutator_cfg['parse_cfg'], dict):
assert mutator_cfg['parse_cfg'][
'type'] == 'Predefined', \
'autoformer only support predefined.'
mutator = MODELS.build(mutator_cfg)
built_mutators[name] = mutator
mutator.prepare_from_supernet(self.architecture)
self.mutators = built_mutators
else:
raise TypeError('mutator should be a `dict` but got '
f'{type(mutator)}')
self.is_supernet = True
def sample_subnet(self) -> Dict:
"""Random sample subnet by mutator."""
subnet_dict = dict()
for name, mutator in self.mutators.items():
if name == 'value_mutator':
subnet_dict.update(
dict((str(group_id), value) for group_id, value in
mutator.sample_choices().items()))
else:
subnet_dict.update(mutator.sample_choices())
return subnet_dict
def set_subnet(self, subnet_dict: Dict) -> None:
"""Set the subnet sampled by :meth:sample_subnet."""
for name, mutator in self.mutators.items():
if name == 'value_mutator':
value_subnet = dict((int(group_id), value)
for group_id, value in subnet_dict.items()
if isinstance(group_id, str))
mutator.set_choices(value_subnet)
else:
channel_subnet = dict(
(group_id, value)
for group_id, value in subnet_dict.items()
if isinstance(group_id, int))
mutator.set_choices(channel_subnet)
def loss(
self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
) -> LossResults:
"""Calculate losses from a batch of inputs and data samples."""
if self.is_supernet:
random_subnet = self.sample_subnet()
self.set_subnet(random_subnet)
return self.architecture(batch_inputs, data_samples, mode='loss')

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .classifiers import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .dynamic_ops import * # noqa: F401,F403
from .generators import * # noqa: F401,F403

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_backbone import DartsBackbone
from .searchable_autoformer import AutoformerBackbone
from .searchable_mobilenet import SearchableMobileNet
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
from .wideresnet import WideResNet
__all__ = [
'SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone',
'WideResNet'
'WideResNet', 'AutoformerBackbone'
]

View File

@ -0,0 +1,374 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmrazor.models.architectures.dynamic_ops.bricks import (
DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed,
DynamicSequential)
from mmrazor.models.mutables import (BaseMutable, BaseMutableChannel,
MutableChannelContainer,
OneShotMutableChannel,
OneShotMutableValue)
from mmrazor.models.mutables.mutable_channel import OneShotMutableChannelUnit
from mmrazor.registry import MODELS
try:
from mmcls.models.backbones.base_backbone import BaseBackbone
except ImportError:
from mmrazor.utils import get_placeholder
BaseBackbone = get_placeholder('mmcls')
class TransformerEncoderLayer(BaseBackbone):
"""Autoformer block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
mlp_ratio (List): Ratio of ffn.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop path rate after attention.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
qkv_bias (bool, optional): Whether to keep bias of qkv.
Defaults to True.
act_cfg (Dict, optional): The config for acitvation function.
Defaults to dict(type='GELU').
norm_cfg (Dict, optional): The config for normalization.
Defaults to dict(type='mmrazor.DynamicLayerNorm').
init_cfg (Dict, optional): The config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
mlp_ratio: float,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
qkv_bias: bool = True,
act_cfg: Dict = dict(type='GELU'),
norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'),
init_cfg: Dict = None) -> None:
super().__init__(init_cfg)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = DynamicMultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop_rate=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
middle_channels = int(embed_dims * mlp_ratio)
self.fc1 = DynamicLinear(embed_dims, middle_channels)
self.fc2 = DynamicLinear(middle_channels, embed_dims)
self.act = build_activation_layer(act_cfg)
@property
def norm1(self):
"""The first normalization."""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""The second normalization."""
return getattr(self, self.norm2_name)
def register_mutables(self, mutable_num_heads: BaseMutable,
mutable_mlp_ratios: BaseMutable,
mutable_q_embed_dims: BaseMutable,
mutable_head_dims: BaseMutable,
mutable_embed_dims: BaseMutable):
"""Mutate the mutables of encoder layer."""
# record the mutables
self.mutable_num_heads = mutable_num_heads
self.mutable_mlp_ratios = mutable_mlp_ratios
self.mutable_q_embed_dims = mutable_q_embed_dims
self.mutable_embed_dims = mutable_embed_dims
self.mutable_head_dims = mutable_head_dims
# handle the mutable of FFN
self.middle_channels = mutable_mlp_ratios * mutable_embed_dims
self.attn.register_mutable_attr('num_heads', mutable_num_heads)
# handle the mutable of the first dynamic LN
MutableChannelContainer.register_mutable_channel_to_module(
self.norm1, self.mutable_embed_dims, True)
# handle the mutable of the second dynamic LN
MutableChannelContainer.register_mutable_channel_to_module(
self.norm2, self.mutable_embed_dims, True)
# handle the mutable of attn
MutableChannelContainer.register_mutable_channel_to_module(
self.attn, self.mutable_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.attn,
self.mutable_q_embed_dims,
True,
end=self.mutable_q_embed_dims.current_choice)
MutableChannelContainer.register_mutable_channel_to_module(
self.attn.rel_pos_embed_k, self.mutable_head_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.attn.rel_pos_embed_v, self.mutable_head_dims, False)
# handle the mutable of fc
MutableChannelContainer.register_mutable_channel_to_module(
self.fc1, mutable_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.fc1,
self.middle_channels,
True,
start=0,
end=self.middle_channels.current_choice)
MutableChannelContainer.register_mutable_channel_to_module(
self.fc2,
self.middle_channels,
False,
start=0,
end=self.middle_channels.current_choice)
MutableChannelContainer.register_mutable_channel_to_module(
self.fc2, mutable_embed_dims, True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward of Transformer Encode Layer."""
residual = x
x = self.norm1(x)
x = self.attn(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return residual + x
@MODELS.register_module()
class AutoformerBackbone(BaseBackbone):
"""Autoformer backbone.
A PyTorch implementation of Autoformer introduced by:
`AutoFormer: Searching Transformers for Visual Recognition
<https://arxiv.org/abs/2107.00651>`_
Modified from the `official repo
<https://github.com/microsoft/Cream/blob/main/AutoFormer/>`.
Args:
arch_setting (Dict[str, List]): Architecture settings.
img_size (int, optional): The image size of input.
Defaults to 224.
patch_size (int, optional): The patch size of autoformer.
Defaults to 16.
in_channels (int, optional): The input channel dimension.
Defaults to 3.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool, optional): Whether to keep bias of qkv.
Defaults to True.
norm_cfg (Dict, optional): The config of normalization.
Defaults to dict(type='mmrazor.DynamicLayerNorm').
act_cfg (Dict, optional): The config of activation functions.
Defaults to dict(type='GELU').
use_final_norm (bool, optional): Whether use final normalization.
Defaults to True.
init_cfg (Dict, optional): The config for initialization.
Defaults to None.
Excamples:
>>> arch_setting = dict(
... mlp_ratios=[3.0, 3.5, 4.0],
... num_heads=[8, 9, 10],
... depth=[14, 15, 16],
... embed_dims=[528, 576, 624]
... )
>>> model = AutoformerBackbone(arch_setting=arch_setting)
"""
def __init__(self,
arch_setting: Dict[str, List],
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
qkv_bias: bool = True,
norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'),
act_cfg: Dict = dict(type='GELU'),
use_final_norm: bool = True,
init_cfg: Dict = None) -> None:
super().__init__(init_cfg)
self.arch_setting = arch_setting
self.img_size = img_size
self.patch_size = patch_size
self.qkv_bias = qkv_bias
self.in_channels = in_channels
self.drop_rate = drop_rate
self.use_final_norm = use_final_norm
self.act_cfg = act_cfg
# adapt mutable settings
self.mlp_ratio_range: List = self.arch_setting['mlp_ratios']
self.num_head_range: List = self.arch_setting['num_heads']
self.depth_range: List = self.arch_setting['depth']
self.embed_dim_range: List = self.arch_setting['embed_dims']
# mutable variables of autoformer
self.mutable_depth = OneShotMutableValue(
value_list=self.depth_range, default_value=self.depth_range[-1])
self.mutable_embed_dims = OneShotMutableChannel(
num_channels=self.embed_dim_range[-1],
candidate_choices=self.embed_dim_range)
# handle the mutable in multihead attention
self.base_embed_dims = OneShotMutableChannel(
num_channels=64, candidate_choices=[64])
self.mutable_num_heads = [
OneShotMutableValue(
value_list=self.num_head_range,
default_value=self.num_head_range[-1])
for _ in range(self.depth_range[-1])
]
self.mutable_mlp_ratios = [
OneShotMutableValue(
value_list=self.mlp_ratio_range,
default_value=self.mlp_ratio_range[-1])
for _ in range(self.depth_range[-1])
]
self.mutable_q_embed_dims = [
i * self.base_embed_dims for i in self.mutable_num_heads
]
# patch embeddings
self.patch_embed = DynamicPatchEmbed(
img_size=self.img_size,
in_channels=self.in_channels,
embed_dims=self.mutable_embed_dims.num_channels)
# num of patches
self.patch_resolution = [
img_size // patch_size, img_size // patch_size
]
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# cls token and pos embed
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1,
self.mutable_embed_dims.num_channels))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth decay rule
self.dpr = np.linspace(0, drop_path_rate,
self.mutable_depth.max_choice)
# main body
self.blocks = self.make_layers(
embed_dims=self.mutable_embed_dims.num_channels,
depth=self.mutable_depth.max_choice)
# final norm
if self.use_final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mutable_embed_dims.num_channels)
self.add_module(self.norm1_name, norm1)
self.last_mutable = self.mutable_embed_dims
self.register_mutables()
@property
def norm1(self):
"""The first normalization."""
return getattr(self, self.norm1_name)
def make_layers(self, embed_dims, depth):
"""Build multiple TransformerEncoderLayers."""
layers = []
for i in range(depth):
layer = TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=self.mutable_num_heads[i].max_choice,
mlp_ratio=self.mutable_mlp_ratios[i].max_choice,
drop_rate=self.drop_rate,
drop_path_rate=self.dpr[i],
qkv_bias=self.qkv_bias,
act_cfg=self.act_cfg)
layers.append(layer)
return DynamicSequential(*layers)
def register_mutables(self):
"""Mutate the autoformer."""
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
# handle the mutation of depth
self.blocks.register_mutable_attr('depth', self.mutable_depth)
# handle the mutation of patch embed
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)
# handle the dependencies of TransformerEncoderLayers
for i in range(self.mutable_depth.max_choice): # max depth here
layer = self.blocks[i]
layer.register_mutables(
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
mutable_head_dims=self.base_embed_dims,
mutable_embed_dims=self.last_mutable)
# handle the mutable of final norm
if self.use_final_norm:
MutableChannelContainer.register_mutable_channel_to_module(
self.norm1, self.last_mutable, True)
def forward(self, x: torch.Tensor):
"""Forward of Autoformer."""
B = x.shape[0]
x = self.patch_embed(x)
embed_dims = int(self.mutable_embed_dims.current_choice) if isinstance(
self.mutable_embed_dims,
BaseMutableChannel) else self.embed_dim_range[-1]
# cls token
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# pos embed
x = x + self.pos_embed[..., :embed_dims]
x = self.drop_after_pos(x)
# dynamic depth
x = self.blocks(x)
if self.use_final_norm:
x = self.norm1(x)
return (torch.mean(x[:, 1:], dim=1), )

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .image import SearchableImageClassifier
__all__ = ['SearchableImageClassifier']

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from mmrazor.registry import MODELS
try:
from mmcls.models import ImageClassifier
except ImportError:
from mmrazor.utils import get_placeholder
ImageClassifier = get_placeholder('mmcls')
@MODELS.register_module()
class SearchableImageClassifier(ImageClassifier):
"""SearchableImageClassifier for sliceable networks.
Args:
backbone (dict): The same as ImageClassifier.
neck (dict, optional): The same as ImageClassifier. Defaults to None.
head (dict, optional): The same as ImageClassifier. Defaults to None.
pretrained (dict, optional): The same as ImageClassifier. Defaults to
None.
train_cfg (dict, optional): The same as ImageClassifier. Defaults to
None.
data_preprocessor (dict, optional): The same as ImageClassifier.
Defaults to None.
init_cfg (dict, optional): The same as ImageClassifier. Defaults to
None.
connect_head (dict, optional): Dimensions are aligned in head will be
substitute to it's `str type` value, so that search_space of the
first components can be connets to the next. e.g:
{'connect_with_backbone': 'backbone.last_mutable'} means that
func:`connect_with_backbone` will be substitute to backbones
last_mutable. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
pretrained: Optional[str] = None,
train_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None,
connect_head: Optional[dict] = None):
super().__init__(backbone, neck, head, pretrained, train_cfg,
data_preprocessor, init_cfg)
if self.with_head and connect_head is not None:
for kh, vh in connect_head.items():
component, attr = vh.split('.')
value = getattr(getattr(self, component), attr)
getattr(self.head, kh)(value)

View File

@ -1,15 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d
from .bricks.dynamic_linear import DynamicLinear
from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d,
DynamicBatchNorm3d, SwitchableBatchNorm2d)
from .mixins.dynamic_conv_mixins import DynamicConvMixin
from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
DynamicLinearMixin, DynamicMixin)
__all__ = [
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear',
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d',
'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin',
'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin'
]
from .bricks import * # noqa: F401,F403
from .head import * # noqa: F401,F403
from .mixins import * # noqa: F401,F403

View File

@ -1 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_container import DynamicSequential
from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d
from .dynamic_embed import DynamicPatchEmbed
from .dynamic_linear import DynamicLinear
from .dynamic_multi_head_attention import DynamicMultiheadAttention
from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d,
DynamicBatchNorm3d, DynamicLayerNorm,
SwitchableBatchNorm2d)
from .dynamic_relative_position import DynamicRelativePosition2D
__all__ = [
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear',
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d',
'SwitchableBatchNorm2d', 'DynamicSequential', 'DynamicPatchEmbed',
'DynamicLayerNorm', 'DynamicRelativePosition2D',
'DynamicMultiheadAttention'
]

View File

@ -0,0 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Iterator, Optional, Set
import torch.nn as nn
from mmengine.model import Sequential
from torch import Tensor
from torch.nn import Module
from mmrazor.models.mutables import DerivedMutable, MutableValue
from mmrazor.models.mutables.base_mutable import BaseMutable
from ..mixins import DynamicMixin
class DynamicSequential(Sequential, DynamicMixin):
"""Dynamic Sequential Container."""
mutable_attrs: nn.ModuleDict
accepted_mutable_attrs: Set[str] = {'depth'}
forward_ignored_module = (MutableValue, DerivedMutable, nn.ModuleDict)
def __init__(self, *args, init_cfg: Optional[dict] = None):
super().__init__(*args, init_cfg=init_cfg)
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
@property
def mutable_depth(self):
"""Mutable depth."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['depth']
def register_mutable_attr(self: Sequential, attr: str,
mutable: BaseMutable):
"""Register attribute of mutable."""
if attr == 'depth':
self._register_mutable_depth(mutable)
else:
raise NotImplementedError
def _register_mutable_depth(self: Sequential, mutable_depth: MutableValue):
"""Register mutable depth."""
assert hasattr(self, 'mutable_attrs')
assert mutable_depth.current_choice is not None
current_depth = mutable_depth.current_choice
if current_depth > len(self._modules):
raise ValueError(f'Expect depth of mutable to be smaller than '
f'{len(self._modules)} as `depth`, '
f'but got: {current_depth}.')
self.mutable_attrs['depth'] = mutable_depth
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return Sequential
def to_static_op(self: Sequential) -> Sequential:
"""Convert dynamic Sequential to static one."""
self.check_if_mutables_fixed()
if self.mutable_depth is None:
fixed_depth = len(self)
else:
fixed_depth = self.get_current_choice(self.mutable_depth)
modules = []
passed_module_nums = 0
for module in self:
if isinstance(module, self.forward_ignored_module):
continue
else:
passed_module_nums += 1
if passed_module_nums > fixed_depth:
break
modules.append(module)
return Sequential(*modules)
def forward(self, x: Tensor) -> Tensor:
"""Forward of Dynamic Sequential."""
if self.mutable_depth is None:
return self(x)
current_depth = self.get_current_choice(self.mutable_depth)
passed_module_nums = 0
for module in self.pure_modules():
passed_module_nums += 1
if passed_module_nums > current_depth:
break
x = module(x)
return x
@property
def pure_module_nums(self) -> int:
"""Number of pure module."""
return sum(1 for _ in self.pure_modules())
def pure_modules(self) -> Iterator[Module]:
"""nn.Module would influence the forward of Sequential."""
for module in self._modules.values():
if isinstance(module, self.forward_ignored_module):
continue
yield module
@classmethod
def convert_from(cls, module: Sequential):
"""Convert the static Sequential to dynamic one."""
dynamic_m = cls(module._modules)
return dynamic_m

View File

@ -0,0 +1,142 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Dict, Set, Tuple
import torch.nn as nn
import torch.nn.functional as F
from mmcls.models.utils import PatchEmbed
from mmengine import print_log
from torch import Tensor
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.registry import MODELS
from ..mixins import DynamicChannelMixin
@MODELS.register_module()
class DynamicPatchEmbed(PatchEmbed, DynamicChannelMixin):
"""Dynamic Patch Embedding.
Note:
Arguments for ``__init__`` of ``DynamicPatchEmbed`` is totally same as
:obj:`mmcls.models.utils.PatchEmbed`.
Attributes:
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
such as `embed_dims`. The key of the dict must in
``accepted_mutable_attrs``.
"""
mutable_attrs: nn.ModuleDict
accepted_mutable_attrs: Set[str] = {'embed_dims'}
attr_mappings: Dict[str, str] = {
'in_channels': 'embed_dims',
'out_channels': 'embed_dims'
}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
@property
def mutable_embed_dims(self):
"""Mutable embedding dimension."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['embed_dims']
def register_mutable_attr(self: PatchEmbed, attr: str,
mutable: BaseMutable):
"""Register attribute of mutable."""
self.check_mutable_attr_valid(attr)
if attr in self.attr_mappings:
attr_map = self.attr_mappings[attr]
assert attr_map in self.accepted_mutable_attrs
if attr_map in self.mutable_attrs:
print_log(
f'{attr_map}({attr}) is already in `mutable_attrs`',
level=logging.WARNING)
else:
self._register_mutable_attr(attr_map, mutable)
elif attr in self.accepted_mutable_attrs:
self._register_mutable_attr(attr, mutable)
else:
raise NotImplementedError
def _register_mutable_attr(self, attr, mutable):
"""Register `embed_dims`."""
if attr == 'embed_dims':
self._register_embed_dims(mutable)
else:
raise NotImplementedError
def _register_embed_dims(self: PatchEmbed,
mutable_patch_embedding: BaseMutable) -> None:
"""Register mutable embedding dimension."""
mask_size = mutable_patch_embedding.current_mask.size(0)
if mask_size != self.embed_dims:
raise ValueError(
f'Expect mask size of mutable to be {self.embed_dims} as '
f'`embed_dims`, but got: {mask_size}.')
self.mutable_attrs['embed_dims'] = mutable_patch_embedding
def _get_dynamic_params(self: PatchEmbed) -> Tuple[Tensor, Tensor]:
"""Get mask of ``embed_dims``"""
if 'embed_dims' not in self.mutable_attrs:
return self.projection.weight, self.projection.bias
else:
out_mask = self.mutable_embed_dims.current_mask.to(
self.projection.weight.device)
weight = self.projection.weight[out_mask][:]
bias = self.projection.bias[
out_mask] if self.projection.bias is not None else None # noqa: E501
return weight, bias
def to_static_op(self: PatchEmbed) -> nn.Module:
"""Convert dynamic PatchEmbed to static PatchEmbed."""
self.check_if_mutables_fixed()
assert self.mutable_embed_dims is not None
weight, bias = self._get_dynamic_params()
static_patch_embed = self.static_op_factory(
img_size=self.img_size,
in_channels=3,
embed_dims=self.mutable_embed_dims.activated_channels)
static_patch_embed.projection.weight = nn.Parameter(weight.clone())
static_patch_embed.projection.bias = nn.Parameter(bias.clone())
return static_patch_embed
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return PatchEmbed
@classmethod
def convert_from(cls, module) -> nn.Module:
"""Convert a PatchEmbed to a DynamicPatchEmbed."""
dynamic_patch_embed = cls(
img_size=module.img_size,
in_channels=3,
embed_dims=module.embed_dims,
norm_cfg=None,
conv_cfg=None,
init_cfg=None)
return dynamic_patch_embed
def forward(self, x: Tensor) -> Tensor:
"""Forward of dynamic patch embed."""
weight, bias = self._get_dynamic_params()
x = F.conv2d(
x,
weight,
bias,
stride=16,
padding=self.projection.padding,
dilation=self.projection.dilation).flatten(2).transpose(1, 2)
return x

View File

@ -0,0 +1,280 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Dict, Set, Tuple
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from torch import Tensor
from mmrazor.models.architectures.ops import MultiheadAttention
from mmrazor.models.mutables.base_mutable import BaseMutable
from ..mixins import DynamicChannelMixin
from .dynamic_relative_position import DynamicRelativePosition2D # noqa: E501
class DynamicMultiheadAttention(MultiheadAttention, DynamicChannelMixin):
"""Dynamic Multihead Attention with iRPE..
Note:
Arguments for ``__init__`` of ``DynamicMultiheadAttention`` is
totally same as
:obj:`mmrazor.models.architectures.MultiheadAttention`.
Attributes:
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
such as `num_heads` `embed_dims` `q_embed_dims`.
The key of the dict must in ``accepted_mutable_attrs``.
"""
mutable_attrs: nn.ModuleDict
relative_position: bool
max_relative_position: int
w_qs: nn.Linear
w_ks: nn.Linear
w_vs: nn.Linear
embed_dims: int
q_embed_dims: int
proj: nn.Linear
attn_drop_rate: float
accepted_mutable_attrs: Set[str] = {
'num_heads', 'embed_dims', 'q_embed_dims'
}
attr_mappings: Dict[str, str] = {
'in_channels': 'embed_dims',
'out_channels': 'q_embed_dims',
}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
# dynamic image relative position encoding
if self.relative_position:
self.rel_pos_embed_k = DynamicRelativePosition2D(
self.head_dims, self.max_relative_position)
self.rel_pos_embed_v = DynamicRelativePosition2D(
self.head_dims, self.max_relative_position)
@property
def mutable_num_heads(self):
"""Mutable number of heads."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['num_heads']
@property
def mutable_embed_dims(self):
"""Mutable embedding dimension."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['embed_dims']
@property
def mutable_q_embed_dims(self):
"""Mutable intermediate embedding dimension."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['q_embed_dims']
def register_mutable_attr(self, attr: str, mutable: BaseMutable):
"""Register attribute of mutable."""
self.check_mutable_attr_valid(attr)
if attr in self.attr_mappings:
attr_map = self.attr_mappings[attr]
assert attr_map in self.accepted_mutable_attrs
# if hasattr(self, 'mutable_attrs'):
if attr_map in self.mutable_attrs:
print_log(
f'{attr_map}({attr}) is already in `mutable_attrs`',
level=logging.WARNING)
else:
self._register_mutable_attr(attr_map, mutable)
elif attr in self.accepted_mutable_attrs:
self._register_mutable_attr(attr, mutable)
else:
raise NotImplementedError
def _register_mutable_attr(self, attr: str, mutable: BaseMutable):
"""Register `embed_dims` `q_embed_dims` `num_heads`"""
if attr == 'num_heads':
self._register_mutable_num_heads(mutable)
elif attr == 'embed_dims':
self._register_mutable_embed_dims(mutable)
elif attr == 'q_embed_dims':
self._register_mutable_q_embed_dims(mutable)
else:
raise NotImplementedError
def _register_mutable_num_heads(self, mutable_num_heads):
"""Register the mutable number of heads."""
assert hasattr(self, 'mutable_attrs')
current_choice = mutable_num_heads.current_choice
if current_choice > self.num_heads:
raise ValueError(
f'Expect value of mutable to be smaller or equal than '
f'{self.num_heads} as `num_heads`, but got: {current_choice}.')
self.mutable_attrs['num_heads'] = mutable_num_heads
def _register_mutable_embed_dims(self, mutable_embed_dims):
"""Register mutable embedding dimension."""
assert hasattr(self, 'mutable_attrs')
mask_size = mutable_embed_dims.current_mask.size(0)
if mask_size != self.embed_dims:
raise ValueError(
f'Expect mask size of mutable to be {self.embed_dims} as '
f'`embed_dims`, but got: {mask_size}.')
self.mutable_attrs['embed_dims'] = mutable_embed_dims
def _register_mutable_q_embed_dims(self, mutable_q_embed_dims):
"""Register intermediate mutable embedding dimension."""
assert hasattr(self, 'mutable_attrs')
self.mutable_attrs['q_embed_dims'] = mutable_q_embed_dims
def _get_dynamic_proj_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]:
"""Get parameters of dynamic projection.
Note:
The input dimension is decided by `mutable_q_embed_dims`.
The output dimension is decided by `mutable_embed_dims`.
"""
# TODO support mask
if self.mutable_embed_dims is None and \
self.mutable_q_embed_dims is None:
return w.weight, w.bias
if self.mutable_q_embed_dims is not None:
in_features = self.mutable_q_embed_dims.activated_channels
else:
in_features = self.embed_dims
if self.mutable_embed_dims is not None:
out_features = self.mutable_embed_dims.activated_channels
else:
out_features = self.embed_dims
weight = w.weight[:out_features, :in_features]
bias = w.bias[:out_features] if w.bias is not None else None
return weight, bias
def _get_dynamic_qkv_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]:
"""Get parameters of dynamic QKV.
Note:
The output dimension is decided by `mutable_q_embed_dims`.
The input dimension is decided by `mutable_embed_dims`.
"""
# TODO support mask later
if self.mutable_q_embed_dims is None and \
self.mutable_embed_dims is None:
return w.weight, w.bias
if self.mutable_embed_dims is not None:
in_features = self.mutable_embed_dims.activated_channels
else:
in_features = self.embed_dims
if self.mutable_q_embed_dims is not None:
out_features = self.mutable_q_embed_dims.activated_channels
else:
out_features = self.mutable_q_embed_dims
weight = w.weight[:out_features, :in_features]
bias = w.bias[:out_features] if w.bias is not None else None
return weight, bias
def to_static_op(self) -> MultiheadAttention:
"""Convert dynamic MultiheadAttention to static one."""
self.check_if_mutables_fixed()
embed_dims = self.mutable_embed_dims.activated_channels
num_heads = self.mutable_num_heads.current_choice
q_w, q_b = self._get_dynamic_qkv_params(self.w_qs)
k_w, k_b = self._get_dynamic_qkv_params(self.w_ks)
v_w, v_b = self._get_dynamic_qkv_params(self.w_vs)
proj_w, proj_b = self._get_dynamic_proj_params(self.proj)
static_mha = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
input_dims=None,
attn_drop_rate=self.attn_drop_rate,
relative_position=self.relative_position,
max_relative_position=self.max_relative_position)
static_mha.w_qs.weight = nn.Parameter(q_w.clone())
static_mha.w_qs.bias = nn.Parameter(q_b.clone())
static_mha.w_ks.weight = nn.Parameter(k_w.clone())
static_mha.w_ks.bias = nn.Parameter(k_b.clone())
static_mha.w_vs.weight = nn.Parameter(v_w.clone())
static_mha.w_vs.bias = nn.Parameter(v_b.clone())
static_mha.proj.weight = nn.Parameter(proj_w.clone())
static_mha.proj.bias = nn.Parameter(proj_b.clone())
if self.relative_position:
static_mha.rel_pos_embed_k = self.rel_pos_embed_k.to_static_op()
static_mha.rel_pos_embed_v = self.rel_pos_embed_v.to_static_op()
return static_mha
@classmethod
def convert_from(cls, module):
"""Convert the static module to dynamic one."""
dynamic_mha = cls(
embed_dims=module.embed_dims,
num_heads=module.num_heads,
)
return dynamic_mha
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return MultiheadAttention
def forward(self, x: Tensor) -> Tensor:
"""Forward of dynamic MultiheadAttention."""
B, N = x.shape[0], x.shape[1]
q_w, q_b = self._get_dynamic_qkv_params(self.w_qs)
k_w, k_b = self._get_dynamic_qkv_params(self.w_ks)
v_w, v_b = self._get_dynamic_qkv_params(self.w_vs)
q_embed_dims = self.mutable_q_embed_dims.activated_channels
num_heads = self.mutable_num_heads.current_choice
q = F.linear(x, q_w, q_b).view(B, N, num_heads,
q_embed_dims // num_heads)
k = F.linear(x, k_w, k_b).view(B, N, num_heads,
q_embed_dims // num_heads)
v = F.linear(x, v_w, v_b).view(B, N, num_heads,
q_embed_dims // num_heads)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.relative_position:
r_p_k = self.rel_pos_embed_k(N, N)
attn = attn + (q.permute(2, 0, 1, 3).reshape(N, num_heads * B, -1) # noqa: E501
@ r_p_k.transpose(2, 1)) \
.transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
if self.relative_position:
r_p_v = self.rel_pos_embed_v(N, N)
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, -1)
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(
B, num_heads, N, -1).transpose(2, 1).reshape(B, N, -1)
# proj
weight, bias = self._get_dynamic_proj_params(self.proj)
x = F.linear(x, weight, bias)
x = self.proj_drop(x)
return x

View File

@ -4,11 +4,12 @@ from typing import Dict, List, Optional
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LayerNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.registry import MODELS
from ..mixins.dynamic_mixins import DynamicBatchNormMixin
from ..mixins import DynamicBatchNormMixin, DynamicLayerNormMixin
class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin):
@ -91,6 +92,7 @@ class DynamicBatchNorm1d(_DynamicBatchNorm):
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return nn.BatchNorm1d
def _check_input_dim(self, input: Tensor) -> None:
@ -106,6 +108,7 @@ class DynamicBatchNorm2d(_DynamicBatchNorm):
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return nn.BatchNorm2d
def _check_input_dim(self, input: Tensor) -> None:
@ -121,6 +124,7 @@ class DynamicBatchNorm3d(_DynamicBatchNorm):
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return nn.BatchNorm3d
def _check_input_dim(self, input: Tensor) -> None:
@ -190,3 +194,61 @@ class SwitchableBatchNorm2d(DynamicBatchNorm2d):
def static_op_factory(self):
"""Return initializer of static op."""
return nn.BatchNorm2d
@MODELS.register_module()
class DynamicLayerNorm(LayerNorm, DynamicLayerNormMixin):
"""Applies Layer Normalization over a mini-batch of inputs according to the
`mutable_num_channels` dynamically.
Note:
Arguments for ``__init__`` of ``DynamicLayerNorm`` is totally same as
:obj:`torch.nn.LayerNorm`.
Attributes:
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
such as `num_features`. The key of the dict must in
``accepted_mutable_attrs``.
"""
accepted_mutable_attrs = {'num_features'}
def __init__(self, *args, **kwargs):
super(DynamicLayerNorm, self).__init__(*args, **kwargs)
self.mutable_attrs: Dict[str, Optional[BaseMutable]] = nn.ModuleDict()
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return LayerNorm
@classmethod
def convert_from(cls, module: LayerNorm):
"""Convert a _BatchNorm module to a DynamicBatchNorm.
Args:
module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module.
"""
dynamic_ln = cls(
normalized_shape=module.normalized_shape,
eps=module.eps,
elementwise_affine=module.elementwise_affine)
return dynamic_ln
def forward(self, input: Tensor) -> Tensor:
"""Slice the parameters according to `mutable_num_channels`, and
forward."""
self._check_input_dim(input)
weight, bias = self.get_dynamic_params()
self.normalized_shape = (
self.mutable_num_features.activated_channels, )
return F.layer_norm(input, self.normalized_shape, weight, bias,
self.eps)
def _check_input_dim(self, input: Tensor) -> None:
"""Check if input dimension is valid."""
if input.dim() != 3:
raise ValueError('expected 3D input (got {}D input)'.format(
input.dim()))

View File

@ -0,0 +1,154 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Dict, Set
import torch
from mmengine import print_log
from torch import Tensor, nn
from mmrazor.models.architectures.ops import RelativePosition2D
from mmrazor.models.mutables.base_mutable import BaseMutable
from ..mixins import DynamicChannelMixin
class DynamicRelativePosition2D(RelativePosition2D, DynamicChannelMixin):
"""Searchable RelativePosition module.
Note:
Arguments for ``__init__`` of ``DynamicRelativePosition2D`` is totally
same as :obj:`mmrazor.models.architectures.RelativePosition2D`.
Attributes:
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
such as `head_dims`. The key of the dict must in
``accepted_mutable_attrs``.
"""
mutable_attrs: nn.ModuleDict
head_dims: int
max_relative_position: int
embeddings_table_v: nn.Parameter
embeddings_table_h: nn.Parameter
accepted_mutable_attrs: Set[str] = {'head_dims'}
attr_mappings: Dict[str, str] = {
'in_channels': 'head_dims',
'out_channels': 'head_dims',
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
@property
def mutable_head_dims(self):
"""Mutable head dimension."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['head_dims']
def register_mutable_attr(self, attr: str, mutable: BaseMutable):
"""Register attribute of mutable."""
self.check_mutable_attr_valid(attr)
if attr in self.attr_mappings:
attr_map = self.attr_mappings[attr]
assert attr_map in self.accepted_mutable_attrs
if attr_map in self.mutable_attrs:
print_log(
f'{attr_map}({attr}) is already in `mutable_attrs`',
level=logging.WARNING)
else:
self._register_mutable_attr(attr_map, mutable)
elif attr in self.accepted_mutable_attrs:
self._register_mutable_attr(attr, mutable)
else:
raise NotImplementedError
def _register_mutable_attr(self, attr, mutable):
"""Register `head_dims`"""
if attr == 'head_dims':
self._registry_mutable_head_dims(mutable)
else:
raise NotImplementedError
def _registry_mutable_head_dims(self,
mutable_head_dims: BaseMutable) -> None:
"""Register head dimension."""
assert hasattr(self, 'mutable_attrs')
self.mutable_attrs['head_dims'] = mutable_head_dims
def to_static_op(self) -> nn.Module:
"""Convert dynamic RelativePosition2D to static One."""
self.check_if_mutables_fixed()
assert self.mutable_head_dims is not None
self.current_head_dim = self.mutable_head_dims.activated_channels
static_relative_position = self.static_op_factory(
self.current_head_dim)
static_relative_position.embeddings_table_v = \
nn.Parameter(
self.embeddings_table_v[:, :self.current_head_dim].clone())
static_relative_position.embeddings_table_h = \
nn.Parameter(
self.embeddings_table_h[:, :self.current_head_dim].clone())
return static_relative_position
@property
def static_op_factory(self):
"""Corresponding Pytorch OP."""
return RelativePosition2D
@classmethod
def convert_from(cls, module):
"""Convert a RP to a dynamic RP."""
dynamic_rp = cls(
head_dims=module.head_dims,
max_relative_position=module.max_relative_position)
return dynamic_rp
def forward(self, length_q, length_k) -> Tensor:
"""Forward of Dynamic Relative Position."""
if self.mutable_head_dims is None:
self.current_head_dim = self.head_dims
else:
self.current_head_dim = self.mutable_head_dims.activated_channels
self.sample_eb_table_h = self.embeddings_table_h[:, :self.
current_head_dim]
self.sample_eb_table_v = self.embeddings_table_v[:, :self.
current_head_dim]
# remove the first cls token distance computation
length_q = length_q - 1
length_k = length_k - 1
range_vec_q = torch.arange(length_q)
range_vec_k = torch.arange(length_k)
# compute the row and column distance
distance_mat_v = (
range_vec_k[None, :] // int(length_q**0.5) -
range_vec_q[:, None] // int(length_q**0.5))
distance_mat_h = (
range_vec_k[None, :] % int(length_q**0.5) -
range_vec_q[:, None] % int(length_q**0.5))
distance_mat_clipped_v = torch.clamp(distance_mat_v,
-self.max_relative_position,
self.max_relative_position)
distance_mat_clipped_h = torch.clamp(distance_mat_h,
-self.max_relative_position,
self.max_relative_position)
final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1
final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1
# pad the 0 which represent the cls token
final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0),
'constant', 0)
final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0),
'constant', 0)
final_mat_v = torch.LongTensor(final_mat_v)
final_mat_h = torch.LongTensor(final_mat_h)
# get the embeddings with the corresponding distance
embeddings = self.sample_eb_table_v[final_mat_v] + \
self.sample_eb_table_h[final_mat_h]
return embeddings

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_linear_head import DynamicLinearClsHead # noqa: F401
__all__ = ['DynamicLinearClsHead']

View File

@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Optional, Tuple
import torch
from mmcls.models import ClsHead
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables.mutable_channel.units import \
OneShotMutableChannelUnit
from mmrazor.registry import MODELS
from ..bricks.dynamic_linear import DynamicLinear
class DynamicHead:
@abstractmethod
def connect_with_backbone(self,
backbone_output_mutable: BaseMutable) -> None:
"""Connect with Dynamic Backbone."""
...
@MODELS.register_module()
class DynamicLinearClsHead(ClsHead, DynamicHead):
"""Dynamic Linear classification head for Autoformer.
Args:
num_classes (int): Number of classes.
in_channels (int): Number of input channels.
init_cfg (Optional[dict], optional): Init config.
Defaults to dict(type='Normal',
layer='DynamicLinear', std=0.01).
"""
def __init__(self,
num_classes: int = 1000,
in_channels: int = 624,
init_cfg: Optional[dict] = dict(
type='Normal', layer='DynamicLinear', std=0.01),
**kwargs):
super().__init__(init_cfg=init_cfg, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self.fc = DynamicLinear(self.in_channels, self.num_classes)
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``LinearClsHead``, we just obtain the
feature of the last stage.
"""
# The LinearClsHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.fc(pre_logits)
return cls_score
def connect_with_backbone(self,
backbone_output_mutable: BaseMutable) -> None:
"""Connect dynamic backbone."""
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
MutableChannelContainer.register_mutable_channel_to_module(
self.fc, backbone_output_mutable, False)

View File

@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_conv_mixins import DynamicConvMixin
from .dynamic_layernorm_mixins import DynamicLayerNormMixin
from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
DynamicLinearMixin, DynamicMixin)
__all__ = [
'DynamicChannelMixin', 'DynamicBatchNormMixin', 'DynamicLinearMixin',
'DynamicMixin', 'DynamicConvMixin'
'DynamicChannelMixin',
'DynamicBatchNormMixin',
'DynamicLinearMixin',
'DynamicMixin',
'DynamicConvMixin',
'DynamicLayerNormMixin',
]

View File

@ -0,0 +1,147 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Dict, Optional, Set, Tuple
import torch
from mmengine import print_log
from torch import Tensor, nn
from torch.nn import LayerNorm
from mmrazor.models.mutables.base_mutable import BaseMutable
from .dynamic_mixins import DynamicChannelMixin
class DynamicLayerNormMixin(DynamicChannelMixin):
"""A mixin class for Pytorch LayerNorm, which can mutate
``num_features``."""
accepted_mutable_attrs: Set[str] = {'num_features'}
attr_mappings: Dict[str, str] = {
'in_channels': 'num_features',
'out_channels': 'num_features',
}
@property
def num_features(self):
return getattr(self, 'normalized_shape')[0]
@property
def mutable_num_features(self):
"""Mutable number of features."""
assert hasattr(self, 'mutable_attrs')
return self.mutable_attrs['num_features']
def register_mutable_attr(self, attr, mutable):
"""Register attribute of mutable."""
self.check_mutable_attr_valid(attr)
if attr in self.attr_mappings:
attr_map = self.attr_mappings[attr]
assert attr_map in self.accepted_mutable_attrs
if attr_map in self.mutable_attrs:
print_log(
f'{attr_map}({attr}) is already in `mutable_attrs`',
level=logging.WARNING)
else:
self._register_mutable_attr(attr_map, mutable)
elif attr in self.accepted_mutable_attrs:
self._register_mutable_attr(attr, mutable)
else:
raise NotImplementedError
def _register_mutable_attr(self, attr, mutable):
"""Register `num_features`."""
if attr == 'num_features':
self._register_mutable_num_features(mutable)
else:
raise NotImplementedError
def _register_mutable_num_features(
self: LayerNorm, mutable_num_features: BaseMutable) -> None:
"""Mutate ``num_features`` with given mutable.
Args:
mutable_num_features (BaseMutable): Mutable for controlling
``num_features``.
Raises:
RuntimeError: Error if both ``affine`` and
``tracking_running_stats`` are False.
ValueError: Error if size of mask if not same as ``num_features``.
"""
if not self.elementwise_affine:
raise RuntimeError(
'num_features can not be mutated if both `affine` and '
'`tracking_running_stats` are False')
self.check_mutable_channels(mutable_num_features)
mask_size = mutable_num_features.current_mask.size(0)
# normalized_shape is a tuple
if mask_size != self.normalized_shape[0]:
raise ValueError(
f'Expect mask size of mutable to be {self.normalized_shape}'
f' as `normalized_shape`, but got: {mask_size}.')
self.mutable_attrs['num_features'] = mutable_num_features
def _get_num_features_mask(self: LayerNorm) -> Optional[torch.Tensor]:
"""Get mask of ``num_features``."""
if self.elementwise_affine:
refer_tensor = self.weight
else:
return None
if 'num_features' in self.mutable_attrs:
out_mask = self.mutable_num_features.current_mask.to(
refer_tensor.device)
else:
out_mask = torch.ones_like(refer_tensor).bool()
return out_mask
def get_dynamic_params(
self: LayerNorm) -> Tuple[Optional[Tensor], Optional[Tensor]]:
"""Get dynamic parameters that will be used in forward process.
Returns:
Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor],
Optional[Tensor]]: Sliced running_mean, running_var, weight and
bias.
"""
out_mask = self._get_num_features_mask()
if self.elementwise_affine:
weight = self.weight[out_mask]
bias = self.bias[out_mask]
else:
weight, bias = self.weight, self.bias
return weight, bias
def to_static_op(self: LayerNorm) -> nn.Module:
"""Convert dynamic LayerNormxd to :obj:`torch.nn.LayerNormxd`.
Returns:
torch.nn.LayerNormxd: :obj:`torch.nn.LayerNormxd` with sliced
parameters.
"""
self.check_if_mutables_fixed()
weight, bias = self.get_dynamic_params()
if 'num_features' in self.mutable_attrs:
num_features = self.mutable_attrs['num_features'].current_mask.sum(
).item()
else:
num_features = self.num_features
static_ln = self.static_op_factory(
normalized_shape=num_features,
eps=self.eps,
elementwise_affine=self.elementwise_affine)
if weight is not None:
static_ln.weight = nn.Parameter(weight.clone())
if bias is not None:
static_ln.bias = nn.Parameter(bias.clone())
return static_ln

View File

@ -74,12 +74,16 @@ class DynamicMixin(ABC):
Raises:
RuntimeError: Error if a existing mutable is not fixed.
"""
from mmrazor.models.mutables import (DerivedMutable,
MutableChannelContainer)
def check_fixed(mutable: Optional[BaseMutable]) -> None:
if mutable is not None and not mutable.is_fixed:
raise RuntimeError(f'Mutable {type(mutable)} is not fixed.')
for mutable in self.mutable_attrs.values(): # type: ignore
if isinstance(mutable, (MutableChannelContainer, DerivedMutable)):
continue
check_fixed(mutable)
def check_mutable_attr_valid(self, attr):
@ -115,6 +119,11 @@ class DynamicChannelMixin(DynamicMixin):
``mutable_out_channels`` APIs.
"""
attr_mappings: Dict[str, str] = {
'in_channels': 'in_channels',
'out_channels': 'out_channels',
}
@staticmethod
def check_mutable_channels(mutable_channels: BaseMutable) -> None:
"""Check if mutable has `currnet_mask` attribute.

View File

@ -6,9 +6,11 @@ from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv
from .gather_tensors import GatherTensors
from .mobilenet_series import MBBlock
from .shufflenet_series import ShuffleBlock, ShuffleXception
from .transformer_series import MultiheadAttention, RelativePosition2D
__all__ = [
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity',
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors'
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors',
'RelativePosition2D', 'MultiheadAttention'
]

View File

@ -0,0 +1,192 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
class RelativePosition2D(nn.Module):
"""Rethinking and Improving Relative Position Encoding for Vision
Transformer.
ICCV 2021. https://arxiv.org/pdf/2107.14222.pdf
Image RPE (iRPE for short) methods are new relative position encoding
methods dedicated to 2D images.
Args:
head_dims (int): embedding dims of relative position.
max_relative_position (int): The max relative position distance.
"""
def __init__(self, head_dims: int, max_relative_position: int = 14):
super().__init__()
self.head_dims = head_dims
self.max_relative_position = max_relative_position
# The first element in embeddings_table_v is the vertical embedding
# for the class
self.embeddings_table_v = nn.Parameter(
torch.randn(max_relative_position * 2 + 2, head_dims))
self.embeddings_table_h = nn.Parameter(
torch.randn(max_relative_position * 2 + 2, head_dims))
trunc_normal_(self.embeddings_table_v, std=.02)
trunc_normal_(self.embeddings_table_h, std=.02)
def forward(self, length_q, length_k):
# remove the first cls token distance computation
length_q = length_q - 1
length_k = length_k - 1
range_vec_q = torch.arange(length_q)
range_vec_k = torch.arange(length_k)
# compute the row and column distance
distance_mat_v = (
range_vec_k[None, :] // int(length_q**0.5) -
range_vec_q[:, None] // int(length_q**0.5))
distance_mat_h = (
range_vec_k[None, :] % int(length_q**0.5) -
range_vec_q[:, None] % int(length_q**0.5))
# clip the distance to the range of
# [-max_relative_position, max_relative_position]
distance_mat_clipped_v = torch.clamp(distance_mat_v,
-self.max_relative_position,
self.max_relative_position)
distance_mat_clipped_h = torch.clamp(distance_mat_h,
-self.max_relative_position,
self.max_relative_position)
# translate the distance from [1, 2 * max_relative_position + 1],
# 0 is for the cls token
final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1
final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1
# pad the 0 which represent the cls token
final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0),
'constant', 0)
final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0),
'constant', 0)
final_mat_v = torch.LongTensor(final_mat_v)
final_mat_h = torch.LongTensor(final_mat_h)
# get the embeddings with the corresponding distance
embeddings = self.embeddings_table_v[
final_mat_v] + self.embeddings_table_h[final_mat_h]
return embeddings
class MultiheadAttention(nn.Module):
"""Multi-head Attention Module with iRPE.
This module implements multi-head attention that supports different input
dims and embed dims. And it also supports a shortcut from ``value``, which
is useful if input dims is not the same with embed dims.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
relative_position (bool, optional): Whether use relative position.
Defaults to True.
max_relative_position (int): The max relative position distance.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
input_dims: Optional[int] = None,
attn_drop_rate: float = 0.,
proj_drop: float = 0.,
dropout_layer: Dict = dict(type='Dropout', drop_prob=0.),
relative_position: Optional[bool] = True,
max_relative_position: int = 14,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
proj_bias: bool = True,
v_shortcut: bool = False,
init_cfg: Optional[dict] = None):
super().__init__()
self.input_dims = input_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.v_shortcut = v_shortcut
self.relative_position = relative_position
self.max_relative_position = max_relative_position
self.attn_drop_rate = attn_drop_rate
self.head_dims = 64 # unit
self.scale = qk_scale or self.head_dims**-0.5
self.q_embed_dims = num_heads * self.head_dims
self.w_qs = nn.Linear(
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
self.w_ks = nn.Linear(
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
self.w_vs = nn.Linear(
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(
num_heads * self.head_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.out_drop = nn.Dropout(dropout_layer['drop_prob'])
# image relative position encoding
if self.relative_position:
self.rel_pos_embed_k = RelativePosition2D(
self.head_dims, self.max_relative_position)
self.rel_pos_embed_v = RelativePosition2D(
self.head_dims, self.max_relative_position)
def forward(self, x):
B, N, _ = x.shape
q = self.w_qs(x).view(B, N, self.num_heads, self.head_dims)
k = self.w_ks(x).view(B, N, self.num_heads, self.head_dims)
v = self.w_vs(x).view(B, N, self.num_heads, self.head_dims)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.relative_position:
r_p_k = self.rel_pos_embed_k(N, N)
attn = attn + (q.permute(2, 0, 1, 3).reshape(N, self.num_heads * B, -1) # noqa: E501
@ r_p_k.transpose(2, 1)) \
.transpose(1, 0).reshape(B, self.num_heads, N, N) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
if self.relative_position:
r_p_v = self.rel_pos_embed_v(N, N)
t_attn = attn.permute(2, 0, 1, 3).reshape(N, B * self.num_heads,
-1)
x = x + (t_attn @ r_p_v).transpose(1, 0).reshape(
B, self.num_heads, N, -1).transpose(2, 1).reshape(B, N, -1)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x

View File

@ -61,8 +61,15 @@ def _expand_mask_fn(
def fn():
mask = mutable.current_mask
expand_num_channels = int(mask.size(0) * expand_ratio)
expand_choice = int(mutable.current_choice * expand_ratio)
if isinstance(expand_ratio, int):
expand_num_channels = mask.size(0) * expand_ratio
expand_choice = mutable.current_choice * expand_ratio
elif isinstance(expand_ratio, float):
expand_num_channels = int(mask.size(0) * expand_ratio)
expand_choice = int(mutable.current_choice * expand_ratio)
else:
raise NotImplementedError(
f'Not support type of expand_ratio: {type(expand_ratio)}')
expand_mask = torch.zeros(expand_num_channels).bool()
expand_mask[:expand_choice] = True
@ -136,25 +143,62 @@ class DerivedMethodMixin:
def derive_expand_mutable(
self: MutableProtocol,
expand_ratio: Union[int, float]) -> 'DerivedMutable':
expand_ratio: Union[int, BaseMutable, float]) -> 'DerivedMutable':
"""Derive expand mutable, usually used with `expand_ratio`."""
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
# avoid circular import
if isinstance(expand_ratio, int):
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
elif isinstance(expand_ratio, float):
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
elif isinstance(expand_ratio, BaseMutable):
current_ratio = expand_ratio.current_choice
choice_fn = _expand_choice_fn(self, expand_ratio=current_ratio)
else:
raise NotImplementedError(
f'Not support type of ratio: {type(expand_ratio)}')
mask_fn: Optional[Callable] = None
if hasattr(self, 'current_mask'):
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
if isinstance(expand_ratio, int):
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
elif isinstance(expand_ratio, float):
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
elif isinstance(expand_ratio, BaseMutable):
mask_fn = _expand_mask_fn(self, expand_ratio=current_ratio)
else:
raise NotImplementedError(
f'Not support type of ratio: {type(expand_ratio)}')
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)
def derive_divide_mutable(self: MutableProtocol,
ratio: int,
ratio: Union[int, float, BaseMutable],
divisor: int = 8) -> 'DerivedMutable':
"""Derive divide mutable, usually used with `make_divisable`."""
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor)
from .mutable_channel import BaseMutableChannel
# avoid circular import
if isinstance(ratio, int):
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor)
current_ratio = ratio
elif isinstance(ratio, float):
current_ratio = int(ratio)
choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1)
elif isinstance(ratio, BaseMutable):
current_ratio = int(ratio.current_choice)
choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1)
else:
raise NotImplementedError(
f'Not support type of ratio: {type(ratio)}')
mask_fn: Optional[Callable] = None
if hasattr(self, 'current_mask'):
mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor)
if isinstance(self, BaseMutableChannel) and hasattr(
self, 'current_mask'):
mask_fn = _divide_mask_fn(
self, ratio=current_ratio, divisor=divisor)
elif getattr(self, 'mask_fn', None): # OneShotMutableChannel
mask_fn = _divide_mask_fn(
self, ratio=current_ratio, divisor=divisor)
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)

View File

@ -3,9 +3,9 @@ import copy
import torch
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
from mmrazor.registry import MODELS
from mmrazor.utils import IndexDict
from ...architectures.dynamic_ops.mixins import DynamicChannelMixin
from .base_mutable_channel import BaseMutableChannel
from .simple_mutable_channel import SimpleMutableChannel
@ -66,7 +66,7 @@ class MutableChannelContainer(BaseMutableChannel):
def register_mutable(self, mutable_channel: BaseMutableChannel, start: int,
end: int):
"""Register/Store BaseMutableChannel in the MutableChannelContainer in
"""Register/Store BaseMutableChannel in the MutableChannelContainer in
the range [start,end)"""
self.mutable_channels[(start, end)] = mutable_channel

View File

@ -82,7 +82,7 @@ class SquentialMutableChannel(SimpleMutableChannel):
mutable2: OneShotMutableValue) -> Callable:
def fn():
return mutable1.current_choice * mutable2.current_choice
return int(mutable1.current_choice * mutable2.current_choice)
return fn
@ -93,9 +93,10 @@ class SquentialMutableChannel(SimpleMutableChannel):
mask = mutable1.current_mask
max_expand_ratio = mutable2.max_choice
current_expand_ratio = mutable2.current_choice
expand_num_channels = mask.size(0) * max_expand_ratio
expand_num_channels = int(mask.size(0) * max_expand_ratio)
expand_choice = mutable1.current_choice * current_expand_ratio
expand_choice = int(mutable1.current_choice *
current_expand_ratio)
expand_mask = torch.zeros(expand_num_channels).bool()
expand_mask[:expand_choice] = True
@ -113,10 +114,17 @@ class SquentialMutableChannel(SimpleMutableChannel):
def __floordiv__(self, other) -> DerivedMutable:
if isinstance(other, int):
return self.derive_divide_mutable(other)
elif isinstance(other, float):
return self.derive_divide_mutable(int(other))
if isinstance(other, tuple):
assert len(other) == 2
return self.derive_divide_mutable(*other)
from ..mutable_value import OneShotMutableValue
if isinstance(other, OneShotMutableValue):
ratio = other.current_choice
return self.derive_divide_mutable(ratio)
raise TypeError(f'Unsupported type {type(other)} for div!')
def _num2ratio(self, choice: Union[int, float]) -> float:

View File

@ -262,14 +262,16 @@ class ChannelUnit(BaseModule):
def add_ouptut_related(self, channel: Channel):
"""Add a Channel which is output related."""
assert channel.is_output_channel
assert self.num_channels == channel.num_channels
assert self.num_channels == \
int(channel.num_channels // channel.expand_ratio)
if channel not in self.output_related:
self.output_related.append(channel)
def add_input_related(self, channel: Channel):
"""Add a Channel which is input related."""
assert channel.is_output_channel is False
assert self.num_channels == channel.num_channels
assert self.num_channels == \
int(channel.num_channels // channel.expand_ratio)
if channel not in self.input_related:
self.input_related.append(channel)

View File

@ -6,26 +6,23 @@ from typing import Dict, List, Type, TypeVar
import torch.nn as nn
from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel,
MutableChannelContainer)
from mmrazor.models.mutables.mutable_value import MutableValue
from .channel_unit import Channel, ChannelUnit
class MutableChannelUnit(ChannelUnit):
# init methods
def __init__(self, num_channels: int, **kwargs) -> None:
"""MutableChannelUnit inherits from ChannelUnit, which manages channels
with channel-dependency.
with channel-dependency. Compared with ChannelUnit, MutableChannelUnit
defines the core interfaces for pruning. By inheriting
MutableChannelUnit, we can implement a variant pruning and nas
algorithm. These apis includes.
Compared with ChannelUnit, MutableChannelUnit defines the core
interfaces for pruning. By inheriting MutableChannelUnit,
we can implement a variant pruning and nas algorithm.
These apis includes
- basic property
- name
- is_mutable
@ -60,6 +57,7 @@ class MutableChannelUnit(ChannelUnit):
mutable2units,
is_output=True):
for index, mutable in contanier.mutable_channels.items():
expand_ratio = 1
if isinstance(mutable, DerivedMutable):
source_mutables: Set = \
mutable._trace_source_mutables()
@ -72,6 +70,17 @@ class MutableChannelUnit(ChannelUnit):
'used in DerivedMutable')
mutable = list(source_channel_mutables)[0]
source_value_mutables = [
mutable for mutable in source_mutables
if isinstance(mutable, MutableValue)
]
assert len(source_value_mutables) <= 1, (
'only support one mutable value '
'used in DerivedMutable')
expand_ratio = int(
list(source_value_mutables)
[0].current_choice) if source_value_mutables else 1
if mutable not in mutable2units:
mutable2units[mutable] = cls.init_from_mutable_channel(
mutable)
@ -83,14 +92,16 @@ class MutableChannelUnit(ChannelUnit):
module_name,
module,
index,
is_output_channel=is_output))
is_output_channel=is_output,
expand_ratio=expand_ratio))
else:
unit.add_input_related(
Channel(
module_name,
module,
index,
is_output_channel=is_output))
is_output_channel=is_output,
expand_ratio=expand_ratio))
mutable2units: Dict = {}
for name, module in model.named_modules():
@ -121,7 +132,7 @@ class MutableChannelUnit(ChannelUnit):
if channel.is_mutable is False:
all_channel_prunable = False
break
if isinstance(channel.module, dynamic_ops.DynamicChannelMixin):
if isinstance(channel.module, DynamicChannelMixin):
has_dynamic_op = True
return has_dynamic_op, all_channel_prunable
@ -223,29 +234,16 @@ class MutableChannelUnit(ChannelUnit):
model: nn.Module, container_class: Type[MutableChannelContainer]):
"""register channel container for dynamic ops."""
for module in model.modules():
if isinstance(module, dynamic_ops.DynamicChannelMixin):
if isinstance(module, DynamicChannelMixin):
in_channels = getattr(module,
module.attr_mappings['in_channels'], 0)
if module.get_mutable_attr('in_channels') is None:
in_channels = 0
if isinstance(module, nn.Conv2d):
in_channels = module.in_channels
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
in_channels = module.num_features
elif isinstance(module, nn.Linear):
in_channels = module.in_features
else:
raise NotImplementedError()
module.register_mutable_attr('in_channels',
container_class(in_channels))
out_channels = getattr(module,
module.attr_mappings['out_channels'], 0)
if module.get_mutable_attr('out_channels') is None:
out_channels = 0
if isinstance(module, nn.Conv2d):
out_channels = module.out_channels
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
out_channels = module.num_features
elif isinstance(module, nn.Linear):
out_channels = module.out_features
else:
raise NotImplementedError()
module.register_mutable_attr('out_channels',
container_class(out_channels))
@ -253,7 +251,7 @@ class MutableChannelUnit(ChannelUnit):
# register mutable_channel
for channel in list(self.input_related) + list(self.output_related):
module = channel.module
if isinstance(module, dynamic_ops.DynamicChannelMixin):
if isinstance(module, DynamicChannelMixin):
container: MutableChannelContainer
if channel.is_output_channel and module.get_mutable_attr(
'out_channels') is not None:

View File

@ -39,6 +39,7 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
min_ratio=0.9) -> None:
super().__init__(num_channels, choice_mode, divisor, min_value,
min_ratio)
candidate_choices = copy.copy(candidate_choices)
if candidate_choices == []:
candidate_choices.append(
@ -50,6 +51,8 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
self.candidate_choices,
choice_mode)
self.unit_predefined = False
@classmethod
def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel):
unit = cls(mutable_channel.num_channels,
@ -61,7 +64,8 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
def prepare_for_pruning(self, model: nn.Module):
"""Prepare for pruning."""
super().prepare_for_pruning(model)
if not self.unit_predefined:
super().prepare_for_pruning(model)
self.current_choice = self.max_choice
# ~

View File

@ -99,7 +99,7 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
return len(self.choices)
@property
def current_choice(self) -> Optional[Any]:
def current_choice(self) -> Value:
"""Current choice of mutable value."""
return self._current_choice
@ -116,7 +116,7 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
"""Please refer to method :func:`__mul__`."""
return self * other
def __mul__(self, other: int) -> DerivedMutable:
def __mul__(self, other: Union[int, float]) -> DerivedMutable:
"""Overload `*` operator.
Args:
@ -127,7 +127,8 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
"""
if isinstance(other, int):
return self.derive_expand_mutable(other)
elif isinstance(other, float):
return self.derive_expand_mutable(other)
raise TypeError(f'Unsupported type {type(other)} for mul!')
def __floordiv__(self, other: Union[int, Tuple[int,
@ -143,6 +144,8 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
"""
if isinstance(other, int):
return self.derive_divide_mutable(other)
elif isinstance(other, float):
return self.derive_divide_mutable(int(other))
if isinstance(other, tuple):
assert len(other) == 2
return self.derive_divide_mutable(*other)

View File

@ -3,8 +3,10 @@ from .channel_mutator import (ChannelMutator, OneShotChannelMutator,
SlimmableChannelMutator)
from .module_mutator import (DiffModuleMutator, ModuleMutator,
OneShotModuleMutator)
from .value_mutator import DynamicValueMutator, ValueMutator
__all__ = [
'OneShotModuleMutator', 'DiffModuleMutator', 'ModuleMutator',
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator'
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator',
'ValueMutator', 'DynamicValueMutator'
]

View File

@ -358,4 +358,7 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin):
units = self.unit_class.init_from_predefined_model(model)
for unit in units:
unit.unit_predefined = self.unit_default_args.pop(
'unit_predefined', False)
return units

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from collections import Counter
from typing import Dict, List, Type
@ -7,6 +7,11 @@ from torch.nn import Module
from ..mutables import BaseMutable
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
class GroupMixin():
"""A mixin for :class:`BaseMutator`, which can group mutables by
@ -220,3 +225,49 @@ class GroupMixin():
f'When a mutable is set alias attribute :{alias_key},'
f'the corresponding module name {mutable_name} should '
f'not be used in `custom_group` {custom_group}.')
class MutatorProtocol(Protocol): # pragma: no cover
@property
def mutable_class_type(self) -> Type[BaseMutable]:
...
@property
def search_groups(self) -> Dict:
...
class OneShotSampleMixin:
def sample_choices(self: MutatorProtocol) -> Dict:
random_choices = dict()
for group_id, modules in self.search_groups.items():
random_choices[group_id] = modules[0].sample_choice()
return random_choices
def set_choices(self: MutatorProtocol, choices: Dict) -> None:
for group_id, modules in self.search_groups.items():
choice = choices[group_id]
for module in modules:
module.current_choice = choice
class DynamicSampleMixin(OneShotSampleMixin):
@property
def max_choices(self: MutatorProtocol) -> Dict:
max_choices = dict()
for group_id, modules in self.search_groups.items():
max_choices[group_id] = modules[0].max_choice
return max_choices
@property
def min_choices(self: MutatorProtocol) -> Dict:
min_choices = dict()
for group_id, modules in self.search_groups.items():
min_choices[group_id] = modules[0].min_choice
return min_choices

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_value_mutator import DynamicValueMutator
from .value_mutator import ValueMutator
__all__ = ['ValueMutator', 'DynamicValueMutator']

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.registry import MODELS
from ..group_mixin import DynamicSampleMixin
from .value_mutator import ValueMutator
@MODELS.register_module()
class DynamicValueMutator(ValueMutator, DynamicSampleMixin):
@property
def mutable_class_type(self):
return OneShotMutableValue

View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Type
from torch.nn import Module
from mmrazor.models.mutables import MutableValue
from mmrazor.registry import MODELS
from ..base_mutator import BaseMutator
from ..group_mixin import GroupMixin
@MODELS.register_module()
class ValueMutator(BaseMutator[MutableValue], GroupMixin):
"""The base class for mutable based mutator. All subclass should implement
the following APIS:
- ``mutable_class_type``
Args:
custom_group (list[list[str]], optional): User-defined search groups.
All searchable modules that are not in ``custom_group`` will be
grouped separately.
"""
def __init__(self,
custom_group: Optional[List[List[str]]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg)
if custom_group is None:
custom_group = []
self._custom_group = custom_group
self._search_groups: Optional[Dict[int, List[MutableValue]]] = None
# TODO
# should be a class property
@property
def mutable_class_type(self) -> Type[MutableValue]:
"""Corresponding mutable class type.
Returns:
Type[MUTABLE_TYPE]: Mutable class type.
"""
return MutableValue
def prepare_from_supernet(self, supernet: Module) -> None:
"""Do some necessary preparations with supernet.
Note:
For mutable based mutator, we need to build search group first.
Args:
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
self._search_groups = self.build_search_groups(supernet,
self.mutable_class_type,
self._custom_group)
@property
def search_groups(self) -> Dict[int, List[MutableValue]]:
"""Search group of supernet.
Note:
For mutable based mutator, the search group is composed of
corresponding mutables.
Raises:
RuntimeError: Called before search group has been built.
Returns:
Dict[int, List[MUTABLE_TYPE]]: Search group.
"""
if self._search_groups is None:
raise RuntimeError(
'Call `prepare_from_supernet` before access search group!')
return self._search_groups

View File

@ -498,6 +498,9 @@ def add_flops_params_counter_variable_or_reset(module):
module.__params__ = 0
counter_warning_list = []
def get_counter_type(module) -> str:
"""Get counter type of the module based on the module class name.
@ -515,10 +518,13 @@ def get_counter_type(module) -> str:
for base_cls in module.__class__.mro():
if base_cls in get_modules_list():
counter_type = base_cls.__name__ + 'Counter'
from mmengine import MMLogger
logger = MMLogger.get_current_instance()
logger.warning(f'`{old_counter_type}` not in op_counters. '
f'Using `{counter_type}` instead.')
global counter_warning_list
if old_counter_type not in counter_warning_list:
from mmengine import MMLogger
logger = MMLogger.get_current_instance()
logger.warning(f'`{old_counter_type}` not in op_counters. '
f'Using `{counter_type}` instead.')
counter_warning_list.append(old_counter_type)
break
return counter_type

View File

@ -1,35 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
class Candidates(UserList):
"""The data structure of sampled candidate. The format is [(any, float),
(any, float), ...].
"""The data structure of sampled candidate. The format is Union[Dict[str,
Dict], List[Dict[str, Dict]]].
Examples:
>>> candidates = Candidates()
>>> subnet_1 = {'choice_1': 'layer_1', 'choice_2': 'layer_2'}
>>> subnet_1 = {'1': 'choice1', '2': 'choice2'}
>>> candidates.append(subnet_1)
>>> candidates
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.0)]
>>> candidates.set_score(0, 0.9)
[{"{'1': 'choice1', '2': 'choice2'}":
{'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}]
>>> candidates.set_resources(0, 49.9, 'flops')
>>> candidates.set_score(0, 100.)
>>> candidates
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9)]
[{"{'1': 'choice1', '2': 'choice2'}":
{'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}]
>>> subnet_2 = {'choice_3': 'layer_3', 'choice_4': 'layer_4'}
>>> candidates.append((subnet_2, 0.5))
>>> candidates.append(subnet_2)
>>> candidates
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9),
({'choice_3': 'layer_3', 'choice_4': 'layer_4'}, 0.5)]
[{"{'1': 'choice1', '2': 'choice2'}":
{'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}},
{"{'choice_3': 'layer_3', 'choice_4':'layer_4'}":
{'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}]
>>> candidates.subnets
[{'choice_1': 'layer_1', 'choice_2': 'layer_2'},
[{'1': 'choice1', '2': 'choice2'},
{'choice_3': 'layer_3', 'choice_4': 'layer_4'}]
>>> candidates.resources('flops')
[49.9, 0.0]
>>> candidates.scores
[0.9, 0.5]
[100.0, 0.0]
"""
_format_return = Union[Tuple[Any, float], List[Tuple[Any, float]]]
_format_return = Union[Dict[str, Dict], List[Dict[str, Dict]]]
_format_input = Union[Dict, List[Dict], Dict[str, Dict], List[Dict[str,
Dict]]]
_indicators = ('score', 'flops', 'params', 'latency')
def __init__(self, initdata: Optional[Any] = None):
def __init__(self, initdata: Optional[_format_input] = None):
self.data = []
if initdata is not None:
initdata = self._format(initdata)
@ -41,23 +50,59 @@ class Candidates(UserList):
@property
def scores(self) -> List[float]:
"""The scores of candidates."""
return [item[1] for item in self.data]
return [
round(value.get('score', 0.), 2) for item in self.data
for _, value in item.items()
]
def resources(self, key_indicator: str = 'flops') -> List[float]:
"""The resources of candidates."""
assert key_indicator in ['flops', 'params', 'latency']
return [
value.get(key_indicator, 0.) for item in self.data
for _, value in item.items()
]
@property
def subnets(self) -> List[Dict]:
"""The subnets of candidates."""
return [item[0] for item in self.data]
return [eval(key) for item in self.data for key, _ in item.items()]
def _format(self, data: Any) -> _format_return:
"""Transform [any, ...] to [tuple(any, float), ...] Transform any to
tuple(any, float)."""
def _format(self, data: _format_input) -> _format_return:
"""Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str,
Dict]]].
def _format_item(item: Any):
"""Transform any to tuple(any, float)."""
if isinstance(item, tuple):
return (item[0], float(item[1]))
Args:
data: Four types of input are supported:
1. Dict: only include network information.
2. List[Dict]: multiple candidates only include network
information.
3. Dict[str, Dict]: network information and the corresponding
resources.
4. List[Dict[str, Dict]]: multiple candidate information.
Returns:
Union[Dict[str, Dict], UserList[Dict[str, Dict]]]:
A dict or a list of dict that contains a pair of network
information and the corresponding Score | FLOPs | Params |
Latency results in each candidate.
Notes:
Score | FLOPs | Params | Latency:
1. a candidate resources with a default value of -1 indicates
that it has not been estimated.
2. a candidate resources with a default value of 0 indicates
that some indicators have been evaluated.
"""
def _format_item(
cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]:
"""Transform Dict to Dict[str, Dict]."""
if isinstance(list(cond.values())[0], dict):
for value in list(cond.values()):
for key in list(self._indicators):
value.setdefault(key, 0.)
return cond
else:
return (item, 0.)
return {str(cond): {}.fromkeys(self._indicators, -1)}
if isinstance(data, UserList):
return [_format_item(i) for i in data.data]
@ -68,12 +113,15 @@ class Candidates(UserList):
else:
return _format_item(data)
def append(self, item: Any) -> None:
def append(self, item: _format_input) -> None:
"""Append operation."""
item = self._format(item)
self.data.append(item)
if isinstance(item, list):
self.data = self.data + item
else:
self.data.append(item)
def insert(self, i: int, item: Any) -> None:
def insert(self, i: int, item: _format_input) -> None:
"""Insert operation."""
item = self._format(item)
self.data.insert(i, item)
@ -88,4 +136,35 @@ class Candidates(UserList):
def set_score(self, i: int, score: float) -> None:
"""Set score to the specified subnet by index."""
self.data[i] = (self.data[i][0], float(score))
self.set_resource(i, score, 'score')
def set_resource(self,
i: int,
resources: float,
key_indicator: str = 'flops') -> None:
"""Set resources to the specified subnet by index."""
assert key_indicator in ['score', 'flops', 'params', 'latency']
for _, value in self.data[i].items():
value[key_indicator] = resources
def update_resources(self, resources: list, start: int = 0) -> None:
"""Update resources to the specified candidate."""
end = start + len(resources)
assert len(
self.data) >= end, 'Check the number of candidate resources.'
for i, item in enumerate(self.data[start:end]):
for _, value in item.items():
value.update(resources[i])
def sort_by(self,
key_indicator: str = 'score',
reverse: bool = True) -> None:
"""Sort by a specific indicator in descending order.
Args:
key_indicator (str): sort all candidates by key_indicator.
Defaults to 'score'.
reverse (bool): sort all candidates in descending order.
"""
self.data.sort(
key=lambda x: list(x.values())[0][key_indicator], reverse=reverse)

View File

@ -43,13 +43,15 @@ def load_fix_subnet(model: nn.Module,
raise RuntimeError('Root model can not be dynamic op.')
# Avoid circular import
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
from mmrazor.models.mutables.base_mutable import BaseMutable
for name, module in model.named_modules():
# The format of `chosen`` is different for each type of mutable.
# In the corresponding mutable, it will check whether the `chosen`
# format is correct.
if isinstance(module, (MutableChannelContainer, DerivedMutable)):
continue
if isinstance(module, BaseMutable):
if not module.is_fixed:
if getattr(module, 'alias', None):
@ -61,8 +63,8 @@ def load_fix_subnet(model: nn.Module,
chosen = fix_mutable.get(alias, None)
else:
mutable_name = name.lstrip(prefix)
if mutable_name not in fix_mutable and \
not isinstance(module, DerivedMutable):
if mutable_name not in fix_mutable and not isinstance(
module, (DerivedMutable, MutableChannelContainer)):
raise RuntimeError(
f'The module name {mutable_name} is not in '
'fix_mutable, please check your `fix_mutable`.')
@ -87,13 +89,15 @@ def export_fix_subnet(model: nn.Module,
level=logging.WARNING)
# Avoid circular import
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
from mmrazor.models.mutables.base_mutable import BaseMutable
fix_subnet = dict()
for name, module in model.named_modules():
if isinstance(module, BaseMutable):
if isinstance(module, DerivedMutable) and not dump_derived_mutable:
if isinstance(module,
(MutableChannelContainer,
DerivedMutable)) and not dump_derived_mutable:
continue
if module.alias:

View File

@ -3,7 +3,7 @@ from torch.nn import Module
from torch import Tensor
import torch.nn as nn
import torch
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutables import DerivedMutable
@ -13,6 +13,9 @@ from mmrazor.registry import MODELS
from mmengine.model import BaseModel
# this file includes models for tesing.
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
class LinearHead(Module):
@ -475,7 +478,7 @@ class DwConvModel(nn.Module):
def register_mutable(module: DynamicChannelMixin,
mutable: OneShotMutableChannelUnit,
mutable: MutableChannelUnit,
is_out=True,
start=0,
end=-1):
@ -581,6 +584,95 @@ class DynamicLinearModel(nn.Module):
self.linear, mutable2, False)
class DynamicAttention(nn.Module):
"""
x
|blocks: DynamicSequential(depth)
|(blocks)
x1
|fc (OneShotMutableChannel * OneShotMutableValue)
output
"""
def __init__(self) -> None:
super().__init__()
self.mutable_depth = OneShotMutableValue(
value_list=[1, 2], default_value=2)
self.mutable_embed_dims = OneShotMutableChannel(
num_channels=624, candidate_choices=[576, 624])
self.base_embed_dims = OneShotMutableChannel(
num_channels=64, candidate_choices=[64])
self.mutable_num_heads = [
OneShotMutableValue(
value_list=[8, 10],
default_value=10)
for _ in range(2)
]
self.mutable_mlp_ratios = [
OneShotMutableValue(
value_list=[3.0, 3.5, 4.0],
default_value=4.0)
for _ in range(2)
]
self.mutable_q_embed_dims = [
i * self.base_embed_dims for i in self.mutable_num_heads
]
self.patch_embed = DynamicPatchEmbed(
img_size=224,
in_channels=3,
embed_dims=self.mutable_embed_dims.num_channels)
# cls token and pos embed
self.pos_embed = nn.Parameter(
torch.zeros(1, 197,
self.mutable_embed_dims.num_channels))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
layers = []
for i in range(self.mutable_depth.max_choice):
layer = TransformerEncoderLayer(
embed_dims=self.mutable_embed_dims.num_channels,
num_heads=self.mutable_num_heads[i].max_choice,
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
layers.append(layer)
self.blocks = DynamicSequential(*layers)
# OneShotMutableChannelUnit
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.register_mutables()
def register_mutables(self):
# mutablevalue
self.blocks.register_mutable_attr('depth', self.mutable_depth)
# mutablechannel
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)
for i in range(self.mutable_depth.max_choice):
layer = self.blocks[i]
layer.register_mutables(
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
mutable_head_dims=self.base_embed_dims,
mutable_embed_dims=self.mutable_embed_dims)
def forward(self, x: torch.Tensor):
B = x.shape[0]
x = self.patch_embed(x)
embed_dims = self.mutable_embed_dims.current_choice
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[..., :embed_dims]
x = self.blocks(x)
return torch.mean(x[:, 1:], dim=1)
default_models = [
LineModel,
ResBlock,

View File

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import torch
from mmrazor.models import Autoformer
from mmrazor.registry import MODELS
arch_setting = dict(
mlp_ratios=[3.0, 3.5, 4.0],
num_heads=[8, 9, 10],
depth=[14, 15, 16],
embed_dims=[528, 576, 624])
MUTATOR_CFG = dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='mmrazor.DynamicValueMutator'))
ARCHITECTURE_CFG = dict(
_scope_='mmrazor',
type='SearchableImageClassifier',
backbone=dict(
_scope_='mmrazor',
type='AutoformerBackbone',
arch_setting=arch_setting),
neck=None,
head=dict(
type='DynamicLinearClsHead',
num_classes=1000,
in_channels=624,
loss=dict(
type='mmcls.LabelSmoothLoss',
mode='original',
num_classes=1000,
label_smooth_val=0.1,
loss_weight=1.0),
topk=(1, 5)),
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
)
ALGORITHM_CFG = dict(
type='mmrazor.Autoformer',
architecture=ARCHITECTURE_CFG,
fix_subnet=None,
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
class TestAUTOFORMER(TestCase):
def test_init(self):
ALGORITHM_CFG_SUPERNET = copy.deepcopy(ALGORITHM_CFG)
# initiate autoformer with built `algorithm`.
autoformer_algo = MODELS.build(ALGORITHM_CFG_SUPERNET)
self.assertIsInstance(autoformer_algo, Autoformer)
# autoformer mutators include channel_mutator and value_mutator
assert 'channel_mutator' in autoformer_algo.mutators
assert 'value_mutator' in autoformer_algo.mutators
# autoformer search_groups
random_subnet = autoformer_algo.sample_subnet()
self.assertIsInstance(random_subnet, dict)
# autoformer_algo support training
self.assertTrue(autoformer_algo.is_supernet)
# initiate autoformer without any `mutator`.
ALGORITHM_CFG_SUPERNET.pop('type')
ALGORITHM_CFG_SUPERNET['mutators'] = None
with self.assertRaisesRegex(
AssertionError,
'mutator cannot be None when fix_subnet is None.'):
_ = Autoformer(**ALGORITHM_CFG_SUPERNET)
# initiate autoformer with error type `mutator`.
backwardtracer_cfg = dict(
type='OneShotChannelMutator',
channel_unit_cfg=dict(
type='OneShotMutableChannelUnit',
default_args=dict(
candidate_choices=list(i / 12 for i in range(2, 13)),
choice_mode='ratio')),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss')))
ALGORITHM_CFG_SUPERNET['mutators'] = dict(
channel_mutator=backwardtracer_cfg,
value_mutator=dict(type='mmrazor.DynamicValueMutator'))
with self.assertRaisesRegex(AssertionError,
'autoformer only support predefined.'):
_ = Autoformer(**ALGORITHM_CFG_SUPERNET)
def test_loss(self):
# supernet
inputs = torch.randn(1, 3, 224, 224)
autoformer = MODELS.build(ALGORITHM_CFG)
loss = autoformer(inputs)
assert loss.size(1) == 1000

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.architectures.dynamic_ops import (
DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed,
DynamicRelativePosition2D, DynamicSequential)
from mmrazor.models.mutables import MutableChannelContainer
from mmrazor.registry import MODELS
arch_setting = dict(
mlp_ratios=[3.0, 3.5, 4.0],
num_heads=[8, 9, 10],
depth=[14, 15, 16],
embed_dims=[528, 576, 624])
BACKBONE_CFG = dict(
type='mmrazor.AutoformerBackbone',
arch_setting=arch_setting,
img_size=224,
patch_size=16,
in_channels=3,
norm_cfg=dict(type='mmrazor.DynamicLayerNorm'),
act_cfg=dict(type='GELU'))
def test_searchable_autoformer_mutable() -> None:
backbone = MODELS.build(BACKBONE_CFG)
num_heads = backbone.arch_setting['num_heads']
mlp_ratios = backbone.arch_setting['mlp_ratios']
depth = backbone.arch_setting['depth']
embed_dims = backbone.arch_setting['embed_dims']
embed_dims_expansion = [i * j for i in mlp_ratios for j in embed_dims]
head_expansion = [i * 64 for i in num_heads]
for name, module in backbone.named_modules():
if isinstance(module, DynamicRelativePosition2D):
assert len(module.mutable_head_dims.current_choice) == 64
elif isinstance(module, DynamicMultiheadAttention):
assert len(
module.mutable_embed_dims.current_choice) == max(embed_dims)
assert len(module.mutable_q_embed_dims.current_choice) == max(
head_expansion)
assert module.mutable_num_heads.choices == num_heads
elif isinstance(module, DynamicLinear):
if 'fc1' in name:
assert module.mutable_attrs['in_features'].num_channels == max(
embed_dims)
assert module.mutable_attrs[
'out_features'].num_channels == max(embed_dims_expansion)
elif 'fc2' in name:
assert module.mutable_attrs['in_features'].num_channels == max(
embed_dims_expansion)
assert module.mutable_attrs[
'out_features'].num_channels == max(embed_dims)
elif isinstance(module, DynamicPatchEmbed):
assert type(module.mutable_embed_dims) == MutableChannelContainer
assert len(
module.mutable_embed_dims.current_choice) == max(embed_dims)
elif isinstance(module, DynamicSequential):
assert module.mutable_depth.choices == depth
assert backbone.last_mutable.num_channels == max(embed_dims)

View File

@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmrazor.models.architectures.dynamic_ops import DynamicMultiheadAttention
from mmrazor.models.architectures.ops import MultiheadAttention
from mmrazor.models.mutables import (MutableChannelContainer,
OneShotMutableChannel,
OneShotMutableChannelUnit,
OneShotMutableValue)
class TestDynamicMHA(TestCase):
def setUp(self) -> None:
self.mutable_num_heads = OneShotMutableValue(
value_list=[2, 4, 8], default_value=8)
self.mutable_embed_dims = OneShotMutableChannel(num_channels=128)
self.base_embed_dims = OneShotMutableChannel(
num_channels=8, candidate_choices=[8])
self.mutable_q_embed_dims = self.mutable_num_heads * \
self.base_embed_dims
self.dynamic_m = DynamicMultiheadAttention(embed_dims=128, num_heads=8)
OneShotMutableChannelUnit._register_channel_container(
self.dynamic_m, MutableChannelContainer)
self.dynamic_m.register_mutable_attr('num_heads',
self.mutable_num_heads)
MutableChannelContainer.register_mutable_channel_to_module(
self.dynamic_m, self.mutable_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.dynamic_m, self.mutable_q_embed_dims, True, end=64)
MutableChannelContainer.register_mutable_channel_to_module(
self.dynamic_m.rel_pos_embed_k, self.base_embed_dims, False)
MutableChannelContainer.register_mutable_channel_to_module(
self.dynamic_m.rel_pos_embed_v, self.base_embed_dims, False)
def test_forward(self) -> None:
x = torch.randn(8, 197, 128)
output = self.dynamic_m(x)
self.assertIsNotNone(output)
def test_convert(self) -> None:
static_m = MultiheadAttention(embed_dims=100, num_heads=10)
dynamic_m = DynamicMultiheadAttention.convert_from(static_m)
self.assertIsNotNone(dynamic_m)

View File

@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch.nn as nn
from torch.nn import Sequential
from mmrazor.models.architectures.dynamic_ops import DynamicSequential
from mmrazor.models.mutables import OneShotMutableValue
class TestDynamicSequential(TestCase):
def setUp(self) -> None:
self.layers = [
nn.Linear(4, 5),
nn.Linear(5, 6),
nn.Linear(6, 7),
nn.Linear(7, 8),
]
self.dynamic_m = DynamicSequential(*self.layers)
mutable_depth = OneShotMutableValue(
value_list=[2, 3, 4], default_value=3)
self.dynamic_m.register_mutable_attr('depth', mutable_depth)
def test_init(self) -> None:
self.assertEqual(
self.dynamic_m.get_mutable_attr('depth').current_choice, 3)
def test_to_static_op(self) -> None:
with pytest.raises(RuntimeError):
self.dynamic_m.to_static_op()
current_mutable = self.dynamic_m.get_mutable_attr('depth')
current_mutable.fix_chosen(current_mutable.dump_chosen().chosen)
static_op = self.dynamic_m.to_static_op()
self.assertIsNotNone(static_op)
def test_convert_from(self) -> None:
static_m = Sequential(*self.layers)
dynamic_m = DynamicSequential.convert_from(static_m)
self.assertIsNotNone(dynamic_m)

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
from mmcls.models.utils import PatchEmbed
from mmrazor.models.architectures.dynamic_ops import DynamicPatchEmbed
from mmrazor.models.mutables import SquentialMutableChannel
class TestPatchEmbed(TestCase):
def setUp(self):
self.dynamic_embed = DynamicPatchEmbed(
img_size=224, in_channels=3, embed_dims=100)
mutable_embed_dims = SquentialMutableChannel(num_channels=100)
mutable_embed_dims.current_choice = 50
self.dynamic_embed.register_mutable_attr('embed_dims',
mutable_embed_dims)
def test_patch_embed(self):
mutable = SquentialMutableChannel(num_channels=120)
with pytest.raises(ValueError):
self.dynamic_embed.register_mutable_attr('embed_dims', mutable)
self.assertTrue(
self.dynamic_embed.get_mutable_attr('embed_dims').current_choice ==
50)
def test_convert(self):
static_m = PatchEmbed(img_size=224, in_channels=3, embed_dims=768)
dynamic_m = DynamicPatchEmbed.convert_from(static_m)
self.assertIsNotNone(dynamic_m)
def test_to_static_op(self):
mutable_embed_dims = SquentialMutableChannel(num_channels=100)
mutable_embed_dims.current_choice = 10
with pytest.raises(RuntimeError):
self.dynamic_embed.to_static_op()
mutable_embed_dims.fix_chosen(mutable_embed_dims.dump_chosen().chosen)
self.dynamic_embed.register_mutable_attr('embed_dims',
mutable_embed_dims)
static_op = self.dynamic_embed.to_static_op()
self.assertIsNotNone(static_op)

View File

@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
from torch.nn import LayerNorm
from mmrazor.models.architectures.dynamic_ops import DynamicLayerNorm
from mmrazor.models.mutables import SquentialMutableChannel
class TestDynamicLayerNorm(TestCase):
def setUp(self) -> None:
self.dynamic_m = DynamicLayerNorm(100)
mutable_num_features = SquentialMutableChannel(num_channels=100)
mutable_num_features.current_choice = 50
self.dynamic_m.register_mutable_attr('num_features',
mutable_num_features)
def test_init(self) -> None:
mutable = SquentialMutableChannel(num_channels=100)
self.dynamic_m.register_mutable_attr('in_channels', mutable)
self.dynamic_m.register_mutable_attr('out_channels', mutable)
self.assertEqual(
self.dynamic_m.get_mutable_attr('num_features').current_choice, 50)
def test_to_static_op(self):
with pytest.raises(RuntimeError):
self.dynamic_m.to_static_op()
current_mutable = self.dynamic_m.get_mutable_attr('num_features')
current_mutable.fix_chosen(current_mutable.dump_chosen().chosen)
static_op = self.dynamic_m.to_static_op()
self.assertIsNotNone(static_op)
def test_convert(self) -> None:
static_m = LayerNorm(100)
dynamic_m = DynamicLayerNorm.convert_from(static_m)
self.assertIsNotNone(dynamic_m)

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from mmrazor.models.architectures.dynamic_ops import DynamicRelativePosition2D
from mmrazor.models.architectures.ops import RelativePosition2D
from mmrazor.models.mutables import SquentialMutableChannel
class TestDynamicRP(TestCase):
def setUp(self) -> None:
mutable_head_dims = SquentialMutableChannel(num_channels=8)
self.dynamic_rp = DynamicRelativePosition2D(
head_dims=8, max_relative_position=14)
mutable_head_dims.current_choice = 6
self.dynamic_rp.register_mutable_attr('head_dims', mutable_head_dims)
def test_mutable_attrs(self) -> None:
assert self.dynamic_rp.mutable_head_dims.current_choice == 6
embed = self.dynamic_rp.forward(14, 14)
self.assertIsNotNone(embed)
def test_convert(self):
static_model = RelativePosition2D(
head_dims=10, max_relative_position=14)
dynamic_model = DynamicRelativePosition2D.convert_from(static_model)
self.assertIsNotNone(dynamic_model)
def test_to_static_op(self):
with pytest.raises(RuntimeError):
static_m = self.dynamic_rp.to_static_op()
mutable = SquentialMutableChannel(num_channels=8)
mutable.current_choice = 4
mutable.fix_chosen(mutable.dump_chosen().chosen)
self.dynamic_rp.register_mutable_attr('head_dims', mutable)
static_m = self.dynamic_rp.to_static_op()
self.assertIsNotNone(static_m)
dynamic_output = self.dynamic_rp.forward(14, 14)
static_output = static_m.forward(14, 14)
self.assertTrue(torch.equal(dynamic_output, static_output))

View File

@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from mmrazor.models import SearchableImageClassifier
class TestSearchableImageClassifier(TestCase):
def test_init(self):
arch_setting = dict(
mlp_ratios=[3.0, 3.5, 4.0],
num_heads=[8, 9, 10],
depth=[14, 15, 16],
embed_dims=[528, 576, 624])
supernet_kwargs = dict(
backbone=dict(
_scope_='mmrazor',
type='AutoformerBackbone',
arch_setting=arch_setting),
neck=None,
head=dict(
_scope_='mmrazor',
type='DynamicLinearClsHead',
num_classes=1000,
in_channels=624,
loss=dict(
type='mmcls.LabelSmoothLoss',
mode='original',
num_classes=1000,
label_smooth_val=0.1,
loss_weight=1.0),
topk=(1, 5)),
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
)
supernet = SearchableImageClassifier(**supernet_kwargs)
# test connect_with_backbone
self.assertEqual(
supernet.backbone.last_mutable.activated_channels,
len(
supernet.head.fc.get_mutable_attr(
'in_channels').current_choice))

View File

@ -123,6 +123,27 @@ class TestDerivedMutable(TestCase):
mv.current_choice == 120
assert mv_derived.current_choice == 16
mc_derived = mc // 8.0
assert mc_derived.source_mutables == {mc}
mc.current_choice = 128.
assert mc_derived.current_choice == 16
assert torch.equal(mc_derived.current_mask,
torch.ones(16, dtype=torch.bool))
mc.current_choice = 120.
assert mc_derived.current_choice == 16
assert torch.equal(mc_derived.current_mask,
torch.ones(16, dtype=torch.bool))
mv = OneShotMutableValue(value_list=[112, 120, 128])
mv_derived = mv // 8.0
assert mv_derived.source_mutables == {mv}
mv.current_choice == 128.
assert mv_derived.current_choice == 16
mv.current_choice == 120.
assert mv_derived.current_choice == 16
def test_source_mutables(self) -> None:
def useless_fn(x):
@ -207,6 +228,43 @@ class TestDerivedMutable(TestCase):
derived_e.current_mask,
torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool))
def test_mutable_channel_value_calculation(self) -> None:
mc = SquentialMutableChannel(num_channels=10)
mv = OneShotMutableValue(value_list=[2.0, 2.5, 3.0, 3.5])
derived_mutable = mc * mv
assert derived_mutable.source_mutables == {mv, mc}
mc.current_choice = 6
mv.current_choice = 3.5
assert derived_mutable.current_choice == 21
mc.current_choice = 9
mv.current_choice = 3.5
assert derived_mutable.current_choice == 31
mc.current_choice = 7
mv.current_choice = 2.5
assert derived_mutable.current_choice == 17
assert isinstance(derived_mutable, BaseMutable)
assert isinstance(derived_mutable, DerivedMutable)
assert not derived_mutable.is_fixed
mc.current_choice = mc.num_channels
mv.current_choice = mv.min_choice
assert derived_mutable.current_choice == \
mv.current_choice * mc.num_channels
mv.current_choice = mv.max_choice
assert derived_mutable.current_choice == \
mv.current_choice * mc.current_choice
with pytest.raises(RuntimeError):
derived_mutable.is_fixed = True
mc.fix_chosen(mc.dump_chosen().chosen)
assert not derived_mutable.is_fixed
mv.fix_chosen(mv.dump_chosen().chosen)
assert derived_mutable.is_fixed
@pytest.mark.parametrize('expand_ratio', [1, 2, 3])
def test_derived_expand_mutable(expand_ratio: int) -> None:
@ -232,3 +290,29 @@ def test_derived_expand_mutable(expand_ratio: int) -> None:
mv.current_choice = 5
assert mv_derived.current_choice == 5 * expand_ratio
@pytest.mark.parametrize('expand_ratio', [1.5, 2.0, 2.5])
def test_derived_expand_mutable_float(expand_ratio: float) -> None:
mv = OneShotMutableValue(value_list=[3, 5, 7])
mv_derived = mv * expand_ratio
assert mv_derived.source_mutables == {mv}
assert isinstance(mv_derived, BaseMutable)
assert isinstance(mv_derived, DerivedMutable)
assert not mv_derived.is_fixed
assert mv_derived.num_choices == 1
mv.current_choice = mv.max_choice
assert mv_derived.current_choice == int(mv.current_choice * expand_ratio)
mv.current_choice = mv.min_choice
assert mv_derived.current_choice == int(mv.current_choice * expand_ratio)
with pytest.raises(RuntimeError):
mv_derived.current_choice = 123
with pytest.raises(RuntimeError):
_ = mv_derived.current_mask
mv.current_choice = 5
assert mv_derived.current_choice == int(5 * expand_ratio)

View File

@ -3,7 +3,8 @@ from unittest import TestCase
import torch
from mmrazor.models.mutables import SquentialMutableChannel
from mmrazor.models.mutables import (OneShotMutableValue,
SquentialMutableChannel)
class TestSquentialMutableChannel(TestCase):
@ -41,3 +42,16 @@ class TestSquentialMutableChannel(TestCase):
channel = SquentialMutableChannel(10, choice_mode='ratio')
self._test_mutable(channel, 0.5, 0.5, 5, self._generate_mask(5, 10))
self._test_mutable(channel, 2, 0.2, 2, self._generate_mask(2, 10))
def test_mutable_channel_mul(self):
channel = SquentialMutableChannel(2)
self.assertEqual(channel.current_choice, 2)
mv = OneShotMutableValue(value_list=[1, 2, 3], default_value=3)
derived1 = channel * mv
derived2 = mv * channel
assert derived1.current_choice == 6
assert derived2.current_choice == 6
mv.current_choice = mv.min_choice
assert derived1.current_choice == 2
assert derived2.current_choice == 2
assert torch.equal(derived1.current_mask, derived2.current_mask)

View File

@ -2,6 +2,8 @@
from unittest import TestCase
from mmrazor.models.mutables import OneShotMutableChannelUnit
from mmrazor.models.mutators.channel_mutator import ChannelMutator
from .....data.models import DynamicAttention
class TestSequentialMutableChannelUnit(TestCase):
@ -14,3 +16,20 @@ class TestSequentialMutableChannelUnit(TestCase):
unit = OneShotMutableChannelUnit(
48, [0.3, 0.5, 0.7], choice_mode='ratio', divisor=8)
self.assertSequenceEqual(unit.candidate_choices, [1 / 3, 0.5, 2 / 3])
def test_unit_predefined(self):
model = DynamicAttention()
mutator = ChannelMutator(
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': False
}
},
parse_cfg={'type': 'Predefined'})
mutator.prepare_from_supernet(model)
choices = mutator.sample_choices()
mutator.set_choices(choices)
self.assertSequenceEqual(mutator.units[0].candidate_choices,
[576, 624])
self.assertSequenceEqual(mutator.units[1].candidate_choices, [64])

View File

@ -73,9 +73,6 @@ class TestMutableValue(TestCase):
assert mul_derived_mv.current_choice == 4
assert rmul_derived_mv.current_choice == 4
with pytest.raises(TypeError):
_ = mv * 1.2
mv = MutableValue(value_list=[1, 2, 3], default_value=3)
mc = SquentialMutableChannel(num_channels=4)
@ -114,9 +111,6 @@ class TestMutableValue(TestCase):
mv.current_choice = 136
assert derived_mv.current_choice == 18
with pytest.raises(TypeError):
_ = mv // 1.2
def test_repr(self) -> None:
value_list = [2, 4, 6]
mv = MutableValue(value_list=value_list)

View File

@ -10,7 +10,7 @@ from mmrazor.models.mutables.mutable_channel import (
L1MutableChannelUnit, SequentialMutableChannelUnit)
from mmrazor.models.mutators.channel_mutator import ChannelMutator
from mmrazor.registry import MODELS
from ...data.models import DynamicLinearModel
from ...data.models import DynamicAttention, DynamicLinearModel
from ...test_core.test_graph.test_graph import TestGraph
sys.setrecursionlimit(2000)
@ -135,6 +135,30 @@ class TestChannelMutator(unittest.TestCase):
mutator.prepare_from_supernet(model)
self._test_a_mutator(mutator, model)
def test_models_with_predefined_dynamic_op_without_pruning(self):
for Model in [
DynamicAttention,
]:
with self.subTest(model=Model):
model = Model()
mutator = ChannelMutator(
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'})
mutator.prepare_from_supernet(model)
choices = mutator.sample_choices()
mutator.set_choices(choices)
self.assertGreater(len(mutator.mutable_units), 0)
x = torch.rand([2, 3, 224, 224])
y = model(x)
self.assertEqual(
list(y.shape),
[2, list(mutator.current_choices.values())[0]])
def test_custom_group(self):
ARCHITECTURE_CFG = dict(
type='mmcls.ImageClassifier',

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmrazor.models.mutables import MutableValue
from mmrazor.models.mutators import DynamicValueMutator
from ...data.models import DynamicAttention
class TestValueMutator(unittest.TestCase):
def test_models_with_predefined_dynamic_op(self):
for Model in [
DynamicAttention,
]:
with self.subTest(model=Model):
model = Model()
value_mutator = DynamicValueMutator()
value_mutator.prepare_from_supernet(model)
value_choices = value_mutator.sample_choices()
value_mutator.set_choices(value_choices)
mutable_value_space = []
for mutable_value, module in model.named_modules():
if isinstance(module, MutableValue):
mutable_value_space.append(mutable_value)
assert len(
value_mutator.search_groups) == len(mutable_value_space)
x = torch.rand([2, 3, 224, 224])
y = model(x)
self.assertEqual(list(y.shape), [2, 624])

View File

@ -10,7 +10,27 @@ class TestCandidates(TestCase):
def setUp(self) -> None:
self.fake_subnet = {'1': 'choice1', '2': 'choice2'}
self.fake_subnet_with_score = (self.fake_subnet, 1.)
self.fake_subnet_with_resource = {
str(self.fake_subnet): {
'score': 0.,
'flops': 50.,
'params': 0.,
'latency': 0.
}
}
self.fake_subnet_with_score = {
str(self.fake_subnet): {
'score': 99.,
'flops': 0.,
'params': 0.,
'latency': 0.
}
}
self.has_flops_network = {
str(self.fake_subnet): {
'flops': 50.,
}
}
def test_init(self):
# initlist is None
@ -23,16 +43,25 @@ class TestCandidates(TestCase):
# initlist is UserList
data = UserList([self.fake_subnet] * 2)
self.assertEqual(len(candidates.data), 2)
self.assertEqual(candidates.resources('flops'), [-1, -1])
# initlist is list(Dict[str, Dict])
candidates = Candidates([self.has_flops_network] * 2)
self.assertEqual(candidates.resources('flops'), [50., 50.])
def test_scores(self):
# test property: scores
data = [self.fake_subnet_with_score] * 2
candidates = Candidates(data)
self.assertEqual(candidates.scores, [1., 1.])
self.assertEqual(candidates.scores, [99., 99.])
def test_resources(self):
data = [self.fake_subnet_with_resource] * 2
candidates = Candidates(data)
self.assertEqual(candidates.resources('flops'), [50., 50.])
def test_subnets(self):
# test property: subnets
data = [self.fake_subnet_with_score] * 2
data = [self.fake_subnet] * 2
candidates = Candidates(data)
self.assertEqual(candidates.subnets, [self.fake_subnet] * 2)
@ -41,17 +70,20 @@ class TestCandidates(TestCase):
candidates = Candidates()
candidates.append(self.fake_subnet)
self.assertEqual(len(candidates), 1)
# item is tuple
# item is List
candidates = Candidates()
candidates.append(self.fake_subnet_with_score)
self.assertEqual(len(candidates), 1)
candidates.append([self.fake_subnet_with_score])
# item is Candidates
candidates_2 = Candidates([self.fake_subnet_with_resource])
candidates.append(candidates_2)
self.assertEqual(len(candidates), 2)
def test_insert(self):
# item is dict
candidates = Candidates([self.fake_subnet_with_score])
candidates = Candidates(self.fake_subnet_with_score)
candidates.insert(1, self.fake_subnet)
self.assertEqual(len(candidates), 2)
# item is tuple
# item is List
candidates = Candidates([self.fake_subnet_with_score])
candidates.insert(1, self.fake_subnet_with_score)
self.assertEqual(len(candidates), 2)
@ -61,13 +93,60 @@ class TestCandidates(TestCase):
candidates = Candidates([self.fake_subnet_with_score])
candidates.extend([self.fake_subnet])
self.assertEqual(len(candidates), 2)
# other is UserList
# other is Candidates
candidates = Candidates([self.fake_subnet_with_score])
candidates.extend(UserList([self.fake_subnet_with_score]))
candidates_2 = Candidates([self.fake_subnet_with_resource])
candidates.extend(candidates_2)
self.assertEqual(len(candidates), 2)
def test_set_score(self):
# test set_score
def test_set_resource(self):
# test set_resource
candidates = Candidates([self.fake_subnet])
for kk in ['flops', 'params', 'latency']:
self.assertEqual(candidates.resources(kk)[0], -1)
candidates.set_resource(0, 49.9, kk)
self.assertEqual(candidates.resources(kk)[0], 49.9)
candidates.insert(0, self.fake_subnet_with_resource)
self.assertEqual(len(candidates), 2)
self.assertEqual(candidates.resources('flops'), [50., 49.9])
self.assertEqual(candidates.resources('latency'), [0., 49.9])
candidates = Candidates([self.fake_subnet_with_score])
candidates.set_score(0, 0.5)
self.assertEqual(candidates[0][1], 0.5)
candidates.set_resource(0, 100.0, 'score')
self.assertEqual(candidates.scores[0], 100.)
candidates = Candidates([self.fake_subnet_with_score])
candidates.set_resource(0, 100.0, 'score')
candidates.extend(UserList([self.fake_subnet_with_resource]))
candidates.set_resource(1, 99.9, 'score')
self.assertEqual(candidates.scores, [100., 99.9])
def test_update_resources(self):
# test update_resources
candidates = Candidates([self.fake_subnet])
candidates.append([self.fake_subnet_with_score])
candidates_2 = Candidates(self.fake_subnet_with_resource)
candidates.append(candidates_2)
self.assertEqual(len(candidates), 3)
self.assertEqual(candidates.resources('flops'), [-1, 0., 50.])
self.assertEqual(candidates.resources('latency'), [-1, 0., 0.])
resources = [{'flops': -2}, {'latency': 4.}]
candidates.update_resources(resources, start=1)
self.assertEqual(candidates.resources('flops'), [-1, -2, 50.])
self.assertEqual(candidates.resources('latency'), [-1, 0., 4])
candidates.update_resources(resources, start=0)
self.assertEqual(candidates.resources('flops'), [-2, -2, 50.])
self.assertEqual(candidates.resources('latency'), [-1, 4., 4.])
def test_sort(self):
# test set_sort
candidates = Candidates([self.fake_subnet_with_score])
candidates.extend(UserList([self.fake_subnet_with_resource]))
candidates.insert(0, self.fake_subnet)
candidates.set_resource(0, 100., 'score')
candidates.set_resource(2, 98., 'score')
self.assertEqual(candidates.scores, [100., 99., 98.])
candidates.sort_by(key_indicator='score', reverse=False)
self.assertEqual(candidates.scores, [98., 99., 100.])
candidates.sort_by(key_indicator='latency')
self.assertEqual(candidates.scores, [98., 99., 100.])
candidates.sort_by(key_indicator='flops', reverse=False)
self.assertEqual(candidates.scores, [100., 99., 98.])

View File

@ -82,7 +82,7 @@ class TestEvolutionSearchLoop(TestCase):
num_mutation=2,
num_crossover=2,
mutate_prob=0.1,
flops_range=None,
constraints_range=dict(flops=(0, 330)),
score_key='coco/bbox_mAP')
self.train_cfg = Config(train_cfg)
self.runner = MagicMock(spec=ToyRunner)
@ -103,7 +103,7 @@ class TestEvolutionSearchLoop(TestCase):
# test init_candidates is not None
fake_subnet = {'1': 'choice1', '2': 'choice2'}
fake_candidates = Candidates((fake_subnet, 0.))
fake_candidates = Candidates(fake_subnet)
init_candidates_path = os.path.join(self.temp_dir, 'candidates.yaml')
fileio.dump(fake_candidates, init_candidates_path)
loop_cfg.init_candidates = init_candidates_path
@ -111,8 +111,12 @@ class TestEvolutionSearchLoop(TestCase):
self.assertIsInstance(loop, EvolutionSearchLoop)
self.assertEqual(loop.candidates, fake_candidates)
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
def test_run_epoch(self, mock_export_fix_subnet):
@patch('mmrazor.engine.runner.utils.check.load_fix_subnet')
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
'get_model_flops_params')
def test_run_epoch(self, flops_params, mock_export_fix_subnet,
load_status):
# test_run_epoch: distributed == False
loop_cfg = copy.deepcopy(self.train_cfg)
loop_cfg.runner = self.runner
@ -120,20 +124,20 @@ class TestEvolutionSearchLoop(TestCase):
loop_cfg.evaluator = self.evaluator
loop = LOOPS.build(loop_cfg)
self.runner.rank = 0
loop._epoch = 1
self.runner.distributed = False
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
load_status.return_value = True
flops_params.return_value = 0, 0
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)
self.assertEqual(loop._epoch, 2)
self.assertEqual(loop._epoch, 1)
# test_run_epoch: distributed == True
loop = LOOPS.build(loop_cfg)
self.runner.rank = 0
loop._epoch = 1
self.runner.distributed = True
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
@ -141,26 +145,27 @@ class TestEvolutionSearchLoop(TestCase):
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)
self.assertEqual(loop._epoch, 2)
self.assertEqual(loop._epoch, 1)
# test_check_constraints
loop_cfg.flops_range = (0, 100)
loop_cfg.constraints_range = dict(params=(0, 100))
loop = LOOPS.build(loop_cfg)
self.runner.rank = 0
loop._epoch = 1
self.runner.distributed = True
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop._check_constraints = MagicMock(return_value=True)
flops_params.return_value = (50., 1)
mock_export_fix_subnet.return_value = fake_subnet
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)
self.assertEqual(loop._epoch, 2)
self.assertEqual(loop._epoch, 1)
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
def test_run(self, mock_export_fix_subnet):
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
'get_model_flops_params')
def test_run_loop(self, mock_flops, mock_export_fix_subnet):
# test a new search: resume == None
loop_cfg = copy.deepcopy(self.train_cfg)
loop_cfg.runner = self.runner
@ -169,16 +174,26 @@ class TestEvolutionSearchLoop(TestCase):
loop = LOOPS.build(loop_cfg)
self.runner.rank = 0
loop._epoch = 1
fake_subnet = {'1': 'choice1', '2': 'choice2'}
self.runner.work_dir = self.temp_dir
loop.update_candidate_pool = MagicMock()
loop.val_candidate_pool = MagicMock()
mutation_candidates = Candidates([fake_subnet] * loop.num_mutation)
for i in range(loop.num_mutation):
mutation_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops')
mutation_candidates.set_resource(i, 99 + i, 'score')
crossover_candidates = Candidates([fake_subnet] * loop.num_crossover)
for i in range(loop.num_crossover):
crossover_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops')
crossover_candidates.set_resource(i, 99 + i, 'score')
loop.gen_mutation_candidates = \
MagicMock(return_value=[fake_subnet]*loop.num_mutation)
MagicMock(return_value=mutation_candidates)
loop.gen_crossover_candidates = \
MagicMock(return_value=[fake_subnet]*loop.num_crossover)
loop.top_k_candidates = Candidates([(fake_subnet, 1.0),
(fake_subnet, 0.9)])
MagicMock(return_value=crossover_candidates)
loop.candidates = Candidates([fake_subnet] * 4)
mock_flops.return_value = (0.5, 101)
mock_export_fix_subnet.return_value = fake_subnet
loop.run()
assert os.path.exists(

View File

@ -119,7 +119,7 @@ class TestGreedySamplerTrainLoop(TestCase):
max_iters=12,
val_interval=2,
score_key='acc',
flops_range=None,
constraints_range=None,
num_candidates=4,
num_samples=2,
top_k=2,
@ -190,7 +190,7 @@ class TestGreedySamplerTrainLoop(TestCase):
loop._iter = loop.val_interval
subnet = loop.sample_subnet()
self.assertEqual(subnet, fake_subnet)
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
self.assertEqual(len(loop.top_k_candidates), loop.top_k)
def test_run(self):
# test run with _check_constraints
@ -200,7 +200,7 @@ class TestGreedySamplerTrainLoop(TestCase):
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop = runner.build_train_loop(cfg.train_cfg)
loop._check_constraints = MagicMock(return_value=True)
loop._check_constraints = MagicMock(return_value=(True, dict()))
runner.train()
self.assertEqual(runner.iter, runner.max_iters)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
from mmrazor.engine.runner.utils import check_subnet_flops
from mmrazor.engine.runner.utils import check_subnet_resources
try:
from mmdet.models.detectors import BaseDetector
@ -12,29 +12,33 @@ except ImportError:
@patch('mmrazor.models.ResourceEstimator')
@patch('mmrazor.models.SPOS')
def test_check_subnet_flops(mock_model, mock_estimator):
# flops_range = None
flops_range = None
def test_check_subnet_resources(mock_model, mock_estimator):
# constraints_range = dict()
constraints_range = dict()
fake_subnet = {'1': 'choice1', '2': 'choice2'}
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
flops_range)
assert result is True
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
mock_estimator, constraints_range)
assert is_pass is True
# flops_range is not None
# constraints_range is not None
# architecturte is BaseDetector
flops_range = (0., 100.)
constraints_range = dict(flops=(0, 330))
mock_model.architecture = BaseDetector
fake_results = {'flops': 50.}
mock_estimator.estimate.return_value = fake_results
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
flops_range)
assert result is True
is_pass, _ = check_subnet_resources(
mock_model,
fake_subnet,
mock_estimator,
constraints_range,
)
assert is_pass is True
# flops_range is not None
# constraints_range is not None
# architecturte is BaseDetector
flops_range = (0., 100.)
constraints_range = dict(flops=(0, 330))
fake_results = {'flops': -50.}
mock_estimator.estimate.return_value = fake_results
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
flops_range)
assert result is False
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
mock_estimator, constraints_range)
assert is_pass is False