pull/7/head
shippingwang 2020-04-10 13:36:20 +00:00
parent 395b600506
commit 79f6b5599e
3 changed files with 116 additions and 99 deletions

View File

@ -18,7 +18,7 @@ Paddle 的模型保存有多种不同的形式,大体可分为两类:
└── res5c_branch2c_weights
```
2. inference 模型fluid.io.save_inference_model保存的模型
一般是模型训练完成后保存的固化模型用于预测部署。与persistable 模型相比inference 模型会额外保存模型的结构信息,用于配合权重文件构成完整的模型。如下所示,`model` 中保存的即为模型的结构信息。
一般是模型训练完成后保存的固化模型,用于预测部署。与 persistable 模型相比inference 模型会额外保存模型的结构信息,用于配合权重文件构成完整的模型。如下所示,`model` 中保存的即为模型的结构信息。
```
resnet50-vd-persistable/
├── bn2a_branch1_mean
@ -40,9 +40,10 @@ Paddle 的模型保存有多种不同的形式,大体可分为两类:
```
在 Paddle 中训练引擎和预测引擎都支持模型的预测推理只不过预测引擎不需要进行反向操作因此可以进行定制型的优化如层融合kernel 选择等),达到低时延、高吞吐的目的。训练引擎既可以支持 persistable 模型,也可以支持 inference 模型,而预测引擎只支持 inference 模型,因此也就衍生出了三种不同的预测方式:
1. 训练引擎 + persistable 模型
2. 训练引擎 + inference 模型
3. 预测引擎 + inference 模型
1. 预测引擎 + inference 模型
2. 训练引擎 + persistable 模型
3. 训练引擎 + inference 模型
不管是何种预测方式,基本都包含以下几个主要的步骤:
+ 构建引擎
@ -50,7 +51,7 @@ Paddle 的模型保存有多种不同的形式,大体可分为两类:
+ 执行预测
+ 预测结果解析
不同预测方式,主要有两方面不同:构建引擎和执行预测,以下的几个部分我们会具体介绍。
不同预测方式,主要有两方面不同:构建引擎和执行预测,以下的几个部分我们会具体介绍。
## 二、模型转换
@ -94,94 +95,7 @@ python tools/export_model.py \
--output_path=model和params保存路径
```
## 三、训练引擎 + persistable 模型预测
在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
```python
python tools/infer.py \
--image_file=待预测的图片文件路径 \
--model=模型名称 \
--pretrained_model=persistable 模型路径 \
--use_gpu=True
```
训练引擎构建:
由于 persistable 模型不包含模型的结构信息,因此需要先构建出网络结构,然后 load 权重来构建训练引擎。
```python
import fluid
from ppcls.modeling.architectures.resnet_vd import ResNet50_vd
place = fluid.CPUPlace()
exe = fluid.Executor(place)
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
image = create_input()
image = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32')
out = ResNet50_vd.net(input=input, class_dim=1000)
infer_prog = infer_prog.clone(for_test=True)
fluid.load(program=infer_prog, model_path=persistable 模型路径, executor=exe)
```
执行预测:
```python
outputs = exe.run(infer_prog,
feed={image.name: data},
fetch_list=[out.name],
return_numpy=False)
```
上述执行预测时候的参数说明可以参考官网 [fluid.Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html)
## 四、训练引擎 + inference 模型预测
在模型库的 `tools/py_infer.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
```python
python tools/py_infer.py \
--image_file=图片路径 \
--model_dir=模型的存储路径 \
--model_file=保存的模型文件 \
--params_file=保存的参数文件 \
--use_gpu=True
```
训练引擎构建:
由于 inference 模型已包含模型的结构信息,因此不再需要提前构建模型结构,直接 load 模型结构和权重文件来构建训练引擎。
```python
import fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[program, feed_names, fetch_lists] = fluid.io.load_inference_model(
模型的存储路径,
exe,
model_filename=保存的模型文件,
params_filename=保存的参数文件)
compiled_program = fluid.compiler.CompiledProgram(program)
```
> `load_inference_model` 即支持零散的权重文件集合,也支持融合后的单个权重文件。
执行预测:
```python
outputs = exe.run(compiled_program,
feed={feed_names[0]: data},
fetch_list=fetch_lists,
return_numpy=False)
```
上述执行预测时候的参数说明可以参考官网 [fluid.Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html)
## 五、预测引擎 + inference 模型预测
## 三、预测引擎 + inference 模型预测
在模型库的 `tools/predict.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
@ -199,11 +113,11 @@ python ./predict.py \
+ `model_file`(简写 m):模型文件路径,如 `./resnet50-vd/model`
+ `params_file`(简写 p):权重文件路径,如 `./resnet50-vd/params`
+ `batch_size`(简写 b):批大小,如 `1`
+ `ir_optim`:是否使用 `IR` 优化
+ `use_tensorrt`:是否使用 TesorRT 预测引擎
+ `ir_optim`:是否使用 `IR` 优化默认值True
+ `use_tensorrt`:是否使用 TesorRT 预测引擎默认值True
+ `gpu_mem` 初始分配GPU显存以M单位
+ `use_gpu`:是否使用 GPU 预测
+ `enable_benchmark`是否启用benchmark
+ `use_gpu`:是否使用 GPU 预测默认值True
+ `enable_benchmark`是否启用benchmark默认值False
+ `model_name`:模型名字
注意:
@ -246,3 +160,103 @@ predictor.zero_copy_run()
默认情况下Paddle 的 wheel 包中是不包含 TensorRT 预测引擎的,如果需要使用 TensorRT 进行预测优化,需要自己编译对应的 wheel 包,编译方式可以参考 Paddle 的编译指南 [Paddle 编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/fromsource.html)。
## 四、训练引擎 + persistable 模型预测
在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
```python
python tools/infer.py \
--i=待预测的图片文件路径 \
--m=模型名称 \
--p=persistable 模型路径 \
--use_gpu=True
```
参数说明:
+ `image_file`(简写 i):待预测的图片文件路径,如 `./test.jpeg`
+ `model_file`(简写 m):模型文件路径,如 `./resnet50-vd/model`
+ `params_file`(简写 p):权重文件路径,如 `./resnet50-vd/params`
+ `use_gpu` : 是否开启GPU训练默认值True
训练引擎构建:
由于 persistable 模型不包含模型的结构信息,因此需要先构建出网络结构,然后 load 权重来构建训练引擎。
```python
import fluid
from ppcls.modeling.architectures.resnet_vd import ResNet50_vd
place = fluid.CPUPlace()
exe = fluid.Executor(place)
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
image = create_input()
image = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32')
out = ResNet50_vd.net(input=input, class_dim=1000)
infer_prog = infer_prog.clone(for_test=True)
fluid.load(program=infer_prog, model_path=persistable 模型路径, executor=exe)
```
执行预测:
```python
outputs = exe.run(infer_prog,
feed={image.name: data},
fetch_list=[out.name],
return_numpy=False)
```
上述执行预测时候的参数说明可以参考官网 [fluid.Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html)
## 五、训练引擎 + inference 模型预测
在模型库的 `tools/py_infer.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
```python
python tools/py_infer.py \
--i=图片路径 \
--d=模型的存储路径 \
--m=保存的模型文件 \
--p=保存的参数文件 \
--use_gpu=True
```
+ `image_file`(简写 i):待预测的图片文件路径,如 `./test.jpeg`
+ `model_file`(简写 m):模型文件路径,如 `./resnet50_vd/model`
+ `params_file`(简写 p):权重文件路径,如 `./resnet50_vd/params`
+ `model_dir`(简写d):模型路径,如`./resent50_vd`
+ `use_gpu`是否开启GPU默认值True
训练引擎构建:
由于 inference 模型已包含模型的结构信息,因此不再需要提前构建模型结构,直接 load 模型结构和权重文件来构建训练引擎。
```python
import fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[program, feed_names, fetch_lists] = fluid.io.load_inference_model(
模型的存储路径,
exe,
model_filename=保存的模型文件,
params_filename=保存的参数文件)
compiled_program = fluid.compiler.CompiledProgram(program)
```
> `load_inference_model` 既支持零散的权重文件集合,也支持融合后的单个权重文件。
执行预测:
```python
outputs = exe.run(compiled_program,
feed={feed_names[0]: data},
fetch_list=fetch_lists,
return_numpy=False)
```
上述执行预测时候的参数说明可以参考官网 [fluid.Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html)

View File

@ -66,7 +66,7 @@ python eval.py \
## 3、模型推理
PaddlePaddle提供三种方式进行预测推理接下来介绍如何用预测引擎进行推理
首先,对预测模型进行导出
首先,对训练好的模型进行转换
```bash
python tools/export_model.py \
-model=模型名字 \
@ -83,4 +83,4 @@ python tools/predict.py \
--use_gpu=1 \
--use_tensorrt=True
```
更多推理方式和实验请参考[分类预测框架](../extension/paddle_inference.md)
更多使用方法和推理方式请参考[分类预测框架](../extension/paddle_inference.md)

View File

@ -103,6 +103,9 @@ def main():
assert args.use_gpu == True
assert args.model_name is not None
assert args.use_tensorrt == True
# HALF precission predict only work when using tensorrt
if args.use_fp16==True:
assert args.use_tensorrt == True
operators = create_operators()
predictor = create_predictor(args)