[Feature] Add Autoformer algorithm (#315)
* update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix utpull/356/head
parent
9c567e4d40
commit
fb42405af8
configs
mmrazor
engine/runner
models
algorithms
architectures
backbones
classifiers
mutables
mutators
channel_mutator
value_mutator
task_modules/estimators/counters
structures/subnet
tests
data
test_models
test_algorithms
test_architectures
test_backbones
test_classifier
test_mutables
test_mutable_channel
test_mutators
test_subnet
test_runners
|
@ -0,0 +1,180 @@
|
|||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
preprocess_cfg = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = preprocess_cfg['mean'][::-1]
|
||||
bgr_std = preprocess_cfg['std'][::-1]
|
||||
|
||||
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
|
||||
rand_increasing_policies = [
|
||||
dict(type='mmcls.AutoContrast'),
|
||||
dict(type='mmcls.Equalize'),
|
||||
dict(type='mmcls.Invert'),
|
||||
dict(type='mmcls.Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
|
||||
dict(type='mmcls.Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
|
||||
dict(type='mmcls.Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
|
||||
dict(
|
||||
type='mmcls.SolarizeAdd',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 110)),
|
||||
dict(
|
||||
type='mmcls.ColorTransform',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='mmcls.Contrast',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='mmcls.Brightness',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='mmcls.Sharpness',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.9)),
|
||||
dict(
|
||||
type='mmcls.Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='mmcls.Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='vertical'),
|
||||
dict(
|
||||
type='mmcls.Translate',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.45),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='mmcls.Translate',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.45),
|
||||
direction='vertical')
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(
|
||||
type='mmcls.RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='mmcls.RandAugment',
|
||||
policies=rand_increasing_policies,
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='mmcls.RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(
|
||||
type='mmcls.ResizeEdge',
|
||||
scale=248,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='mmcls.CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=6,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='mmcls.RepeatAugSampler'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=6,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# optimizer
|
||||
paramwise_cfg = dict(
|
||||
bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0)
|
||||
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=0.002,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
# specific to vit pretrain
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
}))
|
||||
|
||||
# leanring policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
# about 10000 iterations for ImageNet-1k
|
||||
end=20,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=500,
|
||||
eta_min=1e-5,
|
||||
by_epoch=True,
|
||||
begin=20,
|
||||
end=500,
|
||||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=500)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
auto_scale_lr = dict(base_batch_size=2048)
|
|
@ -0,0 +1,66 @@
|
|||
# AutoFormer
|
||||
|
||||
> [Searching Transformers for Visual Recognition](https://arxiv.org/abs/2107.00651)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Recently, pure transformer-based models have shown
|
||||
great potentials for vision tasks such as image classification and detection. However, the design of transformer networks is challenging. It has been observed that the depth,
|
||||
embedding dimension, and number of heads can largely affect the performance of vision transformers. Previous models configure these dimensions based upon manual crafting. In this work, we propose a new one-shot architecture
|
||||
search framework, namely AutoFormer, dedicated to vision
|
||||
transformer search. AutoFormer entangles the weights of
|
||||
different blocks in the same layers during supernet training. Benefiting from the strategy, the trained supernet allows thousands of subnets to be very well-trained. Specifically, the performance of these subnets with weights inherited from the supernet is comparable to those retrained
|
||||
from scratch. Besides, the searched models, which we refer to AutoFormers, surpass the recent state-of-the-arts such
|
||||
as ViT and DeiT. In particular, AutoFormer-tiny/small/base
|
||||
achieve 74.7%/81.7%/82.4% top-1 accuracy on ImageNet
|
||||
with 5.7M/22.9M/53.7M parameters, respectively. Lastly,
|
||||
we verify the transferability of AutoFormer by providing
|
||||
the performance on downstream benchmarks and distillation experiments.
|
||||
|
||||

|
||||
|
||||
## Introduction
|
||||
|
||||
### Supernet pre-training on ImageNet
|
||||
|
||||
```bash
|
||||
python ./tools/train.py \
|
||||
configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py \
|
||||
--work-dir $WORK_DIR
|
||||
```
|
||||
|
||||
### Search for subnet on the trained supernet
|
||||
|
||||
```bash
|
||||
sh tools/train.py \
|
||||
configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py \
|
||||
$STEP1_CKPT \
|
||||
--work-dir $WORK_DIR
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks |
|
||||
| :------: | :------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------: |
|
||||
| ImageNet | vit | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 52.472 | 10.2 | 82.48 | 95.99 | [config](./autoformer_supernet_32xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/x.pth) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/x.log.json) | MMRazor searched |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. There are some small differences in our experiment in order to be consistent with mmrazor repo. For example, we set the max value of embed_channels 624 while the original repo set it 640. However, the original repo only search 528, 576, 624 embed_channels, so set 624 can also get the same result with orifinal paper.
|
||||
2. The original paper get 82.4 top-1 acc with 53.7M Params while we get 82.48 top-1 acc with 52.47M Params.
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@article{xu2021autoformer,
|
||||
title={Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting},
|
||||
author={Xu, Jiehui and Wang, Jianmin and Long, Mingsheng and others},
|
||||
journal={Advances in Neural Information Processing Systems},
|
||||
volume={34},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
Footer
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = ['./autoformer_supernet_32xb256_in1k.py']
|
||||
|
||||
custom_hooks = None
|
||||
|
||||
train_cfg = dict(
|
||||
_delete_=True,
|
||||
type='mmrazor.EvolutionSearchLoop',
|
||||
dataloader=_base_.val_dataloader,
|
||||
evaluator=_base_.val_evaluator,
|
||||
max_epochs=20,
|
||||
num_candidates=20,
|
||||
top_k=10,
|
||||
num_mutation=5,
|
||||
num_crossover=5,
|
||||
mutate_prob=0.2,
|
||||
constraints_range=dict(params=(0, 55)),
|
||||
score_key='accuracy/top1')
|
|
@ -0,0 +1,79 @@
|
|||
_base_ = [
|
||||
'mmrazor::_base_/settings/imagenet_bs2048_AdamW.py',
|
||||
'mmcls::_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# data preprocessor
|
||||
data_preprocessor = dict(
|
||||
_scope_='mmcls',
|
||||
type='ClsDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
num_classes=1000,
|
||||
batch_augments=dict(
|
||||
augments=[
|
||||
dict(type='Mixup', alpha=0.2),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
],
|
||||
probs=[0.5, 0.5]))
|
||||
|
||||
arch_setting = dict(
|
||||
mlp_ratios=[3.0, 3.5, 4.0],
|
||||
num_heads=[8, 9, 10],
|
||||
depth=[14, 15, 16],
|
||||
embed_dims=[528, 576, 624])
|
||||
|
||||
supernet = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SearchableImageClassifier',
|
||||
data_preprocessor=data_preprocessor,
|
||||
backbone=dict(
|
||||
_scope_='mmrazor',
|
||||
type='AutoformerBackbone',
|
||||
arch_setting=arch_setting),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='DynamicLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=624,
|
||||
loss=dict(
|
||||
type='mmcls.LabelSmoothLoss',
|
||||
mode='original',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='mmrazor.Autoformer',
|
||||
architecture=supernet,
|
||||
fix_subnet=None,
|
||||
mutators=dict(
|
||||
channel_mutator=dict(
|
||||
type='mmrazor.OneShotChannelMutator',
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': True
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'}),
|
||||
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
|
||||
# checkpoint saving
|
||||
_base_.default_hooks.checkpoint = dict(
|
||||
type='CheckpointHook',
|
||||
interval=2,
|
||||
by_epoch=True,
|
||||
save_best='accuracy/top1',
|
||||
max_keep_ckpts=3)
|
||||
|
||||
find_unused_parameters = True
|
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=25,
|
||||
num_crossover=25,
|
||||
mutate_prob=0.1,
|
||||
flops_range=(0., 465.),
|
||||
constraints_range=dict(flops=(0., 465.)),
|
||||
score_key='accuracy/top1')
|
||||
|
|
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=25,
|
||||
num_crossover=25,
|
||||
mutate_prob=0.1,
|
||||
flops_range=(0., 330.),
|
||||
constraints_range=dict(flops=(0, 330)),
|
||||
score_key='accuracy/top1')
|
||||
|
|
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=20,
|
||||
num_crossover=20,
|
||||
mutate_prob=0.1,
|
||||
flops_range=(0., 300.),
|
||||
constraints_range=dict(flops=(0, 330)),
|
||||
score_key='coco/bbox_mAP')
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine import fileio
|
||||
|
@ -14,10 +15,10 @@ from mmengine.utils import is_list_of
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.registry import LOOPS, TASK_UTILS
|
||||
from mmrazor.structures import Candidates, export_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import check_subnet_flops, crossover
|
||||
from .utils import check_subnet_resources, crossover
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -41,10 +42,11 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
num_crossover (int): The number of candidates got by crossover.
|
||||
Defaults to 25.
|
||||
mutate_prob (float): The probability of mutation. Defaults to 0.1.
|
||||
flops_range (tuple, optional): It is used for screening candidates.
|
||||
resource_estimator_cfg (dict): The config for building estimator, which
|
||||
is be used to estimate the flops of sampled subnet. Defaults to
|
||||
None, which means default config is used.
|
||||
crossover_prob (float): The probability of crossover. Defaults to 0.5.
|
||||
constraints_range (Dict[str, Any]): Constraints to be used for
|
||||
screening candidates. Defaults to dict(flops=(0, 330)).
|
||||
resource_estimator_cfg (dict, Optional): Used for building a
|
||||
resource estimator. Defaults to None.
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
init_candidates (str, optional): The candidates file path, which is
|
||||
|
@ -64,8 +66,9 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
num_mutation: int = 25,
|
||||
num_crossover: int = 25,
|
||||
mutate_prob: float = 0.1,
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330.),
|
||||
resource_estimator_cfg: Optional[dict] = None,
|
||||
crossover_prob: float = 0.5,
|
||||
constraints_range: Dict[str, Any] = dict(flops=(0., 330.)),
|
||||
resource_estimator_cfg: Optional[Dict] = None,
|
||||
score_key: str = 'accuracy/top1',
|
||||
init_candidates: Optional[str] = None) -> None:
|
||||
super().__init__(runner, dataloader, max_epochs)
|
||||
|
@ -83,11 +86,12 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
|
||||
self.num_candidates = num_candidates
|
||||
self.top_k = top_k
|
||||
self.flops_range = flops_range
|
||||
self.constraints_range = constraints_range
|
||||
self.score_key = score_key
|
||||
self.num_mutation = num_mutation
|
||||
self.num_crossover = num_crossover
|
||||
self.mutate_prob = mutate_prob
|
||||
self.crossover_prob = crossover_prob
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.resume_from = resume_from
|
||||
|
||||
|
@ -99,16 +103,58 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
correct init candidates file'
|
||||
|
||||
self.top_k_candidates = Candidates()
|
||||
if resource_estimator_cfg is None:
|
||||
self.estimator = ResourceEstimator()
|
||||
else:
|
||||
self.estimator = ResourceEstimator(**resource_estimator_cfg)
|
||||
|
||||
if self.runner.distributed:
|
||||
self.model = runner.model.module
|
||||
else:
|
||||
self.model = runner.model
|
||||
|
||||
# Build resource estimator.
|
||||
resource_estimator_cfg = dict(
|
||||
) if resource_estimator_cfg is None else resource_estimator_cfg
|
||||
self.estimator = self.build_resource_estimator(resource_estimator_cfg)
|
||||
|
||||
def build_resource_estimator(
|
||||
self, resource_estimator: Union[ResourceEstimator,
|
||||
Dict]) -> ResourceEstimator:
|
||||
"""Build resource estimator for search loop.
|
||||
|
||||
Examples of ``resource_estimator``:
|
||||
|
||||
# `ResourceEstimator` will be used
|
||||
resource_estimator = dict()
|
||||
|
||||
# custom resource_estimator
|
||||
resource_estimator = dict(type='mmrazor.ResourceEstimator')
|
||||
|
||||
Args:
|
||||
resource_estimator (ResourceEstimator or dict): A
|
||||
resource_estimator or a dict to build resource estimator.
|
||||
If ``resource_estimator`` is a resource estimator object,
|
||||
just returns itself.
|
||||
|
||||
Returns:
|
||||
:obj:`ResourceEstimator`: Resource estimator object build from
|
||||
``resource_estimator``.
|
||||
"""
|
||||
if isinstance(resource_estimator, ResourceEstimator):
|
||||
return resource_estimator
|
||||
elif not isinstance(resource_estimator, dict):
|
||||
raise TypeError(
|
||||
'resource estimator should be a ResourceEstimator object or'
|
||||
f'dict, but got {resource_estimator}')
|
||||
|
||||
resource_estimator_cfg = copy.deepcopy(
|
||||
resource_estimator) # type: ignore
|
||||
|
||||
if 'type' in resource_estimator_cfg:
|
||||
estimator = TASK_UTILS.build(resource_estimator_cfg)
|
||||
else:
|
||||
estimator = ResourceEstimator(
|
||||
**resource_estimator_cfg) # type: ignore
|
||||
|
||||
return estimator # type: ignore
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch searching."""
|
||||
self.runner.call_hook('before_train')
|
||||
|
@ -144,31 +190,48 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
f'{scores_before}')
|
||||
|
||||
self.candidates.extend(self.top_k_candidates)
|
||||
self.candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
|
||||
self.candidates.sort_by(key_indicator='score', reverse=True)
|
||||
self.top_k_candidates = Candidates(self.candidates.data[:self.top_k])
|
||||
|
||||
scores_after = self.top_k_candidates.scores
|
||||
self.runner.logger.info(f'top k scores after update: '
|
||||
f'{scores_after}')
|
||||
|
||||
mutation_candidates = self.gen_mutation_candidates()
|
||||
self.candidates_mutator_crossover = Candidates(mutation_candidates)
|
||||
crossover_candidates = self.gen_crossover_candidates()
|
||||
candidates = mutation_candidates + crossover_candidates
|
||||
assert len(candidates) <= self.num_candidates, 'Total of mutation and \
|
||||
crossover should be no more than the number of candidates.'
|
||||
self.candidates_mutator_crossover.extend(crossover_candidates)
|
||||
|
||||
self.candidates = Candidates(candidates)
|
||||
assert len(self.candidates_mutator_crossover
|
||||
) <= self.num_candidates, 'Total of mutation and \
|
||||
crossover should be less than the number of candidates.'
|
||||
|
||||
self.candidates = self.candidates_mutator_crossover
|
||||
self._epoch += 1
|
||||
|
||||
def sample_candidates(self) -> None:
|
||||
"""Update candidate pool contains specified number of candicates."""
|
||||
candidates_resources = []
|
||||
init_candidates = len(self.candidates)
|
||||
if self.runner.rank == 0:
|
||||
while len(self.candidates) < self.num_candidates:
|
||||
candidate = self.model.sample_subnet()
|
||||
if self._check_constraints(random_subnet=candidate):
|
||||
is_pass, result = self._check_constraints(
|
||||
random_subnet=candidate)
|
||||
if is_pass:
|
||||
self.candidates.append(candidate)
|
||||
candidates_resources.append(result)
|
||||
self.candidates = Candidates(self.candidates.data)
|
||||
else:
|
||||
self.candidates = Candidates([None] * self.num_candidates)
|
||||
self.candidates = Candidates([dict(a=0)] * self.num_candidates)
|
||||
|
||||
if len(candidates_resources) > 0:
|
||||
self.candidates.update_resources(
|
||||
candidates_resources,
|
||||
start=len(self.candidates.data) - len(candidates_resources))
|
||||
assert init_candidates + len(
|
||||
candidates_resources) == self.num_candidates
|
||||
|
||||
# broadcast candidates to val with multi-GPUs.
|
||||
broadcast_object_list(self.candidates.data)
|
||||
|
||||
|
@ -180,14 +243,18 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
metrics = self._val_candidate()
|
||||
score = metrics[self.score_key] \
|
||||
if len(metrics) != 0 else 0.
|
||||
self.candidates.set_score(i, score)
|
||||
self.candidates.set_resource(i, score, 'score')
|
||||
self.runner.logger.info(
|
||||
f'Epoch:[{self._epoch}/{self._max_epochs}] '
|
||||
f'Candidate:[{i + 1}/{self.num_candidates}] '
|
||||
f'Score:{score}')
|
||||
f'Flops: {self.candidates.resources("flops")[i]} '
|
||||
f'Params: {self.candidates.resources("params")[i]} '
|
||||
f'Latency: {self.candidates.resources("latency")[i]} '
|
||||
f'Score: {self.candidates.scores} ')
|
||||
|
||||
def gen_mutation_candidates(self) -> List:
|
||||
def gen_mutation_candidates(self):
|
||||
"""Generate specified number of mutation candicates."""
|
||||
mutation_resources = []
|
||||
mutation_candidates: List = []
|
||||
max_mutate_iters = self.num_mutation * 10
|
||||
mutate_iter = 0
|
||||
|
@ -198,12 +265,20 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
|
||||
mutation_candidate = self._mutation()
|
||||
|
||||
if self._check_constraints(random_subnet=mutation_candidate):
|
||||
is_pass, result = self._check_constraints(
|
||||
random_subnet=mutation_candidate)
|
||||
if is_pass:
|
||||
mutation_candidates.append(mutation_candidate)
|
||||
mutation_resources.append(result)
|
||||
|
||||
mutation_candidates = Candidates(mutation_candidates)
|
||||
mutation_candidates.update_resources(mutation_resources)
|
||||
|
||||
return mutation_candidates
|
||||
|
||||
def gen_crossover_candidates(self) -> List:
|
||||
def gen_crossover_candidates(self):
|
||||
"""Generate specofied number of crossover candicates."""
|
||||
crossover_resources = []
|
||||
crossover_candidates: List = []
|
||||
crossover_iter = 0
|
||||
max_crossover_iters = self.num_crossover * 10
|
||||
|
@ -214,8 +289,15 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
|
||||
crossover_candidate = self._crossover()
|
||||
|
||||
if self._check_constraints(random_subnet=crossover_candidate):
|
||||
is_pass, result = self._check_constraints(
|
||||
random_subnet=crossover_candidate)
|
||||
if is_pass:
|
||||
crossover_candidates.append(crossover_candidate)
|
||||
crossover_resources.append(result)
|
||||
|
||||
crossover_candidates = Candidates(crossover_candidates)
|
||||
crossover_candidates.update_resources(crossover_resources)
|
||||
|
||||
return crossover_candidates
|
||||
|
||||
def _mutation(self) -> SupportRandomSubnet:
|
||||
|
@ -229,7 +311,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
"""Crossover."""
|
||||
candidate1 = random.choice(self.top_k_candidates.subnets)
|
||||
candidate2 = random.choice(self.top_k_candidates.subnets)
|
||||
candidate = crossover(candidate1, candidate2)
|
||||
candidate = crossover(candidate1, candidate2, prob=self.crossover_prob)
|
||||
return candidate
|
||||
|
||||
def _resume(self):
|
||||
|
@ -263,7 +345,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
self.runner.model.eval()
|
||||
for data_batch in self.dataloader:
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.evaluator.process(outputs, data_batch)
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
return metrics
|
||||
|
||||
|
@ -295,16 +377,17 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
if osp.isfile(ckpt_path):
|
||||
os.remove(ckpt_path)
|
||||
|
||||
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
|
||||
def _check_constraints(
|
||||
self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]:
|
||||
"""Check whether is beyond constraints.
|
||||
|
||||
Returns:
|
||||
bool: The result of checking.
|
||||
bool, result: The result of checking.
|
||||
"""
|
||||
is_pass = check_subnet_flops(
|
||||
is_pass, results = check_subnet_resources(
|
||||
model=self.model,
|
||||
subnet=random_subnet,
|
||||
estimator=self.estimator,
|
||||
flops_range=self.flops_range)
|
||||
constraints_range=self.constraints_range)
|
||||
|
||||
return is_pass
|
||||
return is_pass, results
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine import fileio
|
||||
|
@ -13,10 +14,10 @@ from mmengine.utils import is_list_of
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.registry import LOOPS, TASK_UTILS
|
||||
from mmrazor.structures import Candidates
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import check_subnet_flops
|
||||
from .utils import check_subnet_resources
|
||||
|
||||
|
||||
class BaseSamplerTrainLoop(IterBasedTrainLoop):
|
||||
|
@ -77,18 +78,15 @@ class BaseSamplerTrainLoop(IterBasedTrainLoop):
|
|||
@LOOPS.register_module()
|
||||
class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
||||
"""IterBasedTrainLoop for greedy sampler.
|
||||
|
||||
In GreedySamplerTrainLoop, `Greedy` means that only use some top
|
||||
sampled candidates to train the supernet. So GreedySamplerTrainLoop mainly
|
||||
picks the top candidates based on their val socres, then use them to train
|
||||
the supernet one by one.
|
||||
|
||||
Steps:
|
||||
1. Sample from the supernet and the candidates.
|
||||
2. Validate these sampled candidates to get each candidate's score.
|
||||
3. Get top-k candidates based on their scores, then use them to train
|
||||
the supernet one by one.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
|
@ -102,10 +100,10 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
flops_range (dict): Constraints to be used for screening candidates.
|
||||
resource_estimator_cfg (dict): The config for building estimator, which
|
||||
is be used to estimate the flops of sampled subnet. Defaults to
|
||||
None, which means default config is used.
|
||||
constraints_range (Dict[str, Any]): Constraints to be used for
|
||||
screening candidates. Defaults to dict(flops=(0, 330)).
|
||||
resource_estimator_cfg (dict, Optional): Used for building a
|
||||
resource estimator. Defaults to None.
|
||||
num_candidates (int): The number of the candidates consist of samples
|
||||
from supernet and itself. Defaults to 1000.
|
||||
num_samples (int): The number of sample in each sampling subnet.
|
||||
|
@ -139,8 +137,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
val_begin: int = 1,
|
||||
val_interval: int = 1000,
|
||||
score_key: str = 'accuracy/top1',
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330),
|
||||
resource_estimator_cfg: Optional[dict] = None,
|
||||
constraints_range: Dict[str, Any] = dict(flops=(0, 330)),
|
||||
resource_estimator_cfg: Optional[Dict] = None,
|
||||
num_candidates: int = 1000,
|
||||
num_samples: int = 10,
|
||||
top_k: int = 5,
|
||||
|
@ -163,7 +161,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
self.evaluator = evaluator
|
||||
|
||||
self.score_key = score_key
|
||||
self.flops_range = flops_range
|
||||
self.constraints_range = constraints_range
|
||||
self.num_candidates = num_candidates
|
||||
self.num_samples = num_samples
|
||||
self.top_k = top_k
|
||||
|
@ -177,10 +175,52 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.candidates = Candidates()
|
||||
self.top_k_candidates = Candidates()
|
||||
if resource_estimator_cfg is None:
|
||||
self.estimator = ResourceEstimator()
|
||||
|
||||
# Build resource estimator.
|
||||
resource_estimator_cfg = dict(
|
||||
) if resource_estimator_cfg is None else resource_estimator_cfg
|
||||
self.estimator = self.build_resource_estimator(resource_estimator_cfg)
|
||||
|
||||
def build_resource_estimator(
|
||||
self, resource_estimator: Union[ResourceEstimator,
|
||||
Dict]) -> ResourceEstimator:
|
||||
"""Build resource estimator for search loop.
|
||||
|
||||
Examples of ``resource_estimator``:
|
||||
|
||||
# `ResourceEstimator` will be used
|
||||
resource_estimator = dict()
|
||||
|
||||
# custom resource_estimator
|
||||
resource_estimator = dict(type='mmrazor.ResourceEstimator')
|
||||
|
||||
Args:
|
||||
resource_estimator (ResourceEstimator or dict):
|
||||
A resource_estimator or a dict to build resource estimator.
|
||||
If ``resource_estimator`` is a resource estimator object,
|
||||
just returns itself.
|
||||
|
||||
Returns:
|
||||
:obj:`ResourceEstimator`: Resource estimator object build from
|
||||
``resource_estimator``.
|
||||
"""
|
||||
if isinstance(resource_estimator, ResourceEstimator):
|
||||
return resource_estimator
|
||||
elif not isinstance(resource_estimator, dict):
|
||||
raise TypeError(
|
||||
'resource estimator should be a ResourceEstimator object or'
|
||||
f'dict, but got {resource_estimator}')
|
||||
|
||||
resource_estimator_cfg = copy.deepcopy(
|
||||
resource_estimator) # type: ignore
|
||||
|
||||
if 'type' in resource_estimator_cfg:
|
||||
estimator = TASK_UTILS.build(resource_estimator_cfg)
|
||||
else:
|
||||
self.estimator = ResourceEstimator(**resource_estimator_cfg)
|
||||
estimator = ResourceEstimator(
|
||||
**resource_estimator_cfg) # type: ignore
|
||||
|
||||
return estimator # type: ignore
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
|
@ -230,9 +270,11 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.update_candidates_scores()
|
||||
|
||||
self.candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
self.candidates = Candidates(self.candidates[:self.num_candidates])
|
||||
self.top_k_candidates = Candidates(self.candidates[:self.top_k])
|
||||
self.candidates.sort_by(key_indicator='score', reverse=True)
|
||||
self.candidates = Candidates(
|
||||
self.candidates.data[:self.num_candidates])
|
||||
self.top_k_candidates = Candidates(
|
||||
self.candidates.data[:self.top_k])
|
||||
|
||||
top1_score = self.top_k_candidates.scores[0]
|
||||
if (self._iter % self.val_interval) < self.top_k:
|
||||
|
@ -243,7 +285,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
f'{num_sample_from_supernet}/{self.num_samples} '
|
||||
f'top1_score {top1_score:.3f} '
|
||||
f'cur_num_candidates: {len(self.candidates)}')
|
||||
return self.top_k_candidates.pop(0)[0]
|
||||
return self.top_k_candidates.subnets[0]
|
||||
|
||||
def update_cur_prob(self, cur_iter: int) -> None:
|
||||
"""update current probablity of sampling from the candidates, which is
|
||||
|
@ -278,7 +320,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
for _ in range(num_samples):
|
||||
if random.random() >= self.cur_prob or len(self.candidates) == 0:
|
||||
subnet = self._sample_from_supernet()
|
||||
if self._check_constraints(subnet):
|
||||
is_pass, _ = self._check_constraints(subnet)
|
||||
if is_pass:
|
||||
sampled_candidates.append(subnet)
|
||||
num_sample_from_supernet += 1
|
||||
else:
|
||||
|
@ -292,7 +335,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
self.model.set_subnet(candidate)
|
||||
metrics = self._val_candidate()
|
||||
score = metrics[self.score_key] if len(metrics) != 0 else 0.
|
||||
self.candidates.set_score(i, score)
|
||||
self.candidates.set_resource(i, score, 'score')
|
||||
|
||||
@torch.no_grad()
|
||||
def _val_candidate(self) -> Dict:
|
||||
|
@ -312,22 +355,22 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
def _sample_from_candidates(self) -> SupportRandomSubnet:
|
||||
"""Sample from the candidates."""
|
||||
assert len(self.candidates) > 0
|
||||
subnet = random.choice(self.candidates)
|
||||
subnet = random.choice(self.candidates.data)
|
||||
return subnet
|
||||
|
||||
def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
|
||||
def _check_constraints(self, random_subnet: SupportRandomSubnet):
|
||||
"""Check whether is beyond constraints.
|
||||
|
||||
Returns:
|
||||
bool: The result of checking.
|
||||
bool, result: The result of checking.
|
||||
"""
|
||||
is_pass = check_subnet_flops(
|
||||
is_pass, results = check_subnet_resources(
|
||||
model=self.model,
|
||||
subnet=random_subnet,
|
||||
estimator=self.estimator,
|
||||
flops_range=self.flops_range)
|
||||
constraints_range=self.constraints_range)
|
||||
|
||||
return is_pass
|
||||
return is_pass, results
|
||||
|
||||
def _save_candidates(self) -> None:
|
||||
"""Save the candidates to init the next searching."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .check import check_subnet_flops
|
||||
from .check import check_subnet_resources
|
||||
from .genetic import crossover
|
||||
|
||||
__all__ = ['crossover', 'check_subnet_flops']
|
||||
__all__ = ['crossover', 'check_subnet_resources']
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ResourceEstimator
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
|
@ -15,18 +15,20 @@ except ImportError:
|
|||
BaseDetector = get_placeholder('mmdet')
|
||||
|
||||
|
||||
def check_subnet_flops(
|
||||
model: nn.Module,
|
||||
subnet: SupportRandomSubnet,
|
||||
estimator: ResourceEstimator,
|
||||
flops_range: Optional[Tuple[float, float]] = None) -> bool:
|
||||
"""Check whether is beyond flops constraints.
|
||||
@torch.no_grad()
|
||||
def check_subnet_resources(
|
||||
model,
|
||||
subnet: SupportRandomSubnet,
|
||||
estimator: ResourceEstimator,
|
||||
constraints_range: Dict[str, Any] = dict(flops=(0, 330))
|
||||
) -> Tuple[bool, Dict]:
|
||||
"""Check whether is beyond resources constraints.
|
||||
|
||||
Returns:
|
||||
bool: The result of checking.
|
||||
bool, result: The result of checking.
|
||||
"""
|
||||
if flops_range is None:
|
||||
return True
|
||||
if constraints_range is None:
|
||||
return True, dict()
|
||||
|
||||
assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')
|
||||
model.set_subnet(subnet)
|
||||
|
@ -40,9 +42,10 @@ def check_subnet_flops(
|
|||
else:
|
||||
results = estimator.estimate(model=model_to_check)
|
||||
|
||||
flops = results['flops']
|
||||
flops_mix, flops_max = flops_range
|
||||
if flops_mix <= flops <= flops_max: # type: ignore
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
for k, v in constraints_range.items():
|
||||
if not isinstance(v, (list, tuple)):
|
||||
v = (0, v)
|
||||
if results[k] < v[0] or results[k] > v[1]:
|
||||
return False, results
|
||||
|
||||
return True, results
|
||||
|
|
|
@ -3,7 +3,8 @@ from .base import BaseAlgorithm
|
|||
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
|
||||
FpnTeacherDistill, OverhaulFeatureDistillation,
|
||||
SelfDistill, SingleTeacherDistill)
|
||||
from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
|
||||
from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP,
|
||||
Darts, DartsDDP)
|
||||
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
|
||||
from .pruning.ite_prune_algorithm import ItePruneAlgorithm
|
||||
|
||||
|
@ -25,4 +26,5 @@ __all__ = [
|
|||
'ItePruneAlgorithm',
|
||||
'DSNAS',
|
||||
'DSNASDDP',
|
||||
'Autoformer',
|
||||
]
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .autoformer import Autoformer
|
||||
from .autoslim import AutoSlim, AutoSlimDDP
|
||||
from .darts import Darts, DartsDDP
|
||||
from .dsnas import DSNAS, DSNASDDP
|
||||
from .spos import SPOS
|
||||
|
||||
__all__ = [
|
||||
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', 'DSNASDDP'
|
||||
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS',
|
||||
'DSNASDDP', 'Autoformer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.structures import BaseDataElement
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.utils import ValidFixMutable
|
||||
from ..base import BaseAlgorithm, LossResults
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Autoformer(BaseAlgorithm):
|
||||
"""Implementation of `Autoformer <https://arxiv.org/abs/2107.00651>`_
|
||||
|
||||
AutoFormer is dedicated to vision transformer search. AutoFormer
|
||||
entangles the weights of different blocks in the same layers during
|
||||
supernet training.
|
||||
The logic of the search part is implemented in
|
||||
:class:`mmrazor.engine.EvolutionSearchLoop`
|
||||
Args:
|
||||
architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel`
|
||||
or built model. Corresponding to supernet in NAS algorithm.
|
||||
mutators (Optional[dict]): The dict of different Mutators config.
|
||||
Defaults to None.
|
||||
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
|
||||
loaded dict or built :obj:`FixSubnet`. Defaults to None.
|
||||
data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process
|
||||
config of :class:`BaseDataPreprocessor`. Defaults to None.
|
||||
init_cfg (Optional[dict]): Init config for ``BaseModule``.
|
||||
Defaults to None.
|
||||
Note:
|
||||
Autoformer uses two mutators which are ``DynamicValueMutator`` and
|
||||
``ChannelMutator``. `DynamicValueMutator` handle the mutable object
|
||||
``OneShotMutableValue`` in Autoformer while ChannelMutator handle
|
||||
the mutable object ``OneShotMutableChannel`` in Autoformer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
mutators: Optional[Dict] = None,
|
||||
fix_subnet: Optional[ValidFixMutable] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(architecture, data_preprocessor, init_cfg)
|
||||
|
||||
# Autoformer support supernet training and subnet retraining.
|
||||
# fix_subnet is not None, means subnet retraining.
|
||||
if fix_subnet:
|
||||
# Avoid circular import
|
||||
from mmrazor.structures import load_fix_subnet
|
||||
|
||||
# According to fix_subnet, delete the unchosen part of supernet
|
||||
load_fix_subnet(self.architecture, fix_subnet)
|
||||
self.is_supernet = False
|
||||
else:
|
||||
assert mutators is not None, \
|
||||
'mutator cannot be None when fix_subnet is None.'
|
||||
if isinstance(mutators, dict):
|
||||
built_mutators: Dict = dict()
|
||||
for name, mutator_cfg in mutators.items():
|
||||
if 'parse_cfg' in mutator_cfg and isinstance(
|
||||
mutator_cfg['parse_cfg'], dict):
|
||||
assert mutator_cfg['parse_cfg'][
|
||||
'type'] == 'Predefined', \
|
||||
'autoformer only support predefined.'
|
||||
mutator = MODELS.build(mutator_cfg)
|
||||
built_mutators[name] = mutator
|
||||
mutator.prepare_from_supernet(self.architecture)
|
||||
self.mutators = built_mutators
|
||||
else:
|
||||
raise TypeError('mutator should be a `dict` but got '
|
||||
f'{type(mutator)}')
|
||||
|
||||
self.is_supernet = True
|
||||
|
||||
def sample_subnet(self) -> Dict:
|
||||
"""Random sample subnet by mutator."""
|
||||
subnet_dict = dict()
|
||||
for name, mutator in self.mutators.items():
|
||||
if name == 'value_mutator':
|
||||
subnet_dict.update(
|
||||
dict((str(group_id), value) for group_id, value in
|
||||
mutator.sample_choices().items()))
|
||||
else:
|
||||
subnet_dict.update(mutator.sample_choices())
|
||||
return subnet_dict
|
||||
|
||||
def set_subnet(self, subnet_dict: Dict) -> None:
|
||||
"""Set the subnet sampled by :meth:sample_subnet."""
|
||||
for name, mutator in self.mutators.items():
|
||||
if name == 'value_mutator':
|
||||
value_subnet = dict((int(group_id), value)
|
||||
for group_id, value in subnet_dict.items()
|
||||
if isinstance(group_id, str))
|
||||
mutator.set_choices(value_subnet)
|
||||
else:
|
||||
channel_subnet = dict(
|
||||
(group_id, value)
|
||||
for group_id, value in subnet_dict.items()
|
||||
if isinstance(group_id, int))
|
||||
mutator.set_choices(channel_subnet)
|
||||
|
||||
def loss(
|
||||
self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
) -> LossResults:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
if self.is_supernet:
|
||||
random_subnet = self.sample_subnet()
|
||||
self.set_subnet(random_subnet)
|
||||
return self.architecture(batch_inputs, data_samples, mode='loss')
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .classifiers import * # noqa: F401,F403
|
||||
from .connectors import * # noqa: F401,F403
|
||||
from .dynamic_ops import * # noqa: F401,F403
|
||||
from .generators import * # noqa: F401,F403
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .darts_backbone import DartsBackbone
|
||||
from .searchable_autoformer import AutoformerBackbone
|
||||
from .searchable_mobilenet import SearchableMobileNet
|
||||
from .searchable_shufflenet_v2 import SearchableShuffleNetV2
|
||||
from .wideresnet import WideResNet
|
||||
|
||||
__all__ = [
|
||||
'SearchableMobileNet', 'SearchableShuffleNetV2', 'DartsBackbone',
|
||||
'WideResNet'
|
||||
'WideResNet', 'AutoformerBackbone'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops.bricks import (
|
||||
DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed,
|
||||
DynamicSequential)
|
||||
from mmrazor.models.mutables import (BaseMutable, BaseMutableChannel,
|
||||
MutableChannelContainer,
|
||||
OneShotMutableChannel,
|
||||
OneShotMutableValue)
|
||||
from mmrazor.models.mutables.mutable_channel import OneShotMutableChannelUnit
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
try:
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
except ImportError:
|
||||
from mmrazor.utils import get_placeholder
|
||||
BaseBackbone = get_placeholder('mmcls')
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseBackbone):
|
||||
"""Autoformer block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
mlp_ratio (List): Ratio of ffn.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop_rate (float): The drop path rate after attention.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool, optional): Whether to keep bias of qkv.
|
||||
Defaults to True.
|
||||
act_cfg (Dict, optional): The config for acitvation function.
|
||||
Defaults to dict(type='GELU').
|
||||
norm_cfg (Dict, optional): The config for normalization.
|
||||
Defaults to dict(type='mmrazor.DynamicLayerNorm').
|
||||
init_cfg (Dict, optional): The config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
qkv_bias: bool = True,
|
||||
act_cfg: Dict = dict(type='GELU'),
|
||||
norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'),
|
||||
init_cfg: Dict = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.attn = DynamicMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
middle_channels = int(embed_dims * mlp_ratio)
|
||||
self.fc1 = DynamicLinear(embed_dims, middle_channels)
|
||||
self.fc2 = DynamicLinear(middle_channels, embed_dims)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""The first normalization."""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""The second normalization."""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def register_mutables(self, mutable_num_heads: BaseMutable,
|
||||
mutable_mlp_ratios: BaseMutable,
|
||||
mutable_q_embed_dims: BaseMutable,
|
||||
mutable_head_dims: BaseMutable,
|
||||
mutable_embed_dims: BaseMutable):
|
||||
"""Mutate the mutables of encoder layer."""
|
||||
# record the mutables
|
||||
self.mutable_num_heads = mutable_num_heads
|
||||
self.mutable_mlp_ratios = mutable_mlp_ratios
|
||||
self.mutable_q_embed_dims = mutable_q_embed_dims
|
||||
self.mutable_embed_dims = mutable_embed_dims
|
||||
self.mutable_head_dims = mutable_head_dims
|
||||
# handle the mutable of FFN
|
||||
self.middle_channels = mutable_mlp_ratios * mutable_embed_dims
|
||||
|
||||
self.attn.register_mutable_attr('num_heads', mutable_num_heads)
|
||||
|
||||
# handle the mutable of the first dynamic LN
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.norm1, self.mutable_embed_dims, True)
|
||||
# handle the mutable of the second dynamic LN
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.norm2, self.mutable_embed_dims, True)
|
||||
|
||||
# handle the mutable of attn
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.attn, self.mutable_embed_dims, False)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.attn,
|
||||
self.mutable_q_embed_dims,
|
||||
True,
|
||||
end=self.mutable_q_embed_dims.current_choice)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.attn.rel_pos_embed_k, self.mutable_head_dims, False)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.attn.rel_pos_embed_v, self.mutable_head_dims, False)
|
||||
|
||||
# handle the mutable of fc
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.fc1, mutable_embed_dims, False)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.fc1,
|
||||
self.middle_channels,
|
||||
True,
|
||||
start=0,
|
||||
end=self.middle_channels.current_choice)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.fc2,
|
||||
self.middle_channels,
|
||||
False,
|
||||
start=0,
|
||||
end=self.middle_channels.current_choice)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.fc2, mutable_embed_dims, True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward of Transformer Encode Layer."""
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = residual + x
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return residual + x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class AutoformerBackbone(BaseBackbone):
|
||||
"""Autoformer backbone.
|
||||
|
||||
A PyTorch implementation of Autoformer introduced by:
|
||||
`AutoFormer: Searching Transformers for Visual Recognition
|
||||
<https://arxiv.org/abs/2107.00651>`_
|
||||
|
||||
Modified from the `official repo
|
||||
<https://github.com/microsoft/Cream/blob/main/AutoFormer/>`.
|
||||
|
||||
Args:
|
||||
arch_setting (Dict[str, List]): Architecture settings.
|
||||
img_size (int, optional): The image size of input.
|
||||
Defaults to 224.
|
||||
patch_size (int, optional): The patch size of autoformer.
|
||||
Defaults to 16.
|
||||
in_channels (int, optional): The input channel dimension.
|
||||
Defaults to 3.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool, optional): Whether to keep bias of qkv.
|
||||
Defaults to True.
|
||||
norm_cfg (Dict, optional): The config of normalization.
|
||||
Defaults to dict(type='mmrazor.DynamicLayerNorm').
|
||||
act_cfg (Dict, optional): The config of activation functions.
|
||||
Defaults to dict(type='GELU').
|
||||
use_final_norm (bool, optional): Whether use final normalization.
|
||||
Defaults to True.
|
||||
init_cfg (Dict, optional): The config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
Excamples:
|
||||
>>> arch_setting = dict(
|
||||
... mlp_ratios=[3.0, 3.5, 4.0],
|
||||
... num_heads=[8, 9, 10],
|
||||
... depth=[14, 15, 16],
|
||||
... embed_dims=[528, 576, 624]
|
||||
... )
|
||||
>>> model = AutoformerBackbone(arch_setting=arch_setting)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch_setting: Dict[str, List],
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
qkv_bias: bool = True,
|
||||
norm_cfg: Dict = dict(type='mmrazor.DynamicLayerNorm'),
|
||||
act_cfg: Dict = dict(type='GELU'),
|
||||
use_final_norm: bool = True,
|
||||
init_cfg: Dict = None) -> None:
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.arch_setting = arch_setting
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.qkv_bias = qkv_bias
|
||||
self.in_channels = in_channels
|
||||
self.drop_rate = drop_rate
|
||||
self.use_final_norm = use_final_norm
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
# adapt mutable settings
|
||||
self.mlp_ratio_range: List = self.arch_setting['mlp_ratios']
|
||||
self.num_head_range: List = self.arch_setting['num_heads']
|
||||
self.depth_range: List = self.arch_setting['depth']
|
||||
self.embed_dim_range: List = self.arch_setting['embed_dims']
|
||||
|
||||
# mutable variables of autoformer
|
||||
self.mutable_depth = OneShotMutableValue(
|
||||
value_list=self.depth_range, default_value=self.depth_range[-1])
|
||||
|
||||
self.mutable_embed_dims = OneShotMutableChannel(
|
||||
num_channels=self.embed_dim_range[-1],
|
||||
candidate_choices=self.embed_dim_range)
|
||||
|
||||
# handle the mutable in multihead attention
|
||||
self.base_embed_dims = OneShotMutableChannel(
|
||||
num_channels=64, candidate_choices=[64])
|
||||
|
||||
self.mutable_num_heads = [
|
||||
OneShotMutableValue(
|
||||
value_list=self.num_head_range,
|
||||
default_value=self.num_head_range[-1])
|
||||
for _ in range(self.depth_range[-1])
|
||||
]
|
||||
self.mutable_mlp_ratios = [
|
||||
OneShotMutableValue(
|
||||
value_list=self.mlp_ratio_range,
|
||||
default_value=self.mlp_ratio_range[-1])
|
||||
for _ in range(self.depth_range[-1])
|
||||
]
|
||||
|
||||
self.mutable_q_embed_dims = [
|
||||
i * self.base_embed_dims for i in self.mutable_num_heads
|
||||
]
|
||||
|
||||
# patch embeddings
|
||||
self.patch_embed = DynamicPatchEmbed(
|
||||
img_size=self.img_size,
|
||||
in_channels=self.in_channels,
|
||||
embed_dims=self.mutable_embed_dims.num_channels)
|
||||
|
||||
# num of patches
|
||||
self.patch_resolution = [
|
||||
img_size // patch_size, img_size // patch_size
|
||||
]
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# cls token and pos embed
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1,
|
||||
self.mutable_embed_dims.num_channels))
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth decay rule
|
||||
self.dpr = np.linspace(0, drop_path_rate,
|
||||
self.mutable_depth.max_choice)
|
||||
|
||||
# main body
|
||||
self.blocks = self.make_layers(
|
||||
embed_dims=self.mutable_embed_dims.num_channels,
|
||||
depth=self.mutable_depth.max_choice)
|
||||
|
||||
# final norm
|
||||
if self.use_final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.mutable_embed_dims.num_channels)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.last_mutable = self.mutable_embed_dims
|
||||
|
||||
self.register_mutables()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""The first normalization."""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def make_layers(self, embed_dims, depth):
|
||||
"""Build multiple TransformerEncoderLayers."""
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
layer = TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=self.mutable_num_heads[i].max_choice,
|
||||
mlp_ratio=self.mutable_mlp_ratios[i].max_choice,
|
||||
drop_rate=self.drop_rate,
|
||||
drop_path_rate=self.dpr[i],
|
||||
qkv_bias=self.qkv_bias,
|
||||
act_cfg=self.act_cfg)
|
||||
layers.append(layer)
|
||||
return DynamicSequential(*layers)
|
||||
|
||||
def register_mutables(self):
|
||||
"""Mutate the autoformer."""
|
||||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self, MutableChannelContainer)
|
||||
|
||||
# handle the mutation of depth
|
||||
self.blocks.register_mutable_attr('depth', self.mutable_depth)
|
||||
|
||||
# handle the mutation of patch embed
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.patch_embed, self.mutable_embed_dims, True)
|
||||
|
||||
# handle the dependencies of TransformerEncoderLayers
|
||||
for i in range(self.mutable_depth.max_choice): # max depth here
|
||||
layer = self.blocks[i]
|
||||
layer.register_mutables(
|
||||
mutable_num_heads=self.mutable_num_heads[i],
|
||||
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
|
||||
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
|
||||
mutable_head_dims=self.base_embed_dims,
|
||||
mutable_embed_dims=self.last_mutable)
|
||||
|
||||
# handle the mutable of final norm
|
||||
if self.use_final_norm:
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.norm1, self.last_mutable, True)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward of Autoformer."""
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
embed_dims = int(self.mutable_embed_dims.current_choice) if isinstance(
|
||||
self.mutable_embed_dims,
|
||||
BaseMutableChannel) else self.embed_dim_range[-1]
|
||||
|
||||
# cls token
|
||||
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
# pos embed
|
||||
x = x + self.pos_embed[..., :embed_dims]
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
# dynamic depth
|
||||
x = self.blocks(x)
|
||||
|
||||
if self.use_final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
return (torch.mean(x[:, 1:], dim=1), )
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .image import SearchableImageClassifier
|
||||
|
||||
__all__ = ['SearchableImageClassifier']
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
try:
|
||||
from mmcls.models import ImageClassifier
|
||||
except ImportError:
|
||||
from mmrazor.utils import get_placeholder
|
||||
ImageClassifier = get_placeholder('mmcls')
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SearchableImageClassifier(ImageClassifier):
|
||||
"""SearchableImageClassifier for sliceable networks.
|
||||
|
||||
Args:
|
||||
backbone (dict): The same as ImageClassifier.
|
||||
neck (dict, optional): The same as ImageClassifier. Defaults to None.
|
||||
head (dict, optional): The same as ImageClassifier. Defaults to None.
|
||||
pretrained (dict, optional): The same as ImageClassifier. Defaults to
|
||||
None.
|
||||
train_cfg (dict, optional): The same as ImageClassifier. Defaults to
|
||||
None.
|
||||
data_preprocessor (dict, optional): The same as ImageClassifier.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The same as ImageClassifier. Defaults to
|
||||
None.
|
||||
connect_head (dict, optional): Dimensions are aligned in head will be
|
||||
substitute to it's `str type` value, so that search_space of the
|
||||
first components can be connets to the next. e.g:
|
||||
{'connect_with_backbone': 'backbone.last_mutable'} means that
|
||||
func:`connect_with_backbone` will be substitute to backbones
|
||||
last_mutable. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: Optional[dict] = None,
|
||||
head: Optional[dict] = None,
|
||||
pretrained: Optional[str] = None,
|
||||
train_cfg: Optional[dict] = None,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None,
|
||||
connect_head: Optional[dict] = None):
|
||||
super().__init__(backbone, neck, head, pretrained, train_cfg,
|
||||
data_preprocessor, init_cfg)
|
||||
|
||||
if self.with_head and connect_head is not None:
|
||||
for kh, vh in connect_head.items():
|
||||
component, attr = vh.split('.')
|
||||
value = getattr(getattr(self, component), attr)
|
||||
getattr(self.head, kh)(value)
|
|
@ -1,15 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d
|
||||
from .bricks.dynamic_linear import DynamicLinear
|
||||
from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d,
|
||||
DynamicBatchNorm3d, SwitchableBatchNorm2d)
|
||||
from .mixins.dynamic_conv_mixins import DynamicConvMixin
|
||||
from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
|
||||
DynamicLinearMixin, DynamicMixin)
|
||||
|
||||
__all__ = [
|
||||
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear',
|
||||
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d',
|
||||
'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin',
|
||||
'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin'
|
||||
]
|
||||
from .bricks import * # noqa: F401,F403
|
||||
from .head import * # noqa: F401,F403
|
||||
from .mixins import * # noqa: F401,F403
|
||||
|
|
|
@ -1 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dynamic_container import DynamicSequential
|
||||
from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d
|
||||
from .dynamic_embed import DynamicPatchEmbed
|
||||
from .dynamic_linear import DynamicLinear
|
||||
from .dynamic_multi_head_attention import DynamicMultiheadAttention
|
||||
from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d,
|
||||
DynamicBatchNorm3d, DynamicLayerNorm,
|
||||
SwitchableBatchNorm2d)
|
||||
from .dynamic_relative_position import DynamicRelativePosition2D
|
||||
|
||||
__all__ = [
|
||||
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear',
|
||||
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d',
|
||||
'SwitchableBatchNorm2d', 'DynamicSequential', 'DynamicPatchEmbed',
|
||||
'DynamicLayerNorm', 'DynamicRelativePosition2D',
|
||||
'DynamicMultiheadAttention'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Iterator, Optional, Set
|
||||
|
||||
import torch.nn as nn
|
||||
from mmengine.model import Sequential
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor.models.mutables import DerivedMutable, MutableValue
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from ..mixins import DynamicMixin
|
||||
|
||||
|
||||
class DynamicSequential(Sequential, DynamicMixin):
|
||||
"""Dynamic Sequential Container."""
|
||||
mutable_attrs: nn.ModuleDict
|
||||
accepted_mutable_attrs: Set[str] = {'depth'}
|
||||
|
||||
forward_ignored_module = (MutableValue, DerivedMutable, nn.ModuleDict)
|
||||
|
||||
def __init__(self, *args, init_cfg: Optional[dict] = None):
|
||||
super().__init__(*args, init_cfg=init_cfg)
|
||||
|
||||
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
|
||||
|
||||
@property
|
||||
def mutable_depth(self):
|
||||
"""Mutable depth."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['depth']
|
||||
|
||||
def register_mutable_attr(self: Sequential, attr: str,
|
||||
mutable: BaseMutable):
|
||||
"""Register attribute of mutable."""
|
||||
if attr == 'depth':
|
||||
self._register_mutable_depth(mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_depth(self: Sequential, mutable_depth: MutableValue):
|
||||
"""Register mutable depth."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
assert mutable_depth.current_choice is not None
|
||||
current_depth = mutable_depth.current_choice
|
||||
if current_depth > len(self._modules):
|
||||
raise ValueError(f'Expect depth of mutable to be smaller than '
|
||||
f'{len(self._modules)} as `depth`, '
|
||||
f'but got: {current_depth}.')
|
||||
self.mutable_attrs['depth'] = mutable_depth
|
||||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return Sequential
|
||||
|
||||
def to_static_op(self: Sequential) -> Sequential:
|
||||
"""Convert dynamic Sequential to static one."""
|
||||
self.check_if_mutables_fixed()
|
||||
|
||||
if self.mutable_depth is None:
|
||||
fixed_depth = len(self)
|
||||
else:
|
||||
fixed_depth = self.get_current_choice(self.mutable_depth)
|
||||
|
||||
modules = []
|
||||
passed_module_nums = 0
|
||||
for module in self:
|
||||
if isinstance(module, self.forward_ignored_module):
|
||||
continue
|
||||
else:
|
||||
passed_module_nums += 1
|
||||
if passed_module_nums > fixed_depth:
|
||||
break
|
||||
|
||||
modules.append(module)
|
||||
|
||||
return Sequential(*modules)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward of Dynamic Sequential."""
|
||||
if self.mutable_depth is None:
|
||||
return self(x)
|
||||
|
||||
current_depth = self.get_current_choice(self.mutable_depth)
|
||||
passed_module_nums = 0
|
||||
for module in self.pure_modules():
|
||||
passed_module_nums += 1
|
||||
if passed_module_nums > current_depth:
|
||||
break
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def pure_module_nums(self) -> int:
|
||||
"""Number of pure module."""
|
||||
return sum(1 for _ in self.pure_modules())
|
||||
|
||||
def pure_modules(self) -> Iterator[Module]:
|
||||
"""nn.Module would influence the forward of Sequential."""
|
||||
for module in self._modules.values():
|
||||
if isinstance(module, self.forward_ignored_module):
|
||||
continue
|
||||
yield module
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module: Sequential):
|
||||
"""Convert the static Sequential to dynamic one."""
|
||||
dynamic_m = cls(module._modules)
|
||||
return dynamic_m
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcls.models.utils import PatchEmbed
|
||||
from mmengine import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from mmrazor.registry import MODELS
|
||||
from ..mixins import DynamicChannelMixin
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DynamicPatchEmbed(PatchEmbed, DynamicChannelMixin):
|
||||
"""Dynamic Patch Embedding.
|
||||
|
||||
Note:
|
||||
Arguments for ``__init__`` of ``DynamicPatchEmbed`` is totally same as
|
||||
:obj:`mmcls.models.utils.PatchEmbed`.
|
||||
Attributes:
|
||||
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
|
||||
such as `embed_dims`. The key of the dict must in
|
||||
``accepted_mutable_attrs``.
|
||||
"""
|
||||
|
||||
mutable_attrs: nn.ModuleDict
|
||||
accepted_mutable_attrs: Set[str] = {'embed_dims'}
|
||||
attr_mappings: Dict[str, str] = {
|
||||
'in_channels': 'embed_dims',
|
||||
'out_channels': 'embed_dims'
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
|
||||
|
||||
@property
|
||||
def mutable_embed_dims(self):
|
||||
"""Mutable embedding dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['embed_dims']
|
||||
|
||||
def register_mutable_attr(self: PatchEmbed, attr: str,
|
||||
mutable: BaseMutable):
|
||||
"""Register attribute of mutable."""
|
||||
self.check_mutable_attr_valid(attr)
|
||||
if attr in self.attr_mappings:
|
||||
attr_map = self.attr_mappings[attr]
|
||||
assert attr_map in self.accepted_mutable_attrs
|
||||
if attr_map in self.mutable_attrs:
|
||||
print_log(
|
||||
f'{attr_map}({attr}) is already in `mutable_attrs`',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
self._register_mutable_attr(attr_map, mutable)
|
||||
elif attr in self.accepted_mutable_attrs:
|
||||
self._register_mutable_attr(attr, mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_attr(self, attr, mutable):
|
||||
"""Register `embed_dims`."""
|
||||
if attr == 'embed_dims':
|
||||
self._register_embed_dims(mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_embed_dims(self: PatchEmbed,
|
||||
mutable_patch_embedding: BaseMutable) -> None:
|
||||
"""Register mutable embedding dimension."""
|
||||
mask_size = mutable_patch_embedding.current_mask.size(0)
|
||||
|
||||
if mask_size != self.embed_dims:
|
||||
raise ValueError(
|
||||
f'Expect mask size of mutable to be {self.embed_dims} as '
|
||||
f'`embed_dims`, but got: {mask_size}.')
|
||||
|
||||
self.mutable_attrs['embed_dims'] = mutable_patch_embedding
|
||||
|
||||
def _get_dynamic_params(self: PatchEmbed) -> Tuple[Tensor, Tensor]:
|
||||
"""Get mask of ``embed_dims``"""
|
||||
if 'embed_dims' not in self.mutable_attrs:
|
||||
return self.projection.weight, self.projection.bias
|
||||
else:
|
||||
out_mask = self.mutable_embed_dims.current_mask.to(
|
||||
self.projection.weight.device)
|
||||
weight = self.projection.weight[out_mask][:]
|
||||
bias = self.projection.bias[
|
||||
out_mask] if self.projection.bias is not None else None # noqa: E501
|
||||
return weight, bias
|
||||
|
||||
def to_static_op(self: PatchEmbed) -> nn.Module:
|
||||
"""Convert dynamic PatchEmbed to static PatchEmbed."""
|
||||
self.check_if_mutables_fixed()
|
||||
assert self.mutable_embed_dims is not None
|
||||
|
||||
weight, bias = self._get_dynamic_params()
|
||||
static_patch_embed = self.static_op_factory(
|
||||
img_size=self.img_size,
|
||||
in_channels=3,
|
||||
embed_dims=self.mutable_embed_dims.activated_channels)
|
||||
|
||||
static_patch_embed.projection.weight = nn.Parameter(weight.clone())
|
||||
static_patch_embed.projection.bias = nn.Parameter(bias.clone())
|
||||
|
||||
return static_patch_embed
|
||||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return PatchEmbed
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module) -> nn.Module:
|
||||
"""Convert a PatchEmbed to a DynamicPatchEmbed."""
|
||||
|
||||
dynamic_patch_embed = cls(
|
||||
img_size=module.img_size,
|
||||
in_channels=3,
|
||||
embed_dims=module.embed_dims,
|
||||
norm_cfg=None,
|
||||
conv_cfg=None,
|
||||
init_cfg=None)
|
||||
|
||||
return dynamic_patch_embed
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward of dynamic patch embed."""
|
||||
weight, bias = self._get_dynamic_params()
|
||||
x = F.conv2d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
stride=16,
|
||||
padding=self.projection.padding,
|
||||
dilation=self.projection.dilation).flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
|
@ -0,0 +1,280 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmrazor.models.architectures.ops import MultiheadAttention
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from ..mixins import DynamicChannelMixin
|
||||
from .dynamic_relative_position import DynamicRelativePosition2D # noqa: E501
|
||||
|
||||
|
||||
class DynamicMultiheadAttention(MultiheadAttention, DynamicChannelMixin):
|
||||
"""Dynamic Multihead Attention with iRPE..
|
||||
|
||||
Note:
|
||||
Arguments for ``__init__`` of ``DynamicMultiheadAttention`` is
|
||||
totally same as
|
||||
:obj:`mmrazor.models.architectures.MultiheadAttention`.
|
||||
Attributes:
|
||||
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
|
||||
such as `num_heads`、 `embed_dims`、 `q_embed_dims`.
|
||||
The key of the dict must in ``accepted_mutable_attrs``.
|
||||
"""
|
||||
|
||||
mutable_attrs: nn.ModuleDict
|
||||
relative_position: bool
|
||||
max_relative_position: int
|
||||
w_qs: nn.Linear
|
||||
w_ks: nn.Linear
|
||||
w_vs: nn.Linear
|
||||
embed_dims: int
|
||||
q_embed_dims: int
|
||||
proj: nn.Linear
|
||||
attn_drop_rate: float
|
||||
accepted_mutable_attrs: Set[str] = {
|
||||
'num_heads', 'embed_dims', 'q_embed_dims'
|
||||
}
|
||||
attr_mappings: Dict[str, str] = {
|
||||
'in_channels': 'embed_dims',
|
||||
'out_channels': 'q_embed_dims',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
|
||||
|
||||
# dynamic image relative position encoding
|
||||
if self.relative_position:
|
||||
self.rel_pos_embed_k = DynamicRelativePosition2D(
|
||||
self.head_dims, self.max_relative_position)
|
||||
self.rel_pos_embed_v = DynamicRelativePosition2D(
|
||||
self.head_dims, self.max_relative_position)
|
||||
|
||||
@property
|
||||
def mutable_num_heads(self):
|
||||
"""Mutable number of heads."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['num_heads']
|
||||
|
||||
@property
|
||||
def mutable_embed_dims(self):
|
||||
"""Mutable embedding dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['embed_dims']
|
||||
|
||||
@property
|
||||
def mutable_q_embed_dims(self):
|
||||
"""Mutable intermediate embedding dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['q_embed_dims']
|
||||
|
||||
def register_mutable_attr(self, attr: str, mutable: BaseMutable):
|
||||
"""Register attribute of mutable."""
|
||||
self.check_mutable_attr_valid(attr)
|
||||
if attr in self.attr_mappings:
|
||||
attr_map = self.attr_mappings[attr]
|
||||
assert attr_map in self.accepted_mutable_attrs
|
||||
# if hasattr(self, 'mutable_attrs'):
|
||||
if attr_map in self.mutable_attrs:
|
||||
print_log(
|
||||
f'{attr_map}({attr}) is already in `mutable_attrs`',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
self._register_mutable_attr(attr_map, mutable)
|
||||
elif attr in self.accepted_mutable_attrs:
|
||||
self._register_mutable_attr(attr, mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_attr(self, attr: str, mutable: BaseMutable):
|
||||
"""Register `embed_dims` `q_embed_dims` `num_heads`"""
|
||||
if attr == 'num_heads':
|
||||
self._register_mutable_num_heads(mutable)
|
||||
elif attr == 'embed_dims':
|
||||
self._register_mutable_embed_dims(mutable)
|
||||
elif attr == 'q_embed_dims':
|
||||
self._register_mutable_q_embed_dims(mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_num_heads(self, mutable_num_heads):
|
||||
"""Register the mutable number of heads."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
current_choice = mutable_num_heads.current_choice
|
||||
if current_choice > self.num_heads:
|
||||
raise ValueError(
|
||||
f'Expect value of mutable to be smaller or equal than '
|
||||
f'{self.num_heads} as `num_heads`, but got: {current_choice}.')
|
||||
|
||||
self.mutable_attrs['num_heads'] = mutable_num_heads
|
||||
|
||||
def _register_mutable_embed_dims(self, mutable_embed_dims):
|
||||
"""Register mutable embedding dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
mask_size = mutable_embed_dims.current_mask.size(0)
|
||||
if mask_size != self.embed_dims:
|
||||
raise ValueError(
|
||||
f'Expect mask size of mutable to be {self.embed_dims} as '
|
||||
f'`embed_dims`, but got: {mask_size}.')
|
||||
|
||||
self.mutable_attrs['embed_dims'] = mutable_embed_dims
|
||||
|
||||
def _register_mutable_q_embed_dims(self, mutable_q_embed_dims):
|
||||
"""Register intermediate mutable embedding dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
self.mutable_attrs['q_embed_dims'] = mutable_q_embed_dims
|
||||
|
||||
def _get_dynamic_proj_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]:
|
||||
"""Get parameters of dynamic projection.
|
||||
|
||||
Note:
|
||||
The input dimension is decided by `mutable_q_embed_dims`.
|
||||
The output dimension is decided by `mutable_embed_dims`.
|
||||
"""
|
||||
# TODO support mask
|
||||
if self.mutable_embed_dims is None and \
|
||||
self.mutable_q_embed_dims is None:
|
||||
return w.weight, w.bias
|
||||
|
||||
if self.mutable_q_embed_dims is not None:
|
||||
in_features = self.mutable_q_embed_dims.activated_channels
|
||||
else:
|
||||
in_features = self.embed_dims
|
||||
|
||||
if self.mutable_embed_dims is not None:
|
||||
out_features = self.mutable_embed_dims.activated_channels
|
||||
else:
|
||||
out_features = self.embed_dims
|
||||
|
||||
weight = w.weight[:out_features, :in_features]
|
||||
bias = w.bias[:out_features] if w.bias is not None else None
|
||||
|
||||
return weight, bias
|
||||
|
||||
def _get_dynamic_qkv_params(self, w: nn.Linear) -> Tuple[Tensor, Tensor]:
|
||||
"""Get parameters of dynamic QKV.
|
||||
|
||||
Note:
|
||||
The output dimension is decided by `mutable_q_embed_dims`.
|
||||
The input dimension is decided by `mutable_embed_dims`.
|
||||
"""
|
||||
# TODO support mask later
|
||||
if self.mutable_q_embed_dims is None and \
|
||||
self.mutable_embed_dims is None:
|
||||
return w.weight, w.bias
|
||||
|
||||
if self.mutable_embed_dims is not None:
|
||||
in_features = self.mutable_embed_dims.activated_channels
|
||||
else:
|
||||
in_features = self.embed_dims
|
||||
|
||||
if self.mutable_q_embed_dims is not None:
|
||||
out_features = self.mutable_q_embed_dims.activated_channels
|
||||
else:
|
||||
out_features = self.mutable_q_embed_dims
|
||||
|
||||
weight = w.weight[:out_features, :in_features]
|
||||
bias = w.bias[:out_features] if w.bias is not None else None
|
||||
|
||||
return weight, bias
|
||||
|
||||
def to_static_op(self) -> MultiheadAttention:
|
||||
"""Convert dynamic MultiheadAttention to static one."""
|
||||
self.check_if_mutables_fixed()
|
||||
|
||||
embed_dims = self.mutable_embed_dims.activated_channels
|
||||
num_heads = self.mutable_num_heads.current_choice
|
||||
|
||||
q_w, q_b = self._get_dynamic_qkv_params(self.w_qs)
|
||||
k_w, k_b = self._get_dynamic_qkv_params(self.w_ks)
|
||||
v_w, v_b = self._get_dynamic_qkv_params(self.w_vs)
|
||||
|
||||
proj_w, proj_b = self._get_dynamic_proj_params(self.proj)
|
||||
|
||||
static_mha = MultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
input_dims=None,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
relative_position=self.relative_position,
|
||||
max_relative_position=self.max_relative_position)
|
||||
|
||||
static_mha.w_qs.weight = nn.Parameter(q_w.clone())
|
||||
static_mha.w_qs.bias = nn.Parameter(q_b.clone())
|
||||
|
||||
static_mha.w_ks.weight = nn.Parameter(k_w.clone())
|
||||
static_mha.w_ks.bias = nn.Parameter(k_b.clone())
|
||||
|
||||
static_mha.w_vs.weight = nn.Parameter(v_w.clone())
|
||||
static_mha.w_vs.bias = nn.Parameter(v_b.clone())
|
||||
|
||||
static_mha.proj.weight = nn.Parameter(proj_w.clone())
|
||||
static_mha.proj.bias = nn.Parameter(proj_b.clone())
|
||||
|
||||
if self.relative_position:
|
||||
static_mha.rel_pos_embed_k = self.rel_pos_embed_k.to_static_op()
|
||||
static_mha.rel_pos_embed_v = self.rel_pos_embed_v.to_static_op()
|
||||
|
||||
return static_mha
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module):
|
||||
"""Convert the static module to dynamic one."""
|
||||
dynamic_mha = cls(
|
||||
embed_dims=module.embed_dims,
|
||||
num_heads=module.num_heads,
|
||||
)
|
||||
return dynamic_mha
|
||||
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return MultiheadAttention
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward of dynamic MultiheadAttention."""
|
||||
B, N = x.shape[0], x.shape[1]
|
||||
q_w, q_b = self._get_dynamic_qkv_params(self.w_qs)
|
||||
k_w, k_b = self._get_dynamic_qkv_params(self.w_ks)
|
||||
v_w, v_b = self._get_dynamic_qkv_params(self.w_vs)
|
||||
|
||||
q_embed_dims = self.mutable_q_embed_dims.activated_channels
|
||||
num_heads = self.mutable_num_heads.current_choice
|
||||
|
||||
q = F.linear(x, q_w, q_b).view(B, N, num_heads,
|
||||
q_embed_dims // num_heads)
|
||||
k = F.linear(x, k_w, k_b).view(B, N, num_heads,
|
||||
q_embed_dims // num_heads)
|
||||
v = F.linear(x, v_w, v_b).view(B, N, num_heads,
|
||||
q_embed_dims // num_heads)
|
||||
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if self.relative_position:
|
||||
r_p_k = self.rel_pos_embed_k(N, N)
|
||||
attn = attn + (q.permute(2, 0, 1, 3).reshape(N, num_heads * B, -1) # noqa: E501
|
||||
@ r_p_k.transpose(2, 1)) \
|
||||
.transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
if self.relative_position:
|
||||
r_p_v = self.rel_pos_embed_v(N, N)
|
||||
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, -1)
|
||||
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(
|
||||
B, num_heads, N, -1).transpose(2, 1).reshape(B, N, -1)
|
||||
|
||||
# proj
|
||||
weight, bias = self._get_dynamic_proj_params(self.proj)
|
||||
x = F.linear(x, weight, bias)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
|
@ -4,11 +4,12 @@ from typing import Dict, List, Optional
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import LayerNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from mmrazor.registry import MODELS
|
||||
from ..mixins.dynamic_mixins import DynamicBatchNormMixin
|
||||
from ..mixins import DynamicBatchNormMixin, DynamicLayerNormMixin
|
||||
|
||||
|
||||
class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin):
|
||||
|
@ -91,6 +92,7 @@ class DynamicBatchNorm1d(_DynamicBatchNorm):
|
|||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return nn.BatchNorm1d
|
||||
|
||||
def _check_input_dim(self, input: Tensor) -> None:
|
||||
|
@ -106,6 +108,7 @@ class DynamicBatchNorm2d(_DynamicBatchNorm):
|
|||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return nn.BatchNorm2d
|
||||
|
||||
def _check_input_dim(self, input: Tensor) -> None:
|
||||
|
@ -121,6 +124,7 @@ class DynamicBatchNorm3d(_DynamicBatchNorm):
|
|||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return nn.BatchNorm3d
|
||||
|
||||
def _check_input_dim(self, input: Tensor) -> None:
|
||||
|
@ -190,3 +194,61 @@ class SwitchableBatchNorm2d(DynamicBatchNorm2d):
|
|||
def static_op_factory(self):
|
||||
"""Return initializer of static op."""
|
||||
return nn.BatchNorm2d
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DynamicLayerNorm(LayerNorm, DynamicLayerNormMixin):
|
||||
"""Applies Layer Normalization over a mini-batch of inputs according to the
|
||||
`mutable_num_channels` dynamically.
|
||||
|
||||
Note:
|
||||
Arguments for ``__init__`` of ``DynamicLayerNorm`` is totally same as
|
||||
:obj:`torch.nn.LayerNorm`.
|
||||
Attributes:
|
||||
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
|
||||
such as `num_features`. The key of the dict must in
|
||||
``accepted_mutable_attrs``.
|
||||
"""
|
||||
accepted_mutable_attrs = {'num_features'}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DynamicLayerNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
self.mutable_attrs: Dict[str, Optional[BaseMutable]] = nn.ModuleDict()
|
||||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return LayerNorm
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module: LayerNorm):
|
||||
"""Convert a _BatchNorm module to a DynamicBatchNorm.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module.
|
||||
"""
|
||||
dynamic_ln = cls(
|
||||
normalized_shape=module.normalized_shape,
|
||||
eps=module.eps,
|
||||
elementwise_affine=module.elementwise_affine)
|
||||
|
||||
return dynamic_ln
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""Slice the parameters according to `mutable_num_channels`, and
|
||||
forward."""
|
||||
self._check_input_dim(input)
|
||||
|
||||
weight, bias = self.get_dynamic_params()
|
||||
self.normalized_shape = (
|
||||
self.mutable_num_features.activated_channels, )
|
||||
|
||||
return F.layer_norm(input, self.normalized_shape, weight, bias,
|
||||
self.eps)
|
||||
|
||||
def _check_input_dim(self, input: Tensor) -> None:
|
||||
"""Check if input dimension is valid."""
|
||||
if input.dim() != 3:
|
||||
raise ValueError('expected 3D input (got {}D input)'.format(
|
||||
input.dim()))
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
|
||||
import torch
|
||||
from mmengine import print_log
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmrazor.models.architectures.ops import RelativePosition2D
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from ..mixins import DynamicChannelMixin
|
||||
|
||||
|
||||
class DynamicRelativePosition2D(RelativePosition2D, DynamicChannelMixin):
|
||||
"""Searchable RelativePosition module.
|
||||
|
||||
Note:
|
||||
Arguments for ``__init__`` of ``DynamicRelativePosition2D`` is totally
|
||||
same as :obj:`mmrazor.models.architectures.RelativePosition2D`.
|
||||
Attributes:
|
||||
mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes,
|
||||
such as `head_dims`. The key of the dict must in
|
||||
``accepted_mutable_attrs``.
|
||||
"""
|
||||
|
||||
mutable_attrs: nn.ModuleDict
|
||||
head_dims: int
|
||||
max_relative_position: int
|
||||
embeddings_table_v: nn.Parameter
|
||||
embeddings_table_h: nn.Parameter
|
||||
accepted_mutable_attrs: Set[str] = {'head_dims'}
|
||||
attr_mappings: Dict[str, str] = {
|
||||
'in_channels': 'head_dims',
|
||||
'out_channels': 'head_dims',
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict()
|
||||
|
||||
@property
|
||||
def mutable_head_dims(self):
|
||||
"""Mutable head dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['head_dims']
|
||||
|
||||
def register_mutable_attr(self, attr: str, mutable: BaseMutable):
|
||||
"""Register attribute of mutable."""
|
||||
self.check_mutable_attr_valid(attr)
|
||||
if attr in self.attr_mappings:
|
||||
attr_map = self.attr_mappings[attr]
|
||||
assert attr_map in self.accepted_mutable_attrs
|
||||
if attr_map in self.mutable_attrs:
|
||||
print_log(
|
||||
f'{attr_map}({attr}) is already in `mutable_attrs`',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
self._register_mutable_attr(attr_map, mutable)
|
||||
elif attr in self.accepted_mutable_attrs:
|
||||
self._register_mutable_attr(attr, mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_attr(self, attr, mutable):
|
||||
"""Register `head_dims`"""
|
||||
if attr == 'head_dims':
|
||||
self._registry_mutable_head_dims(mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _registry_mutable_head_dims(self,
|
||||
mutable_head_dims: BaseMutable) -> None:
|
||||
"""Register head dimension."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
self.mutable_attrs['head_dims'] = mutable_head_dims
|
||||
|
||||
def to_static_op(self) -> nn.Module:
|
||||
"""Convert dynamic RelativePosition2D to static One."""
|
||||
self.check_if_mutables_fixed()
|
||||
assert self.mutable_head_dims is not None
|
||||
|
||||
self.current_head_dim = self.mutable_head_dims.activated_channels
|
||||
static_relative_position = self.static_op_factory(
|
||||
self.current_head_dim)
|
||||
static_relative_position.embeddings_table_v = \
|
||||
nn.Parameter(
|
||||
self.embeddings_table_v[:, :self.current_head_dim].clone())
|
||||
static_relative_position.embeddings_table_h = \
|
||||
nn.Parameter(
|
||||
self.embeddings_table_h[:, :self.current_head_dim].clone())
|
||||
|
||||
return static_relative_position
|
||||
|
||||
@property
|
||||
def static_op_factory(self):
|
||||
"""Corresponding Pytorch OP."""
|
||||
return RelativePosition2D
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module):
|
||||
"""Convert a RP to a dynamic RP."""
|
||||
dynamic_rp = cls(
|
||||
head_dims=module.head_dims,
|
||||
max_relative_position=module.max_relative_position)
|
||||
return dynamic_rp
|
||||
|
||||
def forward(self, length_q, length_k) -> Tensor:
|
||||
"""Forward of Dynamic Relative Position."""
|
||||
if self.mutable_head_dims is None:
|
||||
self.current_head_dim = self.head_dims
|
||||
else:
|
||||
self.current_head_dim = self.mutable_head_dims.activated_channels
|
||||
|
||||
self.sample_eb_table_h = self.embeddings_table_h[:, :self.
|
||||
current_head_dim]
|
||||
self.sample_eb_table_v = self.embeddings_table_v[:, :self.
|
||||
current_head_dim]
|
||||
|
||||
# remove the first cls token distance computation
|
||||
length_q = length_q - 1
|
||||
length_k = length_k - 1
|
||||
range_vec_q = torch.arange(length_q)
|
||||
range_vec_k = torch.arange(length_k)
|
||||
# compute the row and column distance
|
||||
distance_mat_v = (
|
||||
range_vec_k[None, :] // int(length_q**0.5) -
|
||||
range_vec_q[:, None] // int(length_q**0.5))
|
||||
distance_mat_h = (
|
||||
range_vec_k[None, :] % int(length_q**0.5) -
|
||||
range_vec_q[:, None] % int(length_q**0.5))
|
||||
distance_mat_clipped_v = torch.clamp(distance_mat_v,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
distance_mat_clipped_h = torch.clamp(distance_mat_h,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
|
||||
final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1
|
||||
final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1
|
||||
# pad the 0 which represent the cls token
|
||||
final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0),
|
||||
'constant', 0)
|
||||
final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0),
|
||||
'constant', 0)
|
||||
|
||||
final_mat_v = torch.LongTensor(final_mat_v)
|
||||
final_mat_h = torch.LongTensor(final_mat_h)
|
||||
# get the embeddings with the corresponding distance
|
||||
|
||||
embeddings = self.sample_eb_table_v[final_mat_v] + \
|
||||
self.sample_eb_table_h[final_mat_h]
|
||||
|
||||
return embeddings
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dynamic_linear_head import DynamicLinearClsHead # noqa: F401
|
||||
|
||||
__all__ = ['DynamicLinearClsHead']
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmcls.models import ClsHead
|
||||
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
|
||||
from mmrazor.models.mutables.mutable_channel.units import \
|
||||
OneShotMutableChannelUnit
|
||||
from mmrazor.registry import MODELS
|
||||
from ..bricks.dynamic_linear import DynamicLinear
|
||||
|
||||
|
||||
class DynamicHead:
|
||||
|
||||
@abstractmethod
|
||||
def connect_with_backbone(self,
|
||||
backbone_output_mutable: BaseMutable) -> None:
|
||||
"""Connect with Dynamic Backbone."""
|
||||
...
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DynamicLinearClsHead(ClsHead, DynamicHead):
|
||||
"""Dynamic Linear classification head for Autoformer.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes.
|
||||
in_channels (int): Number of input channels.
|
||||
init_cfg (Optional[dict], optional): Init config.
|
||||
Defaults to dict(type='Normal',
|
||||
layer='DynamicLinear', std=0.01).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 1000,
|
||||
in_channels: int = 624,
|
||||
init_cfg: Optional[dict] = dict(
|
||||
type='Normal', layer='DynamicLinear', std=0.01),
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg, **kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
if self.num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
self.fc = DynamicLinear(self.in_channels, self.num_classes)
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``LinearClsHead``, we just obtain the
|
||||
feature of the last stage.
|
||||
"""
|
||||
# The LinearClsHead doesn't have other module, just return after
|
||||
# unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
# The final classification head.
|
||||
cls_score = self.fc(pre_logits)
|
||||
return cls_score
|
||||
|
||||
def connect_with_backbone(self,
|
||||
backbone_output_mutable: BaseMutable) -> None:
|
||||
"""Connect dynamic backbone."""
|
||||
|
||||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self, MutableChannelContainer)
|
||||
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.fc, backbone_output_mutable, False)
|
|
@ -1,9 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dynamic_conv_mixins import DynamicConvMixin
|
||||
from .dynamic_layernorm_mixins import DynamicLayerNormMixin
|
||||
from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
|
||||
DynamicLinearMixin, DynamicMixin)
|
||||
|
||||
__all__ = [
|
||||
'DynamicChannelMixin', 'DynamicBatchNormMixin', 'DynamicLinearMixin',
|
||||
'DynamicMixin', 'DynamicConvMixin'
|
||||
'DynamicChannelMixin',
|
||||
'DynamicBatchNormMixin',
|
||||
'DynamicLinearMixin',
|
||||
'DynamicMixin',
|
||||
'DynamicConvMixin',
|
||||
'DynamicLayerNormMixin',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine import print_log
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from .dynamic_mixins import DynamicChannelMixin
|
||||
|
||||
|
||||
class DynamicLayerNormMixin(DynamicChannelMixin):
|
||||
"""A mixin class for Pytorch LayerNorm, which can mutate
|
||||
``num_features``."""
|
||||
accepted_mutable_attrs: Set[str] = {'num_features'}
|
||||
|
||||
attr_mappings: Dict[str, str] = {
|
||||
'in_channels': 'num_features',
|
||||
'out_channels': 'num_features',
|
||||
}
|
||||
|
||||
@property
|
||||
def num_features(self):
|
||||
return getattr(self, 'normalized_shape')[0]
|
||||
|
||||
@property
|
||||
def mutable_num_features(self):
|
||||
"""Mutable number of features."""
|
||||
assert hasattr(self, 'mutable_attrs')
|
||||
return self.mutable_attrs['num_features']
|
||||
|
||||
def register_mutable_attr(self, attr, mutable):
|
||||
"""Register attribute of mutable."""
|
||||
self.check_mutable_attr_valid(attr)
|
||||
if attr in self.attr_mappings:
|
||||
attr_map = self.attr_mappings[attr]
|
||||
assert attr_map in self.accepted_mutable_attrs
|
||||
if attr_map in self.mutable_attrs:
|
||||
print_log(
|
||||
f'{attr_map}({attr}) is already in `mutable_attrs`',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
self._register_mutable_attr(attr_map, mutable)
|
||||
elif attr in self.accepted_mutable_attrs:
|
||||
self._register_mutable_attr(attr, mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_attr(self, attr, mutable):
|
||||
"""Register `num_features`."""
|
||||
if attr == 'num_features':
|
||||
self._register_mutable_num_features(mutable)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_mutable_num_features(
|
||||
self: LayerNorm, mutable_num_features: BaseMutable) -> None:
|
||||
"""Mutate ``num_features`` with given mutable.
|
||||
|
||||
Args:
|
||||
mutable_num_features (BaseMutable): Mutable for controlling
|
||||
``num_features``.
|
||||
Raises:
|
||||
RuntimeError: Error if both ``affine`` and
|
||||
``tracking_running_stats`` are False.
|
||||
ValueError: Error if size of mask if not same as ``num_features``.
|
||||
"""
|
||||
if not self.elementwise_affine:
|
||||
raise RuntimeError(
|
||||
'num_features can not be mutated if both `affine` and '
|
||||
'`tracking_running_stats` are False')
|
||||
|
||||
self.check_mutable_channels(mutable_num_features)
|
||||
mask_size = mutable_num_features.current_mask.size(0)
|
||||
|
||||
# normalized_shape is a tuple
|
||||
if mask_size != self.normalized_shape[0]:
|
||||
raise ValueError(
|
||||
f'Expect mask size of mutable to be {self.normalized_shape}'
|
||||
f' as `normalized_shape`, but got: {mask_size}.')
|
||||
|
||||
self.mutable_attrs['num_features'] = mutable_num_features
|
||||
|
||||
def _get_num_features_mask(self: LayerNorm) -> Optional[torch.Tensor]:
|
||||
"""Get mask of ``num_features``."""
|
||||
if self.elementwise_affine:
|
||||
refer_tensor = self.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
if 'num_features' in self.mutable_attrs:
|
||||
out_mask = self.mutable_num_features.current_mask.to(
|
||||
refer_tensor.device)
|
||||
else:
|
||||
out_mask = torch.ones_like(refer_tensor).bool()
|
||||
|
||||
return out_mask
|
||||
|
||||
def get_dynamic_params(
|
||||
self: LayerNorm) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||||
"""Get dynamic parameters that will be used in forward process.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor],
|
||||
Optional[Tensor]]: Sliced running_mean, running_var, weight and
|
||||
bias.
|
||||
"""
|
||||
out_mask = self._get_num_features_mask()
|
||||
|
||||
if self.elementwise_affine:
|
||||
weight = self.weight[out_mask]
|
||||
bias = self.bias[out_mask]
|
||||
else:
|
||||
weight, bias = self.weight, self.bias
|
||||
|
||||
return weight, bias
|
||||
|
||||
def to_static_op(self: LayerNorm) -> nn.Module:
|
||||
"""Convert dynamic LayerNormxd to :obj:`torch.nn.LayerNormxd`.
|
||||
|
||||
Returns:
|
||||
torch.nn.LayerNormxd: :obj:`torch.nn.LayerNormxd` with sliced
|
||||
parameters.
|
||||
"""
|
||||
self.check_if_mutables_fixed()
|
||||
|
||||
weight, bias = self.get_dynamic_params()
|
||||
|
||||
if 'num_features' in self.mutable_attrs:
|
||||
num_features = self.mutable_attrs['num_features'].current_mask.sum(
|
||||
).item()
|
||||
else:
|
||||
num_features = self.num_features
|
||||
|
||||
static_ln = self.static_op_factory(
|
||||
normalized_shape=num_features,
|
||||
eps=self.eps,
|
||||
elementwise_affine=self.elementwise_affine)
|
||||
|
||||
if weight is not None:
|
||||
static_ln.weight = nn.Parameter(weight.clone())
|
||||
if bias is not None:
|
||||
static_ln.bias = nn.Parameter(bias.clone())
|
||||
|
||||
return static_ln
|
|
@ -74,12 +74,16 @@ class DynamicMixin(ABC):
|
|||
Raises:
|
||||
RuntimeError: Error if a existing mutable is not fixed.
|
||||
"""
|
||||
from mmrazor.models.mutables import (DerivedMutable,
|
||||
MutableChannelContainer)
|
||||
|
||||
def check_fixed(mutable: Optional[BaseMutable]) -> None:
|
||||
if mutable is not None and not mutable.is_fixed:
|
||||
raise RuntimeError(f'Mutable {type(mutable)} is not fixed.')
|
||||
|
||||
for mutable in self.mutable_attrs.values(): # type: ignore
|
||||
if isinstance(mutable, (MutableChannelContainer, DerivedMutable)):
|
||||
continue
|
||||
check_fixed(mutable)
|
||||
|
||||
def check_mutable_attr_valid(self, attr):
|
||||
|
@ -115,6 +119,11 @@ class DynamicChannelMixin(DynamicMixin):
|
|||
``mutable_out_channels`` APIs.
|
||||
"""
|
||||
|
||||
attr_mappings: Dict[str, str] = {
|
||||
'in_channels': 'in_channels',
|
||||
'out_channels': 'out_channels',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_mutable_channels(mutable_channels: BaseMutable) -> None:
|
||||
"""Check if mutable has `currnet_mask` attribute.
|
||||
|
|
|
@ -6,9 +6,11 @@ from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv
|
|||
from .gather_tensors import GatherTensors
|
||||
from .mobilenet_series import MBBlock
|
||||
from .shufflenet_series import ShuffleBlock, ShuffleXception
|
||||
from .transformer_series import MultiheadAttention, RelativePosition2D
|
||||
|
||||
__all__ = [
|
||||
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
|
||||
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity',
|
||||
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors'
|
||||
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors',
|
||||
'RelativePosition2D', 'MultiheadAttention'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
|
||||
class RelativePosition2D(nn.Module):
|
||||
"""Rethinking and Improving Relative Position Encoding for Vision
|
||||
Transformer.
|
||||
|
||||
ICCV 2021. https://arxiv.org/pdf/2107.14222.pdf
|
||||
Image RPE (iRPE for short) methods are new relative position encoding
|
||||
methods dedicated to 2D images.
|
||||
Args:
|
||||
head_dims (int): embedding dims of relative position.
|
||||
max_relative_position (int): The max relative position distance.
|
||||
"""
|
||||
|
||||
def __init__(self, head_dims: int, max_relative_position: int = 14):
|
||||
super().__init__()
|
||||
|
||||
self.head_dims = head_dims
|
||||
self.max_relative_position = max_relative_position
|
||||
# The first element in embeddings_table_v is the vertical embedding
|
||||
# for the class
|
||||
self.embeddings_table_v = nn.Parameter(
|
||||
torch.randn(max_relative_position * 2 + 2, head_dims))
|
||||
self.embeddings_table_h = nn.Parameter(
|
||||
torch.randn(max_relative_position * 2 + 2, head_dims))
|
||||
|
||||
trunc_normal_(self.embeddings_table_v, std=.02)
|
||||
trunc_normal_(self.embeddings_table_h, std=.02)
|
||||
|
||||
def forward(self, length_q, length_k):
|
||||
# remove the first cls token distance computation
|
||||
length_q = length_q - 1
|
||||
length_k = length_k - 1
|
||||
range_vec_q = torch.arange(length_q)
|
||||
range_vec_k = torch.arange(length_k)
|
||||
# compute the row and column distance
|
||||
distance_mat_v = (
|
||||
range_vec_k[None, :] // int(length_q**0.5) -
|
||||
range_vec_q[:, None] // int(length_q**0.5))
|
||||
distance_mat_h = (
|
||||
range_vec_k[None, :] % int(length_q**0.5) -
|
||||
range_vec_q[:, None] % int(length_q**0.5))
|
||||
# clip the distance to the range of
|
||||
# [-max_relative_position, max_relative_position]
|
||||
distance_mat_clipped_v = torch.clamp(distance_mat_v,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
distance_mat_clipped_h = torch.clamp(distance_mat_h,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
|
||||
# translate the distance from [1, 2 * max_relative_position + 1],
|
||||
# 0 is for the cls token
|
||||
final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1
|
||||
final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1
|
||||
# pad the 0 which represent the cls token
|
||||
final_mat_v = torch.nn.functional.pad(final_mat_v, (1, 0, 1, 0),
|
||||
'constant', 0)
|
||||
final_mat_h = torch.nn.functional.pad(final_mat_h, (1, 0, 1, 0),
|
||||
'constant', 0)
|
||||
|
||||
final_mat_v = torch.LongTensor(final_mat_v)
|
||||
final_mat_h = torch.LongTensor(final_mat_h)
|
||||
# get the embeddings with the corresponding distance
|
||||
embeddings = self.embeddings_table_v[
|
||||
final_mat_v] + self.embeddings_table_h[final_mat_h]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-head Attention Module with iRPE.
|
||||
|
||||
This module implements multi-head attention that supports different input
|
||||
dims and embed dims. And it also supports a shortcut from ``value``, which
|
||||
is useful if input dims is not the same with embed dims.
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
input_dims (int, optional): The input dimension, and if None,
|
||||
use ``embed_dims``. Defaults to None.
|
||||
attn_drop (float): Dropout rate of the dropout layer after the
|
||||
attention calculation of query and key. Defaults to 0.
|
||||
proj_drop (float): Dropout rate of the dropout layer after the
|
||||
output projection. Defaults to 0.
|
||||
dropout_layer (dict): The dropout config before adding the shortcut.
|
||||
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
||||
relative_position (bool, optional): Whether use relative position.
|
||||
Defaults to True.
|
||||
max_relative_position (int): The max relative position distance.
|
||||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
proj_bias (bool) If True, add a learnable bias to output projection.
|
||||
Defaults to True.
|
||||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||||
used if ``input_dims`` is different from ``embed_dims``.
|
||||
Defaults to False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
input_dims: Optional[int] = None,
|
||||
attn_drop_rate: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
dropout_layer: Dict = dict(type='Dropout', drop_prob=0.),
|
||||
relative_position: Optional[bool] = True,
|
||||
max_relative_position: int = 14,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: Optional[float] = None,
|
||||
proj_bias: bool = True,
|
||||
v_shortcut: bool = False,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__()
|
||||
|
||||
self.input_dims = input_dims or embed_dims
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.v_shortcut = v_shortcut
|
||||
self.relative_position = relative_position
|
||||
self.max_relative_position = max_relative_position
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
|
||||
self.head_dims = 64 # unit
|
||||
self.scale = qk_scale or self.head_dims**-0.5
|
||||
|
||||
self.q_embed_dims = num_heads * self.head_dims
|
||||
|
||||
self.w_qs = nn.Linear(
|
||||
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
|
||||
self.w_ks = nn.Linear(
|
||||
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
|
||||
self.w_vs = nn.Linear(
|
||||
self.input_dims, num_heads * self.head_dims, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(
|
||||
num_heads * self.head_dims, embed_dims, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.out_drop = nn.Dropout(dropout_layer['drop_prob'])
|
||||
|
||||
# image relative position encoding
|
||||
if self.relative_position:
|
||||
self.rel_pos_embed_k = RelativePosition2D(
|
||||
self.head_dims, self.max_relative_position)
|
||||
self.rel_pos_embed_v = RelativePosition2D(
|
||||
self.head_dims, self.max_relative_position)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
q = self.w_qs(x).view(B, N, self.num_heads, self.head_dims)
|
||||
k = self.w_ks(x).view(B, N, self.num_heads, self.head_dims)
|
||||
v = self.w_vs(x).view(B, N, self.num_heads, self.head_dims)
|
||||
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if self.relative_position:
|
||||
r_p_k = self.rel_pos_embed_k(N, N)
|
||||
attn = attn + (q.permute(2, 0, 1, 3).reshape(N, self.num_heads * B, -1) # noqa: E501
|
||||
@ r_p_k.transpose(2, 1)) \
|
||||
.transpose(1, 0).reshape(B, self.num_heads, N, N) * self.scale
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
if self.relative_position:
|
||||
r_p_v = self.rel_pos_embed_v(N, N)
|
||||
t_attn = attn.permute(2, 0, 1, 3).reshape(N, B * self.num_heads,
|
||||
-1)
|
||||
x = x + (t_attn @ r_p_v).transpose(1, 0).reshape(
|
||||
B, self.num_heads, N, -1).transpose(2, 1).reshape(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.out_drop(self.proj_drop(x))
|
||||
|
||||
if self.v_shortcut:
|
||||
x = v.squeeze(1) + x
|
||||
return x
|
|
@ -61,8 +61,15 @@ def _expand_mask_fn(
|
|||
|
||||
def fn():
|
||||
mask = mutable.current_mask
|
||||
expand_num_channels = int(mask.size(0) * expand_ratio)
|
||||
expand_choice = int(mutable.current_choice * expand_ratio)
|
||||
if isinstance(expand_ratio, int):
|
||||
expand_num_channels = mask.size(0) * expand_ratio
|
||||
expand_choice = mutable.current_choice * expand_ratio
|
||||
elif isinstance(expand_ratio, float):
|
||||
expand_num_channels = int(mask.size(0) * expand_ratio)
|
||||
expand_choice = int(mutable.current_choice * expand_ratio)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not support type of expand_ratio: {type(expand_ratio)}')
|
||||
expand_mask = torch.zeros(expand_num_channels).bool()
|
||||
expand_mask[:expand_choice] = True
|
||||
|
||||
|
@ -136,25 +143,62 @@ class DerivedMethodMixin:
|
|||
|
||||
def derive_expand_mutable(
|
||||
self: MutableProtocol,
|
||||
expand_ratio: Union[int, float]) -> 'DerivedMutable':
|
||||
expand_ratio: Union[int, BaseMutable, float]) -> 'DerivedMutable':
|
||||
"""Derive expand mutable, usually used with `expand_ratio`."""
|
||||
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
|
||||
# avoid circular import
|
||||
if isinstance(expand_ratio, int):
|
||||
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
|
||||
elif isinstance(expand_ratio, float):
|
||||
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)
|
||||
elif isinstance(expand_ratio, BaseMutable):
|
||||
current_ratio = expand_ratio.current_choice
|
||||
choice_fn = _expand_choice_fn(self, expand_ratio=current_ratio)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not support type of ratio: {type(expand_ratio)}')
|
||||
|
||||
mask_fn: Optional[Callable] = None
|
||||
if hasattr(self, 'current_mask'):
|
||||
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
|
||||
if isinstance(expand_ratio, int):
|
||||
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
|
||||
elif isinstance(expand_ratio, float):
|
||||
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)
|
||||
elif isinstance(expand_ratio, BaseMutable):
|
||||
mask_fn = _expand_mask_fn(self, expand_ratio=current_ratio)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not support type of ratio: {type(expand_ratio)}')
|
||||
|
||||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)
|
||||
|
||||
def derive_divide_mutable(self: MutableProtocol,
|
||||
ratio: int,
|
||||
ratio: Union[int, float, BaseMutable],
|
||||
divisor: int = 8) -> 'DerivedMutable':
|
||||
"""Derive divide mutable, usually used with `make_divisable`."""
|
||||
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor)
|
||||
from .mutable_channel import BaseMutableChannel
|
||||
|
||||
# avoid circular import
|
||||
if isinstance(ratio, int):
|
||||
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor)
|
||||
current_ratio = ratio
|
||||
elif isinstance(ratio, float):
|
||||
current_ratio = int(ratio)
|
||||
choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1)
|
||||
elif isinstance(ratio, BaseMutable):
|
||||
current_ratio = int(ratio.current_choice)
|
||||
choice_fn = _divide_choice_fn(self, ratio=current_ratio, divisor=1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not support type of ratio: {type(ratio)}')
|
||||
|
||||
mask_fn: Optional[Callable] = None
|
||||
if hasattr(self, 'current_mask'):
|
||||
mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor)
|
||||
if isinstance(self, BaseMutableChannel) and hasattr(
|
||||
self, 'current_mask'):
|
||||
mask_fn = _divide_mask_fn(
|
||||
self, ratio=current_ratio, divisor=divisor)
|
||||
elif getattr(self, 'mask_fn', None): # OneShotMutableChannel
|
||||
mask_fn = _divide_mask_fn(
|
||||
self, ratio=current_ratio, divisor=divisor)
|
||||
|
||||
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ import copy
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.utils import IndexDict
|
||||
from ...architectures.dynamic_ops.mixins import DynamicChannelMixin
|
||||
from .base_mutable_channel import BaseMutableChannel
|
||||
from .simple_mutable_channel import SimpleMutableChannel
|
||||
|
||||
|
@ -66,7 +66,7 @@ class MutableChannelContainer(BaseMutableChannel):
|
|||
|
||||
def register_mutable(self, mutable_channel: BaseMutableChannel, start: int,
|
||||
end: int):
|
||||
"""Register/Store BaseMutableChannel in the MutableChannelContainer in
|
||||
"""Register/Store BaseMutableChannel in the MutableChannelContainer in
|
||||
the range [start,end)"""
|
||||
|
||||
self.mutable_channels[(start, end)] = mutable_channel
|
||||
|
|
|
@ -82,7 +82,7 @@ class SquentialMutableChannel(SimpleMutableChannel):
|
|||
mutable2: OneShotMutableValue) -> Callable:
|
||||
|
||||
def fn():
|
||||
return mutable1.current_choice * mutable2.current_choice
|
||||
return int(mutable1.current_choice * mutable2.current_choice)
|
||||
|
||||
return fn
|
||||
|
||||
|
@ -93,9 +93,10 @@ class SquentialMutableChannel(SimpleMutableChannel):
|
|||
mask = mutable1.current_mask
|
||||
max_expand_ratio = mutable2.max_choice
|
||||
current_expand_ratio = mutable2.current_choice
|
||||
expand_num_channels = mask.size(0) * max_expand_ratio
|
||||
expand_num_channels = int(mask.size(0) * max_expand_ratio)
|
||||
|
||||
expand_choice = mutable1.current_choice * current_expand_ratio
|
||||
expand_choice = int(mutable1.current_choice *
|
||||
current_expand_ratio)
|
||||
expand_mask = torch.zeros(expand_num_channels).bool()
|
||||
expand_mask[:expand_choice] = True
|
||||
|
||||
|
@ -113,10 +114,17 @@ class SquentialMutableChannel(SimpleMutableChannel):
|
|||
def __floordiv__(self, other) -> DerivedMutable:
|
||||
if isinstance(other, int):
|
||||
return self.derive_divide_mutable(other)
|
||||
elif isinstance(other, float):
|
||||
return self.derive_divide_mutable(int(other))
|
||||
if isinstance(other, tuple):
|
||||
assert len(other) == 2
|
||||
return self.derive_divide_mutable(*other)
|
||||
|
||||
from ..mutable_value import OneShotMutableValue
|
||||
if isinstance(other, OneShotMutableValue):
|
||||
ratio = other.current_choice
|
||||
return self.derive_divide_mutable(ratio)
|
||||
|
||||
raise TypeError(f'Unsupported type {type(other)} for div!')
|
||||
|
||||
def _num2ratio(self, choice: Union[int, float]) -> float:
|
||||
|
|
|
@ -262,14 +262,16 @@ class ChannelUnit(BaseModule):
|
|||
def add_ouptut_related(self, channel: Channel):
|
||||
"""Add a Channel which is output related."""
|
||||
assert channel.is_output_channel
|
||||
assert self.num_channels == channel.num_channels
|
||||
assert self.num_channels == \
|
||||
int(channel.num_channels // channel.expand_ratio)
|
||||
if channel not in self.output_related:
|
||||
self.output_related.append(channel)
|
||||
|
||||
def add_input_related(self, channel: Channel):
|
||||
"""Add a Channel which is input related."""
|
||||
assert channel.is_output_channel is False
|
||||
assert self.num_channels == channel.num_channels
|
||||
assert self.num_channels == \
|
||||
int(channel.num_channels // channel.expand_ratio)
|
||||
if channel not in self.input_related:
|
||||
self.input_related.append(channel)
|
||||
|
||||
|
|
|
@ -6,26 +6,23 @@ from typing import Dict, List, Type, TypeVar
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.models.architectures import dynamic_ops
|
||||
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
|
||||
from mmrazor.models.mutables import DerivedMutable
|
||||
from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel,
|
||||
MutableChannelContainer)
|
||||
from mmrazor.models.mutables.mutable_value import MutableValue
|
||||
from .channel_unit import Channel, ChannelUnit
|
||||
|
||||
|
||||
class MutableChannelUnit(ChannelUnit):
|
||||
|
||||
# init methods
|
||||
def __init__(self, num_channels: int, **kwargs) -> None:
|
||||
"""MutableChannelUnit inherits from ChannelUnit, which manages channels
|
||||
with channel-dependency.
|
||||
with channel-dependency. Compared with ChannelUnit, MutableChannelUnit
|
||||
defines the core interfaces for pruning. By inheriting
|
||||
MutableChannelUnit, we can implement a variant pruning and nas
|
||||
algorithm. These apis includes.
|
||||
|
||||
Compared with ChannelUnit, MutableChannelUnit defines the core
|
||||
interfaces for pruning. By inheriting MutableChannelUnit,
|
||||
we can implement a variant pruning and nas algorithm.
|
||||
|
||||
These apis includes
|
||||
- basic property
|
||||
- name
|
||||
- is_mutable
|
||||
|
@ -60,6 +57,7 @@ class MutableChannelUnit(ChannelUnit):
|
|||
mutable2units,
|
||||
is_output=True):
|
||||
for index, mutable in contanier.mutable_channels.items():
|
||||
expand_ratio = 1
|
||||
if isinstance(mutable, DerivedMutable):
|
||||
source_mutables: Set = \
|
||||
mutable._trace_source_mutables()
|
||||
|
@ -72,6 +70,17 @@ class MutableChannelUnit(ChannelUnit):
|
|||
'used in DerivedMutable')
|
||||
mutable = list(source_channel_mutables)[0]
|
||||
|
||||
source_value_mutables = [
|
||||
mutable for mutable in source_mutables
|
||||
if isinstance(mutable, MutableValue)
|
||||
]
|
||||
assert len(source_value_mutables) <= 1, (
|
||||
'only support one mutable value '
|
||||
'used in DerivedMutable')
|
||||
expand_ratio = int(
|
||||
list(source_value_mutables)
|
||||
[0].current_choice) if source_value_mutables else 1
|
||||
|
||||
if mutable not in mutable2units:
|
||||
mutable2units[mutable] = cls.init_from_mutable_channel(
|
||||
mutable)
|
||||
|
@ -83,14 +92,16 @@ class MutableChannelUnit(ChannelUnit):
|
|||
module_name,
|
||||
module,
|
||||
index,
|
||||
is_output_channel=is_output))
|
||||
is_output_channel=is_output,
|
||||
expand_ratio=expand_ratio))
|
||||
else:
|
||||
unit.add_input_related(
|
||||
Channel(
|
||||
module_name,
|
||||
module,
|
||||
index,
|
||||
is_output_channel=is_output))
|
||||
is_output_channel=is_output,
|
||||
expand_ratio=expand_ratio))
|
||||
|
||||
mutable2units: Dict = {}
|
||||
for name, module in model.named_modules():
|
||||
|
@ -121,7 +132,7 @@ class MutableChannelUnit(ChannelUnit):
|
|||
if channel.is_mutable is False:
|
||||
all_channel_prunable = False
|
||||
break
|
||||
if isinstance(channel.module, dynamic_ops.DynamicChannelMixin):
|
||||
if isinstance(channel.module, DynamicChannelMixin):
|
||||
has_dynamic_op = True
|
||||
return has_dynamic_op, all_channel_prunable
|
||||
|
||||
|
@ -223,29 +234,16 @@ class MutableChannelUnit(ChannelUnit):
|
|||
model: nn.Module, container_class: Type[MutableChannelContainer]):
|
||||
"""register channel container for dynamic ops."""
|
||||
for module in model.modules():
|
||||
if isinstance(module, dynamic_ops.DynamicChannelMixin):
|
||||
if isinstance(module, DynamicChannelMixin):
|
||||
in_channels = getattr(module,
|
||||
module.attr_mappings['in_channels'], 0)
|
||||
if module.get_mutable_attr('in_channels') is None:
|
||||
in_channels = 0
|
||||
if isinstance(module, nn.Conv2d):
|
||||
in_channels = module.in_channels
|
||||
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
in_channels = module.num_features
|
||||
elif isinstance(module, nn.Linear):
|
||||
in_channels = module.in_features
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
module.register_mutable_attr('in_channels',
|
||||
container_class(in_channels))
|
||||
out_channels = getattr(module,
|
||||
module.attr_mappings['out_channels'], 0)
|
||||
if module.get_mutable_attr('out_channels') is None:
|
||||
out_channels = 0
|
||||
if isinstance(module, nn.Conv2d):
|
||||
out_channels = module.out_channels
|
||||
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
out_channels = module.num_features
|
||||
elif isinstance(module, nn.Linear):
|
||||
out_channels = module.out_features
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
module.register_mutable_attr('out_channels',
|
||||
container_class(out_channels))
|
||||
|
||||
|
@ -253,7 +251,7 @@ class MutableChannelUnit(ChannelUnit):
|
|||
# register mutable_channel
|
||||
for channel in list(self.input_related) + list(self.output_related):
|
||||
module = channel.module
|
||||
if isinstance(module, dynamic_ops.DynamicChannelMixin):
|
||||
if isinstance(module, DynamicChannelMixin):
|
||||
container: MutableChannelContainer
|
||||
if channel.is_output_channel and module.get_mutable_attr(
|
||||
'out_channels') is not None:
|
||||
|
|
|
@ -39,6 +39,7 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
|
|||
min_ratio=0.9) -> None:
|
||||
super().__init__(num_channels, choice_mode, divisor, min_value,
|
||||
min_ratio)
|
||||
|
||||
candidate_choices = copy.copy(candidate_choices)
|
||||
if candidate_choices == []:
|
||||
candidate_choices.append(
|
||||
|
@ -50,6 +51,8 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
|
|||
self.candidate_choices,
|
||||
choice_mode)
|
||||
|
||||
self.unit_predefined = False
|
||||
|
||||
@classmethod
|
||||
def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel):
|
||||
unit = cls(mutable_channel.num_channels,
|
||||
|
@ -61,7 +64,8 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit):
|
|||
|
||||
def prepare_for_pruning(self, model: nn.Module):
|
||||
"""Prepare for pruning."""
|
||||
super().prepare_for_pruning(model)
|
||||
if not self.unit_predefined:
|
||||
super().prepare_for_pruning(model)
|
||||
self.current_choice = self.max_choice
|
||||
|
||||
# ~
|
||||
|
|
|
@ -99,7 +99,7 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
|
|||
return len(self.choices)
|
||||
|
||||
@property
|
||||
def current_choice(self) -> Optional[Any]:
|
||||
def current_choice(self) -> Value:
|
||||
"""Current choice of mutable value."""
|
||||
return self._current_choice
|
||||
|
||||
|
@ -116,7 +116,7 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
|
|||
"""Please refer to method :func:`__mul__`."""
|
||||
return self * other
|
||||
|
||||
def __mul__(self, other: int) -> DerivedMutable:
|
||||
def __mul__(self, other: Union[int, float]) -> DerivedMutable:
|
||||
"""Overload `*` operator.
|
||||
|
||||
Args:
|
||||
|
@ -127,7 +127,8 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
|
|||
"""
|
||||
if isinstance(other, int):
|
||||
return self.derive_expand_mutable(other)
|
||||
|
||||
elif isinstance(other, float):
|
||||
return self.derive_expand_mutable(other)
|
||||
raise TypeError(f'Unsupported type {type(other)} for mul!')
|
||||
|
||||
def __floordiv__(self, other: Union[int, Tuple[int,
|
||||
|
@ -143,6 +144,8 @@ class MutableValue(BaseMutable, DerivedMethodMixin):
|
|||
"""
|
||||
if isinstance(other, int):
|
||||
return self.derive_divide_mutable(other)
|
||||
elif isinstance(other, float):
|
||||
return self.derive_divide_mutable(int(other))
|
||||
if isinstance(other, tuple):
|
||||
assert len(other) == 2
|
||||
return self.derive_divide_mutable(*other)
|
||||
|
|
|
@ -3,8 +3,10 @@ from .channel_mutator import (ChannelMutator, OneShotChannelMutator,
|
|||
SlimmableChannelMutator)
|
||||
from .module_mutator import (DiffModuleMutator, ModuleMutator,
|
||||
OneShotModuleMutator)
|
||||
from .value_mutator import DynamicValueMutator, ValueMutator
|
||||
|
||||
__all__ = [
|
||||
'OneShotModuleMutator', 'DiffModuleMutator', 'ModuleMutator',
|
||||
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator'
|
||||
'ChannelMutator', 'OneShotChannelMutator', 'SlimmableChannelMutator',
|
||||
'ValueMutator', 'DynamicValueMutator'
|
||||
]
|
||||
|
|
|
@ -358,4 +358,7 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin):
|
|||
|
||||
units = self.unit_class.init_from_predefined_model(model)
|
||||
|
||||
for unit in units:
|
||||
unit.unit_predefined = self.unit_default_args.pop(
|
||||
'unit_predefined', False)
|
||||
return units
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
@ -7,6 +7,11 @@ from torch.nn import Module
|
|||
|
||||
from ..mutables import BaseMutable
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class GroupMixin():
|
||||
"""A mixin for :class:`BaseMutator`, which can group mutables by
|
||||
|
@ -220,3 +225,49 @@ class GroupMixin():
|
|||
f'When a mutable is set alias attribute :{alias_key},'
|
||||
f'the corresponding module name {mutable_name} should '
|
||||
f'not be used in `custom_group` {custom_group}.')
|
||||
|
||||
|
||||
class MutatorProtocol(Protocol): # pragma: no cover
|
||||
|
||||
@property
|
||||
def mutable_class_type(self) -> Type[BaseMutable]:
|
||||
...
|
||||
|
||||
@property
|
||||
def search_groups(self) -> Dict:
|
||||
...
|
||||
|
||||
|
||||
class OneShotSampleMixin:
|
||||
|
||||
def sample_choices(self: MutatorProtocol) -> Dict:
|
||||
random_choices = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
random_choices[group_id] = modules[0].sample_choice()
|
||||
|
||||
return random_choices
|
||||
|
||||
def set_choices(self: MutatorProtocol, choices: Dict) -> None:
|
||||
for group_id, modules in self.search_groups.items():
|
||||
choice = choices[group_id]
|
||||
for module in modules:
|
||||
module.current_choice = choice
|
||||
|
||||
|
||||
class DynamicSampleMixin(OneShotSampleMixin):
|
||||
|
||||
@property
|
||||
def max_choices(self: MutatorProtocol) -> Dict:
|
||||
max_choices = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
max_choices[group_id] = modules[0].max_choice
|
||||
|
||||
return max_choices
|
||||
|
||||
@property
|
||||
def min_choices(self: MutatorProtocol) -> Dict:
|
||||
min_choices = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
min_choices[group_id] = modules[0].min_choice
|
||||
|
||||
return min_choices
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dynamic_value_mutator import DynamicValueMutator
|
||||
from .value_mutator import ValueMutator
|
||||
|
||||
__all__ = ['ValueMutator', 'DynamicValueMutator']
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmrazor.models.mutables import OneShotMutableValue
|
||||
from mmrazor.registry import MODELS
|
||||
from ..group_mixin import DynamicSampleMixin
|
||||
from .value_mutator import ValueMutator
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DynamicValueMutator(ValueMutator, DynamicSampleMixin):
|
||||
|
||||
@property
|
||||
def mutable_class_type(self):
|
||||
return OneShotMutableValue
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor.models.mutables import MutableValue
|
||||
from mmrazor.registry import MODELS
|
||||
from ..base_mutator import BaseMutator
|
||||
from ..group_mixin import GroupMixin
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ValueMutator(BaseMutator[MutableValue], GroupMixin):
|
||||
"""The base class for mutable based mutator. All subclass should implement
|
||||
the following APIS:
|
||||
|
||||
- ``mutable_class_type``
|
||||
Args:
|
||||
custom_group (list[list[str]], optional): User-defined search groups.
|
||||
All searchable modules that are not in ``custom_group`` will be
|
||||
grouped separately.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
custom_group: Optional[List[List[str]]] = None,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if custom_group is None:
|
||||
custom_group = []
|
||||
self._custom_group = custom_group
|
||||
self._search_groups: Optional[Dict[int, List[MutableValue]]] = None
|
||||
|
||||
# TODO
|
||||
# should be a class property
|
||||
@property
|
||||
def mutable_class_type(self) -> Type[MutableValue]:
|
||||
"""Corresponding mutable class type.
|
||||
|
||||
Returns:
|
||||
Type[MUTABLE_TYPE]: Mutable class type.
|
||||
"""
|
||||
return MutableValue
|
||||
|
||||
def prepare_from_supernet(self, supernet: Module) -> None:
|
||||
"""Do some necessary preparations with supernet.
|
||||
|
||||
Note:
|
||||
For mutable based mutator, we need to build search group first.
|
||||
Args:
|
||||
supernet (:obj:`torch.nn.Module`): The supernet to be searched
|
||||
in your algorithm.
|
||||
"""
|
||||
self._search_groups = self.build_search_groups(supernet,
|
||||
self.mutable_class_type,
|
||||
self._custom_group)
|
||||
|
||||
@property
|
||||
def search_groups(self) -> Dict[int, List[MutableValue]]:
|
||||
"""Search group of supernet.
|
||||
|
||||
Note:
|
||||
For mutable based mutator, the search group is composed of
|
||||
corresponding mutables.
|
||||
Raises:
|
||||
RuntimeError: Called before search group has been built.
|
||||
Returns:
|
||||
Dict[int, List[MUTABLE_TYPE]]: Search group.
|
||||
"""
|
||||
if self._search_groups is None:
|
||||
raise RuntimeError(
|
||||
'Call `prepare_from_supernet` before access search group!')
|
||||
return self._search_groups
|
|
@ -498,6 +498,9 @@ def add_flops_params_counter_variable_or_reset(module):
|
|||
module.__params__ = 0
|
||||
|
||||
|
||||
counter_warning_list = []
|
||||
|
||||
|
||||
def get_counter_type(module) -> str:
|
||||
"""Get counter type of the module based on the module class name.
|
||||
|
||||
|
@ -515,10 +518,13 @@ def get_counter_type(module) -> str:
|
|||
for base_cls in module.__class__.mro():
|
||||
if base_cls in get_modules_list():
|
||||
counter_type = base_cls.__name__ + 'Counter'
|
||||
from mmengine import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.warning(f'`{old_counter_type}` not in op_counters. '
|
||||
f'Using `{counter_type}` instead.')
|
||||
global counter_warning_list
|
||||
if old_counter_type not in counter_warning_list:
|
||||
from mmengine import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.warning(f'`{old_counter_type}` not in op_counters. '
|
||||
f'Using `{counter_type}` instead.')
|
||||
counter_warning_list.append(old_counter_type)
|
||||
break
|
||||
return counter_type
|
||||
|
||||
|
|
|
@ -1,35 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import UserList
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class Candidates(UserList):
|
||||
"""The data structure of sampled candidate. The format is [(any, float),
|
||||
(any, float), ...].
|
||||
|
||||
"""The data structure of sampled candidate. The format is Union[Dict[str,
|
||||
Dict], List[Dict[str, Dict]]].
|
||||
Examples:
|
||||
>>> candidates = Candidates()
|
||||
>>> subnet_1 = {'choice_1': 'layer_1', 'choice_2': 'layer_2'}
|
||||
>>> subnet_1 = {'1': 'choice1', '2': 'choice2'}
|
||||
>>> candidates.append(subnet_1)
|
||||
>>> candidates
|
||||
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.0)]
|
||||
>>> candidates.set_score(0, 0.9)
|
||||
[{"{'1': 'choice1', '2': 'choice2'}":
|
||||
{'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}]
|
||||
>>> candidates.set_resources(0, 49.9, 'flops')
|
||||
>>> candidates.set_score(0, 100.)
|
||||
>>> candidates
|
||||
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9)]
|
||||
[{"{'1': 'choice1', '2': 'choice2'}":
|
||||
{'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}]
|
||||
>>> subnet_2 = {'choice_3': 'layer_3', 'choice_4': 'layer_4'}
|
||||
>>> candidates.append((subnet_2, 0.5))
|
||||
>>> candidates.append(subnet_2)
|
||||
>>> candidates
|
||||
[({'choice_1': 'layer_1', 'choice_2': 'layer_2'}, 0.9),
|
||||
({'choice_3': 'layer_3', 'choice_4': 'layer_4'}, 0.5)]
|
||||
[{"{'1': 'choice1', '2': 'choice2'}":
|
||||
{'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}},
|
||||
{"{'choice_3': 'layer_3', 'choice_4':'layer_4'}":
|
||||
{'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}]
|
||||
>>> candidates.subnets
|
||||
[{'choice_1': 'layer_1', 'choice_2': 'layer_2'},
|
||||
[{'1': 'choice1', '2': 'choice2'},
|
||||
{'choice_3': 'layer_3', 'choice_4': 'layer_4'}]
|
||||
>>> candidates.resources('flops')
|
||||
[49.9, 0.0]
|
||||
>>> candidates.scores
|
||||
[0.9, 0.5]
|
||||
[100.0, 0.0]
|
||||
"""
|
||||
_format_return = Union[Tuple[Any, float], List[Tuple[Any, float]]]
|
||||
_format_return = Union[Dict[str, Dict], List[Dict[str, Dict]]]
|
||||
_format_input = Union[Dict, List[Dict], Dict[str, Dict], List[Dict[str,
|
||||
Dict]]]
|
||||
_indicators = ('score', 'flops', 'params', 'latency')
|
||||
|
||||
def __init__(self, initdata: Optional[Any] = None):
|
||||
def __init__(self, initdata: Optional[_format_input] = None):
|
||||
self.data = []
|
||||
if initdata is not None:
|
||||
initdata = self._format(initdata)
|
||||
|
@ -41,23 +50,59 @@ class Candidates(UserList):
|
|||
@property
|
||||
def scores(self) -> List[float]:
|
||||
"""The scores of candidates."""
|
||||
return [item[1] for item in self.data]
|
||||
return [
|
||||
round(value.get('score', 0.), 2) for item in self.data
|
||||
for _, value in item.items()
|
||||
]
|
||||
|
||||
def resources(self, key_indicator: str = 'flops') -> List[float]:
|
||||
"""The resources of candidates."""
|
||||
assert key_indicator in ['flops', 'params', 'latency']
|
||||
return [
|
||||
value.get(key_indicator, 0.) for item in self.data
|
||||
for _, value in item.items()
|
||||
]
|
||||
|
||||
@property
|
||||
def subnets(self) -> List[Dict]:
|
||||
"""The subnets of candidates."""
|
||||
return [item[0] for item in self.data]
|
||||
return [eval(key) for item in self.data for key, _ in item.items()]
|
||||
|
||||
def _format(self, data: Any) -> _format_return:
|
||||
"""Transform [any, ...] to [tuple(any, float), ...] Transform any to
|
||||
tuple(any, float)."""
|
||||
def _format(self, data: _format_input) -> _format_return:
|
||||
"""Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str,
|
||||
Dict]]].
|
||||
|
||||
def _format_item(item: Any):
|
||||
"""Transform any to tuple(any, float)."""
|
||||
if isinstance(item, tuple):
|
||||
return (item[0], float(item[1]))
|
||||
Args:
|
||||
data: Four types of input are supported:
|
||||
1. Dict: only include network information.
|
||||
2. List[Dict]: multiple candidates only include network
|
||||
information.
|
||||
3. Dict[str, Dict]: network information and the corresponding
|
||||
resources.
|
||||
4. List[Dict[str, Dict]]: multiple candidate information.
|
||||
Returns:
|
||||
Union[Dict[str, Dict], UserList[Dict[str, Dict]]]:
|
||||
A dict or a list of dict that contains a pair of network
|
||||
information and the corresponding Score | FLOPs | Params |
|
||||
Latency results in each candidate.
|
||||
Notes:
|
||||
Score | FLOPs | Params | Latency:
|
||||
1. a candidate resources with a default value of -1 indicates
|
||||
that it has not been estimated.
|
||||
2. a candidate resources with a default value of 0 indicates
|
||||
that some indicators have been evaluated.
|
||||
"""
|
||||
|
||||
def _format_item(
|
||||
cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]:
|
||||
"""Transform Dict to Dict[str, Dict]."""
|
||||
if isinstance(list(cond.values())[0], dict):
|
||||
for value in list(cond.values()):
|
||||
for key in list(self._indicators):
|
||||
value.setdefault(key, 0.)
|
||||
return cond
|
||||
else:
|
||||
return (item, 0.)
|
||||
return {str(cond): {}.fromkeys(self._indicators, -1)}
|
||||
|
||||
if isinstance(data, UserList):
|
||||
return [_format_item(i) for i in data.data]
|
||||
|
@ -68,12 +113,15 @@ class Candidates(UserList):
|
|||
else:
|
||||
return _format_item(data)
|
||||
|
||||
def append(self, item: Any) -> None:
|
||||
def append(self, item: _format_input) -> None:
|
||||
"""Append operation."""
|
||||
item = self._format(item)
|
||||
self.data.append(item)
|
||||
if isinstance(item, list):
|
||||
self.data = self.data + item
|
||||
else:
|
||||
self.data.append(item)
|
||||
|
||||
def insert(self, i: int, item: Any) -> None:
|
||||
def insert(self, i: int, item: _format_input) -> None:
|
||||
"""Insert operation."""
|
||||
item = self._format(item)
|
||||
self.data.insert(i, item)
|
||||
|
@ -88,4 +136,35 @@ class Candidates(UserList):
|
|||
|
||||
def set_score(self, i: int, score: float) -> None:
|
||||
"""Set score to the specified subnet by index."""
|
||||
self.data[i] = (self.data[i][0], float(score))
|
||||
self.set_resource(i, score, 'score')
|
||||
|
||||
def set_resource(self,
|
||||
i: int,
|
||||
resources: float,
|
||||
key_indicator: str = 'flops') -> None:
|
||||
"""Set resources to the specified subnet by index."""
|
||||
assert key_indicator in ['score', 'flops', 'params', 'latency']
|
||||
for _, value in self.data[i].items():
|
||||
value[key_indicator] = resources
|
||||
|
||||
def update_resources(self, resources: list, start: int = 0) -> None:
|
||||
"""Update resources to the specified candidate."""
|
||||
end = start + len(resources)
|
||||
assert len(
|
||||
self.data) >= end, 'Check the number of candidate resources.'
|
||||
for i, item in enumerate(self.data[start:end]):
|
||||
for _, value in item.items():
|
||||
value.update(resources[i])
|
||||
|
||||
def sort_by(self,
|
||||
key_indicator: str = 'score',
|
||||
reverse: bool = True) -> None:
|
||||
"""Sort by a specific indicator in descending order.
|
||||
|
||||
Args:
|
||||
key_indicator (str): sort all candidates by key_indicator.
|
||||
Defaults to 'score'.
|
||||
reverse (bool): sort all candidates in descending order.
|
||||
"""
|
||||
self.data.sort(
|
||||
key=lambda x: list(x.values())[0][key_indicator], reverse=reverse)
|
||||
|
|
|
@ -43,13 +43,15 @@ def load_fix_subnet(model: nn.Module,
|
|||
raise RuntimeError('Root model can not be dynamic op.')
|
||||
|
||||
# Avoid circular import
|
||||
from mmrazor.models.mutables import DerivedMutable
|
||||
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# The format of `chosen`` is different for each type of mutable.
|
||||
# In the corresponding mutable, it will check whether the `chosen`
|
||||
# format is correct.
|
||||
if isinstance(module, (MutableChannelContainer, DerivedMutable)):
|
||||
continue
|
||||
if isinstance(module, BaseMutable):
|
||||
if not module.is_fixed:
|
||||
if getattr(module, 'alias', None):
|
||||
|
@ -61,8 +63,8 @@ def load_fix_subnet(model: nn.Module,
|
|||
chosen = fix_mutable.get(alias, None)
|
||||
else:
|
||||
mutable_name = name.lstrip(prefix)
|
||||
if mutable_name not in fix_mutable and \
|
||||
not isinstance(module, DerivedMutable):
|
||||
if mutable_name not in fix_mutable and not isinstance(
|
||||
module, (DerivedMutable, MutableChannelContainer)):
|
||||
raise RuntimeError(
|
||||
f'The module name {mutable_name} is not in '
|
||||
'fix_mutable, please check your `fix_mutable`.')
|
||||
|
@ -87,13 +89,15 @@ def export_fix_subnet(model: nn.Module,
|
|||
level=logging.WARNING)
|
||||
|
||||
# Avoid circular import
|
||||
from mmrazor.models.mutables import DerivedMutable
|
||||
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
|
||||
fix_subnet = dict()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
if isinstance(module, DerivedMutable) and not dump_derived_mutable:
|
||||
if isinstance(module,
|
||||
(MutableChannelContainer,
|
||||
DerivedMutable)) and not dump_derived_mutable:
|
||||
continue
|
||||
|
||||
if module.alias:
|
||||
|
|
|
@ -3,7 +3,7 @@ from torch.nn import Module
|
|||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
|
||||
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
|
||||
from mmrazor.models.mutables import MutableChannelUnit
|
||||
from mmrazor.models.mutables import DerivedMutable
|
||||
|
@ -13,6 +13,9 @@ from mmrazor.registry import MODELS
|
|||
from mmengine.model import BaseModel
|
||||
# this file includes models for tesing.
|
||||
|
||||
from mmrazor.models.mutables import OneShotMutableValue
|
||||
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
|
||||
|
||||
|
||||
class LinearHead(Module):
|
||||
|
||||
|
@ -475,7 +478,7 @@ class DwConvModel(nn.Module):
|
|||
|
||||
|
||||
def register_mutable(module: DynamicChannelMixin,
|
||||
mutable: OneShotMutableChannelUnit,
|
||||
mutable: MutableChannelUnit,
|
||||
is_out=True,
|
||||
start=0,
|
||||
end=-1):
|
||||
|
@ -581,6 +584,95 @@ class DynamicLinearModel(nn.Module):
|
|||
self.linear, mutable2, False)
|
||||
|
||||
|
||||
class DynamicAttention(nn.Module):
|
||||
"""
|
||||
x
|
||||
|blocks: DynamicSequential(depth)
|
||||
|(blocks)
|
||||
x1
|
||||
|fc (OneShotMutableChannel * OneShotMutableValue)
|
||||
output
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.mutable_depth = OneShotMutableValue(
|
||||
value_list=[1, 2], default_value=2)
|
||||
self.mutable_embed_dims = OneShotMutableChannel(
|
||||
num_channels=624, candidate_choices=[576, 624])
|
||||
self.base_embed_dims = OneShotMutableChannel(
|
||||
num_channels=64, candidate_choices=[64])
|
||||
self.mutable_num_heads = [
|
||||
OneShotMutableValue(
|
||||
value_list=[8, 10],
|
||||
default_value=10)
|
||||
for _ in range(2)
|
||||
]
|
||||
self.mutable_mlp_ratios = [
|
||||
OneShotMutableValue(
|
||||
value_list=[3.0, 3.5, 4.0],
|
||||
default_value=4.0)
|
||||
for _ in range(2)
|
||||
]
|
||||
self.mutable_q_embed_dims = [
|
||||
i * self.base_embed_dims for i in self.mutable_num_heads
|
||||
]
|
||||
|
||||
self.patch_embed = DynamicPatchEmbed(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=self.mutable_embed_dims.num_channels)
|
||||
|
||||
# cls token and pos embed
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, 197,
|
||||
self.mutable_embed_dims.num_channels))
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
|
||||
|
||||
layers = []
|
||||
for i in range(self.mutable_depth.max_choice):
|
||||
layer = TransformerEncoderLayer(
|
||||
embed_dims=self.mutable_embed_dims.num_channels,
|
||||
num_heads=self.mutable_num_heads[i].max_choice,
|
||||
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
|
||||
layers.append(layer)
|
||||
self.blocks = DynamicSequential(*layers)
|
||||
|
||||
# OneShotMutableChannelUnit
|
||||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self, MutableChannelContainer)
|
||||
|
||||
self.register_mutables()
|
||||
|
||||
def register_mutables(self):
|
||||
# mutablevalue
|
||||
self.blocks.register_mutable_attr('depth', self.mutable_depth)
|
||||
# mutablechannel
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.patch_embed, self.mutable_embed_dims, True)
|
||||
|
||||
for i in range(self.mutable_depth.max_choice):
|
||||
layer = self.blocks[i]
|
||||
layer.register_mutables(
|
||||
mutable_num_heads=self.mutable_num_heads[i],
|
||||
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
|
||||
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
|
||||
mutable_head_dims=self.base_embed_dims,
|
||||
mutable_embed_dims=self.mutable_embed_dims)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
embed_dims = self.mutable_embed_dims.current_choice
|
||||
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed[..., :embed_dims]
|
||||
x = self.blocks(x)
|
||||
return torch.mean(x[:, 1:], dim=1)
|
||||
|
||||
|
||||
default_models = [
|
||||
LineModel,
|
||||
ResBlock,
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import Autoformer
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
arch_setting = dict(
|
||||
mlp_ratios=[3.0, 3.5, 4.0],
|
||||
num_heads=[8, 9, 10],
|
||||
depth=[14, 15, 16],
|
||||
embed_dims=[528, 576, 624])
|
||||
|
||||
MUTATOR_CFG = dict(
|
||||
channel_mutator=dict(
|
||||
type='mmrazor.OneShotChannelMutator',
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': True
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'}),
|
||||
value_mutator=dict(type='mmrazor.DynamicValueMutator'))
|
||||
|
||||
ARCHITECTURE_CFG = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SearchableImageClassifier',
|
||||
backbone=dict(
|
||||
_scope_='mmrazor',
|
||||
type='AutoformerBackbone',
|
||||
arch_setting=arch_setting),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='DynamicLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=624,
|
||||
loss=dict(
|
||||
type='mmcls.LabelSmoothLoss',
|
||||
mode='original',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
|
||||
)
|
||||
|
||||
ALGORITHM_CFG = dict(
|
||||
type='mmrazor.Autoformer',
|
||||
architecture=ARCHITECTURE_CFG,
|
||||
fix_subnet=None,
|
||||
mutators=dict(
|
||||
channel_mutator=dict(
|
||||
type='mmrazor.OneShotChannelMutator',
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': True
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'}),
|
||||
value_mutator=dict(type='mmrazor.DynamicValueMutator')))
|
||||
|
||||
|
||||
class TestAUTOFORMER(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
ALGORITHM_CFG_SUPERNET = copy.deepcopy(ALGORITHM_CFG)
|
||||
# initiate autoformer with built `algorithm`.
|
||||
autoformer_algo = MODELS.build(ALGORITHM_CFG_SUPERNET)
|
||||
self.assertIsInstance(autoformer_algo, Autoformer)
|
||||
# autoformer mutators include channel_mutator and value_mutator
|
||||
assert 'channel_mutator' in autoformer_algo.mutators
|
||||
assert 'value_mutator' in autoformer_algo.mutators
|
||||
|
||||
# autoformer search_groups
|
||||
random_subnet = autoformer_algo.sample_subnet()
|
||||
self.assertIsInstance(random_subnet, dict)
|
||||
|
||||
# autoformer_algo support training
|
||||
self.assertTrue(autoformer_algo.is_supernet)
|
||||
|
||||
# initiate autoformer without any `mutator`.
|
||||
ALGORITHM_CFG_SUPERNET.pop('type')
|
||||
ALGORITHM_CFG_SUPERNET['mutators'] = None
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'mutator cannot be None when fix_subnet is None.'):
|
||||
_ = Autoformer(**ALGORITHM_CFG_SUPERNET)
|
||||
|
||||
# initiate autoformer with error type `mutator`.
|
||||
backwardtracer_cfg = dict(
|
||||
type='OneShotChannelMutator',
|
||||
channel_unit_cfg=dict(
|
||||
type='OneShotMutableChannelUnit',
|
||||
default_args=dict(
|
||||
candidate_choices=list(i / 12 for i in range(2, 13)),
|
||||
choice_mode='ratio')),
|
||||
parse_cfg=dict(
|
||||
type='BackwardTracer',
|
||||
loss_calculator=dict(type='ImageClassifierPseudoLoss')))
|
||||
ALGORITHM_CFG_SUPERNET['mutators'] = dict(
|
||||
channel_mutator=backwardtracer_cfg,
|
||||
value_mutator=dict(type='mmrazor.DynamicValueMutator'))
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'autoformer only support predefined.'):
|
||||
_ = Autoformer(**ALGORITHM_CFG_SUPERNET)
|
||||
|
||||
def test_loss(self):
|
||||
# supernet
|
||||
inputs = torch.randn(1, 3, 224, 224)
|
||||
autoformer = MODELS.build(ALGORITHM_CFG)
|
||||
loss = autoformer(inputs)
|
||||
assert loss.size(1) == 1000
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmrazor.models.architectures.dynamic_ops import (
|
||||
DynamicLinear, DynamicMultiheadAttention, DynamicPatchEmbed,
|
||||
DynamicRelativePosition2D, DynamicSequential)
|
||||
from mmrazor.models.mutables import MutableChannelContainer
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
arch_setting = dict(
|
||||
mlp_ratios=[3.0, 3.5, 4.0],
|
||||
num_heads=[8, 9, 10],
|
||||
depth=[14, 15, 16],
|
||||
embed_dims=[528, 576, 624])
|
||||
|
||||
BACKBONE_CFG = dict(
|
||||
type='mmrazor.AutoformerBackbone',
|
||||
arch_setting=arch_setting,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
norm_cfg=dict(type='mmrazor.DynamicLayerNorm'),
|
||||
act_cfg=dict(type='GELU'))
|
||||
|
||||
|
||||
def test_searchable_autoformer_mutable() -> None:
|
||||
backbone = MODELS.build(BACKBONE_CFG)
|
||||
|
||||
num_heads = backbone.arch_setting['num_heads']
|
||||
mlp_ratios = backbone.arch_setting['mlp_ratios']
|
||||
depth = backbone.arch_setting['depth']
|
||||
embed_dims = backbone.arch_setting['embed_dims']
|
||||
embed_dims_expansion = [i * j for i in mlp_ratios for j in embed_dims]
|
||||
head_expansion = [i * 64 for i in num_heads]
|
||||
|
||||
for name, module in backbone.named_modules():
|
||||
if isinstance(module, DynamicRelativePosition2D):
|
||||
assert len(module.mutable_head_dims.current_choice) == 64
|
||||
elif isinstance(module, DynamicMultiheadAttention):
|
||||
assert len(
|
||||
module.mutable_embed_dims.current_choice) == max(embed_dims)
|
||||
assert len(module.mutable_q_embed_dims.current_choice) == max(
|
||||
head_expansion)
|
||||
assert module.mutable_num_heads.choices == num_heads
|
||||
elif isinstance(module, DynamicLinear):
|
||||
if 'fc1' in name:
|
||||
assert module.mutable_attrs['in_features'].num_channels == max(
|
||||
embed_dims)
|
||||
assert module.mutable_attrs[
|
||||
'out_features'].num_channels == max(embed_dims_expansion)
|
||||
elif 'fc2' in name:
|
||||
assert module.mutable_attrs['in_features'].num_channels == max(
|
||||
embed_dims_expansion)
|
||||
assert module.mutable_attrs[
|
||||
'out_features'].num_channels == max(embed_dims)
|
||||
elif isinstance(module, DynamicPatchEmbed):
|
||||
assert type(module.mutable_embed_dims) == MutableChannelContainer
|
||||
assert len(
|
||||
module.mutable_embed_dims.current_choice) == max(embed_dims)
|
||||
elif isinstance(module, DynamicSequential):
|
||||
assert module.mutable_depth.choices == depth
|
||||
assert backbone.last_mutable.num_channels == max(embed_dims)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicMultiheadAttention
|
||||
from mmrazor.models.architectures.ops import MultiheadAttention
|
||||
from mmrazor.models.mutables import (MutableChannelContainer,
|
||||
OneShotMutableChannel,
|
||||
OneShotMutableChannelUnit,
|
||||
OneShotMutableValue)
|
||||
|
||||
|
||||
class TestDynamicMHA(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.mutable_num_heads = OneShotMutableValue(
|
||||
value_list=[2, 4, 8], default_value=8)
|
||||
self.mutable_embed_dims = OneShotMutableChannel(num_channels=128)
|
||||
self.base_embed_dims = OneShotMutableChannel(
|
||||
num_channels=8, candidate_choices=[8])
|
||||
self.mutable_q_embed_dims = self.mutable_num_heads * \
|
||||
self.base_embed_dims
|
||||
|
||||
self.dynamic_m = DynamicMultiheadAttention(embed_dims=128, num_heads=8)
|
||||
|
||||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self.dynamic_m, MutableChannelContainer)
|
||||
|
||||
self.dynamic_m.register_mutable_attr('num_heads',
|
||||
self.mutable_num_heads)
|
||||
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.dynamic_m, self.mutable_embed_dims, False)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.dynamic_m, self.mutable_q_embed_dims, True, end=64)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.dynamic_m.rel_pos_embed_k, self.base_embed_dims, False)
|
||||
MutableChannelContainer.register_mutable_channel_to_module(
|
||||
self.dynamic_m.rel_pos_embed_v, self.base_embed_dims, False)
|
||||
|
||||
def test_forward(self) -> None:
|
||||
x = torch.randn(8, 197, 128)
|
||||
output = self.dynamic_m(x)
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
def test_convert(self) -> None:
|
||||
static_m = MultiheadAttention(embed_dims=100, num_heads=10)
|
||||
dynamic_m = DynamicMultiheadAttention.convert_from(static_m)
|
||||
self.assertIsNotNone(dynamic_m)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from torch.nn import Sequential
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicSequential
|
||||
from mmrazor.models.mutables import OneShotMutableValue
|
||||
|
||||
|
||||
class TestDynamicSequential(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.layers = [
|
||||
nn.Linear(4, 5),
|
||||
nn.Linear(5, 6),
|
||||
nn.Linear(6, 7),
|
||||
nn.Linear(7, 8),
|
||||
]
|
||||
self.dynamic_m = DynamicSequential(*self.layers)
|
||||
mutable_depth = OneShotMutableValue(
|
||||
value_list=[2, 3, 4], default_value=3)
|
||||
|
||||
self.dynamic_m.register_mutable_attr('depth', mutable_depth)
|
||||
|
||||
def test_init(self) -> None:
|
||||
self.assertEqual(
|
||||
self.dynamic_m.get_mutable_attr('depth').current_choice, 3)
|
||||
|
||||
def test_to_static_op(self) -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
self.dynamic_m.to_static_op()
|
||||
|
||||
current_mutable = self.dynamic_m.get_mutable_attr('depth')
|
||||
current_mutable.fix_chosen(current_mutable.dump_chosen().chosen)
|
||||
|
||||
static_op = self.dynamic_m.to_static_op()
|
||||
self.assertIsNotNone(static_op)
|
||||
|
||||
def test_convert_from(self) -> None:
|
||||
static_m = Sequential(*self.layers)
|
||||
|
||||
dynamic_m = DynamicSequential.convert_from(static_m)
|
||||
|
||||
self.assertIsNotNone(dynamic_m)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
from mmcls.models.utils import PatchEmbed
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicPatchEmbed
|
||||
from mmrazor.models.mutables import SquentialMutableChannel
|
||||
|
||||
|
||||
class TestPatchEmbed(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dynamic_embed = DynamicPatchEmbed(
|
||||
img_size=224, in_channels=3, embed_dims=100)
|
||||
|
||||
mutable_embed_dims = SquentialMutableChannel(num_channels=100)
|
||||
|
||||
mutable_embed_dims.current_choice = 50
|
||||
self.dynamic_embed.register_mutable_attr('embed_dims',
|
||||
mutable_embed_dims)
|
||||
|
||||
def test_patch_embed(self):
|
||||
mutable = SquentialMutableChannel(num_channels=120)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
self.dynamic_embed.register_mutable_attr('embed_dims', mutable)
|
||||
|
||||
self.assertTrue(
|
||||
self.dynamic_embed.get_mutable_attr('embed_dims').current_choice ==
|
||||
50)
|
||||
|
||||
def test_convert(self):
|
||||
static_m = PatchEmbed(img_size=224, in_channels=3, embed_dims=768)
|
||||
|
||||
dynamic_m = DynamicPatchEmbed.convert_from(static_m)
|
||||
|
||||
self.assertIsNotNone(dynamic_m)
|
||||
|
||||
def test_to_static_op(self):
|
||||
mutable_embed_dims = SquentialMutableChannel(num_channels=100)
|
||||
|
||||
mutable_embed_dims.current_choice = 10
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
self.dynamic_embed.to_static_op()
|
||||
|
||||
mutable_embed_dims.fix_chosen(mutable_embed_dims.dump_chosen().chosen)
|
||||
self.dynamic_embed.register_mutable_attr('embed_dims',
|
||||
mutable_embed_dims)
|
||||
static_op = self.dynamic_embed.to_static_op()
|
||||
|
||||
self.assertIsNotNone(static_op)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicLayerNorm
|
||||
from mmrazor.models.mutables import SquentialMutableChannel
|
||||
|
||||
|
||||
class TestDynamicLayerNorm(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.dynamic_m = DynamicLayerNorm(100)
|
||||
|
||||
mutable_num_features = SquentialMutableChannel(num_channels=100)
|
||||
|
||||
mutable_num_features.current_choice = 50
|
||||
|
||||
self.dynamic_m.register_mutable_attr('num_features',
|
||||
mutable_num_features)
|
||||
|
||||
def test_init(self) -> None:
|
||||
mutable = SquentialMutableChannel(num_channels=100)
|
||||
self.dynamic_m.register_mutable_attr('in_channels', mutable)
|
||||
self.dynamic_m.register_mutable_attr('out_channels', mutable)
|
||||
|
||||
self.assertEqual(
|
||||
self.dynamic_m.get_mutable_attr('num_features').current_choice, 50)
|
||||
|
||||
def test_to_static_op(self):
|
||||
with pytest.raises(RuntimeError):
|
||||
self.dynamic_m.to_static_op()
|
||||
|
||||
current_mutable = self.dynamic_m.get_mutable_attr('num_features')
|
||||
current_mutable.fix_chosen(current_mutable.dump_chosen().chosen)
|
||||
static_op = self.dynamic_m.to_static_op()
|
||||
|
||||
self.assertIsNotNone(static_op)
|
||||
|
||||
def test_convert(self) -> None:
|
||||
static_m = LayerNorm(100)
|
||||
dynamic_m = DynamicLayerNorm.convert_from(static_m)
|
||||
|
||||
self.assertIsNotNone(dynamic_m)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicRelativePosition2D
|
||||
from mmrazor.models.architectures.ops import RelativePosition2D
|
||||
from mmrazor.models.mutables import SquentialMutableChannel
|
||||
|
||||
|
||||
class TestDynamicRP(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
mutable_head_dims = SquentialMutableChannel(num_channels=8)
|
||||
|
||||
self.dynamic_rp = DynamicRelativePosition2D(
|
||||
head_dims=8, max_relative_position=14)
|
||||
|
||||
mutable_head_dims.current_choice = 6
|
||||
self.dynamic_rp.register_mutable_attr('head_dims', mutable_head_dims)
|
||||
|
||||
def test_mutable_attrs(self) -> None:
|
||||
|
||||
assert self.dynamic_rp.mutable_head_dims.current_choice == 6
|
||||
|
||||
embed = self.dynamic_rp.forward(14, 14)
|
||||
|
||||
self.assertIsNotNone(embed)
|
||||
|
||||
def test_convert(self):
|
||||
static_model = RelativePosition2D(
|
||||
head_dims=10, max_relative_position=14)
|
||||
|
||||
dynamic_model = DynamicRelativePosition2D.convert_from(static_model)
|
||||
|
||||
self.assertIsNotNone(dynamic_model)
|
||||
|
||||
def test_to_static_op(self):
|
||||
with pytest.raises(RuntimeError):
|
||||
static_m = self.dynamic_rp.to_static_op()
|
||||
|
||||
mutable = SquentialMutableChannel(num_channels=8)
|
||||
mutable.current_choice = 4
|
||||
|
||||
mutable.fix_chosen(mutable.dump_chosen().chosen)
|
||||
|
||||
self.dynamic_rp.register_mutable_attr('head_dims', mutable)
|
||||
static_m = self.dynamic_rp.to_static_op()
|
||||
|
||||
self.assertIsNotNone(static_m)
|
||||
|
||||
dynamic_output = self.dynamic_rp.forward(14, 14)
|
||||
static_output = static_m.forward(14, 14)
|
||||
self.assertTrue(torch.equal(dynamic_output, static_output))
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.models import SearchableImageClassifier
|
||||
|
||||
|
||||
class TestSearchableImageClassifier(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
arch_setting = dict(
|
||||
mlp_ratios=[3.0, 3.5, 4.0],
|
||||
num_heads=[8, 9, 10],
|
||||
depth=[14, 15, 16],
|
||||
embed_dims=[528, 576, 624])
|
||||
|
||||
supernet_kwargs = dict(
|
||||
backbone=dict(
|
||||
_scope_='mmrazor',
|
||||
type='AutoformerBackbone',
|
||||
arch_setting=arch_setting),
|
||||
neck=None,
|
||||
head=dict(
|
||||
_scope_='mmrazor',
|
||||
type='DynamicLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=624,
|
||||
loss=dict(
|
||||
type='mmcls.LabelSmoothLoss',
|
||||
mode='original',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
|
||||
)
|
||||
|
||||
supernet = SearchableImageClassifier(**supernet_kwargs)
|
||||
|
||||
# test connect_with_backbone
|
||||
self.assertEqual(
|
||||
supernet.backbone.last_mutable.activated_channels,
|
||||
len(
|
||||
supernet.head.fc.get_mutable_attr(
|
||||
'in_channels').current_choice))
|
|
@ -123,6 +123,27 @@ class TestDerivedMutable(TestCase):
|
|||
mv.current_choice == 120
|
||||
assert mv_derived.current_choice == 16
|
||||
|
||||
mc_derived = mc // 8.0
|
||||
assert mc_derived.source_mutables == {mc}
|
||||
|
||||
mc.current_choice = 128.
|
||||
assert mc_derived.current_choice == 16
|
||||
assert torch.equal(mc_derived.current_mask,
|
||||
torch.ones(16, dtype=torch.bool))
|
||||
mc.current_choice = 120.
|
||||
assert mc_derived.current_choice == 16
|
||||
assert torch.equal(mc_derived.current_mask,
|
||||
torch.ones(16, dtype=torch.bool))
|
||||
|
||||
mv = OneShotMutableValue(value_list=[112, 120, 128])
|
||||
mv_derived = mv // 8.0
|
||||
assert mv_derived.source_mutables == {mv}
|
||||
|
||||
mv.current_choice == 128.
|
||||
assert mv_derived.current_choice == 16
|
||||
mv.current_choice == 120.
|
||||
assert mv_derived.current_choice == 16
|
||||
|
||||
def test_source_mutables(self) -> None:
|
||||
|
||||
def useless_fn(x):
|
||||
|
@ -207,6 +228,43 @@ class TestDerivedMutable(TestCase):
|
|||
derived_e.current_mask,
|
||||
torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool))
|
||||
|
||||
def test_mutable_channel_value_calculation(self) -> None:
|
||||
mc = SquentialMutableChannel(num_channels=10)
|
||||
mv = OneShotMutableValue(value_list=[2.0, 2.5, 3.0, 3.5])
|
||||
derived_mutable = mc * mv
|
||||
assert derived_mutable.source_mutables == {mv, mc}
|
||||
|
||||
mc.current_choice = 6
|
||||
mv.current_choice = 3.5
|
||||
assert derived_mutable.current_choice == 21
|
||||
|
||||
mc.current_choice = 9
|
||||
mv.current_choice = 3.5
|
||||
assert derived_mutable.current_choice == 31
|
||||
|
||||
mc.current_choice = 7
|
||||
mv.current_choice = 2.5
|
||||
assert derived_mutable.current_choice == 17
|
||||
|
||||
assert isinstance(derived_mutable, BaseMutable)
|
||||
assert isinstance(derived_mutable, DerivedMutable)
|
||||
assert not derived_mutable.is_fixed
|
||||
|
||||
mc.current_choice = mc.num_channels
|
||||
mv.current_choice = mv.min_choice
|
||||
assert derived_mutable.current_choice == \
|
||||
mv.current_choice * mc.num_channels
|
||||
mv.current_choice = mv.max_choice
|
||||
assert derived_mutable.current_choice == \
|
||||
mv.current_choice * mc.current_choice
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
derived_mutable.is_fixed = True
|
||||
mc.fix_chosen(mc.dump_chosen().chosen)
|
||||
assert not derived_mutable.is_fixed
|
||||
mv.fix_chosen(mv.dump_chosen().chosen)
|
||||
assert derived_mutable.is_fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expand_ratio', [1, 2, 3])
|
||||
def test_derived_expand_mutable(expand_ratio: int) -> None:
|
||||
|
@ -232,3 +290,29 @@ def test_derived_expand_mutable(expand_ratio: int) -> None:
|
|||
|
||||
mv.current_choice = 5
|
||||
assert mv_derived.current_choice == 5 * expand_ratio
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expand_ratio', [1.5, 2.0, 2.5])
|
||||
def test_derived_expand_mutable_float(expand_ratio: float) -> None:
|
||||
mv = OneShotMutableValue(value_list=[3, 5, 7])
|
||||
|
||||
mv_derived = mv * expand_ratio
|
||||
assert mv_derived.source_mutables == {mv}
|
||||
|
||||
assert isinstance(mv_derived, BaseMutable)
|
||||
assert isinstance(mv_derived, DerivedMutable)
|
||||
assert not mv_derived.is_fixed
|
||||
assert mv_derived.num_choices == 1
|
||||
|
||||
mv.current_choice = mv.max_choice
|
||||
assert mv_derived.current_choice == int(mv.current_choice * expand_ratio)
|
||||
mv.current_choice = mv.min_choice
|
||||
assert mv_derived.current_choice == int(mv.current_choice * expand_ratio)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
mv_derived.current_choice = 123
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = mv_derived.current_mask
|
||||
|
||||
mv.current_choice = 5
|
||||
assert mv_derived.current_choice == int(5 * expand_ratio)
|
||||
|
|
|
@ -3,7 +3,8 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models.mutables import SquentialMutableChannel
|
||||
from mmrazor.models.mutables import (OneShotMutableValue,
|
||||
SquentialMutableChannel)
|
||||
|
||||
|
||||
class TestSquentialMutableChannel(TestCase):
|
||||
|
@ -41,3 +42,16 @@ class TestSquentialMutableChannel(TestCase):
|
|||
channel = SquentialMutableChannel(10, choice_mode='ratio')
|
||||
self._test_mutable(channel, 0.5, 0.5, 5, self._generate_mask(5, 10))
|
||||
self._test_mutable(channel, 2, 0.2, 2, self._generate_mask(2, 10))
|
||||
|
||||
def test_mutable_channel_mul(self):
|
||||
channel = SquentialMutableChannel(2)
|
||||
self.assertEqual(channel.current_choice, 2)
|
||||
mv = OneShotMutableValue(value_list=[1, 2, 3], default_value=3)
|
||||
derived1 = channel * mv
|
||||
derived2 = mv * channel
|
||||
assert derived1.current_choice == 6
|
||||
assert derived2.current_choice == 6
|
||||
mv.current_choice = mv.min_choice
|
||||
assert derived1.current_choice == 2
|
||||
assert derived2.current_choice == 2
|
||||
assert torch.equal(derived1.current_mask, derived2.current_mask)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from mmrazor.models.mutables import OneShotMutableChannelUnit
|
||||
from mmrazor.models.mutators.channel_mutator import ChannelMutator
|
||||
from .....data.models import DynamicAttention
|
||||
|
||||
|
||||
class TestSequentialMutableChannelUnit(TestCase):
|
||||
|
@ -14,3 +16,20 @@ class TestSequentialMutableChannelUnit(TestCase):
|
|||
unit = OneShotMutableChannelUnit(
|
||||
48, [0.3, 0.5, 0.7], choice_mode='ratio', divisor=8)
|
||||
self.assertSequenceEqual(unit.candidate_choices, [1 / 3, 0.5, 2 / 3])
|
||||
|
||||
def test_unit_predefined(self):
|
||||
model = DynamicAttention()
|
||||
mutator = ChannelMutator(
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': False
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'})
|
||||
mutator.prepare_from_supernet(model)
|
||||
choices = mutator.sample_choices()
|
||||
mutator.set_choices(choices)
|
||||
self.assertSequenceEqual(mutator.units[0].candidate_choices,
|
||||
[576, 624])
|
||||
self.assertSequenceEqual(mutator.units[1].candidate_choices, [64])
|
||||
|
|
|
@ -73,9 +73,6 @@ class TestMutableValue(TestCase):
|
|||
assert mul_derived_mv.current_choice == 4
|
||||
assert rmul_derived_mv.current_choice == 4
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_ = mv * 1.2
|
||||
|
||||
mv = MutableValue(value_list=[1, 2, 3], default_value=3)
|
||||
mc = SquentialMutableChannel(num_channels=4)
|
||||
|
||||
|
@ -114,9 +111,6 @@ class TestMutableValue(TestCase):
|
|||
mv.current_choice = 136
|
||||
assert derived_mv.current_choice == 18
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_ = mv // 1.2
|
||||
|
||||
def test_repr(self) -> None:
|
||||
value_list = [2, 4, 6]
|
||||
mv = MutableValue(value_list=value_list)
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmrazor.models.mutables.mutable_channel import (
|
|||
L1MutableChannelUnit, SequentialMutableChannelUnit)
|
||||
from mmrazor.models.mutators.channel_mutator import ChannelMutator
|
||||
from mmrazor.registry import MODELS
|
||||
from ...data.models import DynamicLinearModel
|
||||
from ...data.models import DynamicAttention, DynamicLinearModel
|
||||
from ...test_core.test_graph.test_graph import TestGraph
|
||||
|
||||
sys.setrecursionlimit(2000)
|
||||
|
@ -135,6 +135,30 @@ class TestChannelMutator(unittest.TestCase):
|
|||
mutator.prepare_from_supernet(model)
|
||||
self._test_a_mutator(mutator, model)
|
||||
|
||||
def test_models_with_predefined_dynamic_op_without_pruning(self):
|
||||
for Model in [
|
||||
DynamicAttention,
|
||||
]:
|
||||
with self.subTest(model=Model):
|
||||
model = Model()
|
||||
mutator = ChannelMutator(
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': True
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'})
|
||||
mutator.prepare_from_supernet(model)
|
||||
choices = mutator.sample_choices()
|
||||
mutator.set_choices(choices)
|
||||
self.assertGreater(len(mutator.mutable_units), 0)
|
||||
x = torch.rand([2, 3, 224, 224])
|
||||
y = model(x)
|
||||
self.assertEqual(
|
||||
list(y.shape),
|
||||
[2, list(mutator.current_choices.values())[0]])
|
||||
|
||||
def test_custom_group(self):
|
||||
ARCHITECTURE_CFG = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor.models.mutables import MutableValue
|
||||
from mmrazor.models.mutators import DynamicValueMutator
|
||||
from ...data.models import DynamicAttention
|
||||
|
||||
|
||||
class TestValueMutator(unittest.TestCase):
|
||||
|
||||
def test_models_with_predefined_dynamic_op(self):
|
||||
for Model in [
|
||||
DynamicAttention,
|
||||
]:
|
||||
with self.subTest(model=Model):
|
||||
model = Model()
|
||||
value_mutator = DynamicValueMutator()
|
||||
value_mutator.prepare_from_supernet(model)
|
||||
value_choices = value_mutator.sample_choices()
|
||||
value_mutator.set_choices(value_choices)
|
||||
|
||||
mutable_value_space = []
|
||||
for mutable_value, module in model.named_modules():
|
||||
if isinstance(module, MutableValue):
|
||||
mutable_value_space.append(mutable_value)
|
||||
assert len(
|
||||
value_mutator.search_groups) == len(mutable_value_space)
|
||||
|
||||
x = torch.rand([2, 3, 224, 224])
|
||||
y = model(x)
|
||||
self.assertEqual(list(y.shape), [2, 624])
|
|
@ -10,7 +10,27 @@ class TestCandidates(TestCase):
|
|||
|
||||
def setUp(self) -> None:
|
||||
self.fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
self.fake_subnet_with_score = (self.fake_subnet, 1.)
|
||||
self.fake_subnet_with_resource = {
|
||||
str(self.fake_subnet): {
|
||||
'score': 0.,
|
||||
'flops': 50.,
|
||||
'params': 0.,
|
||||
'latency': 0.
|
||||
}
|
||||
}
|
||||
self.fake_subnet_with_score = {
|
||||
str(self.fake_subnet): {
|
||||
'score': 99.,
|
||||
'flops': 0.,
|
||||
'params': 0.,
|
||||
'latency': 0.
|
||||
}
|
||||
}
|
||||
self.has_flops_network = {
|
||||
str(self.fake_subnet): {
|
||||
'flops': 50.,
|
||||
}
|
||||
}
|
||||
|
||||
def test_init(self):
|
||||
# initlist is None
|
||||
|
@ -23,16 +43,25 @@ class TestCandidates(TestCase):
|
|||
# initlist is UserList
|
||||
data = UserList([self.fake_subnet] * 2)
|
||||
self.assertEqual(len(candidates.data), 2)
|
||||
self.assertEqual(candidates.resources('flops'), [-1, -1])
|
||||
# initlist is list(Dict[str, Dict])
|
||||
candidates = Candidates([self.has_flops_network] * 2)
|
||||
self.assertEqual(candidates.resources('flops'), [50., 50.])
|
||||
|
||||
def test_scores(self):
|
||||
# test property: scores
|
||||
data = [self.fake_subnet_with_score] * 2
|
||||
candidates = Candidates(data)
|
||||
self.assertEqual(candidates.scores, [1., 1.])
|
||||
self.assertEqual(candidates.scores, [99., 99.])
|
||||
|
||||
def test_resources(self):
|
||||
data = [self.fake_subnet_with_resource] * 2
|
||||
candidates = Candidates(data)
|
||||
self.assertEqual(candidates.resources('flops'), [50., 50.])
|
||||
|
||||
def test_subnets(self):
|
||||
# test property: subnets
|
||||
data = [self.fake_subnet_with_score] * 2
|
||||
data = [self.fake_subnet] * 2
|
||||
candidates = Candidates(data)
|
||||
self.assertEqual(candidates.subnets, [self.fake_subnet] * 2)
|
||||
|
||||
|
@ -41,17 +70,20 @@ class TestCandidates(TestCase):
|
|||
candidates = Candidates()
|
||||
candidates.append(self.fake_subnet)
|
||||
self.assertEqual(len(candidates), 1)
|
||||
# item is tuple
|
||||
# item is List
|
||||
candidates = Candidates()
|
||||
candidates.append(self.fake_subnet_with_score)
|
||||
self.assertEqual(len(candidates), 1)
|
||||
candidates.append([self.fake_subnet_with_score])
|
||||
# item is Candidates
|
||||
candidates_2 = Candidates([self.fake_subnet_with_resource])
|
||||
candidates.append(candidates_2)
|
||||
self.assertEqual(len(candidates), 2)
|
||||
|
||||
def test_insert(self):
|
||||
# item is dict
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates = Candidates(self.fake_subnet_with_score)
|
||||
candidates.insert(1, self.fake_subnet)
|
||||
self.assertEqual(len(candidates), 2)
|
||||
# item is tuple
|
||||
# item is List
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.insert(1, self.fake_subnet_with_score)
|
||||
self.assertEqual(len(candidates), 2)
|
||||
|
@ -61,13 +93,60 @@ class TestCandidates(TestCase):
|
|||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.extend([self.fake_subnet])
|
||||
self.assertEqual(len(candidates), 2)
|
||||
# other is UserList
|
||||
# other is Candidates
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.extend(UserList([self.fake_subnet_with_score]))
|
||||
candidates_2 = Candidates([self.fake_subnet_with_resource])
|
||||
candidates.extend(candidates_2)
|
||||
self.assertEqual(len(candidates), 2)
|
||||
|
||||
def test_set_score(self):
|
||||
# test set_score
|
||||
def test_set_resource(self):
|
||||
# test set_resource
|
||||
candidates = Candidates([self.fake_subnet])
|
||||
for kk in ['flops', 'params', 'latency']:
|
||||
self.assertEqual(candidates.resources(kk)[0], -1)
|
||||
candidates.set_resource(0, 49.9, kk)
|
||||
self.assertEqual(candidates.resources(kk)[0], 49.9)
|
||||
candidates.insert(0, self.fake_subnet_with_resource)
|
||||
self.assertEqual(len(candidates), 2)
|
||||
self.assertEqual(candidates.resources('flops'), [50., 49.9])
|
||||
self.assertEqual(candidates.resources('latency'), [0., 49.9])
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.set_score(0, 0.5)
|
||||
self.assertEqual(candidates[0][1], 0.5)
|
||||
candidates.set_resource(0, 100.0, 'score')
|
||||
self.assertEqual(candidates.scores[0], 100.)
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.set_resource(0, 100.0, 'score')
|
||||
candidates.extend(UserList([self.fake_subnet_with_resource]))
|
||||
candidates.set_resource(1, 99.9, 'score')
|
||||
self.assertEqual(candidates.scores, [100., 99.9])
|
||||
|
||||
def test_update_resources(self):
|
||||
# test update_resources
|
||||
candidates = Candidates([self.fake_subnet])
|
||||
candidates.append([self.fake_subnet_with_score])
|
||||
candidates_2 = Candidates(self.fake_subnet_with_resource)
|
||||
candidates.append(candidates_2)
|
||||
self.assertEqual(len(candidates), 3)
|
||||
self.assertEqual(candidates.resources('flops'), [-1, 0., 50.])
|
||||
self.assertEqual(candidates.resources('latency'), [-1, 0., 0.])
|
||||
resources = [{'flops': -2}, {'latency': 4.}]
|
||||
candidates.update_resources(resources, start=1)
|
||||
self.assertEqual(candidates.resources('flops'), [-1, -2, 50.])
|
||||
self.assertEqual(candidates.resources('latency'), [-1, 0., 4])
|
||||
candidates.update_resources(resources, start=0)
|
||||
self.assertEqual(candidates.resources('flops'), [-2, -2, 50.])
|
||||
self.assertEqual(candidates.resources('latency'), [-1, 4., 4.])
|
||||
|
||||
def test_sort(self):
|
||||
# test set_sort
|
||||
candidates = Candidates([self.fake_subnet_with_score])
|
||||
candidates.extend(UserList([self.fake_subnet_with_resource]))
|
||||
candidates.insert(0, self.fake_subnet)
|
||||
candidates.set_resource(0, 100., 'score')
|
||||
candidates.set_resource(2, 98., 'score')
|
||||
self.assertEqual(candidates.scores, [100., 99., 98.])
|
||||
candidates.sort_by(key_indicator='score', reverse=False)
|
||||
self.assertEqual(candidates.scores, [98., 99., 100.])
|
||||
candidates.sort_by(key_indicator='latency')
|
||||
self.assertEqual(candidates.scores, [98., 99., 100.])
|
||||
candidates.sort_by(key_indicator='flops', reverse=False)
|
||||
self.assertEqual(candidates.scores, [100., 99., 98.])
|
||||
|
|
|
@ -82,7 +82,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
num_mutation=2,
|
||||
num_crossover=2,
|
||||
mutate_prob=0.1,
|
||||
flops_range=None,
|
||||
constraints_range=dict(flops=(0, 330)),
|
||||
score_key='coco/bbox_mAP')
|
||||
self.train_cfg = Config(train_cfg)
|
||||
self.runner = MagicMock(spec=ToyRunner)
|
||||
|
@ -103,7 +103,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
|
||||
# test init_candidates is not None
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
fake_candidates = Candidates((fake_subnet, 0.))
|
||||
fake_candidates = Candidates(fake_subnet)
|
||||
init_candidates_path = os.path.join(self.temp_dir, 'candidates.yaml')
|
||||
fileio.dump(fake_candidates, init_candidates_path)
|
||||
loop_cfg.init_candidates = init_candidates_path
|
||||
|
@ -111,8 +111,12 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.assertIsInstance(loop, EvolutionSearchLoop)
|
||||
self.assertEqual(loop.candidates, fake_candidates)
|
||||
|
||||
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
|
||||
def test_run_epoch(self, mock_export_fix_subnet):
|
||||
@patch('mmrazor.engine.runner.utils.check.load_fix_subnet')
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
'get_model_flops_params')
|
||||
def test_run_epoch(self, flops_params, mock_export_fix_subnet,
|
||||
load_status):
|
||||
# test_run_epoch: distributed == False
|
||||
loop_cfg = copy.deepcopy(self.train_cfg)
|
||||
loop_cfg.runner = self.runner
|
||||
|
@ -120,20 +124,20 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
loop_cfg.evaluator = self.evaluator
|
||||
loop = LOOPS.build(loop_cfg)
|
||||
self.runner.rank = 0
|
||||
loop._epoch = 1
|
||||
self.runner.distributed = False
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
load_status.return_value = True
|
||||
flops_params.return_value = 0, 0
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
self.assertEqual(loop._epoch, 2)
|
||||
self.assertEqual(loop._epoch, 1)
|
||||
|
||||
# test_run_epoch: distributed == True
|
||||
loop = LOOPS.build(loop_cfg)
|
||||
self.runner.rank = 0
|
||||
loop._epoch = 1
|
||||
self.runner.distributed = True
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
|
@ -141,26 +145,27 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
self.assertEqual(loop._epoch, 2)
|
||||
self.assertEqual(loop._epoch, 1)
|
||||
|
||||
# test_check_constraints
|
||||
loop_cfg.flops_range = (0, 100)
|
||||
loop_cfg.constraints_range = dict(params=(0, 100))
|
||||
loop = LOOPS.build(loop_cfg)
|
||||
self.runner.rank = 0
|
||||
loop._epoch = 1
|
||||
self.runner.distributed = True
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop._check_constraints = MagicMock(return_value=True)
|
||||
flops_params.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
self.assertEqual(loop._epoch, 2)
|
||||
self.assertEqual(loop._epoch, 1)
|
||||
|
||||
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
|
||||
def test_run(self, mock_export_fix_subnet):
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
'get_model_flops_params')
|
||||
def test_run_loop(self, mock_flops, mock_export_fix_subnet):
|
||||
# test a new search: resume == None
|
||||
loop_cfg = copy.deepcopy(self.train_cfg)
|
||||
loop_cfg.runner = self.runner
|
||||
|
@ -169,16 +174,26 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
loop = LOOPS.build(loop_cfg)
|
||||
self.runner.rank = 0
|
||||
loop._epoch = 1
|
||||
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
self.runner.work_dir = self.temp_dir
|
||||
loop.update_candidate_pool = MagicMock()
|
||||
loop.val_candidate_pool = MagicMock()
|
||||
|
||||
mutation_candidates = Candidates([fake_subnet] * loop.num_mutation)
|
||||
for i in range(loop.num_mutation):
|
||||
mutation_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops')
|
||||
mutation_candidates.set_resource(i, 99 + i, 'score')
|
||||
crossover_candidates = Candidates([fake_subnet] * loop.num_crossover)
|
||||
for i in range(loop.num_crossover):
|
||||
crossover_candidates.set_resource(i, 0.1 + 0.1 * i, 'flops')
|
||||
crossover_candidates.set_resource(i, 99 + i, 'score')
|
||||
loop.gen_mutation_candidates = \
|
||||
MagicMock(return_value=[fake_subnet]*loop.num_mutation)
|
||||
MagicMock(return_value=mutation_candidates)
|
||||
loop.gen_crossover_candidates = \
|
||||
MagicMock(return_value=[fake_subnet]*loop.num_crossover)
|
||||
loop.top_k_candidates = Candidates([(fake_subnet, 1.0),
|
||||
(fake_subnet, 0.9)])
|
||||
MagicMock(return_value=crossover_candidates)
|
||||
loop.candidates = Candidates([fake_subnet] * 4)
|
||||
mock_flops.return_value = (0.5, 101)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run()
|
||||
assert os.path.exists(
|
||||
|
|
|
@ -119,7 +119,7 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
max_iters=12,
|
||||
val_interval=2,
|
||||
score_key='acc',
|
||||
flops_range=None,
|
||||
constraints_range=None,
|
||||
num_candidates=4,
|
||||
num_samples=2,
|
||||
top_k=2,
|
||||
|
@ -190,7 +190,7 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
loop._iter = loop.val_interval
|
||||
subnet = loop.sample_subnet()
|
||||
self.assertEqual(subnet, fake_subnet)
|
||||
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
|
||||
self.assertEqual(len(loop.top_k_candidates), loop.top_k)
|
||||
|
||||
def test_run(self):
|
||||
# test run with _check_constraints
|
||||
|
@ -200,7 +200,7 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
loop._check_constraints = MagicMock(return_value=True)
|
||||
loop._check_constraints = MagicMock(return_value=(True, dict()))
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
from mmrazor.engine.runner.utils import check_subnet_flops
|
||||
from mmrazor.engine.runner.utils import check_subnet_resources
|
||||
|
||||
try:
|
||||
from mmdet.models.detectors import BaseDetector
|
||||
|
@ -12,29 +12,33 @@ except ImportError:
|
|||
|
||||
@patch('mmrazor.models.ResourceEstimator')
|
||||
@patch('mmrazor.models.SPOS')
|
||||
def test_check_subnet_flops(mock_model, mock_estimator):
|
||||
# flops_range = None
|
||||
flops_range = None
|
||||
def test_check_subnet_resources(mock_model, mock_estimator):
|
||||
# constraints_range = dict()
|
||||
constraints_range = dict()
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is True
|
||||
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
||||
mock_estimator, constraints_range)
|
||||
assert is_pass is True
|
||||
|
||||
# flops_range is not None
|
||||
# constraints_range is not None
|
||||
# architecturte is BaseDetector
|
||||
flops_range = (0., 100.)
|
||||
constraints_range = dict(flops=(0, 330))
|
||||
mock_model.architecture = BaseDetector
|
||||
fake_results = {'flops': 50.}
|
||||
mock_estimator.estimate.return_value = fake_results
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is True
|
||||
is_pass, _ = check_subnet_resources(
|
||||
mock_model,
|
||||
fake_subnet,
|
||||
mock_estimator,
|
||||
constraints_range,
|
||||
)
|
||||
assert is_pass is True
|
||||
|
||||
# flops_range is not None
|
||||
# constraints_range is not None
|
||||
# architecturte is BaseDetector
|
||||
flops_range = (0., 100.)
|
||||
constraints_range = dict(flops=(0, 330))
|
||||
fake_results = {'flops': -50.}
|
||||
mock_estimator.estimate.return_value = fake_results
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is False
|
||||
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
||||
mock_estimator, constraints_range)
|
||||
assert is_pass is False
|
||||
|
|
Loading…
Reference in New Issue