Inhence groupfisher (#474)

* update

* update

* add pose

* update

* fix bug in mmpose demo input

* add dist

* add config

* update

* update for deploy

* fix bug

* remove dist and make no positional warn only once

* fix bug

* update

* fix for ci

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>
pull/479/head
LKJacky 2023-03-09 16:34:28 +08:00 committed by GitHub
parent 01f671c72d
commit 9446b301a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 772 additions and 113 deletions

View File

@ -12,28 +12,45 @@ Network compression has been widely studied since it is able to reduce the memor
### Classification on ImageNet
| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------ | ----- | ----- | ------- | --------- | ------------- | --------- | ------------------------------------- | ----------------------------------------------------- |
| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] |
| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] |
| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] |
| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] |
| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] |
| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] |
### Detection on COCO
| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- |
| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] |
| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] |
| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] |
| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | Onnx_cpu(FPS) |
| ----------------------------- | ----- | ----- | ------- | --------- | ------------- | --------- | ---------------------------------------- | ----------------------------------------------------------- | ------------- |
| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] | 55.360 |
| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] | 80.671 |
| ResNet50_pruned_act + dist kd | 76.50 | -0.05 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc_kd] | [pruned][r_a_p] \| [finetuned][r_a_f_kd] \| [log][r_a_l_kd] | 80.671 |
| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] | 78.674 |
| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] | 419.673 |
| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] | 576.118 |
| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] | 540.105 |
**Note**
- Because the pruning papers use different pretraining and finetuning settings, It is hard to compare them fairly. As a result, we prefer to apply algorithms on the openmmlab settings.
- This may make the experiment results are different from that in the original papers.
### Detection on COCO
| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | Onnx_cpu(FPS) |
| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- | ------------- |
| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] | 1.095 |
| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] | 1.608 |
| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] | 1.609 |
### Pose on COCO
| Model | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | Onnx_cpu(FPS) |
| -------------------- | ----- | ------ | ------- | --------- | ------------- | --------- | --------------------------------------- | ----------------------------------------------------------- | ------------- |
| rtmpose-s | 0.716 | - | 0.68 | - | 5.47 | - | [mmpose][pose_s_c] | [model][pose_s_m] | 196 |
| rtmpose-s_pruned_act | 0.691 | -0.025 | 0.34 | 50.0% | 3.42 | 62.5% | [prune][rp_a_pc] \| [finetune][rp_a_fc] | [pruned][rp_sc_p] \| [finetuned][rp_sc_f] \| [log][rp_sc_l] | 268 |
| rtmpose-t | 0.682 | - | 0.35 | - | 3.34 | - | [mmpose][pose_t_c] | [model][pose_t_m] | 279 |
| Model | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | Onnx_cpu(FPS) |
| ----------------------------- | ----- | ------ | ------- | --------- | ------------- | --------- | --------------------------------------- | ----------------------------------------------------------- | ------------- |
| rtmpose-s-aic-coco | 0.722 | - | 0.68 | - | 5.47 | - | [mmpose][pose_s_c] | [model][pose_s_m] | 196 |
| rtmpose-s-aic-coco_pruned_act | 0.694 | -0.028 | 0.35 | 51.5% | 3.43 | 62.7% | [prune][rp_a_pc] \| [finetune][rp_a_fc] | [pruned][rp_sa_p] \| [finetuned][rp_sa_f] \| [log][rp_sa_l] | 272 |
| rtmpose-t-aic-coco | 0.685 | - | 0.35 | - | 3.34 | - | [mmpose][pose_t_c] | [model][pose_t_m] | 279 |
- All FPS is test on the same machine with 11th Gen Intel(R) Core(TM) i7-11700 @ 2.50GHz.
## Get Started
We have three steps to apply GroupFisher to your model, including Prune, Finetune, Deploy.
@ -192,6 +209,18 @@ repo link
[m_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/20230201_211550.json
[m_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth
[m_f_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py
[pose_s_c]: https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth
[pose_s_m]: https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth
[pose_t_c]: https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.pth
[pose_t_m]: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
[rp_a_fc]: ../../mmpose/group_fisher/group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.py
[rp_a_pc]: ../../mmpose/group_fisher/group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.py
[rp_sa_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth
[rp_sa_l]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.json
[rp_sa_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth
[rp_sc_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.pth
[rp_sc_l]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.json
[rp_sc_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.pth
[rt_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_a_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
[rt_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/act/20230113_231904.json
@ -204,7 +233,10 @@ repo link
[rt_f_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py
[r_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth
[r_a_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py
[r_a_fc_kd]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k_dist.py
[r_a_f_kd]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k_dist.pth
[r_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/act/20230130_175426.json
[r_a_l_kd]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k_dist.json
[r_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth
[r_a_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py
[r_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth

View File

@ -32,7 +32,7 @@ architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
data_preprocessor = {}
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope

View File

@ -34,7 +34,7 @@ fix_subnet = {
'backbone.layer7.0.conv.0.conv_(0, 960)_960': 944,
'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320
}
divisor = 8
divisor = 16
##############################################################################

View File

@ -32,7 +32,7 @@ architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
data_preprocessor = {}
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope

View File

@ -34,7 +34,7 @@ fix_subnet = {
'backbone.layer7.0.conv.0.conv_(0, 960)_960': 771,
'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320
}
divisor = 8
divisor = 16
##############################################################################
architecture = _base_.model

View File

@ -5,3 +5,43 @@ bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fi
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py 8
# deploy act mode
razor_config=configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py
deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 224x224 \
--device cpu \
--num-iter 1000 \
--warmup 100
# deploy flop mode
razor_config=configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py
deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 224x224 \
--device cpu \
--num-iter 1000 \
--warmup 100

View File

@ -0,0 +1,61 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth' # noqa
finetune_lr = 0.1
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
teacher = algorithm.architecture
pruned = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
data_preprocessor=None,
architecture=pruned,
teacher=teacher,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='DISTLoss', tau=1, loss_weight=1)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None
_base_ = './resnet_group_fisher_prune.py'
# 76.3520

View File

@ -32,7 +32,7 @@ architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
data_preprocessor = {}
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope

View File

@ -46,7 +46,7 @@ fix_subnet = {
'backbone.layer4.2.conv1_(0, 512)_512': 443,
'backbone.layer4.2.conv2_(0, 512)_512': 376
}
divisor = 8
divisor = 16
##############################################################################
architecture = _base_.model

View File

@ -5,3 +5,45 @@ bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fis
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py.py 8
bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py 8
# deploy act mode
razor_config=configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_deploy_resnet50_8xb32_in1k.py
deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 224x224 \
--device cpu \
--num-iter 1000 \
--warmup 100
# deploy flops mode
razor_config=configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_deploy_resnet50_8xb32_in1k.py
deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 224x224 \
--device cpu \
--num-iter 1000 \
--warmup 100

View File

@ -57,7 +57,7 @@ fix_subnet = {
'bbox_head.reg_convs.2.conv_(0, 256)_256': 82,
'bbox_head.reg_convs.3.conv_(0, 256)_256': 117
}
divisor = 8
divisor = 16
##############################################################################

View File

@ -32,10 +32,11 @@ architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
data_preprocessor = {}
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
architecture.backbone.frozen_stages = -1
model = dict(
_delete_=True,

View File

@ -57,7 +57,7 @@ fix_subnet = {
'bbox_head.reg_convs.2.conv_(0, 256)_256': 76,
'bbox_head.reg_convs.3.conv_(0, 256)_256': 122,
}
divisor = 8
divisor = 16
##############################################################################

View File

@ -5,3 +5,45 @@ bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fi
# flops mode
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py 8
bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py 8
# deploy act mode
razor_config=configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py
deploy_config=mmdeploy/configs/mmdet/detection/detection_onnxruntime_static.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 800x1248 \
--device cpu \
--num-iter 1000 \
--warmup 100
# deploy flop mode
razor_config=configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py
deploy_config=mmdeploy/configs/mmdet/detection/detection_onnxruntime_static.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 800x1248 \
--device cpu \
--num-iter 1000 \
--warmup 100

View File

@ -0,0 +1,53 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmpose::body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_aic-coco-256x192.py' # noqa
fix_subnet = {
'backbone.stem.0.conv_(0, 16)_16': 8,
'backbone.stem.1.conv_(0, 16)_16': 9,
'backbone.stem.2.conv_(0, 32)_32': 9,
'backbone.stage1.0.conv_(0, 64)_64': 32,
'backbone.stage1.1.short_conv.conv_(0, 32)_32': 30,
'backbone.stage1.1.main_conv.conv_(0, 32)_32': 29,
'backbone.stage1.1.blocks.0.conv1.conv_(0, 32)_32': 24,
'backbone.stage1.1.final_conv.conv_(0, 64)_64': 27,
'backbone.stage2.0.conv_(0, 128)_128': 62,
'backbone.stage2.1.short_conv.conv_(0, 64)_64': 63,
'backbone.stage2.1.main_conv.conv_(0, 64)_64': 64,
'backbone.stage2.1.blocks.0.conv1.conv_(0, 64)_64': 56,
'backbone.stage2.1.blocks.1.conv1.conv_(0, 64)_64': 62,
'backbone.stage2.1.final_conv.conv_(0, 128)_128': 65,
'backbone.stage3.0.conv_(0, 256)_256': 167,
'backbone.stage3.1.short_conv.conv_(0, 128)_128': 127,
'backbone.stage3.1.main_conv.conv_(0, 128)_128': 128,
'backbone.stage3.1.blocks.0.conv1.conv_(0, 128)_128': 124,
'backbone.stage3.1.blocks.1.conv1.conv_(0, 128)_128': 123,
'backbone.stage3.1.final_conv.conv_(0, 256)_256': 172,
'backbone.stage4.0.conv_(0, 512)_512': 337,
'backbone.stage4.1.conv1.conv_(0, 256)_256': 256,
'backbone.stage4.1.conv2.conv_(0, 512)_512': 379,
'backbone.stage4.2.short_conv.conv_(0, 256)_256': 188,
'backbone.stage4.2.main_conv.conv_(0, 256)_256': 227,
'backbone.stage4.2.blocks.0.conv1.conv_(0, 256)_256': 238,
'backbone.stage4.2.blocks.0.conv2.pointwise_conv.conv_(0, 256)_256': 195,
'backbone.stage4.2.final_conv.conv_(0, 512)_512': 163
}
divisor = 8
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,53 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
_base_ = 'mmpose::body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_coco-256x192.py' # noqa
fix_subnet = {
'backbone.stem.0.conv_(0, 16)_16': 8,
'backbone.stem.1.conv_(0, 16)_16': 10,
'backbone.stem.2.conv_(0, 32)_32': 11,
'backbone.stage1.0.conv_(0, 64)_64': 32,
'backbone.stage1.1.short_conv.conv_(0, 32)_32': 32,
'backbone.stage1.1.main_conv.conv_(0, 32)_32': 23,
'backbone.stage1.1.blocks.0.conv1.conv_(0, 32)_32': 25,
'backbone.stage1.1.final_conv.conv_(0, 64)_64': 25,
'backbone.stage2.0.conv_(0, 128)_128': 71,
'backbone.stage2.1.short_conv.conv_(0, 64)_64': 61,
'backbone.stage2.1.main_conv.conv_(0, 64)_64': 62,
'backbone.stage2.1.blocks.0.conv1.conv_(0, 64)_64': 57,
'backbone.stage2.1.blocks.1.conv1.conv_(0, 64)_64': 59,
'backbone.stage2.1.final_conv.conv_(0, 128)_128': 69,
'backbone.stage3.0.conv_(0, 256)_256': 177,
'backbone.stage3.1.short_conv.conv_(0, 128)_128': 122,
'backbone.stage3.1.main_conv.conv_(0, 128)_128': 123,
'backbone.stage3.1.blocks.0.conv1.conv_(0, 128)_128': 125,
'backbone.stage3.1.blocks.1.conv1.conv_(0, 128)_128': 123,
'backbone.stage3.1.final_conv.conv_(0, 256)_256': 171,
'backbone.stage4.0.conv_(0, 512)_512': 351,
'backbone.stage4.1.conv1.conv_(0, 256)_256': 256,
'backbone.stage4.1.conv2.conv_(0, 512)_512': 367,
'backbone.stage4.2.short_conv.conv_(0, 256)_256': 183,
'backbone.stage4.2.main_conv.conv_(0, 256)_256': 216,
'backbone.stage4.2.blocks.0.conv1.conv_(0, 256)_256': 238,
'backbone.stage4.2.blocks.0.conv2.pointwise_conv.conv_(0, 256)_256': 195,
'backbone.stage4.2.final_conv.conv_(0, 512)_512': 187
}
divisor = 16
##############################################################################
architecture = _base_.model
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)

View File

@ -0,0 +1,32 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.py' # noqa
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth' # noqa
finetune_lr = 4e-3
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,33 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
_base_ = './group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.py'
pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.pth' # noqa
finetune_lr = 4e-3
##############################################################################
algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
# algorithm.update(dict(architecture=dict(test_cfg=dict(flip_test=False), ))) # disable flip test # noqa
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)
# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]
# delete ddp
model_wrapper_cfg = None

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = 'mmpose::body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_aic-coco-256x192.py' # noqa
pretrained_path = 'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-s_simcc-aic-coco_pt-aic-coco_420e-256x192-fcb2599b_20230126.pth' # noqa
interval = 10
normalization_type = 'act'
lr_ratio = 0.1
target_flop_ratio = 0.51
input_shape = (1, 3, 256, 192)
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
_base_ = 'mmpose::body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_coco-256x192.py' # noqa
pretrained_path = 'https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth' # noqa
interval = 10
normalization_type = 'act'
lr_ratio = 0.1
target_flop_ratio = 0.51
input_shape = (1, 3, 256, 192)
##############################################################################
architecture = _base_.model
if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None
architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope
model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)
model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)
optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))
custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]

View File

@ -0,0 +1,39 @@
# deploy rtmpose-s_pruned_act
razor_config=configs/pruning/mmpose/group_fisher/group_fisher_deploy_rtmpose-s_8xb256-420e_coco-256x192.py
deploy_config=mmdeploy/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 256x192 \
--device cpu \
--num-iter 1000 \
--warmup 100
# deploy rtmpose-s-aic-coco_pruned_act
razor_config=configs/pruning/mmpose/group_fisher/group_fisher_deploy_rtmpose-s_8xb256-420e_aic-coco-256x192.py
deploy_config=mmdeploy/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
python mmdeploy/tools/deploy.py $deploy_config \
$razor_config \
https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth \
mmdeploy/tests/data/tiger.jpeg \
--work-dir ./work_dirs/mmdeploy
python mmdeploy/tools/profiler.py $deploy_config \
$razor_config \
mmdeploy/demo/resources \
--model ./work_dirs/mmdeploy/end2end.onnx \
--shape 256x192 \
--device cpu \
--num-iter 1000 \
--warmup 100

View File

@ -79,6 +79,15 @@ class PruningStructureHook(Hook):
self.show(runner)
def input_generator_wrapper(model, demp_input: DefaultDemoInput):
def input_generator(input_shape):
res = demp_input.get_data(model)
return res
return input_generator
@HOOKS.register_module()
class ResourceInfoHook(Hook):
"""This hook is used to display the resource related information and save
@ -167,7 +176,13 @@ class ResourceInfoHook(Hook):
with torch.no_grad():
training = model.training
model.eval()
res = self.estimator.estimate(model)
res = self.estimator.estimate(
model,
flops_params_cfg=dict(
input_constructor=input_generator_wrapper(
model,
self.demo_input,
)))
if training:
model.train()
return res

View File

@ -1,16 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import types
from typing import Union
import torch.nn as nn
from mmengine import fileio
from mmrazor.models.utils.expandable_utils import make_channel_divisible
from mmrazor.registry import MODELS
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
from mmrazor.utils import print_log
def post_process_for_mmdeploy_wrapper(divisor):
def post_process_for_mmdeploy(model: nn.Module):
s = make_channel_divisible(model, divisor=divisor)
print_log(f'structure after make divisible: {json.dumps(s,indent=4)}')
return post_process_for_mmdeploy
@MODELS.register_module()
def GroupFisherDeploySubModel(architecture,
fix_subnet: Union[dict, str] = {},
@ -60,6 +71,8 @@ def GroupFisherDeploySubModel(architecture,
# cooperate with mmdeploy to make the channel divisible after load
# the checkpoint.
if divisor != 1:
setattr(architecture, '_razor_divisor', divisor)
setattr(
architecture, 'post_process_for_mmdeploy',
types.MethodType(
post_process_for_mmdeploy_wrapper(divisor), architecture))
return architecture

View File

@ -89,15 +89,18 @@ class DefaultDemoInput(BaseDemoInput):
scope (str, optional): mm scope name. Defaults to None.
"""
def __init__(self,
input_shape=None,
training=False,
scope: str = None) -> None:
def __init__(
self,
input_shape=None,
training=False,
scope: str = None,
kwargs={},
) -> None:
default_demo_input_class = get_default_demo_input_class(None, scope)
if input_shape is None:
input_shape = default_demo_input_class.default_shape
super().__init__(input_shape, training)
super().__init__(input_shape, training, kwargs=kwargs)
self.scope = scope
def _get_data(self, model, input_shape, training):

View File

@ -11,13 +11,18 @@ class BaseDemoInput():
Args:
input_shape: Default input shape. Defaults to default_shape.
training (bool, optional): Default training mode. Defaults to None.
kwargs (dict): Other keyword args to update the generated inputs.
"""
default_shape = (1, 3, 224, 224)
def __init__(self, input_shape=default_shape, training=None) -> None:
def __init__(self,
input_shape=default_shape,
training=None,
kwargs={}) -> None:
self.input_shape = input_shape
self.training = training
self.kwargs = kwargs
def get_data(self, model, input_shape=None, training=None):
"""Api to generate demo input."""
@ -26,7 +31,10 @@ class BaseDemoInput():
if training is None:
training = self.training
return self._get_data(model, input_shape, training)
data = self._get_data(model, input_shape, training)
if isinstance(data, dict):
data.update(self.kwargs)
return data
def _get_data(self, model, input_shape, training):
"""Helper for get_data, including core logic to generate demo input."""

View File

@ -22,7 +22,6 @@ except ImportError:
def demo_mmpose_inputs(model, for_training=False, batch_size=1):
input_shape = (
1,
3,
@ -30,82 +29,80 @@ def demo_mmpose_inputs(model, for_training=False, batch_size=1):
imgs = torch.randn(*input_shape)
batch_data_samples = []
from mmpose.models.heads import RTMHead
if isinstance(model.head, HeatmapHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1])
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1])['data_samples']
elif isinstance(model.head, MSPNHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size,
with_heatmap=True,
with_reg_label=False,
num_levels=model.head.num_stages * model.head.num_units)
]
batch_data_samples = get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size,
with_heatmap=True,
with_reg_label=False,
num_levels=model.head.num_stages *
model.head.num_units)['data_samples']
elif isinstance(model.head, CPMHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1],
with_heatmap=True,
with_reg_label=False)
]
batch_data_samples = get_packed_inputs(
batch_size=batch_size,
num_instances=1,
num_keypoints=model.head.out_channels,
heatmap_size=model.head.decoder.heatmap_size[::-1],
with_heatmap=True,
with_reg_label=False)['data_samples']
elif isinstance(model.head, SimCCHead):
# bug
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
simcc_split_ratio=model.head.decoder.simcc_split_ratio,
input_size=model.head.decoder.input_size,
with_simcc_label=True)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
simcc_split_ratio=model.head.decoder.simcc_split_ratio,
input_size=model.head.decoder.input_size,
with_simcc_label=True)['data_samples']
elif isinstance(model.head, ViPNASHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
)['data_samples']
elif isinstance(model.head, DSNTHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)['data_samples']
elif isinstance(model.head, IntegralRegressionHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)['data_samples']
elif isinstance(model.head, RegressionHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)['data_samples']
elif isinstance(model.head, RLEHead):
batch_data_samples = [
inputs['data_sample'] for inputs in get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)
]
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.num_joints,
with_reg_label=True)['data_samples']
elif isinstance(model.head, RTMHead):
batch_data_samples = get_packed_inputs(
batch_size,
num_keypoints=model.head.out_channels,
simcc_split_ratio=model.head.decoder.simcc_split_ratio,
input_size=model.head.decoder.input_size,
with_simcc_label=True)['data_samples']
else:
raise AssertionError('Head Type is Not Predefined')
raise AssertionError(f'Head Type {type(model.head)} is Not Predefined')
mm_inputs = {
'inputs': torch.FloatTensor(imgs),

View File

@ -9,6 +9,8 @@ import torch.nn as nn
from mmrazor.registry import TASK_UTILS
no_positional_input_warned = False
def get_model_flops_params(model,
input_shape=(1, 3, 224, 224),
@ -474,9 +476,13 @@ def batch_counter_hook(module, input, output):
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
global no_positional_input_warned
if no_positional_input_warned:
pass
else:
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
no_positional_input_warned = True
module.__batch_counter__ += batch_size

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmrazor.models.architectures.dynamic_ops import DynamicConv2d
from mmrazor.registry import TASK_UTILS
from .base_counter import BaseCounter
@ -66,7 +66,7 @@ class Conv3dCounter(ConvCounter):
class DynamicConv2dCounter(ConvCounter):
@staticmethod
def add_count_hook(module: nn.Conv2d, input, output):
def add_count_hook(module: DynamicConv2d, input, output):
"""Calculate FLOPs and params based on the dynamic channels of conv
layers."""
input = input[0]
@ -76,12 +76,21 @@ class DynamicConv2dCounter(ConvCounter):
kernel_dims = list(module.kernel_size)
out_channels = module.mutable_attrs['out_channels'].activated_channels
mutable_channel = list(
module.mutable_attrs['out_channels'].mutable_channels.values())
if hasattr(mutable_channel[0], 'activated_tensor_channels'):
out_channels = mutable_channel[0].activated_tensor_channels
in_channels = module.mutable_attrs['in_channels'].activated_channels
if 'out_channels' in module.mutable_attrs:
out_channels = module.mutable_attrs[
'out_channels'].activated_channels
mutable_channel = list(
module.mutable_attrs['out_channels'].mutable_channels.values())
if len(mutable_channel) > 0 and hasattr(
mutable_channel[0], 'activated_tensor_channels'):
out_channels = mutable_channel[0].activated_tensor_channels
else:
out_channels = module.out_channels
if 'in_channels' in module.mutable_attrs:
in_channels = module.mutable_attrs[
'in_channels'].activated_channels
else:
in_channels = module.in_channels
groups = module.groups

View File

@ -12,7 +12,14 @@ from .unit import ExpandableUnit
def to_expandable_model(model: nn.Module) -> ChannelMutator[ExpandableUnit]:
"""Convert a static model to an expandable model."""
state_dict = model.state_dict()
mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit)
mutator = ChannelMutator[ExpandableUnit](
channel_unit_cfg=ExpandableUnit,
parse_cfg=dict(
_scope_='mmrazor',
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='FxTracer'),
)
mutator.prepare_from_supernet(model)
model.load_state_dict(state_dict)
return mutator

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model.utils import _BatchNormXd
from mmrazor.models.mutables import (L1MutableChannelUnit,
MutableChannelContainer)
@ -15,6 +16,7 @@ class ExpandableUnit(L1MutableChannelUnit):
model, {
nn.Conv2d: ExpandableConv2d,
nn.BatchNorm2d: ExpandableBatchNorm2d,
_BatchNormXd: ExpandableBatchNorm2d,
nn.Linear: ExpandLinear,
})
self._register_channel_container(model, MutableChannelContainer)

View File

@ -3,8 +3,10 @@ import copy
import os
from unittest import TestCase
import torch
from mmengine import fileio
from mmrazor import digit_version
from mmrazor.implementations.pruning.group_fisher.prune_deploy_sub_model import \
GroupFisherDeploySubModel # noqa
from ....data.models import MMClsResNet18
@ -13,7 +15,12 @@ from .test_prune_sub_model import PruneAlgorithm, get_model_structure
class TestPruneDeploySubModel(TestCase):
def check_torch_version(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
def test_build_sub_model(self):
self.check_torch_version()
model = MMClsResNet18()
parse_cfg = dict(

View File

@ -5,6 +5,7 @@ from unittest import TestCase
import torch
from mmrazor import digit_version
from mmrazor.implementations.pruning.group_fisher.prune_sub_model import \
GroupFisherSubModel
from mmrazor.models import BaseAlgorithm
@ -43,7 +44,12 @@ def get_model_structure(model):
class TestPruneSubModel(TestCase):
def check_torch_version(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
def test_build_sub_model(self):
self.check_torch_version()
x = torch.rand([1, 3, 224, 224])
model = MMClsResNet18()
algorithm = PruneAlgorithm(model)

View File

@ -3,6 +3,7 @@ import unittest
import torch
from mmrazor import digit_version
from mmrazor.models.mutables import SimpleMutableChannel
from mmrazor.models.utils.expandable_utils import (
expand_expandable_dynamic_model, make_channel_divisible,
@ -13,7 +14,12 @@ from ....data.models import DwConvModel, MultiConcatModel, SingleLineModel
class TestExpand(unittest.TestCase):
def check_torch_version(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
def test_expand(self):
self.check_torch_version()
for Model in [MultiConcatModel, DwConvModel]:
x = torch.rand([1, 3, 224, 224])
model = Model()
@ -32,6 +38,7 @@ class TestExpand(unittest.TestCase):
self.assertTrue((y1 - y2).abs().max() < 1e-3)
def test_expand_static_model(self):
self.check_torch_version()
x = torch.rand([1, 3, 224, 224])
model = SingleLineModel()
y1 = model(x)
@ -42,6 +49,7 @@ class TestExpand(unittest.TestCase):
self.assertTrue((y1 - y2).abs().max() < 1e-3)
def test_ExpandConv2d(self):
self.check_torch_version()
linear = ExpandLinear(3, 3)
mutable_in = SimpleMutableChannel(3)
mutable_out = SimpleMutableChannel(3)