From 16c247ac465af18ac8616f5b5d04646f9b99414f Mon Sep 17 00:00:00 2001
From: MissPenguin <lichenxia1991@163.com>
Date: Mon, 21 Jun 2021 12:20:25 +0000
Subject: [PATCH] refine

---
 configs/table/table_mv3.yml                | 24 +++++++++---------
 ppocr/data/imaug/label_ops.py              | 20 ---------------
 ppocr/data/pubtab_dataset.py               | 22 ++--------------
 ppocr/modeling/architectures/base_model.py |  2 +-
 ppocr/modeling/heads/table_att_head.py     | 22 ++++++++--------
 ppocr/modeling/necks/table_fpn.py          | 29 ++++++++--------------
 ppocr/postprocess/rec_postprocess.py       | 12 ---------
 tools/export_model.py                      |  3 ++-
 tools/infer_table.py                       |  4 +--
 tools/program.py                           |  3 ++-
 10 files changed, 40 insertions(+), 101 deletions(-)

diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
index 32164fe30..a74e18d31 100755
--- a/configs/table/table_mv3.yml
+++ b/configs/table/table_mv3.yml
@@ -1,13 +1,12 @@
 Global:
   use_gpu: true
-  epoch_num: 40
+  epoch_num: 50
   log_smooth_window: 20
   print_batch_step: 5
   save_model_dir: ./output/table_mv3/
-  save_epoch_step: 3
-  # evaluation is run every 5000 iterations after the 4000th iteration
+  save_epoch_step: 5
+  # evaluation is run every 400 iterations after the 0th iteration
   eval_batch_step: [0, 400]
-  # if pretrained_model is saved in static mode, load_static_weights must set to True
   cal_metric_during_train: True
   pretrained_model: 
   checkpoints: 
@@ -18,19 +17,20 @@ Global:
   character_dict_path: ppocr/utils/dict/table_structure_dict.txt
   character_type: en
   max_text_length: 100
-  max_elem_length: 800
+  max_elem_length: 500
   max_cell_num: 500
   infer_mode: False
   process_total_num: 0
   process_cut_num: 0
 
+
 Optimizer:
   name: Adam
   beta1: 0.9
   beta2: 0.999
   clip_norm: 5.0
   lr:
-    learning_rate: 0.0001
+    learning_rate: 0.001
   regularizer:
     name: 'L2'
     factor: 0.00000
@@ -41,12 +41,12 @@ Architecture:
   Backbone:
     name: MobileNetV3
     scale: 1.0
-    model_name: large
+    model_name: small
+    disable_se: True
   Head:
-    name: TableAttentionHead  # AttentionHead
-    hidden_size: 256 #
+    name: TableAttentionHead
+    hidden_size: 256
     l2_decay: 0.00001
-#     loc_type: 1
     loc_type: 2
 
 Loss:
@@ -86,7 +86,7 @@ Train:
     shuffle: True
     batch_size_per_card: 32
     drop_last: True
-    num_workers: 4
+    num_workers: 1
 
 Eval:
   dataset:
@@ -113,4 +113,4 @@ Eval:
     shuffle: False
     drop_last: False
     batch_size_per_card: 16
-    num_workers: 4
+    num_workers: 1
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index cd883d1b4..e25cce79b 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
             return None
         elem_num = len(structure)
         structure = [0] + structure + [len(self.dict_elem) - 1]
-#         structure = [0] + structure + [0]
         structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
         structure = np.array(structure)
         data['structure'] = structure
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
                 if cand_span_idx < (self.max_elem_length + 2):
                     if structure[cand_span_idx] in span_idx_list:
                         structure_mask[cand_span_idx] = span_weight
-#                         structure_mask[td_idx] = self.span_weight
-#                         structure_mask[cand_span_idx] = self.span_weight
 
         data['bbox_list'] = bbox_list
         data['bbox_list_mask'] = bbox_list_mask
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
             self.max_elem_length, self.max_cell_num, elem_num])
         return data
 
-        ########
-        # for char decode
-#         cell_list = []
-#         for cell in cells:
-#             char_list = cell['tokens']
-#             cell = self.encode(char_list, 'char')
-#             if cell is None:
-#                 return None
-#             cell = [0] + cell + [len(self.dict_character) - 1]
-#             cell = cell + [0] * (self.max_text_length + 2 - len(cell))
-#             cell_list.append(cell)
-#         cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
-#         cell_list = np.array(cell_list)
-#         cell_list_padding[0:cell_list.shape[0]] = cell_list
-#         data['cells'] = cell_list_padding
-#         return data
-
     def encode(self, text, char_or_elem):
         """convert text-label into text-index.
         """
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
index a2c3eebf7..78b76c5af 100644
--- a/ppocr/data/pubtab_dataset.py
+++ b/ppocr/data/pubtab_dataset.py
@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ import json
 
 from .imaug import transform, create_operators
 
+
 class PubTabDataSet(Dataset):
     def __init__(self, config, mode, logger, seed=None):
         super(PubTabDataSet, self).__init__()
@@ -57,23 +58,6 @@ class PubTabDataSet(Dataset):
             random.seed(self.seed)
             random.shuffle(self.data_lines)
         return
-    
-    def load_hard_select_prob(self):
-        label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
-        img_select_prob = {}
-        with open(label_path, "rb") as fin:
-            lines = fin.readlines()
-            for lno in range(len(lines)):
-                substr = lines[lno].decode('utf-8').strip("\n").split(" ")
-                img_name = substr[0].strip(":")
-                score = float(substr[1])
-                if score <= 0.8:
-                    img_select_prob[img_name] = self.hard_prob[0]
-                elif score <= 0.98:
-                    img_select_prob[img_name] = self.hard_prob[1]
-                else:
-                    img_select_prob[img_name] = self.hard_prob[2]
-        return img_select_prob
 
     def __getitem__(self, idx):
         try:
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
                 table_type = "simple"
                 if 'colspan' in structure_str or 'rowspan' in structure_str:
                     table_type = "complex"
-#                 if self.table_select_type != table_type:
-#                     select_flag = False
                 if table_type == "complex":
                     if self.table_select_prob < random.uniform(0, 1):
                         select_flag = False                    
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index 49160b528..c1bdaaafb 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index 9e5c438a3..61dacd355 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -21,13 +21,16 @@ import paddle.nn as nn
 import paddle.nn.functional as F
 import numpy as np
 
+
 class TableAttentionHead(nn.Layer):
     def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
         super(TableAttentionHead, self).__init__()
         self.input_size = in_channels[-1]
         self.hidden_size = hidden_size
-        self.char_num = 280
         self.elem_num = 30
+        self.max_text_length = 100
+        self.max_elem_length = 500
+        self.max_cell_num = 500
 
         self.structure_attention_cell = AttentionGRUCell(
             self.input_size, hidden_size, self.elem_num, use_gru=False)
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
             self.loc_generator = nn.Linear(hidden_size, 4)
         else:
             if self.in_max_len == 640:
-                self.loc_fea_trans = nn.Linear(400, 801)
+                self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
             elif self.in_max_len == 800:
-                self.loc_fea_trans = nn.Linear(625, 801)
+                self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
             else:
-                self.loc_fea_trans = nn.Linear(256, 801)
+                self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
             self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
             
     def _char_to_onehot(self, input_char, onehot_dim):
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
             fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
             fea = fea.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
         batch_size = fea.shape[0]
-        #sp_tokens = targets[2].numpy()
-        #char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
-        #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
-        #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
-        #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
-        max_text_length, max_elem_length, max_cell_num = 100, 800, 500
         
         hidden = paddle.zeros((batch_size, self.hidden_size))
         output_hiddens = []
         if mode == 'Train' and targets is not None:
             structure = targets[0]
