mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix gap between table structure train model and inference model (#4565)
* add indent in pipeline_rpc_client.py * fix gap in table structure train model and inference model
This commit is contained in:
parent
a8960021ed
commit
b6a21419d6
@ -1,29 +1,28 @@
|
|||||||
Global:
|
Global:
|
||||||
use_gpu: true
|
use_gpu: true
|
||||||
epoch_num: 50
|
epoch_num: 400
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 5
|
print_batch_step: 5
|
||||||
save_model_dir: ./output/table_mv3/
|
save_model_dir: ./output/table_mv3/
|
||||||
save_epoch_step: 5
|
save_epoch_step: 3
|
||||||
# evaluation is run every 400 iterations after the 0th iteration
|
# evaluation is run every 400 iterations after the 0th iteration
|
||||||
eval_batch_step: [0, 400]
|
eval_batch_step: [0, 400]
|
||||||
cal_metric_during_train: True
|
cal_metric_during_train: True
|
||||||
pretrained_model:
|
pretrained_model:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
infer_img: doc/table/table.jpg
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||||
character_type: en
|
character_type: en
|
||||||
max_text_length: 100
|
max_text_length: 100
|
||||||
max_elem_length: 500
|
max_elem_length: 800
|
||||||
max_cell_num: 500
|
max_cell_num: 500
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
process_total_num: 0
|
process_total_num: 0
|
||||||
process_cut_num: 0
|
process_cut_num: 0
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
@ -41,13 +40,15 @@ Architecture:
|
|||||||
Backbone:
|
Backbone:
|
||||||
name: MobileNetV3
|
name: MobileNetV3
|
||||||
scale: 1.0
|
scale: 1.0
|
||||||
model_name: small
|
model_name: large
|
||||||
disable_se: True
|
|
||||||
Head:
|
Head:
|
||||||
name: TableAttentionHead
|
name: TableAttentionHead
|
||||||
hidden_size: 256
|
hidden_size: 256
|
||||||
l2_decay: 0.00001
|
l2_decay: 0.00001
|
||||||
loc_type: 2
|
loc_type: 2
|
||||||
|
max_text_length: 100
|
||||||
|
max_elem_length: 800
|
||||||
|
max_cell_num: 500
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
name: TableAttentionLoss
|
name: TableAttentionLoss
|
||||||
|
@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir):
|
|||||||
image_data = file.read()
|
image_data = file.read()
|
||||||
image = cv2_to_base64(image_data)
|
image = cv2_to_base64(image_data)
|
||||||
|
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
|
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
|
||||||
print(ret)
|
print(ret)
|
||||||
|
@ -23,32 +23,40 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TableAttentionHead(nn.Layer):
|
class TableAttentionHead(nn.Layer):
|
||||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
loc_type,
|
||||||
|
in_max_len=488,
|
||||||
|
max_text_length=100,
|
||||||
|
max_elem_length=800,
|
||||||
|
max_cell_num=500,
|
||||||
|
**kwargs):
|
||||||
super(TableAttentionHead, self).__init__()
|
super(TableAttentionHead, self).__init__()
|
||||||
self.input_size = in_channels[-1]
|
self.input_size = in_channels[-1]
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.elem_num = 30
|
self.elem_num = 30
|
||||||
self.max_text_length = 100
|
self.max_text_length = max_text_length
|
||||||
self.max_elem_length = 500
|
self.max_elem_length = max_elem_length
|
||||||
self.max_cell_num = 500
|
self.max_cell_num = max_cell_num
|
||||||
|
|
||||||
self.structure_attention_cell = AttentionGRUCell(
|
self.structure_attention_cell = AttentionGRUCell(
|
||||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||||
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||||
self.loc_type = loc_type
|
self.loc_type = loc_type
|
||||||
self.in_max_len = in_max_len
|
self.in_max_len = in_max_len
|
||||||
|
|
||||||
if self.loc_type == 1:
|
if self.loc_type == 1:
|
||||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||||
else:
|
else:
|
||||||
if self.in_max_len == 640:
|
if self.in_max_len == 640:
|
||||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
|
||||||
elif self.in_max_len == 800:
|
elif self.in_max_len == 800:
|
||||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
|
||||||
else:
|
else:
|
||||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
|
||||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||||
|
|
||||||
def _char_to_onehot(self, input_char, onehot_dim):
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
return input_ont_hot
|
return input_ont_hot
|
||||||
@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
|
|||||||
if len(fea.shape) == 3:
|
if len(fea.shape) == 3:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||||
batch_size = fea.shape[0]
|
batch_size = fea.shape[0]
|
||||||
|
|
||||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
output_hiddens = []
|
output_hiddens = []
|
||||||
if self.training and targets is not None:
|
if self.training and targets is not None:
|
||||||
structure = targets[0]
|
structure = targets[0]
|
||||||
for i in range(self.max_elem_length+1):
|
for i in range(self.max_elem_length + 1):
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
structure[:, i], onehot_dim=self.elem_num)
|
structure[:, i], onehot_dim=self.elem_num)
|
||||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
|
|||||||
alpha = None
|
alpha = None
|
||||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||||
i = 0
|
i = 0
|
||||||
while i < max_elem_length+1:
|
while i < max_elem_length + 1:
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
temp_elem, onehot_dim=self.elem_num)
|
temp_elem, onehot_dim=self.elem_num)
|
||||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
|
|||||||
structure_probs_step = self.structure_generator(outputs)
|
structure_probs_step = self.structure_generator(outputs)
|
||||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
output = paddle.concat(output_hiddens, axis=1)
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
structure_probs = self.structure_generator(output)
|
structure_probs = self.structure_generator(output)
|
||||||
structure_probs = F.softmax(structure_probs)
|
structure_probs = F.softmax(structure_probs)
|
||||||
@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
|
|||||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||||
loc_preds = self.loc_generator(loc_concat)
|
loc_preds = self.loc_generator(loc_concat)
|
||||||
loc_preds = F.sigmoid(loc_preds)
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionGRUCell(nn.Layer):
|
class AttentionGRUCell(nn.Layer):
|
||||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
super(AttentionGRUCell, self).__init__()
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user