From bb54e9b8ebb573822cddd684e435411a55b48599 Mon Sep 17 00:00:00 2001
From: LDOUBLEV <liuvv0203@outlook.com>
Date: Mon, 19 Sep 2022 14:03:15 +0800
Subject: [PATCH] fix

---
 tools/infer/utility.py | 26 ++++++++++++++------------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 793ff28a2..a8c59fac6 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -227,19 +227,21 @@ def create_predictor(args, mode, logger):
                     use_calib_mode=False)
 
                 # collect shape
-                model_name = os.path.basename(model_dir[:-1]) if model_dir.endswith("/") else os.path.basename(model_dir)
+                model_name = os.path.basename(
+                    model_dir[:-1]) if model_dir.endswith(
+                        "/") else os.path.basename(model_dir)
                 trt_shape_f = f"{mode}_{model_name}.txt"
-                if trt_shape_f is not None:
-                    if not os.path.exists(trt_shape_f):
-                        config.collect_shape_range_info(trt_shape_f)
-                        logger.info(
-                            f"collect dynamic shape info into : {trt_shape_f}"
-                        )
-                    else:
-                        logger.info(
-                            f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again."
-                        )
-                    config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
+
+                if not os.path.exists(trt_shape_f):
+                    config.collect_shape_range_info(trt_shape_f)
+                    logger.info(
+                        f"collect dynamic shape info into : {trt_shape_f}")
+                try:
+                    config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
+                                                               True)
+                except Exception as E:
+                    logger.info(E)
+                    logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
 
         elif args.use_npu:
             config.enable_npu()