Add GroupFisher pruning algorithm. (#459)

* init

* support expand dwconv

* add tools

* init

* add import

* add configs

* add ut and fix bug

* update

* update finetune config

* update impl imports

* add deploy configs and result

* add _train_step

* detla_type -> normalization_type

* change img link

* add prune to config

* add json dump when GroupFisherSubModel init

* update prune config

* update finetune config

* update deploy config

* update prune config

* update readme

* mutable_cfg -> fix_subnet

* update readme

* impl -> implementations

* update script.sh

* rm gen_fake_cfg

* add Implementation to readme

* update docstring

* add finetune_lr to config

* update readme

* fix error in config

* update links

* update configs

* refine

* fix spell error

* add test to readme

* update README

* update readme

* update readme

* update cite format

* fix for ci

* update to pass ci

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
pull/292/merge
LKJacky 2023-02-20 14:29:42 +08:00 committed by GitHub
parent 18754f3599
commit 7acc046678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 3087 additions and 22 deletions

2
.gitignore vendored
View File

@ -121,3 +121,5 @@ venv.bak/
# Srun
*.out
batchscript-*
work_dir
mmdeploy

View File

@ -68,4 +68,5 @@ repos:
^test
| ^docs
| ^configs
| ^.*/configs*
)

View File

@ -0,0 +1,214 @@
# Group_fisher pruning
> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)
## Abstract
Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.
![pipeline](https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true)
## Results and models
### Classification on ImageNet
| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------ | ----- | ----- | ------- | --------- | ------------- | --------- | ------------------------------------- | ----------------------------------------------------- |
| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] |
| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] |
| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] |
| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] |
| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] |
| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] |
### Detection on COCO
| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- |
| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] |
| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] |
| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] |
**Note**
- Because the pruning papers use different pretraining and finetuning settings, It is hard to compare them fairly. As a result, we prefer to apply algorithms on the openmmlab settings.
- This may make the experiment results are different from that in the original papers.
## Get Started
We have three steps to apply GroupFisher to your model, including Prune, Finetune, Deploy.
Note: please use torch>=1.12, as we need fxtracer to parse the models automatically.
### Prune
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_prune_{model_name}.py 8 \
--work-dir $WORK_DIR
```
In the pruning config file. You have to fill some args as below.
```python
"""
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
```
After the pruning process, you will get a checkpoint of the pruned model named flops\_{target_flop_ratio}.pth in your workdir.
### Finetune
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py 8 \
--work-dir $WORK_DIR
```
There are also some args for you to fill in the config file as below.
```python
"""
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
```
After finetuning, except a checkpoint of the best model, there is also a fix_subnet.json, which records the pruned model structure. It will be used when deploying.
### Test
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_test.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py {checkpoint_path} 8
```
### Deploy
First, we assume you are fimilar to mmdeploy. For a pruned model, you only need to use the pruning deploy config to instead the pretrain config to deploy the pruned version of your model.
```bash
python {mmdeploy}/tools/deploy.py \
{mmdeploy}/{mmdeploy_config}.py \
{config_folder}/group_fisher_{normalization_type}_deploy_{model_name}.py \
{path_to_finetuned_checkpoint}.pth \
{mmdeploy}/tests/data/tiger.jpeg
```
The deploy config has some args as below:
```python
"""
_base_ (str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
```
The divisor is important for the actual inference speed, and we suggest you to test it in \[1,2,4,8,16,32\] to find the fastest divisor.
## Implementation
All the modules of GroupFisher is placesded in mmrazor/implementations/pruning/group_fisher/.
| File | Module | Feature |
| -------------------- | -------------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
| algorithm.py | GroupFisherAlgorithm | Dicide when to prune a channel according to the interval and the current iteration. |
| mutator.py | GroupFisherChannelMutator | Select the unit with the channel of the minimal importance and to prune it. |
| unit.py | GroupFisherChannelUnit | Compute fisher info |
| ops.py <br> counters | GroupFisherConv2d <br> GroupFisherLinear <br> corresbonding counters | Collect model info to compute fisher info, including activation, grad and tensor shape. |
There are also some modules to support GroupFisher. These modules may be refactored and moved to other folders as common modules for all pruning algorithms.
| File | Module | Feature |
| ------------------------- | ---------------------------------------- | ------------------------------------------------------------------- |
| hook.py | PruningStructureHook<br>ResourceInfoHook | Display pruning Structure iteratively. |
| prune_sub_model.py | GroupFisherSubModel | Convert a pruning algorithm(architecture) to a pruned static model. |
| prune_deploy_sub_model.py | GroupFisherDeploySubModel | Init a pruned static model for mmdeploy. |
## Citation
```latex
@InProceedings{Liu:2021,
TITLE = {Group Fisher Pruning for Practical Network Compression},
AUTHOR = {Liu, Liyang
AND Zhang, Shilong
AND Kuang, Zhanghui
AND Zhou, Aojun
AND Xue, Jing-hao
AND Wang, Xinjiang
AND Chen, Yimin
AND Yang, Wenming
AND Liao, Qingmin
AND Zhang, Wayne},
BOOKTITLE = {Proceedings of the 38th International Conference on Machine Learning},
YEAR = {2021},
SERIES = {Proceedings of Machine Learning Research},
MONTH = {18--24 Jul},
PUBLISHER = {PMLR},
}
```
<!-- model links
{model}_{prune_mode}_{file type}
model: r: resnet50, m: mobilenetv2, rt:retinanet
prune_mode: a: act, f: flops
file_type: p: pruned model, f:finetuned_model, l: log, pc: prune config, fc: finetune config.
repo link
{repo}_{model}_{file type}
-->
[cls_m_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py
[cls_m_m]: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
[cls_r50_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py
[cls_r50_m]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
[det_rt_c]: https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py
[det_rt_m]: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
[m_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth
[m_a_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py
[m_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/20230130_203443.json
[m_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth
[m_a_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py
[m_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth
[m_f_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py
[m_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/20230201_211550.json
[m_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth
[m_f_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py
[rt_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_a_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
[rt_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/act/20230113_231904.json
[rt_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth
[rt_a_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py
[rt_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_f_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py
[rt_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/20230129_101502.json
[rt_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth
[rt_f_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py
[r_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth
[r_a_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py
[r_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/act/20230130_175426.json
[r_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth
[r_a_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py
[r_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth
[r_f_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py
[r_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/20230129_190931.json
[r_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth
[r_f_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py

View File

@ -0,0 +1,24 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = ''
fix_subnet = {}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,32 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = ''
pruned_path = ''
finetune_lr = 0.1
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = ''
pretrained_path = ''
interval = 10
normalization_type = 'act'
lr_ratio = 0.1
target_flop_ratio = 0.5
input_shape = (1, 3, 224, 224)
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,11 @@
# Group_fisher pruning
> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)
## Abstract
Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.
![pipeline](https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true)
**Please refer to the [full README](../../base/group_fisher/README.md) for more details.**

View File

@ -0,0 +1,50 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'
fix_subnet = {
'backbone.conv1.conv_(0, 32)_32': 21,
'backbone.layer1.0.conv.1.conv_(0, 16)_16': 10,
'backbone.layer2.0.conv.0.conv_(0, 96)_96': 45,
'backbone.layer2.0.conv.2.conv_(0, 24)_24': 24,
'backbone.layer2.1.conv.0.conv_(0, 144)_144': 73,
'backbone.layer3.0.conv.0.conv_(0, 144)_144': 85,
'backbone.layer3.0.conv.2.conv_(0, 32)_32': 32,
'backbone.layer3.1.conv.0.conv_(0, 192)_192': 95,
'backbone.layer3.2.conv.0.conv_(0, 192)_192': 76,
'backbone.layer4.0.conv.0.conv_(0, 192)_192': 160,
'backbone.layer4.0.conv.2.conv_(0, 64)_64': 64,
'backbone.layer4.1.conv.0.conv_(0, 384)_384': 204,
'backbone.layer4.2.conv.0.conv_(0, 384)_384': 200,
'backbone.layer4.3.conv.0.conv_(0, 384)_384': 217,
'backbone.layer5.0.conv.0.conv_(0, 384)_384': 344,
'backbone.layer5.0.conv.2.conv_(0, 96)_96': 96,
'backbone.layer5.1.conv.0.conv_(0, 576)_576': 348,
'backbone.layer5.2.conv.0.conv_(0, 576)_576': 338,
'backbone.layer6.0.conv.0.conv_(0, 576)_576': 543,
'backbone.layer6.0.conv.2.conv_(0, 160)_160': 160,
'backbone.layer6.1.conv.0.conv_(0, 960)_960': 810,
'backbone.layer6.2.conv.0.conv_(0, 960)_960': 803,
'backbone.layer7.0.conv.0.conv_(0, 960)_960': 944,
'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,31 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth' # noqa
finetune_lr = 0.045
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'
pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa
interval = 25
normalization_type = 'act'
lr_ratio = 0.1125
target_flop_ratio = 0.65
input_shape = (1, 3, 224, 224)
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,49 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'
fix_subnet = {
'backbone.conv1.conv_(0, 32)_32': 27,
'backbone.layer1.0.conv.1.conv_(0, 16)_16': 16,
'backbone.layer2.0.conv.0.conv_(0, 96)_96': 77,
'backbone.layer2.0.conv.2.conv_(0, 24)_24': 24,
'backbone.layer2.1.conv.0.conv_(0, 144)_144': 85,
'backbone.layer3.0.conv.0.conv_(0, 144)_144': 115,
'backbone.layer3.0.conv.2.conv_(0, 32)_32': 32,
'backbone.layer3.1.conv.0.conv_(0, 192)_192': 102,
'backbone.layer3.2.conv.0.conv_(0, 192)_192': 95,
'backbone.layer4.0.conv.0.conv_(0, 192)_192': 181,
'backbone.layer4.0.conv.2.conv_(0, 64)_64': 64,
'backbone.layer4.1.conv.0.conv_(0, 384)_384': 169,
'backbone.layer4.2.conv.0.conv_(0, 384)_384': 176,
'backbone.layer4.3.conv.0.conv_(0, 384)_384': 180,
'backbone.layer5.0.conv.0.conv_(0, 384)_384': 308,
'backbone.layer5.0.conv.2.conv_(0, 96)_96': 96,
'backbone.layer5.1.conv.0.conv_(0, 576)_576': 223,
'backbone.layer5.2.conv.0.conv_(0, 576)_576': 241,
'backbone.layer6.0.conv.0.conv_(0, 576)_576': 511,
'backbone.layer6.0.conv.2.conv_(0, 160)_160': 160,
'backbone.layer6.1.conv.0.conv_(0, 960)_960': 467,
'backbone.layer6.2.conv.0.conv_(0, 960)_960': 510,
'backbone.layer7.0.conv.0.conv_(0, 960)_960': 771,
'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,32 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth' # noqa
finetune_lr = 0.045
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,5 @@
_base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py'
model = dict(
mutator=dict(
channel_unit_cfg=dict(
default_args=dict(normalization_type='flops', ), ), ), )

View File

@ -0,0 +1,7 @@
# act mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py 8
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py 8

View File

@ -0,0 +1,61 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py'
fix_subnet = {
'backbone.conv1_(0, 64)_64': 61,
'backbone.layer1.0.conv1_(0, 64)_64': 27,
'backbone.layer1.0.conv2_(0, 64)_64': 35,
'backbone.layer1.0.conv3_(0, 256)_256': 241,
'backbone.layer1.1.conv1_(0, 64)_64': 32,
'backbone.layer1.1.conv2_(0, 64)_64': 29,
'backbone.layer1.2.conv1_(0, 64)_64': 27,
'backbone.layer1.2.conv2_(0, 64)_64': 42,
'backbone.layer2.0.conv1_(0, 128)_128': 87,
'backbone.layer2.0.conv2_(0, 128)_128': 107,
'backbone.layer2.0.conv3_(0, 512)_512': 512,
'backbone.layer2.1.conv1_(0, 128)_128': 44,
'backbone.layer2.1.conv2_(0, 128)_128': 50,
'backbone.layer2.2.conv1_(0, 128)_128': 52,
'backbone.layer2.2.conv2_(0, 128)_128': 81,
'backbone.layer2.3.conv1_(0, 128)_128': 47,
'backbone.layer2.3.conv2_(0, 128)_128': 50,
'backbone.layer3.0.conv1_(0, 256)_256': 210,
'backbone.layer3.0.conv2_(0, 256)_256': 206,
'backbone.layer3.0.conv3_(0, 1024)_1024': 1024,
'backbone.layer3.1.conv1_(0, 256)_256': 107,
'backbone.layer3.1.conv2_(0, 256)_256': 108,
'backbone.layer3.2.conv1_(0, 256)_256': 86,
'backbone.layer3.2.conv2_(0, 256)_256': 126,
'backbone.layer3.3.conv1_(0, 256)_256': 91,
'backbone.layer3.3.conv2_(0, 256)_256': 112,
'backbone.layer3.4.conv1_(0, 256)_256': 98,
'backbone.layer3.4.conv2_(0, 256)_256': 110,
'backbone.layer3.5.conv1_(0, 256)_256': 112,
'backbone.layer3.5.conv2_(0, 256)_256': 115,
'backbone.layer4.0.conv1_(0, 512)_512': 397,
'backbone.layer4.0.conv2_(0, 512)_512': 427,
'backbone.layer4.1.conv1_(0, 512)_512': 373,
'backbone.layer4.1.conv2_(0, 512)_512': 348,
'backbone.layer4.2.conv1_(0, 512)_512': 433,
'backbone.layer4.2.conv2_(0, 512)_512': 384
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,31 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth' # noqa
finetune_lr = 0.1
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py'
pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa
interval = 25
normalization_type = 'act'
lr_ratio = 0.04
target_flop_ratio = 0.5
input_shape = [1, 3, 224, 224]
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,61 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py'
fix_subnet = {
'backbone.conv1_(0, 64)_64': 61,
'backbone.layer1.0.conv1_(0, 64)_64': 28,
'backbone.layer1.0.conv2_(0, 64)_64': 35,
'backbone.layer1.0.conv3_(0, 256)_256': 242,
'backbone.layer1.1.conv1_(0, 64)_64': 31,
'backbone.layer1.1.conv2_(0, 64)_64': 28,
'backbone.layer1.2.conv1_(0, 64)_64': 26,
'backbone.layer1.2.conv2_(0, 64)_64': 41,
'backbone.layer2.0.conv1_(0, 128)_128': 90,
'backbone.layer2.0.conv2_(0, 128)_128': 107,
'backbone.layer2.0.conv3_(0, 512)_512': 509,
'backbone.layer2.1.conv1_(0, 128)_128': 42,
'backbone.layer2.1.conv2_(0, 128)_128': 50,
'backbone.layer2.2.conv1_(0, 128)_128': 51,
'backbone.layer2.2.conv2_(0, 128)_128': 84,
'backbone.layer2.3.conv1_(0, 128)_128': 49,
'backbone.layer2.3.conv2_(0, 128)_128': 51,
'backbone.layer3.0.conv1_(0, 256)_256': 210,
'backbone.layer3.0.conv2_(0, 256)_256': 207,
'backbone.layer3.0.conv3_(0, 1024)_1024': 1024,
'backbone.layer3.1.conv1_(0, 256)_256': 103,
'backbone.layer3.1.conv2_(0, 256)_256': 108,
'backbone.layer3.2.conv1_(0, 256)_256': 90,
'backbone.layer3.2.conv2_(0, 256)_256': 124,
'backbone.layer3.3.conv1_(0, 256)_256': 94,
'backbone.layer3.3.conv2_(0, 256)_256': 114,
'backbone.layer3.4.conv1_(0, 256)_256': 99,
'backbone.layer3.4.conv2_(0, 256)_256': 111,
'backbone.layer3.5.conv1_(0, 256)_256': 108,
'backbone.layer3.5.conv2_(0, 256)_256': 111,
'backbone.layer4.0.conv1_(0, 512)_512': 400,
'backbone.layer4.0.conv2_(0, 512)_512': 421,
'backbone.layer4.1.conv1_(0, 512)_512': 377,
'backbone.layer4.1.conv2_(0, 512)_512': 347,
'backbone.layer4.2.conv1_(0, 512)_512': 443,
'backbone.layer4.2.conv2_(0, 512)_512': 376
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,31 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_flops_prune_resnet50_8xb32_in1k.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth' # noqa
finetune_lr = 0.1
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,5 @@
_base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py'
model = dict(
mutator=dict(
channel_unit_cfg=dict(
default_args=dict(normalization_type='flops', ), ), ), )

View File

@ -0,0 +1,7 @@
# act mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py.py 8
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py 8

View File

@ -0,0 +1,11 @@
# Group_fisher pruning
> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)
## Abstract
Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.
![pipeline](https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true)
**Please refer to the [full README](../../base/group_fisher/README.md) for more details.**

View File

@ -0,0 +1,73 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'
fix_subnet = {
'backbone.conv1_(0, 64)_64': 60,
'backbone.layer1.0.conv1_(0, 64)_64': 48,
'backbone.layer1.0.conv2_(0, 64)_64': 44,
'backbone.layer1.0.conv3_(0, 256)_256': 250,
'backbone.layer1.1.conv1_(0, 64)_64': 40,
'backbone.layer1.1.conv2_(0, 64)_64': 41,
'backbone.layer1.2.conv1_(0, 64)_64': 48,
'backbone.layer1.2.conv2_(0, 64)_64': 62,
'backbone.layer2.0.conv1_(0, 128)_128': 115,
'backbone.layer2.0.conv2_(0, 128)_128': 127,
'backbone.layer2.0.conv3_(0, 512)_512': 511,
'backbone.layer2.1.conv1_(0, 128)_128': 69,
'backbone.layer2.1.conv2_(0, 128)_128': 83,
'backbone.layer2.2.conv1_(0, 128)_128': 111,
'backbone.layer2.2.conv2_(0, 128)_128': 121,
'backbone.layer2.3.conv1_(0, 128)_128': 122,
'backbone.layer2.3.conv2_(0, 128)_128': 128,
'backbone.layer3.0.conv1_(0, 256)_256': 255,
'backbone.layer3.0.conv2_(0, 256)_256': 256,
'backbone.layer3.0.conv3_(0, 1024)_1024': 1024,
'backbone.layer3.1.conv1_(0, 256)_256': 216,
'backbone.layer3.1.conv2_(0, 256)_256': 223,
'backbone.layer3.2.conv1_(0, 256)_256': 229,
'backbone.layer3.2.conv2_(0, 256)_256': 247,
'backbone.layer3.3.conv1_(0, 256)_256': 239,
'backbone.layer3.3.conv2_(0, 256)_256': 246,
'backbone.layer3.4.conv1_(0, 256)_256': 237,
'backbone.layer3.4.conv2_(0, 256)_256': 239,
'backbone.layer3.5.conv1_(0, 256)_256': 233,
'backbone.layer3.5.conv2_(0, 256)_256': 221,
'backbone.layer4.0.conv1_(0, 512)_512': 499,
'backbone.layer4.0.conv2_(0, 512)_512': 494,
'backbone.layer4.0.conv3_(0, 2048)_2048': 2031,
'backbone.layer4.1.conv1_(0, 512)_512': 451,
'backbone.layer4.1.conv2_(0, 512)_512': 401,
'backbone.layer4.2.conv1_(0, 512)_512': 396,
'backbone.layer4.2.conv2_(0, 512)_512': 237,
'neck.lateral_convs.0.conv_(0, 256)_256': 237,
'neck.fpn_convs.0.conv_(0, 256)_256': 241,
'bbox_head.cls_convs.0.conv_(0, 256)_256': 133,
'bbox_head.cls_convs.1.conv_(0, 256)_256': 134,
'bbox_head.cls_convs.2.conv_(0, 256)_256': 139,
'bbox_head.cls_convs.3.conv_(0, 256)_256': 79,
'bbox_head.reg_convs.0.conv_(0, 256)_256': 89,
'bbox_head.reg_convs.1.conv_(0, 256)_256': 92,
'bbox_head.reg_convs.2.conv_(0, 256)_256': 82,
'bbox_head.reg_convs.3.conv_(0, 256)_256': 117
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,31 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth' # noqa
finetune_lr = 0.005
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'
pretrained_path = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa
interval = 10
normalization_type = 'act'
lr_ratio = 0.1
target_flop_ratio = 0.5
input_shape = (1, 3, 1333, 800)
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,73 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'
fix_subnet = {
'backbone.conv1_(0, 64)_64': 60,
'backbone.layer1.0.conv1_(0, 64)_64': 47,
'backbone.layer1.0.conv2_(0, 64)_64': 44,
'backbone.layer1.0.conv3_(0, 256)_256': 249,
'backbone.layer1.1.conv1_(0, 64)_64': 37,
'backbone.layer1.1.conv2_(0, 64)_64': 37,
'backbone.layer1.2.conv1_(0, 64)_64': 44,
'backbone.layer1.2.conv2_(0, 64)_64': 62,
'backbone.layer2.0.conv1_(0, 128)_128': 114,
'backbone.layer2.0.conv2_(0, 128)_128': 127,
'backbone.layer2.0.conv3_(0, 512)_512': 511,
'backbone.layer2.1.conv1_(0, 128)_128': 65,
'backbone.layer2.1.conv2_(0, 128)_128': 83,
'backbone.layer2.2.conv1_(0, 128)_128': 106,
'backbone.layer2.2.conv2_(0, 128)_128': 118,
'backbone.layer2.3.conv1_(0, 128)_128': 118,
'backbone.layer2.3.conv2_(0, 128)_128': 127,
'backbone.layer3.0.conv1_(0, 256)_256': 255,
'backbone.layer3.0.conv2_(0, 256)_256': 256,
'backbone.layer3.0.conv3_(0, 1024)_1024': 1024,
'backbone.layer3.1.conv1_(0, 256)_256': 214,
'backbone.layer3.1.conv2_(0, 256)_256': 232,
'backbone.layer3.2.conv1_(0, 256)_256': 224,
'backbone.layer3.2.conv2_(0, 256)_256': 247,
'backbone.layer3.3.conv1_(0, 256)_256': 240,
'backbone.layer3.3.conv2_(0, 256)_256': 246,
'backbone.layer3.4.conv1_(0, 256)_256': 240,
'backbone.layer3.4.conv2_(0, 256)_256': 243,
'backbone.layer3.5.conv1_(0, 256)_256': 238,
'backbone.layer3.5.conv2_(0, 256)_256': 232,
'backbone.layer4.0.conv1_(0, 512)_512': 503,
'backbone.layer4.0.conv2_(0, 512)_512': 500,
'backbone.layer4.0.conv3_(0, 2048)_2048': 2041,
'backbone.layer4.1.conv1_(0, 512)_512': 466,
'backbone.layer4.1.conv2_(0, 512)_512': 430,
'backbone.layer4.2.conv1_(0, 512)_512': 406,
'backbone.layer4.2.conv2_(0, 512)_512': 274,
'neck.lateral_convs.0.conv_(0, 256)_256': 236,
'neck.fpn_convs.0.conv_(0, 256)_256': 225,
'bbox_head.cls_convs.0.conv_(0, 256)_256': 140,
'bbox_head.cls_convs.1.conv_(0, 256)_256': 133,
'bbox_head.cls_convs.2.conv_(0, 256)_256': 139,
'bbox_head.cls_convs.3.conv_(0, 256)_256': 86,
'bbox_head.reg_convs.0.conv_(0, 256)_256': 89,
'bbox_head.reg_convs.1.conv_(0, 256)_256': 89,
'bbox_head.reg_convs.2.conv_(0, 256)_256': 76,
'bbox_head.reg_convs.3.conv_(0, 256)_256': 122,
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,31 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth' # noqa
finetune_lr = 0.005
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,5 @@
_base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py'
model = dict(
mutator=dict(
channel_unit_cfg=dict(
default_args=dict(normalization_type='flops', ), ), ), )

View File

@ -0,0 +1,7 @@
# act mode
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py 8
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py 8
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py 8
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py 8

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import \
PruningStructureHook # noqa
from mmrazor.implementations.pruning.group_fisher import \
ResourceInfoHook # noqa

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""impl folder is an experimental file structure to store algorithm
implementations.
Previous file structure splits the files of an algorithm into different folders
according to the types of these files. It may make it hard to understand an
algorithm. So we add the impl folder, where all files of an algorithm are
stored in one folder. As this structure is experimental, it may change rapidly.
"""
from . import pruning # noqa
__all__ = ['pruning']

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import group_fisher
__all__ = ['group_fisher']

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .algorithm import GroupFisherAlgorithm
from .counters import GroupFisherConv2dCounter, GroupFisherLinearCounter
from .hook import PruningStructureHook, ResourceInfoHook
from .mutator import GroupFisherChannelMutator
from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin
from .prune_deploy_sub_model import GroupFisherDeploySubModel
from .prune_sub_model import GroupFisherSubModel
from .unit import GroupFisherChannelUnit
__all__ = [
'GroupFisherDeploySubModel',
'GroupFisherSubModel',
'GroupFisherAlgorithm',
'GroupFisherConv2dCounter',
'GroupFisherLinearCounter',
'PruningStructureHook',
'ResourceInfoHook',
'GroupFisherChannelMutator',
'GroupFisherChannelUnit',
'GroupFisherConv2d',
'GroupFisherLinear',
'GroupFisherMixin',
]

View File

@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.model import BaseModel, MMDistributedDataParallel
from mmrazor.models.algorithms.base import BaseAlgorithm
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from mmrazor.utils import RuntimeInfo
from .mutator import GroupFisherChannelMutator
@MODELS.register_module()
class GroupFisherAlgorithm(BaseAlgorithm):
"""`Group Fisher Pruning for Practical Network Compression`.
https://arxiv.org/pdf/2108.00708.pdf.
Args:
architecture (Union[BaseModel, Dict]): The model to be pruned.
mutator (Union[Dict, ChannelMutator], optional): The config
of a mutator. Defaults to dict( type='GroupFisherChannelMutator',
channel_unit_cfg=dict( type='GroupFisherChannelUnit')).
interval (int): The interval of pruning two channels. Defaults to 10.
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
Defaults to None.
init_cfg (Optional[Dict], optional): init config for the model.
Defaults to None.
"""
def __init__(self,
architecture: Union[BaseModel, Dict],
mutator: Union[Dict, GroupFisherChannelMutator] = dict(
type='GroupFisherChannelMutator',
channel_unit_cfg=dict(type='GroupFisherChannelUnit')),
interval: int = 10,
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(architecture, data_preprocessor, init_cfg)
self.interval = interval
# using sync bn or normal bn
if dist.is_initialized():
print_log('Convert Bn to SyncBn.')
self.architecture = nn.SyncBatchNorm.convert_sync_batchnorm(
self.architecture)
else:
from mmengine.model import revert_sync_batchnorm
self.architecture = revert_sync_batchnorm(self.architecture)
# mutator
self.mutator: GroupFisherChannelMutator = MODELS.build(mutator)
self.mutator.prepare_from_supernet(self.architecture)
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper) -> Dict[str, torch.Tensor]:
return self._train_step(data, optim_wrapper)
def _train_step(self, data: Union[dict, tuple, list], optim_wrapper):
"""Train step function for GroupFisherAlgorithm and GroupFisherDDP."""
self.mutator.start_record_info()
res = super().train_step(data, optim_wrapper)
self.mutator.end_record_info()
self.mutator.update_imp()
self.mutator.reset_recorded_info()
if RuntimeInfo.iter() % self.interval == 0:
self.mutator.try_prune()
self.mutator.reset_imp()
return res
@MODEL_WRAPPERS.register_module()
class GroupFisherDDP(MMDistributedDataParallel):
"""Train step for group fisher."""
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper) -> Dict[str, torch.Tensor]:
algorithm = self.module
return GroupFisherAlgorithm._train_step(algorithm, data, optim_wrapper)

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.task_modules.estimators.counters.op_counters.dynamic_op_counters import ( # noqa
DynamicConv2dCounter, DynamicLinearCounter)
from mmrazor.registry import TASK_UTILS
@TASK_UTILS.register_module()
class GroupFisherConv2dCounter(DynamicConv2dCounter):
"""Counter of GroupFisherConv2d."""
pass
@TASK_UTILS.register_module()
class GroupFisherLinearCounter(DynamicLinearCounter):
"""Counter of GroupFisherLinear."""
pass

View File

@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.runner import Runner, save_checkpoint
from torch import distributed as torch_dist
from mmrazor.models.algorithms import BaseAlgorithm
from mmrazor.models.mutators.channel_mutator.channel_mutator import \
ChannelMutator
from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
from mmrazor.models.task_modules.estimators import ResourceEstimator
from mmrazor.registry import HOOKS, TASK_UTILS
from mmrazor.utils import RuntimeInfo, print_log
def get_model_from_runner(runner):
"""Get the model from a runner."""
if torch_dist.is_initialized():
return runner.model.module
else:
return runner.model
def is_pruning_algorithm(algorithm):
"""Check whether a model is a pruning algorithm."""
return isinstance(algorithm, BaseAlgorithm) \
and isinstance(getattr(algorithm, 'mutator', None), ChannelMutator) # noqa
@HOOKS.register_module()
class PruningStructureHook(Hook):
"""This hook is used to display the structurn information during pruning.
Args:
by_epoch (bool, optional): Whether to display structure information
iteratively by epoch. Defaults to True.
interval (int, optional): The interval between two structure
information display.
"""
def __init__(self, by_epoch=True, interval=1) -> None:
super().__init__()
self.by_epoch = by_epoch
self.interval = interval
def show_unit_info(self, algorithm):
"""Show unit information of an algorithm."""
if is_pruning_algorithm(algorithm):
chices = algorithm.mutator.choice_template
import json
print_log(json.dumps(chices, indent=4))
for unit in algorithm.mutator.mutable_units:
if hasattr(unit, 'importance'):
imp = unit.importance()
print_log(
f'{unit.name}: \t{imp.min().item()}\t{imp.max().item()}' # noqa
)
@master_only
def show(self, runner):
"""Show pruning algorithm information of a runner."""
algorithm = get_model_from_runner(runner)
if is_pruning_algorithm(algorithm):
self.show_unit_info(algorithm)
# hook points
def after_train_epoch(self, runner) -> None:
if self.by_epoch and RuntimeInfo.epoch() % self.interval == 0:
self.show(runner)
def after_train_iter(self, runner, batch_idx: int, data_batch,
outputs) -> None:
if not self.by_epoch and RuntimeInfo.iter() % self.interval == 0:
self.show(runner)
@HOOKS.register_module()
class ResourceInfoHook(Hook):
"""This hook is used to display the resource related information and save
the checkpoint according to a threshold during pruning.
Args:
demo_input (dict, optional): the demo input for ResourceEstimator.
Defaults to DefaultDemoInput([1, 3, 224, 224]).
interval (int, optional): the interval to check the resource. Defaults
to 10.
resource_type (str, optional): the type of resource to check.
Defaults to 'flops'.
save_ckpt_thr (list, optional): the threshold to save checkpoint.
Defaults to [0.5].
early_stop (bool, optional): whether to stop when all checkpoints have
been saved according to save_ckpt_thr. Defaults to True.
"""
def __init__(self,
demo_input=DefaultDemoInput([1, 3, 224, 224]),
interval=10,
resource_type='flops',
save_ckpt_thr=[0.5],
early_stop=True) -> None:
super().__init__()
if isinstance(demo_input, dict):
demo_input = TASK_UTILS.build(demo_input)
self.demo_input = demo_input
self.save_ckpt_thr = sorted(
save_ckpt_thr, reverse=True) # big to small
self.resource_type = resource_type
self.early_stop = early_stop
self.estimator: ResourceEstimator = TASK_UTILS.build(
dict(
_scope_='mmrazor',
type='ResourceEstimator',
flops_params_cfg=dict(
input_shape=tuple(demo_input.input_shape), )))
self.interval = interval
self.origin_delta = None
def before_run(self, runner) -> None:
"""Init original_resource."""
model = get_model_from_runner(runner)
original_resource = self._evaluate(model)
print_log(f'get original resource: {original_resource}')
self.origin_delta = original_resource[self.resource_type]
# save checkpoint
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch=None,
outputs=None) -> None:
"""Check resource after train iteration."""
if RuntimeInfo.iter() % self.interval == 0 and len(
self.save_ckpt_thr) > 0:
model = get_model_from_runner(runner)
current_delta = self._evaluate(model)[self.resource_type]
percent = current_delta / self.origin_delta
if percent < self.save_ckpt_thr[0]:
self._save_checkpoint(model, runner.work_dir,
self.save_ckpt_thr.pop(0))
if self.early_stop and len(self.save_ckpt_thr) == 0:
exit()
# show info
@master_only
def after_train_epoch(self, runner) -> None:
"""Check resource after train epoch."""
model = get_model_from_runner(runner)
current_delta = self._evaluate(model)[self.resource_type]
print_log(
f'current {self.resource_type}: {current_delta} / {self.origin_delta}' # noqa
)
#
def _evaluate(self, model: nn.Module):
"""Evaluate the resource required by a model."""
with torch.no_grad():
training = model.training
model.eval()
res = self.estimator.estimate(model)
if training:
model.train()
return res
@master_only
def _save_checkpoint(self, model, path, delta_percent):
"""Save the checkpoint of a model."""
ckpt = {'state_dict': model.state_dict()}
save_path = f'{path}/{self.resource_type}_{delta_percent:.2f}.pth'
save_checkpoint(ckpt, save_path)
print_log(
f'Save checkpoint to {save_path} with {self._evaluate(model)}' # noqa
)

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Type, Union
from mmengine.dist import dist
from mmrazor.models.mutators.channel_mutator.channel_mutator import \
ChannelMutator
from mmrazor.registry import MODELS
from mmrazor.utils import print_log
from .unit import GroupFisherChannelUnit
@MODELS.register_module()
class GroupFisherChannelMutator(ChannelMutator[GroupFisherChannelUnit]):
"""Channel mutator for GroupFisher Pruning Algorithm.
Args:
channel_unit_cfg (Union[dict, Type[ChannelUnitType]], optional):
Config of MutableChannelUnits. Defaults to
dict(type='GroupFisherChannelUnit',
default_args=dict(choice_mode='ratio')).
parse_cfg (Dict): The config of the tracer to parse the model.
Defaults to dict(type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='FxTracer').
"""
def __init__(self,
channel_unit_cfg: Union[dict,
Type[GroupFisherChannelUnit]] = dict(
type='GroupFisherChannelUnit'),
parse_cfg: Dict = dict(
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='FxTracer'),
**kwargs) -> None:
super().__init__(channel_unit_cfg, parse_cfg, **kwargs)
self.mutable_units: List[GroupFisherChannelUnit]
def start_record_info(self) -> None:
"""Start recording the related information."""
for unit in self.mutable_units:
unit.start_record_fisher_info()
def end_record_info(self) -> None:
"""Stop recording the related information."""
for unit in self.mutable_units:
unit.end_record_fisher_info()
def reset_recorded_info(self) -> None:
"""Reset the related information."""
for unit in self.mutable_units:
unit.reset_recorded()
def try_prune(self) -> None:
"""Prune the channel with the minimum fisher unless it is the last
channel of the current layer."""
min_imp = 1e5
min_unit = self.mutable_units[0]
for unit in self.mutable_units:
if unit.mutable_channel.activated_channels > 1:
imp = unit.importance()
if imp.isnan().any():
if dist.get_rank() == 0:
print_log(
f'{unit.name} detects nan in importance, this pruning skips.' # noqa
)
return
if imp.min() < min_imp:
min_imp = imp.min().item()
min_unit = unit
if min_unit.try_to_prune_min_channel():
if dist.get_rank() == 0:
print_log(
f'{min_unit.name} prunes a channel with min imp = {min_imp}' # noqa
)
def update_imp(self) -> None:
"""Update the fisher information of each unit."""
for unit in self.mutable_units:
unit.update_fisher_info()
def reset_imp(self) -> None:
"""Reset the fisher information of each unit."""
for unit in self.mutable_units:
unit.reset_fisher_info()

View File

@ -0,0 +1,150 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_conv import \
DynamicConv2d
from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_linear import \
DynamicLinear
class GroupFisherMixin:
"""The mixin class for GroupFisher ops."""
def _init(self) -> None:
self.handlers: list = []
self.recorded_input: List = []
self.recorded_grad: List = []
self.recorded_out_shape: List = []
def forward_hook_wrapper(self):
"""Wrap the hook used in forward."""
def forward_hook(module: GroupFisherMixin, input, output):
module.recorded_out_shape.append(output.shape)
module.recorded_input.append(input[0])
return forward_hook
def backward_hook_wrapper(self):
"""Wrap the hook used in backward."""
def backward_hook(module: GroupFisherMixin, grad_in, grad_out):
module.recorded_grad.insert(0, grad_in[0])
return backward_hook
def start_record(self: torch.nn.Module) -> None:
"""Start recording information during forward and backward."""
self.end_record() # ensure to run start_record only once
self.handlers.append(
self.register_forward_hook(self.forward_hook_wrapper()))
self.handlers.append(
self.register_backward_hook(self.backward_hook_wrapper()))
def end_record(self):
"""Stop recording information during forward and backward."""
for handle in self.handlers:
handle.remove()
self.handlers = []
def reset_recorded(self):
"""Reset the recorded information."""
self.recorded_input = []
self.recorded_grad = []
self.recorded_out_shape = []
@property
def delta_flop_of_a_out_channel(self):
raise NotImplementedError()
@property
def delta_flop_of_a_in_channel(self):
raise NotImplementedError()
@property
def delta_memory_of_a_out_channel(self):
raise NotImplementedError()
class GroupFisherConv2d(DynamicConv2d, GroupFisherMixin):
"""The Dynamic Conv2d operation used in GroupFisher Algorithm."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._init()
@property
def delta_flop_of_a_out_channel(self) -> torch.Tensor:
"""Calculate the summation of flops when prune an out_channel."""
delta_flop_sum = 0
for shape in self.recorded_out_shape:
_, _, h, w = shape
in_c = int(self.mutable_attrs['in_channels'].current_mask.float().
sum().item())
# normal conv
if self.groups == 1:
delta_flop = h * w * self.kernel_size[0] * self.kernel_size[
1] * in_c
# dwconv
elif self.groups == self.in_channels == self.out_channels:
delta_flop = h * w * self.kernel_size[0] * self.kernel_size[1]
# groupwise conv
else:
raise NotImplementedError()
delta_flop_sum += delta_flop
return delta_flop_sum
@property
def delta_flop_of_a_in_channel(self):
"""Calculate the summation of flops when prune an in_channel."""
delta_flop_sum = 0
for shape in self.recorded_out_shape:
_, out_c, h, w = shape
# normal conv
if self.groups == 1:
delta_flop = h * w * self.kernel_size[0] * self.kernel_size[
1] * out_c
# dwconv
elif self.groups == self.in_channels == self.out_channels:
delta_flop = h * w * self.kernel_size[0] * self.kernel_size[1]
# groupwise conv
else:
raise NotImplementedError()
delta_flop_sum += delta_flop
return delta_flop_sum
@property
def delta_memory_of_a_out_channel(self):
"""Calculate the summation of memory when prune a channel."""
delta_flop_sum = 0
for shape in self.recorded_out_shape:
_, _, h, w = shape
delta_flop_sum += h * w
return delta_flop_sum
class GroupFisherLinear(DynamicLinear, GroupFisherMixin):
"""The Dynamic Linear operation used in GroupFisher Algorithm."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._init()
@property
def delta_flop_of_a_out_channel(self):
"""Calculate the summation of flops when prune an out_channel."""
in_c = self.mutable_attrs['in_channels'].current_mask.float().sum()
return in_c * len(self.recorded_out_shape)
@property
def delta_flop_of_a_in_channel(self):
"""Calculate the summation of flops when prune an in_channel."""
out_c = self.mutable_attrs['out_channels'].current_mask.float().sum()
return out_c * len(self.recorded_out_shape)
@property
def delta_memory_of_a_out_channel(self):
"""Calculate the summation of memory when prune a channel."""
return 1 * len(self.recorded_out_shape)

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Union
import torch.nn as nn
from mmengine import fileio
from mmrazor.registry import MODELS
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
from mmrazor.utils import print_log
@MODELS.register_module()
def GroupFisherDeploySubModel(architecture,
fix_subnet: Union[dict, str] = {},
divisor=1,
parse_cfg=dict(
_scope_='mmrazor',
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='FxTracer'),
**kwargs):
"""Convert a architecture to a pruned static architecture for mmdeploy.
Args:
architecture (Union[nn.Module, dict]): the model to be pruned.
fix_subnet (Union[dict, str]): the channel remaining ratio for each
unit, or the path of a file including this info. Defaults to {}.
divisor (int, optional): The divisor to make the channel number
divisible. Defaults to 1.
parse_cfg (dict, optional): The args for channel mutator.
Returns:
BaseModel: a BaseModel of mmengine.
"""
# import avoid circular import
from mmrazor.models.mutables import SequentialMutableChannelUnit
from mmrazor.models.mutators import ChannelMutator
from mmrazor.models.utils.expandable_utils.unit import ExpandableUnit
# build architecture
if isinstance(architecture, dict):
architecture = MODELS.build(architecture)
assert isinstance(architecture, nn.Module)
# to dynamic model
mutator = ChannelMutator[ExpandableUnit](
channel_unit_cfg=SequentialMutableChannelUnit, parse_cfg=parse_cfg)
mutator.prepare_from_supernet(architecture)
if isinstance(fix_subnet, str):
fix_subnet = fileio.load(fix_subnet)
assert isinstance(fix_subnet, dict)
mutator.set_choices(fix_subnet)
print_log(json.dumps(mutator.current_choices, indent=4))
fix_subnet = export_fix_subnet(architecture)[0]
load_fix_subnet(architecture, fix_subnet)
# cooperate with mmdeploy to make the channel divisible after load
# the checkpoint.
if divisor != 1:
setattr(architecture, '_razor_divisor', divisor)
return architecture

View File

@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import types
import torch.nn as nn
from mmengine import dist, fileio
from mmengine.model import BaseModel, BaseModule
from mmrazor.models.algorithms import BaseAlgorithm
from mmrazor.models.utils.expandable_utils import make_channel_divisible
from mmrazor.registry import MODELS
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
from mmrazor.utils import RuntimeInfo, print_log
def clean_params_init_info(model: nn.Module):
"""Clean param init info."""
if hasattr(model, '_params_init_info'):
delattr(model, '_params_init_info')
for module in model.modules():
if hasattr(module, '_params_init_info'):
delattr(module, '_params_init_info')
def clean_init_cfg(model: BaseModule):
"""Clean init cfg."""
for module in model.modules():
if module is model:
continue
if isinstance(module, BaseModule):
module.init_cfg = {}
def hacky_init_weights_wrapper(fix_subnet):
"""This init weight method is used to prevent the model init again after
build.
Besides, It also save fix_subnet.json after RuntimeInfo is ready.
"""
def hacky_init_weights(model):
if dist.get_rank() == 0:
try:
work_dir = RuntimeInfo.work_dir()
fileio.dump(
fix_subnet, work_dir + '/fix_subnet.json', indent=4)
print_log(
f'save pruning structure in {work_dir}/fix_subnet.json')
except Exception:
pass
return hacky_init_weights
@MODELS.register_module()
def GroupFisherSubModel(
algorithm,
divisor=1,
**kargs,
):
"""Convert a algorithm(with an architecture) to a static pruned
architecture.
Args:
algorithm (Union[BaseAlgorithm, dict]): The pruning algorithm to
finetune.
divisor (int): The divisor to make the channel number
divisible. Defaults to 1.
Returns:
nn.Module: a static model.
"""
# init algorithm
if isinstance(algorithm, dict):
algorithm = MODELS.build(algorithm) # type: ignore
assert isinstance(algorithm, BaseAlgorithm)
algorithm.init_weights()
clean_params_init_info(algorithm)
pruning_structure = algorithm.mutator.choice_template
print_log('PruneSubModel get pruning structure:')
print_log(json.dumps(pruning_structure, indent=4))
# to static model
fix_mutable = export_fix_subnet(algorithm.architecture)[0]
load_fix_subnet(algorithm.architecture, fix_mutable)
model = algorithm.architecture
# make channel divisible
if divisor != 1:
divisible_structure = make_channel_divisible(
model, divisor=divisor, zero_weight=False)
print_log('PruneSubModel get divisible pruning structure:')
print_log(json.dumps(divisible_structure, indent=4))
pruning_structure = divisible_structure
# refine model
model.data_preprocessor = algorithm.data_preprocessor
if isinstance(model, BaseModel):
model.init_cfg = None
model.init_weights = types.MethodType(
hacky_init_weights_wrapper(pruning_structure), model)
return model

View File

@ -0,0 +1,230 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from mmengine.model.utils import _BatchNormXd
from mmengine.utils.dl_utils.parrots_wrapper import \
SyncBatchNorm as EngineSyncBatchNorm
from torch import distributed as dist
import mmrazor.models.architectures.dynamic_ops as dynamic_ops
from mmrazor.models.mutables.mutable_channel.mutable_channel_container import \
MutableChannelContainer
from mmrazor.models.mutables.mutable_channel.units.l1_mutable_channel_unit import \
L1MutableChannelUnit # noqa
from mmrazor.registry import MODELS
from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin
@MODELS.register_module()
class GroupFisherChannelUnit(L1MutableChannelUnit):
"""ChannelUnit for GroupFisher Pruning Algorithm.
Args:
num_channels (int): Number of channels.
normalization_type (str): Type of normalization. It can be one of
['flops','act','none',]. Defaults to 'flop'.
mutate_linear (bool): Whether to prune linear layers.
"""
def __init__(self,
num_channels: int,
normalization_type: str = 'flops',
mutate_linear=False,
*args) -> None:
super().__init__(num_channels, *args)
normalized_fisher_info = torch.zeros([self.num_channels])
self.register_buffer('normalized_fisher_info', normalized_fisher_info)
self.normalized_fisher_info: torch.Tensor
self.hook_handles: List = []
assert normalization_type in ['flops', 'act', 'none']
self.delta_type = normalization_type
self.mutate_linear = mutate_linear
def prepare_for_pruning(self, model: nn.Module) -> None:
"""Prepare for pruning, including register mutable channels.
Args:
model (nn.Module): The model need to be pruned.
"""
# register MutableMask
self._replace_with_dynamic_ops(
model, {
nn.Conv2d: GroupFisherConv2d,
nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d,
nn.Linear: GroupFisherLinear,
nn.SyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm,
EngineSyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm,
_BatchNormXd: dynamic_ops.DynamicBatchNormXd,
})
self._register_channel_container(model, MutableChannelContainer)
self._register_mutable_channel(self.mutable_channel)
# prune
def try_to_prune_min_channel(self) -> bool:
"""Prune the channel with the minimum value of fisher information."""
if self.mutable_channel.activated_channels > 1:
imp = self.importance()
index = imp.argmin()
self.mutable_channel.mask.scatter_(0, index, 0.0)
return True
else:
return False
@property
def is_mutable(self) -> bool:
"""Whether the unit is mutable."""
mutable = super().is_mutable
if self.mutate_linear:
return mutable
else:
has_linear = False
for layer in self.input_related:
if isinstance(layer.module, nn.Linear):
has_linear = True
return mutable and (not has_linear)
@property
def input_related_dynamic_ops(self):
for channel in self.input_related:
if isinstance(channel.module, GroupFisherMixin):
yield channel.module
@property
def output_related_dynamic_ops(self):
for channel in self.output_related:
if isinstance(channel.module, GroupFisherMixin):
yield channel.module
@property
def dynamic_ops(self):
for module in self.input_related_dynamic_ops:
yield module
for module in self.output_related_dynamic_ops:
yield module
# fisher information recorded
def start_record_fisher_info(self) -> None:
"""Start recording the related fisher info of each channel."""
for module in self.dynamic_ops:
module.start_record()
def end_record_fisher_info(self) -> None:
"""Stop recording the related fisher info of each channel."""
for module in self.dynamic_ops:
module.end_record()
def reset_recorded(self) -> None:
"""Reset the recorded info of each channel."""
for module in self.dynamic_ops:
module.reset_recorded()
# fisher related computation
def importance(self):
"""The importance of each channel."""
fisher = self.normalized_fisher_info.clone()
mask = self.mutable_channel.current_mask
n_mask = (1 - mask.float()).bool()
fisher.masked_fill_(n_mask, fisher.max() + 1)
return fisher
def reset_fisher_info(self) -> None:
"""Reset the related fisher info."""
self.normalized_fisher_info.zero_()
@torch.no_grad()
def update_fisher_info(self) -> None:
"""Update the fisher info of each channel."""
batch_fisher_sum = self.current_batch_fisher
assert isinstance(batch_fisher_sum, torch.Tensor)
if dist.is_initialized():
dist.all_reduce(batch_fisher_sum)
batch_fisher_sum = self._get_normalized_fisher_info(
batch_fisher_sum, self.delta_type)
self.normalized_fisher_info = self.normalized_fisher_info + batch_fisher_sum # noqa
@property
def current_batch_fisher(self) -> torch.Tensor:
"""Accumulate the unit's fisher info of this batch."""
with torch.no_grad():
fisher: torch.Tensor = 0
for module in self.input_related_dynamic_ops:
fisher = fisher + self._fisher_of_a_module(module)
return (fisher**2).sum(0) # shape: [C]
@torch.no_grad()
def _fisher_of_a_module(self, module: GroupFisherMixin) -> torch.Tensor:
"""Calculate the fisher info of one module.
Args:
module (GroupFisherConv2d): A `GroupFisherConv2d` module.
Return:
torch.Tensor: Whose shape is [B C]
"""
assert len(module.recorded_input) > 0 and \
len(module.recorded_input) == len(module.recorded_grad)
fisher_sum: torch.Tensor = 0
for input, grad_input in zip(module.recorded_input,
module.recorded_grad):
fisher: torch.Tensor = input * grad_input
if len(fisher.shape) == 4:
fisher = fisher.sum(dim=[2, 3])
assert len(fisher.shape) == 2 # B C
fisher_sum = fisher_sum + fisher
assert isinstance(fisher_sum, torch.Tensor)
# expand to full num_channel
batch_size = fisher_sum.shape[0]
mask = self.mutable_channel.current_mask.unsqueeze(0).expand(
[batch_size, self.num_channels])
zeros = fisher_sum.new_zeros([batch_size, self.num_channels])
fisher_sum = zeros.masked_scatter_(mask, fisher_sum)
return fisher_sum
@torch.no_grad()
def _get_normalized_fisher_info(self,
fisher_info,
delta_type='flop') -> torch.Tensor:
"""Get the normalized fisher info.
Args:
delta_type (str): Type of delta. Defaults to 'flop'.
"""
fisher = fisher_info.double()
if delta_type == 'flops':
delta_flop = self._delta_flop_of_a_channel
assert delta_flop > 0
fisher = fisher / (float(delta_flop) / 1e9)
elif delta_type == 'act':
delta_memory = self._delta_memory_of_a_channel
assert delta_memory > 0
fisher = fisher / (float(delta_memory) / 1e6)
elif delta_type == 'none':
pass
else:
raise NotImplementedError(delta_type)
return fisher
@property
def _delta_flop_of_a_channel(self) -> torch.Tensor:
"""Calculate the flops of a channel."""
delta_flop = 0
for module in self.output_related_dynamic_ops:
delta_flop += module.delta_flop_of_a_out_channel
for module in self.input_related_dynamic_ops:
delta_flop += module.delta_flop_of_a_in_channel
return delta_flop
@property
def _delta_memory_of_a_channel(self) -> torch.Tensor:
"""Calculate the memory of a channel."""
delta_memory = 0
for module in self.output_related_dynamic_ops:
delta_memory += module.delta_memory_of_a_out_channel
return delta_memory

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherAlgorithm # noqa

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherConv2d # noqa
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherLinear # noqa
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherMixin # noqa

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherChannelUnit # noqa

View File

@ -10,7 +10,6 @@ from mmengine.utils.dl_utils.parrots_wrapper import \
SyncBatchNorm as EngineSyncBatchNorm
from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.utils import make_divisible
from mmrazor.registry import MODELS
from ..mutable_channel_container import MutableChannelContainer
from ..sequential_mutable_channel import SquentialMutableChannel
@ -134,6 +133,7 @@ class SequentialMutableChannelUnit(MutableChannelUnit):
def _make_divisible(self, choice_int: int):
"""Make the choice divisible."""
from mmrazor.models.utils import make_divisible
return make_divisible(choice_int, self.divisor, self.min_value,
self.min_ratio)

View File

@ -66,6 +66,7 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]):
dict,
Type[MutableChannelUnit]] = SequentialMutableChannelUnit,
parse_cfg: Dict = dict(
_scope_='mmrazor',
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer'),

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherChannelMutator # noqa

View File

@ -6,6 +6,7 @@ from mmengine.model import BaseModel
from mmrazor.registry import TASK_UTILS
from mmrazor.utils import get_placeholder
from ...algorithms.base import BaseAlgorithm
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
DefaultMMDemoInput, DefaultMMDetDemoInput,
DefaultMMPoseDemoInput, DefaultMMRotateDemoInput,
@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope):
def defaul_demo_inputs(model, input_shape, training=False, scope=None):
"""Get demo input according to a model and scope."""
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)
if isinstance(model, BaseAlgorithm):
return defaul_demo_inputs(model.architecture, input_shape, training,
scope)
else:
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)
@TASK_UTILS.register_module()

View File

@ -51,7 +51,9 @@ class DefaultMMDemoInput(BaseDemoInput):
return data
def _get_mm_data(self, model, input_shape, training=False):
return {'inputs': torch.rand(input_shape), 'data_samples': None}
data = {'inputs': torch.rand(input_shape), 'data_samples': None}
data = model.data_preprocessor(data, training)
return data
@TASK_UTILS.register_module()
@ -84,7 +86,7 @@ class DefaultMMDetDemoInput(DefaultMMDemoInput):
"""Helper for get_data, including core logic to generate demo input."""
from mmdet.models import BaseDetector
from mmdet.testing._utils import demo_mm_inputs
assert isinstance(model, BaseDetector)
assert isinstance(model, BaseDetector), f'{type(model)}'
data = demo_mm_inputs(1, [input_shape[1:]], with_mask=True)
data = model.data_preprocessor(data, training)
@ -132,7 +134,7 @@ class DefaultMMPoseDemoInput(DefaultMMDemoInput):
from mmpose.models import TopdownPoseEstimator
from .mmpose_demo_input import demo_mmpose_inputs
assert isinstance(model, TopdownPoseEstimator)
assert isinstance(model, TopdownPoseEstimator), f'{type(model)}'
data = demo_mmpose_inputs(model, input_shape)
return data

View File

@ -13,10 +13,24 @@ from .pooling_layer_counter import * # noqa: F403, F405, F401
from .upsample_layer_counter import UpsampleCounter
__all__ = [
'ReLUCounter', 'PReLUCounter', 'ELUCounter', 'LeakyReLUCounter',
'ReLU6Counter', 'BatchNorm1dCounter', 'BatchNorm2dCounter',
'BatchNorm3dCounter', 'Conv1dCounter', 'Conv2dCounter', 'Conv3dCounter',
'ConvTranspose2dCounter', 'UpsampleCounter', 'LinearCounter',
'GroupNormCounter', 'InstanceNorm1dCounter', 'InstanceNorm2dCounter',
'InstanceNorm3dCounter', 'LayerNormCounter', 'BaseCounter'
'ReLUCounter',
'PReLUCounter',
'ELUCounter',
'LeakyReLUCounter',
'ReLU6Counter',
'BatchNorm1dCounter',
'BatchNorm2dCounter',
'BatchNorm3dCounter',
'Conv1dCounter',
'Conv2dCounter',
'Conv3dCounter',
'ConvTranspose2dCounter',
'UpsampleCounter',
'LinearCounter',
'GroupNormCounter',
'InstanceNorm1dCounter',
'InstanceNorm2dCounter',
'InstanceNorm3dCounter',
'LayerNormCounter',
'BaseCounter',
]

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from mmrazor.registry import TASK_UTILS
from .conv_layer_counter import Conv2dCounter
from .linear_layer_counter import LinearCounter
@TASK_UTILS.register_module()
class DynamicConv2dCounter(Conv2dCounter):
"""Flop counter for DynamicCon2d."""
@staticmethod
def add_count_hook(module: nn.Conv2d, input: Tuple[torch.Tensor],
output: torch.Tensor) -> None:
"""Count the flops and params of a DynamicConv2d.
Args:
module (nn.Conv2d): A Conv2d module.
input (Tuple[torch.Tensor]): Input of this module.
output (torch.Tensor): Output of this module.
"""
batch_size = input[0].shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(module.kernel_size)
out_channels = module.mutable_attrs['out_channels'].activated_channels
in_channels = module.mutable_attrs['in_channels'].activated_channels
groups = module.groups
filters_per_channel = out_channels / groups
conv_per_position_flops = int(
np.prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
overall_params = conv_per_position_flops
bias_flops = 0
overall_params = conv_per_position_flops
if module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_params += out_channels
overall_flops = overall_conv_flops + bias_flops
module.__flops__ += overall_flops
module.__params__ += int(overall_params)
@TASK_UTILS.register_module()
class DynamicLinearCounter(LinearCounter):
pass

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file includes the modules in the impl folder.
As it only records impl modules, it is not initialized automatically.
"""
from mmrazor.implementations.pruning.group_fisher import ( # noqa
GroupFisherConv2dCounter, GroupFisherLinearCounter)

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This module is used to expand the channels of a supernet.
We only expose some tool functions, rather than all DynamicOps and
MutableChannelUnits, as They uses a few hacky operations.
"""
from .tools import (expand_expandable_dynamic_model, expand_static_model,
make_channel_divisible, to_expandable_model)
__all__ = [
'make_channel_divisible',
'to_expandable_model',
'expand_expandable_dynamic_model',
'expand_static_model',
]

View File

@ -0,0 +1,237 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.mutables import MutableChannelContainer
class ExpandableMixin:
"""This minin coroperates with dynamic ops.
It defines interfaces to expand the channels of ops. We can get a wider
network than original supernet with it.
"""
def expand(self, zero=False):
"""Expand the op.
Args:
zero (bool, optional): whether to set new weights to zero. Defaults
to False.
"""
return self.get_expand_op(
self.expanded_in_channel,
self.expanded_out_channel,
zero=zero,
)
def get_expand_op(self, in_c, out_c, zero=False):
"""Get an expanded op.
Args:
in_c (int): New input channels
out_c (int): New output channels
zero (bool, optional): Whether to zero new weights. Defaults to
False.
"""
pass
@property
def _original_in_channel(self):
"""Return original in channel."""
raise NotImplementedError()
@property
def _original_out_channel(self):
"""Return original out channel."""
@property
def expanded_in_channel(self):
"""Return expanded in channel number."""
if self.in_mutable is not None:
return self.in_mutable.current_mask.numel()
else:
return self._original_in_channel
@property
def expanded_out_channel(self):
"""Return expanded out channel number."""
if self.out_mutable is not None:
return self.out_mutable.current_mask.numel()
else:
return self._original_out_channel
@property
def mutable_in_mask(self):
"""Return the mutable in mask."""
if self.in_mutable is not None:
return self.in_mutable.current_mask
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_in_channel])
else:
return torch.ones([self.expanded_in_channel])
@property
def mutable_out_mask(self):
"""Return the mutable out mask."""
if self.out_mutable is not None:
return self.out_mutable.current_mask
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_out_channel])
else:
return torch.ones([self.expanded_out_channel])
@property
def in_mutable(self) -> MutableChannelContainer:
"""In channel mask."""
return self.get_mutable_attr('in_channels') # type: ignore
@property
def out_mutable(self) -> MutableChannelContainer:
"""Out channel mask."""
return self.get_mutable_attr('out_channels') # type: ignore
def zero_weight_(self: nn.Module):
"""Zero all weights."""
for p in self.parameters():
p.data.zero_()
@torch.no_grad()
def expand_matrix(self, weight: torch.Tensor, old_weight: torch.Tensor):
"""Expand weight matrix."""
assert len(weight.shape) == 3 # out in c
assert len(old_weight.shape) == 3 # out in c
mask = self.mutable_out_mask.float().unsqueeze(
-1) * self.mutable_in_mask.float().unsqueeze(0)
mask = mask.unsqueeze(-1).expand(*weight.shape)
weight.data.masked_scatter_(mask.bool(), old_weight)
return weight
@torch.no_grad()
def expand_vector(self, weight: torch.Tensor, old_weight: torch.Tensor):
"""Expand weight vector which has the shape of [out, c]."""
assert len(weight.shape) == 2 # out c
assert len(old_weight.shape) == 2 # out c
mask = self.mutable_out_mask
mask = mask.unsqueeze(-1).expand(*weight.shape)
weight.data.masked_scatter_(mask.bool(), old_weight)
return weight
@torch.no_grad()
def expand_bias(self, bias: torch.Tensor, old_bias: torch.Tensor):
"""Expand bias."""
assert len(bias.shape) == 1 # out c
assert len(old_bias.shape) == 1 # out c
return self.expand_vector(bias.unsqueeze(-1),
old_bias.unsqueeze(-1)).squeeze(1)
class ExpandableConv2d(dynamic_ops.DynamicConv2d, ExpandableMixin):
@property
def _original_in_channel(self):
return self.in_channels
@property
def _original_out_channel(self):
return self.out_channels
def get_expand_op(self, in_c, out_c, zero=False):
if self.groups == 1:
return self._get_expand_op_normal_conv(in_c, out_c, zero=zero)
elif self.in_channels == self.out_channels == self.groups:
return self._get_expand_op_dw_conv(in_c, out_c, zero=zero)
else:
raise NotImplementedError('Groupwise conv is not supported yet.')
def _get_expand_op_normal_conv(self, in_c, out_c, zero=False):
module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride,
self.padding, self.dilation, self.groups, self.bias
is not None, self.padding_mode)
if zero:
ExpandableMixin.zero_weight_(module)
weight = self.expand_matrix(
module.weight.flatten(2), self.weight.flatten(2))
module.weight.data = weight.reshape(module.weight.shape)
if module.bias is not None and self.bias is not None:
bias = self.expand_vector(
module.bias.unsqueeze(-1), self.bias.unsqueeze(-1))
module.bias.data = bias.reshape(module.bias.shape)
return module
def _get_expand_op_dw_conv(self, in_c, out_c, zero=False):
assert in_c == out_c
module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride,
self.padding, self.dilation, in_c, self.bias
is not None, self.padding_mode)
if zero:
ExpandableMixin.zero_weight_(module)
weight = self.expand_vector(
module.weight.flatten(1), self.weight.flatten(1))
module.weight.data = weight.reshape(module.weight.shape)
if module.bias is not None and self.bias is not None:
bias = self.expand_vector(
module.bias.unsqueeze(-1), self.bias.unsqueeze(-1))
module.bias.data = bias.reshape(module.bias.shape)
return module
class ExpandLinear(dynamic_ops.DynamicLinear, ExpandableMixin):
@property
def _original_in_channel(self):
return self.in_features
@property
def _original_out_channel(self):
return self.out_features
def get_expand_op(self, in_c, out_c, zero=False):
module = nn.Linear(in_c, out_c, self.bias is not None)
if zero:
ExpandableMixin.zero_weight_(module)
weight = self.expand_matrix(
module.weight.unsqueeze(-1), self.weight.unsqueeze(-1))
module.weight.data = weight.reshape(module.weight.shape)
if module.bias is not None:
bias = self.expand_vector(
module.bias.unsqueeze(-1), self.bias.unsqueeze(-1))
module.bias.data = bias.reshape(module.bias.shape)
return module
class ExpandableBatchNorm2d(dynamic_ops.DynamicBatchNorm2d, ExpandableMixin):
@property
def _original_in_channel(self):
return self.num_features
@property
def _original_out_channel(self):
return self.num_features
def get_expand_op(self, in_c, out_c, zero=False):
assert in_c == out_c
module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine,
self.track_running_stats)
if zero:
ExpandableMixin.zero_weight_(module)
if module.running_mean is not None:
module.running_mean.data = self.expand_bias(
module.running_mean, self.running_mean)
if module.running_var is not None:
module.running_var.data = self.expand_bias(module.running_var,
self.running_var)
module.weight.data = self.expand_bias(module.weight, self.weight)
module.bias.data = self.expand_bias(module.bias, self.bias)
return module

View File

@ -0,0 +1,84 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict
import torch.nn as nn
from mmrazor.models.mutators import ChannelMutator
from .ops import ExpandableMixin
from .unit import ExpandableUnit
def to_expandable_model(model: nn.Module) -> ChannelMutator[ExpandableUnit]:
"""Convert a static model to an expandable model."""
state_dict = model.state_dict()
mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit)
mutator.prepare_from_supernet(model)
model.load_state_dict(state_dict)
return mutator
def expand_expandable_dynamic_model(model: nn.Module, zero=False) -> nn.Module:
"""Expand a expandable model and return a expanded static model.
Args:
model (nn.Module): The model to be expanded.
zero (bool, optional): Whether to zero expanded weight. Defaults to
False.
"""
def traverse_children(module: nn.Module) -> None:
for name, mutable in module.items():
if isinstance(mutable, ExpandableMixin):
module[name] = mutable.expand(zero=zero)
if hasattr(mutable, '_modules'):
traverse_children(mutable._modules)
if isinstance(model, ExpandableMixin):
raise RuntimeError('Root model can not be dynamic op.')
if hasattr(model, '_modules'):
traverse_children(model._modules)
return model
def expand_static_model(model: nn.Module, structure: Dict, zero_weight=True):
"""Expand the channels of a model.
Args:
model (nn.Module): the model to be expanded.
structure (Dict): the channel structure for the model.
divisor (_type_): the divisor to make the channels divisible.
"""
mutator = to_expandable_model(model)
for key, value in structure.items():
mutator._name2unit[key].expand_to(value)
expand_expandable_dynamic_model(model, zero=zero_weight)
return model
def make_channel_divisible(model: nn.Module, divisor, zero_weight=True):
"""Expand the channels of a model and return the new divisible channel
structure.
Args:
model (nn.Module): the model to be expanded.
divisor (_type_): the divisor to make the channels divisible.
"""
# to sta
mutator = to_expandable_model(model)
structure = mutator.choice_template
for key, num in structure.items():
unit = mutator._name2unit[key]
if num % divisor == 0:
continue
else:
num = (num // divisor + 1) * divisor
num = max(num, unit.num_channels)
unit.expand_to(num)
model = expand_expandable_dynamic_model(model, zero=zero_weight)
mutator = to_expandable_model(copy.deepcopy(model))
return mutator.choice_template

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmrazor.models.mutables import (L1MutableChannelUnit,
MutableChannelContainer)
from .ops import ExpandableBatchNorm2d, ExpandableConv2d, ExpandLinear
class ExpandableUnit(L1MutableChannelUnit):
"""The units to inplace modules with expandable dynamic ops."""
def prepare_for_pruning(self, model: nn.Module):
self._replace_with_dynamic_ops(
model, {
nn.Conv2d: ExpandableConv2d,
nn.BatchNorm2d: ExpandableBatchNorm2d,
nn.Linear: ExpandLinear,
})
self._register_channel_container(model, MutableChannelContainer)
self._register_mutable_channel(self.mutable_channel)
def expand(self, num):
expand_mask = self.mutable_channel.mask.new_zeros([num])
mask = torch.cat([self.mutable_channel.mask, expand_mask])
self.mutable_channel.mask = mask
def expand_to(self, num):
self.expand(num - self.num_channels)

View File

@ -3,6 +3,7 @@ from .index_dict import IndexDict
from .log_tools import get_level, print_log
from .misc import find_latest_checkpoint
from .placeholder import get_placeholder
from .runtime_info import RuntimeInfo
from .setup_env import register_all_modules, setup_multi_processes
from .typing import (FixMutable, MultiMutatorsRandomSubnet,
SingleMutatorRandomSubnet, SupportRandomSubnet,
@ -12,5 +13,5 @@ __all__ = [
'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules',
'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet',
'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder',
'IndexDict', 'get_level', 'print_log'
'IndexDict', 'get_level', 'print_log', 'RuntimeInfo'
]

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmengine import Config, MessageHub
class RuntimeInfo():
"""A tools to get runtime info in MessageHub."""
@classmethod
def info(cls):
hub = MessageHub.get_current_instance()
return hub.runtime_info
@classmethod
def get_info(cls, key):
info = cls.info()
if key in info:
return info[key]
else:
raise KeyError(key)
@classmethod
def epoch(cls):
return cls.get_info('epoch')
@classmethod
def max_epochs(cls):
return cls.get_info('max_epochs')
@classmethod
def iter(cls):
return cls.get_info('iter')
@classmethod
def max_iters(cls):
return cls.get_info('max_iters')
@classmethod
def iter_by_epoch(cls):
iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs())
return cls.iter() % iter_per_epoch
@classmethod
def iter_pre_epoch(cls):
iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs())
return iter_per_epoch
@classmethod
def config(cls):
cfg: str = cls.get_info('cfg')
config = Config.fromstring(cfg, '.py')
return config
@classmethod
def work_dir(cls):
config = cls.config()
return config['work_dir']

View File

@ -63,6 +63,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
import mmrazor.datasets # noqa: F401,F403
import mmrazor.engine # noqa: F401,F403
import mmrazor.implementations # noqa: F401,F403
import mmrazor.models # noqa: F401,F403
import mmrazor.structures # noqa: F401,F403
if init_default_scope:

View File

@ -78,6 +78,7 @@ class ModuleWithUntracableMethod(nn.Module):
x = x * -2
return x
@MODELS.register_module()
class UntracableBackBone(nn.Module):
@ -106,7 +107,6 @@ class UntracableModel(nn.Module):
return self.head(self.backbone(x))
class ConvAttnModel(Module):
def __init__(self) -> None:
@ -123,6 +123,7 @@ class ConvAttnModel(Module):
x_last = self.conv2(x_attn)
return self.head(x_last)
@MODELS.register_module()
class LinearHeadForTest(Module):
@ -623,6 +624,27 @@ class SelfAttention(nn.Module):
return self.proj(y)
def MMClsResNet18() -> BaseModel:
model_cfg = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
return MODELS.build(model_cfg)
# models with dynamicop
@ -682,7 +704,7 @@ class SampleExpandDerivedMutable(BaseMutable):
def current_choice(self, choice):
super().current_choice(choice)
class DynamicLinearModel(nn.Module):
"""
x
@ -843,7 +865,7 @@ class DynamicMMBlock(nn.Module):
[4, 6, 1],
[4, 6, 1],
[6, 6, 1],
[6, 6, 1]
[6, 6, 1],
],
num_out_channels=[ # [min_channel, max_channel, step]
[16, 24, 8],
@ -852,11 +874,11 @@ class DynamicMMBlock(nn.Module):
[64, 72, 8],
[112, 128, 8],
[192, 216, 8],
[216, 224, 8]
[216, 224, 8],
])
def __init__(
self,
self,
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
fine_grained_mode: bool = False,
@ -936,12 +958,11 @@ class DynamicMMBlock(nn.Module):
act_cfg=dict(type='Swish')))]))
self.add_module('last_conv', last_layers)
self.layers.append(last_layers)
self.register_mutables()
def _make_single_layer(self, out_channels, num_blocks,
kernel_sizes, expand_ratios,
act_cfg, stride, use_se):
def _make_single_layer(self, out_channels, num_blocks, kernel_sizes,
expand_ratios, act_cfg, stride, use_se):
_layers = []
for i in range(max(num_blocks)):
if i >= 1:

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcls.structures import ClsDataSample
from mmengine import MessageHub
from mmrazor.implementations.pruning.group_fisher.algorithm import \
GroupFisherAlgorithm
from mmrazor.implementations.pruning.group_fisher.ops import GroupFisherConv2d
from ....data.models import MMClsResNet18
if torch.cuda.is_available():
DEVICE = torch.device('cuda:0')
else:
DEVICE = torch.device('cpu')
class TestGroupFisherPruneAlgorithm(TestCase):
def fake_cifar_data(self):
imgs = torch.randn(16, 3, 32, 32).to(DEVICE)
data_samples = [
ClsDataSample().set_gt_label(torch.randint(0, 10,
(16, ))).to(DEVICE)
]
return {'inputs': imgs, 'data_samples': data_samples}
def test_group_fisher_prune(self):
data = self.fake_cifar_data()
MUTATOR_CONFIG = dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(
type='ChannelAnalyzer', tracer_type='BackwardTracer'),
channel_unit_cfg=dict(type='GroupFisherChannelUnit'))
epoch = 2
interval = 1
algorithm = GroupFisherAlgorithm(
MMClsResNet18(), mutator=MUTATOR_CONFIG,
interval=interval).to(DEVICE)
mutator = algorithm.mutator
for e in range(epoch):
for ite in range(10):
self._set_epoch_ite(e, ite, epoch)
algorithm.forward(
data['inputs'], data['data_samples'], mode='loss')
self.gen_fake_grad(mutator)
self.assertEqual(interval, algorithm.interval)
def gen_fake_grad(self, mutator):
for unit in mutator.mutable_units:
for channel in unit.input_related:
module = channel.module
if isinstance(module, GroupFisherConv2d):
module.recorded_grad = module.recorded_input
def _set_epoch_ite(self, epoch, ite, max_epoch):
iter_per_epoch = 10
message_hub = MessageHub.get_current_instance()
message_hub.update_info('epoch', epoch)
message_hub.update_info('max_epochs', max_epoch)
message_hub.update_info('max_iters', max_epoch * 10)
message_hub.update_info('iter', ite + iter_per_epoch * epoch)

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from unittest import TestCase
from mmengine import fileio
from mmrazor.implementations.pruning.group_fisher.prune_deploy_sub_model import \
GroupFisherDeploySubModel # noqa
from ....data.models import MMClsResNet18
from .test_prune_sub_model import PruneAlgorithm, get_model_structure
class TestPruneDeploySubModel(TestCase):
def test_build_sub_model(self):
model = MMClsResNet18()
parse_cfg = dict(
_scope_='mmrazor',
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer')
# get structure
algorithm = PruneAlgorithm(copy.deepcopy(model))
algorithm.random_prune()
strucutrue = algorithm.mutator.current_choices
# test divisor
wrapper = GroupFisherDeploySubModel(
copy.deepcopy(model), strucutrue, divisor=1, parse_cfg=parse_cfg)
self.assertSequenceEqual(
list(strucutrue.values()),
list(get_model_structure(wrapper).values()))
wrapper = GroupFisherDeploySubModel(
copy.deepcopy(model), strucutrue, divisor=8, parse_cfg=parse_cfg)
self.assertSequenceEqual(
list(strucutrue.values()),
list(get_model_structure(wrapper).values()))
mutable_path = os.path.dirname(__file__) + '/mutable.json'
fileio.dump(algorithm.mutator.current_choices, mutable_path)
GroupFisherDeploySubModel(
copy.deepcopy(model),
divisor=1,
mutable_cfg=mutable_path,
parse_cfg=parse_cfg)
os.remove(mutable_path)

View File

@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Union
from unittest import TestCase
import torch
from mmrazor.implementations.pruning.group_fisher.prune_sub_model import \
GroupFisherSubModel
from mmrazor.models import BaseAlgorithm
from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS
from ....data.models import MMClsResNet18
class PruneAlgorithm(BaseAlgorithm):
def __init__(self,
architecture,
mutator: Union[Dict, ChannelMutator] = dict(
type='ChannelMutator',
channel_unit_cfg=dict(
type='SequentialMutableChannelUnit')),
data_preprocessor=None,
init_cfg=None) -> None:
super().__init__(
architecture, data_preprocessor, init_cfg, module_inplace=False)
if isinstance(mutator, dict):
mutator = MODELS.build(mutator)
assert isinstance(mutator, ChannelMutator)
self.mutator = mutator
mutator.prepare_from_supernet(self.architecture)
def random_prune(self):
choices = self.mutator.sample_choices()
self.mutator.set_choices(choices)
def get_model_structure(model):
algorithm = PruneAlgorithm(copy.deepcopy(model))
return algorithm.mutator.current_choices
class TestPruneSubModel(TestCase):
def test_build_sub_model(self):
x = torch.rand([1, 3, 224, 224])
model = MMClsResNet18()
algorithm = PruneAlgorithm(model)
algorithm.random_prune()
# test divisor
static_model1 = GroupFisherSubModel(algorithm, divisor=1)
self.assertSequenceEqual(
list(algorithm.mutator.current_choices.values()),
list(get_model_structure(static_model1).values()))
static_model2 = GroupFisherSubModel(algorithm, divisor=8)
for value in get_model_structure(static_model2).values():
self.assertTrue(value % 8 == 0)
y1 = static_model1(x)
y2 = static_model2(x)
self.assertTrue((y1 - y2).abs().max() < 1e-3)

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmrazor.implementations.pruning.group_fisher import \
GroupFisherChannelMutator
from ....data.models import MMClsResNet18
class TestGroupFisherChannelUnit(unittest.TestCase):
def test_init(self):
model = MMClsResNet18()
mutator = GroupFisherChannelMutator(
parse_cfg=dict(
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer'))
mutator.prepare_from_supernet(model)
x = torch.rand([1, 3, 224, 224])
mutator.start_record_info()
for i in range(2):
model.train()
loss = model(x).sum()
loss.backward()
mutator.end_record_info()
for unit in mutator.mutable_units:
for module in unit.input_related_dynamic_ops:
self.assertEqual(len(module.recorded_input), 2)
self.assertEqual(len(module.recorded_grad), 2)
self.assertIsInstance(module.recorded_grad[0], torch.Tensor)
unit = mutator.mutable_units[0]
fisher = unit._fisher_of_a_module(next(unit.input_related_dynamic_ops))
self.assertEqual(list(fisher.shape), [1, unit.num_channels])
fisher = unit.current_batch_fisher
self.assertEqual(list(fisher.shape), [unit.num_channels])
fisher = unit._get_normalized_fisher_info(fisher, unit.delta_type)
unit.update_fisher_info()

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmrazor.models.mutables import SimpleMutableChannel
from mmrazor.models.utils.expandable_utils import (
expand_expandable_dynamic_model, make_channel_divisible,
to_expandable_model)
from mmrazor.models.utils.expandable_utils.ops import ExpandLinear
from ....data.models import DwConvModel, MultiConcatModel, SingleLineModel
class TestExpand(unittest.TestCase):
def test_expand(self):
for Model in [MultiConcatModel, DwConvModel]:
x = torch.rand([1, 3, 224, 224])
model = Model()
print(model)
mutator = to_expandable_model(model)
print(mutator.choice_template)
print(model)
y1 = model(x)
for unit in mutator.mutable_units:
unit.expand(10)
print(unit.mutable_channel.mask.shape)
expand_expandable_dynamic_model(model, zero=True)
print(model)
y2 = model(x)
self.assertTrue((y1 - y2).abs().max() < 1e-3)
def test_expand_static_model(self):
x = torch.rand([1, 3, 224, 224])
model = SingleLineModel()
y1 = model(x)
make_channel_divisible(model, divisor=4)
y2 = model(x)
print(y1.reshape([-1])[:5])
print(y2.reshape([-1])[:5])
self.assertTrue((y1 - y2).abs().max() < 1e-3)
def test_ExpandConv2d(self):
linear = ExpandLinear(3, 3)
mutable_in = SimpleMutableChannel(3)
mutable_out = SimpleMutableChannel(3)
linear.register_mutable_attr('in_channels', mutable_in)
linear.register_mutable_attr('out_channels', mutable_out)
print(linear.weight)
mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
linear_ex = linear.expand(zero=True)
print(linear_ex.weight)

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine import Config
from mmrazor.models.algorithms import ItePruneAlgorithm
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
from mmrazor.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config')
parser.add_argument('-H', default=224, type=int)
parser.add_argument('-W', default=224, type=int)
args = parser.parse_args()
return args
def input_generator_wrapper(model, shape, training, scope=None):
def input_generator(input_shape):
inputs = DefaultDemoInput(scope=scope).get_data(
model, input_shape=input_shape, training=training)
if isinstance(input, dict) and 'mode' in inputs:
inputs['mode'] = 'tensor'
return inputs
return input_generator
if __name__ == '__main__':
args = parse_args()
config = Config.fromfile(args.config)
H = args.H
W = args.W
default_scope = config['default_scope']
model_config = config['model']
# model_config['_scope_'] = default_scope
model: ItePruneAlgorithm = MODELS.build(model_config)
estimator = ResourceEstimator(
flops_params_cfg=dict(
input_shape=(1, 3, H, W),
print_per_layer_stat=False,
input_constructor=input_generator_wrapper(
model,
(1, 3, H, W),
training=False,
scope=default_scope,
)))
result = estimator.estimate(model)
print(result)