diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 8ada2b3b7..b6a770637 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -30,15 +30,18 @@ from ppocr.utils.logging import get_logger def str2bool(v): return v.lower() in ("true", "yes", "t", "y", "1") + def str2int_tuple(v): return tuple([int(i.strip()) for i in v.split(",")]) + def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_xpu", type=str2bool, default=False) parser.add_argument("--use_npu", type=str2bool, default=False) + parser.add_argument("--use_mlu", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) @@ -249,6 +252,8 @@ def create_predictor(args, mode, logger): elif args.use_npu: config.enable_custom_device("npu") + elif args.use_mlu: + config.enable_custom_device("mlu") elif args.use_xpu: config.enable_xpu(10 * 1024 * 1024) else: