fix gap between table structure train model and inference model (#4565)

* add indent in pipeline_rpc_client.py

* fix gap in table structure train model and inference model
This commit is contained in:
zhoujun 2021-11-10 20:18:48 +08:00 committed by GitHub
parent a8960021ed
commit b6a21419d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 27 deletions

View File

@ -1,10 +1,10 @@
Global: Global:
use_gpu: true use_gpu: true
epoch_num: 50 epoch_num: 400
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 5 save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
@ -12,18 +12,17 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/table/table.jpg
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 100
max_elem_length: 500 max_elem_length: 800
max_cell_num: 500 max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9 beta1: 0.9
@ -41,13 +40,15 @@ Architecture:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 1.0 scale: 1.0
model_name: small model_name: large
disable_se: True
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001 l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss

View File

@ -23,14 +23,22 @@ import numpy as np
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.elem_num = 30
self.max_text_length = 100 self.max_text_length = max_text_length
self.max_elem_length = 500 self.max_elem_length = max_elem_length
self.max_cell_num = 500 self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.elem_num, use_gru=False)