fix static train in formula ()

pull/14851/head
liuhongen1234567 2025-03-08 00:24:37 +08:00 committed by GitHub
parent 28657d428b
commit 1ccf688ca2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 3 deletions
ppocr/modeling/architectures

View File

@ -18,9 +18,10 @@ Global:
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
max_new_tokens: &max_new_tokens 1024
input_size: &input_size [768, 768]
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
save_res_path: ./output/rec/predicts_pp_formulanet_l.txt
allow_resize_largeImg: False
start_ema: True
d2s_train_image_shape: [1,768,768]
Optimizer:
name: AdamW

View File

@ -18,9 +18,10 @@ Global:
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
max_new_tokens: &max_new_tokens 1024
input_size: &input_size [384, 384]
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
save_res_path: ./output/rec/predicts_pp_formulanet_s.txt
allow_resize_largeImg: False
start_ema: True
d2s_train_image_shape: [1,384,384]
Optimizer:
name: AdamW

View File

@ -18,8 +18,9 @@ Global:
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
input_size: &input_size [192, 672]
max_seq_len: &max_seq_len 1024
save_res_path: ./output/rec/predicts_unimernet_plus_config_latexocr.txt
save_res_path: ./output/rec/predicts_unimernet.txt
allow_resize_largeImg: False
d2s_train_image_shape: [1,192,672]
Optimizer:
name: AdamW

View File

@ -50,6 +50,9 @@ def apply_to_static(model, config, logger):
"SVTR",
"SVTR_HGNet",
"LaTeXOCR",
"UniMERNet",
"PP-FormulaNet-S",
"PP-FormulaNet-L",
]
if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
@ -127,6 +130,16 @@ def apply_to_static(model, config, logger):
InputSpec(shape=[None, None], dtype="float32"),
]
]
elif algo in ["UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L"]:
specs = [
[
InputSpec(
[None] + config["Global"]["d2s_train_image_shape"], dtype="float32"
),
InputSpec(shape=[None, None], dtype="float32"),
InputSpec(shape=[None, None], dtype="float32"),
]
]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model