Add get_prune_config and a demo config_pruning (#389)
* update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai <your_email@abc.example>pull/398/head
parent
c8e14e5489
commit
f886821ba1
|
@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
|
|||
|
||||
### Generate channel_config file
|
||||
|
||||
Generate `resnet_cls.json` with `tools/get_channel_units.py`.
|
||||
Generate `resnet_cls.json` with `tools/pruning/get_channel_units.py`.
|
||||
|
||||
```bash
|
||||
python tools/get_channel_units.py
|
||||
python tools/pruning/get_channel_units.py
|
||||
configs/pruning/mmcls/dcff/dcff_resnet50_8xb32_in1k.py \
|
||||
-c -i --output-path=configs/pruning/mmcls/dcff/resnet_cls.json
|
||||
```
|
||||
|
|
|
@ -10,7 +10,7 @@ stage_ratio_3 = 0.9
|
|||
stage_ratio_4 = 0.7
|
||||
|
||||
# the config template of target_pruning_ratio can be got by
|
||||
# python ./tools/get_channel_units.py {config_file} --choice
|
||||
# python ./tools/pruning/get_channel_units.py {config_file} --choice
|
||||
target_pruning_ratio = {
|
||||
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
|
||||
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,
|
||||
|
|
|
@ -6,7 +6,7 @@ stage_ratio_3 = 0.7
|
|||
stage_ratio_4 = 1.0
|
||||
|
||||
# the config template of target_pruning_ratio can be got by
|
||||
# python ./tools/get_channel_units.py {config_file} --choice
|
||||
# python ./tools/pruning/get_channel_units.py {config_file} --choice
|
||||
target_pruning_ratio = {
|
||||
'backbone.conv1_(0, 64)_64': stage_ratio_1,
|
||||
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
|
||||
|
|
|
@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
|
|||
|
||||
### Generate channel_config file
|
||||
|
||||
Generate `resnet_det.json` with `tools/get_channel_units.py`.
|
||||
Generate `resnet_det.json` with `tools/pruning/get_channel_units.py`.
|
||||
|
||||
```bash
|
||||
python tools/get_channel_units.py
|
||||
python tools/pruning/get_channel_units.py
|
||||
configs/pruning/mmdet/dcff/dcff_faster_rcnn_resnet50_8xb4_coco.py \
|
||||
-c -i --output-path=configs/pruning/mmcls/dcff/resnet_det.json
|
||||
```
|
||||
|
|
|
@ -11,7 +11,7 @@ stage_ratio_3 = 0.9
|
|||
stage_ratio_4 = 0.7
|
||||
|
||||
# the config template of target_pruning_ratio can be got by
|
||||
# python ./tools/get_channel_units.py {config_file} --choice
|
||||
# python ./tools/pruning/get_channel_units.py {config_file} --choice
|
||||
target_pruning_ratio = {
|
||||
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
|
||||
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,
|
||||
|
|
|
@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
|
|||
|
||||
### Generate channel_config file
|
||||
|
||||
Generate `resnet_pose.json` with `tools/get_channel_units.py`.
|
||||
Generate `resnet_pose.json` with `tools/pruning/get_channel_units.py`.
|
||||
|
||||
```bash
|
||||
python tools/get_channel_units.py
|
||||
python tools/pruning/get_channel_units.py
|
||||
configs/pruning/mmpose/dcff/dcff_topdown_heatmap_resnet50.py \
|
||||
-c -i --output-path=configs/pruning/mmpose/dcff/resnet_pose.json
|
||||
```
|
||||
|
|
|
@ -61,7 +61,7 @@ stage_ratio_3 = 0.9
|
|||
stage_ratio_4 = 0.85
|
||||
|
||||
# the config template of target_pruning_ratio can be got by
|
||||
# python ./tools/get_channel_units.py {config_file} --choice
|
||||
# python ./tools/pruning/get_channel_units.py {config_file} --choice
|
||||
target_pruning_ratio = {
|
||||
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
|
||||
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,
|
||||
|
|
|
@ -47,10 +47,10 @@ The mainstream approach for filter pruning is usually either to force a hard-cod
|
|||
|
||||
### Generate channel_config file
|
||||
|
||||
Generate `resnet_seg.json` with `tools/get_channel_units.py`.
|
||||
Generate `resnet_seg.json` with `tools/pruning/get_channel_units.py`.
|
||||
|
||||
```bash
|
||||
python tools/get_channel_units.py
|
||||
python tools/pruning/get_channel_units.py
|
||||
configs/pruning/mmseg/dcff/dcff_pointrend_resnet50_8xb2_cityscapes.py \
|
||||
-c -i --output-path=configs/pruning/mmseg/dcff/resnet_seg.json
|
||||
```
|
||||
|
|
|
@ -32,7 +32,7 @@ stage_ratio_3 = 0.9
|
|||
stage_ratio_4 = 0.7
|
||||
|
||||
# the config template of target_pruning_ratio can be got by
|
||||
# python ./tools/get_channel_units.py {config_file} --choice
|
||||
# python ./tools/pruning/get_channel_units.py {config_file} --choice
|
||||
target_pruning_ratio = {
|
||||
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
|
||||
'backbone.layer1.0.conv2_(0, 64)_64': stage_ratio_2,
|
||||
|
|
|
@ -0,0 +1,689 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# please set cwd to the root of mmrazor repo."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 使用MMRazor对ResNet34进行剪枝"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"本教程主要介绍如何手动配置剪枝config。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 回顾MMCls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 跨库调用resnet34配置文件\n",
|
||||
"\n",
|
||||
"首先我们先跨库调用resnet34的配置文件。通过跨库调用,我们可以继承原有配置文件的所有内容。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prepare work_dir\n",
|
||||
"work_dir = './demo/tmp/'\n",
|
||||
"if not os.path.exists(work_dir):\n",
|
||||
" os.mkdir(work_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmengine import Config\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def write_config(config_str, filename):\n",
|
||||
" with open(filename, 'w') as f:\n",
|
||||
" f.write(config_str)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'ImageClassifier', 'backbone': {'type': 'ResNet', 'depth': 34, 'num_stages': 4, 'out_indices': (3,), 'style': 'pytorch'}, 'neck': {'type': 'GlobalAveragePooling'}, 'head': {'type': 'LinearClsHead', 'num_classes': 1000, 'in_channels': 512, 'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0}, 'topk': (1, 5)}, '_scope_': 'mmcls'}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-base_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 15359124480, 'Parameters': 88591464}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
|
||||
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
|
||||
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-large_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 34368026112, 'Parameters': 197767336}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
|
||||
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
|
||||
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'convnext-xlarge_3rdparty_in21k', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 60929820672, 'Parameters': 350196968}, 'In Collection': 'ConvNeXt', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth', 'Converted From': {'Weights': 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth', 'Code': 'https://github.com/facebookresearch/ConvNeXt'}}\n",
|
||||
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
|
||||
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-base-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 8510000000, 'Parameters': 87920000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
|
||||
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n",
|
||||
"/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/mmengine/config/utils.py:51: UserWarning: There is not `Config` define in {'Name': 'swinv2-large-w12_3rdparty_in21k-192px', 'Metadata': {'Training Data': 'ImageNet-21k', 'FLOPs': 19040000000, 'Parameters': 196740000}, 'In Collection': 'Swin-Transformer V2', 'Results': None, 'Weights': 'https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-large-w12_3rdparty_in21k-192px_20220803-d9073fee.pth', 'Converted From': {'Weights': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', 'Code': 'https://github.com/microsoft/Swin-Transformer'}}\n",
|
||||
" warnings.warn(f'There is not `Config` define in {model_cfg}')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Prepare pretrain config\n",
|
||||
"pretrain_config_path = f'{work_dir}/pretrain.py'\n",
|
||||
"config_string = \"\"\"\n",
|
||||
"_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
|
||||
"\"\"\"\n",
|
||||
"write_config(config_string, pretrain_config_path)\n",
|
||||
"print(Config.fromfile(pretrain_config_path)['model'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run config\n",
|
||||
"! timeout 2 python ./tools/train.py $prune_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 准备剪枝config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. 增加pretrained参数\n",
|
||||
"2. 将resnet34模型装入剪枝算法wrapper中\n",
|
||||
"3. 配置剪枝比例\n",
|
||||
"4. 运行"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 1. 增加预训练参数\n",
|
||||
"我们将原有的’model‘字段取出,命名为architecture,并且给archtecture增加init_cfg字段用来加载预训练模型参数。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
|
||||
"prune_config_path = work_dir + 'prune.py'\n",
|
||||
"config_string += \"\"\"\\n\n",
|
||||
"data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
|
||||
"architecture = _base_.model\n",
|
||||
"architecture.update({\n",
|
||||
" 'init_cfg': {\n",
|
||||
" 'type':\n",
|
||||
" 'Pretrained',\n",
|
||||
" 'checkpoint':\n",
|
||||
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"\"\"\"\n",
|
||||
"write_config(config_string, prune_config_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2. 将resnet34模型装入剪枝算法wrapper中\n",
|
||||
"\n",
|
||||
"我们将原有的model作为architecture放入到ItePruneAlgorithm算法中,并且将ItePruneAlgorithm作为新的model字段。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config_string += \"\"\"\n",
|
||||
"target_pruning_ratio={}\n",
|
||||
"model = dict(\n",
|
||||
" _delete_=True,\n",
|
||||
" _scope_='mmrazor',\n",
|
||||
" type='ItePruneAlgorithm',\n",
|
||||
" architecture=architecture,\n",
|
||||
" mutator_cfg=dict(\n",
|
||||
" type='ChannelMutator',\n",
|
||||
" channel_unit_cfg=dict(\n",
|
||||
" type='L1MutableChannelUnit',\n",
|
||||
" default_args=dict(choice_mode='ratio'))),\n",
|
||||
" target_pruning_ratio=target_pruning_ratio,\n",
|
||||
" step_freq=1,\n",
|
||||
" prune_times=1,\n",
|
||||
")\n",
|
||||
"\"\"\"\n",
|
||||
"write_config(config_string, prune_config_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"配置到这一步时,我们的config文件已经能够运行了。但是因为我们没有配置target_pruning_ratio,因此现在跑起来就和直接用原有config跑起来没有区别,接下来我们会介绍如何配置剪枝比例"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! timeout 2 python ./tools/train.py $prune_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 4. 配置剪枝比例"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"我们的模型使用tracer解析模型,进而获得剪枝节点,为了方便用户配置剪枝节点比例,我们提供了一个获得剪枝节点剪枝比例配置的工具。通过该工具,我们可以方便地对剪枝比例进行配置。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\n",
|
||||
" \"backbone.conv1_(0, 64)_64\":1.0,\n",
|
||||
" \"backbone.layer1.0.conv1_(0, 64)_64\":1.0,\n",
|
||||
" \"backbone.layer1.1.conv1_(0, 64)_64\":1.0,\n",
|
||||
" \"backbone.layer1.2.conv1_(0, 64)_64\":1.0,\n",
|
||||
" \"backbone.layer2.0.conv1_(0, 128)_128\":1.0,\n",
|
||||
" \"backbone.layer2.0.conv2_(0, 128)_128\":1.0,\n",
|
||||
" \"backbone.layer2.1.conv1_(0, 128)_128\":1.0,\n",
|
||||
" \"backbone.layer2.2.conv1_(0, 128)_128\":1.0,\n",
|
||||
" \"backbone.layer2.3.conv1_(0, 128)_128\":1.0,\n",
|
||||
" \"backbone.layer3.0.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.0.conv2_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.1.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.2.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.3.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.4.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer3.5.conv1_(0, 256)_256\":1.0,\n",
|
||||
" \"backbone.layer4.0.conv1_(0, 512)_512\":1.0,\n",
|
||||
" \"backbone.layer4.0.conv2_(0, 512)_512\":1.0,\n",
|
||||
" \"backbone.layer4.1.conv1_(0, 512)_512\":1.0,\n",
|
||||
" \"backbone.layer4.2.conv1_(0, 512)_512\":1.0\n",
|
||||
"}"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ratio_template_path=work_dir+'prune_ratio_template.json'\n",
|
||||
"! python ./tools/pruning/get_channel_units.py $pretrain_config_path --choice -o $ratio_template_path &> /dev/null 2>&1\n",
|
||||
"! cat $ratio_template_path\n",
|
||||
"! rm $ratio_template_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"我们修改该配置模板如下,并且将替换到我们的剪枝配置文件中。\n",
|
||||
"\n",
|
||||
"(该配置来源于:Li, Hao, et al. \"Pruning filters for efficient convnets.\" arXiv preprint arXiv:1608.08710 (2016).)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"target_config = \"\"\"\n",
|
||||
"un_prune = 1.0\n",
|
||||
"stage_ratio_1 = 0.5\n",
|
||||
"stage_ratio_2 = 0.4\n",
|
||||
"stage_ratio_3 = 0.6\n",
|
||||
"stage_ratio_4 = un_prune\n",
|
||||
"\n",
|
||||
"target_pruning_ratio = {\n",
|
||||
" # stage 1\n",
|
||||
" 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
|
||||
" 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
|
||||
" 'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
|
||||
" # stage 2\n",
|
||||
" 'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
|
||||
" 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
|
||||
" 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
|
||||
" 'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
|
||||
" # stage 3\n",
|
||||
" 'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
|
||||
" 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
|
||||
" # stage 4\n",
|
||||
" 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
|
||||
" 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
|
||||
" 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
|
||||
"}\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}\n",
|
||||
"architecture = _base_.model\n",
|
||||
"architecture.update({\n",
|
||||
" 'init_cfg': {\n",
|
||||
" 'type':\n",
|
||||
" 'Pretrained',\n",
|
||||
" 'checkpoint':\n",
|
||||
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa\n",
|
||||
" }\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"un_prune = 1.0\n",
|
||||
"stage_ratio_1 = 0.5\n",
|
||||
"stage_ratio_2 = 0.4\n",
|
||||
"stage_ratio_3 = 0.6\n",
|
||||
"stage_ratio_4 = un_prune\n",
|
||||
"\n",
|
||||
"target_pruning_ratio = {\n",
|
||||
" # stage 1\n",
|
||||
" 'backbone.conv1_(0, 64)_64': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,\n",
|
||||
" 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,\n",
|
||||
" 'backbone.layer1.2.conv1_(0, 64)_64': un_prune,\n",
|
||||
" # stage 2\n",
|
||||
" 'backbone.layer2.0.conv1_(0, 128)_128': un_prune,\n",
|
||||
" 'backbone.layer2.0.conv2_(0, 128)_128': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,\n",
|
||||
" 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,\n",
|
||||
" 'backbone.layer2.3.conv1_(0, 128)_128': un_prune,\n",
|
||||
" # stage 3\n",
|
||||
" 'backbone.layer3.0.conv1_(0, 256)_256': un_prune,\n",
|
||||
" 'backbone.layer3.0.conv2_(0, 256)_256': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,\n",
|
||||
" 'backbone.layer3.5.conv1_(0, 256)_256': un_prune,\n",
|
||||
" # stage 4\n",
|
||||
" 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,\n",
|
||||
" 'backbone.layer4.0.conv2_(0, 512)_512': un_prune, # short cut layers\n",
|
||||
" 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,\n",
|
||||
" 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"model = dict(\n",
|
||||
" _delete_=True,\n",
|
||||
" _scope_='mmrazor',\n",
|
||||
" type='ItePruneAlgorithm',\n",
|
||||
" architecture=architecture,\n",
|
||||
" mutator_cfg=dict(\n",
|
||||
" type='ChannelMutator',\n",
|
||||
" channel_unit_cfg=dict(\n",
|
||||
" type='L1MutableChannelUnit',\n",
|
||||
" default_args=dict(choice_mode='ratio'))),\n",
|
||||
" target_pruning_ratio=target_pruning_ratio,\n",
|
||||
" step_freq=1,\n",
|
||||
" prune_times=1,\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config_string=config_string.replace('target_pruning_ratio={}',target_config)\n",
|
||||
"write_config(config_string,prune_config_path)\n",
|
||||
"! cat $prune_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 5. 运行"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! timeout 2 python ./tools/train.py $prune_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 自动生成剪枝Config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"我们提供了一键生成剪枝config的工具get_prune_config.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"usage: get_l1_prune_config.py [-h] [--checkpoint CHECKPOINT] [--subnet SUBNET]\n",
|
||||
" [-o O]\n",
|
||||
" config\n",
|
||||
"\n",
|
||||
"Get the config to prune a model.\n",
|
||||
"\n",
|
||||
"positional arguments:\n",
|
||||
" config config of the model\n",
|
||||
"\n",
|
||||
"optional arguments:\n",
|
||||
" -h, --help show this help message and exit\n",
|
||||
" --checkpoint CHECKPOINT\n",
|
||||
" checkpoint path of the model\n",
|
||||
" --subnet SUBNET pruning structure for the model\n",
|
||||
" -o O output path to store the pruning config.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"! python ./tools/pruning/get_l1_prune_config.py -h"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"model = dict(\n",
|
||||
" _scope_='mmrazor',\n",
|
||||
" type='ItePruneAlgorithm',\n",
|
||||
" architecture=dict(\n",
|
||||
" type='ImageClassifier',\n",
|
||||
" backbone=dict(\n",
|
||||
" type='ResNet',\n",
|
||||
" depth=34,\n",
|
||||
" num_stages=4,\n",
|
||||
" out_indices=(3, ),\n",
|
||||
" style='pytorch'),\n",
|
||||
" neck=dict(type='GlobalAveragePooling'),\n",
|
||||
" head=dict(\n",
|
||||
" type='LinearClsHead',\n",
|
||||
" num_classes=1000,\n",
|
||||
" in_channels=512,\n",
|
||||
" loss=dict(type='CrossEntropyLoss', loss_weight=1.0),\n",
|
||||
" topk=(1, 5)),\n",
|
||||
" _scope_='mmcls',\n",
|
||||
" init_cfg=dict(\n",
|
||||
" type='Pretrained',\n",
|
||||
" checkpoint=\n",
|
||||
" 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth'\n",
|
||||
" ),\n",
|
||||
" data_preprocessor=dict(\n",
|
||||
" mean=[123.675, 116.28, 103.53],\n",
|
||||
" std=[58.395, 57.12, 57.375],\n",
|
||||
" to_rgb=True)),\n",
|
||||
" target_pruning_ratio=dict({\n",
|
||||
" 'backbone.conv1_(0, 64)_64': 1.0,\n",
|
||||
" 'backbone.layer1.0.conv1_(0, 64)_64': 1.0,\n",
|
||||
" 'backbone.layer1.1.conv1_(0, 64)_64': 1.0,\n",
|
||||
" 'backbone.layer1.2.conv1_(0, 64)_64': 1.0,\n",
|
||||
" 'backbone.layer2.0.conv1_(0, 128)_128': 1.0,\n",
|
||||
" 'backbone.layer2.0.conv2_(0, 128)_128': 1.0,\n",
|
||||
" 'backbone.layer2.1.conv1_(0, 128)_128': 1.0,\n",
|
||||
" 'backbone.layer2.2.conv1_(0, 128)_128': 1.0,\n",
|
||||
" 'backbone.layer2.3.conv1_(0, 128)_128': 1.0,\n",
|
||||
" 'backbone.layer3.0.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.0.conv2_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.1.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.2.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.3.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.4.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer3.5.conv1_(0, 256)_256': 1.0,\n",
|
||||
" 'backbone.layer4.0.conv1_(0, 512)_512': 1.0,\n",
|
||||
" 'backbone.layer4.0.conv2_(0, 512)_512': 1.0,\n",
|
||||
" 'backbone.layer4.1.conv1_(0, 512)_512': 1.0,\n",
|
||||
" 'backbone.layer4.2.conv1_(0, 512)_512': 1.0\n",
|
||||
" }),\n",
|
||||
" mutator_cfg=dict(\n",
|
||||
" type='ChannelMutator',\n",
|
||||
" channel_unit_cfg=dict(\n",
|
||||
" type='L1MutableChannelUnit',\n",
|
||||
" default_args=dict(choice_mode='ratio')),\n",
|
||||
" parse_cfg=dict(\n",
|
||||
" type='ChannelAnalyzer',\n",
|
||||
" tracer_type='FxTracer',\n",
|
||||
" demo_input=dict(type='DefaultDemoInput', scope='mmcls'))))\n",
|
||||
"dataset_type = 'ImageNet'\n",
|
||||
"data_preprocessor = None\n",
|
||||
"train_pipeline = [\n",
|
||||
" dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
|
||||
" dict(type='RandomResizedCrop', scale=224, _scope_='mmcls'),\n",
|
||||
" dict(type='RandomFlip', prob=0.5, direction='horizontal', _scope_='mmcls'),\n",
|
||||
" dict(type='PackClsInputs', _scope_='mmcls')\n",
|
||||
"]\n",
|
||||
"test_pipeline = [\n",
|
||||
" dict(type='LoadImageFromFile', _scope_='mmcls'),\n",
|
||||
" dict(type='ResizeEdge', scale=256, edge='short', _scope_='mmcls'),\n",
|
||||
" dict(type='CenterCrop', crop_size=224, _scope_='mmcls'),\n",
|
||||
" dict(type='PackClsInputs', _scope_='mmcls')\n",
|
||||
"]\n",
|
||||
"train_dataloader = dict(\n",
|
||||
" batch_size=32,\n",
|
||||
" num_workers=5,\n",
|
||||
" dataset=dict(\n",
|
||||
" type='ImageNet',\n",
|
||||
" data_root='data/imagenet',\n",
|
||||
" ann_file='meta/train.txt',\n",
|
||||
" data_prefix='train',\n",
|
||||
" pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile'),\n",
|
||||
" dict(type='RandomResizedCrop', scale=224),\n",
|
||||
" dict(type='RandomFlip', prob=0.5, direction='horizontal'),\n",
|
||||
" dict(type='PackClsInputs')\n",
|
||||
" ],\n",
|
||||
" _scope_='mmcls'),\n",
|
||||
" sampler=dict(type='DefaultSampler', shuffle=True, _scope_='mmcls'),\n",
|
||||
" persistent_workers=True)\n",
|
||||
"val_dataloader = dict(\n",
|
||||
" batch_size=32,\n",
|
||||
" num_workers=5,\n",
|
||||
" dataset=dict(\n",
|
||||
" type='ImageNet',\n",
|
||||
" data_root='data/imagenet',\n",
|
||||
" ann_file='meta/val.txt',\n",
|
||||
" data_prefix='val',\n",
|
||||
" pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile'),\n",
|
||||
" dict(type='ResizeEdge', scale=256, edge='short'),\n",
|
||||
" dict(type='CenterCrop', crop_size=224),\n",
|
||||
" dict(type='PackClsInputs')\n",
|
||||
" ],\n",
|
||||
" _scope_='mmcls'),\n",
|
||||
" sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
|
||||
" persistent_workers=True)\n",
|
||||
"val_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
|
||||
"test_dataloader = dict(\n",
|
||||
" batch_size=32,\n",
|
||||
" num_workers=5,\n",
|
||||
" dataset=dict(\n",
|
||||
" type='ImageNet',\n",
|
||||
" data_root='data/imagenet',\n",
|
||||
" ann_file='meta/val.txt',\n",
|
||||
" data_prefix='val',\n",
|
||||
" pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile'),\n",
|
||||
" dict(type='ResizeEdge', scale=256, edge='short'),\n",
|
||||
" dict(type='CenterCrop', crop_size=224),\n",
|
||||
" dict(type='PackClsInputs')\n",
|
||||
" ],\n",
|
||||
" _scope_='mmcls'),\n",
|
||||
" sampler=dict(type='DefaultSampler', shuffle=False, _scope_='mmcls'),\n",
|
||||
" persistent_workers=True)\n",
|
||||
"test_evaluator = dict(type='Accuracy', topk=(1, 5), _scope_='mmcls')\n",
|
||||
"optim_wrapper = dict(\n",
|
||||
" optimizer=dict(\n",
|
||||
" type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001,\n",
|
||||
" _scope_='mmcls'))\n",
|
||||
"param_scheduler = dict(\n",
|
||||
" type='MultiStepLR',\n",
|
||||
" by_epoch=True,\n",
|
||||
" milestones=[30, 60, 90],\n",
|
||||
" gamma=0.1,\n",
|
||||
" _scope_='mmcls')\n",
|
||||
"train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)\n",
|
||||
"val_cfg = dict()\n",
|
||||
"test_cfg = dict()\n",
|
||||
"auto_scale_lr = dict(base_batch_size=256)\n",
|
||||
"default_scope = 'mmcls'\n",
|
||||
"default_hooks = dict(\n",
|
||||
" timer=dict(type='IterTimerHook', _scope_='mmcls'),\n",
|
||||
" logger=dict(type='LoggerHook', interval=100, _scope_='mmcls'),\n",
|
||||
" param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmcls'),\n",
|
||||
" checkpoint=dict(type='CheckpointHook', interval=1, _scope_='mmcls'),\n",
|
||||
" sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmcls'),\n",
|
||||
" visualization=dict(\n",
|
||||
" type='VisualizationHook', enable=False, _scope_='mmcls'))\n",
|
||||
"env_cfg = dict(\n",
|
||||
" cudnn_benchmark=False,\n",
|
||||
" mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),\n",
|
||||
" dist_cfg=dict(backend='nccl'))\n",
|
||||
"vis_backends = [dict(type='LocalVisBackend', _scope_='mmcls')]\n",
|
||||
"visualizer = dict(\n",
|
||||
" type='ClsVisualizer',\n",
|
||||
" vis_backends=[dict(type='LocalVisBackend')],\n",
|
||||
" _scope_='mmcls')\n",
|
||||
"log_level = 'INFO'\n",
|
||||
"load_from = None\n",
|
||||
"resume = False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"! python ./tools/pruning/get_l1_prune_config.py $work_dir/pretrain.py --checkpoint $checkpoint_path -o $prune_config_path &> /dev/null\n",
|
||||
"! cat $prune_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 清理临时文件\n",
|
||||
"! rm -r $work_dir"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.13 ('lab2max')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -81,7 +81,7 @@ model = dict(
|
|||
**Specific arguments**:
|
||||
A algorithm may have its specific arguments. You need to read their documents to know how to config. Here, we only introduce the specific arguments of ItePruneAlgorithm.
|
||||
|
||||
- target_pruning_ratio: target_pruning_ratio is a dict that uses the name of units as keys and the choice values as values.. It indicates how many channels remain after pruning. You can use python ./tools/get_channel_units.py --choice {config_file} to get the choice template. Please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
|
||||
- target_pruning_ratio: target_pruning_ratio is a dict that uses the name of units as keys and the choice values as values.. It indicates how many channels remain after pruning. You can use python ./tools/pruning/get_channel_units.py --choice {config_file} to get the choice template. Please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
|
||||
- step_epoch: the step between two pruning operations.
|
||||
- prune_times: the times to prune to reach the pruning target. Here, we prune resnet34 once, so we set it to 1.
|
||||
|
||||
|
|
|
@ -121,9 +121,9 @@ mutator2.prepare_from_supernet(resnet34())
|
|||
To make your development more fluent, we provide a command tool to parse a model and return the config template.
|
||||
|
||||
```shell
|
||||
$ python ./tools/get_channel_units.py -h
|
||||
$ python ./tools/pruning/get_channel_units.py -h
|
||||
|
||||
usage: get_channel_units.py [-h] [-c] [-i] [--choice] [-o OUTPUT_PATH] config
|
||||
usage: pruning/get_channel_units.py [-h] [-c] [-i] [--choice] [-o OUTPUT_PATH] config
|
||||
|
||||
Get channel unit of a model.
|
||||
|
||||
|
@ -142,7 +142,7 @@ optional arguments:
|
|||
Take the algorithm Slimmable Network as an example.
|
||||
|
||||
```shell
|
||||
python ./tools/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
|
||||
python ./tools/pruning/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
|
||||
|
||||
# {
|
||||
# "type":"SlimmableChannelMutator",
|
||||
|
@ -171,7 +171,7 @@ python ./tools/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mb
|
|||
The '-i' flag will return the config with the initialization arguments.
|
||||
|
||||
```shell
|
||||
python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
|
||||
python ./tools/pruning/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
|
||||
|
||||
# {
|
||||
# "type":"SlimmableChannelMutator",
|
||||
|
@ -207,7 +207,7 @@ python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim
|
|||
With "--choice" flag, it will return the choice template, a dict which uses unit_name as key, and use the choice value as value.
|
||||
|
||||
```shell
|
||||
python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py --choice
|
||||
python ./tools/pruning/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py --choice
|
||||
|
||||
# {
|
||||
# "backbone.conv1.conv_(0, 48)_48":32,
|
||||
|
|
|
@ -146,3 +146,5 @@ Please refer to the following documents for more details.
|
|||
- [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md)
|
||||
- [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb)
|
||||
- [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb)
|
||||
- Demos
|
||||
- [Config pruning](../../../demo/config_pruning.ipynb)
|
||||
|
|
|
@ -2,6 +2,8 @@ codecov
|
|||
flake8
|
||||
interrogate
|
||||
isort==4.3.21
|
||||
nbconvert
|
||||
nbformat
|
||||
pytest
|
||||
xdoctest >= 0.10.0
|
||||
yapf
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from unittest import TestCase
|
||||
|
||||
import nbformat
|
||||
from nbconvert.preprocessors import ExecutePreprocessor
|
||||
|
||||
TEST_DOC = os.getenv('TEST_DOC') == 'true'
|
||||
notebook_paths = [
|
||||
'./mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb',
|
||||
'./mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb', # noqa
|
||||
'./demo/config_pruning.ipynb'
|
||||
]
|
||||
|
||||
|
||||
class TestDocs(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
if not TEST_DOC:
|
||||
self.skipTest('disabled')
|
||||
|
||||
def test_notebooks(self):
|
||||
for path in notebook_paths:
|
||||
with self.subTest(path=path):
|
||||
with open(path) as file:
|
||||
nb_in = nbformat.read(file, nbformat.NO_CONVERT)
|
||||
ep = ExecutePreprocessor(
|
||||
timeout=600, kernel_name='python3')
|
||||
try:
|
||||
_ = ep.preprocess(nb_in)
|
||||
except Exception:
|
||||
self.fail()
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor import digit_version
|
||||
|
||||
TEST_TOOLS = os.getenv('TEST_TOOLS') == 'true'
|
||||
|
||||
|
||||
class TestTools(TestCase):
|
||||
_config_path = None
|
||||
|
||||
def setUp(self) -> None:
|
||||
if not TEST_TOOLS:
|
||||
self.skipTest('disabled')
|
||||
|
||||
@property
|
||||
def config_path(self):
|
||||
if self._config_path is None:
|
||||
self._config_path = self._get_config_path()
|
||||
return self._config_path
|
||||
|
||||
def _setUp(self) -> None:
|
||||
self.workdir = os.path.dirname(__file__) + '/tmp/'
|
||||
if not os.path.exists(self.workdir):
|
||||
os.mkdir(self.workdir)
|
||||
|
||||
def save_to_config(self, name, content):
|
||||
with open(self.workdir + f'/{name}', 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
def test_get_channel_unit(self):
|
||||
if digit_version(torch.__version__) < digit_version('1.12.0'):
|
||||
self.skipTest('version of torch < 1.12.0')
|
||||
|
||||
for path in self.config_path:
|
||||
with self.subTest(path=path):
|
||||
self._setUp()
|
||||
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
|
||||
try:
|
||||
subprocess.run([
|
||||
'python', './tools/pruning/get_channel_units.py',
|
||||
f'{self.workdir}/pretrain.py', '-o',
|
||||
f'{self.workdir}/unit.json'
|
||||
])
|
||||
except Exception as e:
|
||||
self.fail(f'{e}')
|
||||
self.assertTrue(os.path.exists(f'{self.workdir}/unit.json'))
|
||||
|
||||
self._tearDown()
|
||||
|
||||
def test_get_prune_config(self):
|
||||
if digit_version(torch.__version__) < digit_version('1.12.0'):
|
||||
self.skipTest('version of torch < 1.12.0')
|
||||
for path in self.config_path:
|
||||
with self.subTest(path=path):
|
||||
self._setUp()
|
||||
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
|
||||
try:
|
||||
subprocess.run([
|
||||
'python',
|
||||
'./tools/pruning/get_l1_prune_config.py',
|
||||
f'{self.workdir}/pretrain.py',
|
||||
'-o',
|
||||
f'{self.workdir}/prune.py',
|
||||
])
|
||||
pass
|
||||
except Exception as e:
|
||||
self.fail(f'{e}')
|
||||
self.assertTrue(os.path.exists(f'{self.workdir}/prune.py'))
|
||||
|
||||
self._tearDown()
|
||||
|
||||
def _tearDown(self) -> None:
|
||||
print('delete')
|
||||
shutil.rmtree(self.workdir)
|
||||
pass
|
||||
|
||||
def _get_config_path(self):
|
||||
config_paths = []
|
||||
paths = [
|
||||
('mmcls', 'mmcls::resnet/resnet34_8xb32_in1k.py'),
|
||||
('mmdet', 'mmdet::retinanet/retinanet_r18_fpn_1x_coco.py'),
|
||||
(
|
||||
'mmseg',
|
||||
'mmseg::deeplabv3plus/deeplabv3plus_r50-d8_4xb4-20k_voc12aug-512x512.py' # noqa
|
||||
),
|
||||
('mmyolo',
|
||||
'mmyolo::yolov5/yolov5_m-p6-v62_syncbn_fast_8xb16-300e_coco.py')
|
||||
]
|
||||
for repo_name, path in paths:
|
||||
try:
|
||||
__import__(repo_name)
|
||||
config_paths.append(path)
|
||||
except Exception:
|
||||
pass
|
||||
return config_paths
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import torch.nn as nn
|
||||
from mmengine import MODELS
|
||||
|
@ -9,6 +10,8 @@ from mmengine.config import Config
|
|||
from mmrazor.models import BaseAlgorithm
|
||||
from mmrazor.models.mutators import ChannelMutator
|
||||
|
||||
sys.setrecursionlimit(int(pow(2, 20)))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -40,11 +43,25 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
config = Config.fromfile(args.config)
|
||||
default_scope = config['default_scope']
|
||||
|
||||
model = MODELS.build(config['model'])
|
||||
if isinstance(model, BaseAlgorithm):
|
||||
mutator = model.mutator
|
||||
elif isinstance(model, nn.Module):
|
||||
mutator = ChannelMutator()
|
||||
mutator: ChannelMutator = ChannelMutator(
|
||||
channel_unit_cfg=dict(
|
||||
type='L1MutableChannelUnit',
|
||||
default_args=dict(choice_mode='ratio'),
|
||||
),
|
||||
parse_cfg={
|
||||
'type': 'ChannelAnalyzer',
|
||||
'demo_input': {
|
||||
'type': 'DefaultDemoInput',
|
||||
'scope': default_scope
|
||||
},
|
||||
'tracer_type': 'FxTracer'
|
||||
})
|
||||
mutator.prepare_from_supernet(model)
|
||||
if args.choice:
|
||||
config = mutator.choice_template
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from mmengine import Config, fileio
|
||||
|
||||
from mmrazor.models.mutators import ChannelMutator
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Get the config to prune a model.')
|
||||
parser.add_argument('config', help='config of the model')
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
default=None,
|
||||
type=str,
|
||||
help='checkpoint path of the model')
|
||||
parser.add_argument(
|
||||
'--subnet',
|
||||
default=None,
|
||||
type=str,
|
||||
help='pruning structure for the model')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
type=str,
|
||||
default='./prune.py',
|
||||
help='output path to store the pruning config.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def wrap_prune_config(config: Config, prune_target: Dict,
|
||||
checkpoint_path: str):
|
||||
config = copy.deepcopy(config)
|
||||
default_scope = config['default_scope']
|
||||
arch_config: Dict = config['model']
|
||||
|
||||
# update checkpoint_path
|
||||
if checkpoint_path is not None:
|
||||
arch_config.update({
|
||||
'init_cfg': {
|
||||
'type': 'Pretrained',
|
||||
'checkpoint': checkpoint_path # noqa
|
||||
},
|
||||
})
|
||||
|
||||
# deal with data_preprocessor
|
||||
if 'data_preprocessor' in config:
|
||||
data_preprocessor = config['data_preprocessor']
|
||||
arch_config.update({'data_preprocessor': data_preprocessor})
|
||||
config['data_preprocessor'] = None
|
||||
else:
|
||||
data_preprocessor = None
|
||||
|
||||
# prepare algorithm
|
||||
algorithm_config = dict(
|
||||
_scope_='mmrazor',
|
||||
type='ItePruneAlgorithm',
|
||||
architecture=arch_config,
|
||||
target_pruning_ratio=prune_target,
|
||||
mutator_cfg=dict(
|
||||
type='ChannelMutator',
|
||||
channel_unit_cfg=dict(
|
||||
type='L1MutableChannelUnit',
|
||||
default_args=dict(choice_mode='ratio')),
|
||||
parse_cfg=dict(
|
||||
type='ChannelAnalyzer',
|
||||
tracer_type='FxTracer',
|
||||
demo_input=dict(type='DefaultDemoInput',
|
||||
scope=default_scope))))
|
||||
config['model'] = algorithm_config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def change_config(config):
|
||||
|
||||
scope = config['default_scope']
|
||||
config['model']['_scope_'] = scope
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
config_path = args.config
|
||||
checkpoint_path = args.checkpoint
|
||||
target_path = args.o
|
||||
|
||||
origin_config = Config.fromfile(config_path)
|
||||
origin_config = change_config(origin_config)
|
||||
default_scope = origin_config['default_scope']
|
||||
|
||||
# get subnet config
|
||||
model = MODELS.build(copy.deepcopy(origin_config['model']))
|
||||
mutator: ChannelMutator = ChannelMutator(
|
||||
channel_unit_cfg=dict(
|
||||
type='L1MutableChannelUnit',
|
||||
default_args=dict(choice_mode='ratio'),
|
||||
),
|
||||
parse_cfg={
|
||||
'type': 'ChannelAnalyzer',
|
||||
'demo_input': {
|
||||
'type': 'DefaultDemoInput',
|
||||
'scope': default_scope
|
||||
},
|
||||
'tracer_type': 'FxTracer'
|
||||
})
|
||||
mutator.prepare_from_supernet(model)
|
||||
if args.subnet is None:
|
||||
choice_template = mutator.choice_template
|
||||
else:
|
||||
input_choices = fileio.load(args.subnet)
|
||||
try:
|
||||
mutator.set_choices(input_choices)
|
||||
choice_template = input_choices
|
||||
except Exception as e:
|
||||
print(f'error when apply input subnet: {e}')
|
||||
choice_template = mutator.choice_template
|
||||
|
||||
# prune and finetune
|
||||
|
||||
prune_config: Config = wrap_prune_config(origin_config, choice_template,
|
||||
checkpoint_path)
|
||||
prune_config.dump(target_path)
|
Loading…
Reference in New Issue