Add new recognition method "ParseQ" (#10836)
* Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>pull/10847/head
parent
ab86490138
commit
75d16610f4
|
@ -69,7 +69,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
|
|||
<a name="技术交流合作"></a>
|
||||
## 📖 技术交流合作
|
||||
- 飞桨AI套件([PaddleX](http://10.136.157.23:8080/paddle/paddleX))提供了飞桨模型训压推一站式全流程高效率开发平台,其使命是助力AI技术快速落地,愿景是使人人成为AI Developer!
|
||||
- PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向,已内置了36种基础单模型,例如RP-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等;集成了12种实用的产业方案,例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。
|
||||
- PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向,已内置了36种基础单模型,例如RT-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等;集成了12种实用的产业方案,例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。
|
||||
- PaddleX 提供了“工具箱”和“开发者”两种AI开发模式。工具箱模式可以无代码调优关键超参,开发者模式可以低代码进行单模型训压推和多模型串联推理,同时支持云端和本地端。
|
||||
- PaddleX 还支持联创开发,利润分成!目前 PaddleX 正在快速迭代,欢迎广大的个人开发者和企业开发者参与进来,共创繁荣的 AI 技术生态!
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 20
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/rec/parseq
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 500]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/parseq_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
num_heads: 8
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_parseq.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: OneCycle
|
||||
max_lr: 0.0007
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ParseQ
|
||||
in_channels: 3
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ViTParseQ
|
||||
img_size: [32, 128]
|
||||
patch_size: [4, 8]
|
||||
embed_dim: 384
|
||||
depth: 12
|
||||
num_heads: 6
|
||||
mlp_ratio: 4
|
||||
in_channels: 3
|
||||
Head:
|
||||
name: ParseQHead
|
||||
# Architecture
|
||||
max_text_length: 25
|
||||
embed_dim: 384
|
||||
dec_num_heads: 12
|
||||
dec_mlp_ratio: 4
|
||||
dec_depth: 1
|
||||
# Training
|
||||
perm_num: 6
|
||||
perm_forward: true
|
||||
perm_mirrored: true
|
||||
dropout: 0.1
|
||||
# Decoding mode (test)
|
||||
decode_ar: true
|
||||
refine_iters: 1
|
||||
|
||||
Loss:
|
||||
name: ParseQLoss
|
||||
|
||||
PostProcess:
|
||||
name: ParseQLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir:
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ParseQRecAug:
|
||||
aug_type: 0 # or 1
|
||||
- ParseQLabelEncode:
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 192
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir:
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ParseQLabelEncode: # Class handling label
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 384
|
||||
num_workers: 4
|
|
@ -81,13 +81,13 @@ PP-OCRv4检测模型对PP-OCRv3中的CML(Collaborative Mutual Learning) 协同
|
|||
<a name="3"></a>
|
||||
## 3. 识别优化
|
||||
|
||||
PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。直接将PP-OCRv2的识别模型,替换成SVTR_Tiny,识别准确率从74.8%提升到80.1%(+5.3%),但是预测速度慢了将近11倍,CPU上预测一条文本行,将近100ms。因此,如下图所示,PP-OCRv3采用如下6个优化策略进行识别模型加速。
|
||||
PP-OCRv4识别模型在PP-OCRv3的基础上进一步升级。如下图所示,整体的框架图保持了与PP-OCRv3识别模型相同的pipeline,分别进行了数据、网络结构、训练策略等方面的优化。
|
||||
|
||||
<div align="center">
|
||||
<img src="../ppocr_v4/v4_rec_pipeline.png" width=800>
|
||||
</div>
|
||||
|
||||
基于上述策略,PP-OCRv4识别模型相比PP-OCRv3,在速度可比的情况下,精度进一步提升4%。 具体消融实验如下所示:
|
||||
经过如图所示的策略优化,PP-OCRv4识别模型相比PP-OCRv3,在速度可比的情况下,精度进一步提升4%。 具体消融实验如下所示:
|
||||
|
||||
| ID | 策略 | 模型大小 | 精度 | 预测耗时(CPU openvino)|
|
||||
|-----|-----|--------|----| --- |
|
||||
|
@ -103,8 +103,8 @@ PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2
|
|||
|
||||
**(1)DF:数据挖掘方案**
|
||||
|
||||
DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据,通过置信度和预测结果等信息,对全量数据进行筛选。具体的:首先使用少量数据快速训练得到一个低精度模型,使用该低精度模型对千万级的数据进行预测,去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余数据。其次使用PP-OCRv3作为高精度模型,对剩余数据进行预测,去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。
|
||||
使用该策略,千万级别训练数据被精简至百万级,显著提升模型训练效率,模型训练时间从2周减少到5天,同时精度提升至72.7%(+1.2%)。
|
||||
DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据,通过置信度和预测结果等信息,对全量的训练数据进行筛选。具体的:首先使用少量数据快速训练得到一个低精度模型,使用该低精度模型对千万级的数据进行预测,去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余样本。其次使用PP-OCRv3作为高精度模型,对剩余数据进行预测,去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。
|
||||
使用该策略,千万级别训练数据被精简至百万级,模型训练时间从2周减少到5天,显著提升了训练效率,同时精度提升至72.7%(+1.2%)。
|
||||
|
||||
|
||||
<div align="center">
|
||||
|
@ -118,12 +118,12 @@ PP-LCNetV3系列模型是PP-LCNet系列模型的延续,覆盖了更大的精
|
|||
|
||||
**(3)Lite-Neck:精简参数的Neck结构**
|
||||
|
||||
Lite-Neck整体结构沿用PP-OCRv3版本,在参数上稍作精简,识别模型整体的模型大小可从12M降低到8.5M,而精度不变;在CTCHead中,将Neck输出特征的维度从64提升到120,此时模型大小从8.5M提升到9.6M,精度提升0.5%。
|
||||
Lite-Neck整体结构沿用PP-OCRv3版本的结构,在参数上稍作精简,识别模型整体的模型大小可从12M降低到8.5M,而精度不变;在CTCHead中,将Neck输出特征的维度从64提升到120,此时模型大小从8.5M提升到9.6M。
|
||||
|
||||
|
||||
**(4)GTC-NRTR:Attention指导CTC训练策略**
|
||||
|
||||
GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合多种文本特征的表达,有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导,相比SAR基于循环神经网络的结构,NRTR基于Transformer实现解码过程泛化能力更强,能有效指导CTC分支学习。解决简单场景下快速过拟合的问题。模型大小不变,识别精度提升至73.21%(+0.5%)。
|
||||
GTC(Guided Training of CTC),是PP-OCRv3识别模型的最有效的策略之一,融合多种文本特征的表达,有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导分支,相比V3版本中的SAR基于循环神经网络的结构,NRTR基于Transformer实现解码过程泛化能力更强,能有效指导CTC分支学习,解决简单场景下快速过拟合的问题。使用Lite-Neck和GTC-NRTR两个策略,识别精度提升至73.21%(+0.5%)。
|
||||
|
||||
<div align="center">
|
||||
<img src="../ppocr_v4/ppocrv4_gtc.png" width="500">
|
||||
|
@ -132,7 +132,7 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
|
|||
|
||||
**(5)Multi-Scale:多尺度训练策略**
|
||||
|
||||
动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增大模型的鲁棒性。在训练过程中随机选择(32,48,64)三种高度进行resize,实验证明在测试集上评估精度不掉,在端到端串联推理时,指标可以提升0.5%。
|
||||
动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增强识别模型在端到端串联使用时的鲁棒性。在训练时,每个iter从(32,48,64)三种高度中随机选择一种高度进行resize。实验证明,使用该策略,尽管在识别测试集上准确率没有提升,但在端到端串联评估时,指标提升0.5%。
|
||||
|
||||
<div align="center">
|
||||
<img src="../ppocr_v4/multi_scale.png" width="500">
|
||||
|
@ -143,9 +143,9 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
|
|||
|
||||
识别模型的蒸馏包含两个部分,NRTRhead蒸馏和CTCHead蒸馏;
|
||||
|
||||
对于NRTR head,使用了DKD loss蒸馏,使学生模型NRTR head输出的logits与教师NRTR head接近。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。
|
||||
对于NRTR head,使用了DKD loss蒸馏,拉近学生模型和教师模型的NRTR head logits。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。
|
||||
|
||||
对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从0.7377提升到0.7545。
|
||||
对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从74.72%提升到75.45%。
|
||||
|
||||
|
||||
|
||||
|
@ -169,11 +169,11 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
|
|||
| PP-OCRv3_en | 64.04% |
|
||||
| PP-OCRv4_en | 70.1% |
|
||||
|
||||
同时,也对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升5%以上,如下表所示:
|
||||
同时,对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升8%以上,如下表所示:
|
||||
|
||||
| Model | 拉丁语系 | 阿拉伯语系 | 日语 | 韩语 |
|
||||
|-----|-----|--------|----| --- |
|
||||
| PP-OCR_mul | 69.60% | 40.50% | 38.50% | 55.40% |
|
||||
| PP-OCRv3_mul | 75.20%| 45.37% | 45.80% | 60.10% |
|
||||
| PP-OCRv3_mul | 71.57%| 72.90% | 45.85% | 77.23% |
|
||||
| PP-OCRv4_mul | 80.00%| 75.48% | 56.50% | 83.25% |
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
- [x] [SPIN](./algorithm_rec_spin.md)
|
||||
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
|
||||
- [x] [RFL](./algorithm_rec_rfl.md)
|
||||
- [x] [ParseQ](./algorithm_rec_parseq.md)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -110,6 +111,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|
||||
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
|
||||
|
||||
|
||||
<a name="13"></a>
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# ParseQ
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
|
||||
> Darwin Bautista, Rowel Atienza
|
||||
> ECCV, 2021
|
||||
|
||||
原论文分别使用真实文本识别数据集(Real)和合成文本识别数据集(Synth)进行训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估。
|
||||
其中:
|
||||
- 真实文本识别数据集(Real)包含COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR, OpenVINO数据集
|
||||
- 合成文本识别数据集(Synth)包含MJSynth和SynthText数据集
|
||||
|
||||
在不同数据集上训练的算法的复现效果如下:
|
||||
|
||||
|数据集|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
|
||||
|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
|
||||
|||||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
|
||||
|
||||
训练
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
|
||||
```
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
|
||||
```
|
||||
|
||||
评估
|
||||
|
||||
```
|
||||
# GPU 评估, Global.pretrained_model 为待测权重
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
预测:
|
||||
|
||||
```
|
||||
# 预测使用的配置文件必须与训练一致
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将ParseQ文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
|
||||
```
|
||||
|
||||
ParseQ文本识别模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
由于C++预处理后处理还未支持ParseQ,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@InProceedings{bautista2022parseq,
|
||||
title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
|
||||
author={Bautista, Darwin and Atienza, Rowel},
|
||||
booktitle={European Conference on Computer Vision},
|
||||
pages={178--196},
|
||||
month={10},
|
||||
year={2022},
|
||||
publisher={Springer Nature Switzerland},
|
||||
address={Cham},
|
||||
doi={10.1007/978-3-031-19815-1_11},
|
||||
url={https://doi.org/10.1007/978-3-031-19815-1_11}
|
||||
}
|
||||
```
|
|
@ -41,8 +41,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|
|||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|
||||
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|
||||
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|
||||
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|
||||
|DB++|ResNet50|90.89%|82.66%|86.58%|[pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
|
||||
|
||||
On Total-Text dataset, the text detection result is as follows:
|
||||
|
@ -83,6 +83,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
||||
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
|
||||
- [x] [RFL](./algorithm_rec_rfl_en.md)
|
||||
- [x] [ParseQ](./algorithm_rec_parseq.md)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -107,6 +108,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
|
||||
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
|
||||
|
||||
|
||||
<a name="13"></a>
|
||||
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# PasreQ
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
|
||||
> Darwin Bautista, Rowel Atienza
|
||||
> ECCV, 2021
|
||||
|
||||
Using real datasets (real) and synthetic datsets (synth) for training respectively,and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
|
||||
- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets.
|
||||
- The synthesis datasets include MJSynth and SynthText datasets.
|
||||
|
||||
the algorithm reproduction effect is as follows:
|
||||
|
||||
|Training Dataset|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
|
||||
|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
|
||||
|||||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the SAR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
|
||||
```
|
||||
|
||||
For SAR text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@InProceedings{bautista2022parseq,
|
||||
title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
|
||||
author={Bautista, Darwin and Atienza, Rowel},
|
||||
booktitle={European Conference on Computer Vision},
|
||||
pages={178--196},
|
||||
month={10},
|
||||
year={2022},
|
||||
publisher={Springer Nature Switzerland},
|
||||
address={Cham},
|
||||
doi={10.1007/978-3-031-19815-1_11},
|
||||
url={https://doi.org/10.1007/978-3-031-19815-1_11}
|
||||
}
|
||||
```
|
|
@ -513,7 +513,7 @@ def get_model_config(type, version, model_type, lang):
|
|||
|
||||
def img_decode(content: bytes):
|
||||
np_arr = np.frombuffer(content, dtype=np.uint8)
|
||||
return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
|
||||
def check_img(img):
|
||||
|
|
|
@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
|
|||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
||||
RFLRecResizeImg, SVTRRecAug
|
||||
RFLRecResizeImg, SVTRRecAug, ParseQRecAug
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -316,6 +316,35 @@ class CVGaussianNoise(object):
|
|||
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
class CVPossionNoise(object):
|
||||
def __init__(self, lam=20):
|
||||
self.lam = lam
|
||||
if isinstance(lam, numbers.Number):
|
||||
self.lam = max(int(sample_asym(lam)), 1)
|
||||
elif isinstance(lam, (tuple, list)) and len(lam) == 2:
|
||||
self.lam = int(sample_uniform(lam[0], lam[1]))
|
||||
else:
|
||||
raise Exception('lam must be number or list with length 2')
|
||||
|
||||
def __call__(self, img):
|
||||
noise = np.random.poisson(lam=self.lam, size=img.shape)
|
||||
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
class CVGaussionBlur(object):
|
||||
def __init__(self, radius):
|
||||
self.radius = radius
|
||||
if isinstance(radius, numbers.Number):
|
||||
self.radius = max(int(sample_asym(radius)), 1)
|
||||
elif isinstance(radius, (tuple, list)) and len(radius) == 2:
|
||||
self.radius = int(sample_uniform(radius[0], radius[1]))
|
||||
else:
|
||||
raise Exception('radius must be number or list with length 2')
|
||||
|
||||
def __call__(self, img):
|
||||
fil = cv2.getGaussianKernel(ksize=self.radius, sigma=1, ktype=cv2.CV_32F)
|
||||
img = cv2.sepFilter2D(img, -1, fil, fil)
|
||||
return img
|
||||
|
||||
class CVMotionBlur(object):
|
||||
def __init__(self, degrees=12, angle=90):
|
||||
|
@ -427,6 +456,29 @@ class SVTRDeterioration(object):
|
|||
else:
|
||||
return img
|
||||
|
||||
class ParseQDeterioration(object):
|
||||
def __init__(self, var, degrees, lam, radius, factor, p=0.5):
|
||||
self.p = p
|
||||
transforms = []
|
||||
if var is not None:
|
||||
transforms.append(CVGaussianNoise(var=var))
|
||||
if degrees is not None:
|
||||
transforms.append(CVMotionBlur(degrees=degrees))
|
||||
if lam is not None:
|
||||
transforms.append(CVPossionNoise(lam=lam))
|
||||
if radius is not None:
|
||||
transforms.append(CVGaussionBlur(radius=radius))
|
||||
if factor is not None:
|
||||
transforms.append(CVRescale(factor=factor))
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
random.shuffle(self.transforms)
|
||||
transforms = Compose(self.transforms)
|
||||
return transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
class SVTRGeometry(object):
|
||||
def __init__(self,
|
||||
|
|
|
@ -1305,6 +1305,37 @@ class NRTRLabelEncode(BaseRecLabelEncode):
|
|||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
class ParseQLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
BOS = '[B]'
|
||||
EOS = '[E]'
|
||||
PAD = '[P]'
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
|
||||
super(ParseQLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len - 2:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
|
||||
text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
|
||||
return dict_character
|
||||
|
||||
class ViTSTRLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
|
|
@ -20,7 +20,7 @@ import copy
|
|||
from PIL import Image
|
||||
import PIL
|
||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
|
||||
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration, ParseQDeterioration
|
||||
from paddle.vision.transforms import Compose
|
||||
|
||||
|
||||
|
@ -204,6 +204,36 @@ class SVTRRecAug(object):
|
|||
data['image'] = img
|
||||
return data
|
||||
|
||||
class ParseQRecAug(object):
|
||||
def __init__(self,
|
||||
aug_type=0,
|
||||
geometry_p=0.5,
|
||||
deterioration_p=0.25,
|
||||
colorjitter_p=0.25,
|
||||
**kwargs):
|
||||
self.transforms = Compose([
|
||||
SVTRGeometry(
|
||||
aug_type=aug_type,
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=geometry_p), ParseQDeterioration(
|
||||
var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p),
|
||||
CVColorJitter(
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=colorjitter_p)
|
||||
])
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = self.transforms(img)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
class ClsResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
|
|
|
@ -43,6 +43,7 @@ from .rec_rfl_loss import RFLLoss
|
|||
from .rec_can_loss import CANLoss
|
||||
from .rec_satrn_loss import SATRNLoss
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_parseq_loss import ParseQLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -76,7 +77,7 @@ def build_loss(config):
|
|||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
|
||||
'SATRNLoss', 'NRTRLoss'
|
||||
'SATRNLoss', 'NRTRLoss', 'ParseQLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class ParseQLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(ParseQLoss, self).__init__()
|
||||
|
||||
def forward(self, predicts, targets):
|
||||
label = targets[1] # label
|
||||
label_len = targets[2]
|
||||
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
|
||||
tgt = label[:, :max_step]
|
||||
|
||||
logits_list = predicts['logits_list']
|
||||
pad_id = predicts['pad_id']
|
||||
eos_id = predicts['eos_id']
|
||||
|
||||
tgt_out = tgt[:, 1:]
|
||||
loss = 0
|
||||
loss_numel = 0
|
||||
n = (tgt_out != pad_id).sum().item()
|
||||
|
||||
for i, logits in enumerate(logits_list):
|
||||
loss += n * paddle.nn.functional.cross_entropy(input=logits, label=tgt_out.flatten(), ignore_index=pad_id)
|
||||
loss_numel += n
|
||||
if i == 1:
|
||||
tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out)
|
||||
n = (tgt_out != pad_id).sum().item()
|
||||
loss /= loss_numel
|
||||
|
||||
return {'loss': loss}
|
|
@ -50,11 +50,12 @@ def build_backbone(config, model_type):
|
|||
from .rec_shallow_cnn import ShallowCNN
|
||||
from .rec_lcnetv3 import PPLCNetV3
|
||||
from .rec_hgnet import PPHGNet_small
|
||||
from .rec_vit_parseq import ViTParseQ
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
|
||||
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small'
|
||||
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
|
||||
]
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,304 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.5/ppcls/arch/backbone/model_zoo/vision_transformer.py
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
||||
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
return tuple([x] * 2)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0., training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Layer):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
# B= paddle.shape(x)[0]
|
||||
N, C = x.shape[1:]
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
|
||||
self.num_heads)).transpose((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
|
||||
attn = nn.functional.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-5):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
|
||||
elif isinstance(norm_layer, Callable):
|
||||
self.norm1 = norm_layer(dim)
|
||||
else:
|
||||
raise TypeError(
|
||||
"The norm_layer must be str or paddle.nn.layer.Layer class")
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
|
||||
elif isinstance(norm_layer, Callable):
|
||||
self.norm2 = norm_layer(dim)
|
||||
else:
|
||||
raise TypeError(
|
||||
"The norm_layer must be str or paddle.nn.layer.Layer class")
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Layer):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2D(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
x = self.proj(x).flatten(2).transpose((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Layer):
|
||||
""" Vision Transformer with support for patch input
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
class_num=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-5,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.class_num = class_num
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=zeros_)
|
||||
self.add_parameter("pos_embed", self.pos_embed)
|
||||
self.cls_token = self.create_parameter(
|
||||
shape=(1, 1, embed_dim), default_initializer=zeros_)
|
||||
self.add_parameter("cls_token", self.cls_token)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, depth)
|
||||
|
||||
self.blocks = nn.LayerList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon) for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim,
|
||||
class_num) if class_num > 0 else Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed)
|
||||
self.out_channels = embed_dim
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = paddle.shape(x)[0]
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class ViTParseQ(VisionTransformer):
|
||||
def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0):
|
||||
super().__init__(img_size, patch_size, in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, class_num=0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.forward_features(x)
|
|
@ -40,6 +40,7 @@ def build_head(config):
|
|||
from .rec_rfl_head import RFLHead
|
||||
from .rec_can_head import CANHead
|
||||
from .rec_satrn_head import SATRNHead
|
||||
from .rec_parseq_head import ParseQHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -56,7 +57,7 @@ def build_head(config):
|
|||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
|
||||
'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal'
|
||||
'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal', 'ParseQHead'
|
||||
]
|
||||
|
||||
if config['name'] == 'DRRGHead':
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
||||
# reference: https://arxiv.org/abs/2207.06966
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
from .self_attention import WrapEncoderForFeature
|
||||
from .self_attention import WrapEncoder
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
import copy
|
||||
from itertools import permutations
|
||||
|
||||
|
||||
class DecoderLayer(paddle.nn.Layer):
|
||||
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
||||
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05):
|
||||
super().__init__()
|
||||
self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式
|
||||
self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True)
|
||||
self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward)
|
||||
self.dropout = paddle.nn.Dropout(p=dropout)
|
||||
self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model)
|
||||
self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
|
||||
self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
|
||||
self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
|
||||
self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
|
||||
self.dropout1 = paddle.nn.Dropout(p=dropout)
|
||||
self.dropout2 = paddle.nn.Dropout(p=dropout)
|
||||
self.dropout3 = paddle.nn.Dropout(p=dropout)
|
||||
if activation == 'gelu':
|
||||
self.activation = paddle.nn.GELU()
|
||||
|
||||
def __setstate__(self, state):
|
||||
if 'activation' not in state:
|
||||
state['activation'] = paddle.nn.functional.gelu
|
||||
super().__setstate__(state)
|
||||
|
||||
def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask):
|
||||
"""Forward pass for a single stream (i.e. content or query)
|
||||
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
||||
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
||||
memory is LayerNorm'd by ViT.
|
||||
"""
|
||||
if tgt_key_padding_mask is not None:
|
||||
tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False)
|
||||
tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1)
|
||||
else:
|
||||
tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask)
|
||||
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt, sa_weights, ca_weights
|
||||
|
||||
def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True):
|
||||
query_norm = self.norm_q(query)
|
||||
content_norm = self.norm_c(content)
|
||||
query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
|
||||
if update_content:
|
||||
content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0]
|
||||
return query, content
|
||||
|
||||
|
||||
def get_clones(module, N):
|
||||
return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class Decoder(paddle.nn.Layer):
|
||||
__constants__ = ['norm']
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm):
|
||||
super().__init__()
|
||||
self.layers = get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None):
|
||||
for i, mod in enumerate(self.layers):
|
||||
last = i == len(self.layers) - 1
|
||||
query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last)
|
||||
query = self.norm(query)
|
||||
return query
|
||||
|
||||
|
||||
class TokenEmbedding(paddle.nn.Layer):
|
||||
|
||||
def __init__(self, charset_size: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, tokens: paddle.Tensor):
|
||||
return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
|
||||
|
||||
|
||||
def trunc_normal_init(param, **kwargs):
|
||||
initializer = nn.initializer.TruncatedNormal(**kwargs)
|
||||
initializer(param, param.block)
|
||||
|
||||
|
||||
def constant_init(param, **kwargs):
|
||||
initializer = nn.initializer.Constant(**kwargs)
|
||||
initializer(param, param.block)
|
||||
|
||||
|
||||
def kaiming_normal_init(param, **kwargs):
|
||||
initializer = nn.initializer.KaimingNormal(**kwargs)
|
||||
initializer(param, param.block)
|
||||
|
||||
|
||||
class ParseQHead(nn.Layer):
|
||||
def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.bos_id = out_channels - 2
|
||||
self.eos_id = 0
|
||||
self.pad_id = out_channels - 1
|
||||
|
||||
self.max_label_length = max_text_length
|
||||
self.decode_ar = decode_ar
|
||||
self.refine_iters = refine_iters
|
||||
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
|
||||
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim))
|
||||
self.rng = np.random.default_rng()
|
||||
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
|
||||
self.perm_forward = perm_forward
|
||||
self.perm_mirrored = perm_mirrored
|
||||
self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2)
|
||||
self.text_embed = TokenEmbedding(out_channels, embed_dim)
|
||||
self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim])))
|
||||
self.pos_queries.stop_gradient = not True
|
||||
self.dropout = paddle.nn.Dropout(p=dropout)
|
||||
self._device = self.parameters()[0].place
|
||||
trunc_normal_init(self.pos_queries, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, paddle.nn.Linear):
|
||||
trunc_normal_init(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, value=0.0)
|
||||
elif isinstance(m, paddle.nn.Embedding):
|
||||
trunc_normal_init(m.weight, std=0.02)
|
||||
if m._padding_idx is not None:
|
||||
m.weight.data[m._padding_idx].zero_()
|
||||
elif isinstance(m, paddle.nn.Conv2D):
|
||||
kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, value=0.0)
|
||||
elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
|
||||
constant_init(m.weight, value=1.0)
|
||||
constant_init(m.bias, value=0.0)
|
||||
|
||||
def no_weight_decay(self):
|
||||
param_names = {'text_embed.embedding.weight', 'pos_queries'}
|
||||
enc_param_names = {('encoder.' + n) for n in self.encoder.
|
||||
no_weight_decay()}
|
||||
return param_names.union(enc_param_names)
|
||||
|
||||
def encode(self, img):
|
||||
return self.encoder(img)
|
||||
|
||||
def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None):
|
||||
N, L = tgt.shape
|
||||
null_ctx = self.text_embed(tgt[:, :1])
|
||||
if L != 1:
|
||||
tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:])
|
||||
tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
|
||||
else:
|
||||
tgt_emb = self.dropout(null_ctx)
|
||||
if tgt_query is None:
|
||||
tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
|
||||
tgt_query = self.dropout(tgt_query)
|
||||
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
|
||||
|
||||
def forward_test(self, memory, max_length=None):
|
||||
testing = max_length is None
|
||||
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
|
||||
bs = memory.shape[0]
|
||||
num_steps = max_length + 1
|
||||
|
||||
pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
|
||||
tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1)
|
||||
if self.decode_ar:
|
||||
tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64')
|
||||
tgt_in[:, (0)] = self.bos_id
|
||||
|
||||
logits = []
|
||||
for i in range(paddle.to_tensor(num_steps)):
|
||||
j = i + 1
|
||||
tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
|
||||
p_i = self.head(tgt_out)
|
||||
logits.append(p_i)
|
||||
if j < num_steps:
|
||||
tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
|
||||
if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all():
|
||||
break
|
||||
logits = paddle.concat(x=logits, axis=1)
|
||||
else:
|
||||
tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
|
||||
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
|
||||
logits = self.head(tgt_out)
|
||||
if self.refine_iters:
|
||||
temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2)
|
||||
posi = np.where(temp.cpu().numpy()==True)
|
||||
query_mask[posi] = 0
|
||||
bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
|
||||
for i in range(self.refine_iters):
|
||||
tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
|
||||
tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32')
|
||||
tgt_padding_mask = tgt_padding_mask.cpu()
|
||||
tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
|
||||
tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0
|
||||
tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
|
||||
logits = self.head(tgt_out)
|
||||
|
||||
final_output = {"predict":logits}
|
||||
|
||||
return final_output
|
||||
|
||||
def gen_tgt_perms(self, tgt):
|
||||
"""Generate shared permutations for the whole batch.
|
||||
This works because the same attention mask can be used for the shorter sequences
|
||||
because of the padding mask.
|
||||
"""
|
||||
max_num_chars = tgt.shape[1] - 2
|
||||
if max_num_chars == 1:
|
||||
return paddle.arange(end=3).unsqueeze(axis=0)
|
||||
perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
|
||||
max_perms = math.factorial(max_num_chars)
|
||||
if self.perm_mirrored:
|
||||
max_perms //= 2
|
||||
num_gen_perms = min(self.max_gen_perms, max_perms)
|
||||
if max_num_chars < 5:
|
||||
if max_num_chars == 4 and self.perm_mirrored:
|
||||
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
|
||||
else:
|
||||
selector = list(range(max_perms))
|
||||
perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector]
|
||||
if self.perm_forward:
|
||||
perm_pool = perm_pool[1:]
|
||||
perms = paddle.stack(x=perms)
|
||||
if len(perm_pool):
|
||||
i = self.rng.choice(len(perm_pool), size=num_gen_perms -
|
||||
len(perms), replace=False)
|
||||
perms = paddle.concat(x=[perms, perm_pool[i]])
|
||||
else:
|
||||
perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))])
|
||||
perms = paddle.stack(x=perms)
|
||||
if self.perm_mirrored:
|
||||
comp = perms.flip(axis=-1)
|
||||
x = paddle.stack(x=[perms, comp])
|
||||
perm_2 = list(range(x.ndim))
|
||||
perm_2[0] = 1
|
||||
perm_2[1] = 0
|
||||
perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
|
||||
bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
|
||||
eos_idx = paddle.full(shape=(len(perms), 1), fill_value=
|
||||
max_num_chars + 1, dtype=perms.dtype)
|
||||
perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
|
||||
if len(perms) > 1:
|
||||
perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
|
||||
return perms
|
||||
|
||||
def generate_attn_masks(self, perm):
|
||||
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
|
||||
:param perm: the permutation sequence. i = 0 is always the BOS
|
||||
:return: lookahead attention masks
|
||||
"""
|
||||
sz = perm.shape[0]
|
||||
mask = paddle.zeros(shape=(sz, sz))
|
||||
for i in range(sz):
|
||||
query_idx = perm[i].cpu().numpy().tolist()
|
||||
masked_keys = perm[i + 1:].cpu().numpy().tolist()
|
||||
if len(masked_keys) == 0:
|
||||
break
|
||||
mask[query_idx, masked_keys] = float('-inf')
|
||||
content_mask = mask[:-1, :-1].clone()
|
||||
mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf')
|
||||
query_mask = mask[1:, :-1]
|
||||
return content_mask, query_mask
|
||||
|
||||
def forward_train(self, memory, tgt):
|
||||
tgt_perms = self.gen_tgt_perms(tgt)
|
||||
tgt_in = tgt[:, :-1]
|
||||
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
|
||||
logits_list = []
|
||||
final_out = {}
|
||||
for i, perm in enumerate(tgt_perms):
|
||||
tgt_mask, query_mask = self.generate_attn_masks(perm)
|
||||
out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
|
||||
logits = self.head(out)
|
||||
if i == 0:
|
||||
final_out['predict'] = logits
|
||||
logits = logits.flatten(stop_axis=1)
|
||||
logits_list.append(logits)
|
||||
|
||||
final_out['logits_list'] = logits_list
|
||||
final_out['pad_id'] = self.pad_id
|
||||
final_out['eos_id'] = self.eos_id
|
||||
|
||||
return final_out
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
# feat : B, N, C
|
||||
# targets : labels, labels_len
|
||||
|
||||
if self.training:
|
||||
label = targets[0] # label
|
||||
label_len = targets[1]
|
||||
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
|
||||
crop_label = label[:, :max_step]
|
||||
final_out = self.forward_train(feat, crop_label)
|
||||
else:
|
||||
final_out = self.forward_test(feat)
|
||||
|
||||
return final_out
|
|
@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
|
|||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
||||
SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode
|
||||
SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode, ParseQLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
||||
|
@ -53,7 +53,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
|
||||
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
|
||||
'SATRNLabelDecode'
|
||||
'SATRNLabelDecode', 'ParseQLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -559,6 +559,95 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
% beg_or_end
|
||||
return idx
|
||||
|
||||
class ParseQLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
BOS = '[B]'
|
||||
EOS = '[E]'
|
||||
PAD = '[P]'
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(ParseQLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
self.max_text_length = kwargs.get('max_text_length', 25)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, dict):
|
||||
pred = preds['predict']
|
||||
else:
|
||||
pred = preds
|
||||
|
||||
char_num = len(self.character_str) + 1 # We don't predict <bos> nor <pad>, with only addition <eos>
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
B, L = pred.shape[:2]
|
||||
pred = np.reshape(pred, [-1, char_num])
|
||||
|
||||
preds_idx = np.argmax(pred, axis=1)
|
||||
preds_prob = np.max(pred, axis=1)
|
||||
|
||||
preds_idx = np.reshape(preds_idx, [B, L])
|
||||
preds_prob = np.reshape(preds_prob, [B, L])
|
||||
|
||||
if label is None:
|
||||
text = self.decode(preds_idx, preds_prob, raw=False)
|
||||
return text
|
||||
|
||||
text = self.decode(preds_idx, preds_prob, raw=False)
|
||||
label = self.decode(label, None, False)
|
||||
|
||||
return text, label
|
||||
|
||||
def decode(self, text_index, text_prob=None, raw=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
|
||||
index = text_index[batch_idx, :]
|
||||
prob = None
|
||||
if text_prob is not None:
|
||||
prob = text_prob[batch_idx, :]
|
||||
|
||||
if not raw:
|
||||
index, prob = self._filter(index, prob)
|
||||
|
||||
for idx in range(len(index)):
|
||||
if index[idx] in ignored_tokens:
|
||||
continue
|
||||
char_list.append(self.character[int(index[idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(prob[idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
|
||||
return result_list
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
|
||||
return dict_character
|
||||
|
||||
def _filter(self, ids, probs=None):
|
||||
ids = ids.tolist()
|
||||
try:
|
||||
eos_idx = ids.index(self.dict[self.EOS])
|
||||
except ValueError:
|
||||
eos_idx = len(ids) # Nothing to truncate.
|
||||
# Truncate after EOS
|
||||
ids = ids[:eos_idx]
|
||||
if probs is not None:
|
||||
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
|
||||
return ids, probs
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.dict[self.BOS], self.dict[self.EOS], self.dict[self.PAD]]
|
||||
|
||||
class SARLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
|
@ -29,18 +29,28 @@ fi
|
|||
sed -i 's/use_gpu/use_xpu/g' $FILENAME
|
||||
# disable benchmark as AutoLog required nvidia-smi command
|
||||
sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME
|
||||
# python has been updated to version 3.9 for xpu backend
|
||||
sed -i "s/python3.7/python3.9/g" $FILENAME
|
||||
dataline=`cat $FILENAME`
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
modelname=$(echo ${lines[1]} | cut -d ":" -f2)
|
||||
if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then
|
||||
sed -i "s/Global.epoch_num:lite_train_lite_infer=2/Global.epoch_num:lite_train_lite_infer=1/g" $FILENAME
|
||||
sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME
|
||||
sed -i "s/Global.use_xpu:True|True/Global.use_xpu:True/g" $FILENAME
|
||||
fi
|
||||
|
||||
# replace training config file
|
||||
grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \
|
||||
| while read line_num ; do
|
||||
train_cmd=$(func_parser_value "${lines[line_num-1]}")
|
||||
trainer_config=$(func_parser_config ${train_cmd})
|
||||
sed -i 's/use_gpu/use_xpu/g' "$REPO_ROOT_PATH/$trainer_config"
|
||||
sed -i 's/use_sync_bn: True/use_sync_bn: False/g' "$REPO_ROOT_PATH/$trainer_config"
|
||||
done
|
||||
|
||||
# change gpu to xpu in execution script
|
||||
|
|
|
@ -122,6 +122,12 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "ParseQ":
|
||||
postprocess_params = {
|
||||
'name': 'ParseQLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.postprocess_params = postprocess_params
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
|
@ -439,7 +445,7 @@ class TextRecognizer(object):
|
|||
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||
norm_img_batch.append(norm_img[0])
|
||||
elif self.rec_algorithm in ["SVTR", "SATRN"]:
|
||||
elif self.rec_algorithm in ["SVTR", "SATRN", "ParseQ"]:
|
||||
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
|
|
@ -231,7 +231,7 @@ def train(config,
|
|||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN",
|
||||
"RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet'
|
||||
"RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet', "ParseQ",
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
|
@ -664,7 +664,7 @@ def preprocess(is_train=False):
|
|||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN',
|
||||
'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG',
|
||||
'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet'
|
||||
'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet', 'ParseQ',
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue