[Feature]Complete the ppocrv4_act (#11345)

* ppocrv4_act

* update

* fix bugs when run act on ppocrv4_dedt_server

* modify act config files

* modify test code and update results

* 新增数据处理的脚本

* fix

* Add batch testing script

* fix

* fix

* fix

* update det_server inference on tesla v100

* update model urls

---------

Co-authored-by: tangshiyu <tangshiyu@baidu.com>
pull/11520/head
Ran chongzhi 2024-01-19 11:12:25 +08:00 committed by GitHub
parent 3b6f117c44
commit 448ee6bec1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 8139 additions and 0 deletions

View File

@ -0,0 +1,303 @@
# OCR模型自动压缩示例
目录:
- [OCR模型自动压缩示例](#ocr模型自动压缩示例)
- [1. 简介](#1-简介)
- [2. Benchmark](#2-benchmark)
- [PPOCRV4\_det](#ppocrv4_det)
- [PPOCRV4\_rec](#ppocrv4_rec)
- [3. 自动压缩流程](#3-自动压缩流程)
- [3.1 准备环境](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集)
- [3.2.1 PPOCRV4\_det\_server数据集预处理](#321-ppocrv4_det_server数据集预处理)
- [3.3 准备预测模型](#33-准备预测模型)
- [4.预测部署](#4预测部署)
- [4.1 Paddle Inference 验证性能](#41-paddle-inference-验证性能)
- [4.1.1 使用测试脚本进行批量测试:](#411-使用测试脚本进行批量测试)
- [4.1.2 基于压缩模型进行基于GPU的批量测试](#412-基于压缩模型进行基于gpu的批量测试)
- [4.1.3 基于压缩前模型进行基于GPU的批量测试](#413-基于压缩前模型进行基于gpu的批量测试)
- [4.1.4 基于压缩模型进行基于CPU的批量测试](#414-基于压缩模型进行基于cpu的批量测试)
- [4.2 PaddleLite端侧部署](#42-paddlelite端侧部署)
- [5.FAQ](#5faq)
- [5.1 报错找不到模型文件或者数据集文件](#51-报错找不到模型文件或者数据集文件)
- [5.2 软件环境一致,硬件不同导致精度差异很大?](#52-软件环境一致硬件不同导致精度差异很大)
## 1. 简介
本示例将以图像分类模型PPOCRV3为例介绍如何使用PaddleOCR中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化训练和蒸馏。
## 2. Benchmark
### PPOCRV4_det
| 模型 | 策略 | Metric(hmean) | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 中文PPOCRV4-det_mobile | Baseline | 72.71 | 5.7 | 92.0 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/ch_PP-OCRv4_det_infer.tar) |
| 中文PPOCRV4-det_mobile | 量化+蒸馏 | 71.10 | 2.3 | 94.1 | [Config](./configs/ppocrv4/ppocrv4_det_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/det_mobile_qat_3090.zip) |
| 中文PPOCRV4-det_server | Baseline | 79.82 | 32.6 | 844.7 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/ch_PP-OCRv4_det_server_infer.tar) |
| 中文PPOCRV4-det_server | 量化+蒸馏 | 79.27 | 12.3 | 635.0 | [Config](./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/det_server_qat_3090.zip) |
> - GPU测试环境RTX 3090, cuda11.7+tensorrt8.4.2.4+paddle2.5
> - CPU测试环境Intel(R) Xeon(R) Gold 6226R使用12线程测试
> - PPOCRV4-det_server在不完整的数据集上测试数据处理流程参考[ppocrv4_det_server数据集预处理](#321-ppocrv4_det_server数据集预处理),仅为了展示自动压缩效果,指标并不具有参考性,模型真实表现请参考[PPOCRV4介绍](../../../doc/doc_ch/PP-OCRv4_introduction.md)
| 模型 | 策略 | Metric(hmean) | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 中文PPOCRV4-det_mobile | Baseline | 72.71 | 4.7 | 198.4 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/ch_PP-OCRv4_det_infer.tar) |
| 中文PPOCRV4-det_mobile | 量化+蒸馏 | 71.38 | 3.3 | 205.2 | [Config](./configs/ppocrv4/ppocrv4_det_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/det_server_qat_v100.zip) |
| 中文PPOCRV4-det_server | Baseline | 79.77 | 50.0 | 2159.4 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/ch_PP-OCRv4_det_server_infer.tar) |
| 中文PPOCRV4-det_server | 量化+蒸馏 | 79.81 | 42.4 | 1834.8 | [Config](./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/det/det_server_qat_v100.zip) |
> - GPU测试环境Tesla V100, cuda11.7+tensorrt8.4.2.4+paddle2.5.2
> - CPU测试环境Intel(R) Xeon(R) Gold 6271C使用12线程测试
> - PPOCRV4-det_server在不完整的数据集上测试数据处理流程参考[ppocrv4_det_server数据集预处理](#321-ppocrv4_det_server数据集预处理),仅为了展示自动压缩效果,指标并不具有参考性,模型真实表现请参考[PPOCRV4介绍](../../../doc/doc_ch/PP-OCRv4_introduction.md)
### PPOCRV4_rec
| 模型 | 策略 | Metric(accuracy) | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 中文PPOCRV4-rec_mobile | Baseline | 78.92 | 1.7 | 33.3 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/rec/ch_PP-OCRv4_rec_infer.tar.gz) |
| 中文PPOCRV4-rec_mobile | 量化+蒸馏 | 78.41 | 1.4 | 34.0 | [Config](./configs/ppocrv4/ppocrv4_rec_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/rec/rec_mobile_qat.tar.gz) |
| 中文PPOCRV4-rec_server | Baseline | 81.62 | 4.0 | 62.5 | - | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/rec/ch_PP-OCRv4_rec_server_infer.tar.gz) |
| 中文PPOCRV4-rec_server | 量化+蒸馏 | 81.03 | 2.0 | 64.4 | [Config](./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml) | [Model](https://paddle-ocr-models.bj.bcebos.com/ppocrv4_qat/rec/rec_server_qat.tar.gz) |
> - GPU测试环境Tesla V100, cuda11.2+tensorrt8.0.3.4+paddle2.5
> - CPU测试环境Intel(R) Xeon(R) Gold 6271C使用12线程测试
## 3. 自动压缩流程
### 3.1 准备环境
- PaddlePaddle == 2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim == 2.5
- PaddleOCR == develop
安装paddlepaddle
```shell
# CPU
python -m pip install paddlepaddle==2.5.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# GPU 以Ubuntu、CUDA 10.2为例
python -m pip install paddlepaddle-gpu==2.5.1.post102 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
安装paddleslim 2.5
```shell
pip install paddleslim@git+https://gitee.com/paddlepaddle/PaddleSlim.git@release/2.5
```
安装其他依赖:
```shell
pip install scikit-image imgaug
```
下载PaddleOCR:
```shell
git clone -b release/2.7 https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR/
pip install -r requirements.txt
```
### 3.2 准备数据集
公开数据集可参考[OCR数据集](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/dataset/ocr_datasets.md),然后根据程序运行过程中提示放置到对应位置。
#### 3.2.1 PPOCRV4_det_server数据集预处理
PPOCRV4_det_server在使用原始数据集推理时默认将输入图像的最小边缩放到736然而原始数据集中存在一些长宽比很大的图像比如13:1此时再进行缩放就会导致长边的尺寸非常大在实验过程中发现最大的长边尺寸有10000+这导致在构建TensorRT子图的时候显存不足。
为了能顺利跑通自动压缩的流程,展示自动压缩的效果,因此需要对原始数据集进行预处理,将长宽比过大的图像进行剔除,处理脚本可见[ppocrv4_det_server_dataset_process.py](./ppocrv4_det_server_dataset_process.py)。
> 注意:使用不同的数据集需要修改配置文件中`dataset`中数据路径和数据处理部分。
### 3.3 准备预测模型
预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
> 注:其他像`__model__`和`__params__`分别对应`model.pdmodel` 和 `model.pdiparams`文件。
可在[PaddleOCR模型库](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/doc/doc_ch/models_list.md)中直接获取Inference模型具体可参考下方获取中文PPOCRV4模型示例
```shell
https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar
tar -xf ch_PP-OCRv4_rec_infer.tar
```
```shell
wget https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar
tar -xf ch_PP-OCRv4_det_infer.tar
```
蒸馏量化自动压缩示例通过run.py脚本启动会使用接口 ```paddleslim.auto_compression.AutoCompression``` 对模型进行量化训练和蒸馏。配置config文件中模型路径、数据集路径、蒸馏、量化和训练等部分的参数配置完成后便可开始自动压缩。
**单卡启动**
```shell
export CUDA_VISIBLE_DEVICES=0
python run.py --save_dir='./save_quant_ppocrv4_det/' --config_path='./configs/ppocrv4/ppocrv4_det_qat_dist.yaml'
```
**多卡启动**
若训练任务中包含大量训练数据,如果使用单卡训练,会非常耗时,使用分布式训练可以达到几乎线性的加速比。
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch run.py --save_dir='./save_quant_ppocrv4_det/' --config_path='./configs/ppocrv4/ppocrv4_det_qat_dist.yaml'
```
多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32单轮训练的数据量即32而四卡训练的```batch size```为32单轮训练的数据量为128。
注意 ```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 8对应的 ```learning rate``` 为0.00005,那么如果 ```batch size``` 增大4倍改为32```learning rate``` 也需乘以4多卡时 ```batch size``` 为8```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。
**验证精度**
根据训练log可以看到模型验证的精度若需再次验证精度修改配置文件```./configs/ppocrv3_det_qat_dist.yaml```中所需验证模型的文件夹路径及模型和参数名称```model_dir, model_filename, params_filename```,然后使用以下命令进行验证:
```shell
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path='./configs/ppocrv3_det_qat_dist.yaml'
```
## 4.预测部署
#### 4.1 Paddle Inference 验证性能
输出的量化模型也是静态图模型静态图模型在GPU上可以使用TensorRT进行加速在CPU上可以使用MKLDNN进行加速。
TensorRT预测环境配置
1. 如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle上述paddle下载的2.5满足打开TensorRT编译的要求。
2. 使用TensorRT预测需要进一步安装TensorRT安装TensorRT的方式参考[TensorRT安装说明](../../../docs/deployment/installtrt.md)。
以下字段用于配置预测参数:
| 参数名 | 含义 |
|:------:|:------:|
| model_path | inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件 |
| model_filename | inference_model_dir文件夹下的模型文件名称 |
| params_filename | inference_model_dir文件夹下的参数文件名称 |
| dataset_config | 数据集配置的config |
| image_file | 待测试单张图片的路径如果设置image_file则dataset_config将无效。 |
| device | 预测时的设备,可选:`CPU`, `GPU`。 |
| use_trt | 是否使用 TesorRT 预测引擎在device为```GPU```时生效。 |
| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```在device为```CPU```时生效。 |
| cpu_threads | CPU预测时使用CPU线程数量默认10 |
| precision | 预测时精度,可选:`fp32`, `fp16`, `int8`。 |
准备好预测模型并且修改dataset_config中数据集路径为正确的路径后启动测试
##### 4.1.1 使用测试脚本进行批量测试:
我们提供两个脚本文件用于测试模型自动化压缩的效果,分别是[test_ocr_det.sh](./test_ocr_det.sh)和[test_ocr_rec.sh](./test_ocr_rec.sh),这两个脚本都接收一个`model_type`参数用于区分是测试mobile模型还是server模型可选参数为`mobile`和`server`,使用示例:
```shell
# 测试mobile模型
bash test_ocr_det.sh mobile
bash test_ocr_rec.sh mobile
# 测试server模型
bash test_ocr_det.sh server
bash test_ocr_rec.sh server
```
##### 4.1.2 基于压缩模型进行基于GPU的批量测试
```shell
cd deploy/slim/auto_compression
python test_ocr.py \
--model_path save_quant_ppocrv4_det \
--config_path configs/ppocrv4/ppocrv4_det_qat_dist.yaml \
--device GPU \
--use_trt True \
--precision int8
```
##### 4.1.3 基于压缩前模型进行基于GPU的批量测试
```shell
cd deploy/slim/auto_compression
python test_ocr.py \
--model_path ch_PP-OCRv4_det_infer \
--config_path configs/ppocrv4/ppocrv4_rec_det_dist.yaml \
--device GPU \
--use_trt True \
--precision int8
```
##### 4.1.4 基于压缩模型进行基于CPU的批量测试
- MKLDNN预测
```shell
cd deploy/slim/auto_compression
python test_ocr.py \
--model_path save_quant_ppocrv4_det \
--config_path configs/ppocrv4/ppocrv4_det_qat_dist.yaml \
--device GPU \
--use_trt True \
--use_mkldnn=True \
--precision=int8 \
--cpu_threads=10
```
### 4.2 PaddleLite端侧部署
PaddleLite端侧部署可参考
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleOCR/tree/9cdab61d909eb595af849db885c257ca8c74cb57/deploy/lite)
## 5.FAQ
### 5.1 报错找不到模型文件或者数据集文件
如果在推理或者跑ACT时报错找不到模型文件或者数据集文件可以检查一下配置文件中的路径是否正确以det_mobile为例配置文件中的指定模型路径的配置信息如下
```yaml
Global:
model_dir: ./models/ch_PP-OCRv4_det_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
```
指定训练集验证集路径的配置信息如下:
```yaml
Train:
dataset:
name: SimpleDataSet
data_dir: datasets/chinese
label_file_list:
- datasets/chinese/zhongce_training_fix_1.6k.txt
- datasets/chinese/label_train_all_f4_part2.txt
- datasets/chinese/label_train_all_f4_part3.txt
- datasets/chinese/label_train_all_f4_part4.txt
- datasets/chinese/label_train_all_f4_part5.txt
- datasets/chinese/synth_en_my_clip.txt
- datasets/chinese/synth_ch_my_clip.txt
- datasets/chinese/synth_en_my_largeword_clip.txt
Eval:
dataset:
name: SimpleDataSet
data_dir: datasets/v4_4_test_dataset
label_file_list:
- datasets/v4_4_test_dataset/label.txt
```
### 5.2 软件环境一致,硬件不同导致精度差异很大?
这种情况是正常的TensorRT针对不同的硬件设备有着不同的优化方法同一种优化策略在不同硬件上可能有着截然不同的表现以本实验的ppocrv4_det_server为举例。截取[test_ocr.py](./test_ocr.py)中的一部分代码如下所示:
```python
if args.precision == 'int8' and "ppocrv4_det_server_qat_dist.yaml" in args.config_path:
# Use the following settings only when the hardware is a Tesla V100. If you are using
# a RTX 3090, use the settings in the else branch.
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=30,
precision_mode=precision_map[args.precision],
use_static=True,
use_calib_mode=False, )
pred_cfg.exp_disable_tensorrt_ops(["elementwise_add"])
else:
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=4,
precision_mode=precision_map[args.precision],
use_static=True,
use_calib_mode=False, )
```
当硬件为RTX 3090的时候使用else分支中的策略即可得到正常的结果但是当硬件是Tesla V100的时候必须使用if分支中的策略才能保证量化后精度不下降具体结果参考[benchmark](#2-benchmark)。

View File

@ -0,0 +1,163 @@
Global:
model_type: det
model_dir: ./models/ch_PP-OCRv4_det_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
algorithm: DB
Distillation:
alpha: 1.0
loss: l2
QuantAware:
use_pact: false
activation_bits: 8
is_full_quantize: false
onnx_format: false
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
weight_bits: 8
TrainConfig:
epochs: 2
eval_iter: 200
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.000005
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: datasets/chinese
label_file_list:
- datasets/chinese/zhongce_training_fix_1.6k.txt
- datasets/chinese/label_train_all_f4_part2.txt
- datasets/chinese/label_train_all_f4_part3.txt
- datasets/chinese/label_train_all_f4_part4.txt
- datasets/chinese/label_train_all_f4_part5.txt
- datasets/chinese/synth_en_my_clip.txt
- datasets/chinese/synth_ch_my_clip.txt
- datasets/chinese/synth_en_my_largeword_clip.txt
ratio_list:
- 0.3
- 0.2
- 0.1
- 0.2
- 0.2
- 0.1
- 0.2
- 0.2
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: datasets/v4_4_test_dataset
label_file_list:
- datasets/v4_4_test_dataset/label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
limit_side_len: 960
limit_type: max
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 10

View File

@ -0,0 +1,161 @@
Global:
model_type: det
model_dir: ./models/ch_PP-OCRv4_det_server_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
algorithm: DB
Distillation:
alpha: 1.0
loss: l2
QuantAware:
use_pact: false
activation_bits: 8
is_full_quantize: false
onnx_format: false
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
weight_bits: 8
TrainConfig:
epochs: 1
eval_iter: 200
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.000005
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: datasets/chinese
label_file_list:
- datasets/chinese/zhongce_training_fix_1.6k.txt
- datasets/chinese/label_train_all_f4_part2.txt
- datasets/chinese/label_train_all_f4_part3.txt
- datasets/chinese/label_train_all_f4_part4.txt
- datasets/chinese/label_train_all_f4_part5.txt
- datasets/chinese/synth_en_my_clip.txt
- datasets/chinese/synth_ch_my_clip.txt
- datasets/chinese/synth_en_my_largeword_clip.txt
ratio_list:
- 0.3
- 0.2
- 0.1
- 0.2
- 0.2
- 0.1
- 0.2
- 0.2
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 960
- 960
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 2
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: datasets/v4_4_test_dataset_small
label_file_list:
- datasets/v4_4_test_dataset_small/label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2

View File

@ -0,0 +1,115 @@
Global:
model_dir: ./models/ch_PP-OCRv4_rec_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
model_type: rec
algorithm: SVTR
character_dict_path: ./ppocr_keys_v1.txt
max_text_length: &max_text_length 25
use_space_char: true
Distillation:
alpha: [1.0, 1.0]
loss: ['skd', 'l2']
node:
- ['softmax_11.tmp_0']
- ['linear_170.tmp_1']
QuantAware:
use_pact: false
activation_bits: 8
is_full_quantize: false
onnx_format: false
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
weight_bits: 8
TrainConfig:
epochs: 1
eval_iter: 1000
logging_iter: 100
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00001
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
ignore_space: False
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: datasets/real_data/
label_file_list:
- datasets/real_data/train_list.txt
ext_op_transform_idx: 1
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 64
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: datasets/real_data/
label_file_list:
- datasets/real_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 4

View File

@ -0,0 +1,113 @@
Global:
model_dir: ./models/ch_PP-OCRv4_rec_server_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
model_type: rec
algorithm: SVTR
character_dict_path: ./ppocr_keys_v1.txt
max_text_length: &max_text_length 25
use_space_char: true
Distillation:
alpha: 1.0
loss: 'l2'
QuantAware:
use_pact: false
activation_bits: 8
is_full_quantize: false
onnx_format: false
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
weight_bits: 8
TrainConfig:
epochs: 1
eval_iter: 1000
logging_iter: 100
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00001
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
ignore_space: False
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: datasets/real_data/
ext_op_transform_idx: 1
label_file_list:
- datasets/real_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 64
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: datasets/real_data/
label_file_list:
- datasets/real_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 4

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,33 @@
import os
import cv2
dataset_path = 'datasets/v4_4_test_dataset'
annotation_file = 'datasets/v4_4_test_dataset/label.txt'
small_images_path = 'datasets/v4_4_test_dataset_small'
new_annotation_file = 'datasets/v4_4_test_dataset_small/label.txt'
os.makedirs(small_images_path, exist_ok=True)
with open(annotation_file, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
image_name = line.split(" ")[0]
image_path = os.path.join(dataset_path, image_name)
try:
image = cv2.imread(image_path)
height, width, _ = image.shape
# 如果图像的宽度和高度都小于2000而且长宽比小于2将其复制到新的文件夹并保存其标注信息
if height < 2000 and width < 2000:
if max(height, width)/min(height,width) < 2:
print(i, height, width, image_path)
small_image_path = os.path.join(small_images_path, image_name)
cv2.imwrite(small_image_path, image)
with open(new_annotation_file, 'a') as f:
f.write(f'{line}')
except:
continue

View File

@ -0,0 +1,164 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from tqdm import tqdm
import numpy as np
import argparse
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import get_logger
from paddleslim.auto_compression import AutoCompression
from paddleslim.common.dataloader import get_feed_vars
import sys
sys.path.append('../../../')
from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
logger = get_logger(__name__, level=logging.INFO)
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_name):
if isinstance(input_name, list) and len(input_name) == 1:
input_name = input_name[0]
def gen(): # 形成一个字典输入
for i, batch in enumerate(reader()):
yield {input_name: batch[0]}
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
post_process_class = build_post_process(all_config['PostProcess'],
global_config)
eval_class = build_metric(all_config['Metric'])
model_type = global_config['model_type']
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, batch in enumerate(val_loader):
images = batch[0]
try:
preds, = exe.run(compiled_test_program,
feed={test_feed_names[0]: images},
fetch_list=test_fetch_list)
except:
preds, _ = exe.run(compiled_test_program,
feed={test_feed_names[0]: images},
fetch_list=test_fetch_list)
batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))
if model_type == 'det':
preds_map = {'maps': preds}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
elif model_type == 'rec':
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
t.update()
metric = eval_class.get_metric()
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
if model_type == 'det':
return metric['hmean']
elif model_type == 'rec':
return metric['acc']
return metric
def main():
rank_id = paddle.distributed.get_rank()
if args.devices == 'gpu':
place = paddle.CUDAPlace(rank_id)
paddle.set_device('gpu')
else:
place = paddle.CPUPlace()
paddle.set_device('cpu')
global all_config, global_config
all_config = load_slim_config(args.config_path)
if "Global" not in all_config:
raise KeyError(f"Key 'Global' not found in config file. \n{all_config}")
global_config = all_config["Global"]
gpu_num = paddle.distributed.get_world_size()
train_dataloader = build_dataloader(all_config, 'Train', args.devices,
logger)
global val_loader
val_loader = build_dataloader(all_config, 'Eval', args.devices, logger)
if isinstance(all_config['TrainConfig']['learning_rate'],
dict) and all_config['TrainConfig']['learning_rate'][
'type'] == 'CosineAnnealingDecay':
steps = len(train_dataloader) * all_config['TrainConfig']['epochs']
all_config['TrainConfig']['learning_rate']['T_max'] = steps
print('total training steps:', steps)
global_config['input_name'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
ac = AutoCompression(
model_dir=global_config['model_dir'],
model_filename=global_config['model_filename'],
params_filename=global_config['params_filename'],
save_dir=args.save_dir,
config=all_config,
train_dataloader=reader_wrapper(train_dataloader,
global_config['input_name']),
eval_callback=eval_function if rank_id == 0 else None,
eval_dataloader=reader_wrapper(val_loader, global_config['input_name']))
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
main()

View File

@ -0,0 +1,288 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
import os
import sys
import cv2
import numpy as np
import paddle
import logging
import numpy as np
import argparse
from tqdm import tqdm
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import get_logger
import sys
sys.path.append('../../../')
from ppocr.data import build_dataloader
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config as PredictConfig
logger = get_logger(__name__, level=logging.INFO)
def find_images_with_bounding_size(dataset: paddle.io.Dataset):
max_length_index = -1
max_width_index = -1
min_length_index = -1
min_width_index = -1
max_length = float('-inf')
max_width = float('-inf')
min_length = float('inf')
min_width = float('inf')
for idx, data in enumerate(dataset):
image = np.array(data[0])
h, w = image.shape[-2:]
if h > max_length:
max_length = h
max_length_index = idx
if w > max_width:
max_width = w
max_width_index = idx
if h < min_length:
min_length = h
min_length_index = idx
if w < min_width:
min_width = w
min_width_index = idx
print(f"Found max image length: {max_length}, index: {max_length_index}")
print(f"Found max image width: {max_width}, index: {max_width_index}")
print(f"Found min image length: {min_length}, index: {min_length_index}")
print(f"Found min image width: {min_width}, index: {min_width_index}")
return paddle.io.Subset(dataset, [max_width_index,max_length_index,
min_width_index, min_length_index])
def load_predictor(args):
"""
load predictor func
"""
rerun_flag = False
model_file = os.path.join(args.model_path, args.model_filename)
params_file = os.path.join(args.model_path, args.params_filename)
pred_cfg = PredictConfig(model_file, params_file)
pred_cfg.enable_memory_optim()
pred_cfg.switch_ir_optim(True)
if args.device == "GPU":
pred_cfg.enable_use_gpu(100, 0)
else:
pred_cfg.disable_gpu()
pred_cfg.set_cpu_math_library_num_threads(args.cpu_threads)
if args.use_mkldnn:
pred_cfg.enable_mkldnn()
if args.precision == "int8":
pred_cfg.enable_mkldnn_int8({"conv2d"})
if global_config['model_type']=="rec":
# delete pass which influence the accuracy, please refer to https://github.com/PaddlePaddle/Paddle/issues/55290
pred_cfg.delete_pass("fc_mkldnn_pass")
pred_cfg.delete_pass("fc_act_mkldnn_fuse_pass")
if args.use_trt:
# To collect the dynamic shapes of inputs for TensorRT engine
dynamic_shape_file = os.path.join(args.model_path, "dynamic_shape.txt")
if os.path.exists(dynamic_shape_file):
pred_cfg.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
print("trt set dynamic shape done!")
precision_map = {
"fp16": PrecisionType.Half,
"fp32": PrecisionType.Float32,
"int8": PrecisionType.Int8
}
if args.precision == 'int8' and "ppocrv4_det_server_qat_dist.yaml" in args.config_path:
# Use the following settings only when the hardware is a Tesla V100. If you are using
# a RTX 3090, use the settings in the else branch.
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=30,
precision_mode=precision_map[args.precision],
use_static=True,
use_calib_mode=False, )
pred_cfg.exp_disable_tensorrt_ops(["elementwise_add"])
else:
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=4,
precision_mode=precision_map[args.precision],
use_static=True,
use_calib_mode=False, )
else:
# pred_cfg.disable_gpu()
# pred_cfg.set_cpu_math_library_num_threads(24)
pred_cfg.collect_shape_range_info(dynamic_shape_file)
print("Start collect dynamic shape...")
rerun_flag = True
predictor = create_predictor(pred_cfg)
return predictor, rerun_flag
def eval(args):
"""
eval mIoU func
"""
# DataLoader need run on cpu
paddle.set_device("cpu")
devices = paddle.device.get_device().split(':')[0]
val_loader = build_dataloader(all_config, 'Eval', devices, logger)
post_process_class = build_post_process(all_config['PostProcess'],
global_config)
eval_class = build_metric(all_config['Metric'])
model_type = global_config['model_type']
predictor, rerun_flag = load_predictor(args)
if rerun_flag:
eval_dataset = find_images_with_bounding_size(val_loader.dataset)
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
val_loader = paddle.io.DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
num_workers=4,
return_list=True)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
sample_nums = len(val_loader)
predict_time = 0.0
time_min = float("inf")
time_max = float("-inf")
print("Start evaluating ( total_iters: {}).".format(sample_nums))
for batch_id, batch in enumerate(val_loader):
images = np.array(batch[0])
batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))
# ori_shape = np.array(batch_numpy).shape[-2:]
input_handle.reshape(images.shape)
input_handle.copy_from_cpu(images)
start_time = time.time()
predictor.run()
preds = output_handle.copy_to_cpu()
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
if model_type == 'det':
preds_map = {'maps': preds}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
elif model_type == 'rec':
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
if rerun_flag:
if batch_id == 3:
print(
"***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
)
return
if batch_id % 100 == 0:
print("Eval iter:", batch_id)
sys.stdout.flush()
metric = eval_class.get_metric()
time_avg = predict_time / sample_nums
print(
"[Benchmark] Inference time(ms): min={}, max={}, avg={}".
format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
for k, v in metric.items():
print('{}:{}'.format(k, v))
sys.stdout.flush()
def main():
global all_config, global_config
all_config = load_slim_config(args.config_path)
global_config = all_config["Global"]
eval(args)
if __name__ == "__main__":
paddle.enable_static()
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", type=str, help="inference model filepath")
parser.add_argument(
"--config_path",
type=str,
default='./configs/ppocrv3_det_qat_dist.yaml',
help="path of compression strategy config.")
parser.add_argument(
"--model_filename",
type=str,
default="inference.pdmodel",
help="model file name")
parser.add_argument(
"--params_filename",
type=str,
default="inference.pdiparams",
help="params file name")
parser.add_argument(
"--device",
type=str,
default="GPU",
choices=["CPU", "GPU"],
help="Choose the device you want to run, it can be: CPU/GPU, default is GPU",
)
parser.add_argument(
"--precision",
type=str,
default="fp32",
choices=["fp32", "fp16", "int8"],
help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
)
parser.add_argument(
"--use_trt",
type=bool,
default=False,
help="Whether to use tensorrt engine or not.")
parser.add_argument(
"--use_mkldnn",
type=bool,
default=False,
help="Whether use mkldnn or not.")
parser.add_argument(
"--cpu_threads", type=int, default=10, help="Num of cpu threads.")
args = parser.parse_args()
main()

View File

@ -0,0 +1,88 @@
#!/bin/bash
# 本脚本用于测试PPOCRV4_det系列模型的自动压缩功能
## 运行脚本前,请确保处于以下环境:
## CUDA11.7+TensorRT8.4.2.4+Paddle2.5.2
model_type="$1"
if [ "$model_type" = "mobile" ]; then
echo "test ppocrv4_det_mobile model......"
## 启动自动化压缩训练
CUDA_VISIBLE_DEVICES=0 python run.py --save_dir ./models/det_mobile_qat --config_path configs/ppocrv4/ppocrv4_det_qat_dist.yaml
## GPU指标测试
### 量化前预期指标hmean:72.71%;time:4.7ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_det_infer --config ./configs/ppocrv4/ppocrv4_det_qat_dist.yaml --precision fp32 --use_trt True
### 量化后预期指标hmean:71.38%;time:3.3ms
python test_ocr.py --model_path ./models/det_mobile_qat --config ./configs/ppocrv4/ppocrv4_det_qat_dist.yaml --precision int8 --use_trt True
## CPU指标测试
### 量化前预期指标hmean:72.71%;time:198.4ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_det_infer --config ./configs/ppocrv4/ppocrv4_det_qat_dist.yaml --precision fp32 --use_mkldnn True --device CPU --cpu_threads 12
### 量化后预期指标hmean:72.30%;time:205.2ms
python test_ocr.py --model_path ./models/det_mobile_qat --config ./configs/ppocrv4/ppocrv4_det_qat_dist.yaml --precision int8 --use_mkldnn True --device CPU --cpu_threads 12
# 量化前模型推理
# GPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_det_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision fp32
# CPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_det_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision fp32
# 量化后模型推理
# GPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/det_mobile_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision int8
# CPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/det_mobile_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision int8
elif [ "$model_type" = "server" ]; then
echo "test ppocrv4_det_server model......"
## 启动自动化压缩训练
CUDA_VISIBLE_DEVICES=0 python run.py --save_dir ./models/det_server_qat --config_path configs/ppocrv4/ppocrv4_det_server_qat_dist.yaml
## GPU指标测试
### 量化前预期指标hmean:79.77%;time:50.0ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_det_server_infer --config ./configs/ppocrv4/ppocrv4_det_server_qat_dist.yaml --precision fp32 --use_trt True
### 量化后预期指标hmean:79.81%;time:42.4ms
python test_ocr.py --model_path ./models/det_server_qat --config ./configs/ppocrv4/ppocrv4_det_server_qat_dist.yaml --precision int8 --use_trt True
## CPU指标测试
### 量化前预期指标hmean:79.77%;time:2159.4ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_det_server_infer --config ./configs/ppocrv4/ppocrv4_det_server_qat_dist.yaml --precision fp32 --use_mkldnn True --device CPU --cpu_threads 12
### 量化后预期指标hmean:79.69%;time:1834.8ms
python test_ocr.py --model_path ./models/det_server_qat --config ./configs/ppocrv4/ppocrv4_det_server_qat_dist.yaml --precision int8 --use_mkldnn True --device CPU --cpu_threads 12
## 量化前模型推理
### GPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_det_server_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision fp32
### CPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_det_server_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision fp32
## 量化后模型推理
### GPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/det_server_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision int8
### CPU
python tools/infer/predict_det.py --det_model_dir deploy/slim/auto_compression/models/det_server_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision int8
else
echo "unrecgnized model_type"
fi

View File

@ -0,0 +1,88 @@
#!/bin/bash
# 本脚本用于测试PPOCRV4_rec系列模型的自动压缩功能
## 运行脚本前,请确保处于以下环境:
## CUDA11.2+TensorRT8.0.3.4+Paddle2.5.2
model_type="$1"
if [ "$model_type" = "mobile" ]; then
echo "test ppocrv4_rec_mobile model......"
## 启动自动化压缩训练
CUDA_VISIBLE_DEVICES=0 python run.py --save_dir ./models/rec_mobile_qat --config_path configs/ppocrv4/ppocrv4_rec_qat_dist.yaml
## GPU指标测试
### 量化前预期指标accuracy:78.92%;time:1.7ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_rec_infer --config ./configs/ppocrv4/ppocrv4_rec_qat_dist.yaml --precision fp32 --use_trt True
### 量化后预期指标accuracy:78.41%;time:1.4ms
python test_ocr.py --model_path ./models/rec_mobile_qat --config ./configs/ppocrv4/ppocrv4_rec_qat_dist.yaml --precision int8 --use_trt True
## CPU指标测试
### 量化前预期指标accuracy:78.92%;time:33.3ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_rec_infer --config ./configs/ppocrv4/ppocrv4_rec_qat_dist.yaml --precision fp32 --use_mkldnn True --device CPU --cpu_threads 12
### 量化后预期指标accuracy:78.44%;time:34.0ms
python test_ocr.py --model_path ./models/rec_mobile_qat --config ./configs/ppocrv4/ppocrv4_rec_qat_dist.yaml --precision int8 --use_mkldnn True --device CPU --cpu_threads 12
# 量化前模型推理
# GPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_rec_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision fp32
# CPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_rec_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision fp32
# 量化后模型推理
# GPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/rec_mobile_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision int8
# CPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/rec_mobile_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision int8
elif [ "$model_type" = "server" ]; then
echo "test ppocrv4_rec_server model......"
## 启动自动化压缩训练
CUDA_VISIBLE_DEVICES=0 python run.py --save_dir ./models/rec_server_qat --config_path configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml
## GPU指标测试
### 量化前预期指标accuracy:81.62%;time:4.0ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_rec_server_infer --config ./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml --precision fp32 --use_trt True
### 量化后预期指标accuracy:81.03%;time:2.0ms
python test_ocr.py --model_path ./models/rec_server_qat --config ./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml --precision int8 --use_trt True
## CPU指标测试
### 量化前预期指标accuracy:81.62%;time:62.5ms
python test_ocr.py --model_path ./models/ch_PP-OCRv4_rec_server_infer --config ./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml --precision fp32 --use_mkldnn True --device CPU --cpu_threads 12
### 量化后预期指标accuracy:81.00%;time:64.4ms
python test_ocr.py --model_path ./models/rec_server_qat --config ./configs/ppocrv4/ppocrv4_rec_server_qat_dist.yaml --precision int8 --use_mkldnn True --device CPU --cpu_threads 12
## 量化前模型推理
### GPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_rec_server_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision fp32
### CPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/ch_PP-OCRv4_rec_server_infer \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision fp32
## 量化后模型推理
### GPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/rec_server_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu True \
--use_tensorrt True --warmup True --precision int8
### CPU
python tools/infer/predict_det.py --rec_model_dir deploy/slim/auto_compression/models/rec_server_qat \
--benchmark True --image_dir deploy/slim/auto_compression/datasets/v4_4_test_dataset --use_gpu False \
--enable_mkldnn True --warmup True --precision int8
else
echo "unrecgnized model_type"
fi