From 6d5f998fe1d5e2ffa71540500f3e1d318c9af4d4 Mon Sep 17 00:00:00 2001 From: CaiRan <15830921501@163.com> Date: Fri, 21 Jun 2024 17:20:40 +0800 Subject: [PATCH] solve the onnxruntime inference issue (#13154) --- tools/infer/utility.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 4a734683d..6f6b8c5dd 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -197,10 +197,18 @@ def create_predictor(args, mode, logger): raise ValueError("not find model file path {}".format(model_file_path)) if args.use_gpu: sess = ort.InferenceSession( - model_file_path, providers=["CUDAExecutionProvider"] + model_file_path, + providers=[ + ( + "CUDAExecutionProvider", + {"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"}, + ) + ], ) else: - sess = ort.InferenceSession(model_file_path) + sess = ort.InferenceSession( + model_file_path, providers=["CPUExecutionProvider"] + ) return sess, sess.get_inputs()[0], None, None else: