CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515)
* modification of return word box * update_implements * Update rec_postprocess.py * Update utility.pypull/10560/head
parent
2b7b9dc2cf
commit
1e11f25409
|
@ -67,7 +67,66 @@ class BaseRecLabelDecode(object):
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
def get_word_info(self, text, selection):
|
||||||
|
"""
|
||||||
|
Group the decoded characters and record the corresponding decoded positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: the decoded text
|
||||||
|
selection: the bool array that identifies which columns of features are decoded as non-separated characters
|
||||||
|
Returns:
|
||||||
|
word_list: list of the grouped words
|
||||||
|
word_col_list: list of decoding positions corresponding to each character in the grouped word
|
||||||
|
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
|
||||||
|
- 'cn': continous chinese characters (e.g., 你好啊)
|
||||||
|
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
|
||||||
|
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
|
||||||
|
"""
|
||||||
|
state = None
|
||||||
|
word_content = []
|
||||||
|
word_col_content = []
|
||||||
|
word_list = []
|
||||||
|
word_col_list = []
|
||||||
|
state_list = []
|
||||||
|
valid_col = np.where(selection==True)[0]
|
||||||
|
|
||||||
|
for c_i, char in enumerate(text):
|
||||||
|
if '\u4e00' <= char <= '\u9fff':
|
||||||
|
c_state = 'cn'
|
||||||
|
elif bool(re.search('[a-zA-Z0-9]', char)):
|
||||||
|
c_state = 'en&num'
|
||||||
|
else:
|
||||||
|
c_state = 'splitter'
|
||||||
|
|
||||||
|
if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
|
||||||
|
c_state = 'en&num'
|
||||||
|
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
|
||||||
|
c_state = 'en&num'
|
||||||
|
|
||||||
|
if state == None:
|
||||||
|
state = c_state
|
||||||
|
|
||||||
|
if state != c_state:
|
||||||
|
if len(word_content) != 0:
|
||||||
|
word_list.append(word_content)
|
||||||
|
word_col_list.append(word_col_content)
|
||||||
|
state_list.append(state)
|
||||||
|
word_content = []
|
||||||
|
word_col_content = []
|
||||||
|
state = c_state
|
||||||
|
|
||||||
|
if state != "splitter":
|
||||||
|
word_content.append(char)
|
||||||
|
word_col_content.append(valid_col[c_i])
|
||||||
|
|
||||||
|
if len(word_content) != 0:
|
||||||
|
word_list.append(word_content)
|
||||||
|
word_col_list.append(word_col_content)
|
||||||
|
state_list.append(state)
|
||||||
|
|
||||||
|
return word_list, word_col_list, state_list
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
|
||||||
""" convert text-index into text-label. """
|
""" convert text-index into text-label. """
|
||||||
result_list = []
|
result_list = []
|
||||||
ignored_tokens = self.get_ignored_tokens()
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
@ -95,8 +154,12 @@ class BaseRecLabelDecode(object):
|
||||||
|
|
||||||
if self.reverse: # for arabic rec
|
if self.reverse: # for arabic rec
|
||||||
text = self.pred_reverse(text)
|
text = self.pred_reverse(text)
|
||||||
|
|
||||||
result_list.append((text, np.mean(conf_list).tolist()))
|
if return_word_box:
|
||||||
|
word_list, word_col_list, state_list = self.get_word_info(text, selection)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
|
||||||
|
else:
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
def get_ignored_tokens(self):
|
def get_ignored_tokens(self):
|
||||||
|
@ -111,14 +174,19 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
||||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||||
use_space_char)
|
use_space_char)
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
|
||||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||||
preds = preds[-1]
|
preds = preds[-1]
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
preds_idx = preds.argmax(axis=2)
|
preds_idx = preds.argmax(axis=2)
|
||||||
preds_prob = preds.max(axis=2)
|
preds_prob = preds.max(axis=2)
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
|
||||||
|
if return_word_box:
|
||||||
|
for rec_idx, rec in enumerate(text):
|
||||||
|
wh_ratio = kwargs['wh_ratio_list'][rec_idx]
|
||||||
|
max_wh_ratio = kwargs['max_wh_ratio']
|
||||||
|
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
|
||||||
if label is None:
|
if label is None:
|
||||||
return text
|
return text
|
||||||
label = self.decode(label)
|
label = self.decode(label)
|
||||||
|
|
|
@ -34,7 +34,7 @@ from ppocr.utils.visual import draw_ser_results, draw_re_results
|
||||||
from tools.infer.predict_system import TextSystem
|
from tools.infer.predict_system import TextSystem
|
||||||
from ppstructure.layout.predict_layout import LayoutPredictor
|
from ppstructure.layout.predict_layout import LayoutPredictor
|
||||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||||
from ppstructure.utility import parse_args, draw_structure_result
|
from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
@ -79,6 +79,8 @@ class StructureSystem(object):
|
||||||
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
|
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
|
||||||
self.kie_predictor = SerRePredictor(args)
|
self.kie_predictor = SerRePredictor(args)
|
||||||
|
|
||||||
|
self.return_word_box = args.return_word_box
|
||||||
|
|
||||||
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
|
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
|
||||||
time_dict = {
|
time_dict = {
|
||||||
'image_orientation': 0,
|
'image_orientation': 0,
|
||||||
|
@ -156,17 +158,27 @@ class StructureSystem(object):
|
||||||
]
|
]
|
||||||
res = []
|
res = []
|
||||||
for box, rec_res in zip(filter_boxes, filter_rec_res):
|
for box, rec_res in zip(filter_boxes, filter_rec_res):
|
||||||
rec_str, rec_conf = rec_res
|
rec_str, rec_conf = rec_res[0], rec_res[1]
|
||||||
for token in style_token:
|
for token in style_token:
|
||||||
if token in rec_str:
|
if token in rec_str:
|
||||||
rec_str = rec_str.replace(token, '')
|
rec_str = rec_str.replace(token, '')
|
||||||
if not self.recovery:
|
if not self.recovery:
|
||||||
box += [x1, y1]
|
box += [x1, y1]
|
||||||
res.append({
|
if self.return_word_box:
|
||||||
'text': rec_str,
|
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
|
||||||
'confidence': float(rec_conf),
|
res.append({
|
||||||
'text_region': box.tolist()
|
'text': rec_str,
|
||||||
})
|
'confidence': float(rec_conf),
|
||||||
|
'text_region': box.tolist(),
|
||||||
|
'text_word': word_box_content_list,
|
||||||
|
'text_word_region': word_box_list
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
res.append({
|
||||||
|
'text': rec_str,
|
||||||
|
'confidence': float(rec_conf),
|
||||||
|
'text_region': box.tolist()
|
||||||
|
})
|
||||||
res_list.append({
|
res_list.append({
|
||||||
'type': region['label'].lower(),
|
'type': region['label'].lower(),
|
||||||
'bbox': [x1, y1, x2, y2],
|
'bbox': [x1, y1, x2, y2],
|
||||||
|
|
|
@ -16,7 +16,7 @@ import ast
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
|
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
|
||||||
|
import math
|
||||||
|
|
||||||
def init_args():
|
def init_args():
|
||||||
parser = infer_args()
|
parser = infer_args()
|
||||||
|
@ -166,6 +166,63 @@ def draw_structure_result(image, result, font_path):
|
||||||
txts.append(text_result['text'])
|
txts.append(text_result['text'])
|
||||||
scores.append(text_result['confidence'])
|
scores.append(text_result['confidence'])
|
||||||
|
|
||||||
|
if 'text_word_region' in text_result:
|
||||||
|
for word_region in text_result['text_word_region']:
|
||||||
|
char_box = word_region
|
||||||
|
box_height = int(
|
||||||
|
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
|
||||||
|
box_width = int(
|
||||||
|
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
|
||||||
|
if box_height == 0 or box_width == 0:
|
||||||
|
continue
|
||||||
|
boxes.append(word_region)
|
||||||
|
txts.append("")
|
||||||
|
scores.append(1.0)
|
||||||
|
|
||||||
im_show = draw_ocr_box_txt(
|
im_show = draw_ocr_box_txt(
|
||||||
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
|
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
|
||||||
return im_show
|
return im_show
|
||||||
|
|
||||||
|
def cal_ocr_word_box(rec_str, box, rec_word_info):
|
||||||
|
''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''
|
||||||
|
|
||||||
|
col_num, word_list, word_col_list, state_list = rec_word_info
|
||||||
|
box = box.tolist()
|
||||||
|
bbox_x_start = box[0][0]
|
||||||
|
bbox_x_end = box[1][0]
|
||||||
|
bbox_y_start = box[0][1]
|
||||||
|
bbox_y_end = box[2][1]
|
||||||
|
|
||||||
|
cell_width = (bbox_x_end - bbox_x_start)/col_num
|
||||||
|
|
||||||
|
word_box_list = []
|
||||||
|
word_box_content_list = []
|
||||||
|
cn_width_list = []
|
||||||
|
cn_col_list = []
|
||||||
|
for word, word_col, state in zip(word_list, word_col_list, state_list):
|
||||||
|
if state == 'cn':
|
||||||
|
if len(word_col) != 1:
|
||||||
|
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
|
||||||
|
char_width = char_seq_length/(len(word_col)-1)
|
||||||
|
cn_width_list.append(char_width)
|
||||||
|
cn_col_list += word_col
|
||||||
|
word_box_content_list += word
|
||||||
|
else:
|
||||||
|
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
|
||||||
|
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
|
||||||
|
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
|
||||||
|
word_box_list.append(cell)
|
||||||
|
word_box_content_list.append("".join(word))
|
||||||
|
if len(cn_col_list) != 0:
|
||||||
|
if len(cn_width_list) != 0:
|
||||||
|
avg_char_width = np.mean(cn_width_list)
|
||||||
|
else:
|
||||||
|
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
|
||||||
|
for center_idx in cn_col_list:
|
||||||
|
center_x = (center_idx+0.5)*cell_width
|
||||||
|
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
|
||||||
|
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
|
||||||
|
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
|
||||||
|
word_box_list.append(cell)
|
||||||
|
|
||||||
|
return word_box_content_list, word_box_list
|
|
@ -116,6 +116,7 @@ class TextRecognizer(object):
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
|
self.postprocess_params = postprocess_params
|
||||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
self.benchmark = args.benchmark
|
self.benchmark = args.benchmark
|
||||||
|
@ -139,6 +140,7 @@ class TextRecognizer(object):
|
||||||
],
|
],
|
||||||
warmup=0,
|
warmup=0,
|
||||||
logger=logger)
|
logger=logger)
|
||||||
|
self.return_word_box = args.return_word_box
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
imgC, imgH, imgW = self.rec_image_shape
|
imgC, imgH, imgW = self.rec_image_shape
|
||||||
|
@ -407,11 +409,12 @@ class TextRecognizer(object):
|
||||||
valid_ratios = []
|
valid_ratios = []
|
||||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||||
max_wh_ratio = imgW / imgH
|
max_wh_ratio = imgW / imgH
|
||||||
# max_wh_ratio = 0
|
wh_ratio_list = []
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
h, w = img_list[indices[ino]].shape[0:2]
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
wh_ratio_list.append(wh_ratio)
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
if self.rec_algorithm == "SAR":
|
if self.rec_algorithm == "SAR":
|
||||||
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||||
|
@ -616,7 +619,10 @@ class TextRecognizer(object):
|
||||||
preds = outputs
|
preds = outputs
|
||||||
else:
|
else:
|
||||||
preds = outputs[0]
|
preds = outputs[0]
|
||||||
rec_result = self.postprocess_op(preds)
|
if self.postprocess_params['name'] == 'CTCLabelDecode':
|
||||||
|
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
|
||||||
|
else:
|
||||||
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
if self.benchmark:
|
if self.benchmark:
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TextSystem(object):
|
||||||
rec_res)
|
rec_res)
|
||||||
filter_boxes, filter_rec_res = [], []
|
filter_boxes, filter_rec_res = [], []
|
||||||
for box, rec_result in zip(dt_boxes, rec_res):
|
for box, rec_result in zip(dt_boxes, rec_res):
|
||||||
text, score = rec_result
|
text, score = rec_result[0], rec_result[1]
|
||||||
if score >= self.drop_score:
|
if score >= self.drop_score:
|
||||||
filter_boxes.append(box)
|
filter_boxes.append(box)
|
||||||
filter_rec_res.append(rec_result)
|
filter_rec_res.append(rec_result)
|
||||||
|
|
|
@ -150,6 +150,10 @@ def init_args():
|
||||||
|
|
||||||
parser.add_argument("--show_log", type=str2bool, default=True)
|
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||||
parser.add_argument("--use_onnx", type=str2bool, default=False)
|
parser.add_argument("--use_onnx", type=str2bool, default=False)
|
||||||
|
|
||||||
|
# extended function
|
||||||
|
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue