From bbca1e0d66298fd21abcb953a4c73379b8e04996 Mon Sep 17 00:00:00 2001
From: smilelite <xuyang508@163.com>
Date: Sun, 12 Jun 2022 13:53:29 +0800
Subject: [PATCH] add pr

---
 .gitignore                                    |   2 +-
 configs/rec/rec_r32_gaspin_bilstm_att.yml     | 117 +++++++
 doc/doc_ch/algorithm_overview.md              |   2 +
 doc/doc_ch/algorithm_rec_spin.md              | 112 +++++++
 doc/doc_en/algorithm_overview_en.md           |   2 +
 doc/doc_en/algorithm_rec_spin_en.md           | 112 +++++++
 log/workerlog.0                               | 131 ++++++++
 ppocr/data/imaug/__init__.py                  |   3 +-
 ppocr/data/imaug/label_ops.py                 |  49 +++
 ppocr/data/imaug/rec_img_aug.py               |  45 +++
 ppocr/losses/__init__.py                      |   4 +-
 ppocr/losses/rec_spin_att_loss.py             |  41 +++
 ppocr/modeling/backbones/__init__.py          |   3 +-
 ppocr/modeling/backbones/rec_resnet_32.py     | 289 ++++++++++++++++++
 ppocr/modeling/heads/__init__.py              |   3 +-
 ppocr/modeling/heads/rec_spin_att_head.py     | 203 ++++++++++++
 ppocr/modeling/necks/rnn.py                   |  67 +++-
 ppocr/modeling/transforms/__init__.py         |   4 +-
 .../modeling/transforms/gaspin_transformer.py | 286 +++++++++++++++++
 ppocr/postprocess/__init__.py                 |   4 +-
 ppocr/postprocess/rec_postprocess.py          |  79 +++++
 ppocr/utils/dict/spin_dict.txt                |  68 +++++
 .../rec_r32_gaspin_bilstm_att.yml             | 118 +++++++
 .../train_infer_python.txt                    |  53 ++++
 tools/export_model.py                         |   6 +
 tools/infer/predict_rec.py                    |  26 ++
 tools/program.py                              |   5 +-
 27 files changed, 1823 insertions(+), 11 deletions(-)
 create mode 100644 configs/rec/rec_r32_gaspin_bilstm_att.yml
 create mode 100644 doc/doc_ch/algorithm_rec_spin.md
 create mode 100644 doc/doc_en/algorithm_rec_spin_en.md
 create mode 100644 log/workerlog.0
 create mode 100644 ppocr/losses/rec_spin_att_loss.py
 create mode 100644 ppocr/modeling/backbones/rec_resnet_32.py
 create mode 100644 ppocr/modeling/heads/rec_spin_att_head.py
 create mode 100644 ppocr/modeling/transforms/gaspin_transformer.py
 create mode 100644 ppocr/utils/dict/spin_dict.txt
 create mode 100644 test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
 create mode 100644 test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt

diff --git a/.gitignore b/.gitignore
index caf886a2b..34f0e0cc9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,7 +10,7 @@ __pycache__/
 inference/
 inference_results/
 output/
-
+train_data/
 *.DS_Store
 *.vs
 *.user
diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 000000000..236a17c43
--- /dev/null
+++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,117 @@
+Global:
+  use_gpu: True
+  epoch_num: 6
+  log_smooth_window: 50
+  print_batch_step: 50
+  save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+  save_epoch_step: 3
+  # evaluation is run every 2000 iterations after the 4000th iteration
+  eval_batch_step: [0, 2000]
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  # for data or label process
+  character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+  save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Piecewise
+    decay_epochs: [3, 4, 5]
+    values: [0.001, 0.0003, 0.00009, 0.000027] 
+  clip_norm: 5
+
+Architecture:
+  model_type: rec
+  algorithm: SPIN
+  in_channels: 1
+  Transform:
+    name: GA_SPIN
+    offsets: True
+    default_type: 6
+    loc_lr: 0.1
+    stn: True
+  Backbone:
+    name: ResNet32
+    out_channels: 512
+  Neck:
+    name: SequenceEncoder
+    encoder_type: cascadernn 
+    hidden_size: 256
+    out_channels: [256, 512]
+    with_linear: True
+  Head:
+    name: SPINAttentionHead  
+    hidden_size: 256
+    
+
+Loss:
+  name: SPINAttentionLoss
+  ignore_index: 0
+
+PostProcess:
+  name: SPINAttnLabelDecode
+  character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+  use_space_char: False
+
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+  is_filter: True
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data/
+    label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+    transforms:
+      - NRTRDecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - SPINAttnLabelEncode: # Class handling label
+      - SPINRecResizeImg:
+          image_shape: [100, 32]
+          interpolation : 2
+          mean: [127.5]
+          std: [127.5]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 8
+    drop_last: True
+    num_workers: 4
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data
+    label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+    transforms:
+      - NRTRDecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - SPINAttnLabelEncode: # Class handling label
+      - SPINRecResizeImg:
+          image_shape: [100, 32]
+          interpolation : 2
+          mean: [127.5]
+          std: [127.5]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 8
+    num_workers: 2
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 6227a2149..ef95317ac 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -66,6 +66,7 @@
 - [x]  [SAR](./algorithm_rec_sar.md)
 - [x]  [SEED](./algorithm_rec_seed.md)
 - [x]  [SVTR](./algorithm_rec_svtr.md)
+- [x]  [SPIN](./algorithm_rec_spin.md)
 
 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
 
@@ -84,6 +85,7 @@
 |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
 |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
 |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
 
 
 <a name="2"></a>
diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md
new file mode 100644
index 000000000..c996992d2
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_spin.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+    - [3.1 训练](#3-1)
+    - [3.2 评估](#3-2)
+    - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+    - [4.1 Python推理](#4-1)
+    - [4.2 C++推理](#4-2)
+    - [4.3 Serving服务化部署](#4-3)
+    - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+<a name="1"></a>
+## 1. 算法简介
+
+论文信息:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。
+使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+<a name="2"></a>
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+<a name="3"></a>
+## 3. 模型训练、评估、预测
+
+请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+训练
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3'  tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+<a name="4"></a>
+## 4. 推理部署
+
+<a name="4-1"></a>
+### 4.1 Python推理
+首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy  Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+SPIN文本识别模型推理,可以执行如下命令:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee
+```
+
+<a name="4-2"></a>
+### 4.2 C++推理
+
+由于C++预处理后处理还未支持SPIN,所以暂未支持
+
+<a name="4-3"></a>
+### 4.3 Serving服务化部署
+
+暂不支持
+
+<a name="4-4"></a>
+### 4.4 更多推理部署
+
+暂不支持
+
+<a name="5"></a>
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@article{2020SPIN,
+  title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+  author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+  journal={AAAI2020},
+  year={2020},
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 383cbe39b..608584e01 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
 - [x]  [SAR](./algorithm_rec_sar_en.md)
 - [x]  [SEED](./algorithm_rec_seed_en.md)
 - [x]  [SVTR](./algorithm_rec_svtr_en.md)
+- [x]  [SPIN](./algorithm_rec_spin_en.md)
 
 Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
 
@@ -83,6 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
 |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
 |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
 |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
 
 
 <a name="2"></a>
diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md
new file mode 100644
index 000000000..43ab30ce7
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_spin_en.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+    - [3.1 Training](#3-1)
+    - [3.2 Evaluation](#3-2)
+    - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+    - [4.1 Python Inference](#4-1)
+    - [4.2 C++ Inference](#4-2)
+    - [4.3 Serving](#4-3)
+    - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+<a name="1"></a>
+## 1. Introduction
+
+Paper:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+<a name="2"></a>
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+<a name="3"></a>
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3'  tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+<a name="4"></a>
+## 4. Inference and Deployment
+
+<a name="4-1"></a>
+### 4.1 Python Inference
+First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy  Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+
+For SPIN text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False
+```
+
+<a name="4-2"></a>
+### 4.2 C++ Inference
+
+Not supported
+
+<a name="4-3"></a>
+### 4.3 Serving
+
+Not supported
+
+<a name="4-4"></a>
+### 4.4 More
+
+Not supported
+
+<a name="5"></a>
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@article{2020SPIN,
+  title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+  author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+  journal={AAAI2020},
+  year={2020},
+}
+```
diff --git a/log/workerlog.0 b/log/workerlog.0
new file mode 100644
index 000000000..7983c87df
--- /dev/null
+++ b/log/workerlog.0
@@ -0,0 +1,131 @@
+D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\socks.py:58: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
+  from collections import Callable
+D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\win32\lib\pywintypes.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
+  import imp, sys, os
+D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:943: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
+  collections.MutableMapping.register(ParseResults)
+D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:3245: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
+  elif isinstance( exprs, collections.Iterable ):
+[2022/06/12 13:42:08] ppocr INFO: Architecture : 
+[2022/06/12 13:42:08] ppocr INFO:     Backbone : 
+[2022/06/12 13:42:08] ppocr INFO:         name : ResNet32
+[2022/06/12 13:42:08] ppocr INFO:         out_channels : 512
+[2022/06/12 13:42:08] ppocr INFO:     Head : 
+[2022/06/12 13:42:08] ppocr INFO:         hidden_size : 256
+[2022/06/12 13:42:08] ppocr INFO:         name : SPINAttentionHead
+[2022/06/12 13:42:08] ppocr INFO:     Neck : 
+[2022/06/12 13:42:08] ppocr INFO:         encoder_type : cascadernn
+[2022/06/12 13:42:08] ppocr INFO:         hidden_size : 256
+[2022/06/12 13:42:08] ppocr INFO:         name : SequenceEncoder
+[2022/06/12 13:42:08] ppocr INFO:         out_channels : [256, 512]
+[2022/06/12 13:42:08] ppocr INFO:         with_linear : True
+[2022/06/12 13:42:08] ppocr INFO:     Transform : 
+[2022/06/12 13:42:08] ppocr INFO:         default_type : 6
+[2022/06/12 13:42:08] ppocr INFO:         loc_lr : 0.1
+[2022/06/12 13:42:08] ppocr INFO:         name : GA_SPIN
+[2022/06/12 13:42:08] ppocr INFO:         offsets : True
+[2022/06/12 13:42:08] ppocr INFO:         stn : True
+[2022/06/12 13:42:08] ppocr INFO:     algorithm : SPIN
+[2022/06/12 13:42:08] ppocr INFO:     in_channels : 1
+[2022/06/12 13:42:08] ppocr INFO:     model_type : rec
+[2022/06/12 13:42:08] ppocr INFO: Eval : 
+[2022/06/12 13:42:08] ppocr INFO:     dataset : 
+[2022/06/12 13:42:08] ppocr INFO:         data_dir : ./train_data/ic15_data
+[2022/06/12 13:42:08] ppocr INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
+[2022/06/12 13:42:08] ppocr INFO:         name : SimpleDataSet
+[2022/06/12 13:42:08] ppocr INFO:         transforms : 
+[2022/06/12 13:42:08] ppocr INFO:             NRTRDecodeImage : 
+[2022/06/12 13:42:08] ppocr INFO:                 channel_first : False
+[2022/06/12 13:42:08] ppocr INFO:                 img_mode : BGR
+[2022/06/12 13:42:08] ppocr INFO:             SPINAttnLabelEncode : None
+[2022/06/12 13:42:08] ppocr INFO:             SPINRecResizeImg : 
+[2022/06/12 13:42:08] ppocr INFO:                 image_shape : [100, 32]
+[2022/06/12 13:42:08] ppocr INFO:                 interpolation : 2
+[2022/06/12 13:42:08] ppocr INFO:                 mean : [127.5]
+[2022/06/12 13:42:08] ppocr INFO:                 std : [127.5]
+[2022/06/12 13:42:08] ppocr INFO:             KeepKeys : 
+[2022/06/12 13:42:08] ppocr INFO:                 keep_keys : ['image', 'label', 'length']
+[2022/06/12 13:42:08] ppocr INFO:     loader : 
+[2022/06/12 13:42:08] ppocr INFO:         batch_size_per_card : 8
+[2022/06/12 13:42:08] ppocr INFO:         drop_last : False
+[2022/06/12 13:42:08] ppocr INFO:         num_workers : 2
+[2022/06/12 13:42:08] ppocr INFO:         shuffle : False
+[2022/06/12 13:42:08] ppocr INFO: Global : 
+[2022/06/12 13:42:08] ppocr INFO:     cal_metric_during_train : True
+[2022/06/12 13:42:08] ppocr INFO:     character_dict_path : ./ppocr/utils/dict/spin_dict.txt
+[2022/06/12 13:42:08] ppocr INFO:     checkpoints : ./inference/rec_r32_gaspin_bilstm_att/best_accuracy
+[2022/06/12 13:42:08] ppocr INFO:     distributed : False
+[2022/06/12 13:42:08] ppocr INFO:     epoch_num : 6
+[2022/06/12 13:42:08] ppocr INFO:     eval_batch_step : [0, 2000]
+[2022/06/12 13:42:08] ppocr INFO:     infer_img : doc/imgs_words/ch/word_1.jpg
+[2022/06/12 13:42:08] ppocr INFO:     infer_mode : False
+[2022/06/12 13:42:08] ppocr INFO:     log_smooth_window : 50
+[2022/06/12 13:42:08] ppocr INFO:     max_text_length : 25
+[2022/06/12 13:42:08] ppocr INFO:     pretrained_model : None
+[2022/06/12 13:42:08] ppocr INFO:     print_batch_step : 50
+[2022/06/12 13:42:08] ppocr INFO:     save_epoch_step : 3
+[2022/06/12 13:42:08] ppocr INFO:     save_inference_dir : None
+[2022/06/12 13:42:08] ppocr INFO:     save_model_dir : ./output/rec/rec_r32_gaspin_bilstm_att/
+[2022/06/12 13:42:08] ppocr INFO:     save_res_path : ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+[2022/06/12 13:42:08] ppocr INFO:     use_gpu : True
+[2022/06/12 13:42:08] ppocr INFO:     use_space_char : False
+[2022/06/12 13:42:08] ppocr INFO:     use_visualdl : False
+[2022/06/12 13:42:08] ppocr INFO: Loss : 
+[2022/06/12 13:42:08] ppocr INFO:     ignore_index : 0
+[2022/06/12 13:42:08] ppocr INFO:     name : SPINAttentionLoss
+[2022/06/12 13:42:08] ppocr INFO: Metric : 
+[2022/06/12 13:42:08] ppocr INFO:     is_filter : True
+[2022/06/12 13:42:08] ppocr INFO:     main_indicator : acc
+[2022/06/12 13:42:08] ppocr INFO:     name : RecMetric
+[2022/06/12 13:42:08] ppocr INFO: Optimizer : 
+[2022/06/12 13:42:08] ppocr INFO:     beta1 : 0.9
+[2022/06/12 13:42:08] ppocr INFO:     beta2 : 0.999
+[2022/06/12 13:42:08] ppocr INFO:     clip_norm : 5
+[2022/06/12 13:42:08] ppocr INFO:     lr : 
+[2022/06/12 13:42:08] ppocr INFO:         decay_epochs : [3, 4, 5]
+[2022/06/12 13:42:08] ppocr INFO:         name : Piecewise
+[2022/06/12 13:42:08] ppocr INFO:         values : [0.001, 0.0003, 9e-05, 2.7e-05]
+[2022/06/12 13:42:08] ppocr INFO:     name : AdamW
+[2022/06/12 13:42:08] ppocr INFO: PostProcess : 
+[2022/06/12 13:42:08] ppocr INFO:     character_dict_path : ./ppocr/utils/dict/spin_dict.txt
+[2022/06/12 13:42:08] ppocr INFO:     name : SPINAttnLabelDecode
+[2022/06/12 13:42:08] ppocr INFO:     use_space_char : False
+[2022/06/12 13:42:08] ppocr INFO: Train : 
+[2022/06/12 13:42:08] ppocr INFO:     dataset : 
+[2022/06/12 13:42:08] ppocr INFO:         data_dir : ./train_data/ic15_data/
+[2022/06/12 13:42:08] ppocr INFO:         label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
+[2022/06/12 13:42:08] ppocr INFO:         name : SimpleDataSet
+[2022/06/12 13:42:08] ppocr INFO:         transforms : 
+[2022/06/12 13:42:08] ppocr INFO:             NRTRDecodeImage : 
+[2022/06/12 13:42:08] ppocr INFO:                 channel_first : False
+[2022/06/12 13:42:08] ppocr INFO:                 img_mode : BGR
+[2022/06/12 13:42:08] ppocr INFO:             SPINAttnLabelEncode : None
+[2022/06/12 13:42:08] ppocr INFO:             SPINRecResizeImg : 
+[2022/06/12 13:42:08] ppocr INFO:                 image_shape : [100, 32]
+[2022/06/12 13:42:08] ppocr INFO:                 interpolation : 2
+[2022/06/12 13:42:08] ppocr INFO:                 mean : [127.5]
+[2022/06/12 13:42:08] ppocr INFO:                 std : [127.5]
+[2022/06/12 13:42:08] ppocr INFO:             KeepKeys : 
+[2022/06/12 13:42:08] ppocr INFO:                 keep_keys : ['image', 'label', 'length']
+[2022/06/12 13:42:08] ppocr INFO:     loader : 
+[2022/06/12 13:42:08] ppocr INFO:         batch_size_per_card : 8
+[2022/06/12 13:42:08] ppocr INFO:         drop_last : True
+[2022/06/12 13:42:08] ppocr INFO:         num_workers : 4
+[2022/06/12 13:42:08] ppocr INFO:         shuffle : True
+[2022/06/12 13:42:08] ppocr INFO: profiler_options : None
+[2022/06/12 13:42:08] ppocr INFO: train with paddle 2.2.2 and device CUDAPlace(0)
+[2022/06/12 13:42:08] ppocr INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
+W0612 13:42:08.814790 17600 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.1, Runtime API Version: 10.2
+W0612 13:42:08.832805 17600 device_context.cc:465] device: 0, cuDNN Version: 7.6.
+[2022/06/12 13:42:12] ppocr INFO: resume from ./inference/rec_r32_gaspin_bilstm_att/best_accuracy
+[2022/06/12 13:42:12] ppocr INFO: metric in ckpt ***************
+[2022/06/12 13:42:12] ppocr INFO: acc:0.90589541082154
+[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:0.9627389225663741
+[2022/06/12 13:42:12] ppocr INFO: fps:1802.1068940938283
+[2022/06/12 13:42:12] ppocr INFO: best_epoch:6
+[2022/06/12 13:42:12] ppocr INFO: start_epoch:7
+
eval model::   0%|          | 0/2 [00:00<?, ?it/s]
eval model::  50%|����������     | 1/2 [00:00<00:00,  4.67it/s]
+[2022/06/12 13:42:12] ppocr INFO: metric eval ***************
+[2022/06/12 13:42:12] ppocr INFO: acc:0.9999987500015626
+[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:1.0
+[2022/06/12 13:42:12] ppocr INFO: fps:57.50210270524082
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index f0fd578f6..9ad0ffa84 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
 from .make_pse_gt import MakePseGt
 
 from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
-    SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
+    SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \
+    SPINRecResizeImg
 from .ssl_img_aug import SSLRotateResize
 from .randaugment import RandAugment
 from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 02a5187da..83e3e7854 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1044,3 +1044,52 @@ class MultiLabelEncode(BaseRecLabelEncode):
         data_out['label_sar'] = sar['label']
         data_out['length'] = ctc['length']
         return data_out
+
+
+class SPINAttnLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 use_space_char=False,
+                 lower=True,
+                 **kwargs):
+        super(SPINAttnLabelEncode, self).__init__(
+            max_text_length, character_dict_path, use_space_char)
+        self.lower = lower
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + [self.end_str] + dict_character
+        return dict_character
+
+    def __call__(self, data):
+        text = data['label']
+        text = self.encode(text)
+        if text is None:
+            return None
+        if len(text) > self.max_text_len:
+            return None
+        data['length'] = np.array(len(text))
+        target = [0] + text + [1]
+        padded_text = [0 for _ in range(self.max_text_len + 2)]
+
+        padded_text[:len(target)] = target
+        data['label'] = np.array(padded_text)
+        return data
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx
\ No newline at end of file
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 32de2b3fc..8caa29e29 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -267,6 +267,51 @@ class PRENResizeImg(object):
         data['image'] = resized_img.astype(np.float32)
         return data
 
+class SPINRecResizeImg(object):
+    def __init__(self,
+                 image_shape,
+                 interpolation=2,
+                 mean=(127.5, 127.5, 127.5),
+                 std=(127.5, 127.5, 127.5),
+                 **kwargs):
+        self.image_shape = image_shape
+        
+        self.mean = np.array(mean, dtype=np.float32)
+        self.std = np.array(std, dtype=np.float32)
+        self.interpolation = interpolation
+
+    def __call__(self, data):
+        img = data['image']
+        # different interpolation type corresponding the OpenCV
+        if self.interpolation == 0:
+            interpolation = cv2.INTER_NEAREST
+        elif self.interpolation == 1:
+            interpolation = cv2.INTER_LINEAR
+        elif self.interpolation == 2:
+            interpolation = cv2.INTER_CUBIC
+        elif self.interpolation == 3:
+            interpolation = cv2.INTER_AREA
+        else:
+            raise Exception("Unsupported interpolation type !!!")
+        # Deal with the image error during image loading
+        if img is None:
+            return None
+
+        img = cv2.resize(img, tuple(self.image_shape), interpolation)
+        img = np.array(img, np.float32)
+        img = np.expand_dims(img, -1)
+        img = img.transpose((2, 0, 1))
+        # normalize the image
+        to_rgb = False
+        img = img.copy().astype(np.float32)
+        mean = np.float64(self.mean.reshape(1, -1))
+        stdinv = 1 / np.float64(self.std.reshape(1, -1))
+        if to_rgb:
+            cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
+        img -= mean
+        img *= stdinv
+        data['image'] = img
+        return data
 
 def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
     imgC, imgH, imgW_min, imgW_max = image_shape
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index de8419b7c..f748b94cf 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
 from .rec_aster_loss import AsterLoss
 from .rec_pren_loss import PRENLoss
 from .rec_multi_loss import MultiLoss
+from .rec_spin_att_loss import SPINAttentionLoss
 
 # cls loss
 from .cls_loss import ClsLoss
@@ -61,7 +62,8 @@ def build_loss(config):
         'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
         'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
         'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
-        'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
+        'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
+        'SPINAttentionLoss'
     ]
     config = copy.deepcopy(config)
     module_name = config.pop('name')
diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py
new file mode 100644
index 000000000..37fd93da5
--- /dev/null
+++ b/ppocr/losses/rec_spin_att_loss.py
@@ -0,0 +1,41 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SPINAttentionLoss(nn.Layer):
+    def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
+        super(SPINAttentionLoss, self).__init__()
+        self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
+
+    def forward(self, predicts, batch):
+        targets = batch[1].astype("int64")
+        targets = targets[:, 1:] # remove [eos] in label
+
+        label_lengths = batch[2].astype('int64')
+        batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
+            1], predicts.shape[2]
+        assert len(targets.shape) == len(list(predicts.shape)) - 1, \
+            "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+        inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
+        targets = paddle.reshape(targets, [-1])
+
+        return {'loss': self.loss_func(inputs, targets)}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 072d6e0f8..6b525326a 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -32,10 +32,11 @@ def build_backbone(config, model_type):
         from .rec_micronet import MicroNet
         from .rec_efficientb3_pren import EfficientNetb3_PREN
         from .rec_svtrnet import SVTRNet
+        from .rec_resnet_32 import ResNet32
         support_dict = [
             'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
             "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
-            'SVTRNet'
+            'SVTRNet', 'ResNet32'
         ]
     elif model_type == "e2e":
         from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py
new file mode 100644
index 000000000..0b072dc5f
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_32.py
@@ -0,0 +1,289 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from: 
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+__all__ = ["ResNet32"]
+
+conv_weight_attr = nn.initializer.KaimingNormal()
+
+class ResNet32(nn.Layer):
+    """
+    Feature Extractor is proposed in  FAN Ref [1]
+
+    Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
+    """
+
+    def __init__(self, in_channels, out_channels=512):
+        """
+
+        Args:
+            in_channels (int): input channel
+            output_channel (int): output channel
+        """
+        super(ResNet32, self).__init__()
+        self.out_channels = out_channels
+        self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
+
+    def forward(self, inputs):
+        """
+        Args:
+            inputs (torch.Tensor): input feature
+
+        Returns:
+             torch.Tensor: output feature
+
+        """
+        return self.ConvNet(inputs)
+
+class BasicBlock(nn.Layer):
+    """Res-net Basic Block"""
+    expansion = 1
+
+    def __init__(self, inplanes, planes,
+                 stride=1, downsample=None,
+                 norm_type='BN', **kwargs):
+        """
+        Args:
+            inplanes (int): input channel
+            planes (int): channels of the middle feature
+            stride (int): stride of the convolution
+            downsample (int): type of the down_sample
+            norm_type (str): type of the normalization
+            **kwargs (None): backup parameter
+        """
+        super(BasicBlock, self).__init__()
+        self.conv1 = self._conv3x3(inplanes, planes)
+        self.bn1 = nn.BatchNorm2D(planes)
+        self.conv2 = self._conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2D(planes)
+        self.relu = nn.ReLU()
+        self.downsample = downsample
+        self.stride = stride
+
+    def _conv3x3(self, in_planes, out_planes, stride=1):
+        """
+
+        Args:
+            in_planes (int): input channel
+            out_planes (int): channels of the middle feature
+            stride (int): stride of the convolution
+        Returns:
+            nn.Module: Conv2D with kernel = 3
+
+        """
+
+        return nn.Conv2D(in_planes, out_planes,
+                         kernel_size=3, stride=stride,
+                         padding=1, weight_attr=conv_weight_attr,
+                         bias_attr=False)
+
+    def forward(self, x):
+        """
+        Args:
+            x (torch.Tensor): input feature
+
+        Returns:
+            torch.Tensor: output feature of the BasicBlock
+
+        """
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+class ResNet(nn.Layer):
+    """Res-Net network structure"""
+    def __init__(self, input_channel,
+                 output_channel, block, layers):
+        """
+
+        Args:
+            input_channel (int): input channel
+            output_channel (int): output channel
+            block (BasicBlock): convolution block
+            layers (list): layers of the block
+        """
+        super(ResNet, self).__init__()
+
+        self.output_channel_block = [int(output_channel / 4),
+                                     int(output_channel / 2),
+                                     output_channel,
+                                     output_channel]
+
+        self.inplanes = int(output_channel / 8)
+        self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
+                                 kernel_size=3, stride=1, 
+                                 padding=1, 
+                                 weight_attr=conv_weight_attr,
+                                 bias_attr=False)
+        self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
+        self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
+                                 kernel_size=3, stride=1,
+                                 padding=1, 
+                                 weight_attr=conv_weight_attr,
+                                 bias_attr=False)
+        self.bn0_2 = nn.BatchNorm2D(self.inplanes)
+        self.relu = nn.ReLU()
+
+        self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+        self.layer1 = self._make_layer(block,
+                                       self.output_channel_block[0],
+                                       layers[0])
+        self.conv1 = nn.Conv2D(self.output_channel_block[0],
+                               self.output_channel_block[0],
+                               kernel_size=3, stride=1,
+                               padding=1, 
+                               weight_attr=conv_weight_attr,
+                               bias_attr=False)
+        self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
+
+        self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+        self.layer2 = self._make_layer(block,
+                                       self.output_channel_block[1],
+                                       layers[1], stride=1)
+        self.conv2 = nn.Conv2D(self.output_channel_block[1],
+                               self.output_channel_block[1],
+                               kernel_size=3, stride=1,
+                               padding=1, 
+                               weight_attr=conv_weight_attr,
+                               bias_attr=False,)
+        self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
+
+        self.maxpool3 = nn.MaxPool2D(kernel_size=2,
+                                     stride=(2, 1),
+                                     padding=(0, 1))
+        self.layer3 = self._make_layer(block, self.output_channel_block[2],
+                                       layers[2], stride=1)
+        self.conv3 = nn.Conv2D(self.output_channel_block[2],
+                               self.output_channel_block[2],
+                               kernel_size=3, stride=1,
+                               padding=1, 
+                               weight_attr=conv_weight_attr,
+                               bias_attr=False)
+        self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
+
+        self.layer4 = self._make_layer(block, self.output_channel_block[3],
+                                       layers[3], stride=1)
+        self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
+                                 self.output_channel_block[3],
+                                 kernel_size=2, stride=(2, 1),
+                                 padding=(0, 1), 
+                                 weight_attr=conv_weight_attr,
+                                 bias_attr=False)
+        self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
+        self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
+                                 self.output_channel_block[3],
+                                 kernel_size=2, stride=1,
+                                 padding=0, 
+                                 weight_attr=conv_weight_attr,
+                                 bias_attr=False)
+        self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        """
+
+        Args:
+            block (block): convolution block
+            planes (int): input channels
+            blocks (list): layers of the block
+            stride (int): stride of the convolution
+
+        Returns:
+            nn.Sequential: the combination of the convolution block
+
+        """
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2D(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride,
+                          weight_attr=conv_weight_attr, 
+                          bias_attr=False),
+                nn.BatchNorm2D(planes * block.expansion),
+            )
+
+        layers = list()
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        """
+        Args:
+            x (torch.Tensor): input feature
+
+        Returns:
+            torch.Tensor: output feature of the Resnet
+
+        """
+        x = self.conv0_1(x)
+        x = self.bn0_1(x)
+        x = self.relu(x)
+        x = self.conv0_2(x)
+        x = self.bn0_2(x)
+        x = self.relu(x)
+
+        x = self.maxpool1(x)
+        x = self.layer1(x)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.maxpool2(x)
+        x = self.layer2(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+
+        x = self.maxpool3(x)
+        x = self.layer3(x)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        x = self.relu(x)
+
+        x = self.layer4(x)
+        x = self.conv4_1(x)
+        x = self.bn4_1(x)
+        x = self.relu(x)
+        x = self.conv4_2(x)
+        x = self.bn4_2(x)
+        x = self.relu(x)
+        return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 1670ea38e..9b53462b8 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -33,6 +33,7 @@ def build_head(config):
     from .rec_aster_head import AsterHead
     from .rec_pren_head import PRENHead
     from .rec_multi_head import MultiHead
+    from .rec_spin_att_head import SPINAttentionHead
 
     # cls head
     from .cls_head import ClsHead
@@ -46,7 +47,7 @@ def build_head(config):
         'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
         'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
         'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
-        'MultiHead'
+        'MultiHead', 'SPINAttentionHead'
     ]
 
     #table head
diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py
new file mode 100644
index 000000000..94e69a7ed
--- /dev/null
+++ b/ppocr/modeling/heads/rec_spin_att_head.py
@@ -0,0 +1,203 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+class SPINAttentionHead(nn.Layer):
+    def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+        super(SPINAttentionHead, self).__init__()
+        self.input_size = in_channels
+        self.hidden_size = hidden_size
+        self.num_classes = out_channels
+
+        self.attention_cell = AttentionLSTMCell(
+            in_channels, hidden_size, out_channels, use_gru=False)
+        self.generator = nn.Linear(hidden_size, out_channels)
+
+    def _char_to_onehot(self, input_char, onehot_dim):
+        input_ont_hot = F.one_hot(input_char, onehot_dim)
+        return input_ont_hot
+
+    def forward(self, inputs, targets=None, batch_max_length=25):
+        batch_size = paddle.shape(inputs)[0]
+        num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
+
+        hidden = (paddle.zeros((batch_size, self.hidden_size)),
+                    paddle.zeros((batch_size, self.hidden_size)))
+        output_hiddens = []
+        if self.training: # for train
+            targets = targets[0]
+            for i in range(num_steps):
+                char_onehots = self._char_to_onehot(
+                    targets[:, i], onehot_dim=self.num_classes)
+                (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+                                                               char_onehots)
+                output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+            output = paddle.concat(output_hiddens, axis=1)
+            probs = self.generator(output)        
+        else:
+            targets = paddle.zeros(shape=[batch_size], dtype="int32")
+            probs = None
+            char_onehots = None
+            outputs = None
+            alpha = None
+
+            for i in range(num_steps):
+                char_onehots = self._char_to_onehot(
+                    targets, onehot_dim=self.num_classes)
+                (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+                                                               char_onehots)
+                probs_step = self.generator(outputs)
+                if probs is None:
+                    probs = paddle.unsqueeze(probs_step, axis=1)
+                else:
+                    probs = paddle.concat(
+                        [probs, paddle.unsqueeze(
+                            probs_step, axis=1)], axis=1)
+                next_input = probs_step.argmax(axis=1)
+                targets = next_input
+        if not self.training:
+            probs = paddle.nn.functional.softmax(probs, axis=2)
+        return probs
+
+
+class AttentionGRUCell(nn.Layer):
+    def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+        super(AttentionGRUCell, self).__init__()
+        self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+        self.h2h = nn.Linear(hidden_size, hidden_size)
+        self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+
+        self.rnn = nn.GRUCell(
+            input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+        self.hidden_size = hidden_size
+
+    def forward(self, prev_hidden, batch_H, char_onehots):
+
+        batch_H_proj = self.i2h(batch_H)
+        prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
+
+        res = paddle.add(batch_H_proj, prev_hidden_proj)
+        res = paddle.tanh(res)
+        e = self.score(res)
+
+        alpha = F.softmax(e, axis=1)
+        alpha = paddle.transpose(alpha, [0, 2, 1])
+        context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+        concat_context = paddle.concat([context, char_onehots], 1)
+
+        cur_hidden = self.rnn(concat_context, prev_hidden)
+
+        return cur_hidden, alpha
+
+
+class AttentionLSTM(nn.Layer):
+    def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+        super(AttentionLSTM, self).__init__()
+        self.input_size = in_channels
+        self.hidden_size = hidden_size
+        self.num_classes = out_channels
+
+        self.attention_cell = AttentionLSTMCell(
+            in_channels, hidden_size, out_channels, use_gru=False)
+        self.generator = nn.Linear(hidden_size, out_channels)
+
+    def _char_to_onehot(self, input_char, onehot_dim):
+        input_ont_hot = F.one_hot(input_char, onehot_dim)
+        return input_ont_hot
+
+    def forward(self, inputs, targets=None, batch_max_length=25):
+        batch_size = inputs.shape[0]
+        num_steps = batch_max_length
+
+        hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
+            (batch_size, self.hidden_size)))
+        output_hiddens = []
+
+        if targets is not None:
+            for i in range(num_steps):
+                # one-hot vectors for a i-th char
+                char_onehots = self._char_to_onehot(
+                    targets[:, i], onehot_dim=self.num_classes)
+                hidden, alpha = self.attention_cell(hidden, inputs,
+                                                    char_onehots)
+
+                hidden = (hidden[1][0], hidden[1][1])
+                output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
+            output = paddle.concat(output_hiddens, axis=1)
+            probs = self.generator(output)
+
+        else:
+            targets = paddle.zeros(shape=[batch_size], dtype="int32")
+            probs = None
+
+            for i in range(num_steps):
+                char_onehots = self._char_to_onehot(
+                    targets, onehot_dim=self.num_classes)
+                hidden, alpha = self.attention_cell(hidden, inputs,
+                                                    char_onehots)
+                probs_step = self.generator(hidden[0])
+                hidden = (hidden[1][0], hidden[1][1])
+                if probs is None:
+                    probs = paddle.unsqueeze(probs_step, axis=1)
+                else:
+                    probs = paddle.concat(
+                        [probs, paddle.unsqueeze(
+                            probs_step, axis=1)], axis=1)
+
+                next_input = probs_step.argmax(axis=1)
+
+                targets = next_input
+
+        return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+    def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+        super(AttentionLSTMCell, self).__init__()
+        self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+        self.h2h = nn.Linear(hidden_size, hidden_size)
+        self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+        if not use_gru:
+            self.rnn = nn.LSTMCell(
+                input_size=input_size + num_embeddings, hidden_size=hidden_size)
+        else:
+            self.rnn = nn.GRUCell(
+                input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+        self.hidden_size = hidden_size
+
+    def forward(self, prev_hidden, batch_H, char_onehots):
+        batch_H_proj = self.i2h(batch_H)
+        prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+        res = paddle.add(batch_H_proj, prev_hidden_proj)
+        res = paddle.tanh(res)
+        e = self.score(res)
+
+        alpha = F.softmax(e, axis=1)
+        alpha = paddle.transpose(alpha, [0, 2, 1])
+        context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+        concat_context = paddle.concat([context, char_onehots], 1)
+        cur_hidden = self.rnn(concat_context, prev_hidden)
+
+        return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index c8a774b8c..32e626c3f 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -47,6 +47,67 @@ class EncoderWithRNN(nn.Layer):
         x, _ = self.lstm(x)
         return x
 
+class BidirectionalLSTM(nn.Layer):
+    def __init__(self, input_size,
+                 hidden_size,
+                 output_size=None,
+                 num_layers=1,
+                 dropout=0,
+                 direction=False,
+                 time_major=False,
+                 with_linear=False):
+        super(BidirectionalLSTM, self).__init__()
+        self.with_linear = with_linear
+        self.rnn = nn.LSTM(input_size,
+                           hidden_size,
+                           num_layers=num_layers,
+                           dropout=dropout,
+                           direction=direction,
+                           time_major=time_major)
+
+        # text recognition the specified structure LSTM with linear
+        if self.with_linear:
+            self.linear = nn.Linear(hidden_size * 2, output_size)
+
+    def forward(self, input_feature):
+        """
+
+        Args:
+            input_feature (Torch.Tensor): visual feature [batch_size x T x input_size]
+
+        Returns:
+            Torch.Tensor: LSTM output contextual feature [batch_size x T x output_size]
+
+        """
+
+        # self.rnn.flatten_parameters() # error in export_model
+        recurrent, _ = self.rnn(input_feature)  # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
+        if self.with_linear:
+            output = self.linear(recurrent)     # batch_size x T x output_size
+            return output
+        return recurrent
+
+class EncoderWithCascadeRNN(nn.Layer):
+    def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
+        super(EncoderWithCascadeRNN, self).__init__()
+        self.out_channels = out_channels[-1]
+        self.encoder = nn.LayerList(
+            [BidirectionalLSTM(
+                in_channels if i == 0 else out_channels[i - 1], 
+                hidden_size, 
+                output_size=out_channels[i], 
+                num_layers=1, 
+                direction='bidirectional', 
+                with_linear=with_linear) 
+            for i in range(num_layers)]
+        )
+        
+
+    def forward(self, x):
+        for i, l in enumerate(self.encoder):
+            x = l(x)
+        return x
+
 
 class EncoderWithFC(nn.Layer):
     def __init__(self, in_channels, hidden_size):
@@ -166,13 +227,17 @@ class SequenceEncoder(nn.Layer):
                 'reshape': Im2Seq,
                 'fc': EncoderWithFC,
                 'rnn': EncoderWithRNN,
-                'svtr': EncoderWithSVTR
+                'svtr': EncoderWithSVTR,
+                'cascadernn': EncoderWithCascadeRNN
             }
             assert encoder_type in support_encoder_dict, '{} must in {}'.format(
                 encoder_type, support_encoder_dict.keys())
             if encoder_type == "svtr":
                 self.encoder = support_encoder_dict[encoder_type](
                     self.encoder_reshape.out_channels, **kwargs)
+            elif encoder_type == 'cascadernn':
+                self.encoder = support_encoder_dict[encoder_type](
+                    self.encoder_reshape.out_channels, hidden_size, **kwargs)
             else:
                 self.encoder = support_encoder_dict[encoder_type](
                     self.encoder_reshape.out_channels, hidden_size)
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 405ab3cc6..7e4ffdf46 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -18,8 +18,10 @@ __all__ = ['build_transform']
 def build_transform(config):
     from .tps import TPS
     from .stn import STN_ON
+    from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
 
-    support_dict = ['TPS', 'STN_ON']
+
+    support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
 
     module_name = config.pop('name')
     assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py
new file mode 100644
index 000000000..331c82aae
--- /dev/null
+++ b/ppocr/modeling/transforms/gaspin_transformer.py
@@ -0,0 +1,286 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import itertools
+import functools
+from .tps import GridGenerator
+
+'''This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
+'''
+
+class SP_TransformerNetwork(nn.Layer):
+    """
+    Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
+    Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+    """
+
+    def __init__(self, nc=1, default_type=5):
+        """ Based on SPIN
+        Args:
+            nc (int): number of input channels (usually in 1 or 3)
+            default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
+        """
+        super(SP_TransformerNetwork, self).__init__()
+        self.power_list = self.cal_K(default_type)
+        self.sigmoid = nn.Sigmoid()
+        self.bn = nn.InstanceNorm2D(nc)
+
+    def cal_K(self, k=5):
+        """
+
+        Args:
+            k (int): the complexity of transformation intensities (by default set to 6 as the paper)
+
+        Returns:
+            List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
+
+        """
+        from math import log
+        x = []
+        if k != 0:
+            for i in range(1, k+1):
+                lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
+                upper = round(1/lower, 2)
+                x.append(lower)
+                x.append(upper)
+        x.append(1.00)
+        return x
+
+    def forward(self, batch_I, weights, offsets, lambda_color=None):
+        """
+
+        Args:
+            batch_I (torch.Tensor): batch of input images [batch_size x nc x I_height x I_width]
+            weights:
+            offsets: the predicted offset by AIN, a scalar
+            lambda_color: the learnable update gate \alpha in Equa. (5) as
+                          g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
+
+        Returns:
+            torch.Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
+                        [batch_size x I_channel_num x I_r_height x I_r_width]
+
+        """
+        batch_I = (batch_I + 1) * 0.5
+        if offsets is not None:
+            batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
+        batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
+        batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
+
+        batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
+        batch_weight_sum = self.bn(batch_weight_sum)
+        batch_weight_sum = self.sigmoid(batch_weight_sum)
+        batch_weight_sum = batch_weight_sum * 2 - 1
+        return batch_weight_sum
+
+class GA_SPIN_Transformer(nn.Layer):
+    """
+    Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
+
+
+    Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+    """
+
+    def __init__(self, in_channels=1,
+                 I_r_size=(32, 100),
+                 offsets=False,
+                 norm_type='BN',
+                 default_type=6,
+                 loc_lr=1,
+                 stn=True):
+        """
+        Args:
+            in_channels (int): channel of input features,
+                                set it to 1 if the grayscale images and 3 if RGB input
+            I_r_size (tuple): size of rectified images (used in STN transformations)
+            inputDataType (str): the type of input data,
+                                only support 'torch.cuda.FloatTensor' this version
+            offsets (bool): set it to False if use SPN w.o. AIN,
+                            and set it to True if use SPIN (both with SPN and AIN)
+            norm_type (str): the normalization type of the module,
+                            set it to 'BN' by default, 'IN' optionally
+            default_type (int): the K chromatic space,
+                                set it to 3/5/6 depend on the complexity of transformation intensities
+            loc_lr (float): learning rate of location network
+
+        """
+        super(GA_SPIN_Transformer, self).__init__()
+        self.nc = in_channels
+        self.spt = True
+        self.offsets = offsets
+        self.stn = stn  # set to True in GA-SPIN, while set it to False in SPIN
+        self.I_r_size = I_r_size
+        self.out_channels = in_channels
+        if norm_type == 'BN':
+            norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
+        elif norm_type == 'IN':
+            norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
+                                           use_global_stats=False)
+        else:
+            raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+
+        if self.spt:
+            self.sp_net = SP_TransformerNetwork(in_channels,
+                                                default_type)
+            self.spt_convnet = nn.Sequential(
+                                  # 32*100
+                                  nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
+                                  norm_layer(32), nn.ReLU(),
+                                  nn.MaxPool2D(kernel_size=2, stride=2),
+                                  # 16*50
+                                  nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
+                                  norm_layer(64), nn.ReLU(),
+                                  nn.MaxPool2D(kernel_size=2, stride=2),
+                                  # 8*25
+                                  nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
+                                  norm_layer(128), nn.ReLU(),
+                                  nn.MaxPool2D(kernel_size=2, stride=2),
+                                  # 4*12
+            )
+            self.stucture_fc1 = nn.Sequential(
+                                  nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
+                                  norm_layer(256), nn.ReLU(),
+                                  nn.MaxPool2D(kernel_size=2, stride=2),
+                                  nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
+                                  norm_layer(256), nn.ReLU(),  # 2*6
+                                  nn.MaxPool2D(kernel_size=2, stride=2),
+                                  nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
+                                  norm_layer(512), nn.ReLU(),  # 1*3
+                                  nn.AdaptiveAvgPool2D(1),
+                                  nn.Flatten(1, -1),  # batch_size x 512
+                                  nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
+                                  nn.BatchNorm1D(256), nn.ReLU()
+                                )
+            self.out_weight = 2*default_type+1
+            self.spt_length = 2*default_type+1
+            if offsets:
+                self.out_weight += 1
+            if self.stn:
+                self.F = 20
+                self.out_weight += self.F * 2
+                self.GridGenerator = GridGenerator(self.F*2, self.F)
+                
+            # self.out_weight*=nc
+            # Init structure_fc2 in LocalizationNetwork
+            initial_bias = self.init_spin(default_type*2)
+            initial_bias = initial_bias.reshape(-1)
+            param_attr = ParamAttr(
+                learning_rate=loc_lr,
+                initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
+            bias_attr = ParamAttr(
+                learning_rate=loc_lr,
+                initializer=nn.initializer.Assign(initial_bias))
+            self.stucture_fc2 = nn.Linear(256, self.out_weight,
+                                weight_attr=param_attr,
+                                bias_attr=bias_attr)
+            self.sigmoid = nn.Sigmoid()
+
+            if offsets:
+                self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
+                                                          3, 1, 1,
+                                                          bias_attr=False),
+                                                norm_layer(16),
+                                                nn.ReLU(),)
+                self.offset_fc2 = nn.Conv2D(16, in_channels,
+                                            3, 1, 1)
+                self.pool = nn.MaxPool2D(2, 2)
+
+    def init_spin(self, nz):
+        """
+        Args:
+            nz (int): number of paired \betas exponents, which means the value of K x 2
+
+        """
+        init_id = [0.00]*nz+[5.00]
+        if self.offsets:
+            init_id += [-5.00]
+            # init_id *=3
+        init = np.array(init_id)
+
+        if self.stn:
+            F = self.F
+            ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+            ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+            ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+            ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+            ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+            initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+            initial_bias = initial_bias.reshape(-1)
+            init = np.concatenate([init, initial_bias], axis=0)
+        return init
+
+    def forward(self, x, return_weight=False):
+        """
+        Args:
+            x (torch.cuda.FloatTensor): input image batch
+            return_weight (bool): set to False by default,
+                                  if set to True return the predicted offsets of AIN, denoted as x_{offsets}
+
+        Returns:
+            torch.Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
+        """
+
+        if self.spt:
+            feat = self.spt_convnet(x)
+            fc1 = self.stucture_fc1(feat)
+            sp_weight_fusion = self.stucture_fc2(fc1)
+            sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
+            if self.offsets:  # SPIN w. AIN
+                lambda_color = sp_weight_fusion[:, self.spt_length, 0]
+                lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+                sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+                offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
+
+                assert offsets.shape[2] == 2  # 2
+                assert offsets.shape[3] == 6  # 16
+                offsets = self.sigmoid(offsets)  # v12
+
+                if return_weight:
+                    return offsets
+                offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
+
+                if self.stn:
+                    batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
+                    build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+                    build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+                                                                   self.I_r_size[0],
+                                                                   self.I_r_size[1],
+                                                                   2])
+
+            else:  # SPIN w.o. AIN
+                sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+                lambda_color, offsets = None, None
+
+                if self.stn:
+                    batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
+                    build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+                    build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+                                                                   self.I_r_size[0],
+                                                                   self.I_r_size[1],
+                                                                   2])
+
+            x = self.sp_net(x, sp_weight, offsets, lambda_color)
+            if self.stn:
+                x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
+        return x
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f50b5f1c5..cf2575ee0 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
 from .fce_postprocess import FCEPostProcess
 from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
     DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
-    SEEDLabelDecode, PRENLabelDecode
+    SEEDLabelDecode, PRENLabelDecode, SPINAttnLabelDecode
 from .cls_postprocess import ClsPostProcess
 from .pg_postprocess import PGPostProcess
 from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
         'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
         'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
         'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
-        'DistillationSARLabelDecode'
+        'DistillationSARLabelDecode', 'SPINAttnLabelDecode'
     ]
 
     if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index bf0fd890b..0df8f3ccd 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -752,3 +752,82 @@ class PRENLabelDecode(BaseRecLabelDecode):
             return text
         label = self.decode(label)
         return text, label
+
+class SPINAttnLabelDecode(BaseRecLabelDecode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self, character_dict_path=None, use_space_char=False,
+                 **kwargs):
+        super(SPINAttnLabelDecode, self).__init__(character_dict_path,
+                                              use_space_char)
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = dict_character
+        dict_character = [self.beg_str] + [self.end_str] + dict_character
+        return dict_character
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+        """ convert text-index into text-label. """
+        result_list = []
+        ignored_tokens = self.get_ignored_tokens()
+        [beg_idx, end_idx] = self.get_ignored_tokens()
+        batch_size = len(text_index)
+        for batch_idx in range(batch_size):
+            char_list = []
+            conf_list = []
+            for idx in range(len(text_index[batch_idx])):
+                if text_index[batch_idx][idx] == int(beg_idx):
+                    continue
+                if int(text_index[batch_idx][idx]) == int(end_idx):
+                    break
+                if is_remove_duplicate:
+                    # only for predict
+                    if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+                            batch_idx][idx]:
+                        continue
+                char_list.append(self.character[int(text_index[batch_idx][
+                    idx])])
+                if text_prob is not None:
+                    conf_list.append(text_prob[batch_idx][idx])
+                else:
+                    conf_list.append(1)
+            text = ''.join(char_list)
+            result_list.append((text.lower(), np.mean(conf_list).tolist()))
+        return result_list
+
+    def __call__(self, preds, label=None, *args, **kwargs):
+        """
+        text = self.decode(text)
+        if label is None:
+            return text
+        else:
+            label = self.decode(label, is_remove_duplicate=False)
+            return text, label
+        """
+        if isinstance(preds, paddle.Tensor):
+            preds = preds.numpy()
+
+        preds_idx = preds.argmax(axis=2)
+        preds_prob = preds.max(axis=2)
+        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+        if label is None:
+            return text
+        label = self.decode(label, is_remove_duplicate=False)
+        return text, label
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx
\ No newline at end of file
diff --git a/ppocr/utils/dict/spin_dict.txt b/ppocr/utils/dict/spin_dict.txt
new file mode 100644
index 000000000..8ee8347fd
--- /dev/null
+++ b/ppocr/utils/dict/spin_dict.txt
@@ -0,0 +1,68 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+:
+(
+'
+-
+,
+%
+>
+.
+[
+?
+)
+"
+=
+_
+*
+]
+;
+&
++
+$
+@
+/
+|
+!
+<
+#
+`
+{
+~
+\
+}
+^
\ No newline at end of file
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 000000000..e53396a03
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,118 @@
+Global:
+  use_gpu: True
+  epoch_num: 6
+  log_smooth_window: 50
+  print_batch_step: 50
+  save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+  save_epoch_step: 3
+  # evaluation is run every 5000 iterations after the 4000th iteration
+  eval_batch_step: [0, 2000]
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  # for data or label process
+  character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+  save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Piecewise
+    decay_epochs: [3, 4, 5]
+    values: [0.001, 0.0003, 0.00009, 0.000027] 
+
+  clip_norm: 5
+
+Architecture:
+  model_type: rec
+  algorithm: SPIN
+  in_channels: 1
+  Transform:
+    name: GA_SPIN
+    offsets: True
+    default_type: 6
+    loc_lr: 0.1
+    stn: True
+  Backbone:
+    name: ResNet32
+    out_channels: 512
+  Neck:
+    name: SequenceEncoder
+    encoder_type: cascadernn 
+    hidden_size: 256
+    out_channels: [256, 512]
+    with_linear: True
+  Head:
+    name: SPINAttentionHead  
+    hidden_size: 256
+    
+
+Loss:
+  name: SPINAttentionLoss
+  ignore_index: 0
+
+PostProcess:
+  name: SPINAttnLabelDecode
+  character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+  use_space_char: False
+
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+  is_filter: True
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data/
+    label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+    transforms:
+      - NRTRDecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - SPINAttnLabelEncode: # Class handling label
+      - SPINRecResizeImg:
+          image_shape: [100, 32]
+          interpolation : 2
+          mean: [127.5]
+          std: [127.5]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 128
+    drop_last: True
+    num_workers: 4
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data
+    label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+    transforms:
+      - NRTRDecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - SPINAttnLabelEncode: # Class handling label
+      - SPINRecResizeImg:
+          image_shape: [100, 32]
+          interpolation : 2
+          mean: [127.5]
+          std: [127.5]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 1
+    num_workers: 1
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
new file mode 100644
index 000000000..4915055a5
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_r32_gaspin_bilstm_att
+python:python
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False|False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,100]}]
diff --git a/tools/export_model.py b/tools/export_model.py
index 3ea0228f8..b8bc5e1ed 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
                 shape=[None, 3, 64, 512], dtype="float32"),
         ]
         model = to_static(model, input_spec=other_shape)
+    elif arch_config["algorithm"] == "SPIN":
+        other_shape = [
+            paddle.static.InputSpec(
+                shape=[None, 1, 32, 100], dtype="float32"),
+        ]
+        model = to_static(model, input_spec=other_shape)
     else:
         infer_shape = [3, -1, -1]
         if arch_config["model_type"] == "rec":
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 3664ef2ca..09e13d8dc 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -69,6 +69,12 @@ class TextRecognizer(object):
                 "character_dict_path": args.rec_char_dict_path,
                 "use_space_char": args.use_space_char
             }
+        elif self.rec_algorithm == "SPIN":
+            postprocess_params = {
+                'name': 'SPINAttnLabelDecode',
+                "character_dict_path": args.rec_char_dict_path,
+                "use_space_char": args.use_space_char
+            }
         self.postprocess_op = build_post_process(postprocess_params)
         self.predictor, self.input_tensor, self.output_tensors, self.config = \
             utility.create_predictor(args, 'rec', logger)
@@ -250,6 +256,22 @@ class TextRecognizer(object):
 
         return padding_im, resize_shape, pad_shape, valid_ratio
 
+    def resize_norm_img_spin(self, img):
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        # return padding_im
+        img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
+        img = np.array(img, np.float32)
+        img = np.expand_dims(img, -1)
+        img = img.transpose((2, 0, 1))
+        mean = [127.5]
+        std = [127.5]
+        mean = np.array(mean, dtype=np.float32)
+        std = np.array(std, dtype=np.float32)
+        mean = np.float32(mean.reshape(1, -1))
+        stdinv = 1 / np.float32(std.reshape(1, -1))
+        img -= mean
+        img *= stdinv
+        return img
     def __call__(self, img_list):
         img_num = len(img_list)
         # Calculate the aspect ratio of all text bars
@@ -300,6 +322,10 @@ class TextRecognizer(object):
                                                          self.rec_image_shape)
                     norm_img = norm_img[np.newaxis, :]
                     norm_img_batch.append(norm_img)
+                elif self.rec_algorithm == 'SPIN':
+                    norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
+                    norm_img = norm_img[np.newaxis, :]
+                    norm_img_batch.append(norm_img)
                 else:
                     norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                     max_wh_ratio)
diff --git a/tools/program.py b/tools/program.py
index aa0d2698c..51c73d3e5 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -207,7 +207,7 @@ def train(config,
     model.train()
 
     use_srn = config['Architecture']['algorithm'] == "SRN"
-    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
     extra_input = False
     if config['Architecture']['algorithm'] == 'Distillation':
         for key in config['Architecture']["Models"]:
@@ -564,7 +564,8 @@ def preprocess(is_train=False):
     assert alg in [
         'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
         'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
-        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
+        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
+        'SPIN'
     ]
 
     if use_xpu: