From 16d70fb74bbeded98c134df0ab69bfc4b29eff31 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 07:08:56 +0000 Subject: [PATCH 1/9] add cpu_math_library_num_threads params --- tools/infer/utility.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b5fe3ba98..41ca445fa 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -21,6 +21,9 @@ import json from PIL import Image, ImageDraw, ImageFont import math from paddle import inference +import time +from ppocr.utils.logging import get_logger +logger = get_logger() def parse_args(): @@ -98,6 +101,7 @@ def parse_args(): parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False) @@ -140,14 +144,15 @@ def create_predictor(args, mode, logger): max_batch_size=args.max_batch_size) else: config.disable_gpu() - config.set_cpu_math_library_num_threads(6) + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + config.set_cpu_math_library_num_threads( + 10) # default cpu threads as 10 if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 - #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) - args.rec_batch_num = 1 # enable memory optim config.enable_memory_optim() From 9a68a6123a04bda96dcce6eed52fcc98b2ac06f0 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 07:09:52 +0000 Subject: [PATCH 2/9] fix hub server error --- tools/infer/utility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 41ca445fa..f4e62b87a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -215,7 +215,7 @@ def draw_ocr(image, txts=None, scores=None, drop_score=0.5, - font_path="./doc/simfang.ttf"): + font_path="./doc/fonts/simfang.ttf"): """ Visualize the results of OCR detection and recognition args: From c946b386fd8bc745bd4b8ac59f30f94067bc5483 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 07:18:40 +0000 Subject: [PATCH 3/9] add trt min max opt shape --- tools/infer/utility.py | 80 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index f4e62b87a..8ad916bee 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -139,9 +139,83 @@ def create_predictor(args, mode, logger): config.enable_use_gpu(args.gpu_mem, 0) if args.use_tensorrt: config.enable_tensorrt_engine( - precision_mode=inference.PrecisionType.Half - if args.use_fp16 else inference.PrecisionType.Float32, - max_batch_size=args.max_batch_size) + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.max_batch_size, + min_subgraph_size=3) # skip the minmum trt subgraph + if mode == "det" and "mobile" in model_file_path: + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_92.tmp_0": [1, 96, 20, 20], + "conv2d_91.tmp_0": [1, 96, 10, 10], + "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], + "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20], + "elementwise_add_7": [1, 56, 2, 2], + "nearest_interp_v2_0.tmp_0": [1, 96, 2, 2] + } + max_input_shape = { + "x": [1, 3, 2000, 2000], + "conv2d_92.tmp_0": [1, 96, 400, 400], + "conv2d_91.tmp_0": [1, 96, 200, 200], + "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], + "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400], + "elementwise_add_7": [1, 56, 400, 400], + "nearest_interp_v2_0.tmp_0": [1, 96, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_92.tmp_0": [1, 96, 160, 160], + "conv2d_91.tmp_0": [1, 96, 80, 80], + "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], + "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160], + "elementwise_add_7": [1, 56, 40, 40], + "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] + } + if mode == "det" and "server" in model_file_path: + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_59.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20] + } + max_input_shape = { + "x": [1, 3, 2000, 2000], + "conv2d_59.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_59.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160] + } + elif mode == "rec": + min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + elif mode == "cls": + min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + else: config.disable_gpu() if hasattr(args, "cpu_threads"): From 0707f743487f8a4707aa7d6dedd78bc3e0d3f0fe Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 07:36:44 +0000 Subject: [PATCH 4/9] delete debug code --- tools/infer/utility.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 8ad916bee..ff0b60b9c 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -497,22 +497,4 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): if __name__ == '__main__': - test_img = "./doc/test_v2" - predict_txt = "./doc/predict.txt" - f = open(predict_txt, 'r') - data = f.readlines() - img_path, anno = data[0].strip().split('\t') - img_name = os.path.basename(img_path) - img_path = os.path.join(test_img, img_name) - image = Image.open(img_path) - - data = json.loads(anno) - boxes, txts, scores = [], [], [] - for dic in data: - boxes.append(dic['points']) - txts.append(dic['transcription']) - scores.append(round(dic['scores'], 3)) - - new_img = draw_ocr(image, boxes, txts, scores) - - cv2.imwrite(img_name, new_img) + pass From 5b3e7a33935bab2bfbca37cc7ada89a14922a28b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 07:56:06 +0000 Subject: [PATCH 5/9] support cpp trt predict --- deploy/cpp_infer/src/ocr_det.cpp | 38 ++++++++++++++++++- deploy/cpp_infer/src/ocr_rec.cpp | 9 +++++ deploy/cpp_infer/src/preprocess_op.cpp | 51 ++++++++------------------ 3 files changed, 61 insertions(+), 37 deletions(-) mode change 100755 => 100644 deploy/cpp_infer/src/preprocess_op.cpp diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 9bfee6138..33ad468a3 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -30,6 +30,42 @@ void DBDetector::LoadModel(const std::string &model_dir) { this->use_fp16_ ? paddle_infer::Config::Precision::kHalf : paddle_infer::Config::Precision::kFloat32, false, false); + std::map> min_input_shape = { + {"x", {1, 3, 50, 50}}, + {"conv2d_92.tmp_0", {1, 96, 20, 20}}, + {"conv2d_91.tmp_0", {1, 96, 10, 10}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}}, + {"elementwise_add_7", {1, 56, 2, 2}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}}; + std::map> max_input_shape = { + {"x", {1, 3, this->max_side_len_, this->max_side_len_}}, + {"conv2d_92.tmp_0", {1, 96, 400, 400}}, + {"conv2d_91.tmp_0", {1, 96, 200, 200}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}}, + {"elementwise_add_7", {1, 56, 400, 400}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}}; + std::map> opt_input_shape = { + {"x", {1, 3, 640, 640}}, + {"conv2d_92.tmp_0", {1, 96, 160, 160}}, + {"conv2d_91.tmp_0", {1, 96, 80, 80}}, + {"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}}, + {"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}}, + {"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}}, + {"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}}, + {"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}}, + {"elementwise_add_7", {1, 56, 40, 40}}, + {"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}}; + + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); } } else { config.DisableGpu(); @@ -48,7 +84,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { config.SwitchIrOptim(true); config.EnableMemoryOptim(); - config.DisableGlogInfo(); + // config.DisableGlogInfo(); this->predictor_ = CreatePredictor(config); } diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 76873dad3..e3f880336 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -105,6 +105,15 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { this->use_fp16_ ? paddle_infer::Config::Precision::kHalf : paddle_infer::Config::Precision::kFloat32, false, false); + std::map> min_input_shape = { + {"x", {1, 3, 32, 10}}}; + std::map> max_input_shape = { + {"x", {1, 3, 32, 2000}}}; + std::map> opt_input_shape = { + {"x", {1, 3, 32, 320}}}; + + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); } } else { config.DisableGpu(); diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp old mode 100755 new mode 100644 index fb7590e35..28590e185 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -77,19 +77,13 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, int resize_h = int(float(h) * ratio); int resize_w = int(float(w) * ratio); - + resize_h = max(int(round(float(resize_h) / 32) * 32), 32); resize_w = max(int(round(float(resize_w) / 32) * 32), 32); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); - ratio_h = float(resize_h) / float(h); - ratio_w = float(resize_w) / float(w); - } else { - cv::resize(img, resize_img, cv::Size(640, 640)); - ratio_h = float(640) / float(h); - ratio_w = float(640) / float(w); - } + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + ratio_h = float(resize_h) / float(h); + ratio_w = float(resize_w) / float(w); } void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, @@ -108,23 +102,12 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, resize_w = imgW; else resize_w = int(ceilf(imgH * ratio)); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, - cv::INTER_LINEAR); - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, - int(imgW - resize_img.cols), cv::BORDER_CONSTANT, - {127, 127, 127}); - } else { - int k = int(img.cols * 32 / img.rows); - if (k >= 100) { - cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, - cv::INTER_LINEAR); - } else { - cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR); - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k), - cv::BORDER_CONSTANT, {127, 127, 127}); - } - } + + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, + int(imgW - resize_img.cols), cv::BORDER_CONSTANT, + {127, 127, 127}); } void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, @@ -142,15 +125,11 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, else resize_w = int(ceilf(imgH * ratio)); - if (!use_tensorrt) { - cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, - cv::INTER_LINEAR); - if (resize_w < imgW) { - cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, - cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); - } - } else { - cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR); + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); } } From 3296edd7b72287c772eb8e915eacf77bdd1d18bf Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 08:03:13 +0000 Subject: [PATCH 6/9] -DUSE to -DWITH --- deploy/cpp_infer/tools/build.sh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/deploy/cpp_infer/tools/build.sh b/deploy/cpp_infer/tools/build.sh index 606539487..3e9926701 100755 --- a/deploy/cpp_infer/tools/build.sh +++ b/deploy/cpp_infer/tools/build.sh @@ -1,7 +1,8 @@ -OPENCV_DIR=your_opencv_dir -LIB_DIR=your_paddle_inference_dir -CUDA_LIB_DIR=your_cuda_lib_dir -CUDNN_LIB_DIR=your_cudnn_lib_dir +OPENCV_DIR=/paddle/Paddle/opencv-3.4.7/opencv3 +LIB_DIR=/paddle/OCR/debug/paddle_inference +CUDA_LIB_DIR=/usr/local/cuda/lib64 +CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu +TENSORRT_DIR=/paddle/Paddle/package/TensorRT/TensorRT-6.0.1.5/ BUILD_DIR=build rm -rf ${BUILD_DIR} @@ -12,7 +13,7 @@ cmake .. \ -DWITH_MKL=ON \ -DWITH_GPU=OFF \ -DWITH_STATIC_LIB=OFF \ - -DUSE_TENSORRT=OFF \ + -DWITH_TENSORRT=OFF \ -DOPENCV_DIR=${OPENCV_DIR} \ -DCUDNN_LIB=${CUDNN_LIB_DIR} \ -DCUDA_LIB=${CUDA_LIB_DIR} \ From b4dc3e9876abeab2ca307bb3f5828acd9638856a Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 08:04:12 +0000 Subject: [PATCH 7/9] complie path retrieval --- deploy/cpp_infer/tools/build.sh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deploy/cpp_infer/tools/build.sh b/deploy/cpp_infer/tools/build.sh index 3e9926701..b767ca7fd 100755 --- a/deploy/cpp_infer/tools/build.sh +++ b/deploy/cpp_infer/tools/build.sh @@ -1,8 +1,7 @@ -OPENCV_DIR=/paddle/Paddle/opencv-3.4.7/opencv3 -LIB_DIR=/paddle/OCR/debug/paddle_inference -CUDA_LIB_DIR=/usr/local/cuda/lib64 -CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu -TENSORRT_DIR=/paddle/Paddle/package/TensorRT/TensorRT-6.0.1.5/ +OPENCV_DIR=your_opencv_dir +LIB_DIR=your_paddle_inference_dir +CUDA_LIB_DIR=your_cuda_lib_dir +CUDNN_LIB_DIR=your_cudnn_lib_dir BUILD_DIR=build rm -rf ${BUILD_DIR} From 0fd13fa41046d30d75c6f8a6dc9b542f9cee582f Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 10:54:54 +0000 Subject: [PATCH 8/9] process else --- tools/infer/utility.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index ff0b60b9c..de3bb3db9 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -212,7 +212,10 @@ def create_predictor(args, mode, logger): min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} - + else: + min_input_shape = {"x": [1, 3, 10, 10]} + max_input_shape = {"x": [1, 3, 1000, 1000]} + opt_input_shape = {"x": [1, 3, 500, 500]} config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape) From c8cc0fb4b77ec75c068c59d67956ec04cc8059f4 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 26 May 2021 10:55:53 +0000 Subject: [PATCH 9/9] pre-commit --- tools/infer/utility.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index de3bb3db9..ff4a0276e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -224,8 +224,8 @@ def create_predictor(args, mode, logger): if hasattr(args, "cpu_threads"): config.set_cpu_math_library_num_threads(args.cpu_threads) else: - config.set_cpu_math_library_num_threads( - 10) # default cpu threads as 10 + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10)