diff --git a/.gitignore b/.gitignore index 3300be325..410be83f4 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,31 @@ paddleocr.egg-info/ /deploy/android_demo/app/cache/ test_tipc/web/models/ test_tipc/web/node_modules/ +en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams +en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams.info +en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdmodel +en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams +en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams.info +en_ppocr_mobile_v2.0_table_structure_infer/inference.pdmodel +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams.info +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdmodel +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams.info +ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdmodel +.gitignore +.gitignore +ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams +ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams.info +ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdmodel +ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/infer_cfg.yml +ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams +ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams.info +ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdmodel +.gitignore +ppstructure/layout/table/inference.pdiparams +ppstructure/layout/table/inference.pdiparams.info +ppstructure/layout/table/inference.pdmodel +ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape.tar +._en_ppocr_mobile_v2.0_table_structure_infer +en_ppocr_mobile_v2.0_table_structure_infer.tar diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml new file mode 100644 index 000000000..ee2584d52 --- /dev/null +++ b/configs/table/SLANet.yml @@ -0,0 +1,141 @@ +Global: + use_gpu: true + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/SLANet + save_epoch_step: 400 + # evaluation is run every 1000 iterations after the 0th iteration + eval_batch_step: [0, 1000] + cal_metric_during_train: True + pretrained_model: + checkpoints: /ssd1/zhoujun20/table/ch/PaddleOCR/output/en/table_lcnet_1_0_csp_pan_headsv3_smooth_l1_pretrain_ssld_weight81_sync_bn/best_accuracy.pdparams + save_inference_dir: ./output/SLANet/infer + use_visualdl: False + infer_img: doc/table/table.jpg + # for data or label process + character_dict_path: ppocr/utils/dict/table_structure_dict.txt + character_type: en + max_text_length: &max_text_length 500 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' + infer_mode: False + use_sync_bn: True + save_res_path: 'output/infer' + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 5.0 + lr: + # name: Piecewise + learning_rate: 0.001 + # decay_epochs : [10, 20] + # values : [0.002, 0.0002, 0.0001] + # warmup_epoch: 0 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: table + algorithm: SLANet + Backbone: + name: PPLCNet + scale: 1.0 + pretrained: true + use_ssld: true + Neck: + name: CSPPAN + out_channels: 96 + Head: + name: SLAHead + hidden_size: 256 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 + +Loss: + name: SLANetLoss + structure_weight: 1.0 + loc_weight: 2.0 + loc_loss: smooth_l1 + +PostProcess: + name: TableLabelDecode + +Metric: + name: TableMetric + main_indicator: acc + compute_bbox_metric: False + loc_reg_num: *loc_reg_num + box_format: *box_format + +Train: + dataset: + name: PubTabDataSet + data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/ + label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: False + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + loader: + shuffle: True + batch_size_per_card: 48 + drop_last: True + num_workers: 1 + +Eval: + dataset: + name: PubTabDataSet + data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/ + label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_val.jsonl] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - TableLabelEncode: + learn_empty_box: False + merge_no_span_structure: False + replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length + - TableBoxEncode: + box_format: *box_format + - ResizeTableImage: + max_len: 488 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - PaddingTableImage: + size: [488, 488] + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 48 + num_workers: 1 diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml index 1e6efe32d..8bed7d069 100755 --- a/configs/table/table_master.yml +++ b/configs/table/table_master.yml @@ -15,9 +15,8 @@ Global: save_res_path: ./output/table_master character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt infer_mode: false - max_text_length: 500 - process_total_num: 0 - process_cut_num: 0 + max_text_length: &max_text_length 500 + box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy' Optimizer: @@ -52,7 +51,8 @@ Architecture: headers: 8 dropout: 0 d_ff: 2024 - max_text_length: 500 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 Loss: name: TableMasterLoss @@ -66,6 +66,7 @@ Metric: name: TableMetric main_indicator: acc compute_bbox_metric: False + box_format: *box_format Train: dataset: @@ -80,13 +81,15 @@ Train: learn_empty_box: False merge_no_span_structure: True replace_empty_cell_token: True + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - ResizeTableImage: max_len: 480 resize_bboxes: True - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: *box_format - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] @@ -114,13 +117,15 @@ Eval: learn_empty_box: False merge_no_span_structure: True replace_empty_cell_token: True + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - ResizeTableImage: max_len: 480 resize_bboxes: True - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: *box_format - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 66c1c83e1..87cda7db2 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -17,10 +17,9 @@ Global: # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en - max_text_length: 800 + max_text_length: &max_text_length 800 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' infer_mode: False - process_total_num: 0 - process_cut_num: 0 Optimizer: name: Adam @@ -44,7 +43,8 @@ Architecture: name: TableAttentionHead hidden_size: 256 loc_type: 2 - max_text_length: 800 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 Loss: name: TableAttentionLoss @@ -72,6 +72,8 @@ Train: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 @@ -104,6 +106,8 @@ Eval: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index 71a19c6b7..dff3abb48 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -118,7 +118,7 @@ class OCRSystem(hub.Module): all_results.append([]) continue starttime = time.time() - dt_boxes, rec_res = self.text_sys(img) + dt_boxes, rec_res, _ = self.text_sys(img) elapse = time.time() - starttime logger.info("Predict time: {}".format(elapse)) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 97539faf2..ad391046a 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=2, + loc_reg_num=4, **kwargs): self.max_text_len = max_text_length self.lower = False @@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode): self.idx2char = {v: k for k, v in self.dict.items()} self.character = dict_character - self.point_num = point_num + self.loc_reg_num = loc_reg_num self.pad_idx = self.dict[self.beg_str] self.start_idx = self.dict[self.beg_str] self.end_idx = self.dict[self.end_str] @@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode): # encode box bboxes = np.zeros( - (self._max_text_len, self.point_num * 2), dtype=np.float32) + (self._max_text_len, self.loc_reg_num), dtype=np.float32) bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32) bbox_idx = 0 @@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode): replace_empty_cell_token=False, merge_no_span_structure=False, learn_empty_box=False, - point_num=2, + loc_reg_num=4, **kwargs): super(TableMasterLabelEncode, self).__init__( max_text_length, character_dict_path, replace_empty_cell_token, - merge_no_span_structure, learn_empty_box, point_num, **kwargs) + merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs) self.pad_idx = self.dict[self.pad_str] self.unknown_idx = self.dict[self.unknown_str] @@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode): class TableBoxEncode(object): - def __init__(self, use_xywh=False, **kwargs): - self.use_xywh = use_xywh + def __init__(self, box_format='xyxy', **kwargs): + assert box_format in ['xywh', 'xyxy', 'xyxyxyxy'] + self.box_format = box_format def __call__(self, data): img_height, img_width = data['image'].shape[:2] bboxes = data['bboxes'] - if self.use_xywh and bboxes.shape[1] == 4: + if self.box_format == 'xywh' and bboxes.shape[1] == 4: bboxes = self.xyxy2xywh(bboxes) bboxes[:, 0::2] /= img_width bboxes[:, 1::2] /= img_height @@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode): dict_character = [''] + dict_character return dict_character + class SPINLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ @@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode): super(SPINLabelEncode, self).__init__( max_text_length, character_dict_path, use_space_char) self.lower = lower + def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" @@ -1248,4 +1251,4 @@ class SPINLabelEncode(AttnLabelEncode): padded_text[:len(target)] = target data['label'] = np.array(padded_text) - return data \ No newline at end of file + return data diff --git a/ppocr/data/imaug/table_ops.py b/ppocr/data/imaug/table_ops.py index 8d139190a..c2c2fb2be 100644 --- a/ppocr/data/imaug/table_ops.py +++ b/ppocr/data/imaug/table_ops.py @@ -206,7 +206,7 @@ class ResizeTableImage(object): data['bboxes'] = data['bboxes'] * ratio data['image'] = resize_img data['src_img'] = img - data['shape'] = np.array([resize_h, resize_w, ratio, ratio]) + data['shape'] = np.array([height, width, ratio, ratio]) data['max_len'] = self.max_len return data diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 30120ac56..4629f0fe4 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss from .combined_loss import CombinedLoss # table loss -from .table_att_loss import TableAttentionLoss +from .table_att_loss import TableAttentionLoss, SLANetLoss from .table_master_loss import TableMasterLoss # vqa token loss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss @@ -63,7 +63,7 @@ def build_loss(config): 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', - 'TableMasterLoss', 'SPINAttentionLoss' + 'TableMasterLoss', 'SPINAttentionLoss', 'SLANetLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py index 3496c9072..d97715d54 100644 --- a/ppocr/losses/table_att_loss.py +++ b/ppocr/losses/table_att_loss.py @@ -22,65 +22,11 @@ from paddle.nn import functional as F class TableAttentionLoss(nn.Layer): - def __init__(self, - structure_weight, - loc_weight, - use_giou=False, - giou_weight=1.0, - **kwargs): + def __init__(self, structure_weight, loc_weight, **kwargs): super(TableAttentionLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.structure_weight = structure_weight self.loc_weight = loc_weight - self.use_giou = use_giou - self.giou_weight = giou_weight - - def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): - ''' - :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] - :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] - :return: loss - ''' - ix1 = paddle.maximum(preds[:, 0], bbox[:, 0]) - iy1 = paddle.maximum(preds[:, 1], bbox[:, 1]) - ix2 = paddle.minimum(preds[:, 2], bbox[:, 2]) - iy2 = paddle.minimum(preds[:, 3], bbox[:, 3]) - - iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10) - ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10) - - # overlap - inters = iw * ih - - # union - uni = (preds[:, 2] - preds[:, 0] + 1e-3) * ( - preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3 - ) * (bbox[:, 3] - bbox[:, 1] + - 1e-3) - inters + eps - - # ious - ious = inters / uni - - ex1 = paddle.minimum(preds[:, 0], bbox[:, 0]) - ey1 = paddle.minimum(preds[:, 1], bbox[:, 1]) - ex2 = paddle.maximum(preds[:, 2], bbox[:, 2]) - ey2 = paddle.maximum(preds[:, 3], bbox[:, 3]) - ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10) - eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10) - - # enclose erea - enclose = ew * eh + eps - giou = ious - (enclose - uni) / enclose - - loss = 1 - giou - - if reduction == 'mean': - loss = paddle.mean(loss) - elif reduction == 'sum': - loss = paddle.sum(loss) - else: - raise NotImplementedError - return loss def forward(self, predicts, batch): structure_probs = predicts['structure_probs'] @@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer): loc_targets_mask = loc_targets_mask[:, 1:, :] loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight - if self.use_giou: - loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, - loc_targets) * self.giou_weight - total_loss = structure_loss + loc_loss + loc_loss_giou - return { - 'loss': total_loss, - "structure_loss": structure_loss, - "loc_loss": loc_loss, - "loc_loss_giou": loc_loss_giou - } - else: - total_loss = structure_loss + loc_loss - return { - 'loss': total_loss, - "structure_loss": structure_loss, - "loc_loss": loc_loss - } + + total_loss = structure_loss + loc_loss + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss + } + + +class SLANetLoss(nn.Layer): + def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs): + super(SLANetLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean') + self.structure_weight = structure_weight + self.loc_weight = loc_weight + self.loc_loss = loc_loss + self.eps = 1e-12 + + def forward(self, predicts, batch): + structure_probs = predicts['structure_probs'] + structure_targets = batch[1].astype("int64") + structure_targets = structure_targets[:, 1:] + + structure_loss = self.loss_func(structure_probs, structure_targets) + + structure_loss = paddle.mean(structure_loss) * self.structure_weight + + loc_preds = predicts['loc_preds'] + loc_targets = batch[2].astype("float32") + loc_targets_mask = batch[3].astype("float32") + loc_targets = loc_targets[:, 1:, :] + loc_targets_mask = loc_targets_mask[:, 1:, :] + + loc_loss = F.smooth_l1_loss( + loc_preds * loc_targets_mask, + loc_targets * loc_targets_mask, + reduction='sum') * self.loc_weight + + loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps) + total_loss = structure_loss + loc_loss + return { + 'loss': total_loss, + "structure_loss": structure_loss, + "loc_loss": loc_loss + } diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py index fd2631e44..43dc1d761 100644 --- a/ppocr/metrics/table_metric.py +++ b/ppocr/metrics/table_metric.py @@ -59,7 +59,7 @@ class TableMetric(object): def __init__(self, main_indicator='acc', compute_bbox_metric=False, - point_num=2, + box_format='xyxy', **kwargs): """ @@ -70,7 +70,7 @@ class TableMetric(object): self.structure_metric = TableStructureMetric() self.bbox_metric = DetMetric() if compute_bbox_metric else None self.main_indicator = main_indicator - self.point_num = point_num + self.box_format = box_format self.reset() def __call__(self, pred_label, batch=None, *args, **kwargs): @@ -129,10 +129,14 @@ class TableMetric(object): self.bbox_metric.reset() def format_box(self, box): - if self.point_num == 2: + if self.box_format == 'xyxy': x1, y1, x2, y2 = box box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] - elif self.point_num == 4: + elif self.box_format == 'xywh': + x, y, w, h = box + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + elif self.box_format == 'xyxyxyxy': x1, y1, x2, y2, x3, y3, x4, y4 = box box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] return box diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index d4f5b15f5..f5d54150b 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -21,7 +21,10 @@ def build_backbone(config, model_type): from .det_resnet import ResNet from .det_resnet_vd import ResNet_vd from .det_resnet_vd_sast import ResNet_SAST - support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"] + from .det_pp_lcnet import PPLCNet + support_dict = [ + "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet" + ] if model_type == "table": from .table_master_resnet import TableResNetExtra support_dict.append('TableResNetExtra') diff --git a/ppocr/modeling/backbones/det_pp_lcnet.py b/ppocr/modeling/backbones/det_pp_lcnet.py new file mode 100644 index 000000000..3f719e92b --- /dev/null +++ b/ppocr/modeling/backbones/det_pp_lcnet.py @@ -0,0 +1,271 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import os +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal +from paddle.utils.download import get_path_from_url + +MODEL_URLS = { + "PPLCNet_x0.25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams", + "PPLCNet_x0.35": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams", + "PPLCNet_x0.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams", + "PPLCNet_x0.75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams", + "PPLCNet_x1.0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams", + "PPLCNet_x1.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams", + "PPLCNet_x2.0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams", + "PPLCNet_x2.5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams" +} + +MODEL_STAGES_PATTERN = { + "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] +} + +__all__ = list(MODEL_URLS.keys()) + +# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se. +# k: kernel_size +# in_c: input channel number in depthwise block +# out_c: output channel number in depthwise block +# s: stride in depthwise block +# use_se: whether to use SE block + +NET_CONFIG = { + "blocks2": + # k, in_c, out_c, s, use_se + [[3, 16, 32, 1, False]], + "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], + "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], + "blocks5": + [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + num_groups=1): + super().__init__() + + self.conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm( + num_filters, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.hardswish = nn.Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardswish(x) + return x + + +class DepthwiseSeparable(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + dw_size=3, + use_se=False): + super().__init__() + self.use_se = use_se + self.dw_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_channels, + filter_size=dw_size, + stride=stride, + num_groups=num_channels) + if use_se: + self.se = SEModule(num_channels) + self.pw_conv = ConvBNLayer( + num_channels=num_channels, + filter_size=1, + num_filters=num_filters, + stride=1) + + def forward(self, x): + x = self.dw_conv(x) + if self.use_se: + x = self.se(x) + x = self.pw_conv(x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class PPLCNet(nn.Layer): + def __init__(self, + in_channels=3, + scale=1.0, + pretrained=False, + use_ssld=False): + super().__init__() + self.out_channels = [ + int(NET_CONFIG["blocks3"][-1][2] * scale), + int(NET_CONFIG["blocks4"][-1][2] * scale), + int(NET_CONFIG["blocks5"][-1][2] * scale), + int(NET_CONFIG["blocks6"][-1][2] * scale) + ] + self.scale = scale + + self.conv1 = ConvBNLayer( + num_channels=in_channels, + filter_size=3, + num_filters=make_divisible(16 * scale), + stride=2) + + self.blocks2 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) + ]) + + self.blocks3 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) + ]) + + self.blocks4 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) + ]) + + self.blocks5 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) + ]) + + self.blocks6 = nn.Sequential(* [ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"]) + ]) + + if pretrained: + self._load_pretrained( + MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld) + + def forward(self, x): + outs = [] + x = self.conv1(x) + x = self.blocks2(x) + x = self.blocks3(x) + outs.append(x) + x = self.blocks4(x) + outs.append(x) + x = self.blocks5(x) + outs.append(x) + x = self.blocks6(x) + outs.append(x) + return outs + + def _load_pretrained(self, pretrained_url, use_ssld=False): + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", + "_ssld_pretrained") + print(pretrained_url) + local_weight_path = get_path_from_url( + pretrained_url, os.path.expanduser("~/.paddleclas/weights")) + param_state_dict = paddle.load(local_weight_path) + self.set_dict(param_state_dict) + return diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index b4f18b372..d8289d458 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -42,14 +42,15 @@ def build_head(config): #kie head from .kie_sdmgr_head import SDMGRHead - from .table_att_head import TableAttentionHead + from .table_att_head import TableAttentionHead, SLAHead from .table_master_head import TableMasterHead support_dict = [ 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead' + 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', + 'SLAHead' ] #table head diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 4f39d6253..00b434105 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -18,12 +18,26 @@ from __future__ import print_function import paddle import paddle.nn as nn +from paddle import ParamAttr import paddle.nn.functional as F import numpy as np from .rec_att_head import AttentionGRUCell +def get_para_bias_attr(l2_decay, k): + if l2_decay > 0: + regularizer = paddle.regularizer.L2Decay(l2_decay) + stdv = 1.0 / math.sqrt(k * 1.0) + initializer = nn.initializer.Uniform(-stdv, stdv) + else: + regularizer = None + initializer = None + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + return [weight_attr, bias_attr] + + class TableAttentionHead(nn.Layer): def __init__(self, in_channels, @@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer): in_max_len=488, max_text_length=800, out_channels=30, - point_num=2, + loc_reg_num=4, **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] @@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer): else: self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, - point_num * 2) + loc_reg_num) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer): loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + + +class SLAHead(nn.Layer): + def __init__(self, + in_channels, + hidden_size, + out_channels=30, + max_text_length=500, + loc_reg_num=4, + fc_decay=0.0, + **kwargs): + """ + @param in_channels: input shape + @param hidden_size: hidden_size for RNN and Embedding + @param out_channels: num_classes to rec + @param max_text_length: max text pred + """ + super().__init__() + in_channels = in_channels[-1] + self.hidden_size = hidden_size + self.max_text_length = max_text_length + self.emb = self._char_to_onehot + self.num_embeddings = out_channels + + # structure + self.structure_attention_cell = AttentionGRUCell( + in_channels, hidden_size, self.num_embeddings) + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + weight_attr1_1, bias_attr1_1 = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + weight_attr1_2, bias_attr1_2 = get_para_bias_attr( + l2_decay=fc_decay, k=hidden_size) + self.structure_generator = nn.Sequential( + nn.Linear( + self.hidden_size, + self.hidden_size, + weight_attr=weight_attr1_2, + bias_attr=bias_attr1_2), + nn.Linear( + hidden_size, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr)) + # loc + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=self.hidden_size) + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=self.hidden_size) + self.loc_generator = nn.Sequential( + nn.Linear( + self.hidden_size, + self.hidden_size, + weight_attr=weight_attr1, + bias_attr=bias_attr1), + nn.Linear( + self.hidden_size, + loc_reg_num, + weight_attr=weight_attr2, + bias_attr=bias_attr2), + nn.Sigmoid()) + + def forward(self, inputs, targets=None): + fea = inputs[-1] + batch_size = fea.shape[0] + # reshape + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + + hidden = paddle.zeros((batch_size, self.hidden_size)) + structure_preds = [] + loc_preds = [] + if self.training and targets is not None: + structure = targets[0] + for i in range(self.max_text_length + 1): + hidden, structure_step, loc_step = self._decode(structure[:, i], + fea, hidden) + structure_preds.append(structure_step) + loc_preds.append(loc_step) + else: + pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") + max_text_length = paddle.to_tensor(self.max_text_length) + # for export + loc_step, structure_step = None, None + for i in range(max_text_length + 1): + hidden, structure_step, loc_step = self._decode(pre_chars, fea, + hidden) + pre_chars = structure_step.argmax(axis=1, dtype="int32") + structure_preds.append(structure_step) + loc_preds.append(loc_step) + structure_preds = paddle.stack(structure_preds, axis=1) + loc_preds = paddle.stack(loc_preds, axis=1) + if not self.training: + structure_preds = F.softmax(structure_preds) + return {'structure_probs': structure_preds, 'loc_preds': loc_preds} + + def _decode(self, pre_chars, features, hidden): + """ + Predict table label and coordinates for each step + @param pre_chars: Table label in previous step + @param features: + @param hidden: hidden status in previous step + @return: + """ + emb_feature = self.emb(pre_chars) + # output shape is b * self.hidden_size + (output, hidden), alpha = self.structure_attention_cell( + hidden, features, emb_feature) + + # structure + structure_step = self.structure_generator(output) + # loc + loc_step = self.loc_generator(output) + return hidden, structure_step, loc_step + + def _char_to_onehot(self, input_char): + input_ont_hot = F.one_hot(input_char, self.num_embeddings) + return input_ont_hot diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py index fddbcc63f..486f9cbea 100644 --- a/ppocr/modeling/heads/table_master_head.py +++ b/ppocr/modeling/heads/table_master_head.py @@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer): d_ff=2048, dropout=0, max_text_length=500, - point_num=2, + loc_reg_num=4, **kwargs): super(TableMasterHead, self).__init__() hidden_size = in_channels[-1] @@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer): self.cls_fc = nn.Linear(hidden_size, out_channels) self.bbox_fc = nn.Sequential( # nn.Linear(hidden_size, hidden_size), - nn.Linear(hidden_size, point_num * 2), + nn.Linear(hidden_size, loc_reg_num), nn.Sigmoid()) self.norm = nn.LayerNorm(hidden_size) self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels) @@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer): self.SOS = out_channels - 3 self.PAD = out_channels - 1 self.out_channels = out_channels - self.point_num = point_num + self.loc_reg_num = loc_reg_num self.max_text_length = max_text_length def make_mask(self, tgt): @@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer): output = paddle.zeros( [input.shape[0], self.max_text_length + 1, self.out_channels]) bbox_output = paddle.zeros( - [input.shape[0], self.max_text_length + 1, self.point_num * 2]) + [input.shape[0], self.max_text_length + 1, self.loc_reg_num]) max_text_length = paddle.to_tensor(self.max_text_length) for i in range(max_text_length + 1): target_mask = self.make_mask(input) diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e10b082d1..e3ae2d6ef 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -25,9 +25,10 @@ def build_neck(config): from .fpn import FPN from .fce_fpn import FCEFPN from .pren_fpn import PRENFPN + from .csp_pan import CSPPAN support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' ] module_name = config.pop('name') diff --git a/ppocr/modeling/necks/csp_pan.py b/ppocr/modeling/necks/csp_pan.py new file mode 100755 index 000000000..625508e99 --- /dev/null +++ b/ppocr/modeling/necks/csp_pan.py @@ -0,0 +1,325 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The code is based on: +# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +__all__ = ['CSPPAN'] + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + groups=1, + act='leaky_relu'): + super(ConvBNLayer, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + assert self.act in ['leaky_relu', "hard_swish"] + self.conv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=groups, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn = nn.BatchNorm2D(out_channel) + + def forward(self, x): + x = self.bn(self.conv(x)) + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + +class DPModule(nn.Layer): + """ + Depth-wise and point-wise module. + Args: + in_channel (int): The input channels of this Module. + out_channel (int): The output channels of this Module. + kernel_size (int): The conv2d kernel size of this Module. + stride (int): The conv2d's stride of this Module. + act (str): The activation function of this Module, + Now support `leaky_relu` and `hard_swish`. + """ + + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + act='leaky_relu'): + super(DPModule, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + self.dwconv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=out_channel, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn1 = nn.BatchNorm2D(out_channel) + self.pwconv = nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + groups=1, + padding=0, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn2 = nn.BatchNorm2D(out_channel) + + def act_func(self, x): + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + def forward(self, x): + x = self.act_func(self.bn1(self.dwconv(x))) + x = self.act_func(self.bn2(self.pwconv(x))) + return x + + +class DarknetBottleneck(nn.Layer): + """The basic bottleneck block used in Darknet. + Each Block consists of two ConvModules and the input is added to the + final output. Each ConvModule is composed of Conv, BN, and act. + The first convLayer has filter size of 1x1 and the second one has the + filter size of 3x3. + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The kernel size of the convolution. Default: 0.5 + add_identity (bool): Whether to add identity to the out. + Default: True + use_depthwise (bool): Whether to use depthwise separable convolution. + Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expansion=0.5, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super(DarknetBottleneck, self).__init__() + hidden_channels = int(out_channels * expansion) + conv_func = DPModule if use_depthwise else ConvBNLayer + self.conv1 = ConvBNLayer( + in_channel=in_channels, + out_channel=hidden_channels, + kernel_size=1, + act=act) + self.conv2 = conv_func( + in_channel=hidden_channels, + out_channel=out_channels, + kernel_size=kernel_size, + stride=1, + act=act) + self.add_identity = \ + add_identity and in_channels == out_channels + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPLayer(nn.Layer): + """Cross Stage Partial Layer. + Args: + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + num_blocks (int): Number of blocks. Default: 1 + add_identity (bool): Whether to add identity in blocks. + Default: True + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expand_ratio=0.5, + num_blocks=1, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.final_conv = ConvBNLayer( + 2 * mid_channels, out_channels, 1, act=act) + + self.blocks = nn.Sequential(* [ + DarknetBottleneck( + mid_channels, + mid_channels, + kernel_size, + 1.0, + add_identity, + use_depthwise, + act=act) for _ in range(num_blocks) + ]) + + def forward(self, x): + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.blocks(x_main) + + x_final = paddle.concat((x_main, x_short), axis=1) + return self.final_conv(x_final) + + +class Channel_T(nn.Layer): + def __init__(self, + in_channels=[116, 232, 464], + out_channels=96, + act="leaky_relu"): + super(Channel_T, self).__init__() + self.convs = nn.LayerList() + for i in range(len(in_channels)): + self.convs.append( + ConvBNLayer( + in_channels[i], out_channels, 1, act=act)) + + def forward(self, x): + outs = [self.convs[i](x[i]) for i in range(len(x))] + return outs + + +class CSPPAN(nn.Layer): + """Path Aggregation Network with CSP module. + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + kernel_size (int): The conv2d kernel size of this Module. + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=5, + num_csp_blocks=1, + use_depthwise=True, + act='hard_swish'): + super(CSPPAN, self).__init__() + self.in_channels = in_channels + self.out_channels = [out_channels] * len(in_channels) + conv_func = DPModule if use_depthwise else ConvBNLayer + + self.conv_t = Channel_T(in_channels, out_channels, act=act) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.top_down_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + CSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + # build bottom-up blocks + self.downsamples = nn.LayerList() + self.bottom_up_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv_func( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=2, + act=act)) + self.bottom_up_blocks.append( + CSPLayer( + out_channels * 2, + out_channels, + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + Returns: + tuple[Tensor]: CSPPAN features. + """ + assert len(inputs) == len(self.in_channels) + inputs = self.conv_t(inputs) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + upsample_feat = F.upsample( + feat_heigh, size=feat_low.shape[2:4], mode="nearest") + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx](paddle.concat( + [downsample_feat, feat_height], 1)) + outs.append(out) + + return tuple(outs) diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py index 4396ec4f7..ce254f314 100644 --- a/ppocr/postprocess/table_postprocess.py +++ b/ppocr/postprocess/table_postprocess.py @@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode): def __init__(self, character_dict_path, **kwargs): super(TableLabelDecode, self).__init__(character_dict_path) - self.td_token = ['', '', ''] + self.td_token = ['', ''] def __call__(self, preds, batch=None): structure_probs = preds['structure_probs'] @@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode): def _bbox_decode(self, bbox, shape): h, w, ratio_h, ratio_w, pad_h, pad_w = shape - src_h = h / ratio_h - src_w = w / ratio_w - bbox[0::2] *= src_w - bbox[1::2] *= src_h + bbox[0::2] *= w + bbox[1::2] *= h return bbox @@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode): bbox[1::2] *= h bbox[0::2] /= ratio_w bbox[1::2] /= ratio_h + x, y, w, h = bbox + x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 + bbox = np.array([x1, y1, x2, y2]) return bbox diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py index e0fbf06ab..030d1c38d 100644 --- a/ppocr/utils/visual.py +++ b/ppocr/utils/visual.py @@ -113,14 +113,10 @@ def draw_re_results(image, return np.array(img_new) -def draw_rectangle(img_path, boxes, use_xywh=False): +def draw_rectangle(img_path, boxes): img = cv2.imread(img_path) img_show = img.copy() for box in boxes.astype(int): - if use_xywh: - x, y, w, h = box - x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 - else: - x1, y1, x2, y2 = box + x1, y1, x2, y2 = box cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) return img_show \ No newline at end of file diff --git a/ppstructure/layout/picodet_postprocess.py b/ppstructure/layout/picodet_postprocess.py new file mode 100644 index 000000000..7df13f827 --- /dev/null +++ b/ppstructure/layout/picodet_postprocess.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.special import softmax + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + input_shape, + ori_shape, + scale_factor, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.ori_shape = ori_shape + self.input_shape = input_shape + self.scale_factor = scale_factor + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def warp_boxes(self, boxes, ori_shape): + """Apply transform to boxes + """ + width, height = ori_shape[1], ori_shape[0] + n = len(boxes) + if n: + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + # xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + def __call__(self, scores, raw_boxes): + batch_size = raw_boxes[0].shape[0] + reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) + out_boxes_num = [] + out_boxes_list = [] + for batch_id in range(batch_size): + # generate centers + decode_boxes = [] + select_scores = [] + for stride, box_distribute, score in zip(self.strides, raw_boxes, + scores): + box_distribute = box_distribute[batch_id] + score = score[batch_id] + # centers + fm_h = self.input_shape[0] / stride + fm_w = self.input_shape[1] / stride + h_range = np.arange(fm_h) + w_range = np.arange(fm_w) + ww, hh = np.meshgrid(w_range, h_range) + ct_row = (hh.flatten() + 0.5) * stride + ct_col = (ww.flatten() + 0.5) * stride + center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) + + # box distribution to distance + reg_range = np.arange(reg_max + 1) + box_distance = box_distribute.reshape((-1, reg_max + 1)) + box_distance = softmax(box_distance, axis=1) + box_distance = box_distance * np.expand_dims(reg_range, axis=0) + box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) + box_distance = box_distance * stride + + # top K candidate + topk_idx = np.argsort(score.max(axis=1))[::-1] + topk_idx = topk_idx[:self.nms_top_k] + center = center[topk_idx] + score = score[topk_idx] + box_distance = box_distance[topk_idx] + + # decode box + decode_box = center + [-1, -1, 1, 1] * box_distance + + select_scores.append(score) + decode_boxes.append(decode_box) + + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + out_boxes_num.append(0) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, :4] = self.warp_boxes( + picked_box_probs[:, :4], self.ori_shape[batch_id]) + im_scale = np.concatenate([ + self.scale_factor[batch_id][::-1], + self.scale_factor[batch_id][::-1] + ]) + picked_box_probs[:, :4] /= im_scale + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), + axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + out_boxes_num.append(len(picked_labels)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + out_boxes_num = np.asarray(out_boxes_num).astype(np.int32) + return out_boxes_list, out_boxes_num diff --git a/ppstructure/layout/predict_layout.py b/ppstructure/layout/predict_layout.py new file mode 100644 index 000000000..2fb4b4623 --- /dev/null +++ b/ppstructure/layout/predict_layout.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time + +import tools.infer.utility as utility +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppstructure.utility import parse_args +from picodet_postprocess import PicoDetPostProcess + +logger = get_logger() + + +class LayoutPredictor(object): + def __init__(self, args): + pre_process_list = [{ + 'Resize': { + 'size': [800, 608] + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image'] + } + }] + # postprocess_params = { + # 'name': 'LayoutPostProcess', + # "character_dict_path": args.layout_dict_path, + # } + + self.preprocess_op = create_operators(pre_process_list) + # self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 'layout', logger) + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + + if img is None: + return None, 0 + + img = np.expand_dims(img, axis=0) + img = img.copy() + + preds, elapse = 0, 1 + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + + # outputs = [] + # for output_tensor in self.output_tensors: + # output = output_tensor.copy_to_cpu() + # outputs.append(output) + np_score_list, np_boxes_list = [], [] + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + for out_idx in range(num_outs): + np_score_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + np_boxes_list.append( + self.predictor.get_output_handle(output_names[ + out_idx + num_outs]).copy_to_cpu()) + # result = dict(boxes=np_score_list, boxes_num=np_boxes_list) + postprocessor = PicoDetPostProcess( + (800, 608), [[800., 608.]], + np.array([[1.010101, 0.99346405]]), + strides=[8, 16, 32, 64], + nms_threshold=0.5) + np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list) + result = dict(boxes=np_boxes, boxes_num=np_boxes_num) + # print(result) + im_bboxes_num = result['boxes_num'][0] + # print('im_bboxes_num:',im_bboxes_num) + + bboxs = result['boxes'][0:0 + im_bboxes_num, :] + threshold = 0.5 + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) + np_boxes = np_boxes[expect_boxes, :] + preds = [] + + id2label = {1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'} + for dt in np_boxes: + clsid, bbox, score = int(dt[0]), dt[2:], dt[1] + label = id2label[clsid + 1] + result_di = {'bbox': bbox, 'label': label} + preds.append(result_di) + # print('result_di',result_di) + # print('clsid, bbox, score:',clsid, bbox, score) + + elapse = time.time() - starttime + return preds, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + layout_predictor = LayoutPredictor(args) + count = 0 + total_time = 0 + + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + layout_res, elapse = layout_predictor(img) + + logger.info("result: {}".format(layout_res)) + + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index d6f2e2424..075d91446 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -18,7 +18,7 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 @@ -32,6 +32,7 @@ from attrdict import AttrDict from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger from tools.infer.predict_system import TextSystem +from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.utility import parse_args, draw_structure_result from ppstructure.recovery.recovery_to_doc import convert_info_docx @@ -51,28 +52,14 @@ class StructureSystem(object): "When args.layout is false, args.ocr is automatically set to false" ) args.drop_score = 0 - # init layout and ocr model + # init model + self.layout_predictor = None self.text_system = None + self.table_system = None if args.layout: - import layoutparser as lp - config_path = None - model_path = None - if os.path.isdir(args.layout_path_model): - model_path = args.layout_path_model - else: - config_path = args.layout_path_model - self.table_layout = lp.PaddleDetectionLayoutModel( - config_path=config_path, - model_path=model_path, - label_map=args.layout_label_map, - threshold=0.5, - enable_mkldnn=args.enable_mkldnn, - enforce_cpu=not args.use_gpu, - thread_num=args.cpu_threads) + self.layout_predictor = LayoutPredictor(args) if args.ocr: self.text_system = TextSystem(args) - else: - self.table_layout = None if args.table: if self.text_system is not None: self.table_system = TableSystem( @@ -80,38 +67,59 @@ class StructureSystem(object): self.text_system.text_recognizer) else: self.table_system = TableSystem(args) - else: - self.table_system = None elif self.mode == 'vqa': raise NotImplementedError def __call__(self, img, return_ocr_result_in_table=False): + time_dict = { + 'layout': 0, + 'table': 0, + 'table_match': 0, + 'det': 0, + 'rec': 0, + 'vqa': 0, + 'all': 0 + } + start = time.time() if self.mode == 'structure': ori_im = img.copy() - if self.table_layout is not None: - layout_res = self.table_layout.detect(img[..., ::-1]) + if self.layout_predictor is not None: + layout_res, elapse = self.layout_predictor(img) + time_dict['layout'] += elapse else: h, w = ori_im.shape[:2] - layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')] + layout_res = [dict(bbox=None, label='table')] res_list = [] for region in layout_res: res = '' - x1, y1, x2, y2 = region.coordinates - x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - roi_img = ori_im[y1:y2, x1:x2, :] - if region.type == 'Table': + if region['bbox'] is not None: + x1, y1, x2, y2 = region['bbox'] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + roi_img = ori_im[y1:y2, x1:x2, :] + else: + x1, y1, x2, y2 = 0, 0, w, h + roi_img = ori_im + if region['label'] == 'table': if self.table_system is not None: - res = self.table_system(roi_img, - return_ocr_result_in_table) + res, table_time_dict = self.table_system( + roi_img, return_ocr_result_in_table) + time_dict['table'] += table_time_dict['table'] + time_dict['table_match'] += table_time_dict['match'] + time_dict['det'] += table_time_dict['det'] + time_dict['rec'] += table_time_dict['rec'] else: if self.text_system is not None: if args.recovery: wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype) wht_im[y1:y2, x1:x2, :] = roi_img - filter_boxes, filter_rec_res = self.text_system(wht_im) + filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( + wht_im) else: - filter_boxes, filter_rec_res = self.text_system(roi_img) + filter_boxes, filter_rec_res, ocr_time_dict = self.text_system( + roi_img) + time_dict['det'] += ocr_time_dict['det'] + time_dict['rec'] += ocr_time_dict['rec'] # remove style char style_token = [ '', '', '', '', '', @@ -133,15 +141,17 @@ class StructureSystem(object): 'text_region': box.tolist() }) res_list.append({ - 'type': region.type, + 'type': region['label'].lower(), 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res }) - return res_list + end = time.time() + time_dict['all'] = end - start + return res_list, time_dict elif self.mode == 'vqa': raise NotImplementedError - return None + return None, None def save_structure_res(res, save_folder, img_name): @@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name): roi_img = region.pop('img') f.write('{}\n'.format(json.dumps(region))) - if region['type'] == 'Table' and len(region[ + if region['type'] == 'table' and len(region[ 'res']) > 0 and 'html' in region['res']: excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) to_excel(region['res']['html'], excel_path) - elif region['type'] == 'Figure': + elif region['type'] == 'figure': img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox'])) cv2.imwrite(img_path, roi_img) @@ -188,7 +198,7 @@ def main(args): logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() - res = structure_sys(img) + res, time_dict = structure_sys(img) if structure_sys.mode == 'structure': save_structure_res(res, save_folder, img_name) @@ -201,7 +211,7 @@ def main(args): cv2.imwrite(img_save_path, draw_img) logger.info('result save to {}'.format(img_save_path)) if args.recovery: - convert_info_docx(img, res, save_folder, img_name) + convert_info_docx(img, res, save_folder, img_name) elapse = time.time() - starttime logger.info("Predict time : {:.3f}s".format(elapse)) diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py index 87b44d3d9..435d69322 100755 --- a/ppstructure/table/eval_table.py +++ b/ppstructure/table/eval_table.py @@ -13,12 +13,14 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 -import json +import pickle +import paddle from tqdm import tqdm from ppstructure.table.table_metric import TEDS from ppstructure.table.predict_table import TableSystem @@ -33,40 +35,74 @@ def parse_args(): parser.add_argument("--gt_path", type=str) return parser.parse_args() -def main(gt_path, img_root, args): - teds = TEDS(n_jobs=16) +def load_txt(txt_path): + pred_html_dict = {} + if not os.path.exists(txt_path): + return pred_html_dict + with open(txt_path, encoding='utf-8') as f: + lines = f.readlines() + for line in lines: + line = line.strip().split('\t') + img_name, pred_html = line + pred_html_dict[img_name] = pred_html + return pred_html_dict + + +def load_result(path): + data = {} + if os.path.exists(path): + data = pickle.load(open(path, 'rb')) + return data + + +def save_result(path, data): + old_data = load_result(path) + old_data.update(data) + with open(path, 'wb') as f: + pickle.dump(old_data, f) + + +def main(gt_path, img_root, args): + os.makedirs(args.output, exist_ok=True) + # init TableSystem text_sys = TableSystem(args) - jsons_gt = json.load(open(gt_path)) # gt + # load gt and preds html result + gt_html_dict = load_txt(gt_path) + + ocr_result = load_result(os.path.join(args.output, 'ocr.pickle')) + structure_result = load_result( + os.path.join(args.output, 'structure.pickle')) + pred_htmls = [] gt_htmls = [] - for img_name in tqdm(jsons_gt): - # read image - img = cv2.imread(os.path.join(img_root,img_name)) - pred_html = text_sys(img) + for img_name, gt_html in tqdm(gt_html_dict.items()): + img = cv2.imread(os.path.join(img_root, img_name)) + # run ocr and save result + if img_name not in ocr_result: + dt_boxes, rec_res, _, _ = text_sys._ocr(img) + ocr_result[img_name] = [dt_boxes, rec_res] + save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result) + # run structure and save result + if img_name not in structure_result: + structure_res, _ = text_sys._structure(img) + structure_result[img_name] = structure_res + save_result( + os.path.join(args.output, 'structure.pickle'), structure_result) + dt_boxes, rec_res = ocr_result[img_name] + structure_res = structure_result[img_name] + # match ocr and structure + pred_html = text_sys.match(structure_res, dt_boxes, rec_res) + pred_htmls.append(pred_html) - - gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name] - gt_html, gt = get_gt_html(gt_structures, gt_contents) gt_htmls.append(gt_html) + + # compute teds + teds = TEDS(n_jobs=16) scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) - logger.info('teds:', sum(scores) / len(scores)) - - -def get_gt_html(gt_structures, gt_contents): - end_html = [] - td_index = 0 - for tag in gt_structures: - if '' in tag: - if gt_contents[td_index] != []: - end_html.extend(gt_contents[td_index]) - end_html.append(tag) - td_index += 1 - else: - end_html.append(tag) - return ''.join(end_html), end_html + logger.info('teds: {}'.format(sum(scores) / len(scores))) if __name__ == '__main__': args = parse_args() - main(args.gt_path,args.image_dir, args) + main(args.gt_path, args.image_dir, args) diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index c3b563844..d75e9abb3 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -1,11 +1,15 @@ import json +from ppstructure.table.table_master_match import deal_eb_token, deal_bb + + def distance(box_1, box_2): - x1, y1, x2, y2 = box_1 - x3, y3, x4, y4 = box_2 - dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) - dis_2 = abs(x3 - x1) + abs(y3 - y1) - dis_3 = abs(x4- x2) + abs(y4 - y2) - return dis + min(dis_2, dis_3) + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4 - x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + def compute_iou(rec1, rec2): """ @@ -18,23 +22,22 @@ def compute_iou(rec1, rec2): # computing area of each rectangles S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) - + # computing the sum_area sum_area = S_rec1 + S_rec2 - + # find the each edge of intersect rectangle left_line = max(rec1[1], rec2[1]) right_line = min(rec1[3], rec2[3]) top_line = max(rec1[0], rec2[0]) bottom_line = min(rec1[2], rec2[2]) - + # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: return 0.0 else: intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect))*1.0 - + return (intersect / (sum_area - intersect)) * 1.0 def matcher_merge(ocr_bboxes, pred_bboxes): @@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes): distances = [] for j, pred_box in enumerate(pred_bboxes): # compute l1 distence and IOU between two boxes - distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) + distances.append((distance(gt_box, pred_box), + 1. - compute_iou(gt_box, pred_box))) sorted_distances = distances.copy() # select nearest cell - sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) - if distances.index(sorted_distances[0]) not in matched.keys(): + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): matched[distances.index(sorted_distances[0])] = [i] else: matched[distances.index(sorted_distances[0])].append(i) - return matched#, sum(ious) / len(ious) + return matched #, sum(ious) / len(ious) + def complex_num(pred_bboxes): complex_nums = [] @@ -67,6 +73,7 @@ def complex_num(pred_bboxes): complex_nums.append(temp_ious[distances.index(min(distances))]) return sum(complex_nums) / len(complex_nums) + def get_rows(pred_bboxes): pre_bbox = pred_bboxes[0] res = [] @@ -81,7 +88,9 @@ def get_rows(pred_bboxes): for i in range(step): pred_bboxes.pop(0) return res, pred_bboxes -def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 + + +def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 ys_1 = [] ys_2 = [] for box in pred_bboxes: @@ -95,12 +104,14 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 box[3] = min_y_2 re_boxes.append(box) return re_boxes - + + def matcher_refine_row(gt_bboxes, pred_bboxes): before_refine_pred_bboxes = pred_bboxes.copy() pred_bboxes = [] - while(len(before_refine_pred_bboxes) != 0): - row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes) + while (len(before_refine_pred_bboxes) != 0): + row_bboxes, before_refine_pred_bboxes = get_rows( + before_refine_pred_bboxes) print(row_bboxes) pred_bboxes.extend(refine_rows(row_bboxes)) all_dis = [] @@ -114,12 +125,11 @@ def matcher_refine_row(gt_bboxes, pred_bboxes): #temp_ious.append(compute_iou(gt_box, pred_box)) #all_dis.append(min(distances)) #ious.append(temp_ious[distances.index(min(distances))]) - if distances.index(min(distances)) not in matched.keys(): + if distances.index(min(distances)) not in matched.keys(): matched[distances.index(min(distances))] = [i] else: matched[distances.index(min(distances))].append(i) - return matched#, sum(ious) / len(ious) - + return matched #, sum(ious) / len(ious) #先挑选出一行,再进行匹配 @@ -128,29 +138,30 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): delete_gt_bboxes = gt_bboxes.copy() match_bboxes_ready = [] matched = {} - while(len(delete_gt_bboxes) != 0): + while (len(delete_gt_bboxes) != 0): row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) - row_bboxes = sorted(row_bboxes, key = lambda key: key[0]) + row_bboxes = sorted(row_bboxes, key=lambda key: key[0]) if len(pred_bboxes_rows) > 0: match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) print(row_bboxes) for i, gt_box in enumerate(row_bboxes): #print(gt_box) pred_distances = [] - distances = [] + distances = [] for pred_bbox in pred_bboxes: pred_distances.append(distance(gt_box, pred_bbox)) for j, pred_box in enumerate(match_bboxes_ready): distances.append(distance(gt_box, pred_box)) index = pred_distances.index(min(distances)) #print('index', index) - if index not in matched.keys(): + if index not in matched.keys(): matched[index] = [gt_box_index] else: matched[index].append(gt_box_index) gt_box_index += 1 return matched + def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): ''' gt_bboxes: 排序后 @@ -161,7 +172,7 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): match_bboxes_ready = [] match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) for i, gt_box in enumerate(gt_bboxes): - + pred_distances = [] for pred_bbox in pred_bboxes: pred_distances.append(distance(gt_box, pred_bbox)) @@ -184,9 +195,143 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): #print(gt_box, index) #match_bboxes_ready.pop(distances.index(min(distances))) print(gt_box, match_bboxes_ready[distances.index(min(distances))]) - if index not in matched.keys(): + if index not in matched.keys(): matched[index] = [i] else: matched[index].append(i) pre_bbox = gt_box return matched + + +class TableMatch: + def __init__(self, filter_ocr_result=False, use_master=False): + self.filter_ocr_result = filter_ocr_result + self.use_master = use_master + + def __call__(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + if self.filter_ocr_result: + dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes, + rec_res) + matched_index = self.match_result(dt_boxes, pred_bboxes) + if self.use_master: + pred_html, pred = self.get_pred_html_master(pred_structures, + matched_index, rec_res) + else: + pred_html, pred = self.get_pred_html(pred_structures, matched_index, + rec_res) + return pred_html + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] + distances = [] + for j, pred_box in enumerate(pred_bboxes): + distances.append((distance(gt_box, pred_box), + 1. - compute_iou(gt_box, pred_box) + )) # 获取两两cell之间的L1距离和 1- IOU + sorted_distances = distances.copy() + # 根据距离和IOU挑选最"近"的cell + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if '' == tag: + end_html.extend('') + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + if '' == tag: + end_html.append('') + else: + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + def get_pred_html_master(self, pred_structures, matched_index, + ocr_contents): + end_html = [] + td_index = 0 + for token in pred_structures: + if '' in token: + txt = '' + b_with = False + if td_index in matched_index.keys(): + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + txt += content + if b_with: + txt = '{}'.format(txt) + if '' == token: + token = '{}'.format(txt) + else: + token = '{}'.format(txt) + td_index += 1 + token = deal_eb_token(token) + end_html.append(token) + html = ''.join(end_html) + html = deal_bb(html) + return html, end_html + + def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): + y1 = pred_bboxes[:, 1::2].min() + new_dt_boxes = [] + new_rec_res = [] + + for box, rec in zip(dt_boxes, rec_res): + if np.max(box[1::2]) < y1: + continue + new_dt_boxes.append(box) + new_rec_res.append(rec) + return new_dt_boxes, new_rec_res diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py index 7a7d3169d..01d467594 100755 --- a/ppstructure/table/predict_structure.py +++ b/ppstructure/table/predict_structure.py @@ -16,7 +16,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -87,6 +87,7 @@ class TableStructurer(object): utility.create_predictor(args, 'table', logger) def __call__(self, img): + starttime = time.time() ori_im = img.copy() data = {'image': img} data = transform(data, self.preprocess_op) @@ -95,7 +96,6 @@ class TableStructurer(object): return None, 0 img = np.expand_dims(img, axis=0) img = img.copy() - starttime = time.time() self.input_tensor.copy_from_cpu(img) self.predictor.run() @@ -126,7 +126,6 @@ def main(args): table_structurer = TableStructurer(args) count = 0 total_time = 0 - use_xywh = args.table_algorithm in ['TableMaster'] os.makedirs(args.output, exist_ok=True) with open( os.path.join(args.output, 'infer.txt'), mode='w', @@ -146,7 +145,7 @@ def main(args): f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str)) - img = draw_rectangle(image_file, bbox_list, use_xywh) + img = draw_rectangle(image_file, bbox_list) img_save_path = os.path.join(args.output, os.path.basename(image_file)) cv2.imwrite(img_save_path, img) diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py index becc6daef..6e7051235 100644 --- a/ppstructure/table/predict_table.py +++ b/ppstructure/table/predict_table.py @@ -18,20 +18,23 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 import copy +import logging import numpy as np import time import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det import tools.infer.utility as utility +from tools.infer.predict_system import sorted_boxes from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from ppstructure.table.matcher import distance, compute_iou +from ppstructure.table.matcher import TableMatch +from ppstructure.table.table_master_match import TableMasterMatcher from ppstructure.utility import parse_args import ppstructure.table.predict_structure as predict_strture @@ -55,11 +58,20 @@ def expand(pix, det_box, shape): class TableSystem(object): def __init__(self, args, text_detector=None, text_recognizer=None): + if not args.show_log: + logger.setLevel(logging.INFO) + self.text_detector = predict_det.TextDetector( args) if text_detector is None else text_detector self.text_recognizer = predict_rec.TextRecognizer( args) if text_recognizer is None else text_recognizer + self.table_structurer = predict_strture.TableStructurer(args) + if args.table_algorithm in ['TableMaster']: + self.match = TableMasterMatcher() + else: + self.match = TableMatch() + self.benchmark = args.benchmark self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( args, 'table', logger) @@ -85,16 +97,47 @@ class TableSystem(object): def __call__(self, img, return_ocr_result_in_table=False): result = dict() - ori_im = img.copy() + time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0} + start = time.time() + + structure_res, elapse = self._structure(copy.deepcopy(img)) + time_dict['table'] = elapse + + dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( + copy.deepcopy(img)) + time_dict['det'] = det_elapse + time_dict['rec'] = rec_elapse + + if return_ocr_result_in_table: + result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes] + result['rec_res'] = rec_res + + tic = time.time() + pred_html = self.match(structure_res, dt_boxes, rec_res) + toc = time.time() + time_dict['match'] = toc - tic + # pred_html = self.match(1, 1, 1,img_name) + result['html'] = pred_html + if self.benchmark: + self.autolog.times.end(stamp=True) + end = time.time() + time_dict['all'] = end - start + if self.benchmark: + self.autolog.times.stamp() + return result, time_dict + + def _structure(self, img): if self.benchmark: self.autolog.times.start() structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + return structure_res, elapse + + def _ocr(self, img): if self.benchmark: self.autolog.times.stamp() - dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) + dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img)) dt_boxes = sorted_boxes(dt_boxes) - if return_ocr_result_in_table: - result['boxes'] = [x.tolist() for x in dt_boxes] + r_boxes = [] for box in dt_boxes: x_min = box[:, 0].min() - 1 @@ -105,125 +148,20 @@ class TableSystem(object): r_boxes.append(box) dt_boxes = np.array(r_boxes) logger.debug("dt_boxes num : {}, elapse : {}".format( - len(dt_boxes), elapse)) + len(dt_boxes), det_elapse)) if dt_boxes is None: return None, None + img_crop_list = [] for i in range(len(dt_boxes)): det_box = dt_boxes[i] - x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) - text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] + x0, y0, x1, y1 = expand(2, det_box, img.shape) + text_rect = img[int(y0):int(y1), int(x0):int(x1), :] img_crop_list.append(text_rect) - rec_res, elapse = self.text_recognizer(img_crop_list) + rec_res, rec_elapse = self.text_recognizer(img_crop_list) logger.debug("rec_res num : {}, elapse : {}".format( - len(rec_res), elapse)) - if self.benchmark: - self.autolog.times.stamp() - if return_ocr_result_in_table: - result['rec_res'] = rec_res - pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) - result['html'] = pred_html - if self.benchmark: - self.autolog.times.end(stamp=True) - return result - - def rebuild_table(self, structure_res, dt_boxes, rec_res): - pred_structures, pred_bboxes = structure_res - dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) - pred_html, pred = self.get_pred_html(pred_structures, matched_index, - rec_res) - return pred_html, pred - - def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res): - y1 = pred_bboxes[:,1::2].min() - new_dt_boxes = [] - new_rec_res = [] - - for box,rec in zip(dt_boxes, rec_res): - if np.max(box[1::2]) < y1: - continue - new_dt_boxes.append(box) - new_rec_res.append(rec) - return new_dt_boxes, new_rec_res - - - def match_result(self, dt_boxes, pred_bboxes): - matched = {} - for i, gt_box in enumerate(dt_boxes): - # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] - distances = [] - for j, pred_box in enumerate(pred_bboxes): - distances.append((distance(gt_box, pred_box), - 1. - compute_iou(gt_box, pred_box) - )) # 获取两两cell之间的L1距离和 1- IOU - sorted_distances = distances.copy() - # 根据距离和IOU挑选最"近"的cell - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0])) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched - - def get_pred_html(self, pred_structures, matched_index, ocr_contents): - end_html = [] - td_index = 0 - for tag in pred_structures: - if '' in tag: - if td_index in matched_index.keys(): - b_with = False - if '' in ocr_contents[matched_index[td_index][ - 0]] and len(matched_index[td_index]) > 1: - b_with = True - end_html.extend('') - for i, td_index_index in enumerate(matched_index[td_index]): - content = ocr_contents[td_index_index][0] - if len(matched_index[td_index]) > 1: - if len(content) == 0: - continue - if content[0] == ' ': - content = content[1:] - if '' in content: - content = content[3:] - if '' in content: - content = content[:-4] - if len(content) == 0: - continue - if i != len(matched_index[ - td_index]) - 1 and ' ' != content[-1]: - content += ' ' - end_html.extend(content) - if b_with: - end_html.extend('') - - end_html.append(tag) - td_index += 1 - else: - end_html.append(tag) - return ''.join(end_html), end_html - - -def sorted_boxes(dt_boxes): - """ - Sort text boxes in order from top to bottom, left to right - args: - dt_boxes(array):detected text boxes with shape [4, 2] - return: - sorted boxes(array) with shape [4, 2] - """ - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - - for i in range(num_boxes - 1): - if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp - return _boxes + len(rec_res), rec_elapse)) + return dt_boxes, rec_res, det_elapse, rec_elapse def to_excel(html_table, excel_path): @@ -249,7 +187,7 @@ def main(args): logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() - pred_res = text_sys(img) + pred_res, _ = text_sys(img) pred_html = pred_res['html'] logger.info(pred_html) to_excel(pred_html, excel_path) diff --git a/ppstructure/table/table_master_match.py b/ppstructure/table/table_master_match.py new file mode 100644 index 000000000..069d576bf --- /dev/null +++ b/ppstructure/table/table_master_match.py @@ -0,0 +1,1009 @@ +import os +import re +import cv2 +import glob +import copy +import math +import pickle +import numpy as np + +from shapely.geometry import Polygon, MultiPoint +""" +Useful function in matching. +""" + + +def remove_empty_bboxes(bboxes): + """ + remove [0., 0., 0., 0.] in structure master bboxes. + len(bboxes.shape) must be 2. + :param bboxes: + :return: + """ + new_bboxes = [] + for bbox in bboxes: + if sum(bbox) == 0.: + continue + new_bboxes.append(bbox) + return np.array(new_bboxes) + + +def xywh2xyxy(bboxes): + if len(bboxes.shape) == 1: + new_bboxes = np.empty_like(bboxes) + new_bboxes[0] = bboxes[0] - bboxes[2] / 2 + new_bboxes[1] = bboxes[1] - bboxes[3] / 2 + new_bboxes[2] = bboxes[0] + bboxes[2] / 2 + new_bboxes[3] = bboxes[1] + bboxes[3] / 2 + return new_bboxes + elif len(bboxes.shape) == 2: + new_bboxes = np.empty_like(bboxes) + new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2 + new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2 + new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2 + new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2 + return new_bboxes + else: + raise ValueError + + +def xyxy2xywh(bboxes): + if len(bboxes.shape) == 1: + new_bboxes = np.empty_like(bboxes) + new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2 + new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2 + new_bboxes[2] = bboxes[2] - bboxes[0] + new_bboxes[3] = bboxes[3] - bboxes[1] + return new_bboxes + elif len(bboxes.shape) == 2: + new_bboxes = np.empty_like(bboxes) + new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2 + new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2 + new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return new_bboxes + else: + raise ValueError + + +def pickle_load(path, prefix='end2end'): + if os.path.isfile(path): + data = pickle.load(open(path, 'rb')) + elif os.path.isdir(path): + data = dict() + search_path = os.path.join(path, '{}_*.pkl'.format(prefix)) + pkls = glob.glob(search_path) + for pkl in pkls: + this_data = pickle.load(open(pkl, 'rb')) + data.update(this_data) + else: + raise ValueError + return data + + +def convert_coord(xyxy): + """ + Convert two points format to four points format. + :param xyxy: + :return: + """ + new_bbox = np.zeros([4, 2], dtype=np.float32) + new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1] + new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1] + new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3] + new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3] + return new_bbox + + +def cal_iou(bbox1, bbox2): + bbox1_poly = Polygon(bbox1).convex_hull + bbox2_poly = Polygon(bbox2).convex_hull + union_poly = np.concatenate((bbox1, bbox2)) + + if not bbox1_poly.intersects(bbox2_poly): + iou = 0 + else: + inter_area = bbox1_poly.intersection(bbox2_poly).area + union_area = MultiPoint(union_poly).convex_hull.area + if union_area == 0: + iou = 0 + else: + iou = float(inter_area) / union_area + return iou + + +def cal_distance(p1, p2): + delta_x = p1[0] - p2[0] + delta_y = p1[1] - p2[1] + d = math.sqrt((delta_x**2) + (delta_y**2)) + return d + + +def is_inside(center_point, corner_point): + """ + Find if center_point inside the bbox(corner_point) or not. + :param center_point: center point (x, y) + :param corner_point: corner point ((x1,y1),(x2,y2)) + :return: + """ + x_flag = False + y_flag = False + if (center_point[0] >= corner_point[0][0]) and ( + center_point[0] <= corner_point[1][0]): + x_flag = True + if (center_point[1] >= corner_point[0][1]) and ( + center_point[1] <= corner_point[1][1]): + y_flag = True + if x_flag and y_flag: + return True + else: + return False + + +def find_no_match(match_list, all_end2end_nums, type='end2end'): + """ + Find out no match end2end bbox in previous match list. + :param match_list: matching pairs. + :param all_end2end_nums: numbers of end2end_xywh + :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1. + :return: no match pse bbox index list + """ + if type == 'end2end': + idx = 0 + elif type == 'master': + idx = 1 + else: + raise ValueError + + no_match_indexs = [] + # m[0] is end2end index m[1] is master index + matched_bbox_indexs = [m[idx] for m in match_list] + for n in range(all_end2end_nums): + if n not in matched_bbox_indexs: + no_match_indexs.append(n) + return no_match_indexs + + +def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3): + # only consider y axis, for grouping in row. + delta = abs(this_bbox[1] - target_bbox[1]) + if delta < threshold: + return True + else: + return False + + +def sort_line_bbox(g, bg): + """ + Sorted the bbox in the same line(group) + compare coord 'x' value, where 'y' value is closed in the same group. + :param g: index in the same group + :param bg: bbox in the same group + :return: + """ + + xs = [bg_item[0] for bg_item in bg] + xs_sorted = sorted(xs) + + g_sorted = [None] * len(xs_sorted) + bg_sorted = [None] * len(xs_sorted) + for g_item, bg_item in zip(g, bg): + idx = xs_sorted.index(bg_item[0]) + bg_sorted[idx] = bg_item + g_sorted[idx] = g_item + + return g_sorted, bg_sorted + + +def flatten(sorted_groups, sorted_bbox_groups): + idxs = [] + bboxes = [] + for group, bbox_group in zip(sorted_groups, sorted_bbox_groups): + for g, bg in zip(group, bbox_group): + idxs.append(g) + bboxes.append(bg) + return idxs, bboxes + + +def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes): + """ + This function will group the render end2end bboxes in row. + :param end2end_xywh_bboxes: + :param no_match_end2end_indexes: + :return: + """ + groups = [] + bbox_groups = [] + for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, + end2end_xywh_bboxes): + this_bbox = end2end_xywh_bbox + if len(groups) == 0: + groups.append([index]) + bbox_groups.append([this_bbox]) + else: + flag = False + for g, bg in zip(groups, bbox_groups): + # this_bbox is belong to bg's row or not + if is_abs_lower_than_threshold(this_bbox, bg[0]): + g.append(index) + bg.append(this_bbox) + flag = True + break + if not flag: + # this_bbox is not belong to bg's row, create a row. + groups.append([index]) + bbox_groups.append([this_bbox]) + + # sorted bboxes in a group + tmp_groups, tmp_bbox_groups = [], [] + for g, bg in zip(groups, bbox_groups): + g_sorted, bg_sorted = sort_line_bbox(g, bg) + tmp_groups.append(g_sorted) + tmp_bbox_groups.append(bg_sorted) + + # sorted groups, sort by coord y's value. + sorted_groups = [None] * len(tmp_groups) + sorted_bbox_groups = [None] * len(tmp_bbox_groups) + ys = [bg[0][1] for bg in tmp_bbox_groups] + sorted_ys = sorted(ys) + for g, bg in zip(tmp_groups, tmp_bbox_groups): + idx = sorted_ys.index(bg[0][1]) + sorted_groups[idx] = g + sorted_bbox_groups[idx] = bg + + # flatten, get final result + end2end_sorted_idx_list, end2end_sorted_bbox_list \ + = flatten(sorted_groups, sorted_bbox_groups) + + # check sorted + #img = cv2.imread('/data_0/yejiaquan/data/TableRecognization/singleVal/PMC3286376_004_00.png') + #img = drawBboxAfterSorted(img, sorted_groups, sorted_bbox_groups) + + return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups + + +def get_bboxes_list(end2end_result, structure_master_result): + """ + This function is use to convert end2end results and structure master results to + List of xyxy bbox format and List of xywh bbox format + :param end2end_result: bbox's format is xyxy + :param structure_master_result: bbox's format is xywh + :return: 4 kind list of bbox () + """ + # end2end + end2end_xyxy_list = [] + end2end_xywh_list = [] + for end2end_item in end2end_result: + src_bbox = end2end_item['bbox'] + end2end_xyxy_list.append(src_bbox) + xywh_bbox = xyxy2xywh(src_bbox) + end2end_xywh_list.append(xywh_bbox) + end2end_xyxy_bboxes = np.array(end2end_xyxy_list) + end2end_xywh_bboxes = np.array(end2end_xywh_list) + + # structure master + src_bboxes = structure_master_result['bbox'] + src_bboxes = remove_empty_bboxes(src_bboxes) + # structure_master_xywh_bboxes = src_bboxes + # xyxy_bboxes = xywh2xyxy(src_bboxes) + # structure_master_xyxy_bboxes = xyxy_bboxes + structure_master_xyxy_bboxes = src_bboxes + xywh_bbox = xyxy2xywh(src_bboxes) + structure_master_xywh_bboxes = xywh_bbox + + return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes + + +def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes): + """ + Judge end2end Bbox's center point is inside structure master Bbox or not, + if end2end Bbox's center is in structure master Bbox, get matching pair. + :param end2end_xywh_bboxes: + :param structure_master_xyxy_bboxes: + :return: match pairs list, e.g. [[0,1], [1,2], ...] + """ + match_pairs_list = [] + for i, end2end_xywh in enumerate(end2end_xywh_bboxes): + for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): + x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1] + x_master1, y_master1, x_master2, y_master2 \ + = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3] + center_point_end2end = (x_end2end, y_end2end) + corner_point_master = ((x_master1, y_master1), + (x_master2, y_master2)) + if is_inside(center_point_end2end, corner_point_master): + match_pairs_list.append([i, j]) + return match_pairs_list + + +def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes, + structure_master_xyxy_bboxes): + """ + Use iou to find matching list. + choose max iou value bbox as match pair. + :param end2end_xyxy_bboxes: + :param end2end_xyxy_indexes: original end2end indexes. + :param structure_master_xyxy_bboxes: + :return: match pairs list, e.g. [[0,1], [1,2], ...] + """ + match_pair_list = [] + for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes, + end2end_xyxy_bboxes): + max_iou = 0 + max_match = [None, None] + for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): + end2end_4xy = convert_coord(end2end_xyxy) + master_4xy = convert_coord(master_xyxy) + iou = cal_iou(end2end_4xy, master_4xy) + if iou > max_iou: + max_match[0], max_match[1] = end2end_xyxy_index, j + max_iou = iou + + if max_match[0] is None: + # no match + continue + match_pair_list.append(max_match) + return match_pair_list + + +def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, + master_bboxes): + """ + Get matching between no-match end2end bboxes and no-match master bboxes. + Use min distance to match. + This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0) + It will Return master_bboxes_nums match-pairs. + :param end2end_indexes: + :param end2end_bboxes: + :param master_indexes: + :param master_bboxes: + :return: match_pairs list, e.g. [[0,1], [1,2], ...] + """ + min_match_list = [] + for j, master_bbox in zip(master_indexes, master_bboxes): + min_distance = np.inf + min_match = [0, 0] # i, j + for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes): + x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1] + x_master, y_master = master_bbox[0], master_bbox[1] + end2end_point = (x_end2end, y_end2end) + master_point = (x_master, y_master) + dist = cal_distance(master_point, end2end_point) + if dist < min_distance: + min_match[0], min_match[1] = i, j + min_distance = dist + min_match_list.append(min_match) + return min_match_list + + +def extra_match(no_match_end2end_indexes, master_bbox_nums): + """ + This function will create some virtual master bboxes, + and get match with the no match end2end indexes. + :param no_match_end2end_indexes: + :param master_bbox_nums: + :return: + """ + end_nums = len(no_match_end2end_indexes) + master_bbox_nums + extra_match_list = [] + for i in range(master_bbox_nums, end_nums): + end2end_index = no_match_end2end_indexes[i - master_bbox_nums] + extra_match_list.append([end2end_index, i]) + return extra_match_list + + +def match_visual(file_name, + match_list, + end2end_xyxy, + master_xyxy, + prex='ordinary_match'): + """ + Show the match result by xyxy coord style. + :param file_name: + :param match_list: + :param end2end_xyxy: + :param master_xyxy: + :param prex: + :return: + """ + folder = '' + save_folder = '/data_0/cache' + file_path = os.path.join(folder, file_name) + img_end2end = cv2.imread(file_path) + img_master = copy.deepcopy(img_end2end) + text_color = (0, 0, 255) + bbox_color = (255, 0, 0) + master_nums = len(master_xyxy) + + for idx, match_group in enumerate(match_list): + end2end_idx, master_index = match_group[0], match_group[1] + + # master_index larger than master_nums, did not draw master bbox. + if master_index < master_nums: + # draw master + master_bbox = master_xyxy[master_index] + img_master = cv2.rectangle( + img_master, (int(master_bbox[0]), int(master_bbox[1])), + (int(master_bbox[2]), int(master_bbox[3])), + bbox_color, + thickness=1) + master_text_coord = (int(master_bbox[0]) - 4, int(master_bbox[1])) + img_master = cv2.putText(img_master, + str(master_index), master_text_coord, 1, 1, + text_color, 2) + + # draw end2end + end2end_bbox = end2end_xyxy[end2end_idx] + img_end2end = cv2.rectangle( + img_end2end, (int(end2end_bbox[0]), int(end2end_bbox[1])), + (int(end2end_bbox[2]), int(end2end_bbox[3])), + bbox_color, + thickness=1) + end2end_text_coord = (int(end2end_bbox[0]) - 4, int(end2end_bbox[1])) + # write end2end bbox matching master bbox's index + img_end2end = cv2.putText(img_end2end, + str(master_index), end2end_text_coord, 1, 1, + text_color, 2) + + img = np.hstack([img_end2end, img_master]) + save_path = os.path.join(save_folder, '{}_matchShow.png'.format(prex)) + cv2.imwrite(save_path, img) + + +def get_match_dict(match_list): + """ + Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index. + :param match_list: + :return: + """ + match_dict = dict() + for match_pair in match_list: + end2end_index, master_index = match_pair[0], match_pair[1] + if master_index not in match_dict.keys(): + match_dict[master_index] = [end2end_index] + else: + match_dict[master_index].append(end2end_index) + return match_dict + + +def deal_successive_space(text): + """ + deal successive space character for text + 1. Replace ' '*3 with '' which is real space is text + 2. Remove ' ', which is split token, not true space + 3. Replace '' with ' ', to get real text + :param text: + :return: + """ + text = text.replace(' ' * 3, '') + text = text.replace(' ', '') + text = text.replace('', ' ') + return text + + +def reduce_repeat_bb(text_list, break_token): + """ + convert ['Local', 'government', 'unit'] to ['Local government unit'] + PS: maybe style Local is also exist, too. it can be processed like this. + :param text_list: + :param break_token: + :return: + """ + count = 0 + for text in text_list: + if text.startswith(''): + count += 1 + if count == len(text_list): + new_text_list = [] + for text in text_list: + text = text.replace('', '').replace('', '') + new_text_list.append(text) + return ['' + break_token.join(new_text_list) + ''] + else: + return text_list + + +def get_match_text_dict(match_dict, end2end_info, break_token=' '): + match_text_dict = dict() + for master_index, end2end_index_list in match_dict.items(): + text_list = [ + end2end_info[end2end_index]['text'] + for end2end_index in end2end_index_list + ] + text_list = reduce_repeat_bb(text_list, break_token) + text = break_token.join(text_list) + match_text_dict[master_index] = text + return match_text_dict + + +def merge_span_token(master_token_list): + """ + Merge the span style token (row span or col span). + :param master_token_list: + :return: + """ + new_master_token_list = [] + pointer = 0 + if master_token_list[-1] != '': + master_token_list.append('') + while master_token_list[pointer] != '': + try: + if master_token_list[pointer] == '' + '' + """ + # tmp = master_token_list[pointer] + master_token_list[pointer+1] + master_token_list[pointer+2] + \ + # master_token_list[pointer+3] + tmp = ''.join(master_token_list[pointer:pointer + 3 + 1]) + pointer += 4 + new_master_token_list.append(tmp) + + elif master_token_list[pointer + 2].startswith( + ' colspan=') or master_token_list[ + pointer + 2].startswith(' rowspan='): + """ + example: + pattern + '' + '' + """ + # tmp = master_token_list[pointer] + master_token_list[pointer+1] + \ + # master_token_list[pointer+2] + master_token_list[pointer+3] + master_token_list[pointer+4] + tmp = ''.join(master_token_list[pointer:pointer + 4 + 1]) + pointer += 5 + new_master_token_list.append(tmp) + + else: + new_master_token_list.append(master_token_list[pointer]) + pointer += 1 + else: + new_master_token_list.append(master_token_list[pointer]) + pointer += 1 + except: + print("Break in merge...") + break + new_master_token_list.append('') + + return new_master_token_list + + +def deal_eb_token(master_token): + """ + post process with , , ... + emptyBboxTokenDict = { + "[]": '', + "[' ']": '', + "['', ' ', '']": '', + "['\\u2028', '\\u2028']": '', + "['', ' ', '']": '', + "['', '']": '', + "['', ' ', '']": '', + "['', '', '', '']": '', + "['', '', ' ', '', '']": '', + "['', '']": '', + "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', + } + :param master_token: + :return: + """ + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '\u2028\u2028') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', + '') + master_token = master_token.replace('', + ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', + ' \u2028 \u2028 ') + return master_token + + +def insert_text_to_token(master_token_list, match_text_dict): + """ + Insert OCR text result to structure token. + :param master_token_list: + :param match_text_dict: + :return: + """ + master_token_list = merge_span_token(master_token_list) + merged_result_list = [] + text_count = 0 + for master_token in master_token_list: + if master_token.startswith(' len(match_text_dict) - 1: + text_count += 1 + continue + elif text_count not in match_text_dict.keys(): + text_count += 1 + continue + else: + master_token = master_token.replace( + '><', '>{}<'.format(match_text_dict[text_count])) + text_count += 1 + master_token = deal_eb_token(master_token) + merged_result_list.append(master_token) + + return ''.join(merged_result_list) + + +def deal_isolate_span(thead_part): + """ + Deal with isolate span cases in this function. + It causes by wrong prediction in structure recognition model. + eg. predict to rowspan="2">. + :param thead_part: + :return: + """ + # 1. find out isolate span tokens. + isolate_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\">|" \ + " colspan=\"(\d)+\" rowspan=\"(\d)+\">|" \ + " rowspan=\"(\d)+\">|" \ + " colspan=\"(\d)+\">" + isolate_iter = re.finditer(isolate_pattern, thead_part) + isolate_list = [i.group() for i in isolate_iter] + + # 2. find out span number, by step 1 results. + span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \ + " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \ + " rowspan=\"(\d)+\"|" \ + " colspan=\"(\d)+\"" + corrected_list = [] + for isolate_item in isolate_list: + span_part = re.search(span_pattern, isolate_item) + spanStr_in_isolateItem = span_part.group() + # 3. merge the span number into the span token format string. + if spanStr_in_isolateItem is not None: + corrected_item = ''.format(spanStr_in_isolateItem) + corrected_list.append(corrected_item) + else: + corrected_list.append(None) + + # 4. replace original isolated token. + for corrected_item, isolate_item in zip(corrected_list, isolate_list): + if corrected_item is not None: + thead_part = thead_part.replace(isolate_item, corrected_item) + else: + pass + return thead_part + + +def deal_duplicate_bb(thead_part): + """ + Deal duplicate or after replace. + Keep one in a token. + :param thead_part: + :return: + """ + # 1. find out in . + td_pattern = "(.+?)|" \ + "(.+?)|" \ + "(.+?)|" \ + "(.+?)|" \ + "(.*?)" + td_iter = re.finditer(td_pattern, thead_part) + td_list = [t.group() for t in td_iter] + + # 2. is multiply in or not? + new_td_list = [] + for td_item in td_list: + if td_item.count('') > 1 or td_item.count('') > 1: + # multiply in case. + # 1. remove all + td_item = td_item.replace('', '').replace('', '') + # 2. replace -> , -> . + td_item = td_item.replace('', '').replace('', + '') + new_td_list.append(td_item) + else: + new_td_list.append(td_item) + + # 3. replace original thead part. + for td_item, new_td_item in zip(td_list, new_td_list): + thead_part = thead_part.replace(td_item, new_td_item) + return thead_part + + +def deal_bb(result_token): + """ + In our opinion, always occurs in text's context. + This function will find out all tokens in and insert by manual. + :param result_token: + :return: + """ + # find out parts. + thead_pattern = '(.*?)' + if re.search(thead_pattern, result_token) is None: + return result_token + thead_part = re.search(thead_pattern, result_token).group() + origin_thead_part = copy.deepcopy(thead_part) + + # check "rowspan" or "colspan" occur in parts or not . + span_pattern = "|||" + span_iter = re.finditer(span_pattern, thead_part) + span_list = [s.group() for s in span_iter] + has_span_in_head = True if len(span_list) > 0 else False + + if not has_span_in_head: + # not include "rowspan" or "colspan" branch 1. + # 1. replace to , and to + # 2. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + thead_part = thead_part.replace('', '')\ + .replace('', '')\ + .replace('', '')\ + .replace('', '') + else: + # include "rowspan" or "colspan" branch 2. + # Firstly, we deal rowspan or colspan cases. + # 1. replace > to > + # 2. replace to + # 3. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + + # Secondly, deal ordinary cases like branch 1 + + # replace ">" to "" + replaced_span_list = [] + for sp in span_list: + replaced_span_list.append(sp.replace('>', '>')) + for sp, rsp in zip(span_list, replaced_span_list): + thead_part = thead_part.replace(sp, rsp) + + # replace "" to "" + thead_part = thead_part.replace('', '') + + # remove duplicated by re.sub + mb_pattern = "()+" + single_b_string = "" + thead_part = re.sub(mb_pattern, single_b_string, thead_part) + + mgb_pattern = "()+" + single_gb_string = "" + thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) + + # ordinary cases like branch 1 + thead_part = thead_part.replace('', '').replace('', + '') + + # convert back to , empty cell has no . + # but space cell( ) is suitable for + thead_part = thead_part.replace('', '') + # deal with duplicated + thead_part = deal_duplicate_bb(thead_part) + # deal with isolate span tokens, which causes by wrong predict by structure prediction. + # eg.PMC5994107_011_00.png + thead_part = deal_isolate_span(thead_part) + # replace original result with new thead part. + result_token = result_token.replace(origin_thead_part, thead_part) + return result_token + + +class Matcher: + def __init__(self, end2end_file, structure_master_file): + """ + This class process the end2end results and structure recognition results. + :param end2end_file: end2end results predict by end2end inference. + :param structure_master_file: structure recognition results predict by structure master inference. + """ + self.end2end_file = end2end_file + self.structure_master_file = structure_master_file + self.end2end_results = pickle_load(end2end_file, prefix='end2end') + self.structure_master_results = pickle_load( + structure_master_file, prefix='structure') + + def match(self): + """ + Match process: + pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format. + 1. Use pseBbox is inside masterBbox judge rule + 2. Use iou between pseBbox and masterBbox rule + 3. Use min distance of center point rule + :return: + """ + match_results = dict() + for idx, (file_name, + end2end_result) in enumerate(self.end2end_results.items()): + match_list = [] + if file_name not in self.structure_master_results: + continue + structure_master_result = self.structure_master_results[file_name] + end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \ + get_bboxes_list(end2end_result, structure_master_result) + + # rule 1: center rule + center_rule_match_list = \ + center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes) + match_list.extend(center_rule_match_list) + + # rule 2: iou rule + # firstly, find not match index in previous step. + center_no_match_end2end_indexs = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + if len(center_no_match_end2end_indexs) > 0: + center_no_match_end2end_xyxy = end2end_xyxy_bboxes[ + center_no_match_end2end_indexs] + # secondly, iou rule match + iou_rule_match_list = \ + iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes) + match_list.extend(iou_rule_match_list) + + # rule 3: distance rule + # match between no-match end2end bboxes and no-match master bboxes. + # it will return master_bboxes_nums match-pairs. + # firstly, find not match index in previous step. + centerIou_no_match_end2end_indexs = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + centerIou_no_match_master_indexs = \ + find_no_match(match_list, len(structure_master_xywh_bboxes), type='master') + if len(centerIou_no_match_master_indexs) > 0 and len( + centerIou_no_match_end2end_indexs) > 0: + centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[ + centerIou_no_match_end2end_indexs] + centerIou_no_match_master_xywh = structure_master_xywh_bboxes[ + centerIou_no_match_master_indexs] + distance_match_list = distance_rule_match( + centerIou_no_match_end2end_indexs, + centerIou_no_match_end2end_xywh, + centerIou_no_match_master_indexs, + centerIou_no_match_master_xywh) + match_list.extend(distance_match_list) + + # TODO: + # The render no-match pseBbox, insert the last + # After step3 distance rule, a master bbox at least match one end2end bbox. + # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length. + # For these render end2end bboxes, we will make some virtual master bboxes, and get matching. + # The above extra insert bboxes will be further processed in "formatOutput" function. + # After this operation, it will increase TEDS score. + no_match_end2end_indexes = \ + find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') + if len(no_match_end2end_indexes) > 0: + no_match_end2end_xywh = end2end_xywh_bboxes[ + no_match_end2end_indexes] + # sort the render no-match end2end bbox in row + end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \ + sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes) + # make virtual master bboxes, and get matching with the no-match end2end bboxes. + extra_match_list = extra_match( + end2end_sorted_indexes_list, + len(structure_master_xywh_bboxes)) + match_list_add_extra_match = copy.deepcopy(match_list) + match_list_add_extra_match.extend(extra_match_list) + else: + # no no-match end2end bboxes + match_list_add_extra_match = copy.deepcopy(match_list) + sorted_groups = [] + sorted_bboxes_groups = [] + + match_result_dict = { + 'match_list': match_list, + 'match_list_add_extra_match': match_list_add_extra_match, + 'sorted_groups': sorted_groups, + 'sorted_bboxes_groups': sorted_bboxes_groups + } + + # ordinary match show + # match_visual(file_name, match_list, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='ordinary_match') + # extra match show + # match_visual(file_name, match_list_add_extra_match, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='extra_match') + + # format output + match_result_dict = self._format(match_result_dict, file_name) + + match_results[file_name] = match_result_dict + + return match_results + + def _format(self, match_result, file_name): + """ + Extend the master token(insert virtual master token), and format matching result. + :param match_result: + :param file_name: + :return: + """ + end2end_info = self.end2end_results[file_name] + master_info = self.structure_master_results[file_name] + master_token = master_info['text'] + sorted_groups = match_result['sorted_groups'] + + # creat virtual master token + virtual_master_token_list = [] + for line_group in sorted_groups: + tmp_list = [''] + item_nums = len(line_group) + for _ in range(item_nums): + tmp_list.append('') + tmp_list.append('') + virtual_master_token_list.extend(tmp_list) + + # insert virtual master token + master_token_list = master_token.split(',') + if master_token_list[-1] == '': + # complete predict(no cut by max length) + # This situation insert virtual master token will drop TEDs score in val set. + # So we will not extend virtual token in this situation. + + # fake extend virtual + master_token_list[:-1].extend(virtual_master_token_list) + + # real extend virtual + # master_token_list = master_token_list[:-1] + # master_token_list.extend(virtual_master_token_list) + # master_token_list.append('') + + elif master_token_list[-1] == '': + master_token_list.append('') + master_token_list.extend(virtual_master_token_list) + master_token_list.append('') + else: + master_token_list.extend(virtual_master_token_list) + master_token_list.append('') + + # format output + match_result.setdefault('matched_master_token_list', master_token_list) + return match_result + + def get_merge_result(self, match_results): + """ + Merge the OCR result into structure token to get final results. + :param match_results: + :return: + """ + merged_results = dict() + + # break_token is linefeed token, when one master bbox has multiply end2end bboxes. + break_token = ' ' + + for idx, (file_name, match_info) in enumerate(match_results.items()): + end2end_info = self.end2end_results[file_name] + master_token_list = match_info['matched_master_token_list'] + match_list = match_info['match_list_add_extra_match'] + + match_dict = get_match_dict(match_list) + match_text_dict = get_match_text_dict(match_dict, end2end_info, + break_token) + merged_result = insert_text_to_token(master_token_list, + match_text_dict) + merged_result = deal_bb(merged_result) + + merged_results[file_name] = merged_result + + return merged_results + + +class TableMasterMatcher(Matcher): + def __init__(self): + pass + + def __call__(self, structure_res, dt_boxes, rec_res, img_name=1): + end2end_results = {img_name: []} + for dt_box, res in zip(dt_boxes, rec_res): + d = dict( + bbox=np.array(dt_box), + text=res[0], ) + end2end_results[img_name].append(d) + + self.end2end_results = end2end_results + + structure_master_result_dict = {img_name: {}} + pred_structures, pred_bboxes = structure_res + pred_structures = ','.join(pred_structures[3:-3]) + structure_master_result_dict[img_name]['text'] = pred_structures + structure_master_result_dict[img_name]['bbox'] = pred_bboxes + self.structure_master_results = structure_master_result_dict + + # match + match_results = self.match() + merged_results = self.get_merge_result(match_results) + pred_html = merged_results[img_name] + # pred_html = '' + pred_html + '
' + return pred_html diff --git a/ppstructure/utility.py b/ppstructure/utility.py index af0616239..f5388fabf 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -32,6 +32,7 @@ def init_args(): type=str, default="../ppocr/utils/dict/table_structure_dict.txt") # params for layout + parser.add_argument("--layout_model_dir", type=str) parser.add_argument( "--layout_path_model", type=str, @@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path): image = Image.fromarray(image) boxes, txts, scores = [], [], [] for region in result: - if region['type'] == 'Table': + if region['type'] == 'table': pass else: for text_result in region['res']: diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml index 5d8e84c95..6ff31fc26 100755 --- a/test_tipc/configs/en_table_structure/table_mv3.yml +++ b/test_tipc/configs/en_table_structure/table_mv3.yml @@ -19,8 +19,6 @@ Global: character_type: en max_text_length: 800 infer_mode: False - process_total_num: 0 - process_cut_num: 0 Optimizer: name: Adam diff --git a/test_tipc/configs/table_master/table_master.yml b/test_tipc/configs/table_master/table_master.yml index c519b5b8f..27f81683b 100644 --- a/test_tipc/configs/table_master/table_master.yml +++ b/test_tipc/configs/table_master/table_master.yml @@ -16,8 +16,6 @@ Global: character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt infer_mode: false max_text_length: 500 - process_total_num: 0 - process_cut_num: 0 Optimizer: @@ -86,7 +84,7 @@ Train: - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: 'xywh' - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] @@ -120,7 +118,7 @@ Eval: - PaddingTableImage: size: [480, 480] - TableBoxEncode: - use_xywh: True + box_format: 'xywh' - NormalizeImage: scale: 1./255. mean: [0.5, 0.5, 0.5] diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 625d365f4..73b7155ba 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -65,9 +65,11 @@ class TextSystem(object): self.crop_image_res_index += bbox_num def __call__(self, img, cls=True): + time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0} + start = time.time() ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - + time_dict['det'] = elapse logger.debug("dt_boxes num : {}, elapse : {}".format( len(dt_boxes), elapse)) if dt_boxes is None: @@ -83,10 +85,12 @@ class TextSystem(object): if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) + time_dict['cls'] = elapse logger.debug("cls num : {}, elapse : {}".format( len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) + time_dict['rec'] = elapse logger.debug("rec_res num : {}, elapse : {}".format( len(rec_res), elapse)) if self.args.save_crop_res: @@ -98,7 +102,9 @@ class TextSystem(object): if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) - return filter_boxes, filter_rec_res + end = time.time() + time_dict['all'] = end - start + return filter_boxes, filter_rec_res, time_dict def sorted_boxes(dt_boxes): @@ -133,9 +139,11 @@ def main(args): os.makedirs(draw_img_save_dir, exist_ok=True) save_results = [] - logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " - "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320") - + logger.info( + "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " + "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" + ) + # warm up 10 times if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) @@ -155,7 +163,7 @@ def main(args): logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() - dt_boxes, rec_res = text_sys(img) + dt_boxes, rec_res, time_dict = text_sys(img) elapse = time.time() - starttime total_time += elapse @@ -198,7 +206,10 @@ def main(args): text_sys.text_detector.autolog.report() text_sys.text_recognizer.autolog.report() - with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f: + with open( + os.path.join(draw_img_save_dir, "system_results.txt"), + 'w', + encoding='utf-8') as f: f.writelines(save_results) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 7eb77dec7..6ad770e28 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -155,6 +155,8 @@ def create_predictor(args, mode, logger): model_dir = args.table_model_dir elif mode == 'ser': model_dir = args.ser_model_dir + elif mode == 'layout': + model_dir = args.layout_model_dir else: model_dir = args.e2e_model_dir diff --git a/tools/infer_table.py b/tools/infer_table.py index 6c02dd864..70dc6205d 100644 --- a/tools/infer_table.py +++ b/tools/infer_table.py @@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer): model = build_model(config['Architecture']) algorithm = config['Architecture']['algorithm'] - use_xywh = algorithm in ['TableMaster'] load_model(config, model) @@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer): f_w.write("result: {}, {}\n".format(structure_str_list, bbox_list_str)) - img = draw_rectangle(file, bbox_list, use_xywh) + img = draw_rectangle(file, bbox_list) cv2.imwrite( os.path.join(save_res_path, os.path.basename(file)), img) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index 0fa0e609b..1802e8529 100755 --- a/tools/program.py +++ b/tools/program.py @@ -154,6 +154,7 @@ def check_xpu(use_xpu): except Exception as e: pass + def to_float32(preds): if isinstance(preds, dict): for k in preds: @@ -173,6 +174,7 @@ def to_float32(preds): preds = preds.astype(paddle.float32) return preds + def train(config, train_dataloader, valid_dataloader, @@ -596,7 +598,7 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', - 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN' + 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'SLANet' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index 309d4bb9e..0e45b5b70 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + use_sync_bn = config["Global"].get("use_sync_bn", False) + if use_sync_bn: + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + logger.info('convert_sync_batchnorm') if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer): scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True) + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level='O2', master_weight=True) else: scaler = None