Latexocr paddle (#13401)

* commit_test

* modified:   configs/rec/rec_latex_ocr.yml
	deleted:    ppocr/modeling/backbones/rec_resnetv2.py

* ntuple_solve

* style

* style

* style

* style

* style

* style

* style

* style

* style

* delete comment

* cla_email
This commit is contained in:
liuhongen1234567 2024-07-22 11:50:23 +08:00 committed by GitHub
parent c556b9083e
commit cf26f2330e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 4442 additions and 1 deletions

View File

@ -0,0 +1,126 @@
Global:
use_gpu: True
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/rec/latex_ocr/
save_epoch_step: 5
max_seq_len: 512
# evaluation is run every 60000 iterations (22 epoch)(batch_size = 56)
eval_batch_step: [0, 60000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/datasets/pme_demo/0000013.png
infer_mode: False
use_space_char: False
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
save_res_path: ./output/rec/predicts_latexocr.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Const
learning_rate: 0.0001
Architecture:
model_type: rec
algorithm: LaTeXOCR
in_channels: 1
Transform:
Backbone:
name: HybridTransformer
img_size: [192, 672]
patch_size: 16
num_classes: 0
embed_dim: 256
depth: 4
num_heads: 8
input_channel: 1
is_predict: False
is_export: False
Head:
name: LaTeXOCRHead
pad_value: 0
is_export: False
decoder_args:
attn_on_attn: True
cross_attend: True
ff_glu: True
rel_pos_bias: False
use_scalenorm: False
Loss:
name: LaTeXOCRLoss
PostProcess:
name: LaTeXOCRDecode
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False
Train:
dataset:
name: LaTeXOCRDataSet
data: ./train_data/LaTeXOCR/latexocr_train.pkl
min_dimensions: [32, 32]
max_dimensions: [672, 192]
batch_size_per_pair: 56
keep_smaller_batches: False
transforms:
- DecodeImage:
channel_first: False
- MinMaxResize:
min_dimensions: [32, 32]
max_dimensions: [672, 192]
- LatexTrainTransform:
bitmap_prob: .04
- NormalizeImage:
mean: [0.7931, 0.7931, 0.7931]
std: [0.1738, 0.1738, 0.1738]
order: 'hwc'
- LatexImageFormat:
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: True
batch_size_per_card: 1
drop_last: False
num_workers: 0
collate_fn: LaTeXOCRCollator
Eval:
dataset:
name: LaTeXOCRDataSet
data: ./train_data/LaTeXOCR/latexocr_val.pkl
min_dimensions: [32, 32]
max_dimensions: [672, 192]
batch_size_per_pair: 10
keep_smaller_batches: True
transforms:
- DecodeImage:
channel_first: False
- MinMaxResize:
min_dimensions: [32, 32]
max_dimensions: [672, 192]
- LatexTestTransform:
- NormalizeImage:
mean: [0.7931, 0.7931, 0.7931]
std: [0.1738, 0.1738, 0.1738]
order: 'hwc'
- LatexImageFormat:
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 0
collate_fn: LaTeXOCRCollator

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -137,6 +137,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型**欢迎广
已支持的公式识别算法列表(戳链接获取使用教程):
- [x] [CAN](./algorithm_rec_can.md)
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr.md)
在CROHME手写公式数据集上算法效果如下
@ -144,6 +145,13 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型**欢迎广
| ----- | ----- | ----- | ----- | ----- |
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
在LaTeX-OCR印刷公式数据集上算法效果如下
| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
<a name="2"></a>
## 2. 端到端算法

View File

@ -0,0 +1,171 @@
# 印刷数学公式识别算法-LaTeX-OCR
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 pickle 标签文件生成](#3-1)
- [3.2 训练](#3-2)
- [3.3 评估](#3-3)
- [3.4 预测](#3-4)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
原始项目:
> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
<a name="model"></a>
`LaTeX-OCR`使用[`LaTeX-OCR印刷公式数据集`](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)进行训练,在对应测试集上的精度如下:
| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 pickle 标签文件生成
从[谷歌云盘](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)中下载 formulae.zip 和 math.txt之后使用如下命令生成 pickle 标签文件。
```shell
# 创建 LaTeX-OCR 数据集目录
mkdir -p train_data/LaTeXOCR
# 解压formulae.zip 并拷贝math.txt
unzip -d train_data/LaTeXOCR path/formulae.zip
cp path/math.txt train_data/LaTeXOCR
# 将原始的 .txt 文件转换为 .pkl 文件,从而对不同尺度的图像进行分组
# 训练集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# 验证集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# 测试集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
```
### 3.2 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化训练`LaTeX-OCR`识别模型时需要**更换配置文件**为`LaTeX-OCR`的[配置文件](../../configs/rec/rec_latex_ocr.yml)。
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练 (默认训练方式)
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml
```
**注意:**
- 默认每训练22个epoch60000次iteration进行1次评估若您更改训练的batch_size或更换数据集请在训练时作出如下修改
```
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml -o Global.eval_batch_step=[0,{length_of_dataset//batch_size*22}]
```
<a name="3-2"></a>
### 3.3 评估
可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型请注意修改路径和文件名为{path/to/weights}/{model_name}。
# 验证集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
# 测试集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
```
<a name="3-3"></a>
### 3.4 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
# 预测文件夹下所有图像时可修改infer_img为文件夹如 Global.infer_img='./doc/datasets/pme_demo/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型转换成inference model。这里以训练完成的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
# 目前的静态图模型支持的最大输出长度为512
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请检查配置文件中的`rec_char_dict_path`是否为所需要的字典文件。
- [转换后模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_infer.tar)
转换成功后,在目录下有三个文件:
```
/inference/rec_latex_ocr_infer/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
# 预测文件夹下所有图像时可修改image_dir为文件夹如 --image_dir='./doc/datasets/pme_demo/'。
```
&nbsp;
![测试图片样例](../datasets/pme_demo/0000295.png)
执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
```shell
Predicts of ./doc/datasets/pme_demo/0000295.png:\zeta_{0}(\nu)=-{\frac{\nu\varrho^{-2\nu}}{\pi}}\int_{\mu}^{\infty}d\omega\int_{C_{+}}d z{\frac{2z^{2}}{(z^{2}+\omega^{2})^{\nu+1}}}{\tilde{\Psi}}(\omega;z)e^{i\epsilon z}~~~,
```
**注意**
- 需要注意预测图像为**白底黑字**,即手写公式部分为黑色,背景为白色的图片。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中 LaTeX-OCR 的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持 LaTeX-OCR所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. LaTeX-OCR 数据集来自于[LaTeXOCR源repo](https://github.com/lukas-blecher/LaTeX-OCR) 。

View File

@ -137,6 +137,8 @@ On the TextZoom public dataset, the effect of the algorithm is as follows:
Supported formula recognition algorithms (Click the link to get the tutorial):
- [x] [CAN](./algorithm_rec_can_en.md)
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr_en.md)
On the CROHME handwritten formula dataset, the effect of the algorithm is as follows:
@ -145,6 +147,13 @@ On the CROHME handwritten formula dataset, the effect of the algorithm is as fol
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|
On the LaTeX-OCR printed formula dataset, the effect of the algorithm is as follows:
| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link|
|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
<a name="2"></a>
## 2. End-to-end OCR Algorithms

View File

@ -0,0 +1,127 @@
# LaTeX-OCR
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Pickle File Generation](#3-1)
- [3.2 Training](#3-2)
- [3.3 Evaluation](#3-3)
- [3.4 Prediction](#3-4)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Original Project:
> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
Using LaTeX-OCR printed mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link|
|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Pickle File Generation:
Download formulae.zip and math.txt in [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO), and then use the following command to generate the pickle file.
```shell
# Create a LaTeX-OCR dataset directory
mkdir -p train_data/LaTeXOCR
# Unzip formulae.zip and copy math.txt
unzip -d train_data/LaTeXOCR path/formulae.zip
cp path/math.txt train_data/LaTeXOCR
# Convert the original .txt file to a .pkl file to group images of different scales
# Training set conversion
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# Validation set conversion
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# Test set conversion
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
```
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (Default training method)
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml
```
Evaluation:
```
# GPU evaluation
# Validation set evaluation
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
# Test set evaluation
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the LaTeX-OCR printed mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True
# The default output max length of the model is 512.
```
For LaTeX-OCR printed mathematical expression recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
```

View File

@ -38,6 +38,7 @@ from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTable
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler
from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
# for PaddleX dataset_type
TextDetDataset = SimpleDataSet
@ -45,6 +46,7 @@ TextRecDataset = SimpleDataSet
MSTextRecDataset = MultiScaleDataSet
PubTabTableRecDataset = PubTabDataSet
KieDataset = SimpleDataSet
LaTeXOCRDataSet = LaTeXOCRDataSet
__all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"]
@ -94,6 +96,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
"MSTextRecDataset",
"PubTabTableRecDataset",
"KieDataset",
"LaTeXOCRDataSet",
]
module_name = config[mode]["dataset"]["name"]
assert module_name in support_dict, Exception(

View File

@ -116,3 +116,18 @@ class DyMaskCollator(object):
label_masks[i][:l] = 1
return images, image_masks, labels, label_masks
class LaTeXOCRCollator(object):
"""
batch: [
image [batch_size, channel, maxHinbatch, maxWinbatch]
label [batch_size, maxLabelLen]
label_mask [batch_size, maxLabelLen]
...
]
"""
def __call__(self, batch):
images, labels, attention_mask = batch[0]
return images, labels, attention_mask

View File

@ -61,6 +61,7 @@ from .fce_aug import *
from .fce_targets import FCENetTargets
from .ct_process import *
from .drrg_targets import DRRGTargets
from .latex_ocr_aug import *
def transform(data, ops=None):

View File

@ -25,6 +25,8 @@ import json
import copy
import random
from random import sample
from collections import defaultdict
from tokenizers import Tokenizer as TokenizerFast
from ppocr.utils.logging import get_logger
from ppocr.data.imaug.vqa.augment import order_by_tbyx
@ -1770,3 +1772,106 @@ class CPPDLabelEncode(BaseRecLabelEncode):
if len(text_list) == 0:
return None, None, None
return text_list, text_node_index, text_node_num
class LatexOCRLabelEncode(object):
def __init__(
self,
rec_char_dict_path,
**kwargs,
):
self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
def _convert_encoding(
self,
encoding,
return_token_type_ids=None,
return_attention_mask=None,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_offsets_mapping=False,
return_length=False,
verbose=True,
):
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_overflowing_tokens and encoding.overflowing is not None:
encodings = [encoding] + encoding.overflowing
else:
encodings = [encoding]
encoding_dict = defaultdict(list)
for e in encodings:
encoding_dict["input_ids"].append(e.ids)
if return_token_type_ids:
encoding_dict["token_type_ids"].append(e.type_ids)
if return_attention_mask:
encoding_dict["attention_mask"].append(e.attention_mask)
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping:
encoding_dict["offset_mapping"].append(e.offsets)
if return_length:
encoding_dict["length"].append(len(e.ids))
return encoding_dict, encodings
def encode(
self,
text,
text_pair=None,
return_token_type_ids=False,
add_special_tokens=True,
is_split_into_words=False,
):
batched_input = text
encodings = self.tokenizer.encode_batch(
batched_input,
add_special_tokens=add_special_tokens,
is_pretokenized=is_split_into_words,
)
tokens_and_encodings = [
self._convert_encoding(
encoding=encoding,
return_token_type_ids=False,
return_attention_mask=None,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_offsets_mapping=False,
return_length=False,
verbose=True,
)
for encoding in encodings
]
sanitized_tokens = {}
for key in tokens_and_encodings[0][0].keys():
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
sanitized_tokens[key] = stack
return sanitized_tokens
def __call__(self, eqs):
topk = self.encode(eqs)
for k, p in zip(topk, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
process_seq = [[p[0]] + x + [p[1]] for x in topk[k]]
max_length = 0
for seq in process_seq:
max_length = max(max_length, len(seq))
labels = np.zeros((len(process_seq), max_length), dtype="int64")
for idx, seq in enumerate(process_seq):
l = len(seq)
labels[idx][:l] = seq
topk[k] = labels
return (
np.array(topk["input_ids"]).astype(np.int64),
np.array(topk["attention_mask"]).astype(np.int64),
max_length,
)

View File

@ -0,0 +1,179 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/transforms.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import cv2
import numpy as np
import albumentations as A
from PIL import Image
class LatexTrainTransform:
def __init__(self, bitmap_prob=0.04, **kwargs):
# your init code
self.bitmap_prob = bitmap_prob
self.train_transform = A.Compose(
[
A.Compose(
[
A.ShiftScaleRotate(
shift_limit=0,
scale_limit=(-0.15, 0),
rotate_limit=1,
border_mode=0,
interpolation=3,
value=[255, 255, 255],
p=1,
),
A.GridDistortion(
distort_limit=0.1,
border_mode=0,
interpolation=3,
value=[255, 255, 255],
p=0.5,
),
],
p=0.15,
),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
A.GaussNoise(10, p=0.2),
A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
A.ImageCompression(95, p=0.3),
A.ToGray(always_apply=True),
]
)
def __call__(self, data):
img = data["image"]
if np.random.random() < self.bitmap_prob:
img[img != 255] = 0
img = self.train_transform(image=img)["image"]
data["image"] = img
return data
class LatexTestTransform:
def __init__(self, **kwargs):
# your init code
self.test_transform = A.Compose(
[
A.ToGray(always_apply=True),
]
)
def __call__(self, data):
img = data["image"]
img = self.test_transform(image=img)["image"]
data["image"] = img
return data
class MinMaxResize:
def __init__(self, min_dimensions=[32, 32], max_dimensions=[672, 192], **kwargs):
# your init code
self.min_dimensions = min_dimensions
self.max_dimensions = max_dimensions
# pass
def pad_(self, img, divable=32):
threshold = 128
data = np.array(img.convert("LA"))
if data[..., -1].var() == 0:
data = (data[..., 0]).astype(np.uint8)
else:
data = (255 - data[..., -1]).astype(np.uint8)
data = (data - data.min()) / (data.max() - data.min()) * 255
if data.mean() > threshold:
# To invert the text to white
gray = 255 * (data < threshold).astype(np.uint8)
else:
gray = 255 * (data > threshold).astype(np.uint8)
data = 255 - data
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
rect = data[b : b + h, a : a + w]
im = Image.fromarray(rect).convert("L")
dims = []
for x in [w, h]:
div, mod = divmod(x, divable)
dims.append(divable * (div + (1 if mod > 0 else 0)))
padded = Image.new("L", dims, 255)
padded.paste(im, (0, 0, im.size[0], im.size[1]))
return padded
def minmax_size_(self, img, max_dimensions, min_dimensions):
if max_dimensions is not None:
ratios = [a / b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size) // max(ratios)
img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
if min_dimensions is not None:
# hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
padded_size = [
max(img_dim, min_dim)
for img_dim, min_dim in zip(img.size, min_dimensions)
]
if padded_size != list(img.size): # assert hypothesis
padded_im = Image.new("L", padded_size, 255)
padded_im.paste(img, img.getbbox())
img = padded_im
return img
def __call__(self, data):
img = data["image"]
h, w = img.shape[:2]
if (
self.min_dimensions[0] <= w <= self.max_dimensions[0]
and self.min_dimensions[1] <= h <= self.max_dimensions[1]
):
return data
else:
im = Image.fromarray(np.uint8(img))
im = self.minmax_size_(
self.pad_(im), self.max_dimensions, self.min_dimensions
)
im = np.array(im)
im = np.dstack((im, im, im))
data["image"] = im
return data
class LatexImageFormat:
def __init__(self, **kwargs):
# your init code
pass
def __call__(self, data):
img = data["image"]
im_h, im_w = img.shape[:2]
divide_h = math.ceil(im_h / 16) * 16
divide_w = math.ceil(im_w / 16) * 16
img = img[:, :, 0]
img = np.pad(
img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
)
img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
data["image"] = img_expanded
return data

View File

@ -0,0 +1,172 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/dataset.py
"""
import numpy as np
import cv2
import math
import os
import json
import pickle
import random
import traceback
import paddle
from paddle.io import Dataset
from .imaug.label_ops import LatexOCRLabelEncode
from .imaug import transform, create_operators
class LaTeXOCRDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(LaTeXOCRDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config["Global"]
dataset_config = config[mode]["dataset"]
loader_config = config[mode]["loader"]
pkl_path = dataset_config.pop("data")
self.min_dimensions = dataset_config.pop("min_dimensions")
self.max_dimensions = dataset_config.pop("max_dimensions")
self.batchsize = dataset_config.pop("batch_size_per_pair")
self.keep_smaller_batches = dataset_config.pop("keep_smaller_batches")
self.max_seq_len = global_config.pop("max_seq_len")
self.rec_char_dict_path = global_config.pop("rec_char_dict_path")
self.tokenizer = LatexOCRLabelEncode(self.rec_char_dict_path)
file = open(pkl_path, "rb")
data = pickle.load(file)
temp = {}
for k in data:
if (
self.min_dimensions[0] <= k[0] <= self.max_dimensions[0]
and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1]
):
temp[k] = data[k]
self.data = temp
self.do_shuffle = loader_config["shuffle"]
self.seed = seed
if self.mode == "train" and self.do_shuffle:
random.seed(self.seed)
self.pairs = []
for k in self.data:
info = np.array(self.data[k], dtype=object)
p = (
paddle.randperm(len(info))
if self.mode == "train" and self.do_shuffle
else paddle.arange(len(info))
)
for i in range(0, len(info), self.batchsize):
batch = info[p[i : i + self.batchsize]]
if len(batch.shape) == 1:
batch = batch[None, :]
if len(batch) < self.batchsize and not self.keep_smaller_batches:
continue
self.pairs.append(batch)
if self.do_shuffle:
self.pairs = np.random.permutation(np.array(self.pairs, dtype=object))
else:
self.pairs = np.array(self.pairs, dtype=object)
self.size = len(self.pairs)
self.set_epoch_as_seed(self.seed, dataset_config)
self.ops = create_operators(dataset_config["transforms"], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
self.need_reset = True
def set_epoch_as_seed(self, seed, dataset_config):
if self.mode == "train":
try:
border_map_id = [
index
for index, dictionary in enumerate(dataset_config["transforms"])
if "MakeBorderMap" in dictionary
][0]
shrink_map_id = [
index
for index, dictionary in enumerate(dataset_config["transforms"])
if "MakeShrinkMap" in dictionary
][0]
dataset_config["transforms"][border_map_id]["MakeBorderMap"][
"epoch"
] = (seed if seed is not None else 0)
dataset_config["transforms"][shrink_map_id]["MakeShrinkMap"][
"epoch"
] = (seed if seed is not None else 0)
except Exception as E:
print(E)
return
def shuffle_data_random(self):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
batch = self.pairs[idx]
eqs, ims = batch.T
try:
max_width, max_height, max_length = 0, 0, 0
images_transform = []
for img_path in ims:
data = {
"img_path": img_path,
}
with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
item = transform(data, self.ops)
images_transform.append(np.array(item[0]))
image_concat = np.concatenate(images_transform, axis=0)[:, np.newaxis, :, :]
images_transform = image_concat.astype(np.float32)
labels, attention_mask, max_length = self.tokenizer(list(eqs))
if self.max_seq_len < max_length:
rnd_idx = (
np.random.randint(self.__len__())
if self.mode == "train"
else (idx + 1) % self.__len__()
)
return self.__getitem__(rnd_idx)
return (images_transform, labels, attention_mask)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data["img_path"], traceback.format_exc()
)
)
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = (
np.random.randint(self.__len__())
if self.mode == "train"
else (idx + 1) % self.__len__()
)
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return self.size

View File

@ -45,6 +45,7 @@ from .rec_satrn_loss import SATRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_parseq_loss import ParseQLoss
from .rec_cppd_loss import CPPDLoss
from .rec_latexocr_loss import LaTeXOCRLoss
# cls loss
from .cls_loss import ClsLoss
@ -107,6 +108,7 @@ def build_loss(config):
"NRTRLoss",
"ParseQLoss",
"CPPDLoss",
"LaTeXOCRLoss",
]
config = copy.deepcopy(config)
module_name = config.pop("name")

View File

@ -0,0 +1,47 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class LaTeXOCRLoss(nn.Layer):
"""
LaTeXOCR adopt CrossEntropyLoss for network training.
"""
def __init__(self):
super(LaTeXOCRLoss, self).__init__()
self.ignore_index = -100
self.cross = nn.CrossEntropyLoss(
reduction="mean", ignore_index=self.ignore_index
)
def forward(self, preds, batch):
word_probs = preds
labels = batch[1][:, 1:]
word_loss = self.cross(
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
paddle.reshape(labels, [-1]),
)
loss = word_loss
return {"loss": loss}

View File

@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
from .rec_metric import RecMetric, CNTMetric, CANMetric
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
@ -50,6 +50,7 @@ def build_metric(config):
"CTMetric",
"CNTMetric",
"CANMetric",
"LaTeXOCRMetric",
]
config = copy.deepcopy(config)

240
ppocr/metrics/bleu.py Normal file
View File

@ -0,0 +1,240 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
"""
import re
import math
import collections
from functools import lru_cache
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i : i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
"""
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for references, translation in zip(reference_corpus, translation_corpus):
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = (matches_by_order[i] + 1.0) / (
possible_matches_by_order[i] + 1.0
)
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (
float(matches_by_order[i]) / possible_matches_by_order[i]
)
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.0
else:
bp = math.exp(1 - 1.0 / ratio)
bleu = geo_mean * bp
return (bleu, precisions, bp, ratio, translation_length, reference_length)
class BaseTokenizer:
"""A base dummy tokenizer to derive from."""
def signature(self):
"""
Returns a signature for the tokenizer.
:return: signature string
"""
return "none"
def __call__(self, line):
"""
Tokenizes an input line with the tokenizer.
:param line: a segment to tokenize
:return: the tokenized line
"""
return line
class TokenizerRegexp(BaseTokenizer):
def signature(self):
return "re"
def __init__(self):
self._re = [
# language-dependent part (assuming Western languages)
(re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
# tokenize period and comma unless preceded by a digit
(re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
# tokenize period and comma unless followed by a digit
(re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
# tokenize dash when preceded by a digit
(re.compile(r"([0-9])(-)"), r"\1 \2 "),
# one space only between words
# NOTE: Doing this in Python (below) is faster
# (re.compile(r'\s+'), r' '),
]
@lru_cache(maxsize=2**16)
def __call__(self, line):
"""Common post-processing tokenizer for `13a` and `zh` tokenizers.
:param line: a segment to tokenize
:return: the tokenized line
"""
for _re, repl in self._re:
line = _re.sub(repl, line)
# no leading or trailing spaces, single space within words
# return ' '.join(line.split())
# This line is changed with regards to the original tokenizer (seen above) to return individual words
return line.split()
class Tokenizer13a(BaseTokenizer):
def signature(self):
return "13a"
def __init__(self):
self._post_tokenizer = TokenizerRegexp()
@lru_cache(maxsize=2**16)
def __call__(self, line):
"""Tokenizes an input line using a relatively minimal tokenization
that is however equivalent to mteval-v13a, used by WMT.
:param line: a segment to tokenize
:return: the tokenized line
"""
# language-independent part:
line = line.replace("<skipped>", "")
line = line.replace("-\n", "")
line = line.replace("\n", " ")
if "&" in line:
line = line.replace("&quot;", '"')
line = line.replace("&amp;", "&")
line = line.replace("&lt;", "<")
line = line.replace("&gt;", ">")
return self._post_tokenizer(f" {line} ")
def compute_blue_score(
predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False
):
# if only one reference is provided make sure we still use list of lists
if isinstance(references[0], str):
references = [[ref] for ref in references]
references = [[tokenizer(r) for r in ref] for ref in references]
predictions = [tokenizer(p) for p in predictions]
score = compute_bleu(
reference_corpus=references,
translation_corpus=predictions,
max_order=max_order,
smooth=smooth,
)
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
return bleu
def cal_distance(word1, word2):
m = len(word1)
n = len(word2)
if m * n == 0:
return m + n
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
a = dp[i - 1][j] + 1
b = dp[i][j - 1] + 1
c = dp[i - 1][j - 1]
if word1[i - 1] != word2[j - 1]:
c += 1
dp[i][j] = min(a, b, c)
return dp[m][n]
def compute_edit_distance(prediction, label):
prediction = prediction.strip().split(" ")
label = label.strip().split(" ")
distance = cal_distance(prediction, label)
return distance

View File

@ -17,6 +17,7 @@ from difflib import SequenceMatcher
import numpy as np
import string
from .bleu import compute_blue_score, compute_edit_distance
class RecMetric(object):
@ -177,3 +178,121 @@ class CANMetric(object):
self.exp_right = []
self.word_total_length = 0
self.exp_total_num = 0
class LaTeXOCRMetric(object):
def __init__(self, main_indicator="exp_rate", cal_blue_score=False, **kwargs):
self.main_indicator = main_indicator
self.cal_blue_score = cal_blue_score
self.edit_right = []
self.exp_right = []
self.blue_right = []
self.e1_right = []
self.e2_right = []
self.e3_right = []
self.editdistance_total_length = 0
self.exp_total_num = 0
self.edit_dist = 0
self.exp_rate = 0
if self.cal_blue_score:
self.blue_score = 0
self.e1 = 0
self.e2 = 0
self.e3 = 0
self.reset()
self.epoch_reset()
def __call__(self, preds, batch, **kwargs):
for k, v in kwargs.items():
epoch_reset = v
if epoch_reset:
self.epoch_reset()
word_pred = preds
word_label = batch
line_right, e1, e2, e3 = 0, 0, 0, 0
lev_dist = []
for labels, prediction in zip(word_label, word_pred):
if prediction == labels:
line_right += 1
distance = compute_edit_distance(prediction, labels)
lev_dist.append(Levenshtein.normalized_distance(prediction, labels))
if distance <= 1:
e1 += 1
if distance <= 2:
e2 += 1
if distance <= 3:
e3 += 1
batch_size = len(lev_dist)
self.edit_dist = sum(lev_dist) # float
self.exp_rate = line_right # float
if self.cal_blue_score:
self.blue_score = compute_blue_score(word_pred, word_label)
self.e1 = e1
self.e2 = e2
self.e3 = e3
exp_length = len(word_label)
self.edit_right.append(self.edit_dist)
self.exp_right.append(self.exp_rate)
if self.cal_blue_score:
self.blue_right.append(self.blue_score * batch_size)
self.e1_right.append(self.e1)
self.e2_right.append(self.e2)
self.e3_right.append(self.e3)
self.editdistance_total_length = self.editdistance_total_length + exp_length
self.exp_total_num = self.exp_total_num + exp_length
def get_metric(self):
"""
return {
'edit distance': 0,
"blue_score": 0,
"exp_rate": 0,
}
"""
cur_edit_distance = sum(self.edit_right) / self.exp_total_num
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
if self.cal_blue_score:
cur_blue_score = sum(self.blue_right) / self.editdistance_total_length
cur_exp_1 = sum(self.e1_right) / self.exp_total_num
cur_exp_2 = sum(self.e2_right) / self.exp_total_num
cur_exp_3 = sum(self.e3_right) / self.exp_total_num
self.reset()
if self.cal_blue_score:
return {
"blue_score ": cur_blue_score,
"edit distance ": cur_edit_distance,
"exp_rate ": cur_exp_rate,
"exp_rate<=1 ": cur_exp_1,
"exp_rate<=2 ": cur_exp_2,
"exp_rate<=3 ": cur_exp_3,
}
else:
return {
"edit distance": cur_edit_distance,
"exp_rate": cur_exp_rate,
"exp_rate<=1 ": cur_exp_1,
"exp_rate<=2 ": cur_exp_2,
"exp_rate<=3 ": cur_exp_3,
}
def reset(self):
self.edit_dist = 0
self.exp_rate = 0
if self.cal_blue_score:
self.blue_score = 0
self.e1 = 0
self.e2 = 0
self.e3 = 0
def epoch_reset(self):
self.edit_right = []
self.exp_right = []
if self.cal_blue_score:
self.blue_right = []
self.e1_right = []
self.e2_right = []
self.e3_right = []
self.editdistance_total_length = 0
self.exp_total_num = 0

View File

@ -59,6 +59,8 @@ def build_backbone(config, model_type):
from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet
from .rec_resnetv2 import ResNetV2
from .rec_hybridvit import HybridTransformer
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
@ -89,6 +91,8 @@ def build_backbone(config, model_type):
"ViT",
"RepSVTR",
"SVTRv2",
"ResNetV2",
"HybridTransformer",
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet

View File

@ -0,0 +1,529 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import repeat
import collections
import math
from functools import partial
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppocr.modeling.backbones.rec_resnetv2 import (
ResNetV2,
StdConv2dSame,
DropPath,
get_padding,
)
from paddle.nn.initializer import (
TruncatedNormal,
Constant,
Normal,
KaimingUniform,
XavierUniform,
)
normal_ = Normal(mean=0.0, std=1e-6)
zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
kaiming_normal_ = KaimingUniform(nonlinearity="relu")
trunc_normal_ = TruncatedNormal(std=0.02)
xavier_uniform_ = XavierUniform()
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class Conv2dAlign(nn.Conv2D):
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self,
in_channel,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-6,
):
super().__init__(
in_channel,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias,
weight_attr=True,
)
self.eps = eps
def forward(self, x):
x = F.conv2d(
x,
self.weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
)
return x
class HybridEmbed(nn.Layer):
"""CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
):
super().__init__()
assert isinstance(backbone, nn.Layer)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
feature_dim = 1024
feature_size = (42, 12)
patch_size = (1, 1)
assert (
feature_size[0] % patch_size[0] == 0
and feature_size[1] % patch_size[1] == 0
)
self.grid_size = (
feature_size[0] // patch_size[0],
feature_size[1] // patch_size[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2D(
feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
weight_attr=True,
bias_attr=True,
)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose([0, 2, 1])
return x
class myLinear(nn.Linear):
def __init__(self, in_channel, out_channels, weight_attr=True, bias_attr=True):
super().__init__(
in_channel, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
)
def forward(self, x):
return paddle.matmul(x, self.weight, transpose_y=True) + self.bias
class Attention(nn.Layer):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = myLinear(dim, dim, weight_attr=True, bias_attr=True)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape([B, N, 3, self.num_heads, C // self.num_heads])
.transpose([2, 0, 3, 1, 4])
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Layer):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Layer):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class HybridTransformer(nn.Layer):
"""Implementation of HybridTransformer.
Args:
x: input images with shape [N, 1, H, W]
label: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
attention_mask: LaTeX-OCR attention mask with shape [N, L] , L is the max sequence length
Returns:
The encoded features with shape [N, 1, H//16, W//16]
"""
def __init__(
self,
backbone_layers=[2, 3, 7],
input_channel=1,
is_predict=False,
is_export=False,
img_size=(224, 224),
patch_size=16,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
representation_size=None,
distilled=False,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
embed_layer=None,
norm_layer=None,
act_layer=None,
weight_init="",
**kwargs,
):
super(HybridTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
act_layer = act_layer or nn.GELU
self.height, self.width = img_size
self.patch_size = patch_size
backbone = ResNetV2(
layers=backbone_layers,
num_classes=0,
global_pool="",
in_chans=input_channel,
preact=False,
stem_type="same",
conv_layer=StdConv2dSame,
is_export=is_export,
)
min_patch_size = 2 ** (len(backbone_layers) + 1)
self.patch_embed = HybridEmbed(
img_size=img_size,
patch_size=patch_size // min_patch_size,
in_chans=input_channel,
embed_dim=embed_dim,
backbone=backbone,
)
num_patches = self.patch_embed.num_patches
self.cls_token = paddle.create_parameter([1, 1, embed_dim], dtype="float32")
self.dist_token = (
paddle.create_parameter(
[1, 1, embed_dim],
dtype="float32",
)
if distilled
else None
)
self.pos_embed = paddle.create_parameter(
[1, num_patches + self.num_tokens, embed_dim], dtype="float32"
)
self.pos_drop = nn.Dropout(p=drop_rate)
zeros_(self.cls_token)
if self.dist_token is not None:
zeros_(self.dist_token)
zeros_(self.pos_embed)
dpr = [
x.item() for x in paddle.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(
("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh())
)
else:
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
self.head_dist = None
if distilled:
self.head_dist = (
nn.Linear(self.embed_dim, self.num_classes)
if num_classes > 0
else nn.Identity()
)
self.init_weights(weight_init)
self.out_channels = embed_dim
self.is_predict = is_predict
self.is_export = is_export
def init_weights(self, mode=""):
assert mode in ("jax", "jax_nlhb", "nlhb", "")
head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(_init_vit_weights)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
def load_pretrained(self, checkpoint_path, prefix=""):
raise NotImplementedError
def no_weight_decay(self):
return {"pos_embed", "cls_token", "dist_token"}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
if self.num_tokens == 2:
self.head_dist = (
nn.Linear(self.embed_dim, self.num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
B, c, h, w = x.shape
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(
[B, -1, -1]
) # stole cls_tokens impl from Phil Wang, thanks
x = paddle.concat((cls_tokens, x), axis=1)
h, w = h // self.patch_size, w // self.patch_size
repeat_tensor = (
paddle.arange(h) * (self.width // self.patch_size - w)
).reshape([-1, 1])
repeat_tensor = paddle.repeat_interleave(
repeat_tensor, paddle.to_tensor(w), axis=1
).reshape([-1])
pos_emb_ind = repeat_tensor + paddle.arange(h * w)
pos_emb_ind = paddle.concat(
(paddle.zeros([1], dtype="int64"), pos_emb_ind + 1), axis=0
).cast(paddle.int64)
x += self.pos_embed[:, pos_emb_ind]
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward(self, input_data):
if self.training:
x, label, attention_mask = input_data
else:
if isinstance(input_data, list):
x = input_data[0]
else:
x = input_data
x = self.forward_features(x)
x = self.head(x)
if self.training:
return x, label, attention_mask
else:
return x
def _init_vit_weights(
module: nn.Layer, name: str = "", head_bias: float = 0.0, jax_impl: bool = False
):
"""ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(module, nn.Linear):
if name.startswith("head"):
zeros_(module.weight)
constant_ = Constant(value=head_bias)
constant_(module.bias, head_bias)
elif name.startswith("pre_logits"):
zeros_(module.bias)
else:
if jax_impl:
xavier_uniform_(module.weight)
if module.bias is not None:
if "mlp" in name:
normal_(module.bias)
else:
zeros_(module.bias)
else:
trunc_normal_(module.weight)
if module.bias is not None:
zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2D):
# NOTE conv was left to pytorch default in my original init
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2D)):
zeros_(module.bias)
ones_(module.weight)

File diff suppressed because it is too large Load Diff

View File

@ -40,6 +40,7 @@ def build_head(config):
from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead
from .rec_latexocr_head import LaTeXOCRHead
from .rec_satrn_head import SATRNHead
from .rec_parseq_head import ParseQHead
from .rec_cppd_head import CPPDHead
@ -81,6 +82,7 @@ def build_head(config):
"RFLHead",
"DRRGHead",
"CANHead",
"LaTeXOCRHead",
"SATRNHead",
"PFHeadLocal",
"ParseQHead",

File diff suppressed because it is too large Load Diff

View File

@ -42,6 +42,7 @@ from .rec_postprocess import (
SATRNLabelDecode,
ParseQLabelDecode,
CPPDLabelDecode,
LaTeXOCRDecode,
)
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
@ -96,6 +97,7 @@ def build_post_process(config, global_config=None):
"SATRNLabelDecode",
"ParseQLabelDecode",
"CPPDLabelDecode",
"LaTeXOCRDecode",
]
if config["name"] == "PSEPostProcess":

View File

@ -15,6 +15,7 @@
import numpy as np
import paddle
from paddle.nn import functional as F
from tokenizers import Tokenizer as TokenizerFast
import re
@ -1210,3 +1211,53 @@ class CPPDLabelDecode(NRTRLabelDecode):
def add_special_char(self, dict_character):
dict_character = ["</s>"] + dict_character
return dict_character
class LaTeXOCRDecode(object):
"""Convert between latex-symbol and symbol-index"""
def __init__(self, rec_char_dict_path, **kwargs):
super(LaTeXOCRDecode, self).__init__()
self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
def post_process(self, s):
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter = "[a-zA-Z]"
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
def decode(self, tokens):
if len(tokens.shape) == 1:
tokens = tokens[None, :]
dec = [self.tokenizer.decode(tok) for tok in tokens]
dec_str_list = [
"".join(detok.split(" "))
.replace("Ġ", " ")
.replace("[EOS]", "")
.replace("[BOS]", "")
.replace("[PAD]", "")
.strip()
for detok in dec
]
return [self.post_process(dec_str) for dec_str in dec_str_list]
def __call__(self, preds, label=None, mode="eval", *args, **kwargs):
if mode == "train":
preds_idx = np.array(preds.argmax(axis=2))
text = self.decode(preds_idx)
else:
text = self.decode(np.array(preds))
if label is None:
return text
label = self.decode(np.array(label))
return text, label

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,70 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from tqdm import tqdm
import os
import cv2
import imagesize
from collections import defaultdict
import glob
from os.path import join
import argparse
def txt2pickle(images, equations, save_dir):
save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1]))
min_dimensions = (32, 32)
max_dimensions = (672, 192)
max_length = 512
data = defaultdict(lambda: [])
if images is not None and equations is not None:
images_list = [
path.replace("\\", "/") for path in glob.glob(join(images, "*.png"))
]
indices = [int(os.path.basename(img).split(".")[0]) for img in images_list]
eqs = open(equations, "r").read().split("\n")
for i, im in tqdm(enumerate(images_list), total=len(images_list)):
width, height = imagesize.get(im)
if (
min_dimensions[0] <= width <= max_dimensions[0]
and min_dimensions[1] <= height <= max_dimensions[1]
):
data[(width, height)].append((eqs[indices[i]], im))
data = dict(data)
with open(save_p, "wb") as file:
pickle.dump(data, file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_dir",
type=str,
default=".",
help="Input_label or input path to be converted",
)
parser.add_argument(
"--mathtxt_path",
type=str,
default=".",
help="Input_label or input path to be converted",
)
parser.add_argument(
"--output_dir", type=str, default="out_label.txt", help="Output file name"
)
args = parser.parse_args()
txt2pickle(args.image_dir, args.mathtxt_path, args.output_dir)

View File

@ -12,3 +12,6 @@ cython
Pillow
pyyaml
requests
albumentations==1.4.10
tokenizers==0.19.1
imagesize

View File

@ -105,6 +105,8 @@ def main():
if "model_type" in config["Architecture"].keys():
if config["Architecture"]["algorithm"] == "CAN":
model_type = "can"
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
model_type = "latexocr"
else:
model_type = config["Architecture"]["model_type"]
else:

View File

@ -131,6 +131,11 @@ def export_single_model(
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "LaTeXOCR":
other_shape = [
paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids

View File

@ -133,6 +133,11 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
}
elif self.rec_algorithm == "LaTeXOCR":
postprocess_params = {
"name": "LaTeXOCRDecode",
"rec_char_dict_path": args.rec_char_dict_path,
}
elif self.rec_algorithm == "ParseQ":
postprocess_params = {
"name": "ParseQLabelDecode",
@ -450,6 +455,90 @@ class TextRecognizer(object):
return img
def pad_(self, img, divable=32):
threshold = 128
data = np.array(img.convert("LA"))
if data[..., -1].var() == 0:
data = (data[..., 0]).astype(np.uint8)
else:
data = (255 - data[..., -1]).astype(np.uint8)
data = (data - data.min()) / (data.max() - data.min()) * 255
if data.mean() > threshold:
# To invert the text to white
gray = 255 * (data < threshold).astype(np.uint8)
else:
gray = 255 * (data > threshold).astype(np.uint8)
data = 255 - data
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
rect = data[b : b + h, a : a + w]
im = Image.fromarray(rect).convert("L")
dims = []
for x in [w, h]:
div, mod = divmod(x, divable)
dims.append(divable * (div + (1 if mod > 0 else 0)))
padded = Image.new("L", dims, 255)
padded.paste(im, (0, 0, im.size[0], im.size[1]))
return padded
def minmax_size_(
self,
img,
max_dimensions,
min_dimensions,
):
if max_dimensions is not None:
ratios = [a / b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size) // max(ratios)
img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
if min_dimensions is not None:
# hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
padded_size = [
max(img_dim, min_dim)
for img_dim, min_dim in zip(img.size, min_dimensions)
]
if padded_size != list(img.size): # assert hypothesis
padded_im = Image.new("L", padded_size, 255)
padded_im.paste(img, img.getbbox())
img = padded_im
return img
def norm_img_latexocr(self, img):
# CAN only predict gray scale image
shape = (1, 1, 3)
mean = [0.7931, 0.7931, 0.7931]
std = [0.1738, 0.1738, 0.1738]
scale = 255.0
min_dimensions = [32, 32]
max_dimensions = [672, 192]
mean = np.array(mean).reshape(shape).astype("float32")
std = np.array(std).reshape(shape).astype("float32")
im_h, im_w = img.shape[:2]
if (
min_dimensions[0] <= im_w <= max_dimensions[0]
and min_dimensions[1] <= im_h <= max_dimensions[1]
):
pass
else:
img = Image.fromarray(np.uint8(img))
img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
img = np.array(img)
im_h, im_w = img.shape[:2]
img = np.dstack([img, img, img])
img = (img.astype("float32") * scale - mean) / std
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
divide_h = math.ceil(im_h / 16) * 16
divide_w = math.ceil(im_w / 16) * 16
img = np.pad(
img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
)
img = img[:, :, np.newaxis].transpose(2, 0, 1)
img = img.astype("float32")
return img
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@ -552,6 +641,10 @@ class TextRecognizer(object):
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
elif self.rec_algorithm == "LaTeXOCR":
norm_img = self.norm_img_latexocr(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(
img_list[indices[ino]], max_wh_ratio
@ -666,6 +759,29 @@ class TextRecognizer(object):
if self.benchmark:
self.autolog.times.stamp()
preds = outputs
elif self.rec_algorithm == "LaTeXOCR":
inputs = [norm_img_batch]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs
else:
input_names = self.predictor.get_input_names()
input_tensor = []
for i in range(len(input_names)):
input_tensor_i = self.predictor.get_input_handle(input_names[i])
input_tensor_i.copy_from_cpu(inputs[i])
input_tensor.append(input_tensor_i)
self.input_tensor = input_tensor
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs
else:
if self.use_onnx:
input_dict = {}
@ -692,6 +808,9 @@ class TextRecognizer(object):
wh_ratio_list=wh_ratio_list,
max_wh_ratio=max_wh_ratio,
)
elif self.postprocess_params["name"] == "LaTeXOCRDecode":
preds = [p.reshape([-1]) for p in preds]
rec_result = self.postprocess_op(preds)
else:
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):

View File

@ -183,6 +183,8 @@ def main():
elif isinstance(post_result, list) and isinstance(post_result[0], int):
# for RFLearning CNT branch
info = str(post_result[0])
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
info = str(post_result[0])
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])

View File

@ -324,6 +324,8 @@ def train(
preds = model(batch)
elif algorithm in ["CAN"]:
preds = model(batch[:3])
elif algorithm in ["LaTeXOCR"]:
preds = model(batch)
else:
preds = model(images)
preds = to_float32(preds)
@ -339,6 +341,8 @@ def train(
preds = model(batch)
elif algorithm in ["CAN"]:
preds = model(batch[:3])
elif algorithm in ["LaTeXOCR"]:
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
@ -360,6 +364,10 @@ def train(
elif algorithm in ["CAN"]:
model_type = "can"
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
elif algorithm in ["LaTeXOCR"]:
model_type = "latexocr"
post_result = post_process_class(preds, batch[1], mode="train")
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
else:
if config["Loss"]["name"] in [
"MultiLoss",
@ -600,6 +608,8 @@ def eval(
preds = model(batch)
elif model_type in ["can"]:
preds = model(batch[:3])
elif model_type in ["latexocr"]:
preds = model(batch)
elif model_type in ["sr"]:
preds = model(batch)
sr_img = preds["sr_img"]
@ -614,6 +624,8 @@ def eval(
preds = model(batch)
elif model_type in ["can"]:
preds = model(batch[:3])
elif model_type in ["latexocr"]:
preds = model(batch)
elif model_type in ["sr"]:
preds = model(batch)
sr_img = preds["sr_img"]
@ -640,6 +652,9 @@ def eval(
eval_class(preds, batch_numpy)
elif model_type in ["can"]:
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
elif model_type in ["latexocr"]:
post_result = post_process_class(preds, batch[1], "eval")
eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0))
else:
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
@ -777,6 +792,7 @@ def preprocess(is_train=False):
"SVTR_HGNet",
"ParseQ",
"CPPD",
"LaTeXOCR",
]
if use_xpu: