From ec96e3c7826a54be579be2fdbe643f2c5746ffa0 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 9 Jun 2022 08:20:58 +0800 Subject: [PATCH] fix attr pred (#2002) --- deploy/python/postprocess.py | 3 +-- deploy/python/predict_cls.py | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 9fe15bea8..a7e7f7b2e 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -319,6 +319,5 @@ class VehicleAttribute(object): ] * 10 + [self.type_threshold] * 9 pred_res = (np.array(res) > np.array(threshold_list) ).astype(np.int8).tolist() - - batch_res.append([label_res, pred_res]) + batch_res.append({"attributes": label_res, "output": pred_res}) return batch_res diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index 440624e0c..90e14bcb3 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -142,10 +142,7 @@ def main(config): "PostProcess"] or "VehicleAttribute" in config[ "PostProcess"]: filename = batch_names[number] - attr_message = result_dict[0] - pred_res = result_dict[1] - print("{}:\t attributes: {}, \npredict output: {}".format( - filename, attr_message, pred_res)) + print("{}:\t {}".format(filename, result_dict)) else: filename = batch_names[number] clas_ids = result_dict["class_ids"]