add write_kie_result to kie infer tool
parent
3293645c1a
commit
514d47f3ae
|
@ -89,6 +89,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
|
|||
cv2.imwrite(save_path, vis_img)
|
||||
logger.info("The Kie Image saved in {}".format(save_path))
|
||||
|
||||
def write_kie_result(fout, node, data):
|
||||
"""
|
||||
Write infer result to output file, sorted by the predict label of each line.
|
||||
The format keeps the same as the input with additional score attribute.
|
||||
"""
|
||||
import json
|
||||
label = data['label']
|
||||
annotations = json.loads(label)
|
||||
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
res = []
|
||||
for i, label in enumerate(node_pred_label):
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
pred_res = {
|
||||
'label': label,
|
||||
'transcription': annotations[i]['transcription'],
|
||||
'score': pred_score,
|
||||
'points': annotations[i]['points'],
|
||||
}
|
||||
res.append(pred_res)
|
||||
res.sort(key=lambda x: x['label'])
|
||||
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
@ -116,7 +139,7 @@ def main():
|
|||
|
||||
warmup_times = 0
|
||||
count_t = []
|
||||
with open(save_res_path, "wb") as fout:
|
||||
with open(save_res_path, "w") as fout:
|
||||
with open(config['Global']['infer_img'], "rb") as f:
|
||||
lines = f.readlines()
|
||||
for index, data_line in enumerate(lines):
|
||||
|
@ -141,6 +164,8 @@ def main():
|
|||
node = F.softmax(node, -1)
|
||||
count_t.append(time.time() - st)
|
||||
draw_kie_result(batch, node, idx_to_cls, index)
|
||||
write_kie_result(fout, node, data)
|
||||
fout.close()
|
||||
logger.info("success!")
|
||||
logger.info("It took {} s for predict {} images.".format(
|
||||
np.sum(count_t), len(count_t)))
|
||||
|
|
Loading…
Reference in New Issue