Add new recognition method "ParseQ" (#10836)

* Update PP-OCRv4_introduction.md

* Update PP-OCRv4_introduction.md (#10616)

* Update PP-OCRv4_introduction.md

* Update PP-OCRv4_introduction.md

* Update PP-OCRv4_introduction.md

* Update README.md

* Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655)

* Don't break overall processing on a bad image

* Add preprocessing common to OCR tasks
Add preprocessing to options

* Update requirements.txt (#10656)

added missing pyyaml library

* [TIPC]update xpu tipc script (#10658)

* fix-typo (#10642)

Co-authored-by: Dennis <dvorst@users.noreply.github.com>
Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com>

* 修改数据增强导致的DSR报错 (#10662) (#10681)

* 修改数据增强导致的DSR报错

* 错误修改回滚

* Update algorithm_overview_en.md (#10670)

Fixed simple spelling errors.

* Implement recoginition method ParseQ

* Document update for new recognition method ParseQ

* add prediction for parseq

* Update rec_vit_parseq.yml

* Update rec_r31_sar.yml

* Update rec_r31_sar.yml

* Update rec_r50_fpn_srn.yml

* Update rec_vit_parseq.py

* Update rec_vit_parseq.yml

* Update rec_parseq_head.py

* Update rec_img_aug.py

* Update rec_vit_parseq.yml

* Update __init__.py

* Update predict_rec.py

* Update paddleocr.py

* Update requirements.txt

* Update utility.py

* Update utility.py

---------

Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com>
Co-authored-by: topduke <784990967@qq.com>
Co-authored-by: dyning <dyning.2003@163.com>
Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com>
Co-authored-by: itasli <ilyas.tasli@outlook.fr>
Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com>
Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com>
Co-authored-by: Dennis <dvorst@users.noreply.github.com>
Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com>
Co-authored-by: Dec20B <1192152456@qq.com>
Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
pull/10847/head
ToddBear 2023-09-07 16:36:47 +08:00 committed by GitHub
parent ab86490138
commit 75d16610f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1404 additions and 25 deletions

View File

@ -69,7 +69,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库助力
<a name="技术交流合作"></a>
## 📖 技术交流合作
- 飞桨AI套件([PaddleX](http://10.136.157.23:8080/paddle/paddleX))提供了飞桨模型训压推一站式全流程高效率开发平台其使命是助力AI技术快速落地愿景是使人人成为AI Developer
- PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向已内置了36种基础单模型例如RP-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等集成了12种实用的产业方案例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。
- PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向已内置了36种基础单模型例如RT-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等集成了12种实用的产业方案例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。
- PaddleX 提供了“工具箱”和“开发者”两种AI开发模式。工具箱模式可以无代码调优关键超参开发者模式可以低代码进行单模型训压推和多模型串联推理同时支持云端和本地端。
- PaddleX 还支持联创开发,利润分成!目前 PaddleX 正在快速迭代,欢迎广大的个人开发者和企业开发者参与进来,共创繁荣的 AI 技术生态!

View File

@ -0,0 +1,116 @@
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/parseq
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 500]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/dict/parseq_dict.txt
character_type: en
max_text_length: 25
num_heads: 8
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_parseq.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: OneCycle
max_lr: 0.0007
Architecture:
model_type: rec
algorithm: ParseQ
in_channels: 3
Transform:
Backbone:
name: ViTParseQ
img_size: [32, 128]
patch_size: [4, 8]
embed_dim: 384
depth: 12
num_heads: 6
mlp_ratio: 4
in_channels: 3
Head:
name: ParseQHead
# Architecture
max_text_length: 25
embed_dim: 384
dec_num_heads: 12
dec_mlp_ratio: 4
dec_depth: 1
# Training
perm_num: 6
perm_forward: true
perm_mirrored: true
dropout: 0.1
# Decoding mode (test)
decode_ar: true
refine_iters: 1
Loss:
name: ParseQLoss
PostProcess:
name: ParseQLabelDecode
Metric:
name: RecMetric
main_indicator: acc
is_filter: True
Train:
dataset:
name: LMDBDataSet
data_dir:
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ParseQRecAug:
aug_type: 0 # or 1
- ParseQLabelEncode:
- SVTRRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 192
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir:
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ParseQLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 32, 128]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 384
num_workers: 4

View File

@ -81,13 +81,13 @@ PP-OCRv4检测模型对PP-OCRv3中的CMLCollaborative Mutual Learning) 协同
<a name="3"></a>
## 3. 识别优化
PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息从而提升文本识别能力。直接将PP-OCRv2的识别模型替换成SVTR_Tiny识别准确率从74.8%提升到80.1%+5.3%但是预测速度慢了将近11倍CPU上预测一条文本行将近100ms。因此如下图所示PP-OCRv3采用如下6个优化策略进行识别模型加速
PP-OCRv4识别模型在PP-OCRv3的基础上进一步升级。如下图所示整体的框架图保持了与PP-OCRv3识别模型相同的pipeline分别进行了数据、网络结构、训练策略等方面的优化
<div align="center">
<img src="../ppocr_v4/v4_rec_pipeline.png" width=800>
</div>
基于上述策略PP-OCRv4识别模型相比PP-OCRv3在速度可比的情况下精度进一步提升4%。 具体消融实验如下所示:
经过如图所示的策略优化PP-OCRv4识别模型相比PP-OCRv3在速度可比的情况下精度进一步提升4%。 具体消融实验如下所示:
| ID | 策略 | 模型大小 | 精度 | 预测耗时CPU openvino)|
|-----|-----|--------|----| --- |
@ -103,8 +103,8 @@ PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2
**1DF数据挖掘方案**
DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据通过置信度和预测结果等信息对全量数据进行筛选。具体的首先使用少量数据快速训练得到一个低精度模型使用该低精度模型对千万级的数据进行预测去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余数据。其次使用PP-OCRv3作为高精度模型对剩余数据进行预测去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。
使用该策略,千万级别训练数据被精简至百万级,显著提升模型训练效率,模型训练时间从2周减少到5天同时精度提升至72.7%(+1.2%)。
DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据,通过置信度和预测结果等信息,对全量的训练数据进行筛选。具体的首先使用少量数据快速训练得到一个低精度模型使用该低精度模型对千万级的数据进行预测去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余样本。其次使用PP-OCRv3作为高精度模型对剩余数据进行预测去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。
使用该策略千万级别训练数据被精简至百万级模型训练时间从2周减少到5天,显著提升了训练效率同时精度提升至72.7%(+1.2%)。
<div align="center">
@ -118,12 +118,12 @@ PP-LCNetV3系列模型是PP-LCNet系列模型的延续覆盖了更大的精
**3Lite-Neck精简参数的Neck结构**
Lite-Neck整体结构沿用PP-OCRv3版本在参数上稍作精简识别模型整体的模型大小可从12M降低到8.5M而精度不变在CTCHead中将Neck输出特征的维度从64提升到120此时模型大小从8.5M提升到9.6M精度提升0.5%
Lite-Neck整体结构沿用PP-OCRv3版本的结构在参数上稍作精简识别模型整体的模型大小可从12M降低到8.5M而精度不变在CTCHead中将Neck输出特征的维度从64提升到120此时模型大小从8.5M提升到9.6M。
**4GTC-NRTRAttention指导CTC训练策略**
GTCGuided Training of CTC在PP-OCRv3中使用过的策略融合多种文本特征的表达有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导相比SAR基于循环神经网络的结构NRTR基于Transformer实现解码过程泛化能力更强能有效指导CTC分支学习。解决简单场景下快速过拟合的问题。模型大小不变识别精度提升至73.21%(+0.5%)。
GTCGuided Training of CTCPP-OCRv3识别模型的最有效的策略之一融合多种文本特征的表达有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导分支,相比V3版本中的SAR基于循环神经网络的结构NRTR基于Transformer实现解码过程泛化能力更强能有效指导CTC分支学习解决简单场景下快速过拟合的问题。使用Lite-Neck和GTC-NRTR两个策略识别精度提升至73.21%(+0.5%)。
<div align="center">
<img src="../ppocr_v4/ppocrv4_gtc.png" width="500">
@ -132,7 +132,7 @@ GTCGuided Training of CTC是在PP-OCRv3中使用过的策略融合
**5Multi-Scale多尺度训练策略**
动态尺度训练策略是在训练过程中随机resize输入图片的高度以增大模型的鲁棒性。在训练过程中随机选择324864三种高度进行resize实验证明在测试集上评估精度不掉在端到端串联推理时指标可以提升0.5%。
动态尺度训练策略是在训练过程中随机resize输入图片的高度以增强识别模型在端到端串联使用时的鲁棒性。在训练时每个iter从324864三种高度中随机选择一种高度进行resize。实验证明使用该策略尽管在识别测试集上准确率没有提升但在端到端串联评估时指标提升0.5%。
<div align="center">
<img src="../ppocr_v4/multi_scale.png" width="500">
@ -143,9 +143,9 @@ GTCGuided Training of CTC是在PP-OCRv3中使用过的策略融合
识别模型的蒸馏包含两个部分NRTRhead蒸馏和CTCHead蒸馏;
对于NRTR head使用了DKD loss蒸馏使学生模型NRTR head输出的logits与教师NRTR head接近。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和用于监督学生模型的backbone训练。通过实验我们发现加入DKD loss后计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度因此我们在这里使用的是不带label smoothing的cross entropy loss。
对于NRTR head使用了DKD loss蒸馏拉近学生模型和教师模型的NRTR head logits。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和用于监督学生模型的backbone训练。通过实验我们发现加入DKD loss后计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度因此我们在这里使用的是不带label smoothing的cross entropy loss。
对于CTCHead由于CTC的输出中存在Blank位即使教师模型和学生模型的预测结果一样二者的输出的logits分布也会存在差异影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中将CTC输出logits沿着文本长度维度计算均值将多字符识别问题转换为多字符分类问题用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略指标从0.7377提升到0.7545
对于CTCHead由于CTC的输出中存在Blank位即使教师模型和学生模型的预测结果一样二者的输出的logits分布也会存在差异影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中将CTC输出logits沿着文本长度维度计算均值将多字符识别问题转换为多字符分类问题用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略指标从74.72%提升到75.45%
@ -169,11 +169,11 @@ GTCGuided Training of CTC是在PP-OCRv3中使用过的策略融合
| PP-OCRv3_en | 64.04% |
| PP-OCRv4_en | 70.1% |
同时,对已支持的80余种语言识别模型进行了升级更新在有评估集的四种语系识别准确率平均提升5%以上,如下表所示:
同时对已支持的80余种语言识别模型进行了升级更新在有评估集的四种语系识别准确率平均提升8%以上,如下表所示:
| Model | 拉丁语系 | 阿拉伯语系 | 日语 | 韩语 |
|-----|-----|--------|----| --- |
| PP-OCR_mul | 69.60% | 40.50% | 38.50% | 55.40% |
| PP-OCRv3_mul | 75.20%| 45.37% | 45.80% | 60.10% |
| PP-OCRv3_mul | 71.57%| 72.90% | 45.85% | 77.23% |
| PP-OCRv4_mul | 80.00%| 75.48% | 56.50% | 83.25% |

View File

@ -86,6 +86,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型**欢迎广
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
- [x] [RFL](./algorithm_rec_rfl.md)
- [x] [ParseQ](./algorithm_rec_parseq.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程使用MJSynth和SynthText两个文字识别数据集训练在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估算法效果如下
@ -110,6 +111,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型**欢迎广
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
<a name="13"></a>

View File

@ -0,0 +1,124 @@
# ParseQ
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
> Darwin Bautista, Rowel Atienza
> ECCV, 2021
原论文分别使用真实文本识别数据集(Real)和合成文本识别数据集(Synth)进行训练在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估。
其中:
- 真实文本识别数据集(Real)包含COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR, OpenVINO数据集
- 合成文本识别数据集(Synth)包含MJSynth和SynthText数据集
在不同数据集上训练的算法的复现效果如下:
|数据集|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- | --- |
|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
|||||||
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化训练不同的识别模型只需要**更换配置文件**即可。
训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
```
评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将ParseQ文本识别训练过程中保存的模型转换成inference model。 [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
```
ParseQ文本识别模型推理可以执行如下命令
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++推理
由于C++预处理后处理还未支持ParseQ所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@InProceedings{bautista2022parseq,
title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
author={Bautista, Darwin and Atienza, Rowel},
booktitle={European Conference on Computer Vision},
pages={178--196},
month={10},
year={2022},
publisher={Springer Nature Switzerland},
address={Cham},
doi={10.1007/978-3-031-19815-1_11},
url={https://doi.org/10.1007/978-3-031-19815-1_11}
}
```

View File

@ -41,8 +41,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|DB++|ResNet50|90.89%|82.66%|86.58%|[pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
On Total-Text dataset, the text detection result is as follows:
@ -83,6 +83,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
- [x] [RFL](./algorithm_rec_rfl_en.md)
- [x] [ParseQ](./algorithm_rec_parseq.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@ -107,6 +108,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
<a name="13"></a>

View File

@ -0,0 +1,123 @@
# PasreQ
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
> Darwin Bautista, Rowel Atienza
> ECCV, 2021
Using real datasets (real) and synthetic datsets (synth) for training respectivelyand evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets.
- The synthesis datasets include MJSynth and SynthText datasets.
the algorithm reproduction effect is as follows:
|Training Dataset|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- | --- |
|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
|||||||
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the SAR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
```
For SAR text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@InProceedings{bautista2022parseq,
title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
author={Bautista, Darwin and Atienza, Rowel},
booktitle={European Conference on Computer Vision},
pages={178--196},
month={10},
year={2022},
publisher={Springer Nature Switzerland},
address={Cham},
doi={10.1007/978-3-031-19815-1_11},
url={https://doi.org/10.1007/978-3-031-19815-1_11}
}
```

View File

@ -513,7 +513,7 @@ def get_model_config(type, version, model_type, lang):
def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img):

View File

@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
RFLRecResizeImg, SVTRRecAug
RFLRecResizeImg, SVTRRecAug, ParseQRecAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste

View File

@ -316,6 +316,35 @@ class CVGaussianNoise(object):
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
class CVPossionNoise(object):
def __init__(self, lam=20):
self.lam = lam
if isinstance(lam, numbers.Number):
self.lam = max(int(sample_asym(lam)), 1)
elif isinstance(lam, (tuple, list)) and len(lam) == 2:
self.lam = int(sample_uniform(lam[0], lam[1]))
else:
raise Exception('lam must be number or list with length 2')
def __call__(self, img):
noise = np.random.poisson(lam=self.lam, size=img.shape)
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
class CVGaussionBlur(object):
def __init__(self, radius):
self.radius = radius
if isinstance(radius, numbers.Number):
self.radius = max(int(sample_asym(radius)), 1)
elif isinstance(radius, (tuple, list)) and len(radius) == 2:
self.radius = int(sample_uniform(radius[0], radius[1]))
else:
raise Exception('radius must be number or list with length 2')
def __call__(self, img):
fil = cv2.getGaussianKernel(ksize=self.radius, sigma=1, ktype=cv2.CV_32F)
img = cv2.sepFilter2D(img, -1, fil, fil)
return img
class CVMotionBlur(object):
def __init__(self, degrees=12, angle=90):
@ -427,6 +456,29 @@ class SVTRDeterioration(object):
else:
return img
class ParseQDeterioration(object):
def __init__(self, var, degrees, lam, radius, factor, p=0.5):
self.p = p
transforms = []
if var is not None:
transforms.append(CVGaussianNoise(var=var))
if degrees is not None:
transforms.append(CVMotionBlur(degrees=degrees))
if lam is not None:
transforms.append(CVPossionNoise(lam=lam))
if radius is not None:
transforms.append(CVGaussionBlur(radius=radius))
if factor is not None:
transforms.append(CVRescale(factor=factor))
self.transforms = transforms
def __call__(self, img):
if random.random() < self.p:
random.shuffle(self.transforms)
transforms = Compose(self.transforms)
return transforms(img)
else:
return img
class SVTRGeometry(object):
def __init__(self,

View File

@ -1305,6 +1305,37 @@ class NRTRLabelEncode(BaseRecLabelEncode):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
class ParseQLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
BOS = '[B]'
EOS = '[E]'
PAD = '[P]'
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(ParseQLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 2:
return None
data['length'] = np.array(len(text))
text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
return dict_character
class ViTSTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

View File

@ -20,7 +20,7 @@ import copy
from PIL import Image
import PIL
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration, ParseQDeterioration
from paddle.vision.transforms import Compose
@ -204,6 +204,36 @@ class SVTRRecAug(object):
data['image'] = img
return data
class ParseQRecAug(object):
def __init__(self,
aug_type=0,
geometry_p=0.5,
deterioration_p=0.25,
colorjitter_p=0.25,
**kwargs):
self.transforms = Compose([
SVTRGeometry(
aug_type=aug_type,
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=geometry_p), ParseQDeterioration(
var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p),
CVColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
p=colorjitter_p)
])
def __call__(self, data):
img = data['image']
img = self.transforms(img)
data['image'] = img
return data
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):

View File

@ -43,6 +43,7 @@ from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_parseq_loss import ParseQLoss
# cls loss
from .cls_loss import ClsLoss
@ -76,7 +77,7 @@ def build_loss(config):
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
'SATRNLoss', 'NRTRLoss'
'SATRNLoss', 'NRTRLoss', 'ParseQLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')

View File

@ -0,0 +1,50 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class ParseQLoss(nn.Layer):
def __init__(self, **kwargs):
super(ParseQLoss, self).__init__()
def forward(self, predicts, targets):
label = targets[1] # label
label_len = targets[2]
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
tgt = label[:, :max_step]
logits_list = predicts['logits_list']
pad_id = predicts['pad_id']
eos_id = predicts['eos_id']
tgt_out = tgt[:, 1:]
loss = 0
loss_numel = 0
n = (tgt_out != pad_id).sum().item()
for i, logits in enumerate(logits_list):
loss += n * paddle.nn.functional.cross_entropy(input=logits, label=tgt_out.flatten(), ignore_index=pad_id)
loss_numel += n
if i == 1:
tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out)
n = (tgt_out != pad_id).sum().item()
loss /= loss_numel
return {'loss': loss}

View File

@ -50,11 +50,12 @@ def build_backbone(config, model_type):
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit_parseq import ViTParseQ
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small'
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet

View File

@ -0,0 +1,304 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.5/ppcls/arch/backbone/model_zoo/vision_transformer.py
"""
from collections.abc import Callable
import numpy as np
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
trunc_normal_ = TruncatedNormal(std=.02)
normal_ = Normal
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
def to_2tuple(x):
return tuple([x] * 2)
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class DropPath(nn.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Layer):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Layer):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
# B= paddle.shape(x)[0]
N, C = x.shape[1:]
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-5):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm1 = norm_layer(dim)
else:
raise TypeError(
"The norm_layer must be str or paddle.nn.layer.Layer class")
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else:
raise TypeError(
"The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x
class VisionTransformer(nn.Layer):
""" Vision Transformer with support for patch input
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
class_num=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
epsilon=1e-5,
**kwargs):
super().__init__()
self.class_num = class_num
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_channels,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=zeros_)
self.add_parameter("pos_embed", self.pos_embed)
self.cls_token = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=zeros_)
self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.LayerList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
epsilon=epsilon) for i in range(depth)
])
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
# Classifier head
self.head = nn.Linear(embed_dim,
class_num) if class_num > 0 else Identity()
trunc_normal_(self.pos_embed)
self.out_channels = embed_dim
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
B = paddle.shape(x)[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
class ViTParseQ(VisionTransformer):
def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0):
super().__init__(img_size, patch_size, in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, class_num=0)
def forward(self, x):
return self.forward_features(x)

View File

@ -40,6 +40,7 @@ def build_head(config):
from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead
from .rec_satrn_head import SATRNHead
from .rec_parseq_head import ParseQHead
# cls head
from .cls_head import ClsHead
@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal'
'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal', 'ParseQHead'
]
if config['name'] == 'DRRGHead':

View File

@ -0,0 +1,342 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
# reference: https://arxiv.org/abs/2207.06966
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .self_attention import WrapEncoderForFeature
from .self_attention import WrapEncoder
from collections import OrderedDict
from typing import Optional
import copy
from itertools import permutations
class DecoderLayer(paddle.nn.Layer):
"""A Transformer decoder layer supporting two-stream attention (XLNet)
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05):
super().__init__()
self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式
self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True)
self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward)
self.dropout = paddle.nn.Dropout(p=dropout)
self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model)
self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
self.dropout1 = paddle.nn.Dropout(p=dropout)
self.dropout2 = paddle.nn.Dropout(p=dropout)
self.dropout3 = paddle.nn.Dropout(p=dropout)
if activation == 'gelu':
self.activation = paddle.nn.GELU()
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = paddle.nn.functional.gelu
super().__setstate__(state)
def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask):
"""Forward pass for a single stream (i.e. content or query)
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
Both tgt_kv and memory are expected to be LayerNorm'd too.
memory is LayerNorm'd by ViT.
"""
if tgt_key_padding_mask is not None:
tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False)
tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1)
else:
tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask)
tgt = tgt + self.dropout1(tgt2)
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
tgt = tgt + self.dropout3(tgt2)
return tgt, sa_weights, ca_weights
def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True):
query_norm = self.norm_q(query)
content_norm = self.norm_c(content)
query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
if update_content:
content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0]
return query, content
def get_clones(module, N):
return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
class Decoder(paddle.nn.Layer):
__constants__ = ['norm']
def __init__(self, decoder_layer, num_layers, norm):
super().__init__()
self.layers = get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None):
for i, mod in enumerate(self.layers):
last = i == len(self.layers) - 1
query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last)
query = self.norm(query)
return query
class TokenEmbedding(paddle.nn.Layer):
def __init__(self, charset_size: int, embed_dim: int):
super().__init__()
self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim)
self.embed_dim = embed_dim
def forward(self, tokens: paddle.Tensor):
return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
def trunc_normal_init(param, **kwargs):
initializer = nn.initializer.TruncatedNormal(**kwargs)
initializer(param, param.block)
def constant_init(param, **kwargs):
initializer = nn.initializer.Constant(**kwargs)
initializer(param, param.block)
def kaiming_normal_init(param, **kwargs):
initializer = nn.initializer.KaimingNormal(**kwargs)
initializer(param, param.block)
class ParseQHead(nn.Layer):
def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs):
super().__init__()
self.bos_id = out_channels - 2
self.eos_id = 0
self.pad_id = out_channels - 1
self.max_label_length = max_text_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim))
self.rng = np.random.default_rng()
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
self.perm_forward = perm_forward
self.perm_mirrored = perm_mirrored
self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2)
self.text_embed = TokenEmbedding(out_channels, embed_dim)
self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim])))
self.pos_queries.stop_gradient = not True
self.dropout = paddle.nn.Dropout(p=dropout)
self._device = self.parameters()[0].place
trunc_normal_init(self.pos_queries, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, paddle.nn.Linear):
trunc_normal_init(m.weight, std=0.02)
if m.bias is not None:
constant_init(m.bias, value=0.0)
elif isinstance(m, paddle.nn.Embedding):
trunc_normal_init(m.weight, std=0.02)
if m._padding_idx is not None:
m.weight.data[m._padding_idx].zero_()
elif isinstance(m, paddle.nn.Conv2D):
kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu')
if m.bias is not None:
constant_init(m.bias, value=0.0)
elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
constant_init(m.weight, value=1.0)
constant_init(m.bias, value=0.0)
def no_weight_decay(self):
param_names = {'text_embed.embedding.weight', 'pos_queries'}
enc_param_names = {('encoder.' + n) for n in self.encoder.
no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img):
return self.encoder(img)
def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None):
N, L = tgt.shape
null_ctx = self.text_embed(tgt[:, :1])
if L != 1:
tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
else:
tgt_emb = self.dropout(null_ctx)
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
tgt_query = self.dropout(tgt_query)
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
def forward_test(self, memory, max_length=None):
testing = max_length is None
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
bs = memory.shape[0]
num_steps = max_length + 1
pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1)
if self.decode_ar:
tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64')
tgt_in[:, (0)] = self.bos_id
logits = []
for i in range(paddle.to_tensor(num_steps)):
j = i + 1
tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all():
break
logits = paddle.concat(x=logits, axis=1)
else:
tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2)
posi = np.where(temp.cpu().numpy()==True)
query_mask[posi] = 0
bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
for i in range(self.refine_iters):
tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32')
tgt_padding_mask = tgt_padding_mask.cpu()
tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0
tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
logits = self.head(tgt_out)
final_output = {"predict":logits}
return final_output
def gen_tgt_perms(self, tgt):
"""Generate shared permutations for the whole batch.
This works because the same attention mask can be used for the shorter sequences
because of the padding mask.
"""
max_num_chars = tgt.shape[1] - 2
if max_num_chars == 1:
return paddle.arange(end=3).unsqueeze(axis=0)
perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
max_perms = math.factorial(max_num_chars)
if self.perm_mirrored:
max_perms //= 2
num_gen_perms = min(self.max_gen_perms, max_perms)
if max_num_chars < 5:
if max_num_chars == 4 and self.perm_mirrored:
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
else:
selector = list(range(max_perms))
perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector]
if self.perm_forward:
perm_pool = perm_pool[1:]
perms = paddle.stack(x=perms)
if len(perm_pool):
i = self.rng.choice(len(perm_pool), size=num_gen_perms -
len(perms), replace=False)
perms = paddle.concat(x=[perms, perm_pool[i]])
else:
perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))])
perms = paddle.stack(x=perms)
if self.perm_mirrored:
comp = perms.flip(axis=-1)
x = paddle.stack(x=[perms, comp])
perm_2 = list(range(x.ndim))
perm_2[0] = 1
perm_2[1] = 0
perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
eos_idx = paddle.full(shape=(len(perms), 1), fill_value=
max_num_chars + 1, dtype=perms.dtype)
perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
if len(perms) > 1:
perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
return perms
def generate_attn_masks(self, perm):
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
:param perm: the permutation sequence. i = 0 is always the BOS
:return: lookahead attention masks
"""
sz = perm.shape[0]
mask = paddle.zeros(shape=(sz, sz))
for i in range(sz):
query_idx = perm[i].cpu().numpy().tolist()
masked_keys = perm[i + 1:].cpu().numpy().tolist()
if len(masked_keys) == 0:
break
mask[query_idx, masked_keys] = float('-inf')
content_mask = mask[:-1, :-1].clone()
mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf')
query_mask = mask[1:, :-1]
return content_mask, query_mask
def forward_train(self, memory, tgt):
tgt_perms = self.gen_tgt_perms(tgt)
tgt_in = tgt[:, :-1]
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
logits_list = []
final_out = {}
for i, perm in enumerate(tgt_perms):
tgt_mask, query_mask = self.generate_attn_masks(perm)
out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
logits = self.head(out)
if i == 0:
final_out['predict'] = logits
logits = logits.flatten(stop_axis=1)
logits_list.append(logits)
final_out['logits_list'] = logits_list
final_out['pad_id'] = self.pad_id
final_out['eos_id'] = self.eos_id
return final_out
def forward(self, feat, targets=None):
# feat : B, N, C
# targets : labels, labels_len
if self.training:
label = targets[0] # label
label_len = targets[1]
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
crop_label = label[:, :max_step]
final_out = self.forward_train(feat, crop_label)
else:
final_out = self.forward_test(feat)
return final_out

View File

@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode
SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode, ParseQLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
@ -53,7 +53,7 @@ def build_post_process(config, global_config=None):
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
'SATRNLabelDecode'
'SATRNLabelDecode', 'ParseQLabelDecode'
]
if config['name'] == 'PSEPostProcess':

View File

@ -559,6 +559,95 @@ class SRNLabelDecode(BaseRecLabelDecode):
% beg_or_end
return idx
class ParseQLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
BOS = '[B]'
EOS = '[E]'
PAD = '[P]'
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ParseQLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, dict):
pred = preds['predict']
else:
pred = preds
char_num = len(self.character_str) + 1 # We don't predict <bos> nor <pad>, with only addition <eos>
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
B, L = pred.shape[:2]
pred = np.reshape(pred, [-1, char_num])
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [B, L])
preds_prob = np.reshape(preds_prob, [B, L])
if label is None:
text = self.decode(preds_idx, preds_prob, raw=False)
return text
text = self.decode(preds_idx, preds_prob, raw=False)
label = self.decode(label, None, False)
return text, label
def decode(self, text_index, text_prob=None, raw=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
index = text_index[batch_idx, :]
prob = None
if text_prob is not None:
prob = text_prob[batch_idx, :]
if not raw:
index, prob = self._filter(index, prob)
for idx in range(len(index)):
if index[idx] in ignored_tokens:
continue
char_list.append(self.character[int(index[idx])])
if text_prob is not None:
conf_list.append(prob[idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def add_special_char(self, dict_character):
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
return dict_character
def _filter(self, ids, probs=None):
ids = ids.tolist()
try:
eos_idx = ids.index(self.dict[self.EOS])
except ValueError:
eos_idx = len(ids) # Nothing to truncate.
# Truncate after EOS
ids = ids[:eos_idx]
if probs is not None:
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
return ids, probs
def get_ignored_tokens(self):
return [self.dict[self.BOS], self.dict[self.EOS], self.dict[self.PAD]]
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """

View File

@ -0,0 +1,94 @@
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~

View File

@ -29,18 +29,28 @@ fi
sed -i 's/use_gpu/use_xpu/g' $FILENAME
# disable benchmark as AutoLog required nvidia-smi command
sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME
# python has been updated to version 3.9 for xpu backend
sed -i "s/python3.7/python3.9/g" $FILENAME
dataline=`cat $FILENAME`
# parser params
IFS=$'\n'
lines=(${dataline})
modelname=$(echo ${lines[1]} | cut -d ":" -f2)
if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then
sed -i "s/Global.epoch_num:lite_train_lite_infer=2/Global.epoch_num:lite_train_lite_infer=1/g" $FILENAME
sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME
sed -i "s/Global.use_xpu:True|True/Global.use_xpu:True/g" $FILENAME
fi
# replace training config file
grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \
| while read line_num ; do
train_cmd=$(func_parser_value "${lines[line_num-1]}")
trainer_config=$(func_parser_config ${train_cmd})
sed -i 's/use_gpu/use_xpu/g' "$REPO_ROOT_PATH/$trainer_config"
sed -i 's/use_sync_bn: True/use_sync_bn: False/g' "$REPO_ROOT_PATH/$trainer_config"
done
# change gpu to xpu in execution script

View File

@ -122,6 +122,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "ParseQ":
postprocess_params = {
'name': 'ParseQLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
@ -439,7 +445,7 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm in ["SVTR", "SATRN"]:
elif self.rec_algorithm in ["SVTR", "SATRN", "ParseQ"]:
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]

View File

@ -231,7 +231,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN",
"RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet'
"RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet', "ParseQ",
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
@ -664,7 +664,7 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN',
'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG',
'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet'
'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet', 'ParseQ',
]
if use_xpu: