add write_kie_result to kie infer tool

pull/6824/head
Felix Chen 2022-07-07 17:49:48 +08:00
parent 3293645c1a
commit 514d47f3ae
1 changed files with 26 additions and 1 deletions

View File

@ -89,6 +89,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
cv2.imwrite(save_path, vis_img)
logger.info("The Kie Image saved in {}".format(save_path))
def write_kie_result(fout, node, data):
"""
Write infer result to output file, sorted by the predict label of each line.
The format keeps the same as the input with additional score attribute.
"""
import json
label = data['label']
annotations = json.loads(label)
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
res = []
for i, label in enumerate(node_pred_label):
pred_score = '{:.2f}'.format(node_pred_score[i])
pred_res = {
'label': label,
'transcription': annotations[i]['transcription'],
'score': pred_score,
'points': annotations[i]['points'],
}
res.append(pred_res)
res.sort(key=lambda x: x['label'])
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
def main():
global_config = config['Global']
@ -116,7 +139,7 @@ def main():
warmup_times = 0
count_t = []
with open(save_res_path, "wb") as fout:
with open(save_res_path, "w") as fout:
with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines()
for index, data_line in enumerate(lines):
@ -141,6 +164,8 @@ def main():
node = F.softmax(node, -1)
count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index)
write_kie_result(fout, node, data)
fout.close()
logger.info("success!")
logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t)))