Add get_prune_config and a demo config_pruning (#389)

* update tools and test

* add demo

* disable test doc

* add switch for test tools and test_doc

* fix bug

* update doc

* update tools name

* mv get_channel_units

Co-authored-by: liukai <your_email@abc.example>
pull/398/head
LKJacky 2022-12-13 10:56:29 +08:00 committed by GitHub
parent c8e14e5489
commit f886821ba1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 991 additions and 20 deletions

View File

@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
### Generate channel_config file
Generate `resnet_cls.json` with `tools/get_channel_units.py`.
Generate `resnet_cls.json` with `tools/pruning/get_channel_units.py`.
```bash
python tools/get_channel_units.py
python tools/pruning/get_channel_units.py
configs/pruning/mmcls/dcff/dcff_resnet50_8xb32_in1k.py \
-c -i --output-path=configs/pruning/mmcls/dcff/resnet_cls.json
```

View File

@ -10,7 +10,7 @@ stage_ratio_3 = 0.9
stage_ratio_4 = 0.7
# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
# python ./tools/pruning/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,

View File

@ -6,7 +6,7 @@ stage_ratio_3 = 0.7
stage_ratio_4 = 1.0
# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
# python ./tools/pruning/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,

View File

@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
### Generate channel_config file
Generate `resnet_det.json` with `tools/get_channel_units.py`.
Generate `resnet_det.json` with `tools/pruning/get_channel_units.py`.
```bash
python tools/get_channel_units.py
python tools/pruning/get_channel_units.py
configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py \
-c -i --output-path=configs/pruning/mmcls/dcff/resnet_det.json
```

View File

@ -11,7 +11,7 @@ stage_ratio_3 = 0.9
stage_ratio_4 = 0.7
# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
# python ./tools/pruning/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,

View File

@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
### Generate channel_config file
Generate `resnet_pose.json` with `tools/get_channel_units.py`.
Generate `resnet_pose.json` with `tools/pruning/get_channel_units.py`.
```bash
python tools/get_channel_units.py
python tools/pruning/get_channel_units.py
configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50.py \
-c -i --output-path=configs/pruning/mmpose/dcff/resnet_pose.json
```

View File

@ -61,7 +61,7 @@ stage_ratio_3 = 0.9
stage_ratio_4 = 0.85
# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
# python ./tools/pruning/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,

View File

@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
### Generate channel_config file
Generate `resnet_seg.json` with `tools/get_channel_units.py`.
Generate `resnet_seg.json` with `tools/pruning/get_channel_units.py`.
```bash
python tools/get_channel_units.py
python tools/pruning/get_channel_units.py
configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py \
-c -i --output-path=configs/pruning/mmseg/dcff/resnet_seg.json
```

View File

@ -32,7 +32,7 @@ stage_ratio_3 = 0.9
stage_ratio_4 = 0.7
# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
# python ./tools/pruning/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,

View File

@ -0,0 +1,689 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# please set cwd to the root of mmrazor repo."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用MMRazor对ResNet34进行剪枝"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"本教程主要介绍如何手动配置剪枝config。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 回顾MMCls"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 跨库调用resnet34配置文件\n",
"\n",
"首先我们先跨库调用resnet34的配置文件。通过跨库调用我们可以继承原有配置文件的所有内容。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# prepare work_dir\n",
"work_dir = './demo/tmp/'\n",
"if not os.path.exists(work_dir):\n",
" os.mkdir(work_dir)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from mmengine import Config\n",
"\n",
"\n",
"def write_config(config_str, filename):\n",
" with open(filename, 'w') as f:\n",
" f.write(config_str)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'type': 'ImageClassifier', 'backbone': {'type': 'ResNet', 'depth': 34, 'num_stages': 4, 'out_indices': (3,), 'style': 'pytorch'}, 'neck': {'type': 'GlobalAveragePooling'}, 'head': {'type': 'LinearClsHead', 'num_classes': 1000, 'in_channels': 512, 'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0}, 'topk': (1, 5)}, '_scope_': 'mmcls'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n"
]
}
],
"source": [
"# Prepare pretrain config\n",
"pretrain_config_path = f'{work_dir}/pretrain.py'\n",
"config_string = \"\"\"\n",
"_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
"\"\"\"\n",
"write_config(config_string, pretrain_config_path)\n",
"print(Config.fromfile(pretrain_config_path)['model'])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Run config\n",
"! timeout 2 python ./tools/train.py $prune_config_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 准备剪枝config"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"1. 增加pretrained参数\n",
"2. 将resnet34模型装入剪枝算法wrapper中\n",
"3. 配置剪枝比例\n",
"4. 运行"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. 增加预训练参数\n",
"我们将原有的model字段取出命名为architecture并且给archtecture增加init_cfg字段用来加载预训练模型参数。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
"prune_config_path = work_dir + 'prune.py'\n",
"config_string += \"\"\"\\n\n",
"data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
"architecture = _base_.model\n",
"architecture.update({\n",
" 'init_cfg': {\n",
" 'type':\n",
" 'Pretrained',\n",
" 'checkpoint':\n",
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n",
" }\n",
"})\n",
"\"\"\"\n",
"write_config(config_string, prune_config_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. 将resnet34模型装入剪枝算法wrapper中\n",
"\n",
"我们将原有的model作为architecture放入到ItePruneAlgorithm算法中并且将ItePruneAlgorithm作为新的model字段。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"config_string += \"\"\"\n",
"target_pruning_ratio={}\n",
"model = dict(\n",
" _delete_=True,\n",
" _scope_='mmrazor',\n",
" type='ItePruneAlgorithm',\n",
" architecture=architecture,\n",
" mutator_cfg=dict(\n",
" type='ChannelMutator',\n",
" channel_unit_cfg=dict(\n",
" type='L1MutableChannelUnit',\n",
" default_args=dict(choice_mode='ratio'))),\n",
" target_pruning_ratio=target_pruning_ratio,\n",
" step_freq=1,\n",
" prune_times=1,\n",
")\n",
"\"\"\"\n",
"write_config(config_string, prune_config_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"配置到这一步时我们的config文件已经能够运行了。但是因为我们没有配置target_pruning_ratio因此现在跑起来就和直接用原有config跑起来没有区别接下来我们会介绍如何配置剪枝比例"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"! timeout 2 python ./tools/train.py $prune_config_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. 配置剪枝比例"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们的模型使用tracer解析模型进而获得剪枝节点为了方便用户配置剪枝节点比例我们提供了一个获得剪枝节点剪枝比例配置的工具。通过该工具我们可以方便地对剪枝比例进行配置。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"backbone.conv1_(0, 64)_64\":1.0,\n",
" \"backbone.layer1.0.conv1_(0, 64)_64\":1.0,\n",
" \"backbone.layer1.1.conv1_(0, 64)_64\":1.0,\n",
" \"backbone.layer1.2.conv1_(0, 64)_64\":1.0,\n",
" \"backbone.layer2.0.conv1_(0, 128)_128\":1.0,\n",
" \"backbone.layer2.0.conv2_(0, 128)_128\":1.0,\n",
" \"backbone.layer2.1.conv1_(0, 128)_128\":1.0,\n",
" \"backbone.layer2.2.conv1_(0, 128)_128\":1.0,\n",
" \"backbone.layer2.3.conv1_(0, 128)_128\":1.0,\n",
" \"backbone.layer3.0.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.0.conv2_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.1.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.2.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.3.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.4.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer3.5.conv1_(0, 256)_256\":1.0,\n",
" \"backbone.layer4.0.conv1_(0, 512)_512\":1.0,\n",
" \"backbone.layer4.0.conv2_(0, 512)_512\":1.0,\n",
" \"backbone.layer4.1.conv1_(0, 512)_512\":1.0,\n",
" \"backbone.layer4.2.conv1_(0, 512)_512\":1.0\n",
"}"
]
}
],
"source": [
"ratio_template_path=work_dir+'prune_ratio_template.json'\n",
"! python ./tools/pruning/get_channel_units.py $pretrain_config_path --choice -o $ratio_template_path &> /dev/null 2>&1\n",
"! cat $ratio_template_path\n",
"! rm $ratio_template_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们修改该配置模板如下,并且将替换到我们的剪枝配置文件中。\n",
"\n",
"该配置来源于Li, Hao, et al. \"Pruning filters for efficient convnets.\" arXiv preprint arXiv:1608.08710 (2016)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"target_config = \"\"\"\n",
"un_prune = 1.0\n",
"stage_ratio_1 = 0.5\n",
"stage_ratio_2 = 0.4\n",
"stage_ratio_3 = 0.6\n",
"stage_ratio_4 = un_prune\n",
"\n",
"target_pruning_ratio = {\n",
" # stage 1\n",
" 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers\n",
" 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
" 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
" 'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
" # stage 2\n",
" 'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
" 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers\n",
" 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
" 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
" 'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
" # stage 3\n",
" 'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
" 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers\n",
" 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
" # stage 4\n",
" 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
" 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers\n",
" 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
" 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
"\n",
"\n",
"data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
"architecture = _base_.model\n",
"architecture.update({\n",
" 'init_cfg': {\n",
" 'type':\n",
" 'Pretrained',\n",
" 'checkpoint':\n",
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n",
" }\n",
"})\n",
"\n",
"\n",
"un_prune = 1.0\n",
"stage_ratio_1 = 0.5\n",
"stage_ratio_2 = 0.4\n",
"stage_ratio_3 = 0.6\n",
"stage_ratio_4 = un_prune\n",
"\n",
"target_pruning_ratio = {\n",
" # stage 1\n",
" 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers\n",
" 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
" 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
" 'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
" # stage 2\n",
" 'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
" 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers\n",
" 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
" 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
" 'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
" # stage 3\n",
" 'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
" 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers\n",
" 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
" 'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
" # stage 4\n",
" 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
" 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers\n",
" 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
" 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
"}\n",
"\n",
"model = dict(\n",
" _delete_=True,\n",
" _scope_='mmrazor',\n",
" type='ItePruneAlgorithm',\n",
" architecture=architecture,\n",
" mutator_cfg=dict(\n",
" type='ChannelMutator',\n",
" channel_unit_cfg=dict(\n",
" type='L1MutableChannelUnit',\n",
" default_args=dict(choice_mode='ratio'))),\n",
" target_pruning_ratio=target_pruning_ratio,\n",
" step_freq=1,\n",
" prune_times=1,\n",
")\n"
]
}
],
"source": [
"config_string=config_string.replace('target_pruning_ratio={}',target_config)\n",
"write_config(config_string,prune_config_path)\n",
"! cat $prune_config_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. 运行"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"! timeout 2 python ./tools/train.py $prune_config_path"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# 自动生成剪枝Config"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"我们提供了一键生成剪枝config的工具get_prune_config.py"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"usage: get_l1_prune_config.py [-h] [--checkpoint CHECKPOINT] [--subnet SUBNET]\n",
" [-o O]\n",
" config\n",
"\n",
"Get the config to prune a model.\n",
"\n",
"positional arguments:\n",
" config config of the model\n",
"\n",
"optional arguments:\n",
" -h, --help show this help message and exit\n",
" --checkpoint CHECKPOINT\n",
" checkpoint path of the model\n",
" --subnet SUBNET pruning structure for the model\n",
" -o O output path to store the pruning config.\n"
]
}
],
"source": [
"! python ./tools/pruning/get_l1_prune_config.py -h"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model = dict(\n",
" _scope_='mmrazor',\n",
" type='ItePruneAlgorithm',\n",
" architecture=dict(\n",
" type='ImageClassifier',\n",
" backbone=dict(\n",
" type='ResNet',\n",
" depth=34,\n",
" num_stages=4,\n",
" out_indices=(3, ),\n",
" style='pytorch'),\n",
" neck=dict(type='GlobalAveragePooling'),\n",
" head=dict(\n",
" type='LinearClsHead',\n",
" num_classes=1000,\n",
" in_channels=512,\n",
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
" topk=(1, 5)),\n",
" _scope_='mmcls',\n",
" init_cfg=dict(\n",
" type='Pretrained',\n",
" checkpoint=\n",
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
" ),\n",
" data_preprocessor=dict(\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" to_rgb=True)),\n",
" target_pruning_ratio=dict({\n",
" 'backbone.conv1_(0, 64)_64': 1.0,\n",
" 'backbone.layer1.0.conv1_(0, 64)_64': 1.0,\n",
" 'backbone.layer1.1.conv1_(0, 64)_64': 1.0,\n",
" 'backbone.layer1.2.conv1_(0, 64)_64': 1.0,\n",
" 'backbone.layer2.0.conv1_(0, 128)_128': 1.0,\n",
" 'backbone.layer2.0.conv2_(0, 128)_128': 1.0,\n",
" 'backbone.layer2.1.conv1_(0, 128)_128': 1.0,\n",
" 'backbone.layer2.2.conv1_(0, 128)_128': 1.0,\n",
" 'backbone.layer2.3.conv1_(0, 128)_128': 1.0,\n",
" 'backbone.layer3.0.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.0.conv2_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.1.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.2.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.3.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.4.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer3.5.conv1_(0, 256)_256': 1.0,\n",
" 'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n",
" 'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n",
" 'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n",
" 'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n",
" }),\n",
" mutator_cfg=dict(\n",
" type='ChannelMutator',\n",
" channel_unit_cfg=dict(\n",
" type='L1MutableChannelUnit',\n",
" default_args=dict(choice_mode='ratio')),\n",
" parse_cfg=dict(\n",
" type='ChannelAnalyzer',\n",
" tracer_type='FxTracer',\n",
" demo_input=dict(type='DefaultDemoInput', scope='mmcls'))))\n",
"dataset_type = 'ImageNet'\n",
"data_preprocessor = None\n",
"train_pipeline = [\n",
" dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
" dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n",
" dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n",
" dict(type='PackClsInputs', _scope_='mmcls')\n",
"]\n",
"test_pipeline = [\n",
" dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
" dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n",
" dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n",
" dict(type='PackClsInputs', _scope_='mmcls')\n",
"]\n",
"train_dataloader = dict(\n",
" batch_size=32,\n",
" num_workers=5,\n",
" dataset=dict(\n",
" type='ImageNet',\n",
" data_root='data/imagenet',\n",
" ann_file='meta/train.txt',\n",
" data_prefix='train',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='RandomResizedCrop', scale=224),\n",
" dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n",
" dict(type='PackClsInputs')\n",
" ],\n",
" _scope_='mmcls'),\n",
" sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n",
" persistent_workers=True)\n",
"val_dataloader = dict(\n",
" batch_size=32,\n",
" num_workers=5,\n",
" dataset=dict(\n",
" type='ImageNet',\n",
" data_root='data/imagenet',\n",
" ann_file='meta/val.txt',\n",
" data_prefix='val',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='ResizeEdge', scale=256, edge='short'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(type='PackClsInputs')\n",
" ],\n",
" _scope_='mmcls'),\n",
" sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
" persistent_workers=True)\n",
"val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
"test_dataloader = dict(\n",
" batch_size=32,\n",
" num_workers=5,\n",
" dataset=dict(\n",
" type='ImageNet',\n",
" data_root='data/imagenet',\n",
" ann_file='meta/val.txt',\n",
" data_prefix='val',\n",
" pipeline=[\n",
" dict(type='LoadImageFromFile'),\n",
" dict(type='ResizeEdge', scale=256, edge='short'),\n",
" dict(type='CenterCrop', crop_size=224),\n",
" dict(type='PackClsInputs')\n",
" ],\n",
" _scope_='mmcls'),\n",
" sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
" persistent_workers=True)\n",
"test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
"optim_wrapper = dict(\n",
" optimizer=dict(\n",
" type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n",
" _scope_='mmcls'))\n",
"param_scheduler = dict(\n",
" type='MultiStepLR',\n",
" by_epoch=True,\n",
" milestones=[30, 60, 90],\n",
" gamma=0.1,\n",
" _scope_='mmcls')\n",
"train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n",
"val_cfg = dict()\n",
"test_cfg = dict()\n",
"auto_scale_lr = dict(base_batch_size=256)\n",
"default_scope = 'mmcls'\n",
"default_hooks = dict(\n",
" timer=dict(type='IterTimerHook', _scope_='mmcls'),\n",
" logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n",
" param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n",
" checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n",
" sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n",
" visualization=dict(\n",
" type='VisualizationHook', enable=False, _scope_='mmcls'))\n",
"env_cfg = dict(\n",
" cudnn_benchmark=False,\n",
" mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n",
" dist_cfg=dict(backend='nccl'))\n",
"vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n",
"visualizer = dict(\n",
" type='ClsVisualizer',\n",
" vis_backends=[dict(type='LocalVisBackend')],\n",
" _scope_='mmcls')\n",
"log_level = 'INFO'\n",
"load_from = None\n",
"resume = False\n"
]
}
],
"source": [
"! python ./tools/pruning/get_l1_prune_config.py $work_dir/pretrain.py --checkpoint $checkpoint_path -o $prune_config_path &> /dev/null\n",
"! cat $prune_config_path"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# 清理临时文件\n",
"! rm -r $work_dir"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('lab2max')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -81,7 +81,7 @@ model = dict(
**Specific arguments**:
A algorithm may have its specific arguments. You need to read their documents to know how to config. Here, we only introduce the specific arguments of ItePruneAlgorithm.
- target_pruning_ratio: target_pruning_ratio is a dict that uses the name of units as keys and the choice values as values.. It indicates how many channels remain after pruning. You can use python ./tools/get_channel_units.py --choice {config_file} to get the choice template. Please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
- target_pruning_ratio: target_pruning_ratio is a dict that uses the name of units as keys and the choice values as values.. It indicates how many channels remain after pruning. You can use python ./tools/pruning/get_channel_units.py --choice {config_file} to get the choice template. Please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
- step_epoch: the step between two pruning operations.
- prune_times: the times to prune to reach the pruning target. Here, we prune resnet34 once, so we set it to 1.

View File

@ -121,9 +121,9 @@ mutator2.prepare_from_supernet(resnet34())
To make your development more fluent, we provide a command tool to parse a model and return the config template.
```shell
$ python ./tools/get_channel_units.py -h
$ python ./tools/pruning/get_channel_units.py -h
usage: get_channel_units.py [-h] [-c] [-i] [--choice] [-o OUTPUT_PATH] config
usage: pruning/get_channel_units.py [-h] [-c] [-i] [--choice] [-o OUTPUT_PATH] config
Get channel unit of a model.
@ -142,7 +142,7 @@ optional arguments:
Take the algorithm Slimmable Network as an example.
```shell
python ./tools/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
python ./tools/pruning/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
# {
# "type":"SlimmableChannelMutator",
@ -171,7 +171,7 @@ python ./tools/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mb
The '-i' flag will return the config with the initialization arguments.
```shell
python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
python ./tools/pruning/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
# {
# "type":"SlimmableChannelMutator",
@ -207,7 +207,7 @@ python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim
With "--choice" flag, it will return the choice template, a dict which uses unit_name as key, and use the choice value as value.
```shell
python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py --choice
python ./tools/pruning/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py --choice
# {
# "backbone.conv1.conv_(0, 48)_48":32,

View File

@ -146,3 +146,5 @@ Please refer to the following documents for more details.
- [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md)
- [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb)
- [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb)
- Demos
- [Config pruning](../../../demo/config_pruning.ipynb)

View File

@ -2,6 +2,8 @@ codecov
flake8
interrogate
isort==4.3.21
nbconvert
nbformat
pytest
xdoctest >= 0.10.0
yapf

32
tests/test_doc.py 100644
View File

@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from unittest import TestCase
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor
TEST_DOC = os.getenv('TEST_DOC') == 'true'
notebook_paths = [
'./mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb',
'./mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb', # noqa
'./demo/config_pruning.ipynb'
]
class TestDocs(TestCase):
def setUp(self) -> None:
if not TEST_DOC:
self.skipTest('disabled')
def test_notebooks(self):
for path in notebook_paths:
with self.subTest(path=path):
with open(path) as file:
nb_in = nbformat.read(file, nbformat.NO_CONVERT)
ep = ExecutePreprocessor(
timeout=600, kernel_name='python3')
try:
_ = ep.preprocess(nb_in)
except Exception:
self.fail()

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
import subprocess
from unittest import TestCase
import torch
from mmrazor import digit_version
TEST_TOOLS = os.getenv('TEST_TOOLS') == 'true'
class TestTools(TestCase):
_config_path = None
def setUp(self) -> None:
if not TEST_TOOLS:
self.skipTest('disabled')
@property
def config_path(self):
if self._config_path is None:
self._config_path = self._get_config_path()
return self._config_path
def _setUp(self) -> None:
self.workdir = os.path.dirname(__file__) + '/tmp/'
if not os.path.exists(self.workdir):
os.mkdir(self.workdir)
def save_to_config(self, name, content):
with open(self.workdir + f'/{name}', 'w') as f:
f.write(content)
def test_get_channel_unit(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
for path in self.config_path:
with self.subTest(path=path):
self._setUp()
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
try:
subprocess.run([
'python', './tools/pruning/get_channel_units.py',
f'{self.workdir}/pretrain.py', '-o',
f'{self.workdir}/unit.json'
])
except Exception as e:
self.fail(f'{e}')
self.assertTrue(os.path.exists(f'{self.workdir}/unit.json'))
self._tearDown()
def test_get_prune_config(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
for path in self.config_path:
with self.subTest(path=path):
self._setUp()
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
try:
subprocess.run([
'python',
'./tools/pruning/get_l1_prune_config.py',
f'{self.workdir}/pretrain.py',
'-o',
f'{self.workdir}/prune.py',
])
pass
except Exception as e:
self.fail(f'{e}')
self.assertTrue(os.path.exists(f'{self.workdir}/prune.py'))
self._tearDown()
def _tearDown(self) -> None:
print('delete')
shutil.rmtree(self.workdir)
pass
def _get_config_path(self):
config_paths = []
paths = [
('mmcls', 'mmcls::resnet/resnet34_8xb32_in1k.py'),
('mmdet', 'mmdet::retinanet/retinanet_r18_fpn_1x_coco.py'),
(
'mmseg',
'mmseg::deeplabv3plus/deeplabv3plus_r50-d8_4xb4-20k_voc12aug-512x512.py' # noqa
),
('mmyolo',
'mmyolo::yolov5/yolov5_m-p6-v62_syncbn_fast_8xb16-300e_coco.py')
]
for repo_name, path in paths:
try:
__import__(repo_name)
config_paths.append(path)
except Exception:
pass
return config_paths

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import sys
import torch.nn as nn
from mmengine import MODELS
@ -9,6 +10,8 @@ from mmengine.config import Config
from mmrazor.models import BaseAlgorithm
from mmrazor.models.mutators import ChannelMutator
sys.setrecursionlimit(int(pow(2, 20)))
def parse_args():
parser = argparse.ArgumentParser(
@ -40,11 +43,25 @@ def parse_args():
def main():
args = parse_args()
config = Config.fromfile(args.config)
default_scope = config['default_scope']
model = MODELS.build(config['model'])
if isinstance(model, BaseAlgorithm):
mutator = model.mutator
elif isinstance(model, nn.Module):
mutator = ChannelMutator()
mutator: ChannelMutator = ChannelMutator(
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'),
),
parse_cfg={
'type': 'ChannelAnalyzer',
'demo_input': {
'type': 'DefaultDemoInput',
'scope': default_scope
},
'tracer_type': 'FxTracer'
})
mutator.prepare_from_supernet(model)
if args.choice:
config = mutator.choice_template

View File

@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
from typing import Dict
from mmengine import Config, fileio
from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser(
description='Get the config to prune a model.')
parser.add_argument('config', help='config of the model')
parser.add_argument(
'--checkpoint',
default=None,
type=str,
help='checkpoint path of the model')
parser.add_argument(
'--subnet',
default=None,
type=str,
help='pruning structure for the model')
parser.add_argument(
'-o',
type=str,
default='./prune.py',
help='output path to store the pruning config.')
args = parser.parse_args()
return args
def wrap_prune_config(config: Config, prune_target: Dict,
checkpoint_path: str):
config = copy.deepcopy(config)
default_scope = config['default_scope']
arch_config: Dict = config['model']
# update checkpoint_path
if checkpoint_path is not None:
arch_config.update({
'init_cfg': {
'type': 'Pretrained',
'checkpoint': checkpoint_path # noqa
},
})
# deal with data_preprocessor
if 'data_preprocessor' in config:
data_preprocessor = config['data_preprocessor']
arch_config.update({'data_preprocessor': data_preprocessor})
config['data_preprocessor'] = None
else:
data_preprocessor = None
# prepare algorithm
algorithm_config = dict(
_scope_='mmrazor',
type='ItePruneAlgorithm',
architecture=arch_config,
target_pruning_ratio=prune_target,
mutator_cfg=dict(
type='ChannelMutator',
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio')),
parse_cfg=dict(
type='ChannelAnalyzer',
tracer_type='FxTracer',
demo_input=dict(type='DefaultDemoInput',
scope=default_scope))))
config['model'] = algorithm_config
return config
def change_config(config):
scope = config['default_scope']
config['model']['_scope_'] = scope
return config
if __name__ == '__main__':
args = parse_args()
config_path = args.config
checkpoint_path = args.checkpoint
target_path = args.o
origin_config = Config.fromfile(config_path)
origin_config = change_config(origin_config)
default_scope = origin_config['default_scope']
# get subnet config
model = MODELS.build(copy.deepcopy(origin_config['model']))
mutator: ChannelMutator = ChannelMutator(
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'),
),
parse_cfg={
'type': 'ChannelAnalyzer',
'demo_input': {
'type': 'DefaultDemoInput',
'scope': default_scope
},
'tracer_type': 'FxTracer'
})
mutator.prepare_from_supernet(model)
if args.subnet is None:
choice_template = mutator.choice_template
else:
input_choices = fileio.load(args.subnet)
try:
mutator.set_choices(input_choices)
choice_template = input_choices
except Exception as e:
print(f'error when apply input subnet: {e}')
choice_template = mutator.choice_template
# prune and finetune
prune_config: Config = wrap_prune_config(origin_config, choice_template,
checkpoint_path)
prune_config.dump(target_path)