Add pp formulanet (#14429)
* add ppformulanet * rename loss * modify doc * add export code * modify yaml for global refpull/14441/head
parent
0697d248f8
commit
d523388ed1
|
@ -0,0 +1,117 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 10
|
||||||
|
log_smooth_window: 10
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/pp_formulanet_l/
|
||||||
|
save_epoch_step: 2
|
||||||
|
# evaluation is run every 417 iterations (1 epoch)(batch_size = 24) # max_seq_len: 1024
|
||||||
|
eval_batch_step: [0, 417 ]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/datasets/pme_demo/0000013.png
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
|
||||||
|
max_new_tokens: &max_new_tokens 1024
|
||||||
|
input_size: &input_size [768, 768]
|
||||||
|
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
|
||||||
|
allow_resize_largeImg: False
|
||||||
|
start_ema: True
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
weight_decay: 0.05
|
||||||
|
lr:
|
||||||
|
name: LinearWarmupCosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: PP-FormulaNet-L
|
||||||
|
in_channels: 3
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: Vary_VIT_B_Formula
|
||||||
|
image_size: 768
|
||||||
|
encoder_embed_dim: 768
|
||||||
|
encoder_depth: 12
|
||||||
|
encoder_num_heads: 12
|
||||||
|
encoder_global_attn_indexes: [2, 5, 8, 11]
|
||||||
|
Head:
|
||||||
|
name: PPFormulaNet_Head
|
||||||
|
max_new_tokens: *max_new_tokens
|
||||||
|
decoder_start_token_id: 0
|
||||||
|
decoder_ffn_dim: 2048
|
||||||
|
decoder_hidden_size: 512
|
||||||
|
decoder_layers: 8
|
||||||
|
temperature: 0.2
|
||||||
|
do_sample: False
|
||||||
|
top_p: 0.95
|
||||||
|
encoder_hidden_size: 1024
|
||||||
|
is_export: False
|
||||||
|
length_aware: False
|
||||||
|
use_parallel: False
|
||||||
|
parallel_step: 0
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: PPFormulaNet_L_Loss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: UniMERNetDecode
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: LaTeXOCRMetric
|
||||||
|
main_indicator: exp_rate
|
||||||
|
cal_blue_score: False
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./ocr_rec_latexocr_dataset_example
|
||||||
|
label_file_list: ["./ocr_rec_latexocr_dataset_example/train.txt"]
|
||||||
|
transforms:
|
||||||
|
- UniMERNetImgDecode:
|
||||||
|
input_size: *input_size
|
||||||
|
- UniMERNetTrainTransform:
|
||||||
|
- LatexImageFormat:
|
||||||
|
- UniMERNetLabelEncode:
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
max_seq_len: *max_new_tokens
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'attention_mask']
|
||||||
|
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 6
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: UniMERNetCollator
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./ocr_rec_latexocr_dataset_example
|
||||||
|
label_file_list: ["./ocr_rec_latexocr_dataset_example/val.txt"]
|
||||||
|
transforms:
|
||||||
|
- UniMERNetImgDecode:
|
||||||
|
input_size: *input_size
|
||||||
|
- UniMERNetTestTransform:
|
||||||
|
- LatexImageFormat:
|
||||||
|
- UniMERNetLabelEncode:
|
||||||
|
max_seq_len: *max_new_tokens
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'attention_mask', 'filename']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 10
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: UniMERNetCollator
|
|
@ -0,0 +1,115 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 20
|
||||||
|
log_smooth_window: 10
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/pp_formulanet_s/
|
||||||
|
save_epoch_step: 2
|
||||||
|
# evaluation is run every 179 iterations (1 epoch)(batch_size = 56) # max_seq_len: 1024
|
||||||
|
eval_batch_step: [0, 179]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/datasets/pme_demo/0000013.png
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
|
||||||
|
max_new_tokens: &max_new_tokens 1024
|
||||||
|
input_size: &input_size [384, 384]
|
||||||
|
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
|
||||||
|
allow_resize_largeImg: False
|
||||||
|
start_ema: True
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
weight_decay: 0.05
|
||||||
|
lr:
|
||||||
|
name: LinearWarmupCosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: PP-FormulaNet-S
|
||||||
|
in_channels: 3
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: PPHGNetV2_B4
|
||||||
|
class_num: 1024
|
||||||
|
|
||||||
|
Head:
|
||||||
|
name: PPFormulaNet_Head
|
||||||
|
max_new_tokens: *max_new_tokens
|
||||||
|
decoder_start_token_id: 0
|
||||||
|
decoder_ffn_dim: 1536
|
||||||
|
decoder_hidden_size: 384
|
||||||
|
decoder_layers: 2
|
||||||
|
temperature: 0.2
|
||||||
|
do_sample: False
|
||||||
|
top_p: 0.95
|
||||||
|
encoder_hidden_size: 2048
|
||||||
|
is_export: False
|
||||||
|
length_aware: True
|
||||||
|
use_parallel: True,
|
||||||
|
parallel_step: 3
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: PPFormulaNet_S_Loss
|
||||||
|
parallel_step: 3
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: UniMERNetDecode
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: LaTeXOCRMetric
|
||||||
|
main_indicator: exp_rate
|
||||||
|
cal_blue_score: False
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./ocr_rec_latexocr_dataset_example
|
||||||
|
label_file_list: ["./ocr_rec_latexocr_dataset_example/train.txt"]
|
||||||
|
transforms:
|
||||||
|
- UniMERNetImgDecode:
|
||||||
|
input_size: *input_size
|
||||||
|
- UniMERNetTrainTransform:
|
||||||
|
- LatexImageFormat:
|
||||||
|
- UniMERNetLabelEncode:
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
max_seq_len: *max_new_tokens
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'attention_mask']
|
||||||
|
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 14
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: UniMERNetCollator
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./ocr_rec_latexocr_dataset_example
|
||||||
|
label_file_list: ["./ocr_rec_latexocr_dataset_example/val.txt"]
|
||||||
|
transforms:
|
||||||
|
- UniMERNetImgDecode:
|
||||||
|
input_size: *input_size
|
||||||
|
- UniMERNetTestTransform:
|
||||||
|
- LatexImageFormat:
|
||||||
|
- UniMERNetLabelEncode:
|
||||||
|
max_seq_len: *max_new_tokens
|
||||||
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'attention_mask', 'filename']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 30
|
||||||
|
num_workers: 0
|
||||||
|
collate_fn: UniMERNetCollator
|
|
@ -15,7 +15,9 @@ Global:
|
||||||
infer_img: doc/datasets/pme_demo/0000013.png
|
infer_img: doc/datasets/pme_demo/0000013.png
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
use_space_char: False
|
use_space_char: False
|
||||||
rec_char_dict_path: ppocr/utils/dict/unimernet_tokenizer
|
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
|
||||||
|
input_size: &input_size [192, 672]
|
||||||
|
max_seq_len: &max_seq_len 1024
|
||||||
save_res_path: ./output/rec/predicts_unimernet_plus_config_latexocr.txt
|
save_res_path: ./output/rec/predicts_unimernet_plus_config_latexocr.txt
|
||||||
allow_resize_largeImg: False
|
allow_resize_largeImg: False
|
||||||
|
|
||||||
|
@ -59,7 +61,7 @@ Loss:
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: UniMERNetDecode
|
name: UniMERNetDecode
|
||||||
rec_char_dict_path: ppocr/utils/dict/unimernet_tokenizer
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
|
|
||||||
Metric:
|
Metric:
|
||||||
name: LaTeXOCRMetric
|
name: LaTeXOCRMetric
|
||||||
|
@ -73,12 +75,12 @@ Train:
|
||||||
label_file_list: ["./train_data/UniMERNet/train_unimernet_1M.txt"]
|
label_file_list: ["./train_data/UniMERNet/train_unimernet_1M.txt"]
|
||||||
transforms:
|
transforms:
|
||||||
- UniMERNetImgDecode:
|
- UniMERNetImgDecode:
|
||||||
input_size: [192, 672]
|
input_size: *input_size
|
||||||
- UniMERNetTrainTransform:
|
- UniMERNetTrainTransform:
|
||||||
- UniMERNetImageFormat:
|
- UniMERNetImageFormat:
|
||||||
- UniMERNetLabelEncode:
|
- UniMERNetLabelEncode:
|
||||||
rec_char_dict_path: ppocr/utils/dict/unimernet_tokenizer
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
max_seq_len: 1024
|
max_seq_len: *max_seq_len
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: ['image', 'label', 'attention_mask']
|
keep_keys: ['image', 'label', 'attention_mask']
|
||||||
loader:
|
loader:
|
||||||
|
@ -95,12 +97,12 @@ Eval:
|
||||||
label_file_list: ["./train_data/UniMERNet/test_unimernet_cpe.txt"]
|
label_file_list: ["./train_data/UniMERNet/test_unimernet_cpe.txt"]
|
||||||
transforms:
|
transforms:
|
||||||
- UniMERNetImgDecode:
|
- UniMERNetImgDecode:
|
||||||
input_size: [192, 672]
|
input_size: *input_size
|
||||||
- UniMERNetTestTransform:
|
- UniMERNetTestTransform:
|
||||||
- UniMERNetImageFormat:
|
- UniMERNetImageFormat:
|
||||||
- UniMERNetLabelEncode:
|
- UniMERNetLabelEncode:
|
||||||
max_seq_len: 1024
|
max_seq_len: *max_seq_len
|
||||||
rec_char_dict_path: ppocr/utils/dict/unimernet_tokenizer
|
rec_char_dict_path: *rec_char_dict_path
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: ['image', 'label', 'attention_mask']
|
keep_keys: ['image', 'label', 'attention_mask']
|
||||||
loader:
|
loader:
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
# 印刷数学公式识别算法-PP-FormulaNet
|
||||||
|
|
||||||
|
## 1. 算法简介
|
||||||
|
|
||||||
|
`PP-FormulaNet` 是百度飞桨自研的公式识别模型,采用 PaddleX 内部自建的 5百万数据集进行训练,在对应测试集上的精度如下:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
| 模型 | 骨干网络 | 配置文件 | SPE-<br/>BLEU↑ | CPE-<br/>BLEU↑ | Easy-<br/>BLEU↑ | Middle-<br/>BLEU↑ | Hard-<br/>BLEU↑| Avg-<br/>BLEU↑ | 下载链接 |
|
||||||
|
|-----------|------------|------------------|:--------------:|:---------:|:----------:|:----------------:|:---------:|:-----------------:|:-----------------:|
|
||||||
|
| UniMERNet | Donut Swin | [rec_unimernet.yml](../../../configs/rec/rec_unimernet.yml) | 0.9187 | 0.9252 | 0.8658 | 0.8228 | 0.7740 | 0.8613 |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_unimernet_train.tar)|
|
||||||
|
| PP-FormulaNet-S | PPHGNetV2_B4 | [rec_pp_formulanet_s.yml](../../../configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml) | 0.8694 | 0.8071 | 0.9294 | 0.9112 | 0.8391 | 0.8712 |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_s_train.tar)|
|
||||||
|
| PP-FormulaNet-L | Vary_VIT_B | [rec_pp_formulanet_l.yml](../../../configs/rec/PP-FormuaNet/rec_pp_formulanet_l.yml) | 0.9055 | 0.9206 | 0.9392 | 0.9273 | 0.9141 | 0.9213 |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_l_train.tar )|
|
||||||
|
|
||||||
|
其中,SPE、CPE为UniMERNet的简单公式数据集和复杂公式数据集;Easy、Middle、Hard为PaddleX内部自建的简单公式数据集(LaTeX 代码长度 0-64)、中等公式数据集(LaTeX 代码长度 64-256)和复杂公式数据集(LaTeX 代码长度 256+)。
|
||||||
|
|
||||||
|
## 2. 环境配置
|
||||||
|
请先参考[《运行环境准备》](../../ppocr/environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](../../ppocr/blog/clone.md)克隆项目代码。
|
||||||
|
|
||||||
|
此外,需要安装额外的依赖:
|
||||||
|
```shell
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libmagickwand-dev
|
||||||
|
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 模型训练、评估、预测
|
||||||
|
|
||||||
|
### 3.1 准备数据集
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 下载 PaddleX 官方示例数据集
|
||||||
|
wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ocr_rec_latexocr_dataset_example.tar
|
||||||
|
tar -xf ocr_rec_latexocr_dataset_example.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3.2 下载预训练模型
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 下载 PP-FormulaNet-S 预训练模型
|
||||||
|
wget https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_s_train.tar
|
||||||
|
tar -xf rec_ppformulanet_s_train.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3.3 模型训练
|
||||||
|
|
||||||
|
请参考[文本识别训练教程](../../ppocr/model_train/recognition.md)。PaddleOCR对代码进行了模块化,训练 `PP-FormulaNet-S` 识别模型时需要**更换配置文件**为 `PP-FormulaNet-S` 的[配置文件](https://github.com/PaddlePaddle/PaddleOCR/blob/main/configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml)。
|
||||||
|
|
||||||
|
#### 启动训练
|
||||||
|
|
||||||
|
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||||
|
```shell
|
||||||
|
#单卡训练 (默认训练方式)
|
||||||
|
python3 tools/train.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
#多卡训练,通过--gpus参数指定卡号
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/train.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意:**
|
||||||
|
|
||||||
|
- 默认每训练 1个epoch(179 次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
|
||||||
|
```
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/train.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.eval_batch_step=[0,{length_of_dataset//batch_size//4}] \
|
||||||
|
Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 评估
|
||||||
|
|
||||||
|
可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_s_train.tar ),使用如下命令进行评估:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
|
||||||
|
# demo 测试集评估
|
||||||
|
python3 tools/eval.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml -o \
|
||||||
|
Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.5 预测
|
||||||
|
|
||||||
|
使用如下命令进行单张图片预测:
|
||||||
|
```shell
|
||||||
|
# 注意将pretrained_model的路径设置为本地路径。
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.infer_img='./docs/datasets/images/pme_demo/0000099.png'\
|
||||||
|
Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/pme_demo/'。
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. FAQ
|
|
@ -0,0 +1,78 @@
|
||||||
|
# PP-FormulaNet
|
||||||
|
|
||||||
|
## 1. Introduction
|
||||||
|
|
||||||
|
|
||||||
|
PP-FormulaNet is a formula recognition model independently developed by Baidu PaddlePaddle. It is trained on a self-built dataset of 5 million samples within PaddleX, achieving the following accuracy on the corresponding test set:
|
||||||
|
|
||||||
|
| Model | Backbone | config |SPE-<br/>BLEU↑ | CPE-<br/>BLEU↑ | Easy-<br/>BLEU↑ | Middle-<br/>BLEU↑ | Hard-<br/>BLEU↑| Avg-<br/>BLEU↑ | Download link |
|
||||||
|
|-----------|--------|---------------------------------------------------|:--------------:|:-----------------:|:----------:|:----------------:|:---------:|:-----------------:|:--------------:|
|
||||||
|
| UniMERNet | Donut Swin | [rec_unimernet.yml](../../../configs/rec/rec_unimernet.yml) | 0.9187 | 0.9252 | 0.8658 | 0.8228 | 0.7740 | 0.8613 |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_unimernet_train.tar)|
|
||||||
|
| PP-FormulaNet-S | PPHGNetV2_B4 | [rec_pp_formulanet_s.yml](../../../configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml) | 0.8694 | 0.8071 | 0.9294 | 0.9112 | 0.8391 | 0.8712 |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_s_train.tar)|
|
||||||
|
| PP-FormulaNet-L | Vary_VIT_B | [rec_pp_formulanet_l.yml](../../../configs/rec/PP-FormuaNet/rec_pp_formulanet_l.yml) | 0.9055 | 0.9206 | 0.9392 | 0.9273 | 0.9141 | 0.9213 |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_l_train.tar )|
|
||||||
|
|
||||||
|
Among them, SPE and CPE refer to the simple and complex formula datasets of UniMERNet, respectively. Easy, Middle, and Hard are simple (LaTeX code length 0-64), medium (LaTeX code length 64-256), and complex formula datasets (LaTeX code length 256+) built internally by PaddleX.
|
||||||
|
|
||||||
|
|
||||||
|
## 2. Environment
|
||||||
|
Please refer to ["Environment Preparation"](../../ppocr/environment.en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](../../ppocr/blog/clone.en.md) to clone the project code.
|
||||||
|
|
||||||
|
Furthermore, additional dependencies need to be installed:
|
||||||
|
```shell
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libmagickwand-dev
|
||||||
|
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. Model Training / Evaluation / Prediction
|
||||||
|
|
||||||
|
Please refer to [Text Recognition Tutorial](../../ppocr/model_train/recognition.en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||||
|
|
||||||
|
|
||||||
|
Dataset Preparation:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# download PaddleX official example dataset
|
||||||
|
wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ocr_rec_latexocr_dataset_example.tar
|
||||||
|
tar -xf ocr_rec_latexocr_dataset_example.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
Download the Pre-trained Model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# download the PP-FormulaNet-S pretrained model
|
||||||
|
wget https://paddleocr.bj.bcebos.com/contribution/rec_ppformulanet_s_train.tar
|
||||||
|
tar -xf rec_ppformulanet_s_train.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
Training:
|
||||||
|
|
||||||
|
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
#Single GPU training
|
||||||
|
python3 tools/train.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||||
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/train.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# GPU evaluation
|
||||||
|
python3 tools/eval.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml -o \
|
||||||
|
Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
Prediction:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# The configuration file used for prediction must match the training
|
||||||
|
python3 tools/infer_rec.py -c configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml \
|
||||||
|
-o Global.infer_img='./docs/datasets/images/pme_demo/0000099.png'\
|
||||||
|
Global.pretrained_model=./rec_ppformulanet_s_train/best_accuracy.pdparams
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. FAQ
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
此外,需要安装额外的依赖:
|
此外,需要安装额外的依赖:
|
||||||
```shell
|
```shell
|
||||||
apt-get install sudo
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libmagickwand-dev
|
sudo apt-get install libmagickwand-dev
|
||||||
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
||||||
|
@ -107,7 +106,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/tr
|
||||||
```
|
```
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/train.py -c configs/rec/rec_unimernet.yml \
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' --ips=127.0.0.1 tools/train.py -c configs/rec/rec_unimernet.yml \
|
||||||
-o Global.eval_batch_step=[0,{length_of_dataset//batch_size//4}] \
|
-o Global.eval_batch_step=[0,{length_of_dataset//batch_size//4}] \
|
||||||
-o Global.pretrained_model=./pretrain_models/texify.pdparams
|
Global.pretrained_model=./pretrain_models/texify.pdparams
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.4 评估
|
### 3.4 评估
|
||||||
|
|
|
@ -21,7 +21,6 @@ Please refer to ["Environment Preparation"](../../ppocr/environment.en.md) to co
|
||||||
|
|
||||||
Furthermore, additional dependencies need to be installed:
|
Furthermore, additional dependencies need to be installed:
|
||||||
```shell
|
```shell
|
||||||
apt-get install sudo
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libmagickwand-dev
|
sudo apt-get install libmagickwand-dev
|
||||||
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
pip install -r docs/algorithm/formula_recognition/requirements.txt
|
||||||
|
|
|
@ -372,6 +372,8 @@ nav:
|
||||||
- 公式识别算法:
|
- 公式识别算法:
|
||||||
- CAN: algorithm/formula_recognition/algorithm_rec_can.md
|
- CAN: algorithm/formula_recognition/algorithm_rec_can.md
|
||||||
- LaTeX-OCR: algorithm/formula_recognition/algorithm_rec_latex_ocr.md
|
- LaTeX-OCR: algorithm/formula_recognition/algorithm_rec_latex_ocr.md
|
||||||
|
- UniMERNet: algorithm/formula_recognition/algorithm_rec_unimernet.md
|
||||||
|
- PP-FormulaNet: algorithm/formula_recognition/algorithm_rec_ppformulanet.md
|
||||||
- 端到端OCR算法:
|
- 端到端OCR算法:
|
||||||
- PGNet: algorithm/end_to_end/algorithm_e2e_pgnet.md
|
- PGNet: algorithm/end_to_end/algorithm_e2e_pgnet.md
|
||||||
- 表格识别算法:
|
- 表格识别算法:
|
||||||
|
|
|
@ -47,6 +47,7 @@ from .rec_parseq_loss import ParseQLoss
|
||||||
from .rec_cppd_loss import CPPDLoss
|
from .rec_cppd_loss import CPPDLoss
|
||||||
from .rec_latexocr_loss import LaTeXOCRLoss
|
from .rec_latexocr_loss import LaTeXOCRLoss
|
||||||
from .rec_unimernet_loss import UniMERNetLoss
|
from .rec_unimernet_loss import UniMERNetLoss
|
||||||
|
from .rec_ppformulanet_loss import PPFormulaNet_S_Loss, PPFormulaNet_L_Loss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
@ -111,6 +112,8 @@ def build_loss(config):
|
||||||
"CPPDLoss",
|
"CPPDLoss",
|
||||||
"LaTeXOCRLoss",
|
"LaTeXOCRLoss",
|
||||||
"UniMERNetLoss",
|
"UniMERNetLoss",
|
||||||
|
"PPFormulaNet_S_Loss",
|
||||||
|
"PPFormulaNet_L_Loss",
|
||||||
]
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop("name")
|
module_name = config.pop("name")
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PPFormulaNet_S_Loss(nn.Layer):
|
||||||
|
"""
|
||||||
|
PP=FormulaNet-S adopt CrossEntropyLoss for network training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size=50000, parallel_step=1):
|
||||||
|
super(PPFormulaNet_S_Loss, self).__init__()
|
||||||
|
self.ignore_index = -100
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.parallel_step = int(parallel_step)
|
||||||
|
self.pad_token_id = 1
|
||||||
|
# ignore padding characters during training
|
||||||
|
self.cross = nn.CrossEntropyLoss(
|
||||||
|
reduction="mean", ignore_index=self.ignore_index
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, preds, batch):
|
||||||
|
logits, masked_label = preds
|
||||||
|
|
||||||
|
word_loss = self.cross(
|
||||||
|
paddle.reshape(logits, [-1, logits.shape[-1]]),
|
||||||
|
paddle.reshape(masked_label[:, self.parallel_step :], [-1]),
|
||||||
|
)
|
||||||
|
loss = word_loss
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"word_loss": word_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PPFormulaNet_L_Loss(nn.Layer):
|
||||||
|
"""
|
||||||
|
PPFormulaNet_L adopt CrossEntropyLoss for network training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size=50000):
|
||||||
|
super(PPFormulaNet_L_Loss, self).__init__()
|
||||||
|
self.ignore_index = -100
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.pad_token_id = 1
|
||||||
|
# ignore padding characters during training
|
||||||
|
self.cross = nn.CrossEntropyLoss(
|
||||||
|
reduction="mean", ignore_index=self.ignore_index
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, preds, batch):
|
||||||
|
logits, masked_label = preds
|
||||||
|
|
||||||
|
word_loss = self.cross(
|
||||||
|
paddle.reshape(logits, [-1, logits.shape[-1]]),
|
||||||
|
paddle.reshape(masked_label[:, 1:], [-1]),
|
||||||
|
)
|
||||||
|
loss = word_loss
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"word_loss": word_loss,
|
||||||
|
}
|
|
@ -70,7 +70,8 @@ def build_backbone(config, model_type):
|
||||||
from .rec_vit_parseq import ViTParseQ
|
from .rec_vit_parseq import ViTParseQ
|
||||||
from .rec_repvit import RepSVTR
|
from .rec_repvit import RepSVTR
|
||||||
from .rec_svtrv2 import SVTRv2
|
from .rec_svtrv2 import SVTRv2
|
||||||
from .rec_vary_vit import Vary_VIT_B
|
from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
|
||||||
|
from .rec_pphgnetv2 import PPHGNetV2_B4
|
||||||
|
|
||||||
support_dict = [
|
support_dict = [
|
||||||
"MobileNetV1Enhance",
|
"MobileNetV1Enhance",
|
||||||
|
@ -99,6 +100,8 @@ def build_backbone(config, model_type):
|
||||||
"HybridTransformer",
|
"HybridTransformer",
|
||||||
"DonutSwinModel",
|
"DonutSwinModel",
|
||||||
"Vary_VIT_B",
|
"Vary_VIT_B",
|
||||||
|
"PPHGNetV2_B4",
|
||||||
|
"Vary_VIT_B_Formula",
|
||||||
]
|
]
|
||||||
elif model_type == "e2e":
|
elif model_type == "e2e":
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -27,6 +27,7 @@ from paddle.nn.initializer import (
|
||||||
TruncatedNormal,
|
TruncatedNormal,
|
||||||
XavierUniform,
|
XavierUniform,
|
||||||
)
|
)
|
||||||
|
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModelOutput
|
||||||
|
|
||||||
zeros_ = Constant(value=0.0)
|
zeros_ = Constant(value=0.0)
|
||||||
ones_ = Constant(value=1.0)
|
ones_ = Constant(value=1.0)
|
||||||
|
@ -90,6 +91,7 @@ class ImageEncoderViT(nn.Layer):
|
||||||
rel_pos_zero_init: bool = True,
|
rel_pos_zero_init: bool = True,
|
||||||
window_size: int = 0,
|
window_size: int = 0,
|
||||||
global_attn_indexes: Tuple[int, ...] = (),
|
global_attn_indexes: Tuple[int, ...] = (),
|
||||||
|
is_formula: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -168,6 +170,7 @@ class ImageEncoderViT(nn.Layer):
|
||||||
self.net_3 = nn.Conv2D(
|
self.net_3 = nn.Conv2D(
|
||||||
512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
|
512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
|
||||||
)
|
)
|
||||||
|
self.is_formula = is_formula
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
|
@ -177,6 +180,8 @@ class ImageEncoderViT(nn.Layer):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
x = self.neck(x.transpose([0, 3, 1, 2]))
|
x = self.neck(x.transpose([0, 3, 1, 2]))
|
||||||
x = self.net_2(x)
|
x = self.net_2(x)
|
||||||
|
if self.is_formula:
|
||||||
|
x = self.net_3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -492,6 +497,7 @@ def _build_vary(
|
||||||
encoder_num_heads,
|
encoder_num_heads,
|
||||||
encoder_global_attn_indexes,
|
encoder_global_attn_indexes,
|
||||||
image_size,
|
image_size,
|
||||||
|
is_formula=False,
|
||||||
):
|
):
|
||||||
prompt_embed_dim = 256
|
prompt_embed_dim = 256
|
||||||
vit_patch_size = 16
|
vit_patch_size = 16
|
||||||
|
@ -509,6 +515,7 @@ def _build_vary(
|
||||||
global_attn_indexes=encoder_global_attn_indexes,
|
global_attn_indexes=encoder_global_attn_indexes,
|
||||||
window_size=14,
|
window_size=14,
|
||||||
out_chans=prompt_embed_dim,
|
out_chans=prompt_embed_dim,
|
||||||
|
is_formula=is_formula,
|
||||||
)
|
)
|
||||||
return image_encoder
|
return image_encoder
|
||||||
|
|
||||||
|
@ -543,3 +550,67 @@ class Vary_VIT_B(nn.Layer):
|
||||||
cnn_feature = self.vision_tower_high(pixel_values)
|
cnn_feature = self.vision_tower_high(pixel_values)
|
||||||
cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1])
|
cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1])
|
||||||
return cnn_feature
|
return cnn_feature
|
||||||
|
|
||||||
|
|
||||||
|
class Vary_VIT_B_Formula(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
image_size=768,
|
||||||
|
encoder_embed_dim=768,
|
||||||
|
encoder_depth=12,
|
||||||
|
encoder_num_heads=12,
|
||||||
|
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Vary_VIT_B_Formula
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels. Default is 3 (for RGB images).
|
||||||
|
image_size (int): Size of the input image. Default is 768.
|
||||||
|
encoder_embed_dim (int): Dimension of the encoder's embedding. Default is 768.
|
||||||
|
encoder_depth (int): Number of layers (depth) in the encoder. Default is 12.
|
||||||
|
encoder_num_heads (int): Number of attention heads in the encoder. Default is 12.
|
||||||
|
encoder_global_attn_indexes (list): List of indices specifying which encoder layers use global attention. Default is [2, 5, 8, 11].
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `Vary_VIT_B_Formula` model with defined architecture.
|
||||||
|
"""
|
||||||
|
super(Vary_VIT_B_Formula, self).__init__()
|
||||||
|
|
||||||
|
self.vision_tower_high = _build_vary(
|
||||||
|
encoder_embed_dim=encoder_embed_dim,
|
||||||
|
encoder_depth=encoder_depth,
|
||||||
|
encoder_num_heads=encoder_num_heads,
|
||||||
|
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
image_size=image_size,
|
||||||
|
is_formula=True,
|
||||||
|
)
|
||||||
|
self.mm_projector_vary = nn.Linear(1024, 1024)
|
||||||
|
self.out_channels = 1024
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
if self.training:
|
||||||
|
pixel_values, label, attention_mask = input_data
|
||||||
|
else:
|
||||||
|
if isinstance(input_data, list):
|
||||||
|
pixel_values = input_data[0]
|
||||||
|
else:
|
||||||
|
pixel_values = input_data
|
||||||
|
num_channels = pixel_values.shape[1]
|
||||||
|
if num_channels == 1:
|
||||||
|
pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
|
||||||
|
|
||||||
|
cnn_feature = self.vision_tower_high(pixel_values)
|
||||||
|
cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1])
|
||||||
|
|
||||||
|
cnn_feature = self.mm_projector_vary(cnn_feature)
|
||||||
|
donut_swin_output = DonutSwinModelOutput(
|
||||||
|
last_hidden_state=cnn_feature,
|
||||||
|
pooler_output=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
reshaped_hidden_states=None,
|
||||||
|
)
|
||||||
|
if self.training:
|
||||||
|
return donut_swin_output, label, attention_mask
|
||||||
|
else:
|
||||||
|
return donut_swin_output
|
||||||
|
|
|
@ -45,6 +45,7 @@ def build_head(config):
|
||||||
from .rec_parseq_head import ParseQHead
|
from .rec_parseq_head import ParseQHead
|
||||||
from .rec_cppd_head import CPPDHead
|
from .rec_cppd_head import CPPDHead
|
||||||
from .rec_unimernet_head import UniMERNetHead
|
from .rec_unimernet_head import UniMERNetHead
|
||||||
|
from .rec_ppformulanet_head import PPFormulaNet_Head
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
|
@ -89,6 +90,7 @@ def build_head(config):
|
||||||
"ParseQHead",
|
"ParseQHead",
|
||||||
"CPPDHead",
|
"CPPDHead",
|
||||||
"UniMERNetHead",
|
"UniMERNetHead",
|
||||||
|
"PPFormulaNet_Head",
|
||||||
]
|
]
|
||||||
|
|
||||||
if config["name"] == "DRRGHead":
|
if config["name"] == "DRRGHead":
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -217,6 +217,8 @@ class MBartConfig(object):
|
||||||
forced_eos_token_id=2,
|
forced_eos_token_id=2,
|
||||||
_attn_implementation="eager",
|
_attn_implementation="eager",
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
|
use_parallel=False,
|
||||||
|
parallel_step=2,
|
||||||
is_export=False,
|
is_export=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -251,6 +253,8 @@ class MBartConfig(object):
|
||||||
self.is_encoder_decoder = is_encoder_decoder
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
self.forced_eos_token_id = forced_eos_token_id
|
self.forced_eos_token_id = forced_eos_token_id
|
||||||
self._attn_implementation = _attn_implementation
|
self._attn_implementation = _attn_implementation
|
||||||
|
self.use_parallel = use_parallel
|
||||||
|
self.parallel_step = parallel_step
|
||||||
self.is_export = is_export
|
self.is_export = is_export
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -310,6 +314,22 @@ class AttentionMaskConverter:
|
||||||
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
|
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_4d_export(
|
||||||
|
self,
|
||||||
|
attention_mask_2d,
|
||||||
|
query_length,
|
||||||
|
dtype,
|
||||||
|
key_value_length,
|
||||||
|
is_export=False,
|
||||||
|
):
|
||||||
|
input_shape = (attention_mask_2d.shape[0], query_length)
|
||||||
|
expanded_attn_mask = self._expand_mask(
|
||||||
|
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
||||||
|
)
|
||||||
|
expanded_4d_mask = expanded_attn_mask
|
||||||
|
|
||||||
|
return expanded_4d_mask
|
||||||
|
|
||||||
def to_4d(
|
def to_4d(
|
||||||
self,
|
self,
|
||||||
attention_mask_2d,
|
attention_mask_2d,
|
||||||
|
@ -321,7 +341,6 @@ class AttentionMaskConverter:
|
||||||
|
|
||||||
input_shape = (attention_mask_2d.shape[0], query_length)
|
input_shape = (attention_mask_2d.shape[0], query_length)
|
||||||
causal_4d_mask = None
|
causal_4d_mask = None
|
||||||
|
|
||||||
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
||||||
if key_value_length is None:
|
if key_value_length is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -375,6 +394,33 @@ def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
|
||||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_4d_causal_attention_mask_export(
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=None,
|
||||||
|
is_export=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
attn_mask_converter = AttentionMaskConverter(
|
||||||
|
is_causal=True, sliding_window=sliding_window
|
||||||
|
)
|
||||||
|
key_value_length = input_shape[-1] + past_key_values_length
|
||||||
|
|
||||||
|
shape = attention_mask.shape
|
||||||
|
len_shape = len(shape)
|
||||||
|
|
||||||
|
attention_mask = attn_mask_converter.to_4d_export(
|
||||||
|
attention_mask,
|
||||||
|
input_shape[-1],
|
||||||
|
key_value_length=key_value_length,
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
is_export=is_export,
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def _prepare_4d_causal_attention_mask(
|
def _prepare_4d_causal_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -1681,7 +1727,7 @@ class CustomMBartDecoder(MBartDecoder):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.is_export:
|
if self.is_export:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask_export(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
input_shape,
|
input_shape,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
|
@ -1721,6 +1767,7 @@ class CustomMBartDecoder(MBartDecoder):
|
||||||
hidden_states = nn.functional.dropout(
|
hidden_states = nn.functional.dropout(
|
||||||
hidden_states, p=self.dropout, training=self.training
|
hidden_states, p=self.dropout, training=self.training
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
print(
|
print(
|
||||||
|
@ -1828,7 +1875,6 @@ class CustomMBartDecoder(MBartDecoder):
|
||||||
]
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
|
@ -2237,6 +2283,21 @@ class UniMERNetHead(nn.Layer):
|
||||||
}
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation_export(
|
||||||
|
self,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
input_dict = {
|
||||||
|
"decoder_attention_mask": None,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
return input_dict
|
||||||
|
|
||||||
def _extract_past_from_model_output(
|
def _extract_past_from_model_output(
|
||||||
self, outputs: ModelOutput, standardize_cache_format: bool = False
|
self, outputs: ModelOutput, standardize_cache_format: bool = False
|
||||||
):
|
):
|
||||||
|
@ -2434,9 +2495,10 @@ class UniMERNetHead(nn.Layer):
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def generate_export(
|
def generate_export(
|
||||||
self,
|
self,
|
||||||
|
encoder_outputs,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
):
|
):
|
||||||
batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
|
batch_size = encoder_outputs["last_hidden_state"].shape[0]
|
||||||
generation_config = {
|
generation_config = {
|
||||||
"decoder_start_token_id": 0,
|
"decoder_start_token_id": 0,
|
||||||
"bos_token_id": 0,
|
"bos_token_id": 0,
|
||||||
|
@ -2447,26 +2509,33 @@ class UniMERNetHead(nn.Layer):
|
||||||
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
||||||
bos_token_id=generation_config["bos_token_id"],
|
bos_token_id=generation_config["bos_token_id"],
|
||||||
)
|
)
|
||||||
|
input_ids = input_ids.reshape([-1, 1])
|
||||||
|
decoder_input_ids = input_ids
|
||||||
model_kwargs["key use_cache"] = True
|
model_kwargs["key use_cache"] = True
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
if "inputs_embeds" in model_kwargs:
|
if "inputs_embeds" in model_kwargs:
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||||
model_kwargs["cache_position"] = paddle.arange(cur_len)
|
cache_position = paddle.arange(cur_len)
|
||||||
pad_token_id = self.pad_token_id
|
pad_token_id = self.pad_token_id
|
||||||
eos_token_id = [self.eos_token_id]
|
eos_token_id = [self.eos_token_id]
|
||||||
eos_token = self.eos_token_id
|
eos_token = self.eos_token_id
|
||||||
unfinished_sequences = paddle.ones([batch_size], dtype=paddle.int64)
|
unfinished_sequences = paddle.ones([batch_size], dtype=paddle.int64)
|
||||||
i_idx = paddle.full([], 0)
|
i_idx = paddle.full([], 0)
|
||||||
|
past_key_values = []
|
||||||
|
for i in range(8):
|
||||||
|
init_arr = paddle.zeros([batch_size, 16, 0, 64])
|
||||||
|
paddle.jit.api.set_dynamic_shape(init_arr, [-1, -1, -1, -1])
|
||||||
|
cache = (init_arr, init_arr, init_arr, init_arr)
|
||||||
|
past_key_values.append(cache)
|
||||||
|
idx = 0
|
||||||
while i_idx < paddle.to_tensor(self.max_seq_len):
|
while i_idx < paddle.to_tensor(self.max_seq_len):
|
||||||
|
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation_export(
|
||||||
decoder_input_ids = model_inputs["decoder_input_ids"]
|
past_key_values=past_key_values, **model_kwargs
|
||||||
|
)
|
||||||
decoder_attention_mask = model_inputs["decoder_attention_mask"]
|
decoder_attention_mask = model_inputs["decoder_attention_mask"]
|
||||||
encoder_outputs = model_inputs["encoder_outputs"]
|
decoder_attention_mask = paddle.ones(input_ids.shape)
|
||||||
past_key_values = model_inputs["past_key_values"]
|
|
||||||
|
|
||||||
paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
|
paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
|
||||||
paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
|
paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
|
||||||
|
|
||||||
|
@ -2489,6 +2558,10 @@ class UniMERNetHead(nn.Layer):
|
||||||
1 - unfinished_sequences
|
1 - unfinished_sequences
|
||||||
)
|
)
|
||||||
input_ids = paddle.concat([input_ids, next_tokens.unsqueeze(1)], axis=-1)
|
input_ids = paddle.concat([input_ids, next_tokens.unsqueeze(1)], axis=-1)
|
||||||
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
decoder_input_ids = next_tokens.unsqueeze(1)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
cache_position = cache_position[-1:] + 1
|
||||||
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
||||||
input_ids
|
input_ids
|
||||||
).cast(paddle.int64)
|
).cast(paddle.int64)
|
||||||
|
@ -2500,6 +2573,7 @@ class UniMERNetHead(nn.Layer):
|
||||||
).all()
|
).all()
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
i_idx += 1
|
i_idx += 1
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
@ -2578,15 +2652,20 @@ class UniMERNetHead(nn.Layer):
|
||||||
"""
|
"""
|
||||||
if not self.training:
|
if not self.training:
|
||||||
encoder_outputs = inputs
|
encoder_outputs = inputs
|
||||||
|
if self.is_export:
|
||||||
|
model_kwargs = {
|
||||||
|
"output_attentions": False,
|
||||||
|
"output_hidden_states": False,
|
||||||
|
"use_cache": True,
|
||||||
|
}
|
||||||
|
word_pred = self.generate_export(encoder_outputs, model_kwargs)
|
||||||
|
else:
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"output_attentions": False,
|
"output_attentions": False,
|
||||||
"output_hidden_states": False,
|
"output_hidden_states": False,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
}
|
}
|
||||||
if self.is_export:
|
|
||||||
word_pred = self.generate_export(model_kwargs)
|
|
||||||
else:
|
|
||||||
word_pred = self.generate(model_kwargs)
|
word_pred = self.generate(model_kwargs)
|
||||||
|
|
||||||
return word_pred
|
return word_pred
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -63,6 +63,12 @@ def dump_infer_config(config, path, logger):
|
||||||
common_dynamic_shapes = {
|
common_dynamic_shapes = {
|
||||||
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
|
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
|
||||||
}
|
}
|
||||||
|
elif arch_config["algorithm"] == "UniMERNet":
|
||||||
|
common_dynamic_shapes = {"x": [[1, 3, 192, 672]]}
|
||||||
|
elif arch_config["algorithm"] == "PP-FormulaNet-L":
|
||||||
|
common_dynamic_shapes = {"x": [[1, 3, 768, 768]]}
|
||||||
|
elif arch_config["algorithm"] == "PP-FormulaNet-S":
|
||||||
|
common_dynamic_shapes = {"x": [[1, 3, 384, 384]]}
|
||||||
else:
|
else:
|
||||||
common_dynamic_shapes = None
|
common_dynamic_shapes = None
|
||||||
|
|
||||||
|
@ -91,6 +97,25 @@ def dump_infer_config(config, path, logger):
|
||||||
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
|
with open(tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
|
||||||
character_dict = json.load(tokenizer_config_handle)
|
character_dict = json.load(tokenizer_config_handle)
|
||||||
postprocess["character_dict"] = character_dict
|
postprocess["character_dict"] = character_dict
|
||||||
|
elif config["Architecture"].get("algorithm") in [
|
||||||
|
"UniMERNet",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
]:
|
||||||
|
tokenizer_file = config["Global"].get("rec_char_dict_path")
|
||||||
|
fast_tokenizer_file = os.path.join(tokenizer_file, "tokenizer.json")
|
||||||
|
tokenizer_config_file = os.path.join(tokenizer_file, "tokenizer_config.json")
|
||||||
|
postprocess["character_dict"] = {}
|
||||||
|
if fast_tokenizer_file is not None:
|
||||||
|
with open(fast_tokenizer_file, encoding="utf-8") as tokenizer_config_handle:
|
||||||
|
character_dict = json.load(tokenizer_config_handle)
|
||||||
|
postprocess["character_dict"]["fast_tokenizer_file"] = character_dict
|
||||||
|
if tokenizer_config_file is not None:
|
||||||
|
with open(
|
||||||
|
tokenizer_config_file, encoding="utf-8"
|
||||||
|
) as tokenizer_config_handle:
|
||||||
|
character_dict = json.load(tokenizer_config_handle)
|
||||||
|
postprocess["character_dict"]["tokenizer_config_file"] = character_dict
|
||||||
else:
|
else:
|
||||||
if config["Global"].get("character_dict_path") is not None:
|
if config["Global"].get("character_dict_path") is not None:
|
||||||
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
|
with open(config["Global"]["character_dict_path"], encoding="utf-8") as f:
|
||||||
|
@ -208,6 +233,31 @@ def dynamic_to_static(model, arch_config, logger, input_shape=None):
|
||||||
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||||
]
|
]
|
||||||
model = to_static(model, input_spec=other_shape)
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
elif arch_config["algorithm"] == "UniMERNet":
|
||||||
|
model = paddle.jit.to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(shape=[-1, 1, 192, 672], dtype="float32")
|
||||||
|
],
|
||||||
|
full_graph=True,
|
||||||
|
)
|
||||||
|
elif arch_config["algorithm"] == "PP-FormulaNet-L":
|
||||||
|
model = paddle.jit.to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(shape=[-1, 1, 768, 768], dtype="float32")
|
||||||
|
],
|
||||||
|
full_graph=True,
|
||||||
|
)
|
||||||
|
elif arch_config["algorithm"] == "PP-FormulaNet-S":
|
||||||
|
model = paddle.jit.to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(shape=[-1, 1, 384, 384], dtype="float32")
|
||||||
|
],
|
||||||
|
full_graph=True,
|
||||||
|
)
|
||||||
|
|
||||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||||
input_spec = [
|
input_spec = [
|
||||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids
|
||||||
|
@ -368,6 +418,14 @@ def export(config, base_model=None, save_path=None):
|
||||||
config["Architecture"]["Backbone"]["is_predict"] = True
|
config["Architecture"]["Backbone"]["is_predict"] = True
|
||||||
config["Architecture"]["Backbone"]["is_export"] = True
|
config["Architecture"]["Backbone"]["is_export"] = True
|
||||||
config["Architecture"]["Head"]["is_export"] = True
|
config["Architecture"]["Head"]["is_export"] = True
|
||||||
|
if config["Architecture"].get("algorithm") in ["UniMERNet"]:
|
||||||
|
config["Architecture"]["Backbone"]["is_export"] = True
|
||||||
|
config["Architecture"]["Head"]["is_export"] = True
|
||||||
|
if config["Architecture"].get("algorithm") in [
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
|
config["Architecture"]["Head"]["is_export"] = True
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
model = base_model
|
model = base_model
|
||||||
if isinstance(model, paddle.DataParallel):
|
if isinstance(model, paddle.DataParallel):
|
||||||
|
|
|
@ -9,7 +9,10 @@ import pytest
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
|
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
|
||||||
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModel, DonutSwinModelOutput
|
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModel, DonutSwinModelOutput
|
||||||
|
from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4
|
||||||
|
from ppocr.modeling.backbones.rec_vary_vit import Vary_VIT_B_Formula
|
||||||
from ppocr.modeling.heads.rec_unimernet_head import UniMERNetHead
|
from ppocr.modeling.heads.rec_unimernet_head import UniMERNetHead
|
||||||
|
from ppocr.modeling.heads.rec_ppformulanet_head import PPFormulaNet_Head
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -17,6 +20,16 @@ def sample_image():
|
||||||
return paddle.randn([1, 1, 192, 672])
|
return paddle.randn([1, 1, 192, 672])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image_ppformulanet_s():
|
||||||
|
return paddle.randn([1, 1, 384, 384])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image_ppformulanet_l():
|
||||||
|
return paddle.randn([1, 1, 768, 768])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def encoder_feat():
|
def encoder_feat():
|
||||||
encoded_feat = paddle.randn([1, 126, 1024])
|
encoded_feat = paddle.randn([1, 126, 1024])
|
||||||
|
@ -25,6 +38,22 @@ def encoder_feat():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def encoder_feat_ppformulanet_s():
|
||||||
|
encoded_feat = paddle.randn([1, 144, 2048])
|
||||||
|
return DonutSwinModelOutput(
|
||||||
|
last_hidden_state=encoded_feat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def encoder_feat_ppformulanet_l():
|
||||||
|
encoded_feat = paddle.randn([1, 144, 1024])
|
||||||
|
return DonutSwinModelOutput(
|
||||||
|
last_hidden_state=encoded_feat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_unimernet_backbone(sample_image):
|
def test_unimernet_backbone(sample_image):
|
||||||
"""
|
"""
|
||||||
Test UniMERNet backbone.
|
Test UniMERNet backbone.
|
||||||
|
@ -68,3 +97,99 @@ def test_unimernet_head(encoder_feat):
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
result = head(encoder_feat)
|
result = head(encoder_feat)
|
||||||
assert result.shape == [1, 6]
|
assert result.shape == [1, 6]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ppformulanet_s_backbone(sample_image_ppformulanet_s):
|
||||||
|
"""
|
||||||
|
Test PP-FormulaNet-S backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_image_ppformulanet_s: sample image to be processed.
|
||||||
|
"""
|
||||||
|
backbone = PPHGNetV2_B4(
|
||||||
|
class_num=1024,
|
||||||
|
)
|
||||||
|
backbone.eval()
|
||||||
|
with paddle.no_grad():
|
||||||
|
result = backbone(sample_image_ppformulanet_s)
|
||||||
|
encoder_feat = result[0]
|
||||||
|
assert encoder_feat.shape == [1, 144, 2048]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ppformulanet_s_head(encoder_feat_ppformulanet_s):
|
||||||
|
"""
|
||||||
|
Test PP-FormulaNet-S head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_feat_ppformulanet_s: encoder feature from PP-FormulaNet-S backbone.
|
||||||
|
"""
|
||||||
|
head = PPFormulaNet_Head(
|
||||||
|
max_new_tokens=6,
|
||||||
|
decoder_start_token_id=0,
|
||||||
|
decoder_ffn_dim=1536,
|
||||||
|
decoder_hidden_size=384,
|
||||||
|
decoder_layers=2,
|
||||||
|
temperature=0.2,
|
||||||
|
do_sample=False,
|
||||||
|
top_p=0.95,
|
||||||
|
encoder_hidden_size=2048,
|
||||||
|
is_export=False,
|
||||||
|
length_aware=True,
|
||||||
|
use_parallel=True,
|
||||||
|
parallel_step=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
head.eval()
|
||||||
|
with paddle.no_grad():
|
||||||
|
result = head(encoder_feat_ppformulanet_s)
|
||||||
|
assert result.shape == [1, 9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ppformulanet_l_backbone(sample_image_ppformulanet_l):
|
||||||
|
"""
|
||||||
|
Test PP-FormulaNet-L backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_image_ppformulanet_l: sample image to be processed.
|
||||||
|
"""
|
||||||
|
backbone = Vary_VIT_B_Formula(
|
||||||
|
image_size=768,
|
||||||
|
encoder_embed_dim=768,
|
||||||
|
encoder_depth=12,
|
||||||
|
encoder_num_heads=12,
|
||||||
|
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
)
|
||||||
|
backbone.eval()
|
||||||
|
with paddle.no_grad():
|
||||||
|
result = backbone(sample_image_ppformulanet_l)
|
||||||
|
encoder_feat = result[0]
|
||||||
|
assert encoder_feat.shape == [1, 144, 1024]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ppformulanet_l_head(encoder_feat_ppformulanet_l):
|
||||||
|
"""
|
||||||
|
Test PP-FormulaNet-L head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_feat_ppformulanet_l: encoder feature from PP-FormulaNet-L Head.
|
||||||
|
"""
|
||||||
|
head = PPFormulaNet_Head(
|
||||||
|
max_new_tokens=6,
|
||||||
|
decoder_start_token_id=0,
|
||||||
|
decoder_ffn_dim=2048,
|
||||||
|
decoder_hidden_size=512,
|
||||||
|
decoder_layers=8,
|
||||||
|
temperature=0.2,
|
||||||
|
do_sample=False,
|
||||||
|
top_p=0.95,
|
||||||
|
encoder_hidden_size=1024,
|
||||||
|
is_export=False,
|
||||||
|
length_aware=False,
|
||||||
|
use_parallel=False,
|
||||||
|
parallel_step=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
head.eval()
|
||||||
|
with paddle.no_grad():
|
||||||
|
result = head(encoder_feat_ppformulanet_l)
|
||||||
|
assert result.shape == [1, 7]
|
||||||
|
|
|
@ -111,6 +111,12 @@ def main():
|
||||||
elif config["Architecture"]["algorithm"] == "UniMERNet":
|
elif config["Architecture"]["algorithm"] == "UniMERNet":
|
||||||
model_type = "unimernet"
|
model_type = "unimernet"
|
||||||
config["Metric"]["cal_blue_score"] = True
|
config["Metric"]["cal_blue_score"] = True
|
||||||
|
elif config["Architecture"]["algorithm"] in [
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
|
model_type = "pp_formulanet"
|
||||||
|
config["Metric"]["cal_blue_score"] = True
|
||||||
else:
|
else:
|
||||||
model_type = config["Architecture"]["model_type"]
|
model_type = config["Architecture"]["model_type"]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -134,7 +134,11 @@ def main():
|
||||||
logger.info("infer_img: {}".format(file))
|
logger.info("infer_img: {}".format(file))
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
if config["Architecture"]["algorithm"] in ["UniMERNet"]:
|
if config["Architecture"]["algorithm"] in [
|
||||||
|
"UniMERNet",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
data = {"image": img, "filename": file}
|
data = {"image": img, "filename": file}
|
||||||
else:
|
else:
|
||||||
data = {"image": img}
|
data = {"image": img}
|
||||||
|
@ -192,7 +196,12 @@ def main():
|
||||||
elif isinstance(post_result, list) and isinstance(post_result[0], int):
|
elif isinstance(post_result, list) and isinstance(post_result[0], int):
|
||||||
# for RFLearning CNT branch
|
# for RFLearning CNT branch
|
||||||
info = str(post_result[0])
|
info = str(post_result[0])
|
||||||
elif config["Architecture"]["algorithm"] in ["LaTeXOCR", "UniMERNet"]:
|
elif config["Architecture"]["algorithm"] in [
|
||||||
|
"LaTeXOCR",
|
||||||
|
"UniMERNet",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
info = str(post_result[0])
|
info = str(post_result[0])
|
||||||
else:
|
else:
|
||||||
if len(post_result[0]) >= 2:
|
if len(post_result[0]) >= 2:
|
||||||
|
|
|
@ -333,7 +333,12 @@ def train(
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif algorithm in ["CAN"]:
|
elif algorithm in ["CAN"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
elif algorithm in ["LaTeXOCR", "UniMERNet"]:
|
elif algorithm in [
|
||||||
|
"LaTeXOCR",
|
||||||
|
"UniMERNet",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
@ -350,7 +355,12 @@ def train(
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif algorithm in ["CAN"]:
|
elif algorithm in ["CAN"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
elif algorithm in ["LaTeXOCR", "UniMERNet"]:
|
elif algorithm in [
|
||||||
|
"LaTeXOCR",
|
||||||
|
"UniMERNet",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
|
]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
@ -381,6 +391,10 @@ def train(
|
||||||
model_type = "unimernet"
|
model_type = "unimernet"
|
||||||
post_result = post_process_class(preds[0], batch[1], mode="train")
|
post_result = post_process_class(preds[0], batch[1], mode="train")
|
||||||
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
||||||
|
elif algorithm in ["PP-FormulaNet-S", "PP-FormulaNet-L"]:
|
||||||
|
model_type = "pp_formulanet"
|
||||||
|
post_result = post_process_class(preds[0], batch[1], mode="train")
|
||||||
|
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
||||||
else:
|
else:
|
||||||
if config["Loss"]["name"] in [
|
if config["Loss"]["name"] in [
|
||||||
"MultiLoss",
|
"MultiLoss",
|
||||||
|
@ -677,7 +691,7 @@ def eval(
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif model_type in ["can"]:
|
elif model_type in ["can"]:
|
||||||
preds = model(batch[:3])
|
preds = model(batch[:3])
|
||||||
elif model_type in ["latexocr", "unimernet"]:
|
elif model_type in ["latexocr", "unimernet", "pp_formulanet"]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
elif model_type in ["sr"]:
|
elif model_type in ["sr"]:
|
||||||
preds = model(batch)
|
preds = model(batch)
|
||||||
|
@ -705,7 +719,7 @@ def eval(
|
||||||
eval_class(preds, batch_numpy)
|
eval_class(preds, batch_numpy)
|
||||||
elif model_type in ["can"]:
|
elif model_type in ["can"]:
|
||||||
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
||||||
elif model_type in ["latexocr", "unimernet"]:
|
elif model_type in ["latexocr", "unimernet", "pp_formulanet"]:
|
||||||
post_result = post_process_class(preds, batch[1], "eval")
|
post_result = post_process_class(preds, batch[1], "eval")
|
||||||
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
|
||||||
else:
|
else:
|
||||||
|
@ -855,6 +869,8 @@ def preprocess(is_train=False):
|
||||||
"LaTeXOCR",
|
"LaTeXOCR",
|
||||||
"UniMERNet",
|
"UniMERNet",
|
||||||
"SLANeXt",
|
"SLANeXt",
|
||||||
|
"PP-FormulaNet-S",
|
||||||
|
"PP-FormulaNet-L",
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_xpu:
|
if use_xpu:
|
||||||
|
|
Loading…
Reference in New Issue