fix gap between table structure train model and inference model (#4565)

* add indent in pipeline_rpc_client.py

* fix gap in table structure train model and inference model
pull/4570/head^2
zhoujun 2021-11-10 20:18:48 +08:00 committed by GitHub
parent a8960021ed
commit b6a21419d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 27 deletions

View File

@ -1,10 +1,10 @@
Global:
use_gpu: true
epoch_num: 50
epoch_num: 400
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 5
save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
@ -12,18 +12,17 @@ Global:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
infer_img: doc/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 500
max_elem_length: 800
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
@ -41,13 +40,15 @@ Architecture:
Backbone:
name: MobileNetV3
scale: 1.0
model_name: small
disable_se: True
model_name: large
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
Loss:
name: TableAttentionLoss

View File

@ -23,14 +23,22 @@ import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)