diff --git a/tools/infer/infer.py b/tools/infer/infer.py index 9a3b8fff9..b06a51bd5 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -105,6 +105,7 @@ def main(): exe, program, feed_names, fetch_names = create_predictor(args) data = preprocess(args.image_file, operators) + data = np.expand_dims(data, axis=0) outputs = exe.run(program, feed={feed_names[0]: data}, fetch_list=fetch_names,