add thresoutput

pull/1925/head
weisy11 2022-05-19 20:50:06 +08:00
parent 70c45dcdfd
commit 2abbb70441
2 changed files with 25 additions and 4 deletions

View File

@ -27,9 +27,10 @@ PreProcess:
- ToCHWImage:
PostProcess:
main_indicator: Topk
Topk:
topk: 5
class_id_map_file: "../ppcls/utils/cls_demo/person_label_list.txt"
main_indicator: ThreshOutput
ThreshOutput:
threshold: 0.9
label_0: invalid
label_1: valid
SavePreLabel:
save_dir: ./pre_label/

View File

@ -53,6 +53,26 @@ class PostProcesser(object):
return rtn
class ThreshOutput(object):
def __init__(self, threshold, label_0="0", label_1="1"):
self.threshold = threshold
self.label_0 = label_0
self.label_1 = label_1
def __call__(self, x, file_names=None):
y = []
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None):
assert isinstance(topk, (int, ))