From 1e11f254094305c593d4c734a5c4148f945accaa Mon Sep 17 00:00:00 2001
From: ToddBear <43341135+ToddBear@users.noreply.github.com>
Date: Wed, 2 Aug 2023 19:11:28 +0800
Subject: [PATCH] =?UTF-8?q?CV=E5=A5=97=E4=BB=B6=E5=BB=BA=E8=AE=BE=E4=B8=93?=
 =?UTF-8?q?=E9=A1=B9=E6=B4=BB=E5=8A=A8=20-=20=E6=96=87=E5=AD=97=E8=AF=86?=
 =?UTF-8?q?=E5=88=AB=E8=BF=94=E5=9B=9E=E5=8D=95=E5=AD=97=E8=AF=86=E5=88=AB?=
 =?UTF-8?q?=E5=9D=90=E6=A0=87=20(#10515)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* modification of return word box

* update_implements

* Update rec_postprocess.py

* Update utility.py
---
 ppocr/postprocess/rec_postprocess.py | 78 ++++++++++++++++++++++++++--
 ppstructure/predict_system.py        | 26 +++++++---
 ppstructure/utility.py               | 59 ++++++++++++++++++++-
 tools/infer/predict_rec.py           | 10 +++-
 tools/infer/predict_system.py        |  2 +-
 tools/infer/utility.py               |  4 ++
 6 files changed, 163 insertions(+), 16 deletions(-)

diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index fbf8b93e3..230f84d1b 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -67,7 +67,66 @@ class BaseRecLabelDecode(object):
     def add_special_char(self, dict_character):
         return dict_character
 
-    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+    def get_word_info(self, text, selection):
+        """
+        Group the decoded characters and record the corresponding decoded positions. 
+
+        Args:
+            text: the decoded text
+            selection: the bool array that identifies which columns of features are decoded as non-separated characters 
+        Returns:
+            word_list: list of the grouped words
+            word_col_list: list of decoding positions corresponding to each character in the grouped word
+            state_list: list of marker to identify the type of grouping words, including two types of grouping words: 
+                        - 'cn': continous chinese characters (e.g., 你好啊)
+                        - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
+                        The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
+        """
+        state = None
+        word_content = []
+        word_col_content = []
+        word_list = []
+        word_col_list = []
+        state_list = []
+        valid_col = np.where(selection==True)[0]
+
+        for c_i, char in enumerate(text):
+            if '\u4e00' <= char <= '\u9fff':
+                c_state = 'cn'
+            elif bool(re.search('[a-zA-Z0-9]', char)):
+                c_state = 'en&num'
+            else:
+                c_state = 'splitter'
+            
+            if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
+                c_state = 'en&num'
+            if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
+                c_state = 'en&num'
+            
+            if state == None:
+                state = c_state
+
+            if state != c_state:
+                if len(word_content) != 0:
+                    word_list.append(word_content)
+                    word_col_list.append(word_col_content)
+                    state_list.append(state)
+                    word_content = []
+                    word_col_content = []
+                state = c_state
+
+            if state != "splitter":
+                word_content.append(char)
+                word_col_content.append(valid_col[c_i])
+
+        if len(word_content) != 0: 
+            word_list.append(word_content)
+            word_col_list.append(word_col_content)
+            state_list.append(state)
+
+        return word_list, word_col_list, state_list
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
         """ convert text-index into text-label. """
         result_list = []
         ignored_tokens = self.get_ignored_tokens()
@@ -95,8 +154,12 @@ class BaseRecLabelDecode(object):
 
             if self.reverse:  # for arabic rec
                 text = self.pred_reverse(text)
-
-            result_list.append((text, np.mean(conf_list).tolist()))
+            
+            if return_word_box:
+                word_list, word_col_list, state_list = self.get_word_info(text, selection)
+                result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
+            else:
+                result_list.append((text, np.mean(conf_list).tolist()))
         return result_list
 
     def get_ignored_tokens(self):
@@ -111,14 +174,19 @@ class CTCLabelDecode(BaseRecLabelDecode):
         super(CTCLabelDecode, self).__init__(character_dict_path,
                                              use_space_char)
 
-    def __call__(self, preds, label=None, *args, **kwargs):
+    def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
         if isinstance(preds, tuple) or isinstance(preds, list):
             preds = preds[-1]
         if isinstance(preds, paddle.Tensor):
             preds = preds.numpy()
         preds_idx = preds.argmax(axis=2)
         preds_prob = preds.max(axis=2)
-        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
+        if return_word_box:
+            for rec_idx, rec in enumerate(text):
+                wh_ratio = kwargs['wh_ratio_list'][rec_idx]
+                max_wh_ratio = kwargs['max_wh_ratio']
+                rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
         if label is None:
             return text
         label = self.decode(label)
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index b32b70629..b8b871689 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -34,7 +34,7 @@ from ppocr.utils.visual import draw_ser_results, draw_re_results
 from tools.infer.predict_system import TextSystem
 from ppstructure.layout.predict_layout import LayoutPredictor
 from ppstructure.table.predict_table import TableSystem, to_excel
-from ppstructure.utility import parse_args, draw_structure_result
+from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box
 
 logger = get_logger()
 
@@ -79,6 +79,8 @@ class StructureSystem(object):
             from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
             self.kie_predictor = SerRePredictor(args)
 
+        self.return_word_box = args.return_word_box
+
     def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
         time_dict = {
             'image_orientation': 0,
@@ -156,17 +158,27 @@ class StructureSystem(object):
                         ]
                         res = []
                         for box, rec_res in zip(filter_boxes, filter_rec_res):
-                            rec_str, rec_conf = rec_res
+                            rec_str, rec_conf = rec_res[0], rec_res[1]
                             for token in style_token:
                                 if token in rec_str:
                                     rec_str = rec_str.replace(token, '')
                             if not self.recovery:
                                 box += [x1, y1]
-                            res.append({
-                                'text': rec_str,
-                                'confidence': float(rec_conf),
-                                'text_region': box.tolist()
-                            })
+                            if self.return_word_box:
+                                word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
+                                res.append({
+                                    'text': rec_str,
+                                    'confidence': float(rec_conf),
+                                    'text_region': box.tolist(),
+                                    'text_word': word_box_content_list,
+                                    'text_word_region': word_box_list
+                                })
+                            else:
+                                res.append({
+                                    'text': rec_str,
+                                    'confidence': float(rec_conf),
+                                    'text_region': box.tolist()
+                                })
                 res_list.append({
                     'type': region['label'].lower(),
                     'bbox': [x1, y1, x2, y2],
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 182283a7f..320722d1f 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -16,7 +16,7 @@ import ast
 from PIL import Image, ImageDraw, ImageFont
 import numpy as np
 from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
-
+import math
 
 def init_args():
     parser = infer_args()
@@ -166,6 +166,63 @@ def draw_structure_result(image, result, font_path):
                 txts.append(text_result['text'])
                 scores.append(text_result['confidence'])
 
+                if 'text_word_region' in text_result:
+                    for word_region in text_result['text_word_region']:
+                        char_box = word_region
+                        box_height = int(
+                            math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
+                        box_width = int(
+                            math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
+                        if box_height == 0 or box_width == 0:
+                            continue
+                        boxes.append(word_region)
+                        txts.append("")
+                        scores.append(1.0)
+
     im_show = draw_ocr_box_txt(
         img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
     return im_show
+
+def cal_ocr_word_box(rec_str, box, rec_word_info):
+    ''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''
+    
+    col_num, word_list, word_col_list, state_list = rec_word_info
+    box = box.tolist()
+    bbox_x_start = box[0][0]
+    bbox_x_end = box[1][0]
+    bbox_y_start = box[0][1]
+    bbox_y_end = box[2][1]
+
+    cell_width = (bbox_x_end - bbox_x_start)/col_num
+
+    word_box_list = []
+    word_box_content_list = []
+    cn_width_list = []
+    cn_col_list = []
+    for word, word_col, state in zip(word_list, word_col_list, state_list):
+        if state == 'cn':
+            if len(word_col) != 1:
+                char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
+                char_width = char_seq_length/(len(word_col)-1)
+                cn_width_list.append(char_width)
+            cn_col_list += word_col
+            word_box_content_list += word
+        else:
+            cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
+            cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
+            cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
+            word_box_list.append(cell)
+            word_box_content_list.append("".join(word))
+    if len(cn_col_list) != 0:
+        if len(cn_width_list) != 0:
+            avg_char_width = np.mean(cn_width_list)
+        else:
+            avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
+        for center_idx in cn_col_list:
+            center_x = (center_idx+0.5)*cell_width
+            cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
+            cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
+            cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
+            word_box_list.append(cell)
+    
+    return word_box_content_list, word_box_list
\ No newline at end of file
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index b3ef557c0..7f4a3863e 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -116,6 +116,7 @@ class TextRecognizer(object):
                 "use_space_char": args.use_space_char
             }
         self.postprocess_op = build_post_process(postprocess_params)
+        self.postprocess_params = postprocess_params
         self.predictor, self.input_tensor, self.output_tensors, self.config = \
             utility.create_predictor(args, 'rec', logger)
         self.benchmark = args.benchmark
@@ -139,6 +140,7 @@ class TextRecognizer(object):
                 ],
                 warmup=0,
                 logger=logger)
+        self.return_word_box = args.return_word_box
 
     def resize_norm_img(self, img, max_wh_ratio):
         imgC, imgH, imgW = self.rec_image_shape
@@ -407,11 +409,12 @@ class TextRecognizer(object):
                 valid_ratios = []
             imgC, imgH, imgW = self.rec_image_shape[:3]
             max_wh_ratio = imgW / imgH
-            # max_wh_ratio = 0
+            wh_ratio_list = []
             for ino in range(beg_img_no, end_img_no):
                 h, w = img_list[indices[ino]].shape[0:2]
                 wh_ratio = w * 1.0 / h
                 max_wh_ratio = max(max_wh_ratio, wh_ratio)
+                wh_ratio_list.append(wh_ratio)
             for ino in range(beg_img_no, end_img_no):
                 if self.rec_algorithm == "SAR":
                     norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
@@ -616,7 +619,10 @@ class TextRecognizer(object):
                         preds = outputs
                     else:
                         preds = outputs[0]
-            rec_result = self.postprocess_op(preds)
+            if self.postprocess_params['name'] == 'CTCLabelDecode':
+                rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
+            else:
+                rec_result = self.postprocess_op(preds)
             for rno in range(len(rec_result)):
                 rec_res[indices[beg_img_no + rno]] = rec_result[rno]
             if self.benchmark:
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index 95d87be61..8af45b4cf 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -111,7 +111,7 @@ class TextSystem(object):
                                    rec_res)
         filter_boxes, filter_rec_res = [], []
         for box, rec_result in zip(dt_boxes, rec_res):
-            text, score = rec_result
+            text, score = rec_result[0], rec_result[1]
             if score >= self.drop_score:
                 filter_boxes.append(box)
                 filter_rec_res.append(rec_result)
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index b6a770637..4883015b7 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -150,6 +150,10 @@ def init_args():
 
     parser.add_argument("--show_log", type=str2bool, default=True)
     parser.add_argument("--use_onnx", type=str2bool, default=False)
+
+    # extended function
+    parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')
+
     return parser