mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
Merge pull request #6824 from ChenNima/release/2.5-kie-save-res
[kie]add write_kie_result to kie infer tool
This commit is contained in:
parent
b7d99acd2e
commit
d15fca53d2
@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
|
|||||||
cv2.imwrite(save_path, vis_img)
|
cv2.imwrite(save_path, vis_img)
|
||||||
logger.info("The Kie Image saved in {}".format(save_path))
|
logger.info("The Kie Image saved in {}".format(save_path))
|
||||||
|
|
||||||
|
def write_kie_result(fout, node, data):
|
||||||
|
"""
|
||||||
|
Write infer result to output file, sorted by the predict label of each line.
|
||||||
|
The format keeps the same as the input with additional score attribute.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
label = data['label']
|
||||||
|
annotations = json.loads(label)
|
||||||
|
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
||||||
|
node_pred_label = max_idx.numpy().tolist()
|
||||||
|
node_pred_score = max_value.numpy().tolist()
|
||||||
|
res = []
|
||||||
|
for i, label in enumerate(node_pred_label):
|
||||||
|
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||||
|
pred_res = {
|
||||||
|
'label': label,
|
||||||
|
'transcription': annotations[i]['transcription'],
|
||||||
|
'score': pred_score,
|
||||||
|
'points': annotations[i]['points'],
|
||||||
|
}
|
||||||
|
res.append(pred_res)
|
||||||
|
res.sort(key=lambda x: x['label'])
|
||||||
|
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
@ -114,7 +137,7 @@ def main():
|
|||||||
|
|
||||||
warmup_times = 0
|
warmup_times = 0
|
||||||
count_t = []
|
count_t = []
|
||||||
with open(save_res_path, "wb") as fout:
|
with open(save_res_path, "w") as fout:
|
||||||
with open(config['Global']['infer_img'], "rb") as f:
|
with open(config['Global']['infer_img'], "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
for index, data_line in enumerate(lines):
|
for index, data_line in enumerate(lines):
|
||||||
@ -139,6 +162,8 @@ def main():
|
|||||||
node = F.softmax(node, -1)
|
node = F.softmax(node, -1)
|
||||||
count_t.append(time.time() - st)
|
count_t.append(time.time() - st)
|
||||||
draw_kie_result(batch, node, idx_to_cls, index)
|
draw_kie_result(batch, node, idx_to_cls, index)
|
||||||
|
write_kie_result(fout, node, data)
|
||||||
|
fout.close()
|
||||||
logger.info("success!")
|
logger.info("success!")
|
||||||
logger.info("It took {} s for predict {} images.".format(
|
logger.info("It took {} s for predict {} images.".format(
|
||||||
np.sum(count_t), len(count_t)))
|
np.sum(count_t), len(count_t)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user