-            for i in range(max_elem_length+1):
+            for i in range(self.max_elem_length+1):
                 elem_onehots = self._char_to_onehot(
                     structure[:, i], onehot_dim=self.elem_num)
                 (outputs, hidden), alpha = self.structure_attention_cell(
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
             elem_onehots = None
             outputs = None
             alpha = None
-            max_elem_length = paddle.to_tensor(max_elem_length)
+            max_elem_length = paddle.to_tensor(self.max_elem_length)
             i = 0
             while i < max_elem_length+1:
                 elem_onehots = self._char_to_onehot(
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
                 loc_preds = F.sigmoid(loc_preds)
         return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
 
+    
 class AttentionGRUCell(nn.Layer):
     def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
         super(AttentionGRUCell, self).__init__()
diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py
index d72bff4ff..734f15af6 100644
--- a/ppocr/modeling/necks/table_fpn.py
+++ b/ppocr/modeling/necks/table_fpn.py
@@ -1,4 +1,4 @@
-# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
             in_channels=in_channels[0],
             out_channels=self.out_channels,
             kernel_size=1,
-            weight_attr=ParamAttr(
-                name='conv2d_51.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.in3_conv = nn.Conv2D(
             in_channels=in_channels[1],
             out_channels=self.out_channels,
             kernel_size=1,
             stride = 1,
-            weight_attr=ParamAttr(
-                name='conv2d_50.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.in4_conv = nn.Conv2D(
             in_channels=in_channels[2],
             out_channels=self.out_channels,
             kernel_size=1,
-            weight_attr=ParamAttr(
-                name='conv2d_49.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.in5_conv = nn.Conv2D(
             in_channels=in_channels[3],
             out_channels=self.out_channels,
             kernel_size=1,
-            weight_attr=ParamAttr(
-                name='conv2d_48.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.p5_conv = nn.Conv2D(
             in_channels=self.out_channels,
             out_channels=self.out_channels // 4,
             kernel_size=3,
             padding=1,
-            weight_attr=ParamAttr(
-                name='conv2d_52.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.p4_conv = nn.Conv2D(
             in_channels=self.out_channels,
             out_channels=self.out_channels // 4,
             kernel_size=3,
             padding=1,
-            weight_attr=ParamAttr(
-                name='conv2d_53.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.p3_conv = nn.Conv2D(
             in_channels=self.out_channels,
             out_channels=self.out_channels // 4,
             kernel_size=3,
             padding=1,
-            weight_attr=ParamAttr(
-                name='conv2d_54.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.p2_conv = nn.Conv2D(
             in_channels=self.out_channels,
             out_channels=self.out_channels // 4,
             kernel_size=3,
             padding=1,
-            weight_attr=ParamAttr(
-                name='conv2d_55.w_0', initializer=weight_attr),
+            weight_attr=ParamAttr(initializer=weight_attr),
             bias_attr=False)
         self.fuse_conv = nn.Conv2D(
             in_channels=self.out_channels * 4,
             out_channels=512,
             kernel_size=3,
             padding=1,
-            weight_attr=ParamAttr(
-                name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
+            weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
 
     def forward(self, x):
         c2, c3, c4, c5 = x
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 9429d6b47..912d9bbae 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -368,18 +368,6 @@ class TableLabelDecode(object):
         self.end_str = "eos"
         list_character = [self.beg_str] + list_character + [self.end_str]
         return list_character
-
-    def get_sp_tokens(self):
-        char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
-        char_end_idx = self.get_beg_end_flag_idx('end', 'char')
-        elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
-        elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
-        elem_char_idx1 = self.dict_elem['<td>']
-        elem_char_idx2 = self.dict_elem['<td']
-        sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx, 
-            elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, 
-            self.max_elem_length, self.max_cell_num])
-        return sp_tokens
     
     def __call__(self, preds):
         structure_probs = preds['structure_probs']
diff --git a/tools/export_model.py b/tools/export_model.py
index 625c82468..785aca10e 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
                     "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
                 )
                 infer_shape[-1] = 100
-
+        elif arch_config["model_type"] == "table":
+            infer_shape = [3, 488, 488]
         model = to_static(
             model,
             input_spec=[
diff --git a/tools/infer_table.py b/tools/infer_table.py
index 494f3936f..450a83dd0 100644
--- a/tools/infer_table.py
+++ b/tools/infer_table.py
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
             img = f.read()
             data = {'image': img}
         batch = transform(data, ops)
-        sp_tokens = post_process_class.get_sp_tokens()
-        targets = [[], [], paddle.to_tensor([sp_tokens])]
         images = np.expand_dims(batch[0], axis=0)
         images = paddle.to_tensor(images)
-        preds = model(images, data=targets, mode='Test')
+        preds = model(images, data=None, mode='Test')
         post_result = post_process_class(preds)
         res_html_code = post_result['res_html_code']
         res_loc = post_result['res_loc']
diff --git a/tools/program.py b/tools/program.py
index 06a8d7423..d68e123a8 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -276,6 +276,7 @@ def train(config,
                     valid_dataloader,
                     post_process_class,
                     eval_class,
+                    "table",
                     use_srn=use_srn)
                 cur_metric_str = 'cur metric, {}'.format(', '.join(
                     ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))