mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix attr pred (#2002)
This commit is contained in:
parent
b457c393eb
commit
ec96e3c782
@ -319,6 +319,5 @@ class VehicleAttribute(object):
|
|||||||
] * 10 + [self.type_threshold] * 9
|
] * 10 + [self.type_threshold] * 9
|
||||||
pred_res = (np.array(res) > np.array(threshold_list)
|
pred_res = (np.array(res) > np.array(threshold_list)
|
||||||
).astype(np.int8).tolist()
|
).astype(np.int8).tolist()
|
||||||
|
batch_res.append({"attributes": label_res, "output": pred_res})
|
||||||
batch_res.append([label_res, pred_res])
|
|
||||||
return batch_res
|
return batch_res
|
||||||
|
@ -142,10 +142,7 @@ def main(config):
|
|||||||
"PostProcess"] or "VehicleAttribute" in config[
|
"PostProcess"] or "VehicleAttribute" in config[
|
||||||
"PostProcess"]:
|
"PostProcess"]:
|
||||||
filename = batch_names[number]
|
filename = batch_names[number]
|
||||||
attr_message = result_dict[0]
|
print("{}:\t {}".format(filename, result_dict))
|
||||||
pred_res = result_dict[1]
|
|
||||||
print("{}:\t attributes: {}, \npredict output: {}".format(
|
|
||||||
filename, attr_message, pred_res))
|
|
||||||
else:
|
else:
|
||||||
filename = batch_names[number]
|
filename = batch_names[number]
|
||||||
clas_ids = result_dict["class_ids"]
|
clas_ids = result_dict["class_ids"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user