diff --git a/configs/rec/rec_latex_ocr.yml b/configs/rec/rec_latex_ocr.yml new file mode 100644 index 0000000000..cde3449076 --- /dev/null +++ b/configs/rec/rec_latex_ocr.yml @@ -0,0 +1,126 @@ +Global: + use_gpu: True + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 100 + save_model_dir: ./output/rec/latex_ocr/ + save_epoch_step: 5 + max_seq_len: 512 + # evaluation is run every 60000 iterations (22 epoch)(batch_size = 56) + eval_batch_step: [0, 60000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/datasets/pme_demo/0000013.png + infer_mode: False + use_space_char: False + rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json + save_res_path: ./output/rec/predicts_latexocr.txt + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Const + learning_rate: 0.0001 + +Architecture: + model_type: rec + algorithm: LaTeXOCR + in_channels: 1 + Transform: + Backbone: + name: HybridTransformer + img_size: [192, 672] + patch_size: 16 + num_classes: 0 + embed_dim: 256 + depth: 4 + num_heads: 8 + input_channel: 1 + is_predict: False + is_export: False + Head: + name: LaTeXOCRHead + pad_value: 0 + is_export: False + decoder_args: + attn_on_attn: True + cross_attend: True + ff_glu: True + rel_pos_bias: False + use_scalenorm: False + +Loss: + name: LaTeXOCRLoss + +PostProcess: + name: LaTeXOCRDecode + rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json + +Metric: + name: LaTeXOCRMetric + main_indicator: exp_rate + cal_blue_score: False + +Train: + dataset: + name: LaTeXOCRDataSet + data: ./train_data/LaTeXOCR/latexocr_train.pkl + min_dimensions: [32, 32] + max_dimensions: [672, 192] + batch_size_per_pair: 56 + keep_smaller_batches: False + transforms: + - DecodeImage: + channel_first: False + - MinMaxResize: + min_dimensions: [32, 32] + max_dimensions: [672, 192] + - LatexTrainTransform: + bitmap_prob: .04 + - NormalizeImage: + mean: [0.7931, 0.7931, 0.7931] + std: [0.1738, 0.1738, 0.1738] + order: 'hwc' + - LatexImageFormat: + - KeepKeys: + keep_keys: ['image'] + loader: + shuffle: True + batch_size_per_card: 1 + drop_last: False + num_workers: 0 + collate_fn: LaTeXOCRCollator + +Eval: + dataset: + name: LaTeXOCRDataSet + data: ./train_data/LaTeXOCR/latexocr_val.pkl + min_dimensions: [32, 32] + max_dimensions: [672, 192] + batch_size_per_pair: 10 + keep_smaller_batches: True + transforms: + - DecodeImage: + channel_first: False + - MinMaxResize: + min_dimensions: [32, 32] + max_dimensions: [672, 192] + - LatexTestTransform: + - NormalizeImage: + mean: [0.7931, 0.7931, 0.7931] + std: [0.1738, 0.1738, 0.1738] + order: 'hwc' + - LatexImageFormat: + - KeepKeys: + keep_keys: ['image'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 0 + collate_fn: LaTeXOCRCollator diff --git a/doc/datasets/pme_demo/0000013.png b/doc/datasets/pme_demo/0000013.png new file mode 100644 index 0000000000..9f8b11b580 Binary files /dev/null and b/doc/datasets/pme_demo/0000013.png differ diff --git a/doc/datasets/pme_demo/0000295.png b/doc/datasets/pme_demo/0000295.png new file mode 100644 index 0000000000..26a271abf6 Binary files /dev/null and b/doc/datasets/pme_demo/0000295.png differ diff --git a/doc/datasets/pme_demo/0000562.png b/doc/datasets/pme_demo/0000562.png new file mode 100644 index 0000000000..121eda8e2a Binary files /dev/null and b/doc/datasets/pme_demo/0000562.png differ diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 6fb277a1f4..90b4b62013 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -137,6 +137,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 已支持的公式识别算法列表(戳链接获取使用教程): - [x] [CAN](./algorithm_rec_can.md) +- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr.md) 在CROHME手写公式数据集上,算法效果如下: @@ -144,6 +145,13 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 | ----- | ----- | ----- | ----- | ----- | |CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)| +在LaTeX-OCR印刷公式数据集上,算法效果如下: + +| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接| +|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- | +| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)| + + ## 2. 端到端算法 diff --git a/doc/doc_ch/algorithm_rec_latex_ocr.md b/doc/doc_ch/algorithm_rec_latex_ocr.md new file mode 100644 index 0000000000..91a9f0ca2a --- /dev/null +++ b/doc/doc_ch/algorithm_rec_latex_ocr.md @@ -0,0 +1,171 @@ +# 印刷数学公式识别算法-LaTeX-OCR + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 pickle 标签文件生成](#3-1) + - [3.2 训练](#3-2) + - [3.3 评估](#3-3) + - [3.4 预测](#3-4) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +原始项目: +> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) + + + + +`LaTeX-OCR`使用[`LaTeX-OCR印刷公式数据集`](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)进行训练,在对应测试集上的精度如下: + +| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接| +|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- | +| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)| + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + +## 3. 模型训练、评估、预测 + + + +### 3.1 pickle 标签文件生成 +从[谷歌云盘](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)中下载 formulae.zip 和 math.txt,之后,使用如下命令,生成 pickle 标签文件。 + +```shell +# 创建 LaTeX-OCR 数据集目录 +mkdir -p train_data/LaTeXOCR +# 解压formulae.zip ,并拷贝math.txt +unzip -d train_data/LaTeXOCR path/formulae.zip +cp path/math.txt train_data/LaTeXOCR +# 将原始的 .txt 文件转换为 .pkl 文件,从而对不同尺度的图像进行分组 +# 训练集转换 +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +# 验证集转换 +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +# 测试集转换 +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +``` + +### 3.2 模型训练 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`LaTeX-OCR`识别模型时需要**更换配置文件**为`LaTeX-OCR`的[配置文件](../../configs/rec/rec_latex_ocr.yml)。 + +#### 启动训练 + + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: +```shell +#单卡训练 (默认训练方式) +python3 tools/train.py -c configs/rec/rec_latex_ocr.yml +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml +``` + +**注意:** + +- 默认每训练22个epoch(60000次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改 +``` +python3 tools/train.py -c configs/rec/rec_latex_ocr.yml -o Global.eval_batch_step=[0,{length_of_dataset//batch_size*22}] +``` + + +### 3.3 评估 + +可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar),使用如下命令进行评估: + +```shell +# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。 +# 验证集评估 +python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True +# 测试集评估 +python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl +``` + + +### 3.4 预测 + +使用如下命令进行单张图片预测: +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams +# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/pme_demo/'。 +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar) ),可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True + +# 目前的静态图模型支持的最大输出长度为512 +``` +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请检查配置文件中的`rec_char_dict_path`是否为所需要的字典文件。 +- [转换后模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_infer.tar) + +转换成功后,在目录下有三个文件: +``` +/inference/rec_latex_ocr_infer/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +执行如下命令进行模型推理: + +```shell +python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json" + +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/pme_demo/'。 +``` +  + +![测试图片样例](../datasets/pme_demo/0000295.png) + +执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下: +```shell +Predicts of ./doc/datasets/pme_demo/0000295.png:\zeta_{0}(\nu)=-{\frac{\nu\varrho^{-2\nu}}{\pi}}\int_{\mu}^{\infty}d\omega\int_{C_{+}}d z{\frac{2z^{2}}{(z^{2}+\omega^{2})^{\nu+1}}}{\tilde{\Psi}}(\omega;z)e^{i\epsilon z}~~~, +``` + + +**注意**: + +- 需要注意预测图像为**白底黑字**,即手写公式部分为黑色,背景为白色的图片。 +- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。 +- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中 LaTeX-OCR 的预处理为您的预处理方法。 + + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持 LaTeX-OCR,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + +1. LaTeX-OCR 数据集来自于[LaTeXOCR源repo](https://github.com/lukas-blecher/LaTeX-OCR) 。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 4cffcfd419..4c893ddcf4 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -137,6 +137,8 @@ On the TextZoom public dataset, the effect of the algorithm is as follows: Supported formula recognition algorithms (Click the link to get the tutorial): - [x] [CAN](./algorithm_rec_can_en.md) +- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr_en.md) + On the CROHME handwritten formula dataset, the effect of the algorithm is as follows: @@ -145,6 +147,13 @@ On the CROHME handwritten formula dataset, the effect of the algorithm is as fol |CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)| +On the LaTeX-OCR printed formula dataset, the effect of the algorithm is as follows: + +| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link| +|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- | +| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)| + + ## 2. End-to-end OCR Algorithms diff --git a/doc/doc_en/algorithm_rec_latex_ocr_en.md b/doc/doc_en/algorithm_rec_latex_ocr_en.md new file mode 100644 index 0000000000..087d40145a --- /dev/null +++ b/doc/doc_en/algorithm_rec_latex_ocr_en.md @@ -0,0 +1,127 @@ +# LaTeX-OCR + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Pickle File Generation](#3-1) + - [3.2 Training](#3-2) + - [3.3 Evaluation](#3-3) + - [3.4 Prediction](#3-4) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Original Project: +> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR) + + +Using LaTeX-OCR printed mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows: + +| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link| +|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- | +| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)| + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Pickle File Generation: + +Download formulae.zip and math.txt in [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO), and then use the following command to generate the pickle file. + +```shell +# Create a LaTeX-OCR dataset directory +mkdir -p train_data/LaTeXOCR +# Unzip formulae.zip and copy math.txt +unzip -d train_data/LaTeXOCR path/formulae.zip +cp path/math.txt train_data/LaTeXOCR +# Convert the original .txt file to a .pkl file to group images of different scales +# Training set conversion +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +# Validation set conversion +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +# Test set conversion +python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/ +``` + + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (Default training method) +python3 tools/train.py -c configs/rec/rec_latex_ocr.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml +``` + +Evaluation: + +``` +# GPU evaluation +# Validation set evaluation +python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True +# Test set evaluation +python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the LaTeX-OCR printed mathematical expression recognition training process is converted into an inference model. you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True + +# The default output max length of the model is 512. +``` + +For LaTeX-OCR printed mathematical expression recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json" +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +``` diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 27d74c89d8..5678aebec1 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -38,6 +38,7 @@ from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTable from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pubtab_dataset import PubTabDataSet from ppocr.data.multi_scale_sampler import MultiScaleSampler +from ppocr.data.latexocr_dataset import LaTeXOCRDataSet # for PaddleX dataset_type TextDetDataset = SimpleDataSet @@ -45,6 +46,7 @@ TextRecDataset = SimpleDataSet MSTextRecDataset = MultiScaleDataSet PubTabTableRecDataset = PubTabDataSet KieDataset = SimpleDataSet +LaTeXOCRDataSet = LaTeXOCRDataSet __all__ = ["build_dataloader", "transform", "create_operators", "set_signal_handlers"] @@ -94,6 +96,7 @@ def build_dataloader(config, mode, device, logger, seed=None): "MSTextRecDataset", "PubTabTableRecDataset", "KieDataset", + "LaTeXOCRDataSet", ] module_name = config[mode]["dataset"]["name"] assert module_name in support_dict, Exception( diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py index f1f317510b..29bb3f1aa4 100644 --- a/ppocr/data/collate_fn.py +++ b/ppocr/data/collate_fn.py @@ -116,3 +116,18 @@ class DyMaskCollator(object): label_masks[i][:l] = 1 return images, image_masks, labels, label_masks + + +class LaTeXOCRCollator(object): + """ + batch: [ + image [batch_size, channel, maxHinbatch, maxWinbatch] + label [batch_size, maxLabelLen] + label_mask [batch_size, maxLabelLen] + ... + ] + """ + + def __call__(self, batch): + images, labels, attention_mask = batch[0] + return images, labels, attention_mask diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 350887933b..d76a15555d 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -61,6 +61,7 @@ from .fce_aug import * from .fce_targets import FCENetTargets from .ct_process import * from .drrg_targets import DRRGTargets +from .latex_ocr_aug import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 46cabaed8c..430e8ef80b 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -25,6 +25,8 @@ import json import copy import random from random import sample +from collections import defaultdict +from tokenizers import Tokenizer as TokenizerFast from ppocr.utils.logging import get_logger from ppocr.data.imaug.vqa.augment import order_by_tbyx @@ -1770,3 +1772,106 @@ class CPPDLabelEncode(BaseRecLabelEncode): if len(text_list) == 0: return None, None, None return text_list, text_node_index, text_node_num + + +class LatexOCRLabelEncode(object): + def __init__( + self, + rec_char_dict_path, + **kwargs, + ): + self.tokenizer = TokenizerFast.from_file(rec_char_dict_path) + self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + + def _convert_encoding( + self, + encoding, + return_token_type_ids=None, + return_attention_mask=None, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + return_offsets_mapping=False, + return_length=False, + verbose=True, + ): + + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_overflowing_tokens and encoding.overflowing is not None: + encodings = [encoding] + encoding.overflowing + else: + encodings = [encoding] + + encoding_dict = defaultdict(list) + for e in encodings: + encoding_dict["input_ids"].append(e.ids) + + if return_token_type_ids: + encoding_dict["token_type_ids"].append(e.type_ids) + if return_attention_mask: + encoding_dict["attention_mask"].append(e.attention_mask) + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) + if return_offsets_mapping: + encoding_dict["offset_mapping"].append(e.offsets) + if return_length: + encoding_dict["length"].append(len(e.ids)) + + return encoding_dict, encodings + + def encode( + self, + text, + text_pair=None, + return_token_type_ids=False, + add_special_tokens=True, + is_split_into_words=False, + ): + batched_input = text + encodings = self.tokenizer.encode_batch( + batched_input, + add_special_tokens=add_special_tokens, + is_pretokenized=is_split_into_words, + ) + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=False, + return_attention_mask=None, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + return_offsets_mapping=False, + return_length=False, + verbose=True, + ) + for encoding in encodings + ] + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + return sanitized_tokens + + def __call__(self, eqs): + topk = self.encode(eqs) + for k, p in zip(topk, [[self.bos_token_id, self.eos_token_id], [1, 1]]): + process_seq = [[p[0]] + x + [p[1]] for x in topk[k]] + max_length = 0 + for seq in process_seq: + max_length = max(max_length, len(seq)) + labels = np.zeros((len(process_seq), max_length), dtype="int64") + for idx, seq in enumerate(process_seq): + l = len(seq) + labels[idx][:l] = seq + topk[k] = labels + return ( + np.array(topk["input_ids"]).astype(np.int64), + np.array(topk["attention_mask"]).astype(np.int64), + max_length, + ) diff --git a/ppocr/data/imaug/latex_ocr_aug.py b/ppocr/data/imaug/latex_ocr_aug.py new file mode 100644 index 0000000000..db787f3459 --- /dev/null +++ b/ppocr/data/imaug/latex_ocr_aug.py @@ -0,0 +1,179 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/transforms.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import math +import cv2 +import numpy as np +import albumentations as A +from PIL import Image + + +class LatexTrainTransform: + def __init__(self, bitmap_prob=0.04, **kwargs): + # your init code + self.bitmap_prob = bitmap_prob + self.train_transform = A.Compose( + [ + A.Compose( + [ + A.ShiftScaleRotate( + shift_limit=0, + scale_limit=(-0.15, 0), + rotate_limit=1, + border_mode=0, + interpolation=3, + value=[255, 255, 255], + p=1, + ), + A.GridDistortion( + distort_limit=0.1, + border_mode=0, + interpolation=3, + value=[255, 255, 255], + p=0.5, + ), + ], + p=0.15, + ), + A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3), + A.GaussNoise(10, p=0.2), + A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2), + A.ImageCompression(95, p=0.3), + A.ToGray(always_apply=True), + ] + ) + + def __call__(self, data): + img = data["image"] + if np.random.random() < self.bitmap_prob: + img[img != 255] = 0 + img = self.train_transform(image=img)["image"] + data["image"] = img + return data + + +class LatexTestTransform: + def __init__(self, **kwargs): + # your init code + self.test_transform = A.Compose( + [ + A.ToGray(always_apply=True), + ] + ) + + def __call__(self, data): + img = data["image"] + img = self.test_transform(image=img)["image"] + data["image"] = img + return data + + +class MinMaxResize: + def __init__(self, min_dimensions=[32, 32], max_dimensions=[672, 192], **kwargs): + # your init code + self.min_dimensions = min_dimensions + self.max_dimensions = max_dimensions + # pass + + def pad_(self, img, divable=32): + threshold = 128 + data = np.array(img.convert("LA")) + if data[..., -1].var() == 0: + data = (data[..., 0]).astype(np.uint8) + else: + data = (255 - data[..., -1]).astype(np.uint8) + data = (data - data.min()) / (data.max() - data.min()) * 255 + if data.mean() > threshold: + # To invert the text to white + gray = 255 * (data < threshold).astype(np.uint8) + else: + gray = 255 * (data > threshold).astype(np.uint8) + data = 255 - data + + coords = cv2.findNonZero(gray) # Find all non-zero points (text) + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + rect = data[b : b + h, a : a + w] + im = Image.fromarray(rect).convert("L") + dims = [] + for x in [w, h]: + div, mod = divmod(x, divable) + dims.append(divable * (div + (1 if mod > 0 else 0))) + padded = Image.new("L", dims, 255) + padded.paste(im, (0, 0, im.size[0], im.size[1])) + return padded + + def minmax_size_(self, img, max_dimensions, min_dimensions): + if max_dimensions is not None: + ratios = [a / b for a, b in zip(img.size, max_dimensions)] + if any([r > 1 for r in ratios]): + size = np.array(img.size) // max(ratios) + img = img.resize(tuple(size.astype(int)), Image.BILINEAR) + if min_dimensions is not None: + # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions + padded_size = [ + max(img_dim, min_dim) + for img_dim, min_dim in zip(img.size, min_dimensions) + ] + if padded_size != list(img.size): # assert hypothesis + padded_im = Image.new("L", padded_size, 255) + padded_im.paste(img, img.getbbox()) + img = padded_im + return img + + def __call__(self, data): + img = data["image"] + h, w = img.shape[:2] + if ( + self.min_dimensions[0] <= w <= self.max_dimensions[0] + and self.min_dimensions[1] <= h <= self.max_dimensions[1] + ): + return data + else: + im = Image.fromarray(np.uint8(img)) + im = self.minmax_size_( + self.pad_(im), self.max_dimensions, self.min_dimensions + ) + im = np.array(im) + im = np.dstack((im, im, im)) + data["image"] = im + return data + + +class LatexImageFormat: + def __init__(self, **kwargs): + # your init code + pass + + def __call__(self, data): + img = data["image"] + im_h, im_w = img.shape[:2] + divide_h = math.ceil(im_h / 16) * 16 + divide_w = math.ceil(im_w / 16) * 16 + img = img[:, :, 0] + img = np.pad( + img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1) + ) + img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1) + data["image"] = img_expanded + return data diff --git a/ppocr/data/latexocr_dataset.py b/ppocr/data/latexocr_dataset.py new file mode 100644 index 0000000000..a1a747f040 --- /dev/null +++ b/ppocr/data/latexocr_dataset.py @@ -0,0 +1,172 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/dataset.py +""" + +import numpy as np +import cv2 +import math +import os +import json +import pickle +import random +import traceback +import paddle +from paddle.io import Dataset +from .imaug.label_ops import LatexOCRLabelEncode +from .imaug import transform, create_operators + + +class LaTeXOCRDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(LaTeXOCRDataSet, self).__init__() + self.logger = logger + self.mode = mode.lower() + + global_config = config["Global"] + dataset_config = config[mode]["dataset"] + loader_config = config[mode]["loader"] + + pkl_path = dataset_config.pop("data") + self.min_dimensions = dataset_config.pop("min_dimensions") + self.max_dimensions = dataset_config.pop("max_dimensions") + self.batchsize = dataset_config.pop("batch_size_per_pair") + self.keep_smaller_batches = dataset_config.pop("keep_smaller_batches") + self.max_seq_len = global_config.pop("max_seq_len") + self.rec_char_dict_path = global_config.pop("rec_char_dict_path") + self.tokenizer = LatexOCRLabelEncode(self.rec_char_dict_path) + + file = open(pkl_path, "rb") + data = pickle.load(file) + temp = {} + for k in data: + if ( + self.min_dimensions[0] <= k[0] <= self.max_dimensions[0] + and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1] + ): + temp[k] = data[k] + self.data = temp + self.do_shuffle = loader_config["shuffle"] + self.seed = seed + + if self.mode == "train" and self.do_shuffle: + random.seed(self.seed) + self.pairs = [] + for k in self.data: + info = np.array(self.data[k], dtype=object) + p = ( + paddle.randperm(len(info)) + if self.mode == "train" and self.do_shuffle + else paddle.arange(len(info)) + ) + for i in range(0, len(info), self.batchsize): + batch = info[p[i : i + self.batchsize]] + if len(batch.shape) == 1: + batch = batch[None, :] + if len(batch) < self.batchsize and not self.keep_smaller_batches: + continue + self.pairs.append(batch) + if self.do_shuffle: + self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) + else: + self.pairs = np.array(self.pairs, dtype=object) + + self.size = len(self.pairs) + self.set_epoch_as_seed(self.seed, dataset_config) + + self.ops = create_operators(dataset_config["transforms"], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2) + self.need_reset = True + + def set_epoch_as_seed(self, seed, dataset_config): + if self.mode == "train": + try: + border_map_id = [ + index + for index, dictionary in enumerate(dataset_config["transforms"]) + if "MakeBorderMap" in dictionary + ][0] + shrink_map_id = [ + index + for index, dictionary in enumerate(dataset_config["transforms"]) + if "MakeShrinkMap" in dictionary + ][0] + dataset_config["transforms"][border_map_id]["MakeBorderMap"][ + "epoch" + ] = (seed if seed is not None else 0) + dataset_config["transforms"][shrink_map_id]["MakeShrinkMap"][ + "epoch" + ] = (seed if seed is not None else 0) + except Exception as E: + print(E) + return + + def shuffle_data_random(self): + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def __getitem__(self, idx): + batch = self.pairs[idx] + eqs, ims = batch.T + try: + max_width, max_height, max_length = 0, 0, 0 + + images_transform = [] + + for img_path in ims: + data = { + "img_path": img_path, + } + with open(data["img_path"], "rb") as f: + img = f.read() + data["image"] = img + item = transform(data, self.ops) + images_transform.append(np.array(item[0])) + image_concat = np.concatenate(images_transform, axis=0)[:, np.newaxis, :, :] + images_transform = image_concat.astype(np.float32) + labels, attention_mask, max_length = self.tokenizer(list(eqs)) + if self.max_seq_len < max_length: + rnd_idx = ( + np.random.randint(self.__len__()) + if self.mode == "train" + else (idx + 1) % self.__len__() + ) + return self.__getitem__(rnd_idx) + return (images_transform, labels, attention_mask) + + except: + + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + data["img_path"], traceback.format_exc() + ) + ) + outs = None + + if outs is None: + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = ( + np.random.randint(self.__len__()) + if self.mode == "train" + else (idx + 1) % self.__len__() + ) + return self.__getitem__(rnd_idx) + return outs + + def __len__(self): + return self.size diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index ed66e9837a..915a28d165 100644 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -45,6 +45,7 @@ from .rec_satrn_loss import SATRNLoss from .rec_nrtr_loss import NRTRLoss from .rec_parseq_loss import ParseQLoss from .rec_cppd_loss import CPPDLoss +from .rec_latexocr_loss import LaTeXOCRLoss # cls loss from .cls_loss import ClsLoss @@ -107,6 +108,7 @@ def build_loss(config): "NRTRLoss", "ParseQLoss", "CPPDLoss", + "LaTeXOCRLoss", ] config = copy.deepcopy(config) module_name = config.pop("name") diff --git a/ppocr/losses/rec_latexocr_loss.py b/ppocr/losses/rec_latexocr_loss.py new file mode 100644 index 0000000000..d209c04200 --- /dev/null +++ b/ppocr/losses/rec_latexocr_loss.py @@ -0,0 +1,47 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py +""" + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class LaTeXOCRLoss(nn.Layer): + """ + LaTeXOCR adopt CrossEntropyLoss for network training. + """ + + def __init__(self): + super(LaTeXOCRLoss, self).__init__() + self.ignore_index = -100 + self.cross = nn.CrossEntropyLoss( + reduction="mean", ignore_index=self.ignore_index + ) + + def forward(self, preds, batch): + word_probs = preds + labels = batch[1][:, 1:] + word_loss = self.cross( + paddle.reshape(word_probs, [-1, word_probs.shape[-1]]), + paddle.reshape(labels, [-1]), + ) + + loss = word_loss + return {"loss": loss} diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 9ab515fcb7..dd28d73538 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -22,7 +22,7 @@ import copy __all__ = ["build_metric"] from .det_metric import DetMetric, DetFCEMetric -from .rec_metric import RecMetric, CNTMetric, CANMetric +from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric from .cls_metric import ClsMetric from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric @@ -50,6 +50,7 @@ def build_metric(config): "CTMetric", "CNTMetric", "CANMetric", + "LaTeXOCRMetric", ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/bleu.py b/ppocr/metrics/bleu.py new file mode 100644 index 0000000000..672e7b4c03 --- /dev/null +++ b/ppocr/metrics/bleu.py @@ -0,0 +1,240 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py +""" + +import re +import math +import collections +from functools import lru_cache + + +def _get_ngrams(segment, max_order): + """Extracts all n-grams upto a given maximum order from an input segment. + + Args: + segment: text segment from which n-grams will be extracted. + max_order: maximum length in tokens of the n-grams returned by this + methods. + + Returns: + The Counter containing all n-grams upto max_order in segment + with a count of how many times each n-gram occurred. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i : i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False): + """Computes BLEU score of translated segments against one or more references. + + Args: + reference_corpus: list of lists of references for each translation. Each + reference should be tokenized into a list of tokens. + translation_corpus: list of translations to score. Each translation + should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. + + Returns: + 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram + precisions and brevity penalty. + """ + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for references, translation in zip(reference_corpus, translation_corpus): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = (matches_by_order[i] + 1.0) / ( + possible_matches_by_order[i] + 1.0 + ) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = ( + float(matches_by_order[i]) / possible_matches_by_order[i] + ) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1.0 + else: + bp = math.exp(1 - 1.0 / ratio) + + bleu = geo_mean * bp + + return (bleu, precisions, bp, ratio, translation_length, reference_length) + + +class BaseTokenizer: + """A base dummy tokenizer to derive from.""" + + def signature(self): + """ + Returns a signature for the tokenizer. + :return: signature string + """ + return "none" + + def __call__(self, line): + """ + Tokenizes an input line with the tokenizer. + :param line: a segment to tokenize + :return: the tokenized line + """ + return line + + +class TokenizerRegexp(BaseTokenizer): + def signature(self): + return "re" + + def __init__(self): + self._re = [ + # language-dependent part (assuming Western languages) + (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "), + # tokenize period and comma unless preceded by a digit + (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "), + # tokenize period and comma unless followed by a digit + (re.compile(r"([\.,])([^0-9])"), r" \1 \2"), + # tokenize dash when preceded by a digit + (re.compile(r"([0-9])(-)"), r"\1 \2 "), + # one space only between words + # NOTE: Doing this in Python (below) is faster + # (re.compile(r'\s+'), r' '), + ] + + @lru_cache(maxsize=2**16) + def __call__(self, line): + """Common post-processing tokenizer for `13a` and `zh` tokenizers. + :param line: a segment to tokenize + :return: the tokenized line + """ + for _re, repl in self._re: + line = _re.sub(repl, line) + + # no leading or trailing spaces, single space within words + # return ' '.join(line.split()) + # This line is changed with regards to the original tokenizer (seen above) to return individual words + return line.split() + + +class Tokenizer13a(BaseTokenizer): + def signature(self): + return "13a" + + def __init__(self): + self._post_tokenizer = TokenizerRegexp() + + @lru_cache(maxsize=2**16) + def __call__(self, line): + """Tokenizes an input line using a relatively minimal tokenization + that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + # language-independent part: + line = line.replace("", "") + line = line.replace("-\n", "") + line = line.replace("\n", " ") + + if "&" in line: + line = line.replace(""", '"') + line = line.replace("&", "&") + line = line.replace("<", "<") + line = line.replace(">", ">") + + return self._post_tokenizer(f" {line} ") + + +def compute_blue_score( + predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False +): + # if only one reference is provided make sure we still use list of lists + if isinstance(references[0], str): + references = [[ref] for ref in references] + + references = [[tokenizer(r) for r in ref] for ref in references] + predictions = [tokenizer(p) for p in predictions] + score = compute_bleu( + reference_corpus=references, + translation_corpus=predictions, + max_order=max_order, + smooth=smooth, + ) + (bleu, precisions, bp, ratio, translation_length, reference_length) = score + return bleu + + +def cal_distance(word1, word2): + m = len(word1) + n = len(word2) + if m * n == 0: + return m + n + dp = [[0] * (n + 1) for _ in range(m + 1)] + for i in range(m + 1): + dp[i][0] = i + for j in range(n + 1): + dp[0][j] = j + for i in range(1, m + 1): + for j in range(1, n + 1): + a = dp[i - 1][j] + 1 + b = dp[i][j - 1] + 1 + c = dp[i - 1][j - 1] + if word1[i - 1] != word2[j - 1]: + c += 1 + dp[i][j] = min(a, b, c) + return dp[m][n] + + +def compute_edit_distance(prediction, label): + prediction = prediction.strip().split(" ") + label = label.strip().split(" ") + distance = cal_distance(prediction, label) + return distance diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index e41dd36e09..dbb5ddeb76 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -17,6 +17,7 @@ from difflib import SequenceMatcher import numpy as np import string +from .bleu import compute_blue_score, compute_edit_distance class RecMetric(object): @@ -177,3 +178,121 @@ class CANMetric(object): self.exp_right = [] self.word_total_length = 0 self.exp_total_num = 0 + + +class LaTeXOCRMetric(object): + def __init__(self, main_indicator="exp_rate", cal_blue_score=False, **kwargs): + self.main_indicator = main_indicator + self.cal_blue_score = cal_blue_score + self.edit_right = [] + self.exp_right = [] + self.blue_right = [] + self.e1_right = [] + self.e2_right = [] + self.e3_right = [] + self.editdistance_total_length = 0 + self.exp_total_num = 0 + self.edit_dist = 0 + self.exp_rate = 0 + if self.cal_blue_score: + self.blue_score = 0 + self.e1 = 0 + self.e2 = 0 + self.e3 = 0 + self.reset() + self.epoch_reset() + + def __call__(self, preds, batch, **kwargs): + for k, v in kwargs.items(): + epoch_reset = v + if epoch_reset: + self.epoch_reset() + word_pred = preds + word_label = batch + line_right, e1, e2, e3 = 0, 0, 0, 0 + lev_dist = [] + for labels, prediction in zip(word_label, word_pred): + if prediction == labels: + line_right += 1 + distance = compute_edit_distance(prediction, labels) + lev_dist.append(Levenshtein.normalized_distance(prediction, labels)) + if distance <= 1: + e1 += 1 + if distance <= 2: + e2 += 1 + if distance <= 3: + e3 += 1 + + batch_size = len(lev_dist) + + self.edit_dist = sum(lev_dist) # float + self.exp_rate = line_right # float + if self.cal_blue_score: + self.blue_score = compute_blue_score(word_pred, word_label) + self.e1 = e1 + self.e2 = e2 + self.e3 = e3 + exp_length = len(word_label) + self.edit_right.append(self.edit_dist) + self.exp_right.append(self.exp_rate) + if self.cal_blue_score: + self.blue_right.append(self.blue_score * batch_size) + self.e1_right.append(self.e1) + self.e2_right.append(self.e2) + self.e3_right.append(self.e3) + self.editdistance_total_length = self.editdistance_total_length + exp_length + self.exp_total_num = self.exp_total_num + exp_length + + def get_metric(self): + """ + return { + 'edit distance': 0, + "blue_score": 0, + "exp_rate": 0, + } + """ + cur_edit_distance = sum(self.edit_right) / self.exp_total_num + cur_exp_rate = sum(self.exp_right) / self.exp_total_num + if self.cal_blue_score: + cur_blue_score = sum(self.blue_right) / self.editdistance_total_length + cur_exp_1 = sum(self.e1_right) / self.exp_total_num + cur_exp_2 = sum(self.e2_right) / self.exp_total_num + cur_exp_3 = sum(self.e3_right) / self.exp_total_num + self.reset() + if self.cal_blue_score: + return { + "blue_score ": cur_blue_score, + "edit distance ": cur_edit_distance, + "exp_rate ": cur_exp_rate, + "exp_rate<=1 ": cur_exp_1, + "exp_rate<=2 ": cur_exp_2, + "exp_rate<=3 ": cur_exp_3, + } + else: + return { + "edit distance": cur_edit_distance, + "exp_rate": cur_exp_rate, + "exp_rate<=1 ": cur_exp_1, + "exp_rate<=2 ": cur_exp_2, + "exp_rate<=3 ": cur_exp_3, + } + + def reset(self): + self.edit_dist = 0 + self.exp_rate = 0 + if self.cal_blue_score: + self.blue_score = 0 + self.e1 = 0 + self.e2 = 0 + self.e3 = 0 + + def epoch_reset(self): + self.edit_right = [] + self.exp_right = [] + if self.cal_blue_score: + self.blue_right = [] + self.e1_right = [] + self.e2_right = [] + self.e3_right = [] + self.editdistance_total_length = 0 + self.exp_total_num = 0 diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 81d107c293..2a18a51b4b 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -59,6 +59,8 @@ def build_backbone(config, model_type): from .rec_vitstr import ViTSTR from .rec_resnet_rfl import ResNetRFL from .rec_densenet import DenseNet + from .rec_resnetv2 import ResNetV2 + from .rec_hybridvit import HybridTransformer from .rec_shallow_cnn import ShallowCNN from .rec_lcnetv3 import PPLCNetV3 from .rec_hgnet import PPHGNet_small @@ -89,6 +91,8 @@ def build_backbone(config, model_type): "ViT", "RepSVTR", "SVTRv2", + "ResNetV2", + "HybridTransformer", ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_hybridvit.py b/ppocr/modeling/backbones/rec_hybridvit.py new file mode 100644 index 0000000000..e873a781b6 --- /dev/null +++ b/ppocr/modeling/backbones/rec_hybridvit.py @@ -0,0 +1,529 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from itertools import repeat +import collections +import math +from functools import partial + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppocr.modeling.backbones.rec_resnetv2 import ( + ResNetV2, + StdConv2dSame, + DropPath, + get_padding, +) +from paddle.nn.initializer import ( + TruncatedNormal, + Constant, + Normal, + KaimingUniform, + XavierUniform, +) + +normal_ = Normal(mean=0.0, std=1e-6) +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) +kaiming_normal_ = KaimingUniform(nonlinearity="relu") +trunc_normal_ = TruncatedNormal(std=0.02) +xavier_uniform_ = XavierUniform() + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +class Conv2dAlign(nn.Conv2D): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + + def __init__( + self, + in_channel, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + eps=1e-6, + ): + + super().__init__( + in_channel, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias, + weight_attr=True, + ) + self.eps = eps + + def forward(self, x): + x = F.conv2d( + x, + self.weight, + self.bias, + self._stride, + self._padding, + self._dilation, + self._groups, + ) + return x + + +class HybridEmbed(nn.Layer): + """CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=768, + ): + super().__init__() + assert isinstance(backbone, nn.Layer) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + feature_dim = 1024 + feature_size = (42, 12) + patch_size = (1, 1) + assert ( + feature_size[0] % patch_size[0] == 0 + and feature_size[1] % patch_size[1] == 0 + ) + + self.grid_size = ( + feature_size[0] // patch_size[0], + feature_size[1] // patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2D( + feature_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + weight_attr=True, + bias_attr=True, + ) + + def forward(self, x): + + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose([0, 2, 1]) + + return x + + +class myLinear(nn.Linear): + def __init__(self, in_channel, out_channels, weight_attr=True, bias_attr=True): + super().__init__( + in_channel, out_channels, weight_attr=weight_attr, bias_attr=bias_attr + ) + + def forward(self, x): + return paddle.matmul(x, self.weight, transpose_y=True) + self.bias + + +class Attention(nn.Layer): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = myLinear(dim, dim, weight_attr=True, bias_attr=True) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape([B, N, 3, self.num_heads, C // self.num_heads]) + .transpose([2, 0, 3, 1, 4]) + ) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale + + attn = F.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Layer): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Block(nn.Layer): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x): + + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class HybridTransformer(nn.Layer): + """Implementation of HybridTransformer. + + Args: + x: input images with shape [N, 1, H, W] + label: LaTeX-OCR labels with shape [N, L] , L is the max sequence length + attention_mask: LaTeX-OCR attention mask with shape [N, L] , L is the max sequence length + + Returns: + The encoded features with shape [N, 1, H//16, W//16] + """ + + def __init__( + self, + backbone_layers=[2, 3, 7], + input_channel=1, + is_predict=False, + is_export=False, + img_size=(224, 224), + patch_size=16, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + representation_size=None, + distilled=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + embed_layer=None, + norm_layer=None, + act_layer=None, + weight_init="", + **kwargs, + ): + super(HybridTransformer, self).__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6) + act_layer = act_layer or nn.GELU + self.height, self.width = img_size + self.patch_size = patch_size + backbone = ResNetV2( + layers=backbone_layers, + num_classes=0, + global_pool="", + in_chans=input_channel, + preact=False, + stem_type="same", + conv_layer=StdConv2dSame, + is_export=is_export, + ) + min_patch_size = 2 ** (len(backbone_layers) + 1) + self.patch_embed = HybridEmbed( + img_size=img_size, + patch_size=patch_size // min_patch_size, + in_chans=input_channel, + embed_dim=embed_dim, + backbone=backbone, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = paddle.create_parameter([1, 1, embed_dim], dtype="float32") + self.dist_token = ( + paddle.create_parameter( + [1, 1, embed_dim], + dtype="float32", + ) + if distilled + else None + ) + self.pos_embed = paddle.create_parameter( + [1, num_patches + self.num_tokens, embed_dim], dtype="float32" + ) + self.pos_drop = nn.Dropout(p=drop_rate) + zeros_(self.cls_token) + if self.dist_token is not None: + zeros_(self.dist_token) + zeros_(self.pos_embed) + + dpr = [ + x.item() for x in paddle.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential( + ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()) + ) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + self.head_dist = None + if distilled: + self.head_dist = ( + nn.Linear(self.embed_dim, self.num_classes) + if num_classes > 0 + else nn.Identity() + ) + self.init_weights(weight_init) + self.out_channels = embed_dim + self.is_predict = is_predict + self.is_export = is_export + + def init_weights(self, mode=""): + assert mode in ("jax", "jax_nlhb", "nlhb", "") + head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed) + trunc_normal_(self.cls_token) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + def load_pretrained(self, checkpoint_path, prefix=""): + raise NotImplementedError + + def no_weight_decay(self): + return {"pos_embed", "cls_token", "dist_token"} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + if self.num_tokens == 2: + self.head_dist = ( + nn.Linear(self.embed_dim, self.num_classes) + if num_classes > 0 + else nn.Identity() + ) + + def forward_features(self, x): + B, c, h, w = x.shape + x = self.patch_embed(x) + cls_tokens = self.cls_token.expand( + [B, -1, -1] + ) # stole cls_tokens impl from Phil Wang, thanks + x = paddle.concat((cls_tokens, x), axis=1) + h, w = h // self.patch_size, w // self.patch_size + repeat_tensor = ( + paddle.arange(h) * (self.width // self.patch_size - w) + ).reshape([-1, 1]) + repeat_tensor = paddle.repeat_interleave( + repeat_tensor, paddle.to_tensor(w), axis=1 + ).reshape([-1]) + pos_emb_ind = repeat_tensor + paddle.arange(h * w) + pos_emb_ind = paddle.concat( + (paddle.zeros([1], dtype="int64"), pos_emb_ind + 1), axis=0 + ).cast(paddle.int64) + x += self.pos_embed[:, pos_emb_ind] + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + def forward(self, input_data): + + if self.training: + x, label, attention_mask = input_data + else: + if isinstance(input_data, list): + x = input_data[0] + else: + x = input_data + x = self.forward_features(x) + x = self.head(x) + if self.training: + return x, label, attention_mask + else: + return x + + +def _init_vit_weights( + module: nn.Layer, name: str = "", head_bias: float = 0.0, jax_impl: bool = False +): + """ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith("head"): + zeros_(module.weight) + constant_ = Constant(value=head_bias) + constant_(module.bias, head_bias) + elif name.startswith("pre_logits"): + zeros_(module.bias) + else: + if jax_impl: + xavier_uniform_(module.weight) + if module.bias is not None: + if "mlp" in name: + normal_(module.bias) + else: + zeros_(module.bias) + else: + trunc_normal_(module.weight) + if module.bias is not None: + zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2D): + # NOTE conv was left to pytorch default in my original init + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2D)): + zeros_(module.bias) + ones_(module.weight) diff --git a/ppocr/modeling/backbones/rec_resnetv2.py b/ppocr/modeling/backbones/rec_resnetv2.py new file mode 100644 index 0000000000..083e08c7b9 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnetv2.py @@ -0,0 +1,1283 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnetv2.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import collections.abc +from itertools import repeat +from collections import OrderedDict # pylint: disable=g-importing-member + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingUniform +from functools import partial +from typing import Union, Callable, Type, List, Tuple + +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +normal_ = Normal(mean=0.0, std=0.01) +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) +kaiming_normal_ = KaimingUniform(nonlinearity="relu") + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +class StdConv2dSame(nn.Conv2D): + def __init__( + self, + in_channel, + out_channels, + kernel_size, + stride=1, + padding="SAME", + dilation=1, + groups=1, + bias_attr=False, + eps=1e-6, + is_export=False, + ): + padding, is_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation + ) + super().__init__( + in_channel, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias_attr, + ) + self.same_pad = is_dynamic + self.export = is_export + self.eps = eps + + def forward(self, x): + if self.same_pad: + if self.export: + x = pad_same_export(x, self._kernel_size, self._stride, self._dilation) + else: + x = pad_same(x, self._kernel_size, self._stride, self._dilation) + running_mean = paddle.to_tensor([0] * self._out_channels, dtype="float32") + running_variance = paddle.to_tensor([1] * self._out_channels, dtype="float32") + if self.export: + weight = paddle.reshape( + F.batch_norm( + self.weight.reshape([1, self._out_channels, -1]), + running_mean, + running_variance, + momentum=0.0, + epsilon=self.eps, + use_global_stats=False, + ), + self.weight.shape, + ) + else: + weight = paddle.reshape( + F.batch_norm( + self.weight.reshape([1, self._out_channels, -1]), + running_mean, + running_variance, + training=True, + momentum=0.0, + epsilon=self.eps, + ), + self.weight.shape, + ) + x = F.conv2d( + x, + weight, + self.bias, + self._stride, + self._padding, + self._dilation, + self._groups, + ) + return x + + +class StdConv2d(nn.Conv2D): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + + def __init__( + self, + in_channel, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=False, + eps=1e-6, + ): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias, + ) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), + None, + None, + training=True, + momentum=0.0, + epsilon=self.eps, + ).reshape_as(self.weight) + x = F.conv2d( + x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + return x + + +class MaxPool2dSame(nn.MaxPool2D): + """Tensorflow like 'SAME' wrapper for 2D max pooling""" + + def __init__( + self, + kernel_size: int, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + is_export=False, + ): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + self.export = is_export + super(MaxPool2dSame, self).__init__( + kernel_size, stride, (0, 0), dilation, ceil_mode + ) + + def forward(self, x): + if self.export: + x = pad_same_export(x, self.ksize, self.stride, value=-float("inf")) + else: + x = pad_same(x, self.ksize, self.stride, value=-float("inf")) + return F.max_pool2d(x, self.ksize, self.stride, (0, 0), self.ceil_mode) + + +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == "same": + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == "valid": + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_pool2d(pool_type, kernel_size, stride=None, is_export=False, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop("padding", "") + padding, is_dynamic = get_padding_value( + padding, kernel_size, stride=stride, **kwargs + ) + if is_dynamic: + if pool_type == "avg": + return AvgPool2dSame( + kernel_size, stride=stride, is_export=is_export, **kwargs + ) + elif pool_type == "max": + return MaxPool2dSame( + kernel_size, stride=stride, is_export=is_export, **kwargs + ) + else: + assert False, f"Unsupported pool type {pool_type}" + + +def get_same_padding(x, k, s, d): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +def get_same_padding_export(x, k, s, d): + x = paddle.to_tensor(x) + k = paddle.to_tensor(k) + s = paddle.to_tensor(s) + d = paddle.to_tensor(d) + return paddle.max((paddle.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +def pad_same_export(x, k, s, d=(1, 1), value=0): + ih, iw = x.shape[-2:] + pad_h, pad_w = get_same_padding_export( + ih, k[0], s[0], d[0] + ), get_same_padding_export(iw, k[1], s[1], d[1]) + pad_h = pad_h.cast(paddle.int32) + pad_w = pad_w.cast(paddle.int32) + pad_list = paddle.to_tensor( + [ + (pad_w // 2), + (pad_w - pad_w // 2).cast(paddle.int32), + (pad_h // 2).cast(paddle.int32), + (pad_h - pad_h // 2).cast(paddle.int32), + ] + ) + + if pad_h > 0 or pad_w > 0: + if len(pad_list.shape) == 2: + pad_list = pad_list.squeeze(1) + x = F.pad(x, pad_list.cast(paddle.int32), value=value) + return x + + +def pad_same(x, k, s, d=(1, 1), value=0, pad_h=None, pad_w=None): + ih, iw = x.shape[-2:] + + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( + iw, k[1], s[1], d[1] + ) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + value=value, + ) + return x + + +class AvgPool2dSame(nn.AvgPool2D): + """Tensorflow like 'SAME' wrapper for 2D average pooling""" + + def __init__( + self, + kernel_size: int, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__( + kernel_size, stride, (0, 0), ceil_mode, count_include_pad + ) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def adaptive_pool_feat_mult(pool_type="avg"): + if pool_type == "catavgmax": + return 2 + else: + return 1 + + +class SelectAdaptivePool2d(nn.Layer): + """Selectable global pooling layer with dynamic input kernel size""" + + def __init__(self, output_size=1, pool_type="fast", flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = ( + pool_type or "" + ) # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == "": + self.pool = nn.Identity() # pass through + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return ( + self.__class__.__name__ + + " (" + + "pool_type=" + + self.pool_type + + ", flatten=" + + str(self.flatten) + + ")" + ) + + +def _create_pool(num_features, num_classes, pool_type="avg", use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert ( + num_classes == 0 or use_conv + ), "Pooling can only be disabled if classifier is also removed or conv classifier is used" + flatten_in_pool = ( + False # disable flattening if pooling is pass-through (no pooling) + ) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2D(num_features, num_classes, 1, bias_attr=True) + else: + fc = nn.Linear(num_features, num_classes, bias_attr=True) + return fc + + +class ClassifierHead(nn.Layer): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__( + self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False + ): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool( + in_chs, num_classes, pool_type, use_conv=use_conv + ) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x + + +class EvoNormBatch2d(nn.Layer): + def __init__( + self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None + ): + super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = paddle.create_parameter( + paddle.ones(num_features), dtype="float32" + ) + self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32") + self.v = ( + paddle.create_parameter(paddle.ones(num_features), dtype="float32") + if apply_act + else None + ) + self.register_buffer("running_var", paddle.ones([num_features])) + self.reset_parameters() + + def reset_parameters(self): + ones_(self.weight) + zeros_(self.bias) + if self.apply_act: + ones_(self.v) + + def forward(self, x): + x_type = x.dtype + if self.v is not None: + running_var = self.running_var.view(1, -1, 1, 1) + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + running_var = var.detach() * self.momentum * ( + n / (n - 1) + ) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) + else: + var = running_var + v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) + d = x * v + ( + x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps + ).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) + x = x / d + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + + +class EvoNormSample2d(nn.Layer): + def __init__( + self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None + ): + super(EvoNormSample2d, self).__init__() + self.apply_act = apply_act + self.groups = groups + self.eps = eps + self.weight = paddle.create_parameter( + paddle.ones(num_features), dtype="float32" + ) + self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32") + self.v = ( + paddle.create_parameter(paddle.ones(num_features), dtype="float32") + if apply_act + else None + ) + self.reset_parameters() + + def reset_parameters(self): + ones_(self.weight) + zeros_(self.bias) + if self.apply_act: + ones_(self.v) + + def forward(self, x): + B, C, H, W = x.shape + if self.v is not None: + n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() + x = x.reshape(B, self.groups, -1) + x = ( + n.reshape(B, self.groups, -1) + / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + ) + x = x.reshape(B, C, H, W) + return x * self.weight.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1]) + + +from paddle.common_ops_import import ( + LayerHelper, + check_type, + check_variable_and_dtype, +) + + +def group_norm( + input, + groups, + epsilon=1e-05, + weight=None, + bias=None, + act=None, + data_layout="NCHW", + name=None, +): + helper = LayerHelper("group_norm", **locals()) + dtype = helper.input_dtype() + check_variable_and_dtype( + input, + "input", + ["float16", "uint16", "float32", "float64"], + "group_norm", + ) + # create intput and parameters + inputs = {"X": input} + input_shape = input.shape + if len(input_shape) < 2: + raise ValueError( + f"The dimensions of Op(static.nn.group_norm)'s input should be more than 1. But received {len(input_shape)}" + ) + if data_layout != "NCHW" and data_layout != "NHWC": + raise ValueError( + "Param(data_layout) of Op(static.nn.group_norm) got wrong value: received " + + data_layout + + " but only NCHW or NHWC supported." + ) + channel_num = input_shape[1] if data_layout == "NCHW" else input_shape[-1] + param_shape = [channel_num] + inputs["Scale"] = weight + inputs["Bias"] = bias + # create output + mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_variable(dtype=dtype) + + helper.append_op( + type="group_norm", + inputs=inputs, + outputs={ + "Y": group_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={ + "epsilon": epsilon, + "groups": groups, + "data_layout": data_layout, + }, + ) + + return helper.append_activation(group_norm_out) + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__( + self, + num_channels, + num_groups=32, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + drop_block=None, + ): + super(GroupNormAct, self).__init__(num_groups, num_channels, epsilon=eps) + if affine: + self.weight = paddle.create_parameter([num_channels], dtype="float32") + self.bias = paddle.create_parameter([num_channels], dtype="float32") + ones_(self.weight) + zeros_(self.bias) + if act_layer is not None and apply_act: + act_args = {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = group_norm( + x, self._num_groups, self._epsilon, weight=self.weight, bias=self.bias + ) + x = self.act(x) + return x + + +class BatchNormAct2d(nn.BatchNorm2D): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + drop_block=None, + ): + super(BatchNormAct2d, self).__init__( + num_features, + epsilon=eps, + momentum=momentum, + use_global_stats=track_running_stats, + ) + if act_layer is not None and apply_act: + act_args = dict() + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + x = self._forward_python(x) + x = self.act(x) + return x + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = ( + conv_weight.float() + ) # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError("Weight format not supported by conversion.") + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= 3 / float(in_chans) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def named_apply( + fn: Callable, module: nn.Layer, name="", depth_first=True, include_root=False +) -> nn.Layer: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": (7, 7), + "crop_pct": 0.875, + "interpolation": "bilinear", + "mean": IMAGENET_INCEPTION_MEAN, + "std": IMAGENET_INCEPTION_STD, + "first_conv": "stem.conv", + "classifier": "head.fc", + **kwargs, + } + + +def make_div(v, divisor=8): + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PreActBottleneck(nn.Layer): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0.0, + is_export=False, + ): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + preact=True, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer( + mid_chs, + mid_chs, + 3, + stride=stride, + dilation=first_dilation, + groups=groups, + is_export=is_export, + ) + self.norm3 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export) + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + ) + + def zero_init_last(self): + zeros_(self.conv3.weight) + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.conv3(self.norm3(x)) + x = self.drop_path(x) + return x + shortcut + + +class Bottleneck(nn.Layer): + """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.""" + + def __init__( + self, + in_chs, + out_chs=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + act_layer=None, + conv_layer=None, + norm_layer=None, + proj_layer=None, + drop_path_rate=0.0, + is_export=False, + ): + super().__init__() + first_dilation = first_dilation or dilation + act_layer = act_layer or nn.ReLU + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, + out_chs, + stride=stride, + dilation=dilation, + preact=False, + conv_layer=conv_layer, + norm_layer=norm_layer, + is_export=is_export, + ) + else: + self.downsample = None + + self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export) + self.norm1 = norm_layer(mid_chs) + self.conv2 = conv_layer( + mid_chs, + mid_chs, + 3, + stride=stride, + dilation=first_dilation, + groups=groups, + is_export=is_export, + ) + self.norm2 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export) + self.norm3 = norm_layer(out_chs, apply_act=False) + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + ) + self.act3 = act_layer() + + def zero_init_last(self): + zeros_(self.norm3.weight) + + def forward(self, x): + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + # residual + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.conv3(x) + x = self.norm3(x) + x = self.drop_path(x) + x = self.act3(x + shortcut) + return x + + +class DownsampleConv(nn.Layer): + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + is_export=False, + ): + super(DownsampleConv, self).__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride, is_export=is_export) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class DownsampleAvg(nn.Layer): + def __init__( + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + is_export=False, + ): + """AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = ( + AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2D + ) + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, exclusive=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1, is_export=is_export) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ResNetStage(nn.Layer): + """ResNet Stage.""" + + def __init__( + self, + in_chs, + out_chs, + stride, + dilation, + depth, + bottle_ratio=0.25, + groups=1, + avg_down=False, + block_dpr=None, + block_fn=PreActBottleneck, + is_export=False, + act_layer=None, + conv_layer=None, + norm_layer=None, + **block_kwargs, + ): + super(ResNetStage, self).__init__() + first_dilation = 1 if dilation in (1, 2) else 2 + layer_kwargs = dict( + act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer + ) + proj_layer = DownsampleAvg if avg_down else DownsampleConv + prev_chs = in_chs + self.blocks = nn.Sequential() + for block_idx in range(depth): + drop_path_rate = block_dpr[block_idx] if block_dpr else 0.0 + stride = stride if block_idx == 0 else 1 + self.blocks.add_sublayer( + str(block_idx), + block_fn( + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + first_dilation=first_dilation, + proj_layer=proj_layer, + drop_path_rate=drop_path_rate, + is_export=is_export, + **layer_kwargs, + **block_kwargs, + ), + ) + prev_chs = out_chs + first_dilation = dilation + proj_layer = None + + def forward(self, x): + x = self.blocks(x) + return x + + +def is_stem_deep(stem_type): + return any([s in stem_type for s in ("deep", "tiered")]) + + +def create_resnetv2_stem( + in_chs, + out_chs=64, + stem_type="", + preact=True, + conv_layer=StdConv2d, + norm_layer=partial(GroupNormAct, num_groups=32), + is_export=False, +): + stem = OrderedDict() + assert stem_type in ( + "", + "fixed", + "same", + "deep", + "deep_fixed", + "deep_same", + "tiered", + ) + + # NOTE conv padding mode can be changed by overriding the conv_layer def + if is_stem_deep(stem_type): + # A 3 deep 3x3 conv stack as in ResNet V1D models + if "tiered" in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets + stem["conv1"] = conv_layer( + in_chs, stem_chs[0], kernel_size=3, stride=2, is_export=is_export + ) + stem["norm1"] = norm_layer(stem_chs[0]) + stem["conv2"] = conv_layer( + stem_chs[0], stem_chs[1], kernel_size=3, stride=1, is_export=is_export + ) + stem["norm2"] = norm_layer(stem_chs[1]) + stem["conv3"] = conv_layer( + stem_chs[1], out_chs, kernel_size=3, stride=1, is_export=is_export + ) + if not preact: + stem["norm3"] = norm_layer(out_chs) + else: + # The usual 7x7 stem conv + stem["conv"] = conv_layer( + in_chs, out_chs, kernel_size=7, stride=2, is_export=is_export + ) + if not preact: + stem["norm"] = norm_layer(out_chs) + + if "fixed" in stem_type: + # 'fixed' SAME padding approximation that is used in BiT models + stem["pad"] = paddle.nn.Pad2D( + 1, mode="constant", value=0.0, data_format="NCHW", name=None + ) + stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=0) + elif "same" in stem_type: + # full, input size based 'SAME' padding, used in ViT Hybrid model + stem["pool"] = create_pool2d( + "max", kernel_size=3, stride=2, padding="same", is_export=is_export + ) + else: + # the usual Pypaddle symmetric padding + stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + stem_seq = nn.Sequential() + for key, value in stem.items(): + stem_seq.add_sublayer(key, value) + + return stem_seq + + +class ResNetV2(nn.Layer): + """Implementation of Pre-activation (v2) ResNet mode. + + Args: + x: input images with shape [N, 1, H, W] + + Returns: + The extracted features [N, 1, H//16, W//16] + """ + + def __init__( + self, + layers, + channels=(256, 512, 1024, 2048), + num_classes=1000, + in_chans=3, + global_pool="avg", + output_stride=32, + width_factor=1, + stem_chs=64, + stem_type="", + avg_down=False, + preact=True, + act_layer=nn.ReLU, + conv_layer=StdConv2d, + norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0.0, + drop_path_rate=0.0, + zero_init_last=False, + is_export=False, + ): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.is_export = is_export + wf = width_factor + self.feature_info = [] + stem_chs = make_div(stem_chs * wf) + self.stem = create_resnetv2_stem( + in_chans, + stem_chs, + stem_type, + preact, + conv_layer=conv_layer, + norm_layer=norm_layer, + is_export=is_export, + ) + stem_feat = ( + ("stem.conv3" if is_stem_deep(stem_type) else "stem.conv") + if preact + else "stem.norm" + ) + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) + + prev_chs = stem_chs + curr_stride = 4 + dilation = 1 + block_dprs = [ + x.tolist() + for x in paddle.linspace(0, drop_path_rate, sum(layers)).split(layers) + ] + block_fn = PreActBottleneck if preact else Bottleneck + self.stages = nn.Sequential() + for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): + out_chs = make_div(c * wf) + stride = 1 if stage_idx == 0 else 2 + if curr_stride >= output_stride: + dilation *= stride + stride = 1 + stage = ResNetStage( + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + depth=d, + avg_down=avg_down, + act_layer=act_layer, + conv_layer=conv_layer, + norm_layer=norm_layer, + block_dpr=bdpr, + block_fn=block_fn, + is_export=is_export, + ) + prev_chs = out_chs + curr_stride *= stride + self.feature_info += [ + dict( + num_chs=prev_chs, + reduction=curr_stride, + module=f"stages.{stage_idx}", + ) + ] + self.stages.add_sublayer(str(stage_idx), stage) + + self.num_features = prev_chs + self.norm = norm_layer(self.num_features) if preact else nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + use_conv=True, + ) + + self.init_weights(zero_init_last=zero_init_last) + + def init_weights(self, zero_init_last=True): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + def load_pretrained(self, checkpoint_path, prefix="resnet/"): + _load_weights(self, checkpoint_path, prefix) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool="avg"): + self.num_classes = num_classes + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + use_conv=True, + ) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Layer, name: str = "", zero_init_last=True): + if isinstance(module, nn.Linear) or ( + "head.fc" in name and isinstance(module, nn.Conv2D) + ): + normal_(module.weight) + zeros_(module.bias) + elif isinstance(module, nn.Conv2D): + kaiming_normal_(module.weight) + if module.bias is not None: + zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2D, nn.LayerNorm, nn.GroupNorm)): + ones_(module.weight) + zeros_(module.bias) + elif zero_init_last and hasattr(module, "zero_init_last"): + module.zero_init_last() + + +@paddle.no_grad() +def _load_weights(model: nn.Layer, checkpoint_path: str, prefix: str = "resnet/"): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return paddle.to_tensor(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], + t2p(weights[f"{prefix}root_block/standardized_conv2d/kernel"]), + ) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f"{prefix}group_norm/gamma"])) + model.norm.bias.copy_(t2p(weights[f"{prefix}group_norm/beta"])) + if ( + isinstance(getattr(model.head, "fc", None), nn.Conv2D) + and model.head.fc.weight.shape[0] + == weights[f"{prefix}head/conv2d/kernel"].shape[-1] + ): + model.head.fc.weight.copy_(t2p(weights[f"{prefix}head/conv2d/kernel"])) + model.head.fc.bias.copy_(t2p(weights[f"{prefix}head/conv2d/bias"])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = "standardized_conv2d" + block_prefix = f"{prefix}block{i + 1}/unit{j + 1:02d}/" + block.conv1.weight.copy_(t2p(weights[f"{block_prefix}a/{cname}/kernel"])) + block.conv2.weight.copy_(t2p(weights[f"{block_prefix}b/{cname}/kernel"])) + block.conv3.weight.copy_(t2p(weights[f"{block_prefix}c/{cname}/kernel"])) + block.norm1.weight.copy_(t2p(weights[f"{block_prefix}a/group_norm/gamma"])) + block.norm2.weight.copy_(t2p(weights[f"{block_prefix}b/group_norm/gamma"])) + block.norm3.weight.copy_(t2p(weights[f"{block_prefix}c/group_norm/gamma"])) + block.norm1.bias.copy_(t2p(weights[f"{block_prefix}a/group_norm/beta"])) + block.norm2.bias.copy_(t2p(weights[f"{block_prefix}b/group_norm/beta"])) + block.norm3.bias.copy_(t2p(weights[f"{block_prefix}c/group_norm/beta"])) + if block.downsample is not None: + w = weights[f"{block_prefix}a/proj/{cname}/kernel"] + block.downsample.conv.weight.copy_(t2p(w)) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index f9a9528eb0..bcf60e98a2 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -40,6 +40,7 @@ def build_head(config): from .rec_visionlan_head import VLHead from .rec_rfl_head import RFLHead from .rec_can_head import CANHead + from .rec_latexocr_head import LaTeXOCRHead from .rec_satrn_head import SATRNHead from .rec_parseq_head import ParseQHead from .rec_cppd_head import CPPDHead @@ -81,6 +82,7 @@ def build_head(config): "RFLHead", "DRRGHead", "CANHead", + "LaTeXOCRHead", "SATRNHead", "PFHeadLocal", "ParseQHead", diff --git a/ppocr/modeling/heads/rec_latexocr_head.py b/ppocr/modeling/heads/rec_latexocr_head.py new file mode 100644 index 0000000000..4e368da0dd --- /dev/null +++ b/ppocr/modeling/heads/rec_latexocr_head.py @@ -0,0 +1,1027 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/models/transformer.py +""" + +import math +import paddle +from paddle import nn, einsum +import paddle.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple + +from paddle.nn.initializer import ( + TruncatedNormal, + Constant, + Normal, + KaimingUniform, + XavierUniform, +) + +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) +normal_ = Normal(std=0.02) +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"]) + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class always: + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -paddle.finfo(tensor.dtype).max + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) + ) + return kwargs_without_prefix, kwargs + + +# positional embeddings + + +class DepthWiseConv1d(nn.Layer): + def __init__( + self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False + ): + super().__init__() + groups = default(groups, dim_in) + self.net = nn.Sequential( + nn.Conv1D( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias_attr=bias, + ), + nn.Conv1D(dim_in, dim_out, 1), + ) + + def forward(self, x): + return self.net(x) + + +class AbsolutePositionalEmbedding(nn.Layer): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + + normal_(self.emb.weight) + + def forward(self, x): + n = paddle.arange(x.shape[1]) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Layer): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = ( + paddle.arange( + x.shape[seq_dim], + ).type_as(self.inv_freq) + + offset + ) + sinusoid_inp = paddle.einsum("i , j -> i j", t, self.inv_freq) + emb = paddle.concat((sinusoid_inp.sin(), sinusoid_inp.cos()), axis=-1) + return emb[None, :, :] + + +class Scale(nn.Layer): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = paddle.create_parameter([1], dtype="float32") + zeros_(self.g) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Layer): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = paddle.create_parameter([1], dtype="float32") + ones_(self.g) + + def forward(self, x): + norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Layer): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = paddle.create_parameter([dim]) + ones_(self.g) + + def forward(self, x): + norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Layer): + def forward(self, x, residual): + return x + residual + + +class GEGLU(nn.Layer): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, axis=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Layer): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Layer): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + collab_heads=False, + collab_compression=0.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + is_export=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.causal = causal + self.mask = mask + self.is_export = is_export + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(paddle.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias_attr=False) + self.to_k = nn.Linear(dim, qk_dim, bias_attr=False) + self.to_v = nn.Linear(dim, v_dim, bias_attr=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + zeros_(self.to_v_gate.weight) + ones_(self.to_v_gate.bias) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(paddle.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(paddle.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = ( + nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) + if on_attn + else nn.Linear(v_dim, dim) + ) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + seq_len=0, + ): + b, n, _, h, talking_heads, collab_heads, has_context = ( + *x.shape, + self.heads, + self.talking_heads, + self.collab_heads, + exists(context), + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = paddle.concat((mem, k_input), axis=-2) + v_input = paddle.concat((mem, v_input), axis=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + def rearrange_q_k_v(x, h, is_export): + if is_export: + b, n, h_d = paddle.shape(x) + else: + b, n, h_d = x.shape + d = h_d // h + return x.reshape([b, n, h, d]).transpose([0, 2, 1, 3]) + + q, k, v = map( + lambda t: rearrange_q_k_v(t, h, is_export=self.is_export), (q, k, v) + ) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default( + mask, + lambda: paddle.ones( + (b, n), + ).cast(paddle.bool), + ) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default( + k_mask, lambda: paddle.ones((b, k.shape[-2])).cast(paddle.bool) + ) + + q_mask = q_mask.reshape([q_mask.shape[0], 1, q_mask.shape[1], 1]) + k_mask = k_mask.reshape([k_mask.shape[0], 1, 1, k_mask.shape[1]]) + input_mask = q_mask * k_mask + + if collab_heads: + k = k.expand(-1, h, -1, -1) + dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum( + "b h i j, h k -> b k i j", dots, self.pre_softmax_proj + ).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + input_mask = input_mask.cast(paddle.bool) + if exists(input_mask): + + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = paddle.arange(i) + r_shape = r.shape[0] + mask = r.reshape([1, 1, r_shape, 1]) < r.reshape([1, 1, 1, r_shape]) + + if self.is_export: + pad_list = [ + paddle.to_tensor(0, dtype="int32"), + paddle.to_tensor(0, dtype="int32"), + paddle.to_tensor(j - i, dtype="int32"), + paddle.to_tensor(0, dtype="int32"), + ] + mask = F.pad( + mask.cast(paddle.int32), + paddle.to_tensor(pad_list).cast(paddle.int32), + value=False, + ).cast(paddle.bool) + dots = dots.masked_fill_(mask, mask_value) + else: + mask = F.pad(mask.cast(paddle.int32), (0, 0, j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, axis=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum( + "b h i j, h k -> b k i j", attn, self.post_softmax_proj + ).contiguous() + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + b, h, n, d = out.shape + out = out.transpose([0, 2, 1, 3]).reshape([b, n, h * d]) + + if exists(self.to_v_gate): + gates = self.gate_v(x) + out = out * gates.sigmoid() + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Layer): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + is_export=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.LayerList([]) + + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = ( + FixedPositionalEmbedding(dim) if position_infused_attn else None + ) + + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + self.rel_pos = None + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + if macaron: + default_block = ("f",) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = ( + par_depth * 2 // 3 + ) # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert ( + len(default_block) <= par_width + ), "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert ( + sandwich_coef > 0 and sandwich_coef <= depth + ), "sandwich coefficient should be less than the depth" + layer_types = ( + ("a",) * sandwich_coef + + default_block * (depth - sandwich_coef) + + ("f",) * sandwich_coef + ) + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + for layer_type in self.layer_types: + if layer_type == "a": + layer = Attention( + dim, heads=heads, causal=causal, is_export=is_export, **attn_kwargs + ) + elif layer_type == "c": + layer = Attention(dim, heads=heads, is_export=is_export, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + residual_fn = Residual() + self.layers.append(nn.LayerList([norm_fn(), layer, residual_fn])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + seq_len=0, + return_hiddens=False, + ): + assert not ( + self.cross_attend ^ exists(context) + ), "context must be passed in if cross_attend is set to True" + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + rotary_pos_emb = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): + is_last = ind == (len(self.layers) - 1) + + if layer_type == "a": + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == "a": + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + rotary_pos_emb=rotary_pos_emb, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == "c": + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn, + ) + elif layer_type == "f": + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on decoder" + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +def create_latex_parameter(shape): + return paddle.create_parameter( + shape=shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Assign(paddle.randn(shape)), + ) + + +class TransformerDecoder(nn.Layer): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + is_export=False, + ): + super().__init__() + assert isinstance( + attn_layers, AttentionLayers + ), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.is_export = is_export + + self.init_() + + self.to_logits = ( + nn.Linear(dim, num_tokens) + if not tie_embedding + else lambda t: t @ self.token_emb.weight.t() + ) + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = create_latex_parameter([num_memory_tokens, dim]) + + # let funnel encoder know number of memory tokens, if specified + # TODO: think of a cleaner solution + if hasattr(attn_layers, "num_memory_tokens"): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + seq_len=0, + mems=None, + **kwargs, + ): + b, n, num_mem = *x.shape, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + + x = self.emb_dropout(x) + x = self.project_emb(x) + + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, seq_len=seq_len, **kwargs + ) + x = self.norm(x) + mem, x = x[:, :num_mem], x[:, num_mem:] + out = self.to_logits(x) if not return_embeddings else x + if return_mems: + hiddens = intermediates.hiddens + new_mems = ( + list(map(lambda pair: paddle.concat(pair, axis=-2), zip(mems, hiddens))) + if exists(mems) + else hiddens + ) + new_mems = list( + map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) + ) + return out, new_mems + + if return_attn: + attn_maps = list( + map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates) + ) + return out, attn_maps + + return out + + +def top_p(logits, thres=0.9): + sorted_logits, sorted_indices = paddle.sort(logits, descending=True) + cum_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1) + + sorted_indices_to_remove = cum_probs > (1 - thres) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + sorted_logits[sorted_indices_to_remove] = float("-inf") + return sorted_logits.scatter(1, sorted_indices, sorted_logits) + + +# topk + + +def top_k(logits, thres=0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = paddle.topk(logits, k) + probs = paddle.full_like(logits, float("-inf")) + probs = paddle.put_along_axis(probs, ind, val, 1) + return probs + + +class LaTeXOCRHead(nn.Layer): + """Implementation of LaTeX OCR decoder. + + Args: + encoded_feat: The encoded features with shape[N, 1, H//16, W//16] + tgt_seq: LaTeX-OCR labels with shape [N, L] , L is the max sequence length + xi: The first N-1 LaTeX-OCR sequences in tgt_seq with shape [N, L-1] + mask: The first N-1 LaTeX-OCR attention mask with shape [N, L-1] , L is the max sequence length + + Returns: + The predicted LaTeX sequences with shape [N, L-1, C], C is the number of LaTeX classes + """ + + def __init__( + self, + net=None, + in_channels=256, + out_channels=256, + pad_value=0, + decoder_args=None, + is_export=False, + ): + super().__init__() + decoder = Decoder( + dim=256, depth=4, heads=8, is_export=is_export, **decoder_args + ) + transformer_decoder = TransformerDecoder( + num_tokens=8000, + max_seq_len=512, + attn_layers=decoder, + is_export=is_export, + ) + self.temperature = 0.333 + self.bos_token = 1 + self.eos_token = 2 + self.max_length = 512 + self.pad_value = pad_value + + self.net = transformer_decoder + self.max_seq_len = self.net.max_seq_len + self.is_export = is_export + + @paddle.no_grad() + def generate( + self, + start_tokens, + seq_len, + eos_token=None, + temperature=1.0, + filter_logits_fn=top_k, + filter_thres=0.9, + **kwargs, + ): + was_training = self.net.training + num_dims = len(start_tokens.shape) + + if num_dims == 1: + start_tokens = start_tokens[None, :] + + b, t = start_tokens.shape + + self.net.eval() + out = start_tokens + mask = kwargs.pop("mask", None) + + if mask is None: + mask = paddle.full_like(out, True, dtype=paddle.bool) + + for _ in range(seq_len): + x = out[:, -self.max_seq_len :] + mask = mask[:, -self.max_seq_len :] + logits = self.net(x, mask=mask, **kwargs)[:, -1, :] + if filter_logits_fn in {top_k, top_p}: + filtered_logits = filter_logits_fn(logits, thres=filter_thres) + + probs = F.softmax(filtered_logits / temperature, axis=-1) + else: + raise NotImplementedError("The filter_logits_fn is not supported ") + + sample = paddle.multinomial(probs, 1) + out = paddle.concat((out, sample), axis=-1) + pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool") + mask = paddle.concat((mask, pad_mask), axis=1) + if ( + eos_token is not None + and ( + paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1 + ).all() + ): + break + out = out[:, t:] + if num_dims == 1: + out = out.squeeze(0) + return out + + @paddle.no_grad() + def generate_export( + self, + start_tokens, + seq_len, + eos_token=None, + context=None, + temperature=1.0, + filter_logits_fn=None, + filter_thres=0.9, + **kwargs, + ): + was_training = self.net.training + num_dims = len(start_tokens.shape) + + if num_dims == 1: + start_tokens = start_tokens[None, :] + + b, t = start_tokens.shape + + self.net.eval() + out = start_tokens + mask = kwargs.pop("mask", None) + + if mask is None: + mask = paddle.full_like(out, True, dtype=paddle.bool) + + i_idx = paddle.full([], 0) + while i_idx < paddle.to_tensor(seq_len): + x = out[:, -self.max_seq_len :] + paddle.jit.api.set_dynamic_shape(x, [-1, -1]) + mask = mask[:, -self.max_seq_len :] + paddle.jit.api.set_dynamic_shape(mask, [-1, -1]) + logits = self.net(x, mask=mask, context=context, seq_len=i_idx, **kwargs)[ + :, -1, : + ] + if filter_logits_fn in {top_k, top_p}: + filtered_logits = filter_logits_fn(logits, thres=filter_thres) + + probs = F.softmax(filtered_logits / temperature, axis=-1) + + sample = paddle.multinomial(probs, 1) + out = paddle.concat((out, sample), axis=-1) + + pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool") + mask = paddle.concat((mask, pad_mask), axis=1) + if ( + eos_token is not None + and ( + paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1 + ).all() + ): + out = out[:, t:] + if num_dims == 1: + out = out.squeeze(0) + return out + i_idx += 1 + out = out[:, t:] + if num_dims == 1: + out = out.squeeze(0) + return out + + # forward for export + def forward(self, inputs, targets=None): + if not self.training: + encoded_feat = inputs + batch_num = encoded_feat.shape[0] + bos_tensor = paddle.full([batch_num, 1], self.bos_token, dtype=paddle.int64) + if self.is_export: + word_pred = self.generate_export( + bos_tensor, + self.max_seq_len, + eos_token=self.eos_token, + context=encoded_feat, + temperature=self.temperature, + filter_logits_fn=top_k, + ) + else: + word_pred = self.generate( + bos_tensor, + self.max_seq_len, + eos_token=self.eos_token, + context=encoded_feat, + temperature=self.temperature, + filter_logits_fn=top_k, + ) + return word_pred + + encoded_feat, tgt_seq, mask = inputs + kwargs = {"context": encoded_feat, "mask": mask.cast(paddle.bool)} + x = tgt_seq + xi = x[:, :-1] + + mask = kwargs.get("mask", None) + if mask is not None and mask.shape[1] == x.shape[1]: + mask = mask[:, :-1] + kwargs["mask"] = mask + out = self.net(xi, **kwargs) + + return out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index e0a6a87fd3..04579a376a 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -42,6 +42,7 @@ from .rec_postprocess import ( SATRNLabelDecode, ParseQLabelDecode, CPPDLabelDecode, + LaTeXOCRDecode, ) from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess @@ -96,6 +97,7 @@ def build_post_process(config, global_config=None): "SATRNLabelDecode", "ParseQLabelDecode", "CPPDLabelDecode", + "LaTeXOCRDecode", ] if config["name"] == "PSEPostProcess": diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 46b629d531..a81d62f4d8 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -15,6 +15,7 @@ import numpy as np import paddle from paddle.nn import functional as F +from tokenizers import Tokenizer as TokenizerFast import re @@ -1210,3 +1211,53 @@ class CPPDLabelDecode(NRTRLabelDecode): def add_special_char(self, dict_character): dict_character = [""] + dict_character return dict_character + + +class LaTeXOCRDecode(object): + """Convert between latex-symbol and symbol-index""" + + def __init__(self, rec_char_dict_path, **kwargs): + super(LaTeXOCRDecode, self).__init__() + self.tokenizer = TokenizerFast.from_file(rec_char_dict_path) + + def post_process(self, s): + text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})" + letter = "[a-zA-Z]" + noletter = "[\W_^\d]" + names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)] + s = re.sub(text_reg, lambda match: str(names.pop(0)), s) + news = s + while True: + s = news + news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s) + news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news) + news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news) + if news == s: + break + return s + + def decode(self, tokens): + if len(tokens.shape) == 1: + tokens = tokens[None, :] + dec = [self.tokenizer.decode(tok) for tok in tokens] + dec_str_list = [ + "".join(detok.split(" ")) + .replace("Ġ", " ") + .replace("[EOS]", "") + .replace("[BOS]", "") + .replace("[PAD]", "") + .strip() + for detok in dec + ] + return [self.post_process(dec_str) for dec_str in dec_str_list] + + def __call__(self, preds, label=None, mode="eval", *args, **kwargs): + if mode == "train": + preds_idx = np.array(preds.argmax(axis=2)) + text = self.decode(preds_idx) + else: + text = self.decode(np.array(preds)) + if label is None: + return text + label = self.decode(np.array(label)) + return text, label diff --git a/ppocr/utils/dict/latex_ocr_tokenizer.json b/ppocr/utils/dict/latex_ocr_tokenizer.json new file mode 100644 index 0000000000..e8fd4f6d82 --- /dev/null +++ b/ppocr/utils/dict/latex_ocr_tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[PAD]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[BOS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[EOS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"ByteLevel","add_prefix_space":false,"trim_offsets":true},"post_processor":null,"decoder":null,"model":{"dropout":null,"unk_token":null,"continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[PAD]":0,"[BOS]":1,"[EOS]":2,"!":3,"\"":4,"#":5,"$":6,"&":7,"'":8,"(":9,")":10,"*":11,"+":12,",":13,"-":14,".":15,"/":16,"0":17,"1":18,"2":19,"3":20,"4":21,"5":22,"6":23,"7":24,"8":25,"9":26,":":27,";":28,"<":29,"=":30,">":31,"?":32,"@":33,"A":34,"B":35,"C":36,"D":37,"E":38,"F":39,"G":40,"H":41,"I":42,"J":43,"K":44,"L":45,"M":46,"N":47,"O":48,"P":49,"Q":50,"R":51,"S":52,"T":53,"U":54,"V":55,"W":56,"X":57,"Y":58,"Z":59,"[":60,"\\":61,"]":62,"^":63,"_":64,"`":65,"a":66,"b":67,"c":68,"d":69,"e":70,"f":71,"g":72,"h":73,"i":74,"j":75,"k":76,"l":77,"m":78,"n":79,"o":80,"p":81,"q":82,"r":83,"s":84,"t":85,"u":86,"v":87,"w":88,"x":89,"y":90,"z":91,"{":92,"|":93,"}":94,"~":95,"½":96,"¿":97,"ï":98,"Ċ":99,"č":100,"Ġ":101,"Ġ}":102,"Ġ{":103,"Ġ\\":104,"Ġ_":105,"Ġ^":106,"Ġ2":107,"Ġ)":108,"Ġ(":109,"Ġ1":110,"ra":111,"Ġ=":112,"Ġ-":113,"čĊ":114,"Ġ,":115,"fra":116,"frac":117,"Ġ+":118,"ma":119,"ta":120,"ig":121,"Ġ0":122,"ar":123,"al":124,"le":125,"Ġi":126,"th":127,"Ġx":128,"ft":129,"igh":130,"me":131,"righ":132,"math":133,"Ġn":134,"Ġ.":135,"Ġ\\,":136,"in":137,"ph":138,"Ġd":139,"left":140,"Ġa":141,"right":142,"am":143,"eta":144,"ti":145,"Ġm":146,"mu":147,"Ġ3":148,"Ġk":149,"Ġt":150,"Ġe":151,"Ġr":152,"Ġ&":153,"Ġc":154,"Ġp":155,"si":156,"rm":157,"de":158,"mathrm":159,"Ġ4":160,"Ġs":161,"pr":162,"Ġ~":163,"pha":164,"Ġl":165,"alpha":166,"da":167,"ĠA":168,"Ġ\\;":169,"ot":170,"pi":171,"par":172,"tial":173,"partial":174,"ime":175,"prime":176,"psi":177,"dot":178,"Ġj":179,"Ġb":180,"Ġf":181,"lta":182,"Ġ|":183,"amma":184,"bda":185,"ambda":186,"phi":187,"Ġq":188,"bf":189,"Ġg":190,"nu":191,"Ġz":192,"ray":193,"array":194,"ĠM":195,"ĠT":196,"Ġy":197,"cal":198,"bar":199,"ĠN":200,"igma":201,"ĠR":202,"rt":203,"lambda":204,"int":205,"ĠB":206,"ve":207,"ng":208,"qu":209,"ĠL":210,"Ġ/":211,"lo":212,"beta":213,"ngle":214,"Ġu":215,"delta":216,"sq":217,"sqrt":218,"theta":219,"Ġ\\\\":220,"gamma":221,"Ġ]":222,"sigma":223,"ga":224,"mega":225,"ĠD":226,"ĠF":227,"Ġ[":228,"ĠS":229,"mathbf":230,"su":231,"ĠP":232,"lon":233,"Ġv":234,"sum":235,"psilon":236,"ĠV":237,"ĠC":238,"cdot":239,"epsilon":240,"at":241,"hat":242,"ad":243,"quad":244,"Ġh":245,"ho":246,"rho":247,"hi":248,"to":249,"ĠE":250,"la":251,"ĠH":252,"lde":253,"tilde":254,"ĠQ":255,"Ġ5":256,"var":257,"ĠX":258,"ĠG":259,"be":260,"nd":261,"omega":262,"end":263,"gin":264,"begin":265,"tau":266,"Ġ6":267,"na":268,"vec":269,"ĠI":270,"Ġo":271,"rangle":272,"Ġ*":273,"De":274,"Delta":275,"Gamma":276,"pe":277,"fty":278,"infty":279,"ĠK":280,"xi":281,"Ġ8":282,"ow":283,"ĠJ":284,"ĠU":285,"row":286,"tar":287,"ge":288,"Phi":289,"ap":290,"ĠW":291,"co":292,"mes":293,"times":294,"sin":295,"ĠZ":296,"langle":297,"ope":298,"rna":299,"rato":300,"operato":301,"rname":302,"operatorname":303,"tarrow":304,"lin":305,"line":306,"varphi":307,"pm":308,"rline":309,"Lambda":310,"Ġ\\!":311,"Ġ;":312,"dots":313,"cos":314,"Ġw":315,"rightarrow":316,"big":317,"chi":318,"ove":319,"Ġ\\}":320,"overline":321,"Ġ7":322,"ex":323,"pa":324,"st":325,"pro":326,"qquad":327,"iv":328,"equ":329,"equiv":330,"ĠO":331,"ln":332,"Omega":333,"ll":334,"Ġ9":335,"kap":336,"kappa":337,"Big":338,"Ġ\\{":339,"dag":340,"ĠY":341,"\\{":342,"varepsilon":343,"cdots":344,"Ġ:":345,"mathcal":346,"Psi":347,"Ġ>":348,"bo":349,"bol":350,"Ġ<":351,"ger":352,"dagger":353,"ldots":354,"ell":355,"bla":356,"nabla":357,"exp":358,"yle":359,"style":360,"zeta":361,"Sigma":362,"wi":363,"wide":364,"sim":365,"leq":366,"Ġ!":367,"bigg":368,"mathb":369,"mathbb":370,"Ġ\\:":371,"hbar":372,"otimes":373,"bold":374,"\\}":375,"mi":376,"prox":377,"approx":378,"Pi":379,"log":380,"mid":381,"sp":382,"vert":383,"di":384,"prod":385,"per":386,"perp":387,"ystyle":388,"laystyle":389,"splaystyle":390,"displaystyle":391,"meq":392,"simeq":393,"ed":394,"wed":395,"wedge":396,"widetilde":397,"sy":398,"sym":399,"symbol":400,"boldsymbol":401,"ck":402,"tex":403,"text":404,"ri":405,"Th":406,"Theta":407,"geq":408,"se":409,"eq":410,"nde":411,"unde":412,"tan":413,"sc":414,"ast":415,"rc":416,"set":417,"pt":418,"widehat":419,"ci":420,"circ":421,"re":422,"ript":423,"script":424,"underline":425,"Ġ\\|":426,"rel":427,"neq":428,"sta":429,"stack":430,"stackrel":431,"sinh":432,"op":433,"us":434,"cosh":435,"Bigg":436,"ce":437,"textstyle":438,"star":439,"not":440,"frak":441,"mathfrak":442,"mp":443,"biggr":444,"lus":445,"oplus":446,"vartheta":447,"biggl":448,"Bigr":449,"bra":450,"Bigl":451,"fo":452,"sf":453,"sub":454,"subset":455,"ngrightarrow":456,"ec":457,"boldmath":458,"rall":459,"forall":460,"scriptstyle":461,"ect":462,"parrow":463,"uparrow":464,"bj":465,"bject":466,"pto":467,"propto":468,"Ġ'":469,"longrightarrow":470,"bigl":471,"bigr":472,"oint":473,"ps":474,"maps":475,"mapsto":476,"om":477,"lle":478,"\\|":479,"ddot":480,"cu":481,"bin":482,"binom":483,"vdots":484,"angle":485,"leftrightarrow":486,"over":487,"or":488,"mathsf":489,"cup":490,"brace":491,"no":492,"arc":493,"flo":494,"floor":495,"tri":496,"triangle":497,"Xi":498,"cot":499,"bot":500,"cong":501,"it":502,"mbe":503,"numbe":504,"nonumbe":505,"nonumber":506,"cap":507,"Righ":508,"Rightarrow":509,"ze":510,"size":511,"textrm":512,"ne":513,"arctan":514,"ralle":515,"paralle":516,"parallel":517,"cfrac":518,"Ġ--":519,"object":520,"ĠObject":521,"brack":522,"sh":523,"arrow":524,"own":525,"varrho":526,"subseteq":527,"rbrace":528,"textbf":529,"imath":530,"od":531,"down":532,"he":533,"land":534,"scriptscriptstyle":535,"scriptsize":536,"che":537,"check":538,"sla":539,"overrightarrow":540,"downarrow":541,"Biggl":542,"gg":543,"nto":544,"phanto":545,"phantom":546,"exi":547,"hline":548,"sts":549,"exists":550,"Biggr":551,"bu":552,"rfloor":553,"ddots":554,"io":555,"iota":556,"llet":557,"bullet":558,"colon":559,"inus":560,"Upsilon":561,"lfloor":562,"lbrack":563,"underbrace":564,"neg":565,"Im":566,"mathit":567,"tin":568,"tiny":569,"jmath":570,"lef":571,"slash":572,"vee":573,"minus":574,"setminus":575,"Re":576,"iint":577,"leftarrow":578,"Ve":579,"Vert":580,"atop":581,"sup":582,"bigcup":583,"wp":584,"dim":585,"sec":586,"supset":587,"Lo":588,"lor":589,"pmod":590,"mod":591,"bigoplus":592,"il":593,"bmod":594,"coth":595,"Le":596,"ftrightarrow":597,"Leftrightarrow":598,"ngleftrightarrow":599,"sma":600,"upsilon":601,"\\,":602,"csc":603,"eph":604,"aleph":605,"bigwedge":606,"arcsin":607,"small":608,"odot":609,"overset":610,"rbrack":611,"mit":612,"lbrace":613,"li":614,"arp":615,"arge":616,"Ġ\\#":617,"bre":618,"textsf":619,"Longrightarrow":620,"breve":621,"em":622,"yset":623,"varpi":624,"ptyset":625,"emptyset":626,"ff":627,"iff":628,"nt":629,"er":630,"lap":631,"lnot":632,"dash":633,"under":634,"slant":635,"arg":636,"underset":637,"Bo":638,"Box":639,"Ġ\"":640,"spa":641,"space":642,"deg":643,"iiint":644,"oo":645,"otnot":646,"footnot":647,"arpoo":648,"footnote":649,"rlap":650,"es":651,"imp":652,"sb":653,"te":654,"bigtriangle":655,"lies":656,"implies":657,"\\;":658,"ker":659,"footnotesize":660,"tharpoo":661,"up":662,"acu":663,"acute":664,"longleftrightarrow":665,"eil":666,"lce":667,"rceil":668,"lceil":669,"vphantom":670,"en":671,"thin":672,"ack":673,"back":674,"tt":675,"backslash":676,"xrightarrow":677,"vdash":678,"top":679,"rightharpoo":680,"varsigma":681,"Longleftrightarrow":682,"mathop":683,"large":684,"bigcap":685,"leqslant":686,"Ġ`":687,"overbrace":688,"nup":689,"rightharpoonup":690,"bigotimes":691,"triangleq":692,"Large":693,"ru":694,"null":695,"bigtriangleup":696,"varno":697,"thing":698,"varnothing":699,"doteq":700,"Ġ\\_":701,"overleftarrow":702,"hf":703,"bigstar":704,"enspace":705,"\\!":706,"stru":707,"strut":708,"ominus":709,"div":710,"ond":711,"amond":712,"ddagger":713,"Ġcm":714,"ni":715,"sk":716,"diamond":717,"rVert":718,"prot":719,"protect":720,"ip":721,"varDelta":722,"notin":723,"skip":724,"lVert":725,"Ġ\\/":726,"dotsc":727,"ill":728,"ule":729,"\\:":730,"hfill":731,"krightarrow":732,"okrightarrow":733,"hookrightarrow":734,"sharp":735,"Vdash":736,"bigvee":737,"subsetneq":738,"supseteq":739,"Ġ?":740,"ngmapsto":741,"longmapsto":742,"cdotp":743,"geqslant":744,"bigtriangledown":745,"dotsb":746,"lim":747,"fl":748,"triangleleft":749,"flat":750,"sl":751,"box":752,"Ġ---":753,"sqcup":754,"jlim":755,"ls":756,"mo":757,"dels":758,"ref":759,"models":760,"tag":761,"Pr":762,"mal":763,"ou":764,"llap":765,"thinspace":766,"enskip":767,"Vec":768,"ebox":769,"kebox":770,"nor":771,"rd":772,"squ":773,"vline":774,"¿½":775,"�":776,"Ġ�":777,"makebox":778,"surd":779,"normal":780,"are":781,"square":782,"pou":783,"mathrel":784,"varOmega":785,"nds":786,"smallsetminus":787,"pounds":788,"ns":789,"ss":790,"smi":791,"mathor":792,"rightlef":793,"textup":794,"tharpoons":795,"smile":796,"mathord":797,"rightleftharpoons":798,"cc":799,"Ġ\\-":800,"succ":801,"ftarrow":802,"rtimes":803,"det":804,"prec":805,"texttt":806,"oslash":807,"Ġ\\&":808,"arrowvert":809,"lg":810,"Ġmm":811,"inter":812,"ngleftarrow":813,"hfil":814,"intercal":815,"frow":816,"Ġ\\*":817,"frown":818,"mpe":819,"Ġpt":820,"varpro":821,"searrow":822,"bumpe":823,"varprojlim":824,"bumpeq":825,"Down":826,"SS":827,"cd":828,"ere":829,"gcd":830,"ohe":831,"tw":832,"leme":833,"there":834,"injlim":835,"tit":836,"adrightarrow":837,"varinjlim":838,"comp":839,"textit":840,"fore":841,"overleftrightarrow":842,"Downarrow":843,"oheadrightarrow":844,"twoheadrightarrow":845,"lement":846,"therefore":847,"complement":848,"ca":849,"thi":850,"longleftarrow":851,"bigm":852,"triangleright":853,"nearrow":854,"\\#":855,"nce":856,"ral":857,"cance":858,"thick":859,"cancel":860,"Uparrow":861,"nat":862,"ural":863,"mathstrut":864,"suit":865,"bigcirc":866,"smallskip":867,"diamondsuit":868,"normalsize":869,"natural":870,"gt":871,"less":872,"mathtt":873,"bigsqcup":874,"thicksim":875,"lesssim":876,"bow":877,"llde":878,"tie":879,"nullde":880,"miter":881,"limiter":882,"kern":883,"bowtie":884,"nulldelimiter":885,"nulldelimiterspace":886,"Da":887,"hphantom":888,"ro":889,"vDa":890,"barwedge":891,"beth":892,"eqno":893,"vDash":894,"AR":895,"Di":896,"GE":897,"LAR":898,"dskip":899,"ts":900,"Ġ@":901,"medskip":902,"ndown":903,"gets":904,"coprod":905,"dotsm":906,"smash":907,"rightharpoondown":908,"Diamond":909,"LARGE":910,"nrightarrow":911,"nleftrightarrow":912,"rsim":913,"rne":914,"warrow":915,"mathc":916,"corne":917,"textnormal":918,"preceq":919,"gtrsim":920,"roup":921,"corner":922,"Ġ\\[":923,"Ġ\\]":924,"mathope":925,"lefteq":926,"lose":927,"varkappa":928,"Bigm":929,"Biggm":930,"mathclose":931,"mathopen":932,"lefteqn":933,"Bar":934,"Ti":935,"lr":936,"swarrow":937,"uge":938,"vru":939,"xleftarrow":940,"mathnormal":941,"rightrightarrow":942,"rightleftarrow":943,"sqsubseteq":944,"succeq":945,"Tilde":946,"lrcorner":947,"vrule":948,"rightrightarrows":949,"rightleftarrows":950,"AA":951,"Hat":952,"ak":953,"ble":954,"dou":955,"hss":956,"min":957,"nright":958,"nleftarrow":959,"uph":960,"wbre":961,"allo":962,"side":963,"sqcap":964,"hom":965,"bigodot":966,"arpoonright":967,"blebarwedge":968,"doublebarwedge":969,"upharpoonright":970,"wbreak":971,"allowbreak":972,"sideset":973,"--":974,"Huge":975,"amal":976,"do":977,"fbox":978,"group":979,"hskip":980,"lse":981,"pprox":982,"rk":983,"rgroup":984,"rapprox":985,"Ġin":986,"arrayco":987,"sure":988,"varlim":989,"pmb":990,"cite":991,"substack":992,"leftrightarrows":993,"supsetneq":994,"Longleftarrow":995,"updownarrow":996,"ensure":997,"lgroup":998,"gtrapprox":999,"amalg":1000,"lsep":1001,"arraycolsep":1002,"ensuremath":1003,"asym":1004,"ch":1005,"dig":1006,"ddag":1007,"ew":1008,"gra":1009,"gime":1010,"jo":1011,"ltimes":1012,"nleq":1013,"tch":1014,"frame":1015,"max":1016,"thde":1017,"inrel":1018,"ver":1019,"withde":1020,"ointop":1021,"notag":1022,"smallint":1023,"skew":1024,"lims":1025,"asymp":1026,"digamma":1027,"grave":1028,"gimel":1029,"joinrel":1030,"framebox":1031,"withdelims":1032,"Ar":1033,"Rrightarrow":1034,"ae":1035,"ag":1036,"fill":1037,"hspace":1038,"huge":1039,"lq":1040,"nwarrow":1041,"wline":1042,"Ġ14":1043,"mark":1044,"led":1045,"inf":1046,"inde":1047,"Ġex":1048,"pitch":1049,"dotsi":1050,"intop":1051,"rowvert":1052,"llcorner":1053,"black":1054,"leqq":1055,"biggm":1056,"approxeq":1057,"diag":1058,"textsc":1059,"textsl":1060,"circled":1061,"fork":1062,"cur":1063,"newline":1064,"negthick":1065,"atopwithdelims":1066,"Leftarrow":1067,"footnotemark":1068,"uplus":1069,"subsetneqq":1070,"---":1071,"varlimsup":1072,"varliminf":1073,"verb":1074,"Arrowvert":1075,"pitchfork":1076,"blacksquare":1077,"diagup":1078,"negthickspace":1079,"23":1080,"25":1081,"\\-":1082,"\\/":1083,"ape":1084,"ckap":1085,"dddot":1086,"erline":1087,"ever":1088,"ij":1089,"ice":1090,"ly":1091,"md":1092,"nda":1093,"nnu":1094,"nmid":1095,"nRightarrow":1096,"nVdash":1097,"of":1098,"off":1099,"sho":1100,"spe":1101,"wr":1102,"ymath":1103,"Ġ#":1104,"Ġ\\'":1105,"Ġ\\^":1106,"Ġ10":1107,"Ġ15":1108,"mannu":1109,"igarrow":1110,"fter":1111,"meral":1112,"leftrightharpoo":1113,"rightsqu":1114,"def":1115,"arrayst":1116,"rtmid":1117,"interline":1118,"vearrow":1119,"ngeq":1120,"hoice":1121,"lax":1122,"varGamma":1123,"varpropto":1124,"vartriangle":1125,"varUpsilon":1126,"biguplus":1127,"expa":1128,"Ġ<$":1129,"mathbin":1130,"perca":1131,"textcircled":1132,"textmd":1133,"scsh":1134,"cial":1135,"retch":1136,"relax":1137,"overwithdelims":1138,"noinde":1139,"owns":1140,"veebar":1141,"underbar":1142,"underrightarrow":1143,"upperca":1144,"backsimeq":1145,"trianglelefteq":1146,"boxtimes":1147,"boxed":1148,"preccur":1149,"thickap":1150,"root":1151,"romannu":1152,"mathchoice":1153,"index":1154,"circledcirc":1155,"curvearrow":1156,"everymath":1157,"lyeq":1158,"ndafter":1159,"offinterline":1160,"shortmid":1161,"special":1162,"leftrightharpoons":1163,"rightsquigarrow":1164,"arraystretch":1165,"expandafter":1166,"scshape":1167,"noindent":1168,"uppercase":1169,"preccurlyeq":1170,"thickapprox":1171,"romannumeral":1172,"curvearrowright":1173,"offinterlineskip":1174},"merges":["Ġ }","Ġ {","Ġ \\","Ġ _","Ġ ^","Ġ 2","Ġ )","Ġ (","Ġ 1","r a","Ġ =","Ġ -","č Ċ","Ġ ,","f ra","fra c","Ġ +","m a","t a","i g","Ġ 0","a r","a l","l e","Ġ i","t h","Ġ x","f t","ig h","m e","r igh","ma th","Ġ n","Ġ .","Ġ\\ ,","i n","p h","Ġ d","le ft","Ġ a","righ t","a m","e ta","t i","Ġ m","m u","Ġ 3","Ġ k","Ġ t","Ġ e","Ġ r","Ġ &","Ġ c","Ġ p","s i","r m","d e","math rm","Ġ 4","Ġ s","p r","Ġ ~","ph a","Ġ l","al pha","d a","Ġ A","Ġ\\ ;","o t","p i","p ar","ti al","par tial","i me","pr ime","p si","d ot","Ġ j","Ġ b","Ġ f","l ta","Ġ |","am ma","b da","am bda","ph i","Ġ q","b f","Ġ g","n u","Ġ z","ra y","ar ray","Ġ M","Ġ T","Ġ y","c al","b ar","Ġ N","ig ma","Ġ R","r t","l ambda","in t","Ġ B","v e","n g","q u","Ġ L","Ġ /","l o","b eta","ng le","Ġ u","de lta","s q","sq rt","th eta","Ġ\\ \\","g amma","Ġ ]","s igma","g a","me ga","Ġ D","Ġ F","Ġ [","Ġ S","math bf","s u","Ġ P","lo n","Ġ v","su m","psi lon","Ġ V","Ġ C","c dot","e psilon","a t","h at","a d","qu ad","Ġ h","h o","r ho","h i","t o","Ġ E","l a","Ġ H","l de","ti lde","Ġ Q","Ġ 5","v ar","Ġ X","Ġ G","b e","n d","o mega","e nd","g in","be gin","ta u","Ġ 6","n a","ve c","Ġ I","Ġ o","ra ngle","Ġ *","D e","De lta","G amma","p e","ft y","in fty","Ġ K","x i","Ġ 8","o w","Ġ J","Ġ U","r ow","ta r","g e","P hi","a p","Ġ W","c o","me s","ti mes","s in","Ġ Z","la ngle","o pe","r na","ra to","ope rato","rna me","operato rname","tar row","l in","lin e","var phi","p m","r line","L ambda","Ġ\\ !","Ġ ;","dot s","co s","Ġ w","righ tarrow","b ig","c hi","o ve","Ġ\\ }","ove rline","Ġ 7","e x","p a","s t","pr o","q quad","i v","e qu","equ iv","Ġ O","l n","O mega","l l","Ġ 9","k ap","kap pa","B ig","Ġ\\ {","da g","Ġ Y","\\ {","var epsilon","cdot s","Ġ :","math cal","P si","Ġ >","b o","bo l","Ġ <","ge r","dag ger","l dots","e ll","b la","na bla","ex p","y le","st yle","z eta","S igma","w i","wi de","si m","le q","Ġ !","big g","math b","mathb b","Ġ\\ :","h bar","o times","bol d","\\ }","m i","pro x","ap prox","P i","lo g","mi d","s p","ve rt","d i","pro d","pe r","per p","y style","la ystyle","sp laystyle","di splaystyle","me q","si meq","e d","w ed","wed ge","wide tilde","s y","sy m","sym bol","bold symbol","c k","t ex","tex t","r i","T h","Th eta","ge q","s e","e q","n de","u nde","ta n","s c","a st","r c","se t","p t","wide hat","c i","ci rc","r e","ri pt","sc ript","unde rline","Ġ\\ |","re l","n eq","s ta","sta ck","stack rel","sin h","o p","u s","cos h","Big g","c e","text style","s tar","n ot","fra k","math frak","m p","bigg r","l us","op lus","var theta","bigg l","Big r","b ra","Big l","f o","s f","su b","sub set","ng rightarrow","e c","bold math","ra ll","fo rall","script style","ec t","par row","u parrow","b j","bj ect","p to","pro pto","Ġ '","lo ngrightarrow","big l","big r","o int","p s","ma ps","maps to","o m","l le","\\ |","d dot","c u","b in","bin om","v dots","a ngle","left rightarrow","ove r","o r","math sf","cu p","bra ce","n o","ar c","f lo","flo or","t ri","tri angle","X i","c ot","b ot","co ng","i t","m be","nu mbe","no numbe","nonumbe r","c ap","R igh","Righ tarrow","z e","si ze","text rm","n e","arc tan","ra lle","pa ralle","paralle l","c frac","Ġ- -","o bject","ĠO bject","bra ck","s h","ar row","ow n","var rho","subset eq","r brace","text bf","i math","o d","d own","h e","la nd","script scriptstyle","script size","c he","che ck","s la","over rightarrow","down arrow","Bigg l","g g","n to","pha nto","phanto m","e xi","h line","st s","exi sts","Bigg r","b u","r floor","d dots","i o","io ta","lle t","bu llet","co lon","in us","U psilon","l floor","l brack","unde rbrace","ne g","I m","math it","t in","tin y","j math","le f","sla sh","ve e","m inus","set minus","R e","i int","lef tarrow","V e","Ve rt","at op","su p","big cup","w p","di m","se c","sup set","L o","lo r","pm od","m od","big oplus","i l","b mod","co th","L e","ft rightarrow","Le ftrightarrow","ng leftrightarrow","s ma","u psilon","\\ ,","c sc","e ph","al eph","big wedge","arc sin","sma ll","o dot","over set","r brack","mi t","l brace","l i","ar p","ar ge","Ġ\\ #","b re","text sf","Lo ngrightarrow","bre ve","e m","y set","var pi","pt yset","em ptyset","f f","i ff","n t","e r","la p","ln ot","da sh","unde r","sla nt","ar g","under set","B o","Bo x","Ġ \"","s pa","spa ce","de g","i iint","o o","ot not","fo otnot","arp oo","footnot e","r lap","e s","i mp","s b","t e","big triangle","li es","imp lies","\\ ;","k er","footnote size","th arpoo","u p","a cu","acu te","lo ngleftrightarrow","e il","l ce","rc eil","lce il","v phantom","e n","th in","a ck","b ack","t t","back slash","x rightarrow","v dash","to p","righ tharpoo","var sigma","Lo ngleftrightarrow","math op","l arge","big cap","leq slant","Ġ `","over brace","nu p","rightharpoo nup","big otimes","triangle q","L arge","r u","nu ll","bigtriangle up","var no","thin g","varno thing","dot eq","Ġ\\ _","over leftarrow","h f","big star","en space","\\ !","st ru","stru t","om inus","d iv","o nd","am ond","d dagger","Ġc m","n i","s k","di amond","r Vert","pr ot","prot ect","i p","var Delta","not in","sk ip","l Vert","Ġ\\ /","dots c","i ll","u le","\\ :","hf ill","k rightarrow","o krightarrow","ho okrightarrow","sh arp","V dash","big vee","subset neq","supset eq","Ġ ?","ng mapsto","lo ngmapsto","cdot p","geq slant","bigtriangle down","dots b","li m","f l","triangle left","fl at","s l","bo x","Ġ-- -","sq cup","j lim","l s","m o","de ls","re f","mo dels","ta g","P r","ma l","o u","l lap","thin space","en skip","V ec","e box","k ebox","n or","r d","s qu","v line","¿ ½","ï ¿½","Ġ �","ma kebox","su rd","nor mal","ar e","squ are","p ou","math rel","var Omega","nd s","small setminus","pou nds","n s","s s","s mi","math or","right lef","text up","tharpoo ns","smi le","mathor d","rightlef tharpoons","c c","Ġ\\ -","su cc","f tarrow","r times","de t","pr ec","text tt","o slash","Ġ\\ &","arrow vert","l g","Ġm m","int er","ngle ftarrow","hf il","inter cal","f row","Ġ\\ *","frow n","m pe","Ġp t","var pro","se arrow","bu mpe","varpro jlim","bumpe q","D own","S S","c d","e re","g cd","o he","t w","le me","th ere","in jlim","ti t","ad rightarrow","var injlim","co mp","tex tit","fo re","over leftrightarrow","Down arrow","ohe adrightarrow","tw oheadrightarrow","leme nt","there fore","comp lement","c a","th i","lo ngleftarrow","big m","triangle right","ne arrow","\\ #","n ce","ra l","ca nce","thi ck","cance l","U parrow","n at","u ral","math strut","su it","big circ","small skip","diamond suit","normal size","nat ural","g t","le ss","math tt","big sqcup","thick sim","less sim","b ow","l lde","ti e","nu llde","mit er","li miter","ker n","bow tie","nullde limiter","nulldelimiter space","D a","h phantom","r o","v Da","bar wedge","be th","eq no","vDa sh","A R","D i","G E","L AR","d skip","t s","Ġ @","me dskip","nd own","ge ts","co prod","dots m","sma sh","rightharpoo ndown","Di amond","LAR GE","n rightarrow","n leftrightarrow","r sim","r ne","w arrow","math c","co rne","text normal","prec eq","gt rsim","ro up","corne r","Ġ\\ [","Ġ\\ ]","math ope","left eq","lo se","var kappa","Big m","Bigg m","mathc lose","mathope n","lefteq n","B ar","T i","l r","s warrow","u ge","v ru","x leftarrow","math normal","right rightarrow","right leftarrow","sq subseteq","succ eq","Ti lde","lr corner","vru le","rightrightarrow s","rightleftarrow s","A A","H at","a k","b le","d ou","h ss","m in","n right","n leftarrow","u ph","w bre","al lo","si de","sq cap","ho m","big odot","arpoo nright","ble barwedge","dou blebarwedge","uph arpoonright","wbre ak","allo wbreak","side set","- -","H uge","a mal","d o","f box","g roup","h skip","l se","p prox","r k","r group","ra pprox","Ġi n","array co","su re","var lim","pm b","ci te","sub stack","leftrightarrow s","supset neq","Lo ngleftarrow","up downarrow","en sure","lg roup","gt rapprox","amal g","lse p","arrayco lsep","ensure math","a sym","c h","d ig","d dag","e w","g ra","g ime","j o","l times","n leq","t ch","fra me","ma x","th de","in rel","ve r","wi thde","oint op","no tag","small int","sk ew","lim s","asym p","dig amma","gra ve","gime l","jo inrel","frame box","withde lims","A r","R rightarrow","a e","a g","f ill","h space","h uge","l q","n warrow","w line","Ġ1 4","ma rk","le d","in f","in de","Ġe x","pi tch","dot si","int op","row vert","ll corner","bla ck","leq q","bigg m","approx eq","di ag","text sc","text sl","circ led","fo rk","cu r","ne wline","neg thick","atop withdelims","Le ftarrow","footnote mark","up lus","subsetneq q","-- -","varlim sup","varlim inf","ver b","Ar rowvert","pitch fork","black square","diag up","negthick space","2 3","2 5","\\ -","\\ /","a pe","c kap","d ddot","e rline","e ver","i j","i ce","l y","m d","n da","n nu","n mid","n Rightarrow","n Vdash","o f","o ff","s ho","s pe","w r","y math","Ġ #","Ġ\\ '","Ġ\\ ^","Ġ1 0","Ġ1 5","ma nnu","ig arrow","ft er","me ral","left rightharpoo","right squ","de f","array st","rt mid","int erline","ve arrow","ng eq","ho ice","la x","var Gamma","var propto","var triangle","var Upsilon","big uplus","ex pa","Ġ< $","mathb in","per ca","text circled","text md","sc sh","ci al","re tch","re lax","over withdelims","no inde","own s","vee bar","under bar","under rightarrow","up perca","back simeq","triangleleft eq","box times","box ed","prec cur","thi ckap","ro ot","ro mannu","mathc hoice","inde x","circled circ","cur vearrow","ever ymath","ly eq","nda fter","off interline","sho rtmid","spe cial","leftrightharpoo ns","rightsqu igarrow","arrayst retch","expa ndafter","scsh ape","noinde nt","upperca se","preccur lyeq","thickap prox","romannu meral","curvearrow right","offinterline skip"]}} diff --git a/ppocr/utils/formula_utils/math_txt2pkl.py b/ppocr/utils/formula_utils/math_txt2pkl.py new file mode 100644 index 0000000000..748fcb1ba1 --- /dev/null +++ b/ppocr/utils/formula_utils/math_txt2pkl.py @@ -0,0 +1,70 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from tqdm import tqdm +import os +import cv2 +import imagesize +from collections import defaultdict +import glob +from os.path import join +import argparse + + +def txt2pickle(images, equations, save_dir): + save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1])) + min_dimensions = (32, 32) + max_dimensions = (672, 192) + max_length = 512 + data = defaultdict(lambda: []) + if images is not None and equations is not None: + images_list = [ + path.replace("\\", "/") for path in glob.glob(join(images, "*.png")) + ] + indices = [int(os.path.basename(img).split(".")[0]) for img in images_list] + eqs = open(equations, "r").read().split("\n") + for i, im in tqdm(enumerate(images_list), total=len(images_list)): + width, height = imagesize.get(im) + if ( + min_dimensions[0] <= width <= max_dimensions[0] + and min_dimensions[1] <= height <= max_dimensions[1] + ): + data[(width, height)].append((eqs[indices[i]], im)) + data = dict(data) + with open(save_p, "wb") as file: + pickle.dump(data, file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--image_dir", + type=str, + default=".", + help="Input_label or input path to be converted", + ) + parser.add_argument( + "--mathtxt_path", + type=str, + default=".", + help="Input_label or input path to be converted", + ) + parser.add_argument( + "--output_dir", type=str, default="out_label.txt", help="Output file name" + ) + + args = parser.parse_args() + txt2pickle(args.image_dir, args.mathtxt_path, args.output_dir) diff --git a/requirements.txt b/requirements.txt index e513a2e8d9..40afd21d6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,6 @@ cython Pillow pyyaml requests +albumentations==1.4.10 +tokenizers==0.19.1 +imagesize diff --git a/tools/eval.py b/tools/eval.py index 9ac5498b75..59a36e15a9 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -105,6 +105,8 @@ def main(): if "model_type" in config["Architecture"].keys(): if config["Architecture"]["algorithm"] == "CAN": model_type = "can" + elif config["Architecture"]["algorithm"] == "LaTeXOCR": + model_type = "latexocr" else: model_type = config["Architecture"]["model_type"] else: diff --git a/tools/export_model.py b/tools/export_model.py index 8ca31c9d58..c10f81d223 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -131,6 +131,11 @@ def export_single_model( ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "LaTeXOCR": + other_shape = [ + paddle.static.InputSpec(shape=[None, 1, None, None], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: input_spec = [ paddle.static.InputSpec(shape=[None, 512], dtype="int64"), # input_ids diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 06b318eb35..239b09ef19 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -133,6 +133,11 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char, } + elif self.rec_algorithm == "LaTeXOCR": + postprocess_params = { + "name": "LaTeXOCRDecode", + "rec_char_dict_path": args.rec_char_dict_path, + } elif self.rec_algorithm == "ParseQ": postprocess_params = { "name": "ParseQLabelDecode", @@ -450,6 +455,90 @@ class TextRecognizer(object): return img + def pad_(self, img, divable=32): + threshold = 128 + data = np.array(img.convert("LA")) + if data[..., -1].var() == 0: + data = (data[..., 0]).astype(np.uint8) + else: + data = (255 - data[..., -1]).astype(np.uint8) + data = (data - data.min()) / (data.max() - data.min()) * 255 + if data.mean() > threshold: + # To invert the text to white + gray = 255 * (data < threshold).astype(np.uint8) + else: + gray = 255 * (data > threshold).astype(np.uint8) + data = 255 - data + + coords = cv2.findNonZero(gray) # Find all non-zero points (text) + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + rect = data[b : b + h, a : a + w] + im = Image.fromarray(rect).convert("L") + dims = [] + for x in [w, h]: + div, mod = divmod(x, divable) + dims.append(divable * (div + (1 if mod > 0 else 0))) + padded = Image.new("L", dims, 255) + padded.paste(im, (0, 0, im.size[0], im.size[1])) + return padded + + def minmax_size_( + self, + img, + max_dimensions, + min_dimensions, + ): + if max_dimensions is not None: + ratios = [a / b for a, b in zip(img.size, max_dimensions)] + if any([r > 1 for r in ratios]): + size = np.array(img.size) // max(ratios) + img = img.resize(tuple(size.astype(int)), Image.BILINEAR) + if min_dimensions is not None: + # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions + padded_size = [ + max(img_dim, min_dim) + for img_dim, min_dim in zip(img.size, min_dimensions) + ] + if padded_size != list(img.size): # assert hypothesis + padded_im = Image.new("L", padded_size, 255) + padded_im.paste(img, img.getbbox()) + img = padded_im + return img + + def norm_img_latexocr(self, img): + # CAN only predict gray scale image + shape = (1, 1, 3) + mean = [0.7931, 0.7931, 0.7931] + std = [0.1738, 0.1738, 0.1738] + scale = 255.0 + min_dimensions = [32, 32] + max_dimensions = [672, 192] + mean = np.array(mean).reshape(shape).astype("float32") + std = np.array(std).reshape(shape).astype("float32") + + im_h, im_w = img.shape[:2] + if ( + min_dimensions[0] <= im_w <= max_dimensions[0] + and min_dimensions[1] <= im_h <= max_dimensions[1] + ): + pass + else: + img = Image.fromarray(np.uint8(img)) + img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions) + img = np.array(img) + im_h, im_w = img.shape[:2] + img = np.dstack([img, img, img]) + img = (img.astype("float32") * scale - mean) / std + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + divide_h = math.ceil(im_h / 16) * 16 + divide_w = math.ceil(im_w / 16) * 16 + img = np.pad( + img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1) + ) + img = img[:, :, np.newaxis].transpose(2, 0, 1) + img = img.astype("float32") + return img + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -552,6 +641,10 @@ class TextRecognizer(object): word_label_list = [] norm_img_mask_batch.append(norm_image_mask) word_label_list.append(word_label) + elif self.rec_algorithm == "LaTeXOCR": + norm_img = self.norm_img_latexocr(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) else: norm_img = self.resize_norm_img( img_list[indices[ino]], max_wh_ratio @@ -666,6 +759,29 @@ class TextRecognizer(object): if self.benchmark: self.autolog.times.stamp() preds = outputs + elif self.rec_algorithm == "LaTeXOCR": + inputs = [norm_img_batch] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = outputs + else: + input_names = self.predictor.get_input_names() + input_tensor = [] + for i in range(len(input_names)): + input_tensor_i = self.predictor.get_input_handle(input_names[i]) + input_tensor_i.copy_from_cpu(inputs[i]) + input_tensor.append(input_tensor_i) + self.input_tensor = input_tensor + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs else: if self.use_onnx: input_dict = {} @@ -692,6 +808,9 @@ class TextRecognizer(object): wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio, ) + elif self.postprocess_params["name"] == "LaTeXOCRDecode": + preds = [p.reshape([-1]) for p in preds] + rec_result = self.postprocess_op(preds) else: rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 0e04c8b636..22df30f866 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -183,6 +183,8 @@ def main(): elif isinstance(post_result, list) and isinstance(post_result[0], int): # for RFLearning CNT branch info = str(post_result[0]) + elif config["Architecture"]["algorithm"] == "LaTeXOCR": + info = str(post_result[0]) else: if len(post_result[0]) >= 2: info = post_result[0][0] + "\t" + str(post_result[0][1]) diff --git a/tools/program.py b/tools/program.py index 3055cdc6e7..1cc5bbac1c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -324,6 +324,8 @@ def train( preds = model(batch) elif algorithm in ["CAN"]: preds = model(batch[:3]) + elif algorithm in ["LaTeXOCR"]: + preds = model(batch) else: preds = model(images) preds = to_float32(preds) @@ -339,6 +341,8 @@ def train( preds = model(batch) elif algorithm in ["CAN"]: preds = model(batch[:3]) + elif algorithm in ["LaTeXOCR"]: + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch) @@ -360,6 +364,10 @@ def train( elif algorithm in ["CAN"]: model_type = "can" eval_class(preds[0], batch[2:], epoch_reset=(idx == 0)) + elif algorithm in ["LaTeXOCR"]: + model_type = "latexocr" + post_result = post_process_class(preds, batch[1], mode="train") + eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0)) else: if config["Loss"]["name"] in [ "MultiLoss", @@ -600,6 +608,8 @@ def eval( preds = model(batch) elif model_type in ["can"]: preds = model(batch[:3]) + elif model_type in ["latexocr"]: + preds = model(batch) elif model_type in ["sr"]: preds = model(batch) sr_img = preds["sr_img"] @@ -614,6 +624,8 @@ def eval( preds = model(batch) elif model_type in ["can"]: preds = model(batch[:3]) + elif model_type in ["latexocr"]: + preds = model(batch) elif model_type in ["sr"]: preds = model(batch) sr_img = preds["sr_img"] @@ -640,6 +652,9 @@ def eval( eval_class(preds, batch_numpy) elif model_type in ["can"]: eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0)) + elif model_type in ["latexocr"]: + post_result = post_process_class(preds, batch[1], "eval") + eval_class(post_result[0], post_result[1], epoch_reset=(idx == 0)) else: post_result = post_process_class(preds, batch_numpy[1]) eval_class(post_result, batch_numpy) @@ -777,6 +792,7 @@ def preprocess(is_train=False): "SVTR_HGNet", "ParseQ", "CPPD", + "LaTeXOCR", ] if use_xpu: