diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b5fe3ba98..3f81e95d9 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -98,6 +98,7 @@ def parse_args(): parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False) @@ -140,14 +141,12 @@ def create_predictor(args, mode, logger): max_batch_size=args.max_batch_size) else: config.disable_gpu() - config.set_cpu_math_library_num_threads(6) + cpu_threads = args.cpu_threads if hasattr(args, cpu_threads) else 10 + config.set_cpu_math_library_num_threads(cpu_threads) if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 - #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) - args.rec_batch_num = 1 # enable memory optim config.enable_memory_optim()