add thresoutput
parent
70c45dcdfd
commit
2abbb70441
|
@ -27,9 +27,10 @@ PreProcess:
|
|||
- ToCHWImage:
|
||||
|
||||
PostProcess:
|
||||
main_indicator: Topk
|
||||
Topk:
|
||||
topk: 5
|
||||
class_id_map_file: "../ppcls/utils/cls_demo/person_label_list.txt"
|
||||
main_indicator: ThreshOutput
|
||||
ThreshOutput:
|
||||
threshold: 0.9
|
||||
label_0: invalid
|
||||
label_1: valid
|
||||
SavePreLabel:
|
||||
save_dir: ./pre_label/
|
||||
|
|
|
@ -53,6 +53,26 @@ class PostProcesser(object):
|
|||
return rtn
|
||||
|
||||
|
||||
class ThreshOutput(object):
|
||||
def __init__(self, threshold, label_0="0", label_1="1"):
|
||||
self.threshold = threshold
|
||||
self.label_0 = label_0
|
||||
self.label_1 = label_1
|
||||
|
||||
def __call__(self, x, file_names=None):
|
||||
y = []
|
||||
for idx, probs in enumerate(x):
|
||||
score = probs[1]
|
||||
if score < self.threshold:
|
||||
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
|
||||
else:
|
||||
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
|
||||
if file_names is not None:
|
||||
result["file_name"] = file_names[idx]
|
||||
y.append(result)
|
||||
return y
|
||||
|
||||
|
||||
class Topk(object):
|
||||
def __init__(self, topk=1, class_id_map_file=None):
|
||||
assert isinstance(topk, (int, ))
|
||||
|
|
Loading…
Reference in New Issue