diff --git a/.gitignore b/.gitignore
index 3300be325..410be83f4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,3 +32,31 @@ paddleocr.egg-info/
/deploy/android_demo/app/cache/
test_tipc/web/models/
test_tipc/web/node_modules/
+en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams
+en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams.info
+en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdmodel
+en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams
+en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams.info
+en_ppocr_mobile_v2.0_table_structure_infer/inference.pdmodel
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams.info
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdmodel
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams.info
+ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdmodel
+.gitignore
+.gitignore
+ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams
+ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams.info
+ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdmodel
+ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/infer_cfg.yml
+ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams
+ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams.info
+ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdmodel
+.gitignore
+ppstructure/layout/table/inference.pdiparams
+ppstructure/layout/table/inference.pdiparams.info
+ppstructure/layout/table/inference.pdmodel
+ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape.tar
+._en_ppocr_mobile_v2.0_table_structure_infer
+en_ppocr_mobile_v2.0_table_structure_infer.tar
diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml
new file mode 100644
index 000000000..ee2584d52
--- /dev/null
+++ b/configs/table/SLANet.yml
@@ -0,0 +1,141 @@
+Global:
+ use_gpu: true
+ epoch_num: 400
+ log_smooth_window: 20
+ print_batch_step: 20
+ save_model_dir: ./output/SLANet
+ save_epoch_step: 400
+ # evaluation is run every 1000 iterations after the 0th iteration
+ eval_batch_step: [0, 1000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints: /ssd1/zhoujun20/table/ch/PaddleOCR/output/en/table_lcnet_1_0_csp_pan_headsv3_smooth_l1_pretrain_ssld_weight81_sync_bn/best_accuracy.pdparams
+ save_inference_dir: ./output/SLANet/infer
+ use_visualdl: False
+ infer_img: doc/table/table.jpg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/table_structure_dict.txt
+ character_type: en
+ max_text_length: &max_text_length 500
+ box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
+ infer_mode: False
+ use_sync_bn: True
+ save_res_path: 'output/infer'
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 5.0
+ lr:
+ # name: Piecewise
+ learning_rate: 0.001
+ # decay_epochs : [10, 20]
+ # values : [0.002, 0.0002, 0.0001]
+ # warmup_epoch: 0
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: table
+ algorithm: SLANet
+ Backbone:
+ name: PPLCNet
+ scale: 1.0
+ pretrained: true
+ use_ssld: true
+ Neck:
+ name: CSPPAN
+ out_channels: 96
+ Head:
+ name: SLAHead
+ hidden_size: 256
+ max_text_length: *max_text_length
+ loc_reg_num: &loc_reg_num 4
+
+Loss:
+ name: SLANetLoss
+ structure_weight: 1.0
+ loc_weight: 2.0
+ loc_loss: smooth_l1
+
+PostProcess:
+ name: TableLabelDecode
+
+Metric:
+ name: TableMetric
+ main_indicator: acc
+ compute_bbox_metric: False
+ loc_reg_num: *loc_reg_num
+ box_format: *box_format
+
+Train:
+ dataset:
+ name: PubTabDataSet
+ data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/
+ label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: False
+ replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
+ - TableBoxEncode:
+ box_format: *box_format
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ loader:
+ shuffle: True
+ batch_size_per_card: 48
+ drop_last: True
+ num_workers: 1
+
+Eval:
+ dataset:
+ name: PubTabDataSet
+ data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
+ label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_val.jsonl]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - TableLabelEncode:
+ learn_empty_box: False
+ merge_no_span_structure: False
+ replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
+ - TableBoxEncode:
+ box_format: *box_format
+ - ResizeTableImage:
+ max_len: 488
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - PaddingTableImage:
+ size: [488, 488]
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 48
+ num_workers: 1
diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml
index 1e6efe32d..8bed7d069 100755
--- a/configs/table/table_master.yml
+++ b/configs/table/table_master.yml
@@ -15,9 +15,8 @@ Global:
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
- max_text_length: 500
- process_total_num: 0
- process_cut_num: 0
+ max_text_length: &max_text_length 500
+ box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
Optimizer:
@@ -52,7 +51,8 @@ Architecture:
headers: 8
dropout: 0
d_ff: 2024
- max_text_length: 500
+ max_text_length: *max_text_length
+ loc_reg_num: &loc_reg_num 4
Loss:
name: TableMasterLoss
@@ -66,6 +66,7 @@ Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
+ box_format: *box_format
Train:
dataset:
@@ -80,13 +81,15 @@ Train:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
- use_xywh: True
+ box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
@@ -114,13 +117,15 @@ Eval:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
- use_xywh: True
+ box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
index 66c1c83e1..87cda7db2 100755
--- a/configs/table/table_mv3.yml
+++ b/configs/table/table_mv3.yml
@@ -17,10 +17,9 @@ Global:
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
- max_text_length: 800
+ max_text_length: &max_text_length 800
+ box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
- process_total_num: 0
- process_cut_num: 0
Optimizer:
name: Adam
@@ -44,7 +43,8 @@ Architecture:
name: TableAttentionHead
hidden_size: 256
loc_type: 2
- max_text_length: 800
+ max_text_length: *max_text_length
+ loc_reg_num: &loc_reg_num 4
Loss:
name: TableAttentionLoss
@@ -72,6 +72,8 @@ Train:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
@@ -104,6 +106,8 @@ Eval:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
+ loc_reg_num: *loc_reg_num
+ max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py
index 71a19c6b7..dff3abb48 100644
--- a/deploy/hubserving/ocr_system/module.py
+++ b/deploy/hubserving/ocr_system/module.py
@@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
all_results.append([])
continue
starttime = time.time()
- dt_boxes, rec_res = self.text_sys(img)
+ dt_boxes, rec_res, _ = self.text_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 97539faf2..ad391046a 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
- point_num=2,
+ loc_reg_num=4,
**kwargs):
self.max_text_len = max_text_length
self.lower = False
@@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode):
self.idx2char = {v: k for k, v in self.dict.items()}
self.character = dict_character
- self.point_num = point_num
+ self.loc_reg_num = loc_reg_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
@@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode):
# encode box
bboxes = np.zeros(
- (self._max_text_len, self.point_num * 2), dtype=np.float32)
+ (self._max_text_len, self.loc_reg_num), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
@@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
- point_num=2,
+ loc_reg_num=4,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
- merge_no_span_structure, learn_empty_box, point_num, **kwargs)
+ merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode):
class TableBoxEncode(object):
- def __init__(self, use_xywh=False, **kwargs):
- self.use_xywh = use_xywh
+ def __init__(self, box_format='xyxy', **kwargs):
+ assert box_format in ['xywh', 'xyxy', 'xyxyxyxy']
+ self.box_format = box_format
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
- if self.use_xywh and bboxes.shape[1] == 4:
+ if self.box_format == 'xywh' and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
@@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character = [''] + dict_character
return dict_character
+
class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
@@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode):
super(SPINLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower
+
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
@@ -1248,4 +1251,4 @@ class SPINLabelEncode(AttnLabelEncode):
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
- return data
\ No newline at end of file
+ return data
diff --git a/ppocr/data/imaug/table_ops.py b/ppocr/data/imaug/table_ops.py
index 8d139190a..c2c2fb2be 100644
--- a/ppocr/data/imaug/table_ops.py
+++ b/ppocr/data/imaug/table_ops.py
@@ -206,7 +206,7 @@ class ResizeTableImage(object):
data['bboxes'] = data['bboxes'] * ratio
data['image'] = resize_img
data['src_img'] = img
- data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
+ data['shape'] = np.array([height, width, ratio, ratio])
data['max_len'] = self.max_len
return data
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 30120ac56..4629f0fe4 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss
# table loss
-from .table_att_loss import TableAttentionLoss
+from .table_att_loss import TableAttentionLoss, SLANetLoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
@@ -63,7 +63,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss', 'SLANetLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
index 3496c9072..d97715d54 100644
--- a/ppocr/losses/table_att_loss.py
+++ b/ppocr/losses/table_att_loss.py
@@ -22,65 +22,11 @@ from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer):
- def __init__(self,
- structure_weight,
- loc_weight,
- use_giou=False,
- giou_weight=1.0,
- **kwargs):
+ def __init__(self, structure_weight, loc_weight, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
- self.use_giou = use_giou
- self.giou_weight = giou_weight
-
- def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
- '''
- :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
- :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
- :return: loss
- '''
- ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
- iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
- ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
- iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
-
- iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
- ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
-
- # overlap
- inters = iw * ih
-
- # union
- uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
- preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
- ) * (bbox[:, 3] - bbox[:, 1] +
- 1e-3) - inters + eps
-
- # ious
- ious = inters / uni
-
- ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
- ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
- ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
- ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
- ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
- eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
-
- # enclose erea
- enclose = ew * eh + eps
- giou = ious - (enclose - uni) / enclose
-
- loss = 1 - giou
-
- if reduction == 'mean':
- loss = paddle.mean(loss)
- elif reduction == 'sum':
- loss = paddle.sum(loss)
- else:
- raise NotImplementedError
- return loss
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
@@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer):
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
- if self.use_giou:
- loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
- loc_targets) * self.giou_weight
- total_loss = structure_loss + loc_loss + loc_loss_giou
- return {
- 'loss': total_loss,
- "structure_loss": structure_loss,
- "loc_loss": loc_loss,
- "loc_loss_giou": loc_loss_giou
- }
- else:
- total_loss = structure_loss + loc_loss
- return {
- 'loss': total_loss,
- "structure_loss": structure_loss,
- "loc_loss": loc_loss
- }
+
+ total_loss = structure_loss + loc_loss
+ return {
+ 'loss': total_loss,
+ "structure_loss": structure_loss,
+ "loc_loss": loc_loss
+ }
+
+
+class SLANetLoss(nn.Layer):
+ def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
+ super(SLANetLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
+ self.structure_weight = structure_weight
+ self.loc_weight = loc_weight
+ self.loc_loss = loc_loss
+ self.eps = 1e-12
+
+ def forward(self, predicts, batch):
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1].astype("int64")
+ structure_targets = structure_targets[:, 1:]
+
+ structure_loss = self.loss_func(structure_probs, structure_targets)
+
+ structure_loss = paddle.mean(structure_loss) * self.structure_weight
+
+ loc_preds = predicts['loc_preds']
+ loc_targets = batch[2].astype("float32")
+ loc_targets_mask = batch[3].astype("float32")
+ loc_targets = loc_targets[:, 1:, :]
+ loc_targets_mask = loc_targets_mask[:, 1:, :]
+
+ loc_loss = F.smooth_l1_loss(
+ loc_preds * loc_targets_mask,
+ loc_targets * loc_targets_mask,
+ reduction='sum') * self.loc_weight
+
+ loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
+ total_loss = structure_loss + loc_loss
+ return {
+ 'loss': total_loss,
+ "structure_loss": structure_loss,
+ "loc_loss": loc_loss
+ }
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
index fd2631e44..43dc1d761 100644
--- a/ppocr/metrics/table_metric.py
+++ b/ppocr/metrics/table_metric.py
@@ -59,7 +59,7 @@ class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
- point_num=2,
+ box_format='xyxy',
**kwargs):
"""
@@ -70,7 +70,7 @@ class TableMetric(object):
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
- self.point_num = point_num
+ self.box_format = box_format
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
@@ -129,10 +129,14 @@ class TableMetric(object):
self.bbox_metric.reset()
def format_box(self, box):
- if self.point_num == 2:
+ if self.box_format == 'xyxy':
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- elif self.point_num == 4:
+ elif self.box_format == 'xywh':
+ x, y, w, h = box
+ x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
+ box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+ elif self.box_format == 'xyxyxyxy':
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index d4f5b15f5..f5d54150b 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -21,7 +21,10 @@ def build_backbone(config, model_type):
from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
- support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
+ from .det_pp_lcnet import PPLCNet
+ support_dict = [
+ "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
+ ]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
diff --git a/ppocr/modeling/backbones/det_pp_lcnet.py b/ppocr/modeling/backbones/det_pp_lcnet.py
new file mode 100644
index 000000000..3f719e92b
--- /dev/null
+++ b/ppocr/modeling/backbones/det_pp_lcnet.py
@@ -0,0 +1,271 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, division, print_function
+
+import os
+import paddle
+import paddle.nn as nn
+from paddle import ParamAttr
+from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import KaimingNormal
+from paddle.utils.download import get_path_from_url
+
+MODEL_URLS = {
+ "PPLCNet_x0.25":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
+ "PPLCNet_x0.35":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
+ "PPLCNet_x0.5":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
+ "PPLCNet_x0.75":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
+ "PPLCNet_x1.0":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
+ "PPLCNet_x1.5":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
+ "PPLCNet_x2.0":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
+ "PPLCNet_x2.5":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
+}
+
+MODEL_STAGES_PATTERN = {
+ "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"]
+}
+
+__all__ = list(MODEL_URLS.keys())
+
+# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
+# k: kernel_size
+# in_c: input channel number in depthwise block
+# out_c: output channel number in depthwise block
+# s: stride in depthwise block
+# use_se: whether to use SE block
+
+NET_CONFIG = {
+ "blocks2":
+ # k, in_c, out_c, s, use_se
+ [[3, 16, 32, 1, False]],
+ "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
+ "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
+ "blocks5":
+ [[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
+ [5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
+ "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
+}
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ num_groups=1):
+ super().__init__()
+
+ self.conv = Conv2D(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=num_groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+ self.bn = BatchNorm(
+ num_filters,
+ param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ self.hardswish = nn.Hardswish()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.hardswish(x)
+ return x
+
+
+class DepthwiseSeparable(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ dw_size=3,
+ use_se=False):
+ super().__init__()
+ self.use_se = use_se
+ self.dw_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_channels,
+ filter_size=dw_size,
+ stride=stride,
+ num_groups=num_channels)
+ if use_se:
+ self.se = SEModule(num_channels)
+ self.pw_conv = ConvBNLayer(
+ num_channels=num_channels,
+ filter_size=1,
+ num_filters=num_filters,
+ stride=1)
+
+ def forward(self, x):
+ x = self.dw_conv(x)
+ if self.use_se:
+ x = self.se(x)
+ x = self.pw_conv(x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channel, reduction=4):
+ super().__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.relu = nn.ReLU()
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.hardsigmoid = nn.Hardsigmoid()
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.hardsigmoid(x)
+ x = paddle.multiply(x=identity, y=x)
+ return x
+
+
+class PPLCNet(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ scale=1.0,
+ pretrained=False,
+ use_ssld=False):
+ super().__init__()
+ self.out_channels = [
+ int(NET_CONFIG["blocks3"][-1][2] * scale),
+ int(NET_CONFIG["blocks4"][-1][2] * scale),
+ int(NET_CONFIG["blocks5"][-1][2] * scale),
+ int(NET_CONFIG["blocks6"][-1][2] * scale)
+ ]
+ self.scale = scale
+
+ self.conv1 = ConvBNLayer(
+ num_channels=in_channels,
+ filter_size=3,
+ num_filters=make_divisible(16 * scale),
+ stride=2)
+
+ self.blocks2 = nn.Sequential(* [
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se)
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
+ ])
+
+ self.blocks3 = nn.Sequential(* [
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se)
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
+ ])
+
+ self.blocks4 = nn.Sequential(* [
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se)
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
+ ])
+
+ self.blocks5 = nn.Sequential(* [
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se)
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
+ ])
+
+ self.blocks6 = nn.Sequential(* [
+ DepthwiseSeparable(
+ num_channels=make_divisible(in_c * scale),
+ num_filters=make_divisible(out_c * scale),
+ dw_size=k,
+ stride=s,
+ use_se=se)
+ for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
+ ])
+
+ if pretrained:
+ self._load_pretrained(
+ MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
+
+ def forward(self, x):
+ outs = []
+ x = self.conv1(x)
+ x = self.blocks2(x)
+ x = self.blocks3(x)
+ outs.append(x)
+ x = self.blocks4(x)
+ outs.append(x)
+ x = self.blocks5(x)
+ outs.append(x)
+ x = self.blocks6(x)
+ outs.append(x)
+ return outs
+
+ def _load_pretrained(self, pretrained_url, use_ssld=False):
+ if use_ssld:
+ pretrained_url = pretrained_url.replace("_pretrained",
+ "_ssld_pretrained")
+ print(pretrained_url)
+ local_weight_path = get_path_from_url(
+ pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
+ param_state_dict = paddle.load(local_weight_path)
+ self.set_dict(param_state_dict)
+ return
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index b4f18b372..d8289d458 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -42,14 +42,15 @@ def build_head(config):
#kie head
from .kie_sdmgr_head import SDMGRHead
- from .table_att_head import TableAttentionHead
+ from .table_att_head import TableAttentionHead, SLAHead
from .table_master_head import TableMasterHead
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
+ 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
+ 'SLAHead'
]
#table head
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index 4f39d6253..00b434105 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -18,12 +18,26 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
+from paddle import ParamAttr
import paddle.nn.functional as F
import numpy as np
from .rec_att_head import AttentionGRUCell
+def get_para_bias_attr(l2_decay, k):
+ if l2_decay > 0:
+ regularizer = paddle.regularizer.L2Decay(l2_decay)
+ stdv = 1.0 / math.sqrt(k * 1.0)
+ initializer = nn.initializer.Uniform(-stdv, stdv)
+ else:
+ regularizer = None
+ initializer = None
+ weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
+ bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
+ return [weight_attr, bias_attr]
+
+
class TableAttentionHead(nn.Layer):
def __init__(self,
in_channels,
@@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
in_max_len=488,
max_text_length=800,
out_channels=30,
- point_num=2,
+ loc_reg_num=4,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
@@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
- point_num * 2)
+ loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
+
+
+class SLAHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ hidden_size,
+ out_channels=30,
+ max_text_length=500,
+ loc_reg_num=4,
+ fc_decay=0.0,
+ **kwargs):
+ """
+ @param in_channels: input shape
+ @param hidden_size: hidden_size for RNN and Embedding
+ @param out_channels: num_classes to rec
+ @param max_text_length: max text pred
+ """
+ super().__init__()
+ in_channels = in_channels[-1]
+ self.hidden_size = hidden_size
+ self.max_text_length = max_text_length
+ self.emb = self._char_to_onehot
+ self.num_embeddings = out_channels
+
+ # structure
+ self.structure_attention_cell = AttentionGRUCell(
+ in_channels, hidden_size, self.num_embeddings)
+ weight_attr, bias_attr = get_para_bias_attr(
+ l2_decay=fc_decay, k=hidden_size)
+ weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
+ l2_decay=fc_decay, k=hidden_size)
+ weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
+ l2_decay=fc_decay, k=hidden_size)
+ self.structure_generator = nn.Sequential(
+ nn.Linear(
+ self.hidden_size,
+ self.hidden_size,
+ weight_attr=weight_attr1_2,
+ bias_attr=bias_attr1_2),
+ nn.Linear(
+ hidden_size,
+ out_channels,
+ weight_attr=weight_attr,
+ bias_attr=bias_attr))
+ # loc
+ weight_attr1, bias_attr1 = get_para_bias_attr(
+ l2_decay=fc_decay, k=self.hidden_size)
+ weight_attr2, bias_attr2 = get_para_bias_attr(
+ l2_decay=fc_decay, k=self.hidden_size)
+ self.loc_generator = nn.Sequential(
+ nn.Linear(
+ self.hidden_size,
+ self.hidden_size,
+ weight_attr=weight_attr1,
+ bias_attr=bias_attr1),
+ nn.Linear(
+ self.hidden_size,
+ loc_reg_num,
+ weight_attr=weight_attr2,
+ bias_attr=bias_attr2),
+ nn.Sigmoid())
+
+ def forward(self, inputs, targets=None):
+ fea = inputs[-1]
+ batch_size = fea.shape[0]
+ # reshape
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+
+ hidden = paddle.zeros((batch_size, self.hidden_size))
+ structure_preds = []
+ loc_preds = []
+ if self.training and targets is not None:
+ structure = targets[0]
+ for i in range(self.max_text_length + 1):
+ hidden, structure_step, loc_step = self._decode(structure[:, i],
+ fea, hidden)
+ structure_preds.append(structure_step)
+ loc_preds.append(loc_step)
+ else:
+ pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
+ max_text_length = paddle.to_tensor(self.max_text_length)
+ # for export
+ loc_step, structure_step = None, None
+ for i in range(max_text_length + 1):
+ hidden, structure_step, loc_step = self._decode(pre_chars, fea,
+ hidden)
+ pre_chars = structure_step.argmax(axis=1, dtype="int32")
+ structure_preds.append(structure_step)
+ loc_preds.append(loc_step)
+ structure_preds = paddle.stack(structure_preds, axis=1)
+ loc_preds = paddle.stack(loc_preds, axis=1)
+ if not self.training:
+ structure_preds = F.softmax(structure_preds)
+ return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
+
+ def _decode(self, pre_chars, features, hidden):
+ """
+ Predict table label and coordinates for each step
+ @param pre_chars: Table label in previous step
+ @param features:
+ @param hidden: hidden status in previous step
+ @return:
+ """
+ emb_feature = self.emb(pre_chars)
+ # output shape is b * self.hidden_size
+ (output, hidden), alpha = self.structure_attention_cell(
+ hidden, features, emb_feature)
+
+ # structure
+ structure_step = self.structure_generator(output)
+ # loc
+ loc_step = self.loc_generator(output)
+ return hidden, structure_step, loc_step
+
+ def _char_to_onehot(self, input_char):
+ input_ont_hot = F.one_hot(input_char, self.num_embeddings)
+ return input_ont_hot
diff --git a/ppocr/modeling/heads/table_master_head.py b/ppocr/modeling/heads/table_master_head.py
index fddbcc63f..486f9cbea 100644
--- a/ppocr/modeling/heads/table_master_head.py
+++ b/ppocr/modeling/heads/table_master_head.py
@@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
d_ff=2048,
dropout=0,
max_text_length=500,
- point_num=2,
+ loc_reg_num=4,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
@@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
- nn.Linear(hidden_size, point_num * 2),
+ nn.Linear(hidden_size, loc_reg_num),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
@@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
- self.point_num = point_num
+ self.loc_reg_num = loc_reg_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
@@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
- [input.shape[0], self.max_text_length + 1, self.point_num * 2])
+ [input.shape[0], self.max_text_length + 1, self.loc_reg_num])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index e10b082d1..e3ae2d6ef 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -25,9 +25,10 @@ def build_neck(config):
from .fpn import FPN
from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN
+ from .csp_pan import CSPPAN
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
- 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
+ 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
]
module_name = config.pop('name')
diff --git a/ppocr/modeling/necks/csp_pan.py b/ppocr/modeling/necks/csp_pan.py
new file mode 100755
index 000000000..625508e99
--- /dev/null
+++ b/ppocr/modeling/necks/csp_pan.py
@@ -0,0 +1,325 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# The code is based on:
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+__all__ = ['CSPPAN']
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channel=96,
+ out_channel=96,
+ kernel_size=3,
+ stride=1,
+ groups=1,
+ act='leaky_relu'):
+ super(ConvBNLayer, self).__init__()
+ initializer = nn.initializer.KaimingUniform()
+ self.act = act
+ assert self.act in ['leaky_relu', "hard_swish"]
+ self.conv = nn.Conv2D(
+ in_channels=in_channel,
+ out_channels=out_channel,
+ kernel_size=kernel_size,
+ groups=groups,
+ padding=(kernel_size - 1) // 2,
+ stride=stride,
+ weight_attr=ParamAttr(initializer=initializer),
+ bias_attr=False)
+ self.bn = nn.BatchNorm2D(out_channel)
+
+ def forward(self, x):
+ x = self.bn(self.conv(x))
+ if self.act == "leaky_relu":
+ x = F.leaky_relu(x)
+ elif self.act == "hard_swish":
+ x = F.hardswish(x)
+ return x
+
+
+class DPModule(nn.Layer):
+ """
+ Depth-wise and point-wise module.
+ Args:
+ in_channel (int): The input channels of this Module.
+ out_channel (int): The output channels of this Module.
+ kernel_size (int): The conv2d kernel size of this Module.
+ stride (int): The conv2d's stride of this Module.
+ act (str): The activation function of this Module,
+ Now support `leaky_relu` and `hard_swish`.
+ """
+
+ def __init__(self,
+ in_channel=96,
+ out_channel=96,
+ kernel_size=3,
+ stride=1,
+ act='leaky_relu'):
+ super(DPModule, self).__init__()
+ initializer = nn.initializer.KaimingUniform()
+ self.act = act
+ self.dwconv = nn.Conv2D(
+ in_channels=in_channel,
+ out_channels=out_channel,
+ kernel_size=kernel_size,
+ groups=out_channel,
+ padding=(kernel_size - 1) // 2,
+ stride=stride,
+ weight_attr=ParamAttr(initializer=initializer),
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(out_channel)
+ self.pwconv = nn.Conv2D(
+ in_channels=out_channel,
+ out_channels=out_channel,
+ kernel_size=1,
+ groups=1,
+ padding=0,
+ weight_attr=ParamAttr(initializer=initializer),
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm2D(out_channel)
+
+ def act_func(self, x):
+ if self.act == "leaky_relu":
+ x = F.leaky_relu(x)
+ elif self.act == "hard_swish":
+ x = F.hardswish(x)
+ return x
+
+ def forward(self, x):
+ x = self.act_func(self.bn1(self.dwconv(x)))
+ x = self.act_func(self.bn2(self.pwconv(x)))
+ return x
+
+
+class DarknetBottleneck(nn.Layer):
+ """The basic bottleneck block used in Darknet.
+ Each Block consists of two ConvModules and the input is added to the
+ final output. Each ConvModule is composed of Conv, BN, and act.
+ The first convLayer has filter size of 1x1 and the second one has the
+ filter size of 3x3.
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ expansion (int): The kernel size of the convolution. Default: 0.5
+ add_identity (bool): Whether to add identity to the out.
+ Default: True
+ use_depthwise (bool): Whether to use depthwise separable convolution.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expansion=0.5,
+ add_identity=True,
+ use_depthwise=False,
+ act="leaky_relu"):
+ super(DarknetBottleneck, self).__init__()
+ hidden_channels = int(out_channels * expansion)
+ conv_func = DPModule if use_depthwise else ConvBNLayer
+ self.conv1 = ConvBNLayer(
+ in_channel=in_channels,
+ out_channel=hidden_channels,
+ kernel_size=1,
+ act=act)
+ self.conv2 = conv_func(
+ in_channel=hidden_channels,
+ out_channel=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ act=act)
+ self.add_identity = \
+ add_identity and in_channels == out_channels
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+
+ if self.add_identity:
+ return out + identity
+ else:
+ return out
+
+
+class CSPLayer(nn.Layer):
+ """Cross Stage Partial Layer.
+ Args:
+ in_channels (int): The input channels of the CSP layer.
+ out_channels (int): The output channels of the CSP layer.
+ expand_ratio (float): Ratio to adjust the number of channels of the
+ hidden layer. Default: 0.5
+ num_blocks (int): Number of blocks. Default: 1
+ add_identity (bool): Whether to add identity in blocks.
+ Default: True
+ use_depthwise (bool): Whether to depthwise separable convolution in
+ blocks. Default: False
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ expand_ratio=0.5,
+ num_blocks=1,
+ add_identity=True,
+ use_depthwise=False,
+ act="leaky_relu"):
+ super().__init__()
+ mid_channels = int(out_channels * expand_ratio)
+ self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
+ self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
+ self.final_conv = ConvBNLayer(
+ 2 * mid_channels, out_channels, 1, act=act)
+
+ self.blocks = nn.Sequential(* [
+ DarknetBottleneck(
+ mid_channels,
+ mid_channels,
+ kernel_size,
+ 1.0,
+ add_identity,
+ use_depthwise,
+ act=act) for _ in range(num_blocks)
+ ])
+
+ def forward(self, x):
+ x_short = self.short_conv(x)
+
+ x_main = self.main_conv(x)
+ x_main = self.blocks(x_main)
+
+ x_final = paddle.concat((x_main, x_short), axis=1)
+ return self.final_conv(x_final)
+
+
+class Channel_T(nn.Layer):
+ def __init__(self,
+ in_channels=[116, 232, 464],
+ out_channels=96,
+ act="leaky_relu"):
+ super(Channel_T, self).__init__()
+ self.convs = nn.LayerList()
+ for i in range(len(in_channels)):
+ self.convs.append(
+ ConvBNLayer(
+ in_channels[i], out_channels, 1, act=act))
+
+ def forward(self, x):
+ outs = [self.convs[i](x[i]) for i in range(len(x))]
+ return outs
+
+
+class CSPPAN(nn.Layer):
+ """Path Aggregation Network with CSP module.
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ kernel_size (int): The conv2d kernel size of this Module.
+ num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
+ use_depthwise (bool): Whether to depthwise separable convolution in
+ blocks. Default: True
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=5,
+ num_csp_blocks=1,
+ use_depthwise=True,
+ act='hard_swish'):
+ super(CSPPAN, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = [out_channels] * len(in_channels)
+ conv_func = DPModule if use_depthwise else ConvBNLayer
+
+ self.conv_t = Channel_T(in_channels, out_channels, act=act)
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
+ self.top_down_blocks = nn.LayerList()
+ for idx in range(len(in_channels) - 1, 0, -1):
+ self.top_down_blocks.append(
+ CSPLayer(
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ act=act))
+
+ # build bottom-up blocks
+ self.downsamples = nn.LayerList()
+ self.bottom_up_blocks = nn.LayerList()
+ for idx in range(len(in_channels) - 1):
+ self.downsamples.append(
+ conv_func(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ act=act))
+ self.bottom_up_blocks.append(
+ CSPLayer(
+ out_channels * 2,
+ out_channels,
+ kernel_size=kernel_size,
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ act=act))
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (tuple[Tensor]): input features.
+ Returns:
+ tuple[Tensor]: CSPPAN features.
+ """
+ assert len(inputs) == len(self.in_channels)
+ inputs = self.conv_t(inputs)
+
+ # top-down path
+ inner_outs = [inputs[-1]]
+ for idx in range(len(self.in_channels) - 1, 0, -1):
+ feat_heigh = inner_outs[0]
+ feat_low = inputs[idx - 1]
+
+ upsample_feat = F.upsample(
+ feat_heigh, size=feat_low.shape[2:4], mode="nearest")
+
+ inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
+ paddle.concat([upsample_feat, feat_low], 1))
+ inner_outs.insert(0, inner_out)
+
+ # bottom-up path
+ outs = [inner_outs[0]]
+ for idx in range(len(self.in_channels) - 1):
+ feat_low = outs[-1]
+ feat_height = inner_outs[idx + 1]
+ downsample_feat = self.downsamples[idx](feat_low)
+ out = self.bottom_up_blocks[idx](paddle.concat(
+ [downsample_feat, feat_height], 1))
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py
index 4396ec4f7..ce254f314 100644
--- a/ppocr/postprocess/table_postprocess.py
+++ b/ppocr/postprocess/table_postprocess.py
@@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode):
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
- self.td_token = ['
', ' | ', ' | | ']
+ self.td_token = ['', ' | | ']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
@@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode):
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
- src_h = h / ratio_h
- src_w = w / ratio_w
- bbox[0::2] *= src_w
- bbox[1::2] *= src_h
+ bbox[0::2] *= w
+ bbox[1::2] *= h
return bbox
@@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode):
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
+ x, y, w, h = bbox
+ x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
+ bbox = np.array([x1, y1, x2, y2])
return bbox
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
index e0fbf06ab..030d1c38d 100644
--- a/ppocr/utils/visual.py
+++ b/ppocr/utils/visual.py
@@ -113,14 +113,10 @@ def draw_re_results(image,
return np.array(img_new)
-def draw_rectangle(img_path, boxes, use_xywh=False):
+def draw_rectangle(img_path, boxes):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
- if use_xywh:
- x, y, w, h = box
- x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
- else:
- x1, y1, x2, y2 = box
+ x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
diff --git a/ppstructure/layout/picodet_postprocess.py b/ppstructure/layout/picodet_postprocess.py
new file mode 100644
index 000000000..7df13f827
--- /dev/null
+++ b/ppstructure/layout/picodet_postprocess.py
@@ -0,0 +1,227 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from scipy.special import softmax
+
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ """
+ Args:
+ box_scores (N, 5): boxes in corner-form and probabilities.
+ iou_threshold: intersection over union threshold.
+ top_k: keep top_k results. If k <= 0, keep all the results.
+ candidate_size: only consider the candidates with the highest scores.
+ Returns:
+ picked: a list of indexes of the kept boxes
+ """
+ scores = box_scores[:, -1]
+ boxes = box_scores[:, :-1]
+ picked = []
+ indexes = np.argsort(scores)
+ indexes = indexes[-candidate_size:]
+ while len(indexes) > 0:
+ current = indexes[-1]
+ picked.append(current)
+ if 0 < top_k == len(picked) or len(indexes) == 1:
+ break
+ current_box = boxes[current, :]
+ indexes = indexes[:-1]
+ rest_boxes = boxes[indexes, :]
+ iou = iou_of(
+ rest_boxes,
+ np.expand_dims(
+ current_box, axis=0), )
+ indexes = indexes[iou <= iou_threshold]
+
+ return box_scores[picked, :]
+
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+ """Return intersection-over-union (Jaccard index) of boxes.
+ Args:
+ boxes0 (N, 4): ground truth boxes.
+ boxes1 (N or 1, 4): predicted boxes.
+ eps: a small number to avoid 0 as denominator.
+ Returns:
+ iou (N): IoU values.
+ """
+ overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
+ overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
+
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+ return overlap_area / (area0 + area1 - overlap_area + eps)
+
+
+def area_of(left_top, right_bottom):
+ """Compute the areas of rectangles given two corners.
+ Args:
+ left_top (N, 2): left top corner.
+ right_bottom (N, 2): right bottom corner.
+ Returns:
+ area (N): return the area.
+ """
+ hw = np.clip(right_bottom - left_top, 0.0, None)
+ return hw[..., 0] * hw[..., 1]
+
+
+class PicoDetPostProcess(object):
+ """
+ Args:
+ input_shape (int): network input image size
+ ori_shape (int): ori image shape of before padding
+ scale_factor (float): scale factor of ori image
+ enable_mkldnn (bool): whether to open MKLDNN
+ """
+
+ def __init__(self,
+ input_shape,
+ ori_shape,
+ scale_factor,
+ strides=[8, 16, 32, 64],
+ score_threshold=0.4,
+ nms_threshold=0.5,
+ nms_top_k=1000,
+ keep_top_k=100):
+ self.ori_shape = ori_shape
+ self.input_shape = input_shape
+ self.scale_factor = scale_factor
+ self.strides = strides
+ self.score_threshold = score_threshold
+ self.nms_threshold = nms_threshold
+ self.nms_top_k = nms_top_k
+ self.keep_top_k = keep_top_k
+
+ def warp_boxes(self, boxes, ori_shape):
+ """Apply transform to boxes
+ """
+ width, height = ori_shape[1], ori_shape[0]
+ n = len(boxes)
+ if n:
+ # warp points
+ xy = np.ones((n * 4, 3))
+ xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
+ n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ # xy = xy @ M.T # transform
+ xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
+ # create new boxes
+ x = xy[:, [0, 2, 4, 6]]
+ y = xy[:, [1, 3, 5, 7]]
+ xy = np.concatenate(
+ (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+ # clip boxes
+ xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
+ xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
+ return xy.astype(np.float32)
+ else:
+ return boxes
+
+ def __call__(self, scores, raw_boxes):
+ batch_size = raw_boxes[0].shape[0]
+ reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
+ out_boxes_num = []
+ out_boxes_list = []
+ for batch_id in range(batch_size):
+ # generate centers
+ decode_boxes = []
+ select_scores = []
+ for stride, box_distribute, score in zip(self.strides, raw_boxes,
+ scores):
+ box_distribute = box_distribute[batch_id]
+ score = score[batch_id]
+ # centers
+ fm_h = self.input_shape[0] / stride
+ fm_w = self.input_shape[1] / stride
+ h_range = np.arange(fm_h)
+ w_range = np.arange(fm_w)
+ ww, hh = np.meshgrid(w_range, h_range)
+ ct_row = (hh.flatten() + 0.5) * stride
+ ct_col = (ww.flatten() + 0.5) * stride
+ center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
+
+ # box distribution to distance
+ reg_range = np.arange(reg_max + 1)
+ box_distance = box_distribute.reshape((-1, reg_max + 1))
+ box_distance = softmax(box_distance, axis=1)
+ box_distance = box_distance * np.expand_dims(reg_range, axis=0)
+ box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
+ box_distance = box_distance * stride
+
+ # top K candidate
+ topk_idx = np.argsort(score.max(axis=1))[::-1]
+ topk_idx = topk_idx[:self.nms_top_k]
+ center = center[topk_idx]
+ score = score[topk_idx]
+ box_distance = box_distance[topk_idx]
+
+ # decode box
+ decode_box = center + [-1, -1, 1, 1] * box_distance
+
+ select_scores.append(score)
+ decode_boxes.append(decode_box)
+
+ # nms
+ bboxes = np.concatenate(decode_boxes, axis=0)
+ confidences = np.concatenate(select_scores, axis=0)
+ picked_box_probs = []
+ picked_labels = []
+ for class_index in range(0, confidences.shape[1]):
+ probs = confidences[:, class_index]
+ mask = probs > self.score_threshold
+ probs = probs[mask]
+ if probs.shape[0] == 0:
+ continue
+ subset_boxes = bboxes[mask, :]
+ box_probs = np.concatenate(
+ [subset_boxes, probs.reshape(-1, 1)], axis=1)
+ box_probs = hard_nms(
+ box_probs,
+ iou_threshold=self.nms_threshold,
+ top_k=self.keep_top_k, )
+ picked_box_probs.append(box_probs)
+ picked_labels.extend([class_index] * box_probs.shape[0])
+
+ if len(picked_box_probs) == 0:
+ out_boxes_list.append(np.empty((0, 4)))
+ out_boxes_num.append(0)
+
+ else:
+ picked_box_probs = np.concatenate(picked_box_probs)
+
+ # resize output boxes
+ picked_box_probs[:, :4] = self.warp_boxes(
+ picked_box_probs[:, :4], self.ori_shape[batch_id])
+ im_scale = np.concatenate([
+ self.scale_factor[batch_id][::-1],
+ self.scale_factor[batch_id][::-1]
+ ])
+ picked_box_probs[:, :4] /= im_scale
+ # clas score box
+ out_boxes_list.append(
+ np.concatenate(
+ [
+ np.expand_dims(
+ np.array(picked_labels),
+ axis=-1), np.expand_dims(
+ picked_box_probs[:, 4], axis=-1),
+ picked_box_probs[:, :4]
+ ],
+ axis=1))
+ out_boxes_num.append(len(picked_labels))
+
+ out_boxes_list = np.concatenate(out_boxes_list, axis=0)
+ out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
+ return out_boxes_list, out_boxes_num
diff --git a/ppstructure/layout/predict_layout.py b/ppstructure/layout/predict_layout.py
new file mode 100644
index 000000000..2fb4b4623
--- /dev/null
+++ b/ppstructure/layout/predict_layout.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+from picodet_postprocess import PicoDetPostProcess
+
+logger = get_logger()
+
+
+class LayoutPredictor(object):
+ def __init__(self, args):
+ pre_process_list = [{
+ 'Resize': {
+ 'size': [800, 608]
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image']
+ }
+ }]
+ # postprocess_params = {
+ # 'name': 'LayoutPostProcess',
+ # "character_dict_path": args.layout_dict_path,
+ # }
+
+ self.preprocess_op = create_operators(pre_process_list)
+ # self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'layout', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+
+ if img is None:
+ return None, 0
+
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+
+ preds, elapse = 0, 1
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+
+ # outputs = []
+ # for output_tensor in self.output_tensors:
+ # output = output_tensor.copy_to_cpu()
+ # outputs.append(output)
+ np_score_list, np_boxes_list = [], []
+ output_names = self.predictor.get_output_names()
+ num_outs = int(len(output_names) / 2)
+ for out_idx in range(num_outs):
+ np_score_list.append(
+ self.predictor.get_output_handle(output_names[out_idx])
+ .copy_to_cpu())
+ np_boxes_list.append(
+ self.predictor.get_output_handle(output_names[
+ out_idx + num_outs]).copy_to_cpu())
+ # result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
+ postprocessor = PicoDetPostProcess(
+ (800, 608), [[800., 608.]],
+ np.array([[1.010101, 0.99346405]]),
+ strides=[8, 16, 32, 64],
+ nms_threshold=0.5)
+ np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
+ result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
+ # print(result)
+ im_bboxes_num = result['boxes_num'][0]
+ # print('im_bboxes_num:',im_bboxes_num)
+
+ bboxs = result['boxes'][0:0 + im_bboxes_num, :]
+ threshold = 0.5
+ expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
+ np_boxes = np_boxes[expect_boxes, :]
+ preds = []
+
+ id2label = {1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'}
+ for dt in np_boxes:
+ clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+ label = id2label[clsid + 1]
+ result_di = {'bbox': bbox, 'label': label}
+ preds.append(result_di)
+ # print('result_di',result_di)
+ # print('clsid, bbox, score:',clsid, bbox, score)
+
+ elapse = time.time() - starttime
+ return preds, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ layout_predictor = LayoutPredictor(args)
+ count = 0
+ total_time = 0
+
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ layout_res, elapse = layout_predictor(img)
+
+ logger.info("result: {}".format(layout_res))
+
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index d6f2e2424..075d91446 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -18,7 +18,7 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
@@ -32,6 +32,7 @@ from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
+from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.recovery_to_doc import convert_info_docx
@@ -51,28 +52,14 @@ class StructureSystem(object):
"When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
- # init layout and ocr model
+ # init model
+ self.layout_predictor = None
self.text_system = None
+ self.table_system = None
if args.layout:
- import layoutparser as lp
- config_path = None
- model_path = None
- if os.path.isdir(args.layout_path_model):
- model_path = args.layout_path_model
- else:
- config_path = args.layout_path_model
- self.table_layout = lp.PaddleDetectionLayoutModel(
- config_path=config_path,
- model_path=model_path,
- label_map=args.layout_label_map,
- threshold=0.5,
- enable_mkldnn=args.enable_mkldnn,
- enforce_cpu=not args.use_gpu,
- thread_num=args.cpu_threads)
+ self.layout_predictor = LayoutPredictor(args)
if args.ocr:
self.text_system = TextSystem(args)
- else:
- self.table_layout = None
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
@@ -80,38 +67,59 @@ class StructureSystem(object):
self.text_system.text_recognizer)
else:
self.table_system = TableSystem(args)
- else:
- self.table_system = None
elif self.mode == 'vqa':
raise NotImplementedError
def __call__(self, img, return_ocr_result_in_table=False):
+ time_dict = {
+ 'layout': 0,
+ 'table': 0,
+ 'table_match': 0,
+ 'det': 0,
+ 'rec': 0,
+ 'vqa': 0,
+ 'all': 0
+ }
+ start = time.time()
if self.mode == 'structure':
ori_im = img.copy()
- if self.table_layout is not None:
- layout_res = self.table_layout.detect(img[..., ::-1])
+ if self.layout_predictor is not None:
+ layout_res, elapse = self.layout_predictor(img)
+ time_dict['layout'] += elapse
else:
h, w = ori_im.shape[:2]
- layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
+ layout_res = [dict(bbox=None, label='table')]
res_list = []
for region in layout_res:
res = ''
- x1, y1, x2, y2 = region.coordinates
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
- roi_img = ori_im[y1:y2, x1:x2, :]
- if region.type == 'Table':
+ if region['bbox'] is not None:
+ x1, y1, x2, y2 = region['bbox']
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ roi_img = ori_im[y1:y2, x1:x2, :]
+ else:
+ x1, y1, x2, y2 = 0, 0, w, h
+ roi_img = ori_im
+ if region['label'] == 'table':
if self.table_system is not None:
- res = self.table_system(roi_img,
- return_ocr_result_in_table)
+ res, table_time_dict = self.table_system(
+ roi_img, return_ocr_result_in_table)
+ time_dict['table'] += table_time_dict['table']
+ time_dict['table_match'] += table_time_dict['match']
+ time_dict['det'] += table_time_dict['det']
+ time_dict['rec'] += table_time_dict['rec']
else:
if self.text_system is not None:
if args.recovery:
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
wht_im[y1:y2, x1:x2, :] = roi_img
- filter_boxes, filter_rec_res = self.text_system(wht_im)
+ filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
+ wht_im)
else:
- filter_boxes, filter_rec_res = self.text_system(roi_img)
+ filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
+ roi_img)
+ time_dict['det'] += ocr_time_dict['det']
+ time_dict['rec'] += ocr_time_dict['rec']
# remove style char
style_token = [
'', '', '', '', '',
@@ -133,15 +141,17 @@ class StructureSystem(object):
'text_region': box.tolist()
})
res_list.append({
- 'type': region.type,
+ 'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'img': roi_img,
'res': res
})
- return res_list
+ end = time.time()
+ time_dict['all'] = end - start
+ return res_list, time_dict
elif self.mode == 'vqa':
raise NotImplementedError
- return None
+ return None, None
def save_structure_res(res, save_folder, img_name):
@@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name):
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
- if region['type'] == 'Table' and len(region[
+ if region['type'] == 'table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res']['html'], excel_path)
- elif region['type'] == 'Figure':
+ elif region['type'] == 'figure':
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
@@ -188,7 +198,7 @@ def main(args):
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
- res = structure_sys(img)
+ res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
save_structure_res(res, save_folder, img_name)
@@ -201,7 +211,7 @@ def main(args):
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
- convert_info_docx(img, res, save_folder, img_name)
+ convert_info_docx(img, res, save_folder, img_name)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
index 87b44d3d9..435d69322 100755
--- a/ppstructure/table/eval_table.py
+++ b/ppstructure/table/eval_table.py
@@ -13,12 +13,14 @@
# limitations under the License.
import os
import sys
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
-import json
+import pickle
+import paddle
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem
@@ -33,40 +35,74 @@ def parse_args():
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
-def main(gt_path, img_root, args):
- teds = TEDS(n_jobs=16)
+def load_txt(txt_path):
+ pred_html_dict = {}
+ if not os.path.exists(txt_path):
+ return pred_html_dict
+ with open(txt_path, encoding='utf-8') as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip().split('\t')
+ img_name, pred_html = line
+ pred_html_dict[img_name] = pred_html
+ return pred_html_dict
+
+
+def load_result(path):
+ data = {}
+ if os.path.exists(path):
+ data = pickle.load(open(path, 'rb'))
+ return data
+
+
+def save_result(path, data):
+ old_data = load_result(path)
+ old_data.update(data)
+ with open(path, 'wb') as f:
+ pickle.dump(old_data, f)
+
+
+def main(gt_path, img_root, args):
+ os.makedirs(args.output, exist_ok=True)
+ # init TableSystem
text_sys = TableSystem(args)
- jsons_gt = json.load(open(gt_path)) # gt
+ # load gt and preds html result
+ gt_html_dict = load_txt(gt_path)
+
+ ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
+ structure_result = load_result(
+ os.path.join(args.output, 'structure.pickle'))
+
pred_htmls = []
gt_htmls = []
- for img_name in tqdm(jsons_gt):
- # read image
- img = cv2.imread(os.path.join(img_root,img_name))
- pred_html = text_sys(img)
+ for img_name, gt_html in tqdm(gt_html_dict.items()):
+ img = cv2.imread(os.path.join(img_root, img_name))
+ # run ocr and save result
+ if img_name not in ocr_result:
+ dt_boxes, rec_res, _, _ = text_sys._ocr(img)
+ ocr_result[img_name] = [dt_boxes, rec_res]
+ save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
+ # run structure and save result
+ if img_name not in structure_result:
+ structure_res, _ = text_sys._structure(img)
+ structure_result[img_name] = structure_res
+ save_result(
+ os.path.join(args.output, 'structure.pickle'), structure_result)
+ dt_boxes, rec_res = ocr_result[img_name]
+ structure_res = structure_result[img_name]
+ # match ocr and structure
+ pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
+
pred_htmls.append(pred_html)
-
- gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
- gt_html, gt = get_gt_html(gt_structures, gt_contents)
gt_htmls.append(gt_html)
+
+ # compute teds
+ teds = TEDS(n_jobs=16)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
- logger.info('teds:', sum(scores) / len(scores))
-
-
-def get_gt_html(gt_structures, gt_contents):
- end_html = []
- td_index = 0
- for tag in gt_structures:
- if '' in tag:
- if gt_contents[td_index] != []:
- end_html.extend(gt_contents[td_index])
- end_html.append(tag)
- td_index += 1
- else:
- end_html.append(tag)
- return ''.join(end_html), end_html
+ logger.info('teds: {}'.format(sum(scores) / len(scores)))
if __name__ == '__main__':
args = parse_args()
- main(args.gt_path,args.image_dir, args)
+ main(args.gt_path, args.image_dir, args)
diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py
index c3b563844..d75e9abb3 100755
--- a/ppstructure/table/matcher.py
+++ b/ppstructure/table/matcher.py
@@ -1,11 +1,15 @@
import json
+from ppstructure.table.table_master_match import deal_eb_token, deal_bb
+
+
def distance(box_1, box_2):
- x1, y1, x2, y2 = box_1
- x3, y3, x4, y4 = box_2
- dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
- dis_2 = abs(x3 - x1) + abs(y3 - y1)
- dis_3 = abs(x4- x2) + abs(y4 - y2)
- return dis + min(dis_2, dis_3)
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4 - x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
def compute_iou(rec1, rec2):
"""
@@ -18,23 +22,22 @@ def compute_iou(rec1, rec2):
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
-
+
# computing the sum_area
sum_area = S_rec1 + S_rec2
-
+
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
-
+
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect))*1.0
-
+ return (intersect / (sum_area - intersect)) * 1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
@@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
- distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
+ distances.append((distance(gt_box, pred_box),
+ 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
- sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
- if distances.index(sorted_distances[0]) not in matched.keys():
+ sorted_distances = sorted(
+ sorted_distances, key=lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
- return matched#, sum(ious) / len(ious)
+ return matched #, sum(ious) / len(ious)
+
def complex_num(pred_bboxes):
complex_nums = []
@@ -67,6 +73,7 @@ def complex_num(pred_bboxes):
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
+
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
@@ -81,7 +88,9 @@ def get_rows(pred_bboxes):
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
-def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
+
+
+def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
@@ -95,12 +104,14 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
-
+
+
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
- while(len(before_refine_pred_bboxes) != 0):
- row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
+ while (len(before_refine_pred_bboxes) != 0):
+ row_bboxes, before_refine_pred_bboxes = get_rows(
+ before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
@@ -114,12 +125,11 @@ def matcher_refine_row(gt_bboxes, pred_bboxes):
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
- if distances.index(min(distances)) not in matched.keys():
+ if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
- return matched#, sum(ious) / len(ious)
-
+ return matched #, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
@@ -128,29 +138,30 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
- while(len(delete_gt_bboxes) != 0):
+ while (len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
- row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
+ row_bboxes = sorted(row_bboxes, key=lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
- distances = []
+ distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
- if index not in matched.keys():
+ if index not in matched.keys():
matched[index] = [gt_box_index]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
+
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
@@ -161,7 +172,7 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
-
+
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
@@ -184,9 +195,143 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
- if index not in matched.keys():
+ if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
+
+
+class TableMatch:
+ def __init__(self, filter_ocr_result=False, use_master=False):
+ self.filter_ocr_result = filter_ocr_result
+ self.use_master = use_master
+
+ def __call__(self, structure_res, dt_boxes, rec_res):
+ pred_structures, pred_bboxes = structure_res
+ if self.filter_ocr_result:
+ dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes,
+ rec_res)
+ matched_index = self.match_result(dt_boxes, pred_bboxes)
+ if self.use_master:
+ pred_html, pred = self.get_pred_html_master(pred_structures,
+ matched_index, rec_res)
+ else:
+ pred_html, pred = self.get_pred_html(pred_structures, matched_index,
+ rec_res)
+ return pred_html
+
+ def match_result(self, dt_boxes, pred_bboxes):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ distances.append((distance(gt_box, pred_box),
+ 1. - compute_iou(gt_box, pred_box)
+ )) # 获取两两cell之间的L1距离和 1- IOU
+ sorted_distances = distances.copy()
+ # 根据距离和IOU挑选最"近"的cell
+ sorted_distances = sorted(
+ sorted_distances, key=lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in pred_structures:
+ if '' in tag:
+ if ' | ' == tag:
+ end_html.extend('')
+ if td_index in matched_index.keys():
+ b_with = False
+ if '' in ocr_contents[matched_index[td_index][
+ 0]] and len(matched_index[td_index]) > 1:
+ b_with = True
+ end_html.extend('')
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+ if content[0] == ' ':
+ content = content[1:]
+ if '' in content:
+ content = content[3:]
+ if '' in content:
+ content = content[:-4]
+ if len(content) == 0:
+ continue
+ if i != len(matched_index[
+ td_index]) - 1 and ' ' != content[-1]:
+ content += ' '
+ end_html.extend(content)
+ if b_with:
+ end_html.extend('')
+ if ' | | ' == tag:
+ end_html.append('')
+ else:
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+ def get_pred_html_master(self, pred_structures, matched_index,
+ ocr_contents):
+ end_html = []
+ td_index = 0
+ for token in pred_structures:
+ if '' in token:
+ txt = ''
+ b_with = False
+ if td_index in matched_index.keys():
+ if '' in ocr_contents[matched_index[td_index][
+ 0]] and len(matched_index[td_index]) > 1:
+ b_with = True
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+ if content[0] == ' ':
+ content = content[1:]
+ if '' in content:
+ content = content[3:]
+ if '' in content:
+ content = content[:-4]
+ if len(content) == 0:
+ continue
+ if i != len(matched_index[
+ td_index]) - 1 and ' ' != content[-1]:
+ content += ' '
+ txt += content
+ if b_with:
+ txt = '{}'.format(txt)
+ if ' | ' == token:
+ token = '{} | '.format(txt)
+ else:
+ token = '{}'.format(txt)
+ td_index += 1
+ token = deal_eb_token(token)
+ end_html.append(token)
+ html = ''.join(end_html)
+ html = deal_bb(html)
+ return html, end_html
+
+ def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
+ y1 = pred_bboxes[:, 1::2].min()
+ new_dt_boxes = []
+ new_rec_res = []
+
+ for box, rec in zip(dt_boxes, rec_res):
+ if np.max(box[1::2]) < y1:
+ continue
+ new_dt_boxes.append(box)
+ new_rec_res.append(rec)
+ return new_dt_boxes, new_rec_res
diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py
index 7a7d3169d..01d467594 100755
--- a/ppstructure/table/predict_structure.py
+++ b/ppstructure/table/predict_structure.py
@@ -16,7 +16,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -87,6 +87,7 @@ class TableStructurer(object):
utility.create_predictor(args, 'table', logger)
def __call__(self, img):
+ starttime = time.time()
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
@@ -95,7 +96,6 @@ class TableStructurer(object):
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
- starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
@@ -126,7 +126,6 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
- use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
@@ -146,7 +145,7 @@ def main(args):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
- img = draw_rectangle(image_file, bbox_list, use_xywh)
+ img = draw_rectangle(image_file, bbox_list)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
index becc6daef..6e7051235 100644
--- a/ppstructure/table/predict_table.py
+++ b/ppstructure/table/predict_table.py
@@ -18,20 +18,23 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
+import logging
import numpy as np
import time
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import tools.infer.utility as utility
+from tools.infer.predict_system import sorted_boxes
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
-from ppstructure.table.matcher import distance, compute_iou
+from ppstructure.table.matcher import TableMatch
+from ppstructure.table.table_master_match import TableMasterMatcher
from ppstructure.utility import parse_args
import ppstructure.table.predict_structure as predict_strture
@@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+
self.text_detector = predict_det.TextDetector(
args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer
+
self.table_structurer = predict_strture.TableStructurer(args)
+ if args.table_algorithm in ['TableMaster']:
+ self.match = TableMasterMatcher()
+ else:
+ self.match = TableMatch()
+
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'table', logger)
@@ -85,16 +97,47 @@ class TableSystem(object):
def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
- ori_im = img.copy()
+ time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
+ start = time.time()
+
+ structure_res, elapse = self._structure(copy.deepcopy(img))
+ time_dict['table'] = elapse
+
+ dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
+ copy.deepcopy(img))
+ time_dict['det'] = det_elapse
+ time_dict['rec'] = rec_elapse
+
+ if return_ocr_result_in_table:
+ result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes]
+ result['rec_res'] = rec_res
+
+ tic = time.time()
+ pred_html = self.match(structure_res, dt_boxes, rec_res)
+ toc = time.time()
+ time_dict['match'] = toc - tic
+ # pred_html = self.match(1, 1, 1,img_name)
+ result['html'] = pred_html
+ if self.benchmark:
+ self.autolog.times.end(stamp=True)
+ end = time.time()
+ time_dict['all'] = end - start
+ if self.benchmark:
+ self.autolog.times.stamp()
+ return result, time_dict
+
+ def _structure(self, img):
if self.benchmark:
self.autolog.times.start()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
+ return structure_res, elapse
+
+ def _ocr(self, img):
if self.benchmark:
self.autolog.times.stamp()
- dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
+ dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
- if return_ocr_result_in_table:
- result['boxes'] = [x.tolist() for x in dt_boxes]
+
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
@@ -105,125 +148,20 @@ class TableSystem(object):
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format(
- len(dt_boxes), elapse))
+ len(dt_boxes), det_elapse))
if dt_boxes is None:
return None, None
+
img_crop_list = []
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
- x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
- text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
+ x0, y0, x1, y1 = expand(2, det_box, img.shape)
+ text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
- rec_res, elapse = self.text_recognizer(img_crop_list)
+ rec_res, rec_elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
- len(rec_res), elapse))
- if self.benchmark:
- self.autolog.times.stamp()
- if return_ocr_result_in_table:
- result['rec_res'] = rec_res
- pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
- result['html'] = pred_html
- if self.benchmark:
- self.autolog.times.end(stamp=True)
- return result
-
- def rebuild_table(self, structure_res, dt_boxes, rec_res):
- pred_structures, pred_bboxes = structure_res
- dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
- matched_index = self.match_result(dt_boxes, pred_bboxes)
- pred_html, pred = self.get_pred_html(pred_structures, matched_index,
- rec_res)
- return pred_html, pred
-
- def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
- y1 = pred_bboxes[:,1::2].min()
- new_dt_boxes = []
- new_rec_res = []
-
- for box,rec in zip(dt_boxes, rec_res):
- if np.max(box[1::2]) < y1:
- continue
- new_dt_boxes.append(box)
- new_rec_res.append(rec)
- return new_dt_boxes, new_rec_res
-
-
- def match_result(self, dt_boxes, pred_bboxes):
- matched = {}
- for i, gt_box in enumerate(dt_boxes):
- # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
- distances = []
- for j, pred_box in enumerate(pred_bboxes):
- distances.append((distance(gt_box, pred_box),
- 1. - compute_iou(gt_box, pred_box)
- )) # 获取两两cell之间的L1距离和 1- IOU
- sorted_distances = distances.copy()
- # 根据距离和IOU挑选最"近"的cell
- sorted_distances = sorted(
- sorted_distances, key=lambda item: (item[1], item[0]))
- if distances.index(sorted_distances[0]) not in matched.keys():
- matched[distances.index(sorted_distances[0])] = [i]
- else:
- matched[distances.index(sorted_distances[0])].append(i)
- return matched
-
- def get_pred_html(self, pred_structures, matched_index, ocr_contents):
- end_html = []
- td_index = 0
- for tag in pred_structures:
- if '' in tag:
- if td_index in matched_index.keys():
- b_with = False
- if '' in ocr_contents[matched_index[td_index][
- 0]] and len(matched_index[td_index]) > 1:
- b_with = True
- end_html.extend('')
- for i, td_index_index in enumerate(matched_index[td_index]):
- content = ocr_contents[td_index_index][0]
- if len(matched_index[td_index]) > 1:
- if len(content) == 0:
- continue
- if content[0] == ' ':
- content = content[1:]
- if '' in content:
- content = content[3:]
- if '' in content:
- content = content[:-4]
- if len(content) == 0:
- continue
- if i != len(matched_index[
- td_index]) - 1 and ' ' != content[-1]:
- content += ' '
- end_html.extend(content)
- if b_with:
- end_html.extend('')
-
- end_html.append(tag)
- td_index += 1
- else:
- end_html.append(tag)
- return ''.join(end_html), end_html
-
-
-def sorted_boxes(dt_boxes):
- """
- Sort text boxes in order from top to bottom, left to right
- args:
- dt_boxes(array):detected text boxes with shape [4, 2]
- return:
- sorted boxes(array) with shape [4, 2]
- """
- num_boxes = dt_boxes.shape[0]
- sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
- _boxes = list(sorted_boxes)
-
- for i in range(num_boxes - 1):
- if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
- (_boxes[i + 1][0][0] < _boxes[i][0][0]):
- tmp = _boxes[i]
- _boxes[i] = _boxes[i + 1]
- _boxes[i + 1] = tmp
- return _boxes
+ len(rec_res), rec_elapse))
+ return dt_boxes, rec_res, det_elapse, rec_elapse
def to_excel(html_table, excel_path):
@@ -249,7 +187,7 @@ def main(args):
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
- pred_res = text_sys(img)
+ pred_res, _ = text_sys(img)
pred_html = pred_res['html']
logger.info(pred_html)
to_excel(pred_html, excel_path)
diff --git a/ppstructure/table/table_master_match.py b/ppstructure/table/table_master_match.py
new file mode 100644
index 000000000..069d576bf
--- /dev/null
+++ b/ppstructure/table/table_master_match.py
@@ -0,0 +1,1009 @@
+import os
+import re
+import cv2
+import glob
+import copy
+import math
+import pickle
+import numpy as np
+
+from shapely.geometry import Polygon, MultiPoint
+"""
+Useful function in matching.
+"""
+
+
+def remove_empty_bboxes(bboxes):
+ """
+ remove [0., 0., 0., 0.] in structure master bboxes.
+ len(bboxes.shape) must be 2.
+ :param bboxes:
+ :return:
+ """
+ new_bboxes = []
+ for bbox in bboxes:
+ if sum(bbox) == 0.:
+ continue
+ new_bboxes.append(bbox)
+ return np.array(new_bboxes)
+
+
+def xywh2xyxy(bboxes):
+ if len(bboxes.shape) == 1:
+ new_bboxes = np.empty_like(bboxes)
+ new_bboxes[0] = bboxes[0] - bboxes[2] / 2
+ new_bboxes[1] = bboxes[1] - bboxes[3] / 2
+ new_bboxes[2] = bboxes[0] + bboxes[2] / 2
+ new_bboxes[3] = bboxes[1] + bboxes[3] / 2
+ return new_bboxes
+ elif len(bboxes.shape) == 2:
+ new_bboxes = np.empty_like(bboxes)
+ new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
+ new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
+ new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
+ new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
+ return new_bboxes
+ else:
+ raise ValueError
+
+
+def xyxy2xywh(bboxes):
+ if len(bboxes.shape) == 1:
+ new_bboxes = np.empty_like(bboxes)
+ new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
+ new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
+ new_bboxes[2] = bboxes[2] - bboxes[0]
+ new_bboxes[3] = bboxes[3] - bboxes[1]
+ return new_bboxes
+ elif len(bboxes.shape) == 2:
+ new_bboxes = np.empty_like(bboxes)
+ new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
+ new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
+ new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+ new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+ return new_bboxes
+ else:
+ raise ValueError
+
+
+def pickle_load(path, prefix='end2end'):
+ if os.path.isfile(path):
+ data = pickle.load(open(path, 'rb'))
+ elif os.path.isdir(path):
+ data = dict()
+ search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
+ pkls = glob.glob(search_path)
+ for pkl in pkls:
+ this_data = pickle.load(open(pkl, 'rb'))
+ data.update(this_data)
+ else:
+ raise ValueError
+ return data
+
+
+def convert_coord(xyxy):
+ """
+ Convert two points format to four points format.
+ :param xyxy:
+ :return:
+ """
+ new_bbox = np.zeros([4, 2], dtype=np.float32)
+ new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
+ new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
+ new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
+ new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
+ return new_bbox
+
+
+def cal_iou(bbox1, bbox2):
+ bbox1_poly = Polygon(bbox1).convex_hull
+ bbox2_poly = Polygon(bbox2).convex_hull
+ union_poly = np.concatenate((bbox1, bbox2))
+
+ if not bbox1_poly.intersects(bbox2_poly):
+ iou = 0
+ else:
+ inter_area = bbox1_poly.intersection(bbox2_poly).area
+ union_area = MultiPoint(union_poly).convex_hull.area
+ if union_area == 0:
+ iou = 0
+ else:
+ iou = float(inter_area) / union_area
+ return iou
+
+
+def cal_distance(p1, p2):
+ delta_x = p1[0] - p2[0]
+ delta_y = p1[1] - p2[1]
+ d = math.sqrt((delta_x**2) + (delta_y**2))
+ return d
+
+
+def is_inside(center_point, corner_point):
+ """
+ Find if center_point inside the bbox(corner_point) or not.
+ :param center_point: center point (x, y)
+ :param corner_point: corner point ((x1,y1),(x2,y2))
+ :return:
+ """
+ x_flag = False
+ y_flag = False
+ if (center_point[0] >= corner_point[0][0]) and (
+ center_point[0] <= corner_point[1][0]):
+ x_flag = True
+ if (center_point[1] >= corner_point[0][1]) and (
+ center_point[1] <= corner_point[1][1]):
+ y_flag = True
+ if x_flag and y_flag:
+ return True
+ else:
+ return False
+
+
+def find_no_match(match_list, all_end2end_nums, type='end2end'):
+ """
+ Find out no match end2end bbox in previous match list.
+ :param match_list: matching pairs.
+ :param all_end2end_nums: numbers of end2end_xywh
+ :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
+ :return: no match pse bbox index list
+ """
+ if type == 'end2end':
+ idx = 0
+ elif type == 'master':
+ idx = 1
+ else:
+ raise ValueError
+
+ no_match_indexs = []
+ # m[0] is end2end index m[1] is master index
+ matched_bbox_indexs = [m[idx] for m in match_list]
+ for n in range(all_end2end_nums):
+ if n not in matched_bbox_indexs:
+ no_match_indexs.append(n)
+ return no_match_indexs
+
+
+def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
+ # only consider y axis, for grouping in row.
+ delta = abs(this_bbox[1] - target_bbox[1])
+ if delta < threshold:
+ return True
+ else:
+ return False
+
+
+def sort_line_bbox(g, bg):
+ """
+ Sorted the bbox in the same line(group)
+ compare coord 'x' value, where 'y' value is closed in the same group.
+ :param g: index in the same group
+ :param bg: bbox in the same group
+ :return:
+ """
+
+ xs = [bg_item[0] for bg_item in bg]
+ xs_sorted = sorted(xs)
+
+ g_sorted = [None] * len(xs_sorted)
+ bg_sorted = [None] * len(xs_sorted)
+ for g_item, bg_item in zip(g, bg):
+ idx = xs_sorted.index(bg_item[0])
+ bg_sorted[idx] = bg_item
+ g_sorted[idx] = g_item
+
+ return g_sorted, bg_sorted
+
+
+def flatten(sorted_groups, sorted_bbox_groups):
+ idxs = []
+ bboxes = []
+ for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
+ for g, bg in zip(group, bbox_group):
+ idxs.append(g)
+ bboxes.append(bg)
+ return idxs, bboxes
+
+
+def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
+ """
+ This function will group the render end2end bboxes in row.
+ :param end2end_xywh_bboxes:
+ :param no_match_end2end_indexes:
+ :return:
+ """
+ groups = []
+ bbox_groups = []
+ for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
+ end2end_xywh_bboxes):
+ this_bbox = end2end_xywh_bbox
+ if len(groups) == 0:
+ groups.append([index])
+ bbox_groups.append([this_bbox])
+ else:
+ flag = False
+ for g, bg in zip(groups, bbox_groups):
+ # this_bbox is belong to bg's row or not
+ if is_abs_lower_than_threshold(this_bbox, bg[0]):
+ g.append(index)
+ bg.append(this_bbox)
+ flag = True
+ break
+ if not flag:
+ # this_bbox is not belong to bg's row, create a row.
+ groups.append([index])
+ bbox_groups.append([this_bbox])
+
+ # sorted bboxes in a group
+ tmp_groups, tmp_bbox_groups = [], []
+ for g, bg in zip(groups, bbox_groups):
+ g_sorted, bg_sorted = sort_line_bbox(g, bg)
+ tmp_groups.append(g_sorted)
+ tmp_bbox_groups.append(bg_sorted)
+
+ # sorted groups, sort by coord y's value.
+ sorted_groups = [None] * len(tmp_groups)
+ sorted_bbox_groups = [None] * len(tmp_bbox_groups)
+ ys = [bg[0][1] for bg in tmp_bbox_groups]
+ sorted_ys = sorted(ys)
+ for g, bg in zip(tmp_groups, tmp_bbox_groups):
+ idx = sorted_ys.index(bg[0][1])
+ sorted_groups[idx] = g
+ sorted_bbox_groups[idx] = bg
+
+ # flatten, get final result
+ end2end_sorted_idx_list, end2end_sorted_bbox_list \
+ = flatten(sorted_groups, sorted_bbox_groups)
+
+ # check sorted
+ #img = cv2.imread('/data_0/yejiaquan/data/TableRecognization/singleVal/PMC3286376_004_00.png')
+ #img = drawBboxAfterSorted(img, sorted_groups, sorted_bbox_groups)
+
+ return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
+
+
+def get_bboxes_list(end2end_result, structure_master_result):
+ """
+ This function is use to convert end2end results and structure master results to
+ List of xyxy bbox format and List of xywh bbox format
+ :param end2end_result: bbox's format is xyxy
+ :param structure_master_result: bbox's format is xywh
+ :return: 4 kind list of bbox ()
+ """
+ # end2end
+ end2end_xyxy_list = []
+ end2end_xywh_list = []
+ for end2end_item in end2end_result:
+ src_bbox = end2end_item['bbox']
+ end2end_xyxy_list.append(src_bbox)
+ xywh_bbox = xyxy2xywh(src_bbox)
+ end2end_xywh_list.append(xywh_bbox)
+ end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
+ end2end_xywh_bboxes = np.array(end2end_xywh_list)
+
+ # structure master
+ src_bboxes = structure_master_result['bbox']
+ src_bboxes = remove_empty_bboxes(src_bboxes)
+ # structure_master_xywh_bboxes = src_bboxes
+ # xyxy_bboxes = xywh2xyxy(src_bboxes)
+ # structure_master_xyxy_bboxes = xyxy_bboxes
+ structure_master_xyxy_bboxes = src_bboxes
+ xywh_bbox = xyxy2xywh(src_bboxes)
+ structure_master_xywh_bboxes = xywh_bbox
+
+ return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
+
+
+def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
+ """
+ Judge end2end Bbox's center point is inside structure master Bbox or not,
+ if end2end Bbox's center is in structure master Bbox, get matching pair.
+ :param end2end_xywh_bboxes:
+ :param structure_master_xyxy_bboxes:
+ :return: match pairs list, e.g. [[0,1], [1,2], ...]
+ """
+ match_pairs_list = []
+ for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
+ for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
+ x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
+ x_master1, y_master1, x_master2, y_master2 \
+ = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
+ center_point_end2end = (x_end2end, y_end2end)
+ corner_point_master = ((x_master1, y_master1),
+ (x_master2, y_master2))
+ if is_inside(center_point_end2end, corner_point_master):
+ match_pairs_list.append([i, j])
+ return match_pairs_list
+
+
+def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
+ structure_master_xyxy_bboxes):
+ """
+ Use iou to find matching list.
+ choose max iou value bbox as match pair.
+ :param end2end_xyxy_bboxes:
+ :param end2end_xyxy_indexes: original end2end indexes.
+ :param structure_master_xyxy_bboxes:
+ :return: match pairs list, e.g. [[0,1], [1,2], ...]
+ """
+ match_pair_list = []
+ for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
+ end2end_xyxy_bboxes):
+ max_iou = 0
+ max_match = [None, None]
+ for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
+ end2end_4xy = convert_coord(end2end_xyxy)
+ master_4xy = convert_coord(master_xyxy)
+ iou = cal_iou(end2end_4xy, master_4xy)
+ if iou > max_iou:
+ max_match[0], max_match[1] = end2end_xyxy_index, j
+ max_iou = iou
+
+ if max_match[0] is None:
+ # no match
+ continue
+ match_pair_list.append(max_match)
+ return match_pair_list
+
+
+def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
+ master_bboxes):
+ """
+ Get matching between no-match end2end bboxes and no-match master bboxes.
+ Use min distance to match.
+ This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
+ It will Return master_bboxes_nums match-pairs.
+ :param end2end_indexes:
+ :param end2end_bboxes:
+ :param master_indexes:
+ :param master_bboxes:
+ :return: match_pairs list, e.g. [[0,1], [1,2], ...]
+ """
+ min_match_list = []
+ for j, master_bbox in zip(master_indexes, master_bboxes):
+ min_distance = np.inf
+ min_match = [0, 0] # i, j
+ for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
+ x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
+ x_master, y_master = master_bbox[0], master_bbox[1]
+ end2end_point = (x_end2end, y_end2end)
+ master_point = (x_master, y_master)
+ dist = cal_distance(master_point, end2end_point)
+ if dist < min_distance:
+ min_match[0], min_match[1] = i, j
+ min_distance = dist
+ min_match_list.append(min_match)
+ return min_match_list
+
+
+def extra_match(no_match_end2end_indexes, master_bbox_nums):
+ """
+ This function will create some virtual master bboxes,
+ and get match with the no match end2end indexes.
+ :param no_match_end2end_indexes:
+ :param master_bbox_nums:
+ :return:
+ """
+ end_nums = len(no_match_end2end_indexes) + master_bbox_nums
+ extra_match_list = []
+ for i in range(master_bbox_nums, end_nums):
+ end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
+ extra_match_list.append([end2end_index, i])
+ return extra_match_list
+
+
+def match_visual(file_name,
+ match_list,
+ end2end_xyxy,
+ master_xyxy,
+ prex='ordinary_match'):
+ """
+ Show the match result by xyxy coord style.
+ :param file_name:
+ :param match_list:
+ :param end2end_xyxy:
+ :param master_xyxy:
+ :param prex:
+ :return:
+ """
+ folder = ''
+ save_folder = '/data_0/cache'
+ file_path = os.path.join(folder, file_name)
+ img_end2end = cv2.imread(file_path)
+ img_master = copy.deepcopy(img_end2end)
+ text_color = (0, 0, 255)
+ bbox_color = (255, 0, 0)
+ master_nums = len(master_xyxy)
+
+ for idx, match_group in enumerate(match_list):
+ end2end_idx, master_index = match_group[0], match_group[1]
+
+ # master_index larger than master_nums, did not draw master bbox.
+ if master_index < master_nums:
+ # draw master
+ master_bbox = master_xyxy[master_index]
+ img_master = cv2.rectangle(
+ img_master, (int(master_bbox[0]), int(master_bbox[1])),
+ (int(master_bbox[2]), int(master_bbox[3])),
+ bbox_color,
+ thickness=1)
+ master_text_coord = (int(master_bbox[0]) - 4, int(master_bbox[1]))
+ img_master = cv2.putText(img_master,
+ str(master_index), master_text_coord, 1, 1,
+ text_color, 2)
+
+ # draw end2end
+ end2end_bbox = end2end_xyxy[end2end_idx]
+ img_end2end = cv2.rectangle(
+ img_end2end, (int(end2end_bbox[0]), int(end2end_bbox[1])),
+ (int(end2end_bbox[2]), int(end2end_bbox[3])),
+ bbox_color,
+ thickness=1)
+ end2end_text_coord = (int(end2end_bbox[0]) - 4, int(end2end_bbox[1]))
+ # write end2end bbox matching master bbox's index
+ img_end2end = cv2.putText(img_end2end,
+ str(master_index), end2end_text_coord, 1, 1,
+ text_color, 2)
+
+ img = np.hstack([img_end2end, img_master])
+ save_path = os.path.join(save_folder, '{}_matchShow.png'.format(prex))
+ cv2.imwrite(save_path, img)
+
+
+def get_match_dict(match_list):
+ """
+ Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
+ :param match_list:
+ :return:
+ """
+ match_dict = dict()
+ for match_pair in match_list:
+ end2end_index, master_index = match_pair[0], match_pair[1]
+ if master_index not in match_dict.keys():
+ match_dict[master_index] = [end2end_index]
+ else:
+ match_dict[master_index].append(end2end_index)
+ return match_dict
+
+
+def deal_successive_space(text):
+ """
+ deal successive space character for text
+ 1. Replace ' '*3 with '' which is real space is text
+ 2. Remove ' ', which is split token, not true space
+ 3. Replace '' with ' ', to get real text
+ :param text:
+ :return:
+ """
+ text = text.replace(' ' * 3, '')
+ text = text.replace(' ', '')
+ text = text.replace('', ' ')
+ return text
+
+
+def reduce_repeat_bb(text_list, break_token):
+ """
+ convert ['Local', 'government', 'unit'] to ['Local government unit']
+ PS: maybe style Local is also exist, too. it can be processed like this.
+ :param text_list:
+ :param break_token:
+ :return:
+ """
+ count = 0
+ for text in text_list:
+ if text.startswith(''):
+ count += 1
+ if count == len(text_list):
+ new_text_list = []
+ for text in text_list:
+ text = text.replace('', '').replace('', '')
+ new_text_list.append(text)
+ return ['' + break_token.join(new_text_list) + '']
+ else:
+ return text_list
+
+
+def get_match_text_dict(match_dict, end2end_info, break_token=' '):
+ match_text_dict = dict()
+ for master_index, end2end_index_list in match_dict.items():
+ text_list = [
+ end2end_info[end2end_index]['text']
+ for end2end_index in end2end_index_list
+ ]
+ text_list = reduce_repeat_bb(text_list, break_token)
+ text = break_token.join(text_list)
+ match_text_dict[master_index] = text
+ return match_text_dict
+
+
+def merge_span_token(master_token_list):
+ """
+ Merge the span style token (row span or col span).
+ :param master_token_list:
+ :return:
+ """
+ new_master_token_list = []
+ pointer = 0
+ if master_token_list[-1] != '':
+ master_token_list.append('')
+ while master_token_list[pointer] != '':
+ try:
+ if master_token_list[pointer] == '' + ' | '
+ """
+ # tmp = master_token_list[pointer] + master_token_list[pointer+1] + master_token_list[pointer+2] + \
+ # master_token_list[pointer+3]
+ tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
+ pointer += 4
+ new_master_token_list.append(tmp)
+
+ elif master_token_list[pointer + 2].startswith(
+ ' colspan=') or master_token_list[
+ pointer + 2].startswith(' rowspan='):
+ """
+ example:
+ pattern
+ ' | ' + ' | '
+ """
+ # tmp = master_token_list[pointer] + master_token_list[pointer+1] + \
+ # master_token_list[pointer+2] + master_token_list[pointer+3] + master_token_list[pointer+4]
+ tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
+ pointer += 5
+ new_master_token_list.append(tmp)
+
+ else:
+ new_master_token_list.append(master_token_list[pointer])
+ pointer += 1
+ else:
+ new_master_token_list.append(master_token_list[pointer])
+ pointer += 1
+ except:
+ print("Break in merge...")
+ break
+ new_master_token_list.append('')
+
+ return new_master_token_list
+
+
+def deal_eb_token(master_token):
+ """
+ post process with , , ...
+ emptyBboxTokenDict = {
+ "[]": '',
+ "[' ']": '',
+ "['', ' ', '']": '',
+ "['\\u2028', '\\u2028']": '',
+ "['', ' ', '']": '',
+ "['', '']": '',
+ "['', ' ', '']": '',
+ "['', '', '', '']": '',
+ "['', '', ' ', '', '']": '',
+ "['', '']": '',
+ "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '',
+ }
+ :param master_token:
+ :return:
+ """
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('', '\u2028\u2028 | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('',
+ ' | ')
+ master_token = master_token.replace('',
+ ' | ')
+ master_token = master_token.replace('', ' | ')
+ master_token = master_token.replace('',
+ ' \u2028 \u2028 | ')
+ return master_token
+
+
+def insert_text_to_token(master_token_list, match_text_dict):
+ """
+ Insert OCR text result to structure token.
+ :param master_token_list:
+ :param match_text_dict:
+ :return:
+ """
+ master_token_list = merge_span_token(master_token_list)
+ merged_result_list = []
+ text_count = 0
+ for master_token in master_token_list:
+ if master_token.startswith(' len(match_text_dict) - 1:
+ text_count += 1
+ continue
+ elif text_count not in match_text_dict.keys():
+ text_count += 1
+ continue
+ else:
+ master_token = master_token.replace(
+ '><', '>{}<'.format(match_text_dict[text_count]))
+ text_count += 1
+ master_token = deal_eb_token(master_token)
+ merged_result_list.append(master_token)
+
+ return ''.join(merged_result_list)
+
+
+def deal_isolate_span(thead_part):
+ """
+ Deal with isolate span cases in this function.
+ It causes by wrong prediction in structure recognition model.
+ eg. predict | | to | rowspan="2">.
+ :param thead_part:
+ :return:
+ """
+ # 1. find out isolate span tokens.
+ isolate_pattern = " | rowspan=\"(\d)+\" colspan=\"(\d)+\">|" \
+ " | colspan=\"(\d)+\" rowspan=\"(\d)+\">|" \
+ " | rowspan=\"(\d)+\">|" \
+ " | colspan=\"(\d)+\">"
+ isolate_iter = re.finditer(isolate_pattern, thead_part)
+ isolate_list = [i.group() for i in isolate_iter]
+
+ # 2. find out span number, by step 1 results.
+ span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
+ " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
+ " rowspan=\"(\d)+\"|" \
+ " colspan=\"(\d)+\""
+ corrected_list = []
+ for isolate_item in isolate_list:
+ span_part = re.search(span_pattern, isolate_item)
+ spanStr_in_isolateItem = span_part.group()
+ # 3. merge the span number into the span token format string.
+ if spanStr_in_isolateItem is not None:
+ corrected_item = ' | '.format(spanStr_in_isolateItem)
+ corrected_list.append(corrected_item)
+ else:
+ corrected_list.append(None)
+
+ # 4. replace original isolated token.
+ for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+ if corrected_item is not None:
+ thead_part = thead_part.replace(isolate_item, corrected_item)
+ else:
+ pass
+ return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+ """
+ Deal duplicate or after replace.
+ Keep one in a | token.
+ :param thead_part:
+ :return:
+ """
+ # 1. find out | in .
+ td_pattern = "(.+?) | |" \
+ "(.+?) | |" \
+ "(.+?) | |" \
+ "(.+?) | |" \
+ "(.*?) | "
+ td_iter = re.finditer(td_pattern, thead_part)
+ td_list = [t.group() for t in td_iter]
+
+ # 2. is multiply in | or not?
+ new_td_list = []
+ for td_item in td_list:
+ if td_item.count('') > 1 or td_item.count('') > 1:
+ # multiply in | case.
+ # 1. remove all
+ td_item = td_item.replace('', '').replace('', '')
+ # 2. replace -> , -> .
+ td_item = td_item.replace('', ' | ').replace(' | ',
+ '')
+ new_td_list.append(td_item)
+ else:
+ new_td_list.append(td_item)
+
+ # 3. replace original thead part.
+ for td_item, new_td_item in zip(td_list, new_td_list):
+ thead_part = thead_part.replace(td_item, new_td_item)
+ return thead_part
+
+
+def deal_bb(result_token):
+ """
+ In our opinion, always occurs in text's context.
+ This function will find out all tokens in and insert by manual.
+ :param result_token:
+ :return:
+ """
+ # find out parts.
+ thead_pattern = '(.*?)'
+ if re.search(thead_pattern, result_token) is None:
+ return result_token
+ thead_part = re.search(thead_pattern, result_token).group()
+ origin_thead_part = copy.deepcopy(thead_part)
+
+ # check "rowspan" or "colspan" occur in parts or not .
+ span_pattern = "| | | | | | "
+ span_iter = re.finditer(span_pattern, thead_part)
+ span_list = [s.group() for s in span_iter]
+ has_span_in_head = True if len(span_list) > 0 else False
+
+ if not has_span_in_head:
+ # not include "rowspan" or "colspan" branch 1.
+ # 1. replace | to | , and | to
+ # 2. it is possible to predict text include or by Text-line recognition,
+ # so we replace to , and to
+ thead_part = thead_part.replace('', ' | ')\
+ .replace(' | ', '')\
+ .replace('', '')\
+ .replace('', '')
+ else:
+ # include "rowspan" or "colspan" branch 2.
+ # Firstly, we deal rowspan or colspan cases.
+ # 1. replace > to >
+ # 2. replace to
+ # 3. it is possible to predict text include or by Text-line recognition,
+ # so we replace to , and to
+
+ # Secondly, deal ordinary cases like branch 1
+
+ # replace ">" to ""
+ replaced_span_list = []
+ for sp in span_list:
+ replaced_span_list.append(sp.replace('>', '>'))
+ for sp, rsp in zip(span_list, replaced_span_list):
+ thead_part = thead_part.replace(sp, rsp)
+
+ # replace "" to ""
+ thead_part = thead_part.replace('', '')
+
+ # remove duplicated by re.sub
+ mb_pattern = "()+"
+ single_b_string = ""
+ thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+ mgb_pattern = "()+"
+ single_gb_string = ""
+ thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+ # ordinary cases like branch 1
+ thead_part = thead_part.replace('', ' | ').replace('',
+ '')
+
+ # convert back to , empty cell has no .
+ # but space cell( ) is suitable for | |
+ thead_part = thead_part.replace(' | ', ' | ')
+ # deal with duplicated
+ thead_part = deal_duplicate_bb(thead_part)
+ # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+ # eg.PMC5994107_011_00.png
+ thead_part = deal_isolate_span(thead_part)
+ # replace original result with new thead part.
+ result_token = result_token.replace(origin_thead_part, thead_part)
+ return result_token
+
+
+class Matcher:
+ def __init__(self, end2end_file, structure_master_file):
+ """
+ This class process the end2end results and structure recognition results.
+ :param end2end_file: end2end results predict by end2end inference.
+ :param structure_master_file: structure recognition results predict by structure master inference.
+ """
+ self.end2end_file = end2end_file
+ self.structure_master_file = structure_master_file
+ self.end2end_results = pickle_load(end2end_file, prefix='end2end')
+ self.structure_master_results = pickle_load(
+ structure_master_file, prefix='structure')
+
+ def match(self):
+ """
+ Match process:
+ pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
+ 1. Use pseBbox is inside masterBbox judge rule
+ 2. Use iou between pseBbox and masterBbox rule
+ 3. Use min distance of center point rule
+ :return:
+ """
+ match_results = dict()
+ for idx, (file_name,
+ end2end_result) in enumerate(self.end2end_results.items()):
+ match_list = []
+ if file_name not in self.structure_master_results:
+ continue
+ structure_master_result = self.structure_master_results[file_name]
+ end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
+ get_bboxes_list(end2end_result, structure_master_result)
+
+ # rule 1: center rule
+ center_rule_match_list = \
+ center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
+ match_list.extend(center_rule_match_list)
+
+ # rule 2: iou rule
+ # firstly, find not match index in previous step.
+ center_no_match_end2end_indexs = \
+ find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
+ if len(center_no_match_end2end_indexs) > 0:
+ center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
+ center_no_match_end2end_indexs]
+ # secondly, iou rule match
+ iou_rule_match_list = \
+ iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
+ match_list.extend(iou_rule_match_list)
+
+ # rule 3: distance rule
+ # match between no-match end2end bboxes and no-match master bboxes.
+ # it will return master_bboxes_nums match-pairs.
+ # firstly, find not match index in previous step.
+ centerIou_no_match_end2end_indexs = \
+ find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
+ centerIou_no_match_master_indexs = \
+ find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
+ if len(centerIou_no_match_master_indexs) > 0 and len(
+ centerIou_no_match_end2end_indexs) > 0:
+ centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
+ centerIou_no_match_end2end_indexs]
+ centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
+ centerIou_no_match_master_indexs]
+ distance_match_list = distance_rule_match(
+ centerIou_no_match_end2end_indexs,
+ centerIou_no_match_end2end_xywh,
+ centerIou_no_match_master_indexs,
+ centerIou_no_match_master_xywh)
+ match_list.extend(distance_match_list)
+
+ # TODO:
+ # The render no-match pseBbox, insert the last
+ # After step3 distance rule, a master bbox at least match one end2end bbox.
+ # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
+ # For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
+ # The above extra insert bboxes will be further processed in "formatOutput" function.
+ # After this operation, it will increase TEDS score.
+ no_match_end2end_indexes = \
+ find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
+ if len(no_match_end2end_indexes) > 0:
+ no_match_end2end_xywh = end2end_xywh_bboxes[
+ no_match_end2end_indexes]
+ # sort the render no-match end2end bbox in row
+ end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
+ sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
+ # make virtual master bboxes, and get matching with the no-match end2end bboxes.
+ extra_match_list = extra_match(
+ end2end_sorted_indexes_list,
+ len(structure_master_xywh_bboxes))
+ match_list_add_extra_match = copy.deepcopy(match_list)
+ match_list_add_extra_match.extend(extra_match_list)
+ else:
+ # no no-match end2end bboxes
+ match_list_add_extra_match = copy.deepcopy(match_list)
+ sorted_groups = []
+ sorted_bboxes_groups = []
+
+ match_result_dict = {
+ 'match_list': match_list,
+ 'match_list_add_extra_match': match_list_add_extra_match,
+ 'sorted_groups': sorted_groups,
+ 'sorted_bboxes_groups': sorted_bboxes_groups
+ }
+
+ # ordinary match show
+ # match_visual(file_name, match_list, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='ordinary_match')
+ # extra match show
+ # match_visual(file_name, match_list_add_extra_match, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='extra_match')
+
+ # format output
+ match_result_dict = self._format(match_result_dict, file_name)
+
+ match_results[file_name] = match_result_dict
+
+ return match_results
+
+ def _format(self, match_result, file_name):
+ """
+ Extend the master token(insert virtual master token), and format matching result.
+ :param match_result:
+ :param file_name:
+ :return:
+ """
+ end2end_info = self.end2end_results[file_name]
+ master_info = self.structure_master_results[file_name]
+ master_token = master_info['text']
+ sorted_groups = match_result['sorted_groups']
+
+ # creat virtual master token
+ virtual_master_token_list = []
+ for line_group in sorted_groups:
+ tmp_list = ['']
+ item_nums = len(line_group)
+ for _ in range(item_nums):
+ tmp_list.append(' | ')
+ tmp_list.append('
')
+ virtual_master_token_list.extend(tmp_list)
+
+ # insert virtual master token
+ master_token_list = master_token.split(',')
+ if master_token_list[-1] == '':
+ # complete predict(no cut by max length)
+ # This situation insert virtual master token will drop TEDs score in val set.
+ # So we will not extend virtual token in this situation.
+
+ # fake extend virtual
+ master_token_list[:-1].extend(virtual_master_token_list)
+
+ # real extend virtual
+ # master_token_list = master_token_list[:-1]
+ # master_token_list.extend(virtual_master_token_list)
+ # master_token_list.append('')
+
+ elif master_token_list[-1] == ' | ':
+ master_token_list.append('')
+ master_token_list.extend(virtual_master_token_list)
+ master_token_list.append('')
+ else:
+ master_token_list.extend(virtual_master_token_list)
+ master_token_list.append('')
+
+ # format output
+ match_result.setdefault('matched_master_token_list', master_token_list)
+ return match_result
+
+ def get_merge_result(self, match_results):
+ """
+ Merge the OCR result into structure token to get final results.
+ :param match_results:
+ :return:
+ """
+ merged_results = dict()
+
+ # break_token is linefeed token, when one master bbox has multiply end2end bboxes.
+ break_token = ' '
+
+ for idx, (file_name, match_info) in enumerate(match_results.items()):
+ end2end_info = self.end2end_results[file_name]
+ master_token_list = match_info['matched_master_token_list']
+ match_list = match_info['match_list_add_extra_match']
+
+ match_dict = get_match_dict(match_list)
+ match_text_dict = get_match_text_dict(match_dict, end2end_info,
+ break_token)
+ merged_result = insert_text_to_token(master_token_list,
+ match_text_dict)
+ merged_result = deal_bb(merged_result)
+
+ merged_results[file_name] = merged_result
+
+ return merged_results
+
+
+class TableMasterMatcher(Matcher):
+ def __init__(self):
+ pass
+
+ def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
+ end2end_results = {img_name: []}
+ for dt_box, res in zip(dt_boxes, rec_res):
+ d = dict(
+ bbox=np.array(dt_box),
+ text=res[0], )
+ end2end_results[img_name].append(d)
+
+ self.end2end_results = end2end_results
+
+ structure_master_result_dict = {img_name: {}}
+ pred_structures, pred_bboxes = structure_res
+ pred_structures = ','.join(pred_structures[3:-3])
+ structure_master_result_dict[img_name]['text'] = pred_structures
+ structure_master_result_dict[img_name]['bbox'] = pred_bboxes
+ self.structure_master_results = structure_master_result_dict
+
+ # match
+ match_results = self.match()
+ merged_results = self.get_merge_result(match_results)
+ pred_html = merged_results[img_name]
+ # pred_html = ''
+ return pred_html
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index af0616239..f5388fabf 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -32,6 +32,7 @@ def init_args():
type=str,
default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
+ parser.add_argument("--layout_model_dir", type=str)
parser.add_argument(
"--layout_path_model",
type=str,
@@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
- if region['type'] == 'Table':
+ if region['type'] == 'table':
pass
else:
for text_result in region['res']:
diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml
index 5d8e84c95..6ff31fc26 100755
--- a/test_tipc/configs/en_table_structure/table_mv3.yml
+++ b/test_tipc/configs/en_table_structure/table_mv3.yml
@@ -19,8 +19,6 @@ Global:
character_type: en
max_text_length: 800
infer_mode: False
- process_total_num: 0
- process_cut_num: 0
Optimizer:
name: Adam
diff --git a/test_tipc/configs/table_master/table_master.yml b/test_tipc/configs/table_master/table_master.yml
index c519b5b8f..27f81683b 100644
--- a/test_tipc/configs/table_master/table_master.yml
+++ b/test_tipc/configs/table_master/table_master.yml
@@ -16,8 +16,6 @@ Global:
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
- process_total_num: 0
- process_cut_num: 0
Optimizer:
@@ -86,7 +84,7 @@ Train:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
- use_xywh: True
+ box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
@@ -120,7 +118,7 @@ Eval:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
- use_xywh: True
+ box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index 625d365f4..73b7155ba 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -65,9 +65,11 @@ class TextSystem(object):
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True):
+ time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
+ start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
-
+ time_dict['det'] = elapse
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
@@ -83,10 +85,12 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
+ time_dict['cls'] = elapse
logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
+ time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if self.args.save_crop_res:
@@ -98,7 +102,9 @@ class TextSystem(object):
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
- return filter_boxes, filter_rec_res
+ end = time.time()
+ time_dict['all'] = end - start
+ return filter_boxes, filter_rec_res, time_dict
def sorted_boxes(dt_boxes):
@@ -133,9 +139,11 @@ def main(args):
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
- logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
- "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
-
+ logger.info(
+ "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
+ "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
+ )
+
# warm up 10 times
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
@@ -155,7 +163,7 @@ def main(args):
logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
- dt_boxes, rec_res = text_sys(img)
+ dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
@@ -198,7 +206,10 @@ def main(args):
text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report()
- with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
+ with open(
+ os.path.join(draw_img_save_dir, "system_results.txt"),
+ 'w',
+ encoding='utf-8') as f:
f.writelines(save_results)
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 7eb77dec7..6ad770e28 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -155,6 +155,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
+ elif mode == 'layout':
+ model_dir = args.layout_model_dir
else:
model_dir = args.e2e_model_dir
diff --git a/tools/infer_table.py b/tools/infer_table.py
index 6c02dd864..70dc6205d 100644
--- a/tools/infer_table.py
+++ b/tools/infer_table.py
@@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
- use_xywh = algorithm in ['TableMaster']
load_model(config, model)
@@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
- img = draw_rectangle(file, bbox_list, use_xywh)
+ img = draw_rectangle(file, bbox_list)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!")
diff --git a/tools/program.py b/tools/program.py
index 0fa0e609b..1802e8529 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -154,6 +154,7 @@ def check_xpu(use_xpu):
except Exception as e:
pass
+
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
@@ -173,6 +174,7 @@ def to_float32(preds):
preds = preds.astype(paddle.float32)
return preds
+
def train(config,
train_dataloader,
valid_dataloader,
@@ -596,7 +598,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'SLANet'
]
if use_xpu:
diff --git a/tools/train.py b/tools/train.py
index 309d4bb9e..0e45b5b70 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
+ use_sync_bn = config["Global"].get("use_sync_bn", False)
+ if use_sync_bn:
+ model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ logger.info('convert_sync_batchnorm')
if config['Global']['distributed']:
model = paddle.DataParallel(model)
@@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
- model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
+ model, optimizer = paddle.amp.decorate(
+ models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None