add more details
parent
a4e1da6610
commit
939a35d605
|
@ -27,7 +27,7 @@ PreProcess:
|
|||
PostProcess:
|
||||
main_indicator: Attribute
|
||||
Attribute:
|
||||
threshold: 0.5
|
||||
glasses_threshold: 0.3
|
||||
hold_threshold: 0.6
|
||||
threshold: 0.5 #default threshold
|
||||
glasses_threshold: 0.3 #threshold only for glasses
|
||||
hold_threshold: 0.6 #threshold only for hold
|
||||
|
|
@ -64,9 +64,17 @@ class ThreshOutput(object):
|
|||
for idx, probs in enumerate(x):
|
||||
score = probs[1]
|
||||
if score < self.threshold:
|
||||
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
|
||||
result = {
|
||||
"class_ids": [0],
|
||||
"scores": [1 - score],
|
||||
"label_names": [self.label_0]
|
||||
}
|
||||
else:
|
||||
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
|
||||
result = {
|
||||
"class_ids": [1],
|
||||
"scores": [score],
|
||||
"label_names": [self.label_1]
|
||||
}
|
||||
if file_names is not None:
|
||||
result["file_name"] = file_names[idx]
|
||||
y.append(result)
|
||||
|
@ -264,5 +272,11 @@ class Attribute(object):
|
|||
shoe = 'Boots' if res[14] > self.threshold else 'No boots'
|
||||
label_res.append(shoe)
|
||||
|
||||
batch_res.append(label_res)
|
||||
threshold_list = [0.5] * len(res)
|
||||
threshold_list[1] = self.glasses_threshold
|
||||
threshold_list[18] = self.hold_threshold
|
||||
pred_res = (np.array(res) > np.array(threshold_list)
|
||||
).astype(np.int8).tolist()
|
||||
|
||||
batch_res.append([label_res, pred_res])
|
||||
return batch_res
|
||||
|
|
|
@ -140,9 +140,10 @@ def main(config):
|
|||
for number, result_dict in enumerate(batch_results):
|
||||
if "Attribute" in config["PostProcess"]:
|
||||
filename = batch_names[number]
|
||||
attr_message = result_dict
|
||||
print("{}:\tclass id(s): {}".format(filename,
|
||||
attr_message))
|
||||
attr_message = result_dict[0]
|
||||
pred_res = result_dict[1]
|
||||
print("{}:\t attributes: {}, \npredict output: {}".format(
|
||||
filename, attr_message, pred_res))
|
||||
else:
|
||||
filename = batch_names[number]
|
||||
clas_ids = result_dict["class_ids"]
|
||||
|
|
|
@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric):
|
|||
def get_attr_metrics(gt_label, preds_probs, threshold):
|
||||
"""
|
||||
index: evaluated label index
|
||||
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
|
||||
"""
|
||||
pred_label = (preds_probs > threshold).astype(int)
|
||||
|
||||
|
|
Loading…
Reference in New Issue