From 37e177595bf76b5e41a8edb6f7734d0d0fbd66d0 Mon Sep 17 00:00:00 2001 From: Salmondx <9963444+Salmondx@users.noreply.github.com> Date: Fri, 25 Oct 2024 06:21:39 -0600 Subject: [PATCH] Allow `create_predictor` function to accept array of ONNX Execution Providers (#14078) * pass onnx execution providers to create_predictor function * added ability to provide onnxruntime SessionOptions * added argument parser for onnx_sess_options --------- Co-authored-by: ggolda --- tools/infer/utility.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index c40fc6b1b..57b585601 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -147,6 +147,8 @@ def init_args(): parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--use_onnx", type=str2bool, default=False) + parser.add_argument("--onnx_providers", nargs="+", type=str, default=False) + parser.add_argument("--onnx_sess_options", type=list, default=False) # extended function parser.add_argument( @@ -193,7 +195,16 @@ def create_predictor(args, mode, logger): model_file_path = model_dir if not os.path.exists(model_file_path): raise ValueError("not find model file path {}".format(model_file_path)) - if args.use_gpu: + + sess_options = args.onnx_sess_options or None + + if args.onnx_providers and len(args.onnx_providers) > 0: + sess = ort.InferenceSession( + model_file_path, + providers=args.onnx_providers, + sess_options=sess_options, + ) + elif args.use_gpu: sess = ort.InferenceSession( model_file_path, providers=[ @@ -202,10 +213,13 @@ def create_predictor(args, mode, logger): {"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"}, ) ], + sess_options=sess_options, ) else: sess = ort.InferenceSession( - model_file_path, providers=["CPUExecutionProvider"] + model_file_path, + providers=["CPUExecutionProvider"], + sess_options=sess_options, ) return sess, sess.get_inputs()[0], None, None