From 2960b4c28b879594371f96e4ef9b50d2a73cfac4 Mon Sep 17 00:00:00 2001 From: lubin10 Date: Thu, 20 Jan 2022 09:34:33 +0000 Subject: [PATCH 1/8] support onnx inference for cls; add readme --- deploy/paddle2onnx/readme.md | 63 ++++++++++++++++++++++++++++++++++++ deploy/python/predict_cls.py | 27 +++++++++++----- deploy/python/predict_rec.py | 26 ++++++++++----- deploy/utils/predictor.py | 25 ++++++++++++-- 4 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 deploy/paddle2onnx/readme.md diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md new file mode 100644 index 000000000..8c357bf74 --- /dev/null +++ b/deploy/paddle2onnx/readme.md @@ -0,0 +1,63 @@ +# paddle2onnx 模型转化与预测 + +本章节介绍 ResNet50_vd 模型如何转化为 ONNX 模型,并基于 ONNX 引擎预测。 + +## 1. 环境准备 + +需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境 + +### Paddle2ONNX + +Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 +更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md) + +- 安装 Paddle2ONNX +``` +python3.7 -m pip install paddle2onnx +``` + +- 安装 ONNX 运行时 +``` +python3.7 -m pip install onnxruntime +``` + +## 2. 模型转换 + +- ResNet50_vd inference模型下载 + +``` +cd deploy +mkdir models && cd models +wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar && tar xf ResNet50_vd_infer.tar +cd .. +``` + +- 模型转换 + +使用 Paddle2ONNX 将Paddle静态图模型转换为ONNX模型格式: +``` +paddle2onnx --model_dir=./models/ResNet50_vd_infer/ \ +--model_filename=inference.pdmodel \ +--params_filename=inference.pdiparams \ +--save_file=./models/ResNet50_vd_infer/inference.onnx \ +--opset_version=10 \ +--enable_onnx_checker=True +``` + +执行完毕后,ONNX 模型 `inference.onnx` 会被保存在 `./models/ResNet50_vd_infer/` 路径下 + +## 3. onnx 预测 + +执行如下命令: +``` +python3.7 python/predict_cls.py \ +-c configs/inference_cls.yaml \ +-o Global.use_onnx=True \ +-o Global.use_gpu=False \ +-o Global.inference_model_dir=./models/ResNet50_vd_infer \ +``` + +结果如下: +``` +ILSVRC2012_val_00000010.jpeg: class id(s): [153, 204, 229, 332, 155], score(s): [0.69, 0.10, 0.02, 0.01, 0.01], label_name(s): ['Maltese dog, Maltese terrier, Maltese', 'Lhasa, Lhasa apso', 'Old English sheepdog, bobtail', 'Angora, Angora rabbit', 'Shih-Tzu'] +``` diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index 7d6d86164..574caa3e7 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -67,12 +67,17 @@ class ClsPredictor(Predictor): warmup=2) def predict(self, images): - input_names = self.paddle_predictor.get_input_names() - input_tensor = self.paddle_predictor.get_input_handle(input_names[0]) + use_onnx = self.args.get("use_onnx", False) + if not use_onnx: + input_names = self.predictor.get_input_names() + input_tensor = self.predictor.get_input_handle(input_names[0]) + + output_names = self.predictor.get_output_names() + output_tensor = self.predictor.get_output_handle(output_names[0]) + else: + input_names = self.predictor.get_inputs()[0].name + output_names = self.predictor.get_outputs()[0].name - output_names = self.paddle_predictor.get_output_names() - output_tensor = self.paddle_predictor.get_output_handle(output_names[ - 0]) if self.benchmark: self.auto_logger.times.start() if not isinstance(images, (list, )): @@ -84,9 +89,15 @@ class ClsPredictor(Predictor): if self.benchmark: self.auto_logger.times.stamp() - input_tensor.copy_from_cpu(image) - self.paddle_predictor.run() - batch_output = output_tensor.copy_to_cpu() + if not use_onnx: + input_tensor.copy_from_cpu(image) + self.predictor.run() + batch_output = output_tensor.copy_to_cpu() + else: + batch_output = self.predictor.run( + output_names=[output_names], + input_feed={input_names: image})[0] + if self.benchmark: self.auto_logger.times.stamp() if self.postprocess is not None: diff --git a/deploy/python/predict_rec.py b/deploy/python/predict_rec.py index 0d0e74159..e24335810 100644 --- a/deploy/python/predict_rec.py +++ b/deploy/python/predict_rec.py @@ -58,12 +58,16 @@ class RecPredictor(Predictor): warmup=2) def predict(self, images, feature_normalize=True): - input_names = self.paddle_predictor.get_input_names() - input_tensor = self.paddle_predictor.get_input_handle(input_names[0]) + use_onnx = self.args.get("use_onnx", False) + if not use_onnx: + input_names = self.predictor.get_input_names() + input_tensor = self.predictor.get_input_handle(input_names[0]) - output_names = self.paddle_predictor.get_output_names() - output_tensor = self.paddle_predictor.get_output_handle(output_names[ - 0]) + output_names = self.predictor.get_output_names() + output_tensor = self.predictor.get_output_handle(output_names[0]) + else: + input_names = self.predictor.get_inputs()[0].name + output_names = self.predictor.get_outputs()[0].name if self.benchmark: self.auto_logger.times.start() @@ -76,9 +80,15 @@ class RecPredictor(Predictor): if self.benchmark: self.auto_logger.times.stamp() - input_tensor.copy_from_cpu(image) - self.paddle_predictor.run() - batch_output = output_tensor.copy_to_cpu() + if not use_onnx: + input_tensor.copy_from_cpu(image) + self.predictor.run() + batch_output = output_tensor.copy_to_cpu() + else: + batch_output = self.predictor.run( + output_names=[output_names], + input_feed={input_names: image})[0] + if self.benchmark: self.auto_logger.times.stamp() diff --git a/deploy/utils/predictor.py b/deploy/utils/predictor.py index 1c49c9dd5..d44ae03df 100644 --- a/deploy/utils/predictor.py +++ b/deploy/utils/predictor.py @@ -28,8 +28,12 @@ class Predictor(object): if args.use_fp16 is True: assert args.use_tensorrt is True self.args = args - self.paddle_predictor, self.config = self.create_paddle_predictor( - args, inference_model_dir) + if self.args.get("use_onnx", False): + self.predictor, self.config = self.create_onnx_predictor( + args, inference_model_dir) + else: + self.predictor, self.config = self.create_paddle_predictor( + args, inference_model_dir) def predict(self, image): raise NotImplementedError @@ -69,3 +73,20 @@ class Predictor(object): predictor = create_predictor(config) return predictor, config + + def create_onnx_predictor(self, args, inference_model_dir=None): + import onnxruntime as ort + if inference_model_dir is None: + inference_model_dir = args.inference_model_dir + model_file = os.path.join(inference_model_dir, "inference.onnx") + config = ort.SessionOptions() + if args.use_gpu: + raise ValueError( + "onnx inference now only supports cpu! please specify use_gpu false." + ) + else: + config.intra_op_num_threads = args.cpu_num_threads + if args.ir_optim: + config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + predictor = ort.InferenceSession(model_file, sess_options=config) + return predictor, config From bc52b8009659420df22411fe8c8549331afd008e Mon Sep 17 00:00:00 2001 From: lubin10 Date: Thu, 20 Jan 2022 09:39:25 +0000 Subject: [PATCH 2/8] update readme.md --- deploy/paddle2onnx/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md index 8c357bf74..510d051f5 100644 --- a/deploy/paddle2onnx/readme.md +++ b/deploy/paddle2onnx/readme.md @@ -34,7 +34,7 @@ cd .. - 模型转换 -使用 Paddle2ONNX 将Paddle静态图模型转换为ONNX模型格式: +使用 Paddle2ONNX 将 Paddle 静态图模型转换为 ONNX 模型格式: ``` paddle2onnx --model_dir=./models/ResNet50_vd_infer/ \ --model_filename=inference.pdmodel \ From b8d562e3160e4f750f97edefbc8213236c41c408 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Thu, 20 Jan 2022 17:40:38 +0800 Subject: [PATCH 3/8] Update readme.md --- deploy/paddle2onnx/readme.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md index 510d051f5..3420968dc 100644 --- a/deploy/paddle2onnx/readme.md +++ b/deploy/paddle2onnx/readme.md @@ -4,11 +4,7 @@ ## 1. 环境准备 -需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境 - -### Paddle2ONNX - -Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 +需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境。Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md) - 安装 Paddle2ONNX From f7516d6a226f0fc74b0b3ea8da88f11102580825 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Thu, 20 Jan 2022 17:41:03 +0800 Subject: [PATCH 4/8] Update readme.md --- deploy/paddle2onnx/readme.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md index 3420968dc..1bba79148 100644 --- a/deploy/paddle2onnx/readme.md +++ b/deploy/paddle2onnx/readme.md @@ -4,7 +4,8 @@ ## 1. 环境准备 -需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境。Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 +需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境。 +Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md) - 安装 Paddle2ONNX From 3f0caad78776f31a8ab8590952d1e64a41a40c43 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Thu, 20 Jan 2022 17:41:17 +0800 Subject: [PATCH 5/8] Update readme.md --- deploy/paddle2onnx/readme.md | 1 + 1 file changed, 1 insertion(+) diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md index 1bba79148..5db45ad9c 100644 --- a/deploy/paddle2onnx/readme.md +++ b/deploy/paddle2onnx/readme.md @@ -5,6 +5,7 @@ ## 1. 环境准备 需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境。 + Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md) From 605f6a2c6b2f54b5cc8633e176aa8dbc43c9efb6 Mon Sep 17 00:00:00 2001 From: lubin10 Date: Thu, 20 Jan 2022 11:23:17 +0000 Subject: [PATCH 6/8] add onnx tipc for ResNet50_vd --- ...al_normal_paddle2onnx_python_linux_cpu.txt | 15 +++++++++++ test_tipc/prepare.sh | 12 +++++++++ test_tipc/test_paddle2onnx.sh | 27 ++++++++++--------- 3 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt diff --git a/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 000000000..dba07d176 --- /dev/null +++ b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,15 @@ +===========================paddle2onnx_params=========================== +model_name:ResNet50_vd +python:python3.7 +2onnx: paddle2onnx +--model_dir:./deploy/models/ResNet50_vd_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--save_file:./deploy/models/ResNet50_vd_infer/inference.onnx +--opset_version:10 +--enable_onnx_checker:True +inference: python/predict_cls.py -c configs/inference_cls.yaml +Global.use_onnx:True +Global.inference_model_dir:models/ResNet50_vd_infer/ +Global.use_gpu:False +Global.infer_imgs:./images/ILSVRC2012_val_00000010.jpeg diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 646e7f4d8..89cdd505e 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -165,3 +165,15 @@ if [ ${MODE} = "serving_infer" ];then cd ./deploy/paddleserving wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar && tar xf ResNet50_vd_infer.tar fi + +if [ ${MODE} = "paddle2onnx_infer" ];then + # prepare paddle2onnx env + python_name=$(func_parser_value "${lines[2]}") + ${python_name} -m pip install install paddle2onnx + ${python_name} -m pip install onnxruntime + + # wget model + cd deploy && mkdir models && cd models + wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar && tar xf ResNet50_vd_infer.tar + cd ../../ +fi diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 300c61770..7d8051fb8 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -11,7 +11,7 @@ python=$(func_parser_value "${lines[2]}") # parser params -dataline=$(awk 'NR==1, NR==12{print}' $FILENAME) +dataline=$(awk 'NR==1, NR==15{print}' $FILENAME) IFS=$'\n' lines=(${dataline}) @@ -33,12 +33,14 @@ enable_onnx_checker_key=$(func_parser_key "${lines[9]}") enable_onnx_checker_value=$(func_parser_value "${lines[9]}") # parser onnx inference inference_py=$(func_parser_value "${lines[10]}") -use_gpu_key=$(func_parser_key "${lines[11]}") -use_gpu_value=$(func_parser_value "${lines[11]}") -det_model_key=$(func_parser_key "${lines[12]}") -image_dir_key=$(func_parser_key "${lines[13]}") -image_dir_value=$(func_parser_value "${lines[13]}") - +use_onnx_key=$(func_parser_key "${lines[11]}") +use_onnx_value=$(func_parser_value "${lines[11]}") +inference_model_dir_key=$(func_parser_key "${lines[12]}") +inference_model_dir_value=$(func_parser_value "${lines[12]}") +inference_hardware_key=$(func_parser_key "${lines[13]}") +inference_hardware_value=$(func_parser_value "${lines[13]}") +inference_imgs_key=$(func_parser_key "${lines[14]}") +inference_imgs_value=$(func_parser_value "${lines[14]}") LOG_PATH="./test_tipc/output" mkdir -p ./test_tipc/output @@ -50,7 +52,7 @@ function func_paddle2onnx(){ _script=$1 # paddle2onnx - _save_log_path="${LOG_PATH}/paddle2onnx_infer_cpu.log" + _save_log_path=".${LOG_PATH}/paddle2onnx_infer_cpu.log" set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}") set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") @@ -62,10 +64,11 @@ function func_paddle2onnx(){ last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" # python inference - set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu_value}") - set_model_dir=$(func_set_params "${det_model_key}" "${save_file_value}") - set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") - infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " + set_model_dir=$(func_set_params "${inference_model_dir_key}" "${inference_model_dir_value}") + set_use_onnx=$(func_set_params "${use_onnx_key}" "${use_onnx_value}") + set_hardware=$(func_set_params "${inference_hardware_key}" "${inference_hardware_value}") + set_infer_imgs=$(func_set_params "${inference_imgs_key}" "${inference_imgs_value}") + infer_model_cmd="cd deploy && ${python} ${inference_py} -o ${set_model_dir} -o ${set_use_onnx} -o ${set_hardware} -o ${set_infer_imgs} >${_save_log_path} 2>&1 && cd ../" eval $infer_model_cmd status_check $last_status "${infer_model_cmd}" "${status_log}" } From 8276dccd7d28352a6d6863166a9d8de1971878c9 Mon Sep 17 00:00:00 2001 From: lubin10 Date: Thu, 20 Jan 2022 11:30:19 +0000 Subject: [PATCH 7/8] rm useless config --- ...inux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt | 1 - test_tipc/test_paddle2onnx.sh | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt index dba07d176..163bb4842 100644 --- a/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt +++ b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -12,4 +12,3 @@ inference: python/predict_cls.py -c configs/inference_cls.yaml Global.use_onnx:True Global.inference_model_dir:models/ResNet50_vd_infer/ Global.use_gpu:False -Global.infer_imgs:./images/ILSVRC2012_val_00000010.jpeg diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 7d8051fb8..850fc9049 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -11,7 +11,7 @@ python=$(func_parser_value "${lines[2]}") # parser params -dataline=$(awk 'NR==1, NR==15{print}' $FILENAME) +dataline=$(awk 'NR==1, NR==14{print}' $FILENAME) IFS=$'\n' lines=(${dataline}) @@ -39,8 +39,6 @@ inference_model_dir_key=$(func_parser_key "${lines[12]}") inference_model_dir_value=$(func_parser_value "${lines[12]}") inference_hardware_key=$(func_parser_key "${lines[13]}") inference_hardware_value=$(func_parser_value "${lines[13]}") -inference_imgs_key=$(func_parser_key "${lines[14]}") -inference_imgs_value=$(func_parser_value "${lines[14]}") LOG_PATH="./test_tipc/output" mkdir -p ./test_tipc/output @@ -67,8 +65,7 @@ function func_paddle2onnx(){ set_model_dir=$(func_set_params "${inference_model_dir_key}" "${inference_model_dir_value}") set_use_onnx=$(func_set_params "${use_onnx_key}" "${use_onnx_value}") set_hardware=$(func_set_params "${inference_hardware_key}" "${inference_hardware_value}") - set_infer_imgs=$(func_set_params "${inference_imgs_key}" "${inference_imgs_value}") - infer_model_cmd="cd deploy && ${python} ${inference_py} -o ${set_model_dir} -o ${set_use_onnx} -o ${set_hardware} -o ${set_infer_imgs} >${_save_log_path} 2>&1 && cd ../" + infer_model_cmd="cd deploy && ${python} ${inference_py} -o ${set_model_dir} -o ${set_use_onnx} -o ${set_hardware} >${_save_log_path} 2>&1 && cd ../" eval $infer_model_cmd status_check $last_status "${infer_model_cmd}" "${status_log}" } From b2dbbc3ca40f06c5d53277957e9fbc425579e35a Mon Sep 17 00:00:00 2001 From: lubin10 Date: Fri, 21 Jan 2022 05:45:21 +0000 Subject: [PATCH 8/8] fix predict_det.py error --- deploy/python/predict_det.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deploy/python/predict_det.py b/deploy/python/predict_det.py index 7b23e6220..e4e0a24a6 100644 --- a/deploy/python/predict_det.py +++ b/deploy/python/predict_det.py @@ -109,17 +109,16 @@ class DetPredictor(Predictor): ''' inputs = self.preprocess(image) np_boxes = None - input_names = self.paddle_predictor.get_input_names() + input_names = self.predictor.get_input_names() for i in range(len(input_names)): - input_tensor = self.paddle_predictor.get_input_handle(input_names[ - i]) + input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor.copy_from_cpu(inputs[input_names[i]]) t1 = time.time() - self.paddle_predictor.run() - output_names = self.paddle_predictor.get_output_names() - boxes_tensor = self.paddle_predictor.get_output_handle(output_names[0]) + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) np_boxes = boxes_tensor.copy_to_cpu() t2 = time.time()