add more details

pull/1960/head
zhiboniu 2022-05-26 07:14:10 +00:00
parent a4e1da6610
commit 939a35d605
4 changed files with 25 additions and 9 deletions

View File

@ -27,7 +27,7 @@ PreProcess:
PostProcess:
main_indicator: Attribute
Attribute:
threshold: 0.5
glasses_threshold: 0.3
hold_threshold: 0.6
threshold: 0.5 #default threshold
glasses_threshold: 0.3 #threshold only for glasses
hold_threshold: 0.6 #threshold only for hold

View File

@ -64,9 +64,17 @@ class ThreshOutput(object):
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
@ -264,5 +272,11 @@ class Attribute(object):
shoe = 'Boots' if res[14] > self.threshold else 'No boots'
label_res.append(shoe)
batch_res.append(label_res)
threshold_list = [0.5] * len(res)
threshold_list[1] = self.glasses_threshold
threshold_list[18] = self.hold_threshold
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append([label_res, pred_res])
return batch_res

View File

@ -140,9 +140,10 @@ def main(config):
for number, result_dict in enumerate(batch_results):
if "Attribute" in config["PostProcess"]:
filename = batch_names[number]
attr_message = result_dict
print("{}:\tclass id(s): {}".format(filename,
attr_message))
attr_message = result_dict[0]
pred_res = result_dict[1]
print("{}:\t attributes: {}, \npredict output: {}".format(
filename, attr_message, pred_res))
else:
filename = batch_names[number]
clas_ids = result_dict["class_ids"]

View File

@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric):
def get_attr_metrics(gt_label, preds_probs, threshold):
"""
index: evaluated label index
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
"""
pred_label = (preds_probs > threshold).astype(int)