Add ACT/FullQuant Demo
parent
688f44ea32
commit
bf4d6998ed
|
@ -0,0 +1,110 @@
|
|||
# 图像分类模型自动压缩示例
|
||||
|
||||
目录:
|
||||
- [1. 简介](#1简介)
|
||||
- [2. Benchmark](#2Benchmark)
|
||||
- [3. 自动压缩流程](#自动压缩流程)
|
||||
- [3.1 准备环境](#31-准备准备)
|
||||
- [3.2 准备数据集](#32-准备数据集)
|
||||
- [3.3 准备预测模型](#33-准备预测模型)
|
||||
- [3.4 自动压缩并产出模型](#34-自动压缩并产出模型)
|
||||
- [4. 预测部署](#4预测部署)
|
||||
- [4.1 Python预测推理](#41-Python预测推理)
|
||||
- [4.2 PaddleLite端侧部署](#42-PaddleLite端侧部署)
|
||||
- [5. FAQ](5FAQ)
|
||||
|
||||
|
||||
## 1. 简介
|
||||
本示例将以图像分类模型MobileNetV3为例,介绍如何使用PaddleClas中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化训练和蒸馏。
|
||||
|
||||
## 2. Benchmark
|
||||
|
||||
### PaddleClas模型
|
||||
|
||||
| 模型 | 策略 | Top-1 Acc | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|
||||
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
|
||||
| MobileNetV3_large_x1_0 | Baseline | 75.32 | - | 16.62 | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar) |
|
||||
| MobileNetV3_large_x1_0 | 量化+蒸馏 | 74.04 | - | 9.85 | [Config](./mbv3_qat_dis.yaml) | [Model](https://paddle-slim-models.bj.bcebos.com/act/MobileNetV3_large_x1_0_QAT.tar) |
|
||||
|
||||
|
||||
- ARM CPU 测试环境:`SDM865(4xA77+4xA55)`
|
||||
- Nvidia GPU 测试环境:
|
||||
- 硬件:NVIDIA Tesla T4 单卡
|
||||
- 软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4
|
||||
- 测试配置:batch_size: 1, image size: 224
|
||||
|
||||
## 3. 自动压缩流程
|
||||
|
||||
#### 3.1 准备环境
|
||||
|
||||
- python >= 3.6
|
||||
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
|
||||
- PaddleSlim >= 2.3
|
||||
|
||||
安装paddlepaddle:
|
||||
```shell
|
||||
# CPU
|
||||
pip install paddlepaddle
|
||||
# GPU
|
||||
pip install paddlepaddle-gpu
|
||||
```
|
||||
|
||||
安装paddleslim:
|
||||
```shell
|
||||
pip install paddleslim
|
||||
```
|
||||
|
||||
#### 3.2 准备数据集
|
||||
本案例默认以ImageNet1k数据进行自动压缩实验,如数据集为非ImageNet1k格式数据, 请参考[PaddleClas数据准备文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/data_preparation/classification_dataset.md)。将下载好的数据集放在当前目录下`./ILSVRC2012`。
|
||||
|
||||
|
||||
#### 3.3 准备预测模型
|
||||
预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
|
||||
|
||||
注:其他像`__model__`和`__params__`分别对应`model.pdmodel` 和 `model.pdiparams`文件。
|
||||
|
||||
可在[PaddleClas预训练模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md)中直接获取Inference模型,具体可参考下方获取MobileNetV3模型示例:
|
||||
|
||||
```shell
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar
|
||||
tar -xf MobileNetV3_large_x1_0_infer.tar
|
||||
```
|
||||
也可根据[PaddleClas文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/inference_deployment/export_model.md)导出Inference模型。
|
||||
|
||||
#### 3.4 自动压缩并产出模型
|
||||
|
||||
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口 ```paddleslim.auto_compression.AutoCompression``` 对模型进行量化训练和蒸馏。配置config文件中模型路径、数据集路径、蒸馏、量化和训练等部分的参数,配置完成后便可开始自动压缩。
|
||||
|
||||
**单卡启动**
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python run.py --save_dir='./save_quant_mobilev3/' --config_path='./configs/mbv3_qat_dis.yaml'
|
||||
```
|
||||
|
||||
**多卡启动**
|
||||
|
||||
图像分类训练任务中往往包含大量训练数据,以ImageNet为例,ImageNet22k数据集中包含1400W张图像,如果使用单卡训练,会非常耗时,使用分布式训练可以达到几乎线性的加速比。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev3/' --config_path='./configs/mbv3_qat_dis.yaml'
|
||||
```
|
||||
多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32,单轮训练的数据量即32,而四卡训练的```batch size```为32,单轮训练的数据量为128。
|
||||
|
||||
注意 ```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。
|
||||
|
||||
|
||||
|
||||
## 4.预测部署
|
||||
#### 4.1 Python预测推理
|
||||
Python预测推理可参考:
|
||||
- [Python部署](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/inference_deployment/python_deploy.md)
|
||||
|
||||
|
||||
|
||||
#### 4.2 PaddleLite端侧部署
|
||||
PaddleLite端侧部署可参考:
|
||||
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/inference_deployment/paddle_lite_deploy.md)
|
||||
|
||||
## 5.FAQ
|
|
@ -0,0 +1,94 @@
|
|||
# global configs
|
||||
Global:
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
model_dir: ./MobileNetV3_large_x1_0_infer
|
||||
model_filename: inference.pdmodel
|
||||
params_filename: inference.pdiparams
|
||||
input_name: inputs
|
||||
|
||||
Distillation:
|
||||
alpha: 1.0
|
||||
loss: soft_label
|
||||
|
||||
Quantization:
|
||||
use_pact: true
|
||||
activation_bits: 8
|
||||
is_full_quantize: false
|
||||
onnx_format: true
|
||||
activation_quantize_type: moving_average_abs_max
|
||||
weight_quantize_type: channel_wise_abs_max
|
||||
not_quant_pattern:
|
||||
- skip_quant
|
||||
quantize_op_types:
|
||||
- conv2d
|
||||
- depthwise_conv2d
|
||||
weight_bits: 8
|
||||
|
||||
TrainConfig:
|
||||
epochs: 2
|
||||
eval_iter: 500
|
||||
learning_rate: 0.001
|
||||
optimizer_builder:
|
||||
optimizer:
|
||||
type: Momentum
|
||||
weight_decay: 0.000005
|
||||
# origin_metric: 0.7532
|
||||
|
||||
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: /paddle/dataset/ILSVRC2012/
|
||||
cls_label_path: /paddle/dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AutoAugment:
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: /paddle/dataset/ILSVRC2012/
|
||||
cls_label_path: /paddle/dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 32
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddleslim
|
||||
from paddle.jit import to_static
|
||||
from paddleslim.analysis import dygraph_flops as flops
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
from paddleslim.auto_compression import AutoCompression
|
||||
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.utils import config as conf
|
||||
from ppcls.utils.logger import init_logger
|
||||
|
||||
|
||||
def reader_wrapper(reader, input_name):
|
||||
def gen():
|
||||
for i, (imgs, label) in enumerate(reader()):
|
||||
yield {input_name: imgs}
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def eval_function(exe, compiled_test_program, test_feed_names,
|
||||
test_fetch_list):
|
||||
results = []
|
||||
with tqdm(
|
||||
total=len(val_loader),
|
||||
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
|
||||
ncols=80) as t:
|
||||
for batch_id, (image, label) in enumerate(val_loader):
|
||||
|
||||
# top1_acc, top5_acc
|
||||
if len(test_feed_names) == 1:
|
||||
image = np.array(image)
|
||||
label = np.array(label).astype('int64')
|
||||
pred = exe.run(compiled_test_program,
|
||||
feed={test_feed_names[0]: image},
|
||||
fetch_list=test_fetch_list)
|
||||
pred = np.array(pred[0])
|
||||
label = np.array(label).reshape((-1, 1))
|
||||
sort_array = pred.argsort(axis=1)
|
||||
top_1_pred = sort_array[:, -1:][:, ::-1]
|
||||
# print(label, top_1_pred)
|
||||
top_1 = np.mean(label == top_1_pred)
|
||||
# print(top_1)
|
||||
top_5_pred = sort_array[:, -5:][:, ::-1]
|
||||
acc_num = 0
|
||||
for i in range(len(label)):
|
||||
if label[i][0] in top_5_pred[i]:
|
||||
acc_num += 1
|
||||
top_5 = float(acc_num) / len(label)
|
||||
results.append([top_1, top_5])
|
||||
else:
|
||||
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
|
||||
image = np.array(image)
|
||||
label = np.array(label).astype('int64')
|
||||
result = exe.run(compiled_test_program,
|
||||
feed={
|
||||
test_feed_names[0]: image,
|
||||
test_feed_names[1]: label
|
||||
},
|
||||
fetch_list=test_fetch_list)
|
||||
result = [np.mean(r) for r in result]
|
||||
results.append(result)
|
||||
t.update()
|
||||
result = np.mean(np.array(results), axis=0)
|
||||
return result[0]
|
||||
|
||||
|
||||
def main():
|
||||
args = conf.parse_args()
|
||||
global config
|
||||
config = conf.get_config(args.config, overrides=args.override, show=False)
|
||||
|
||||
assert os.path.exists(
|
||||
os.path.join(config["Global"]["model_dir"], 'inference.pdmodel')
|
||||
) and os.path.exists(
|
||||
os.path.join(config["Global"]["model_dir"], 'inference.pdiparams'))
|
||||
if "Query" in config["DataLoader"]["Eval"]:
|
||||
config["DataLoader"]["Eval"] = config["DataLoader"]["Eval"]["Query"]
|
||||
|
||||
init_logger()
|
||||
train_dataloader = build_dataloader(config["DataLoader"], "Train",
|
||||
config["Global"]['device'], False)
|
||||
if isinstance(config['TrainConfig']['learning_rate'], dict) and config[
|
||||
'TrainConfig']['learning_rate']['type'] == 'CosineAnnealingDecay':
|
||||
|
||||
gpu_num = paddle.distributed.get_world_size()
|
||||
step = len(train_dataloader)
|
||||
config['TrainConfig']['learning_rate']['T_max'] = step
|
||||
print('total training steps:', step)
|
||||
|
||||
global val_loader
|
||||
val_loader = build_dataloader(config["DataLoader"], "Eval",
|
||||
config["Global"]['device'], False)
|
||||
|
||||
if config["Global"]['device'] == 'gpu':
|
||||
rank_id = paddle.distributed.get_rank()
|
||||
place = paddle.CUDAPlace(rank_id)
|
||||
paddle.set_device('gpu')
|
||||
else:
|
||||
place = paddle.CPUPlace()
|
||||
paddle.set_device('cpu')
|
||||
|
||||
ac = AutoCompression(
|
||||
model_dir=config["Global"]["model_dir"],
|
||||
model_filename=config["Global"]["model_filename"],
|
||||
params_filename=config["Global"]["params_filename"],
|
||||
save_dir=config["Global"]['output_dir'],
|
||||
config=config,
|
||||
train_dataloader=reader_wrapper(
|
||||
train_dataloader, input_name=config['Global']['input_name']),
|
||||
eval_callback=eval_function if rank_id == 0 else None,
|
||||
eval_dataloader=reader_wrapper(
|
||||
val_loader, input_name=config['Global']['input_name']))
|
||||
|
||||
ac.compress()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
paddle.enable_static()
|
||||
main()
|
Loading…
Reference in New Issue