From bbca1e0d66298fd21abcb953a4c73379b8e04996 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 12 Jun 2022 13:53:29 +0800 Subject: [PATCH 01/11] add pr --- .gitignore | 2 +- configs/rec/rec_r32_gaspin_bilstm_att.yml | 117 +++++++ doc/doc_ch/algorithm_overview.md | 2 + doc/doc_ch/algorithm_rec_spin.md | 112 +++++++ doc/doc_en/algorithm_overview_en.md | 2 + doc/doc_en/algorithm_rec_spin_en.md | 112 +++++++ log/workerlog.0 | 131 ++++++++ ppocr/data/imaug/__init__.py | 3 +- ppocr/data/imaug/label_ops.py | 49 +++ ppocr/data/imaug/rec_img_aug.py | 45 +++ ppocr/losses/__init__.py | 4 +- ppocr/losses/rec_spin_att_loss.py | 41 +++ ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_resnet_32.py | 289 ++++++++++++++++++ ppocr/modeling/heads/__init__.py | 3 +- ppocr/modeling/heads/rec_spin_att_head.py | 203 ++++++++++++ ppocr/modeling/necks/rnn.py | 67 +++- ppocr/modeling/transforms/__init__.py | 4 +- .../modeling/transforms/gaspin_transformer.py | 286 +++++++++++++++++ ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 79 +++++ ppocr/utils/dict/spin_dict.txt | 68 +++++ .../rec_r32_gaspin_bilstm_att.yml | 118 +++++++ .../train_infer_python.txt | 53 ++++ tools/export_model.py | 6 + tools/infer/predict_rec.py | 26 ++ tools/program.py | 5 +- 27 files changed, 1823 insertions(+), 11 deletions(-) create mode 100644 configs/rec/rec_r32_gaspin_bilstm_att.yml create mode 100644 doc/doc_ch/algorithm_rec_spin.md create mode 100644 doc/doc_en/algorithm_rec_spin_en.md create mode 100644 log/workerlog.0 create mode 100644 ppocr/losses/rec_spin_att_loss.py create mode 100644 ppocr/modeling/backbones/rec_resnet_32.py create mode 100644 ppocr/modeling/heads/rec_spin_att_head.py create mode 100644 ppocr/modeling/transforms/gaspin_transformer.py create mode 100644 ppocr/utils/dict/spin_dict.txt create mode 100644 test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml create mode 100644 test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt diff --git a/.gitignore b/.gitignore index caf886a2b..34f0e0cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ __pycache__/ inference/ inference_results/ output/ - +train_data/ *.DS_Store *.vs *.user diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml new file mode 100644 index 000000000..236a17c43 --- /dev/null +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 50 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4, 5] + values: [0.001, 0.0003, 0.00009, 0.000027] + clip_norm: 5 + +Architecture: + model_type: rec + algorithm: SPIN + in_channels: 1 + Transform: + name: GA_SPIN + offsets: True + default_type: 6 + loc_lr: 0.1 + stn: True + Backbone: + name: ResNet32 + out_channels: 512 + Neck: + name: SequenceEncoder + encoder_type: cascadernn + hidden_size: 256 + out_channels: [256, 512] + with_linear: True + Head: + name: SPINAttentionHead + hidden_size: 256 + + +Loss: + name: SPINAttentionLoss + ignore_index: 0 + +PostProcess: + name: SPINAttnLabelDecode + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + use_space_char: False + + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINAttnLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 8 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINAttnLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 2 diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 6227a2149..ef95317ac 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -66,6 +66,7 @@ - [x] [SAR](./algorithm_rec_sar.md) - [x] [SEED](./algorithm_rec_seed.md) - [x] [SVTR](./algorithm_rec_svtr.md) +- [x] [SPIN](./algorithm_rec_spin.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -84,6 +85,7 @@ |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md new file mode 100644 index 000000000..c996992d2 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_spin.md @@ -0,0 +1,112 @@ +# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) +> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou +> AAAI, 2020 + +SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。 +使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +训练 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml +``` + +评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att +``` +SPIN文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee +``` + + +### 4.2 C++推理 + +由于C++预处理后处理还未支持SPIN,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@article{2020SPIN, + title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition}, + author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou}, + journal={AAAI2020}, + year={2020}, +} +``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 383cbe39b..608584e01 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [SAR](./algorithm_rec_sar_en.md) - [x] [SEED](./algorithm_rec_seed_en.md) - [x] [SVTR](./algorithm_rec_svtr_en.md) +- [x] [SPIN](./algorithm_rec_spin_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -83,6 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md new file mode 100644 index 000000000..43ab30ce7 --- /dev/null +++ b/doc/doc_en/algorithm_rec_spin_en.md @@ -0,0 +1,112 @@ +# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) +> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou +> AAAI, 2020 + +Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows: + +|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| + + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att +``` + +For SPIN text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@article{2020SPIN, + title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition}, + author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou}, + journal={AAAI2020}, + year={2020}, +} +``` diff --git a/log/workerlog.0 b/log/workerlog.0 new file mode 100644 index 000000000..7983c87df --- /dev/null +++ b/log/workerlog.0 @@ -0,0 +1,131 @@ +D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\socks.py:58: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working + from collections import Callable +D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\win32\lib\pywintypes.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses + import imp, sys, os +D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:943: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working + collections.MutableMapping.register(ParseResults) +D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:3245: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working + elif isinstance( exprs, collections.Iterable ): +[2022/06/12 13:42:08] ppocr INFO: Architecture : +[2022/06/12 13:42:08] ppocr INFO: Backbone : +[2022/06/12 13:42:08] ppocr INFO: name : ResNet32 +[2022/06/12 13:42:08] ppocr INFO: out_channels : 512 +[2022/06/12 13:42:08] ppocr INFO: Head : +[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256 +[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionHead +[2022/06/12 13:42:08] ppocr INFO: Neck : +[2022/06/12 13:42:08] ppocr INFO: encoder_type : cascadernn +[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256 +[2022/06/12 13:42:08] ppocr INFO: name : SequenceEncoder +[2022/06/12 13:42:08] ppocr INFO: out_channels : [256, 512] +[2022/06/12 13:42:08] ppocr INFO: with_linear : True +[2022/06/12 13:42:08] ppocr INFO: Transform : +[2022/06/12 13:42:08] ppocr INFO: default_type : 6 +[2022/06/12 13:42:08] ppocr INFO: loc_lr : 0.1 +[2022/06/12 13:42:08] ppocr INFO: name : GA_SPIN +[2022/06/12 13:42:08] ppocr INFO: offsets : True +[2022/06/12 13:42:08] ppocr INFO: stn : True +[2022/06/12 13:42:08] ppocr INFO: algorithm : SPIN +[2022/06/12 13:42:08] ppocr INFO: in_channels : 1 +[2022/06/12 13:42:08] ppocr INFO: model_type : rec +[2022/06/12 13:42:08] ppocr INFO: Eval : +[2022/06/12 13:42:08] ppocr INFO: dataset : +[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data +[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt'] +[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet +[2022/06/12 13:42:08] ppocr INFO: transforms : +[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage : +[2022/06/12 13:42:08] ppocr INFO: channel_first : False +[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR +[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None +[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg : +[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32] +[2022/06/12 13:42:08] ppocr INFO: interpolation : 2 +[2022/06/12 13:42:08] ppocr INFO: mean : [127.5] +[2022/06/12 13:42:08] ppocr INFO: std : [127.5] +[2022/06/12 13:42:08] ppocr INFO: KeepKeys : +[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length'] +[2022/06/12 13:42:08] ppocr INFO: loader : +[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8 +[2022/06/12 13:42:08] ppocr INFO: drop_last : False +[2022/06/12 13:42:08] ppocr INFO: num_workers : 2 +[2022/06/12 13:42:08] ppocr INFO: shuffle : False +[2022/06/12 13:42:08] ppocr INFO: Global : +[2022/06/12 13:42:08] ppocr INFO: cal_metric_during_train : True +[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt +[2022/06/12 13:42:08] ppocr INFO: checkpoints : ./inference/rec_r32_gaspin_bilstm_att/best_accuracy +[2022/06/12 13:42:08] ppocr INFO: distributed : False +[2022/06/12 13:42:08] ppocr INFO: epoch_num : 6 +[2022/06/12 13:42:08] ppocr INFO: eval_batch_step : [0, 2000] +[2022/06/12 13:42:08] ppocr INFO: infer_img : doc/imgs_words/ch/word_1.jpg +[2022/06/12 13:42:08] ppocr INFO: infer_mode : False +[2022/06/12 13:42:08] ppocr INFO: log_smooth_window : 50 +[2022/06/12 13:42:08] ppocr INFO: max_text_length : 25 +[2022/06/12 13:42:08] ppocr INFO: pretrained_model : None +[2022/06/12 13:42:08] ppocr INFO: print_batch_step : 50 +[2022/06/12 13:42:08] ppocr INFO: save_epoch_step : 3 +[2022/06/12 13:42:08] ppocr INFO: save_inference_dir : None +[2022/06/12 13:42:08] ppocr INFO: save_model_dir : ./output/rec/rec_r32_gaspin_bilstm_att/ +[2022/06/12 13:42:08] ppocr INFO: save_res_path : ./output/rec/predicts_r32_gaspin_bilstm_att.txt +[2022/06/12 13:42:08] ppocr INFO: use_gpu : True +[2022/06/12 13:42:08] ppocr INFO: use_space_char : False +[2022/06/12 13:42:08] ppocr INFO: use_visualdl : False +[2022/06/12 13:42:08] ppocr INFO: Loss : +[2022/06/12 13:42:08] ppocr INFO: ignore_index : 0 +[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionLoss +[2022/06/12 13:42:08] ppocr INFO: Metric : +[2022/06/12 13:42:08] ppocr INFO: is_filter : True +[2022/06/12 13:42:08] ppocr INFO: main_indicator : acc +[2022/06/12 13:42:08] ppocr INFO: name : RecMetric +[2022/06/12 13:42:08] ppocr INFO: Optimizer : +[2022/06/12 13:42:08] ppocr INFO: beta1 : 0.9 +[2022/06/12 13:42:08] ppocr INFO: beta2 : 0.999 +[2022/06/12 13:42:08] ppocr INFO: clip_norm : 5 +[2022/06/12 13:42:08] ppocr INFO: lr : +[2022/06/12 13:42:08] ppocr INFO: decay_epochs : [3, 4, 5] +[2022/06/12 13:42:08] ppocr INFO: name : Piecewise +[2022/06/12 13:42:08] ppocr INFO: values : [0.001, 0.0003, 9e-05, 2.7e-05] +[2022/06/12 13:42:08] ppocr INFO: name : AdamW +[2022/06/12 13:42:08] ppocr INFO: PostProcess : +[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt +[2022/06/12 13:42:08] ppocr INFO: name : SPINAttnLabelDecode +[2022/06/12 13:42:08] ppocr INFO: use_space_char : False +[2022/06/12 13:42:08] ppocr INFO: Train : +[2022/06/12 13:42:08] ppocr INFO: dataset : +[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data/ +[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt'] +[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet +[2022/06/12 13:42:08] ppocr INFO: transforms : +[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage : +[2022/06/12 13:42:08] ppocr INFO: channel_first : False +[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR +[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None +[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg : +[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32] +[2022/06/12 13:42:08] ppocr INFO: interpolation : 2 +[2022/06/12 13:42:08] ppocr INFO: mean : [127.5] +[2022/06/12 13:42:08] ppocr INFO: std : [127.5] +[2022/06/12 13:42:08] ppocr INFO: KeepKeys : +[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length'] +[2022/06/12 13:42:08] ppocr INFO: loader : +[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8 +[2022/06/12 13:42:08] ppocr INFO: drop_last : True +[2022/06/12 13:42:08] ppocr INFO: num_workers : 4 +[2022/06/12 13:42:08] ppocr INFO: shuffle : True +[2022/06/12 13:42:08] ppocr INFO: profiler_options : None +[2022/06/12 13:42:08] ppocr INFO: train with paddle 2.2.2 and device CUDAPlace(0) +[2022/06/12 13:42:08] ppocr INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt'] +W0612 13:42:08.814790 17600 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.1, Runtime API Version: 10.2 +W0612 13:42:08.832805 17600 device_context.cc:465] device: 0, cuDNN Version: 7.6. +[2022/06/12 13:42:12] ppocr INFO: resume from ./inference/rec_r32_gaspin_bilstm_att/best_accuracy +[2022/06/12 13:42:12] ppocr INFO: metric in ckpt *************** +[2022/06/12 13:42:12] ppocr INFO: acc:0.90589541082154 +[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:0.9627389225663741 +[2022/06/12 13:42:12] ppocr INFO: fps:1802.1068940938283 +[2022/06/12 13:42:12] ppocr INFO: best_epoch:6 +[2022/06/12 13:42:12] ppocr INFO: start_epoch:7 + eval model:: 0%| | 0/2 [00:00 self.max_text_len: + return None + data['length'] = np.array(len(text)) + target = [0] + text + [1] + padded_text = [0 for _ in range(self.max_text_len + 2)] + + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx \ No newline at end of file diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 32de2b3fc..8caa29e29 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -267,6 +267,51 @@ class PRENResizeImg(object): data['image'] = resized_img.astype(np.float32) return data +class SPINRecResizeImg(object): + def __init__(self, + image_shape, + interpolation=2, + mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5), + **kwargs): + self.image_shape = image_shape + + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.interpolation = interpolation + + def __call__(self, data): + img = data['image'] + # different interpolation type corresponding the OpenCV + if self.interpolation == 0: + interpolation = cv2.INTER_NEAREST + elif self.interpolation == 1: + interpolation = cv2.INTER_LINEAR + elif self.interpolation == 2: + interpolation = cv2.INTER_CUBIC + elif self.interpolation == 3: + interpolation = cv2.INTER_AREA + else: + raise Exception("Unsupported interpolation type !!!") + # Deal with the image error during image loading + if img is None: + return None + + img = cv2.resize(img, tuple(self.image_shape), interpolation) + img = np.array(img, np.float32) + img = np.expand_dims(img, -1) + img = img.transpose((2, 0, 1)) + # normalize the image + to_rgb = False + img = img.copy().astype(np.float32) + mean = np.float64(self.mean.reshape(1, -1)) + stdinv = 1 / np.float64(self.std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img -= mean + img *= stdinv + data['image'] = img + return data def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index de8419b7c..f748b94cf 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss from .rec_aster_loss import AsterLoss from .rec_pren_loss import PRENLoss from .rec_multi_loss import MultiLoss +from .rec_spin_att_loss import SPINAttentionLoss # cls loss from .cls_loss import ClsLoss @@ -61,7 +62,8 @@ def build_loss(config): 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', - 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', + 'SPINAttentionLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py new file mode 100644 index 000000000..37fd93da5 --- /dev/null +++ b/ppocr/losses/rec_spin_att_loss.py @@ -0,0 +1,41 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SPINAttentionLoss(nn.Layer): + def __init__(self, reduction='mean', ignore_index=-100, **kwargs): + super(SPINAttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index) + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + targets = targets[:, 1:] # remove [eos] in label + + label_lengths = batch[2].astype('int64') + batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[ + 1], predicts.shape[2] + assert len(targets.shape) == len(list(predicts.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) + targets = paddle.reshape(targets, [-1]) + + return {'loss': self.loss_func(inputs, targets)} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 072d6e0f8..6b525326a 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -32,10 +32,11 @@ def build_backbone(config, model_type): from .rec_micronet import MicroNet from .rec_efficientb3_pren import EfficientNetb3_PREN from .rec_svtrnet import SVTRNet + from .rec_resnet_32 import ResNet32 support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', - 'SVTRNet' + 'SVTRNet', 'ResNet32' ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py new file mode 100644 index 000000000..0b072dc5f --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_32.py @@ -0,0 +1,289 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +__all__ = ["ResNet32"] + +conv_weight_attr = nn.initializer.KaimingNormal() + +class ResNet32(nn.Layer): + """ + Feature Extractor is proposed in FAN Ref [1] + + Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017 + """ + + def __init__(self, in_channels, out_channels=512): + """ + + Args: + in_channels (int): input channel + output_channel (int): output channel + """ + super(ResNet32, self).__init__() + self.out_channels = out_channels + self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3]) + + def forward(self, inputs): + """ + Args: + inputs (torch.Tensor): input feature + + Returns: + torch.Tensor: output feature + + """ + return self.ConvNet(inputs) + +class BasicBlock(nn.Layer): + """Res-net Basic Block""" + expansion = 1 + + def __init__(self, inplanes, planes, + stride=1, downsample=None, + norm_type='BN', **kwargs): + """ + Args: + inplanes (int): input channel + planes (int): channels of the middle feature + stride (int): stride of the convolution + downsample (int): type of the down_sample + norm_type (str): type of the normalization + **kwargs (None): backup parameter + """ + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2D(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + """ + + Args: + in_planes (int): input channel + out_planes (int): channels of the middle feature + stride (int): stride of the convolution + Returns: + nn.Module: Conv2D with kernel = 3 + + """ + + return nn.Conv2D(in_planes, out_planes, + kernel_size=3, stride=stride, + padding=1, weight_attr=conv_weight_attr, + bias_attr=False) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input feature + + Returns: + torch.Tensor: output feature of the BasicBlock + + """ + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + + return out + +class ResNet(nn.Layer): + """Res-Net network structure""" + def __init__(self, input_channel, + output_channel, block, layers): + """ + + Args: + input_channel (int): input channel + output_channel (int): output channel + block (BasicBlock): convolution block + layers (list): layers of the block + """ + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), + int(output_channel / 2), + output_channel, + output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16)) + self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn0_2 = nn.BatchNorm2D(self.inplanes) + self.relu = nn.ReLU() + + self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, + self.output_channel_block[0], + layers[0]) + self.conv1 = nn.Conv2D(self.output_channel_block[0], + self.output_channel_block[0], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, + self.output_channel_block[1], + layers[1], stride=1) + self.conv2 = nn.Conv2D(self.output_channel_block[1], + self.output_channel_block[1], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False,) + self.bn2 = nn.BatchNorm2D(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2D(kernel_size=2, + stride=(2, 1), + padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], + layers[2], stride=1) + self.conv3 = nn.Conv2D(self.output_channel_block[2], + self.output_channel_block[2], + kernel_size=3, stride=1, + padding=1, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn3 = nn.BatchNorm2D(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], + layers[3], stride=1) + self.conv4_1 = nn.Conv2D(self.output_channel_block[3], + self.output_channel_block[3], + kernel_size=2, stride=(2, 1), + padding=(0, 1), + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3]) + self.conv4_2 = nn.Conv2D(self.output_channel_block[3], + self.output_channel_block[3], + kernel_size=2, stride=1, + padding=0, + weight_attr=conv_weight_attr, + bias_attr=False) + self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + """ + + Args: + block (block): convolution block + planes (int): input channels + blocks (list): layers of the block + stride (int): stride of the convolution + + Returns: + nn.Sequential: the combination of the convolution block + + """ + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, + weight_attr=conv_weight_attr, + bias_attr=False), + nn.BatchNorm2D(planes * block.expansion), + ) + + layers = list() + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input feature + + Returns: + torch.Tensor: output feature of the Resnet + + """ + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 1670ea38e..9b53462b8 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,6 +33,7 @@ def build_head(config): from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead from .rec_multi_head import MultiHead + from .rec_spin_att_head import SPINAttentionHead # cls head from .cls_head import ClsHead @@ -46,7 +47,7 @@ def build_head(config): 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead' + 'MultiHead', 'SPINAttentionHead' ] #table head diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py new file mode 100644 index 000000000..94e69a7ed --- /dev/null +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -0,0 +1,203 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class SPINAttentionHead(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(SPINAttentionHead, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = paddle.shape(inputs)[0] + num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence + + hidden = (paddle.zeros((batch_size, self.hidden_size)), + paddle.zeros((batch_size, self.hidden_size))) + output_hiddens = [] + if self.training: # for train + targets = targets[0] + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + char_onehots = None + outputs = None + alpha = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(outputs) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + next_input = probs_step.argmax(axis=1) + targets = next_input + if not self.training: + probs = paddle.nn.functional.softmax(probs, axis=2) + return probs + + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index c8a774b8c..32e626c3f 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -47,6 +47,67 @@ class EncoderWithRNN(nn.Layer): x, _ = self.lstm(x) return x +class BidirectionalLSTM(nn.Layer): + def __init__(self, input_size, + hidden_size, + output_size=None, + num_layers=1, + dropout=0, + direction=False, + time_major=False, + with_linear=False): + super(BidirectionalLSTM, self).__init__() + self.with_linear = with_linear + self.rnn = nn.LSTM(input_size, + hidden_size, + num_layers=num_layers, + dropout=dropout, + direction=direction, + time_major=time_major) + + # text recognition the specified structure LSTM with linear + if self.with_linear: + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input_feature): + """ + + Args: + input_feature (Torch.Tensor): visual feature [batch_size x T x input_size] + + Returns: + Torch.Tensor: LSTM output contextual feature [batch_size x T x output_size] + + """ + + # self.rnn.flatten_parameters() # error in export_model + recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + if self.with_linear: + output = self.linear(recurrent) # batch_size x T x output_size + return output + return recurrent + +class EncoderWithCascadeRNN(nn.Layer): + def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False): + super(EncoderWithCascadeRNN, self).__init__() + self.out_channels = out_channels[-1] + self.encoder = nn.LayerList( + [BidirectionalLSTM( + in_channels if i == 0 else out_channels[i - 1], + hidden_size, + output_size=out_channels[i], + num_layers=1, + direction='bidirectional', + with_linear=with_linear) + for i in range(num_layers)] + ) + + + def forward(self, x): + for i, l in enumerate(self.encoder): + x = l(x) + return x + class EncoderWithFC(nn.Layer): def __init__(self, in_channels, hidden_size): @@ -166,13 +227,17 @@ class SequenceEncoder(nn.Layer): 'reshape': Im2Seq, 'fc': EncoderWithFC, 'rnn': EncoderWithRNN, - 'svtr': EncoderWithSVTR + 'svtr': EncoderWithSVTR, + 'cascadernn': EncoderWithCascadeRNN } assert encoder_type in support_encoder_dict, '{} must in {}'.format( encoder_type, support_encoder_dict.keys()) if encoder_type == "svtr": self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, **kwargs) + elif encoder_type == 'cascadernn': + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, hidden_size, **kwargs) else: self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, hidden_size) diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 405ab3cc6..7e4ffdf46 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -18,8 +18,10 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS from .stn import STN_ON + from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN - support_dict = ['TPS', 'STN_ON'] + + support_dict = ['TPS', 'STN_ON', 'GA_SPIN'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py new file mode 100644 index 000000000..331c82aae --- /dev/null +++ b/ppocr/modeling/transforms/gaspin_transformer.py @@ -0,0 +1,286 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import itertools +import functools +from .tps import GridGenerator + +'''This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py +''' + +class SP_TransformerNetwork(nn.Layer): + """ + Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1] + Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021. + """ + + def __init__(self, nc=1, default_type=5): + """ Based on SPIN + Args: + nc (int): number of input channels (usually in 1 or 3) + default_type (int): the complexity of transformation intensities (by default set to 6 as the paper) + """ + super(SP_TransformerNetwork, self).__init__() + self.power_list = self.cal_K(default_type) + self.sigmoid = nn.Sigmoid() + self.bn = nn.InstanceNorm2D(nc) + + def cal_K(self, k=5): + """ + + Args: + k (int): the complexity of transformation intensities (by default set to 6 as the paper) + + Returns: + List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)] + + """ + from math import log + x = [] + if k != 0: + for i in range(1, k+1): + lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2) + upper = round(1/lower, 2) + x.append(lower) + x.append(upper) + x.append(1.00) + return x + + def forward(self, batch_I, weights, offsets, lambda_color=None): + """ + + Args: + batch_I (torch.Tensor): batch of input images [batch_size x nc x I_height x I_width] + weights: + offsets: the predicted offset by AIN, a scalar + lambda_color: the learnable update gate \alpha in Equa. (5) as + g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets} + + Returns: + torch.Tensor: transformed images by SPN as Equa. (4) in Ref. [1] + [batch_size x I_channel_num x I_r_height x I_r_width] + + """ + batch_I = (batch_I + 1) * 0.5 + if offsets is not None: + batch_I = batch_I*(1-lambda_color) + offsets*lambda_color + batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1) + batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1) + + batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1) + batch_weight_sum = self.bn(batch_weight_sum) + batch_weight_sum = self.sigmoid(batch_weight_sum) + batch_weight_sum = batch_weight_sum * 2 - 1 + return batch_weight_sum + +class GA_SPIN_Transformer(nn.Layer): + """ + Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1] + + + Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021. + """ + + def __init__(self, in_channels=1, + I_r_size=(32, 100), + offsets=False, + norm_type='BN', + default_type=6, + loc_lr=1, + stn=True): + """ + Args: + in_channels (int): channel of input features, + set it to 1 if the grayscale images and 3 if RGB input + I_r_size (tuple): size of rectified images (used in STN transformations) + inputDataType (str): the type of input data, + only support 'torch.cuda.FloatTensor' this version + offsets (bool): set it to False if use SPN w.o. AIN, + and set it to True if use SPIN (both with SPN and AIN) + norm_type (str): the normalization type of the module, + set it to 'BN' by default, 'IN' optionally + default_type (int): the K chromatic space, + set it to 3/5/6 depend on the complexity of transformation intensities + loc_lr (float): learning rate of location network + + """ + super(GA_SPIN_Transformer, self).__init__() + self.nc = in_channels + self.spt = True + self.offsets = offsets + self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN + self.I_r_size = I_r_size + self.out_channels = in_channels + if norm_type == 'BN': + norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True) + elif norm_type == 'IN': + norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False, + use_global_stats=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + + if self.spt: + self.sp_net = SP_TransformerNetwork(in_channels, + default_type) + self.spt_convnet = nn.Sequential( + # 32*100 + nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False), + norm_layer(32), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 16*50 + nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False), + norm_layer(64), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 8*25 + nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False), + norm_layer(128), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + # 4*12 + ) + self.stucture_fc1 = nn.Sequential( + nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False), + norm_layer(256), nn.ReLU(), + nn.MaxPool2D(kernel_size=2, stride=2), + nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False), + norm_layer(256), nn.ReLU(), # 2*6 + nn.MaxPool2D(kernel_size=2, stride=2), + nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False), + norm_layer(512), nn.ReLU(), # 1*3 + nn.AdaptiveAvgPool2D(1), + nn.Flatten(1, -1), # batch_size x 512 + nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)), + nn.BatchNorm1D(256), nn.ReLU() + ) + self.out_weight = 2*default_type+1 + self.spt_length = 2*default_type+1 + if offsets: + self.out_weight += 1 + if self.stn: + self.F = 20 + self.out_weight += self.F * 2 + self.GridGenerator = GridGenerator(self.F*2, self.F) + + # self.out_weight*=nc + # Init structure_fc2 in LocalizationNetwork + initial_bias = self.init_spin(default_type*2) + initial_bias = initial_bias.reshape(-1) + param_attr = ParamAttr( + learning_rate=loc_lr, + initializer=nn.initializer.Assign(np.zeros([256, self.out_weight]))) + bias_attr = ParamAttr( + learning_rate=loc_lr, + initializer=nn.initializer.Assign(initial_bias)) + self.stucture_fc2 = nn.Linear(256, self.out_weight, + weight_attr=param_attr, + bias_attr=bias_attr) + self.sigmoid = nn.Sigmoid() + + if offsets: + self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16, + 3, 1, 1, + bias_attr=False), + norm_layer(16), + nn.ReLU(),) + self.offset_fc2 = nn.Conv2D(16, in_channels, + 3, 1, 1) + self.pool = nn.MaxPool2D(2, 2) + + def init_spin(self, nz): + """ + Args: + nz (int): number of paired \betas exponents, which means the value of K x 2 + + """ + init_id = [0.00]*nz+[5.00] + if self.offsets: + init_id += [-5.00] + # init_id *=3 + init = np.array(init_id) + + if self.stn: + F = self.F + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + initial_bias = initial_bias.reshape(-1) + init = np.concatenate([init, initial_bias], axis=0) + return init + + def forward(self, x, return_weight=False): + """ + Args: + x (torch.cuda.FloatTensor): input image batch + return_weight (bool): set to False by default, + if set to True return the predicted offsets of AIN, denoted as x_{offsets} + + Returns: + torch.Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size + """ + + if self.spt: + feat = self.spt_convnet(x) + fc1 = self.stucture_fc1(feat) + sp_weight_fusion = self.stucture_fc2(fc1) + sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1]) + if self.offsets: # SPIN w. AIN + lambda_color = sp_weight_fusion[:, self.spt_length, 0] + lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + sp_weight = sp_weight_fusion[:, :self.spt_length, :] + offsets = self.pool(self.offset_fc2(self.offset_fc1(feat))) + + assert offsets.shape[2] == 2 # 2 + assert offsets.shape[3] == 6 # 16 + offsets = self.sigmoid(offsets) # v12 + + if return_weight: + return offsets + offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear') + + if self.stn: + batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2]) + build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0], + self.I_r_size[0], + self.I_r_size[1], + 2]) + + else: # SPIN w.o. AIN + sp_weight = sp_weight_fusion[:, :self.spt_length, :] + lambda_color, offsets = None, None + + if self.stn: + batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2]) + build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0], + self.I_r_size[0], + self.I_r_size[1], + 2]) + + x = self.sp_net(x, sp_weight, offsets, lambda_color) + if self.stn: + x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border') + return x diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f50b5f1c5..cf2575ee0 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ - SEEDLabelDecode, PRENLabelDecode + SEEDLabelDecode, PRENLabelDecode, SPINAttnLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode' + 'DistillationSARLabelDecode', 'SPINAttnLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index bf0fd890b..0df8f3ccd 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -752,3 +752,82 @@ class PRENLabelDecode(BaseRecLabelDecode): return text label = self.decode(label) return text, label + +class SPINAttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SPINAttnLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + [self.end_str] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == int(beg_idx): + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx \ No newline at end of file diff --git a/ppocr/utils/dict/spin_dict.txt b/ppocr/utils/dict/spin_dict.txt new file mode 100644 index 000000000..8ee8347fd --- /dev/null +++ b/ppocr/utils/dict/spin_dict.txt @@ -0,0 +1,68 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +: +( +' +- +, +% +> +. +[ +? +) +" += +_ +* +] +; +& ++ +$ +@ +/ +| +! +< +# +` +{ +~ +\ +} +^ \ No newline at end of file diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml new file mode 100644 index 000000000..e53396a03 --- /dev/null +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -0,0 +1,118 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 50 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4, 5] + values: [0.001, 0.0003, 0.00009, 0.000027] + + clip_norm: 5 + +Architecture: + model_type: rec + algorithm: SPIN + in_channels: 1 + Transform: + name: GA_SPIN + offsets: True + default_type: 6 + loc_lr: 0.1 + stn: True + Backbone: + name: ResNet32 + out_channels: 512 + Neck: + name: SequenceEncoder + encoder_type: cascadernn + hidden_size: 256 + out_channels: [256, 512] + with_linear: True + Head: + name: SPINAttentionHead + hidden_size: 256 + + +Loss: + name: SPINAttentionLoss + ignore_index: 0 + +PostProcess: + name: SPINAttnLabelDecode + character_dict_path: ./ppocr/utils/dict/spin_dict.txt + use_space_char: False + + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINAttnLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - SPINAttnLabelEncode: # Class handling label + - SPINRecResizeImg: + image_shape: [100, 32] + interpolation : 2 + mean: [127.5] + std: [127.5] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 1 diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt new file mode 100644 index 000000000..4915055a5 --- /dev/null +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:rec_r32_gaspin_bilstm_att +python:python +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN" +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1|6 +--use_tensorrt:False|False +--precision:fp32|int8 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,32,100]}] diff --git a/tools/export_model.py b/tools/export_model.py index 3ea0228f8..b8bc5e1ed 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): shape=[None, 3, 64, 512], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SPIN": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 1, 32, 100], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 3664ef2ca..09e13d8dc 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -69,6 +69,12 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SPIN": + postprocess_params = { + 'name': 'SPINAttnLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -250,6 +256,22 @@ class TextRecognizer(object): return padding_im, resize_shape, pad_shape, valid_ratio + def resize_norm_img_spin(self, img): + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC) + img = np.array(img, np.float32) + img = np.expand_dims(img, -1) + img = img.transpose((2, 0, 1)) + mean = [127.5] + std = [127.5] + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + mean = np.float32(mean.reshape(1, -1)) + stdinv = 1 / np.float32(std.reshape(1, -1)) + img -= mean + img *= stdinv + return img def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -300,6 +322,10 @@ class TextRecognizer(object): self.rec_image_shape) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == 'SPIN': + norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) else: norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) diff --git a/tools/program.py b/tools/program.py index aa0d2698c..51c73d3e5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -207,7 +207,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: @@ -564,7 +564,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', + 'SPIN' ] if use_xpu: From 1cda437c4dad19d529f9dcb97c264038ab3e3998 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 09:20:59 +0800 Subject: [PATCH 02/11] modified pr --- .gitignore | 1 + configs/rec/rec_r32_gaspin_bilstm_att.yml | 1 - ppocr/losses/rec_spin_att_loss.py | 6 +- ppocr/modeling/heads/rec_spin_att_head.py | 94 +------------------ ppocr/modeling/necks/rnn.py | 11 --- .../modeling/transforms/gaspin_transformer.py | 13 ++- tools/export_model.py | 6 -- 7 files changed, 13 insertions(+), 119 deletions(-) diff --git a/.gitignore b/.gitignore index 34f0e0cc9..3300be325 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ inference/ inference_results/ output/ train_data/ +log/ *.DS_Store *.vs *.user diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index 236a17c43..e8235415c 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -61,7 +61,6 @@ Loss: PostProcess: name: SPINAttnLabelDecode - character_dict_path: ./ppocr/utils/dict/spin_dict.txt use_space_char: False diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py index 37fd93da5..195780c7b 100644 --- a/ppocr/losses/rec_spin_att_loss.py +++ b/ppocr/losses/rec_spin_att_loss.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -19,6 +20,9 @@ from __future__ import print_function import paddle from paddle import nn +'''This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR +''' class SPINAttentionLoss(nn.Layer): def __init__(self, reduction='mean', ignore_index=-100, **kwargs): diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py index 94e69a7ed..07a58b083 100644 --- a/ppocr/modeling/heads/rec_spin_att_head.py +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -80,98 +80,6 @@ class SPINAttentionHead(nn.Layer): return probs -class AttentionGRUCell(nn.Layer): - def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): - super(AttentionGRUCell, self).__init__() - self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.score = nn.Linear(hidden_size, 1, bias_attr=False) - - self.rnn = nn.GRUCell( - input_size=input_size + num_embeddings, hidden_size=hidden_size) - - self.hidden_size = hidden_size - - def forward(self, prev_hidden, batch_H, char_onehots): - - batch_H_proj = self.i2h(batch_H) - prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) - - res = paddle.add(batch_H_proj, prev_hidden_proj) - res = paddle.tanh(res) - e = self.score(res) - - alpha = F.softmax(e, axis=1) - alpha = paddle.transpose(alpha, [0, 2, 1]) - context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) - concat_context = paddle.concat([context, char_onehots], 1) - - cur_hidden = self.rnn(concat_context, prev_hidden) - - return cur_hidden, alpha - - -class AttentionLSTM(nn.Layer): - def __init__(self, in_channels, out_channels, hidden_size, **kwargs): - super(AttentionLSTM, self).__init__() - self.input_size = in_channels - self.hidden_size = hidden_size - self.num_classes = out_channels - - self.attention_cell = AttentionLSTMCell( - in_channels, hidden_size, out_channels, use_gru=False) - self.generator = nn.Linear(hidden_size, out_channels) - - def _char_to_onehot(self, input_char, onehot_dim): - input_ont_hot = F.one_hot(input_char, onehot_dim) - return input_ont_hot - - def forward(self, inputs, targets=None, batch_max_length=25): - batch_size = inputs.shape[0] - num_steps = batch_max_length - - hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( - (batch_size, self.hidden_size))) - output_hiddens = [] - - if targets is not None: - for i in range(num_steps): - # one-hot vectors for a i-th char - char_onehots = self._char_to_onehot( - targets[:, i], onehot_dim=self.num_classes) - hidden, alpha = self.attention_cell(hidden, inputs, - char_onehots) - - hidden = (hidden[1][0], hidden[1][1]) - output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) - output = paddle.concat(output_hiddens, axis=1) - probs = self.generator(output) - - else: - targets = paddle.zeros(shape=[batch_size], dtype="int32") - probs = None - - for i in range(num_steps): - char_onehots = self._char_to_onehot( - targets, onehot_dim=self.num_classes) - hidden, alpha = self.attention_cell(hidden, inputs, - char_onehots) - probs_step = self.generator(hidden[0]) - hidden = (hidden[1][0], hidden[1][1]) - if probs is None: - probs = paddle.unsqueeze(probs_step, axis=1) - else: - probs = paddle.concat( - [probs, paddle.unsqueeze( - probs_step, axis=1)], axis=1) - - next_input = probs_step.argmax(axis=1) - - targets = next_input - - return probs - - class AttentionLSTMCell(nn.Layer): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionLSTMCell, self).__init__() diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index 32e626c3f..33be9400b 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -70,17 +70,6 @@ class BidirectionalLSTM(nn.Layer): self.linear = nn.Linear(hidden_size * 2, output_size) def forward(self, input_feature): - """ - - Args: - input_feature (Torch.Tensor): visual feature [batch_size x T x input_size] - - Returns: - Torch.Tensor: LSTM output contextual feature [batch_size x T x output_size] - - """ - - # self.rnn.flatten_parameters() # error in export_model recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) if self.with_linear: output = self.linear(recurrent) # batch_size x T x output_size diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py index 331c82aae..9440e360d 100644 --- a/ppocr/modeling/transforms/gaspin_transformer.py +++ b/ppocr/modeling/transforms/gaspin_transformer.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -71,14 +71,14 @@ class SP_TransformerNetwork(nn.Layer): """ Args: - batch_I (torch.Tensor): batch of input images [batch_size x nc x I_height x I_width] + batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width] weights: offsets: the predicted offset by AIN, a scalar lambda_color: the learnable update gate \alpha in Equa. (5) as g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets} Returns: - torch.Tensor: transformed images by SPN as Equa. (4) in Ref. [1] + Tensor: transformed images by SPN as Equa. (4) in Ref. [1] [batch_size x I_channel_num x I_r_height x I_r_width] """ @@ -114,8 +114,6 @@ class GA_SPIN_Transformer(nn.Layer): in_channels (int): channel of input features, set it to 1 if the grayscale images and 3 if RGB input I_r_size (tuple): size of rectified images (used in STN transformations) - inputDataType (str): the type of input data, - only support 'torch.cuda.FloatTensor' this version offsets (bool): set it to False if use SPN w.o. AIN, and set it to True if use SPIN (both with SPN and AIN) norm_type (str): the normalization type of the module, @@ -123,6 +121,7 @@ class GA_SPIN_Transformer(nn.Layer): default_type (int): the K chromatic space, set it to 3/5/6 depend on the complexity of transformation intensities loc_lr (float): learning rate of location network + stn (bool): whther to use stn. """ super(GA_SPIN_Transformer, self).__init__() @@ -233,12 +232,12 @@ class GA_SPIN_Transformer(nn.Layer): def forward(self, x, return_weight=False): """ Args: - x (torch.cuda.FloatTensor): input image batch + x (Tensor): input image batch return_weight (bool): set to False by default, if set to True return the predicted offsets of AIN, denoted as x_{offsets} Returns: - torch.Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size + Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size """ if self.spt: diff --git a/tools/export_model.py b/tools/export_model.py index b8bc5e1ed..3ea0228f8 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -73,12 +73,6 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): shape=[None, 3, 64, 512], dtype="float32"), ] model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SPIN": - other_shape = [ - paddle.static.InputSpec( - shape=[None, 1, 32, 100], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": From f9e379ab630b51718b867a125e2036082201b205 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 09:24:48 +0800 Subject: [PATCH 03/11] remove log --- log/workerlog.0 | 131 ------------------------------------------------ 1 file changed, 131 deletions(-) delete mode 100644 log/workerlog.0 diff --git a/log/workerlog.0 b/log/workerlog.0 deleted file mode 100644 index 7983c87df..000000000 --- a/log/workerlog.0 +++ /dev/null @@ -1,131 +0,0 @@ -D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\socks.py:58: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working - from collections import Callable -D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\win32\lib\pywintypes.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses - import imp, sys, os -D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:943: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working - collections.MutableMapping.register(ParseResults) -D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:3245: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working - elif isinstance( exprs, collections.Iterable ): -[2022/06/12 13:42:08] ppocr INFO: Architecture : -[2022/06/12 13:42:08] ppocr INFO: Backbone : -[2022/06/12 13:42:08] ppocr INFO: name : ResNet32 -[2022/06/12 13:42:08] ppocr INFO: out_channels : 512 -[2022/06/12 13:42:08] ppocr INFO: Head : -[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256 -[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionHead -[2022/06/12 13:42:08] ppocr INFO: Neck : -[2022/06/12 13:42:08] ppocr INFO: encoder_type : cascadernn -[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256 -[2022/06/12 13:42:08] ppocr INFO: name : SequenceEncoder -[2022/06/12 13:42:08] ppocr INFO: out_channels : [256, 512] -[2022/06/12 13:42:08] ppocr INFO: with_linear : True -[2022/06/12 13:42:08] ppocr INFO: Transform : -[2022/06/12 13:42:08] ppocr INFO: default_type : 6 -[2022/06/12 13:42:08] ppocr INFO: loc_lr : 0.1 -[2022/06/12 13:42:08] ppocr INFO: name : GA_SPIN -[2022/06/12 13:42:08] ppocr INFO: offsets : True -[2022/06/12 13:42:08] ppocr INFO: stn : True -[2022/06/12 13:42:08] ppocr INFO: algorithm : SPIN -[2022/06/12 13:42:08] ppocr INFO: in_channels : 1 -[2022/06/12 13:42:08] ppocr INFO: model_type : rec -[2022/06/12 13:42:08] ppocr INFO: Eval : -[2022/06/12 13:42:08] ppocr INFO: dataset : -[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data -[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt'] -[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet -[2022/06/12 13:42:08] ppocr INFO: transforms : -[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage : -[2022/06/12 13:42:08] ppocr INFO: channel_first : False -[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR -[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None -[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg : -[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32] -[2022/06/12 13:42:08] ppocr INFO: interpolation : 2 -[2022/06/12 13:42:08] ppocr INFO: mean : [127.5] -[2022/06/12 13:42:08] ppocr INFO: std : [127.5] -[2022/06/12 13:42:08] ppocr INFO: KeepKeys : -[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length'] -[2022/06/12 13:42:08] ppocr INFO: loader : -[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8 -[2022/06/12 13:42:08] ppocr INFO: drop_last : False -[2022/06/12 13:42:08] ppocr INFO: num_workers : 2 -[2022/06/12 13:42:08] ppocr INFO: shuffle : False -[2022/06/12 13:42:08] ppocr INFO: Global : -[2022/06/12 13:42:08] ppocr INFO: cal_metric_during_train : True -[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt -[2022/06/12 13:42:08] ppocr INFO: checkpoints : ./inference/rec_r32_gaspin_bilstm_att/best_accuracy -[2022/06/12 13:42:08] ppocr INFO: distributed : False -[2022/06/12 13:42:08] ppocr INFO: epoch_num : 6 -[2022/06/12 13:42:08] ppocr INFO: eval_batch_step : [0, 2000] -[2022/06/12 13:42:08] ppocr INFO: infer_img : doc/imgs_words/ch/word_1.jpg -[2022/06/12 13:42:08] ppocr INFO: infer_mode : False -[2022/06/12 13:42:08] ppocr INFO: log_smooth_window : 50 -[2022/06/12 13:42:08] ppocr INFO: max_text_length : 25 -[2022/06/12 13:42:08] ppocr INFO: pretrained_model : None -[2022/06/12 13:42:08] ppocr INFO: print_batch_step : 50 -[2022/06/12 13:42:08] ppocr INFO: save_epoch_step : 3 -[2022/06/12 13:42:08] ppocr INFO: save_inference_dir : None -[2022/06/12 13:42:08] ppocr INFO: save_model_dir : ./output/rec/rec_r32_gaspin_bilstm_att/ -[2022/06/12 13:42:08] ppocr INFO: save_res_path : ./output/rec/predicts_r32_gaspin_bilstm_att.txt -[2022/06/12 13:42:08] ppocr INFO: use_gpu : True -[2022/06/12 13:42:08] ppocr INFO: use_space_char : False -[2022/06/12 13:42:08] ppocr INFO: use_visualdl : False -[2022/06/12 13:42:08] ppocr INFO: Loss : -[2022/06/12 13:42:08] ppocr INFO: ignore_index : 0 -[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionLoss -[2022/06/12 13:42:08] ppocr INFO: Metric : -[2022/06/12 13:42:08] ppocr INFO: is_filter : True -[2022/06/12 13:42:08] ppocr INFO: main_indicator : acc -[2022/06/12 13:42:08] ppocr INFO: name : RecMetric -[2022/06/12 13:42:08] ppocr INFO: Optimizer : -[2022/06/12 13:42:08] ppocr INFO: beta1 : 0.9 -[2022/06/12 13:42:08] ppocr INFO: beta2 : 0.999 -[2022/06/12 13:42:08] ppocr INFO: clip_norm : 5 -[2022/06/12 13:42:08] ppocr INFO: lr : -[2022/06/12 13:42:08] ppocr INFO: decay_epochs : [3, 4, 5] -[2022/06/12 13:42:08] ppocr INFO: name : Piecewise -[2022/06/12 13:42:08] ppocr INFO: values : [0.001, 0.0003, 9e-05, 2.7e-05] -[2022/06/12 13:42:08] ppocr INFO: name : AdamW -[2022/06/12 13:42:08] ppocr INFO: PostProcess : -[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt -[2022/06/12 13:42:08] ppocr INFO: name : SPINAttnLabelDecode -[2022/06/12 13:42:08] ppocr INFO: use_space_char : False -[2022/06/12 13:42:08] ppocr INFO: Train : -[2022/06/12 13:42:08] ppocr INFO: dataset : -[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data/ -[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt'] -[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet -[2022/06/12 13:42:08] ppocr INFO: transforms : -[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage : -[2022/06/12 13:42:08] ppocr INFO: channel_first : False -[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR -[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None -[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg : -[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32] -[2022/06/12 13:42:08] ppocr INFO: interpolation : 2 -[2022/06/12 13:42:08] ppocr INFO: mean : [127.5] -[2022/06/12 13:42:08] ppocr INFO: std : [127.5] -[2022/06/12 13:42:08] ppocr INFO: KeepKeys : -[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length'] -[2022/06/12 13:42:08] ppocr INFO: loader : -[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8 -[2022/06/12 13:42:08] ppocr INFO: drop_last : True -[2022/06/12 13:42:08] ppocr INFO: num_workers : 4 -[2022/06/12 13:42:08] ppocr INFO: shuffle : True -[2022/06/12 13:42:08] ppocr INFO: profiler_options : None -[2022/06/12 13:42:08] ppocr INFO: train with paddle 2.2.2 and device CUDAPlace(0) -[2022/06/12 13:42:08] ppocr INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt'] -W0612 13:42:08.814790 17600 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.1, Runtime API Version: 10.2 -W0612 13:42:08.832805 17600 device_context.cc:465] device: 0, cuDNN Version: 7.6. -[2022/06/12 13:42:12] ppocr INFO: resume from ./inference/rec_r32_gaspin_bilstm_att/best_accuracy -[2022/06/12 13:42:12] ppocr INFO: metric in ckpt *************** -[2022/06/12 13:42:12] ppocr INFO: acc:0.90589541082154 -[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:0.9627389225663741 -[2022/06/12 13:42:12] ppocr INFO: fps:1802.1068940938283 -[2022/06/12 13:42:12] ppocr INFO: best_epoch:6 -[2022/06/12 13:42:12] ppocr INFO: start_epoch:7 - eval model:: 0%| | 0/2 [00:00 Date: Sun, 10 Jul 2022 09:37:26 +0800 Subject: [PATCH 04/11] merge export_model --- .../rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml | 1 - tools/export_model.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml index e53396a03..3999ecda8 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -62,7 +62,6 @@ Loss: PostProcess: name: SPINAttnLabelDecode - character_dict_path: ./ppocr/utils/dict/spin_dict.txt use_space_char: False diff --git a/tools/export_model.py b/tools/export_model.py index 3ea0228f8..15c4e35b3 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -84,7 +84,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" ) infer_shape[-1] = 100 - if arch_config["algorithm"] == "NRTR": + if arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": infer_shape = [1, 32, 100] elif arch_config["model_type"] == "table": infer_shape = [3, 488, 488] From 7e81a9e62d8ebe1e8213575b6cb0602f54fff424 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 09:45:52 +0800 Subject: [PATCH 05/11] modified annotation --- ppocr/modeling/backbones/rec_resnet_32.py | 28 +++---------------- ppocr/modeling/heads/rec_spin_att_head.py | 1 - .../modeling/transforms/gaspin_transformer.py | 1 - 3 files changed, 4 insertions(+), 26 deletions(-) diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py index 0b072dc5f..cbd19251a 100644 --- a/ppocr/modeling/backbones/rec_resnet_32.py +++ b/ppocr/modeling/backbones/rec_resnet_32.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle -from paddle import ParamAttr import paddle.nn as nn -import paddle.nn.functional as F -import numpy as np __all__ = ["ResNet32"] @@ -51,10 +47,10 @@ class ResNet32(nn.Layer): def forward(self, inputs): """ Args: - inputs (torch.Tensor): input feature + inputs: input feature Returns: - torch.Tensor: output feature + output feature """ return self.ConvNet(inputs) @@ -92,7 +88,7 @@ class BasicBlock(nn.Layer): out_planes (int): channels of the middle feature stride (int): stride of the convolution Returns: - nn.Module: Conv2D with kernel = 3 + nn.Layer: Conv2D with kernel = 3 """ @@ -102,14 +98,6 @@ class BasicBlock(nn.Layer): bias_attr=False) def forward(self, x): - """ - Args: - x (torch.Tensor): input feature - - Returns: - torch.Tensor: output feature of the BasicBlock - - """ residual = x out = self.conv1(x) @@ -246,14 +234,6 @@ class ResNet(nn.Layer): return nn.Sequential(*layers) def forward(self, x): - """ - Args: - x (torch.Tensor): input feature - - Returns: - torch.Tensor: output feature of the Resnet - - """ x = self.conv0_1(x) x = self.bn0_1(x) x = self.relu(x) diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py index 07a58b083..8f92d1ef4 100644 --- a/ppocr/modeling/heads/rec_spin_att_head.py +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -19,7 +19,6 @@ from __future__ import print_function import paddle import paddle.nn as nn import paddle.nn.functional as F -import numpy as np class SPINAttentionHead(nn.Layer): diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py index 9440e360d..f4719eb21 100644 --- a/ppocr/modeling/transforms/gaspin_transformer.py +++ b/ppocr/modeling/transforms/gaspin_transformer.py @@ -21,7 +21,6 @@ import paddle from paddle import nn, ParamAttr from paddle.nn import functional as F import numpy as np -import itertools import functools from .tps import GridGenerator From f56a7e9c45d87ce911d6ac43fcd41f2718e62292 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 11:25:42 +0800 Subject: [PATCH 06/11] merge overview.md --- doc/doc_ch/algorithm_overview.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index a7754b2f0..fbd3ce9eb 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -88,12 +88,9 @@ |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | -<<<<<<< HEAD -|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | -======= |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | ->>>>>>> b6982b7fc720a1c9346838978b9228025f26c42b +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | From 46e3442e2e38f4449d06a77010634d43fab16331 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 11:47:25 +0800 Subject: [PATCH 07/11] add spin --- configs/rec/rec_r32_gaspin_bilstm_att.yml | 4 ++-- ppocr/data/imaug/rec_img_aug.py | 4 +--- .../rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml | 4 ++-- tools/export_model.py | 2 +- tools/infer/predict_rec.py | 3 ++- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index e8235415c..f7c1b813f 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -75,7 +75,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -98,7 +98,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 3f002173d..c5d8a3b2f 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -274,6 +274,7 @@ class SPINRecResizeImg(object): def __call__(self, data): img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # different interpolation type corresponding the OpenCV if self.interpolation == 0: interpolation = cv2.INTER_NEAREST @@ -294,12 +295,9 @@ class SPINRecResizeImg(object): img = np.expand_dims(img, -1) img = img.transpose((2, 0, 1)) # normalize the image - to_rgb = False img = img.copy().astype(np.float32) mean = np.float64(self.mean.reshape(1, -1)) stdinv = 1 / np.float64(self.std.reshape(1, -1)) - if to_rgb: - cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img -= mean img *= stdinv data['image'] = img diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml index 3999ecda8..a08efe579 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -76,7 +76,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -99,7 +99,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/tools/export_model.py b/tools/export_model.py index afecbff8c..4855c53a9 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR": + elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 7f9aea09d..5a8cb84f7 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -81,7 +81,6 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } - elif self.rec_algorithm == "SPIN": postprocess_params = { 'name': 'SPINAttnLabelDecode', @@ -362,6 +361,8 @@ class TextRecognizer(object): norm_img_batch.append(norm_img) elif self.rec_algorithm == 'SPIN': norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) elif self.rec_algorithm == "ABINet": norm_img = self.resize_norm_img_abinet( img_list[indices[ino]], self.rec_image_shape) From 4a3b874a366037caf878e02cf8f9a4da7796a620 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 11:49:48 +0800 Subject: [PATCH 08/11] overview.md --- doc/doc_en/algorithm_overview_en.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 96dae98f9..a579d2447 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -66,12 +66,9 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [SAR](./algorithm_rec_sar_en.md) - [x] [SEED](./algorithm_rec_seed_en.md) - [x] [SVTR](./algorithm_rec_svtr_en.md) -<<<<<<< HEAD -- [x] [SPIN](./algorithm_rec_spin_en.md) -======= - [x] [ViTSTR](./algorithm_rec_vitstr_en.md) - [x] [ABINet](./algorithm_rec_abinet_en.md) ->>>>>>> b6982b7fc720a1c9346838978b9228025f26c42b +- [x] [SPIN](./algorithm_rec_spin_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: From cb370419ec28cb2fbd148b125272946ecc4c88f5 Mon Sep 17 00:00:00 2001 From: smilelite Date: Mon, 11 Jul 2022 23:59:45 +0800 Subject: [PATCH 09/11] modified pr --- ppocr/data/imaug/label_ops.py | 2 +- ppocr/modeling/heads/rec_spin_att_head.py | 5 + ppocr/postprocess/rec_postprocess.py | 147 ++++++++++++---------- tools/export_model.py | 2 +- 4 files changed, 88 insertions(+), 68 deletions(-) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 94e0dd226..36bc29793 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode): dict_character = [''] + dict_character return dict_character -class SPINAttnLabelEncode(BaseRecLabelEncode): +class SPINAttnLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ def __init__(self, diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py index 8f92d1ef4..86e35e433 100644 --- a/ppocr/modeling/heads/rec_spin_att_head.py +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 26ea71fd0..6f64899b7 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -669,7 +669,86 @@ class ABINetLabelDecode(NRTRLabelDecode): return dict_character -class SPINAttnLabelDecode(BaseRecLabelDecode): +# class SPINAttnLabelDecode(BaseRecLabelDecode): +# """ Convert between text-label and text-index """ + +# def __init__(self, character_dict_path=None, use_space_char=False, +# **kwargs): +# super(SPINAttnLabelDecode, self).__init__(character_dict_path, +# use_space_char) + +# def add_special_char(self, dict_character): +# self.beg_str = "sos" +# self.end_str = "eos" +# dict_character = dict_character +# dict_character = [self.beg_str] + [self.end_str] + dict_character +# return dict_character + +# def decode(self, text_index, text_prob=None, is_remove_duplicate=False): +# """ convert text-index into text-label. """ +# result_list = [] +# ignored_tokens = self.get_ignored_tokens() +# [beg_idx, end_idx] = self.get_ignored_tokens() +# batch_size = len(text_index) +# for batch_idx in range(batch_size): +# char_list = [] +# conf_list = [] +# for idx in range(len(text_index[batch_idx])): +# if text_index[batch_idx][idx] == int(beg_idx): +# continue +# if int(text_index[batch_idx][idx]) == int(end_idx): +# break +# if is_remove_duplicate: +# # only for predict +# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ +# batch_idx][idx]: +# continue +# char_list.append(self.character[int(text_index[batch_idx][ +# idx])]) +# if text_prob is not None: +# conf_list.append(text_prob[batch_idx][idx]) +# else: +# conf_list.append(1) +# text = ''.join(char_list) +# result_list.append((text.lower(), np.mean(conf_list).tolist())) +# return result_list + +# def __call__(self, preds, label=None, *args, **kwargs): +# """ +# text = self.decode(text) +# if label is None: +# return text +# else: +# label = self.decode(label, is_remove_duplicate=False) +# return text, label +# """ +# if isinstance(preds, paddle.Tensor): +# preds = preds.numpy() + +# preds_idx = preds.argmax(axis=2) +# preds_prob = preds.max(axis=2) +# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) +# if label is None: +# return text +# label = self.decode(label, is_remove_duplicate=False) +# return text, label + +# def get_ignored_tokens(self): +# beg_idx = self.get_beg_end_flag_idx("beg") +# end_idx = self.get_beg_end_flag_idx("end") +# return [beg_idx, end_idx] + +# def get_beg_end_flag_idx(self, beg_or_end): +# if beg_or_end == "beg": +# idx = np.array(self.dict[self.beg_str]) +# elif beg_or_end == "end": +# idx = np.array(self.dict[self.end_str]) +# else: +# assert False, "unsupport type %s in get_beg_end_flag_idx" \ +# % beg_or_end +# return idx + +class SPINAttnLabelDecode(AttnLabelDecode): """ Convert between text-label and text-index """ def __init__(self, character_dict_path=None, use_space_char=False, @@ -682,68 +761,4 @@ class SPINAttnLabelDecode(BaseRecLabelDecode): self.end_str = "eos" dict_character = dict_character dict_character = [self.beg_str] + [self.end_str] + dict_character - return dict_character - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - ignored_tokens = self.get_ignored_tokens() - [beg_idx, end_idx] = self.get_ignored_tokens() - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == int(beg_idx): - continue - if int(text_index[batch_idx][idx]) == int(end_idx): - break - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - result_list.append((text.lower(), np.mean(conf_list).tolist())) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - """ - text = self.decode(text) - if label is None: - return text - else: - label = self.decode(label, is_remove_duplicate=False) - return text, label - """ - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label, is_remove_duplicate=False) - return text, label - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return idx \ No newline at end of file + return dict_character \ No newline at end of file diff --git a/tools/export_model.py b/tools/export_model.py index 4855c53a9..69ac904c6 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": + elif arch_config["algorithm"] in ["NRTR", "SPIN"]: other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), From f614274672a0c0874c4749f51be0520babf170ca Mon Sep 17 00:00:00 2001 From: smilelite Date: Tue, 12 Jul 2022 22:15:00 +0800 Subject: [PATCH 10/11] modified label_ops --- ppocr/data/imaug/label_ops.py | 17 +----- ppocr/postprocess/rec_postprocess.py | 80 ---------------------------- 2 files changed, 1 insertion(+), 96 deletions(-) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 36bc29793..775ceec83 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1248,19 +1248,4 @@ class SPINAttnLabelEncode(AttnLabelEncode): padded_text[:len(target)] = target data['label'] = np.array(padded_text) - return data - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return \ No newline at end of file + return data \ No newline at end of file diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 6f64899b7..3e7c29d8d 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -668,86 +668,6 @@ class ABINetLabelDecode(NRTRLabelDecode): dict_character = [''] + dict_character return dict_character - -# class SPINAttnLabelDecode(BaseRecLabelDecode): -# """ Convert between text-label and text-index """ - -# def __init__(self, character_dict_path=None, use_space_char=False, -# **kwargs): -# super(SPINAttnLabelDecode, self).__init__(character_dict_path, -# use_space_char) - -# def add_special_char(self, dict_character): -# self.beg_str = "sos" -# self.end_str = "eos" -# dict_character = dict_character -# dict_character = [self.beg_str] + [self.end_str] + dict_character -# return dict_character - -# def decode(self, text_index, text_prob=None, is_remove_duplicate=False): -# """ convert text-index into text-label. """ -# result_list = [] -# ignored_tokens = self.get_ignored_tokens() -# [beg_idx, end_idx] = self.get_ignored_tokens() -# batch_size = len(text_index) -# for batch_idx in range(batch_size): -# char_list = [] -# conf_list = [] -# for idx in range(len(text_index[batch_idx])): -# if text_index[batch_idx][idx] == int(beg_idx): -# continue -# if int(text_index[batch_idx][idx]) == int(end_idx): -# break -# if is_remove_duplicate: -# # only for predict -# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ -# batch_idx][idx]: -# continue -# char_list.append(self.character[int(text_index[batch_idx][ -# idx])]) -# if text_prob is not None: -# conf_list.append(text_prob[batch_idx][idx]) -# else: -# conf_list.append(1) -# text = ''.join(char_list) -# result_list.append((text.lower(), np.mean(conf_list).tolist())) -# return result_list - -# def __call__(self, preds, label=None, *args, **kwargs): -# """ -# text = self.decode(text) -# if label is None: -# return text -# else: -# label = self.decode(label, is_remove_duplicate=False) -# return text, label -# """ -# if isinstance(preds, paddle.Tensor): -# preds = preds.numpy() - -# preds_idx = preds.argmax(axis=2) -# preds_prob = preds.max(axis=2) -# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) -# if label is None: -# return text -# label = self.decode(label, is_remove_duplicate=False) -# return text, label - -# def get_ignored_tokens(self): -# beg_idx = self.get_beg_end_flag_idx("beg") -# end_idx = self.get_beg_end_flag_idx("end") -# return [beg_idx, end_idx] - -# def get_beg_end_flag_idx(self, beg_or_end): -# if beg_or_end == "beg": -# idx = np.array(self.dict[self.beg_str]) -# elif beg_or_end == "end": -# idx = np.array(self.dict[self.end_str]) -# else: -# assert False, "unsupport type %s in get_beg_end_flag_idx" \ -# % beg_or_end -# return idx - class SPINAttnLabelDecode(AttnLabelDecode): """ Convert between text-label and text-index """ From 484bf2f7dcd73d708bdae4269d90f4cbd3ecf7cf Mon Sep 17 00:00:00 2001 From: smilelite Date: Thu, 14 Jul 2022 22:26:10 +0800 Subject: [PATCH 11/11] modified SPINLabelEncode SPINLabelDecode --- configs/rec/rec_r32_gaspin_bilstm_att.yml | 6 +++--- ppocr/data/imaug/label_ops.py | 4 ++-- ppocr/postprocess/__init__.py | 4 ++-- ppocr/postprocess/rec_postprocess.py | 4 ++-- .../rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml | 6 +++--- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index f7c1b813f..aea71388f 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -60,7 +60,7 @@ Loss: ignore_index: 0 PostProcess: - name: SPINAttnLabelDecode + name: SPINLabelDecode use_space_char: False @@ -78,7 +78,7 @@ Train: - DecodeImage: # load image img_mode: BGR channel_first: False - - SPINAttnLabelEncode: # Class handling label + - SPINLabelEncode: # Class handling label - SPINRecResizeImg: image_shape: [100, 32] interpolation : 2 @@ -101,7 +101,7 @@ Eval: - DecodeImage: # load image img_mode: BGR channel_first: False - - SPINAttnLabelEncode: # Class handling label + - SPINLabelEncode: # Class handling label - SPINRecResizeImg: image_shape: [100, 32] interpolation : 2 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 775ceec83..97539faf2 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode): dict_character = [''] + dict_character return dict_character -class SPINAttnLabelEncode(AttnLabelEncode): +class SPINLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ def __init__(self, @@ -1226,7 +1226,7 @@ class SPINAttnLabelEncode(AttnLabelEncode): use_space_char=False, lower=True, **kwargs): - super(SPINAttnLabelEncode, self).__init__( + super(SPINLabelEncode, self).__init__( max_text_length, character_dict_path, use_space_char) self.lower = lower def add_special_char(self, dict_character): diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 204ae0bdf..eeebc5803 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ - SPINAttnLabelDecode + SPINLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -45,7 +45,7 @@ def build_post_process(config, global_config=None): 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', - 'TableMasterLabelDecode', 'SPINAttnLabelDecode' + 'TableMasterLabelDecode', 'SPINLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 3e7c29d8d..3fe29aabe 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -668,12 +668,12 @@ class ABINetLabelDecode(NRTRLabelDecode): dict_character = [''] + dict_character return dict_character -class SPINAttnLabelDecode(AttnLabelDecode): +class SPINLabelDecode(AttnLabelDecode): """ Convert between text-label and text-index """ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): - super(SPINAttnLabelDecode, self).__init__(character_dict_path, + super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char) def add_special_char(self, dict_character): diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml index a08efe579..d0cb20481 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -61,7 +61,7 @@ Loss: ignore_index: 0 PostProcess: - name: SPINAttnLabelDecode + name: SPINLabelDecode use_space_char: False @@ -79,7 +79,7 @@ Train: - DecodeImage: # load image img_mode: BGR channel_first: False - - SPINAttnLabelEncode: # Class handling label + - SPINLabelEncode: # Class handling label - SPINRecResizeImg: image_shape: [100, 32] interpolation : 2 @@ -102,7 +102,7 @@ Eval: - DecodeImage: # load image img_mode: BGR channel_first: False - - SPINAttnLabelEncode: # Class handling label + - SPINLabelEncode: # Class handling label - SPINRecResizeImg: image_shape: [100, 32] interpolation : 2