add pr
parent
6220899e4b
commit
bbca1e0d66
|
@ -10,7 +10,7 @@ __pycache__/
|
|||
inference/
|
||||
inference_results/
|
||||
output/
|
||||
|
||||
train_data/
|
||||
*.DS_Store
|
||||
*.vs
|
||||
*.user
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 6
|
||||
log_smooth_window: 50
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [3, 4, 5]
|
||||
values: [0.001, 0.0003, 0.00009, 0.000027]
|
||||
clip_norm: 5
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SPIN
|
||||
in_channels: 1
|
||||
Transform:
|
||||
name: GA_SPIN
|
||||
offsets: True
|
||||
default_type: 6
|
||||
loc_lr: 0.1
|
||||
stn: True
|
||||
Backbone:
|
||||
name: ResNet32
|
||||
out_channels: 512
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: cascadernn
|
||||
hidden_size: 256
|
||||
out_channels: [256, 512]
|
||||
with_linear: True
|
||||
Head:
|
||||
name: SPINAttentionHead
|
||||
hidden_size: 256
|
||||
|
||||
|
||||
Loss:
|
||||
name: SPINAttentionLoss
|
||||
ignore_index: 0
|
||||
|
||||
PostProcess:
|
||||
name: SPINAttnLabelDecode
|
||||
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
- SPINRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
interpolation : 2
|
||||
mean: [127.5]
|
||||
std: [127.5]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 8
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
- SPINRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
interpolation : 2
|
||||
mean: [127.5]
|
||||
std: [127.5]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 2
|
|
@ -66,6 +66,7 @@
|
|||
- [x] [SAR](./algorithm_rec_sar.md)
|
||||
- [x] [SEED](./algorithm_rec_seed.md)
|
||||
- [x] [SVTR](./algorithm_rec_svtr.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin.md)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -84,6 +85,7 @@
|
|||
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|
||||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
|
||||
> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
|
||||
> AAAI, 2020
|
||||
|
||||
SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。
|
||||
使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
|
||||
|
||||
训练
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
|
||||
```
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
|
||||
```
|
||||
|
||||
评估
|
||||
|
||||
```
|
||||
# GPU 评估, Global.pretrained_model 为待测权重
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
预测:
|
||||
|
||||
```
|
||||
# 预测使用的配置文件必须与训练一致
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
|
||||
```
|
||||
SPIN文本识别模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
由于C++预处理后处理还未支持SPIN,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{2020SPIN,
|
||||
title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
|
||||
author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
|
||||
journal={AAAI2020},
|
||||
year={2020},
|
||||
}
|
||||
```
|
|
@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||
- [x] [SAR](./algorithm_rec_sar_en.md)
|
||||
- [x] [SEED](./algorithm_rec_seed_en.md)
|
||||
- [x] [SVTR](./algorithm_rec_svtr_en.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -83,6 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|
||||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
|
||||
> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
|
||||
> AAAI, 2020
|
||||
|
||||
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
|
||||
```
|
||||
|
||||
For SPIN text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{2020SPIN,
|
||||
title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
|
||||
author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
|
||||
journal={AAAI2020},
|
||||
year={2020},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,131 @@
|
|||
D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\socks.py:58: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
|
||||
from collections import Callable
|
||||
D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\win32\lib\pywintypes.py:2: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
|
||||
import imp, sys, os
|
||||
D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:943: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
|
||||
collections.MutableMapping.register(ParseResults)
|
||||
D:\Projects\3rdparty\anaconda\envs\pd2\lib\site-packages\pkg_resources\_vendor\pyparsing.py:3245: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
|
||||
elif isinstance( exprs, collections.Iterable ):
|
||||
[2022/06/12 13:42:08] ppocr INFO: Architecture :
|
||||
[2022/06/12 13:42:08] ppocr INFO: Backbone :
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : ResNet32
|
||||
[2022/06/12 13:42:08] ppocr INFO: out_channels : 512
|
||||
[2022/06/12 13:42:08] ppocr INFO: Head :
|
||||
[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionHead
|
||||
[2022/06/12 13:42:08] ppocr INFO: Neck :
|
||||
[2022/06/12 13:42:08] ppocr INFO: encoder_type : cascadernn
|
||||
[2022/06/12 13:42:08] ppocr INFO: hidden_size : 256
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SequenceEncoder
|
||||
[2022/06/12 13:42:08] ppocr INFO: out_channels : [256, 512]
|
||||
[2022/06/12 13:42:08] ppocr INFO: with_linear : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: Transform :
|
||||
[2022/06/12 13:42:08] ppocr INFO: default_type : 6
|
||||
[2022/06/12 13:42:08] ppocr INFO: loc_lr : 0.1
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : GA_SPIN
|
||||
[2022/06/12 13:42:08] ppocr INFO: offsets : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: stn : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: algorithm : SPIN
|
||||
[2022/06/12 13:42:08] ppocr INFO: in_channels : 1
|
||||
[2022/06/12 13:42:08] ppocr INFO: model_type : rec
|
||||
[2022/06/12 13:42:08] ppocr INFO: Eval :
|
||||
[2022/06/12 13:42:08] ppocr INFO: dataset :
|
||||
[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data
|
||||
[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet
|
||||
[2022/06/12 13:42:08] ppocr INFO: transforms :
|
||||
[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage :
|
||||
[2022/06/12 13:42:08] ppocr INFO: channel_first : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR
|
||||
[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None
|
||||
[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg :
|
||||
[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32]
|
||||
[2022/06/12 13:42:08] ppocr INFO: interpolation : 2
|
||||
[2022/06/12 13:42:08] ppocr INFO: mean : [127.5]
|
||||
[2022/06/12 13:42:08] ppocr INFO: std : [127.5]
|
||||
[2022/06/12 13:42:08] ppocr INFO: KeepKeys :
|
||||
[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length']
|
||||
[2022/06/12 13:42:08] ppocr INFO: loader :
|
||||
[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8
|
||||
[2022/06/12 13:42:08] ppocr INFO: drop_last : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: num_workers : 2
|
||||
[2022/06/12 13:42:08] ppocr INFO: shuffle : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: Global :
|
||||
[2022/06/12 13:42:08] ppocr INFO: cal_metric_during_train : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt
|
||||
[2022/06/12 13:42:08] ppocr INFO: checkpoints : ./inference/rec_r32_gaspin_bilstm_att/best_accuracy
|
||||
[2022/06/12 13:42:08] ppocr INFO: distributed : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: epoch_num : 6
|
||||
[2022/06/12 13:42:08] ppocr INFO: eval_batch_step : [0, 2000]
|
||||
[2022/06/12 13:42:08] ppocr INFO: infer_img : doc/imgs_words/ch/word_1.jpg
|
||||
[2022/06/12 13:42:08] ppocr INFO: infer_mode : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: log_smooth_window : 50
|
||||
[2022/06/12 13:42:08] ppocr INFO: max_text_length : 25
|
||||
[2022/06/12 13:42:08] ppocr INFO: pretrained_model : None
|
||||
[2022/06/12 13:42:08] ppocr INFO: print_batch_step : 50
|
||||
[2022/06/12 13:42:08] ppocr INFO: save_epoch_step : 3
|
||||
[2022/06/12 13:42:08] ppocr INFO: save_inference_dir : None
|
||||
[2022/06/12 13:42:08] ppocr INFO: save_model_dir : ./output/rec/rec_r32_gaspin_bilstm_att/
|
||||
[2022/06/12 13:42:08] ppocr INFO: save_res_path : ./output/rec/predicts_r32_gaspin_bilstm_att.txt
|
||||
[2022/06/12 13:42:08] ppocr INFO: use_gpu : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: use_space_char : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: use_visualdl : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: Loss :
|
||||
[2022/06/12 13:42:08] ppocr INFO: ignore_index : 0
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SPINAttentionLoss
|
||||
[2022/06/12 13:42:08] ppocr INFO: Metric :
|
||||
[2022/06/12 13:42:08] ppocr INFO: is_filter : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: main_indicator : acc
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : RecMetric
|
||||
[2022/06/12 13:42:08] ppocr INFO: Optimizer :
|
||||
[2022/06/12 13:42:08] ppocr INFO: beta1 : 0.9
|
||||
[2022/06/12 13:42:08] ppocr INFO: beta2 : 0.999
|
||||
[2022/06/12 13:42:08] ppocr INFO: clip_norm : 5
|
||||
[2022/06/12 13:42:08] ppocr INFO: lr :
|
||||
[2022/06/12 13:42:08] ppocr INFO: decay_epochs : [3, 4, 5]
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : Piecewise
|
||||
[2022/06/12 13:42:08] ppocr INFO: values : [0.001, 0.0003, 9e-05, 2.7e-05]
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : AdamW
|
||||
[2022/06/12 13:42:08] ppocr INFO: PostProcess :
|
||||
[2022/06/12 13:42:08] ppocr INFO: character_dict_path : ./ppocr/utils/dict/spin_dict.txt
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SPINAttnLabelDecode
|
||||
[2022/06/12 13:42:08] ppocr INFO: use_space_char : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: Train :
|
||||
[2022/06/12 13:42:08] ppocr INFO: dataset :
|
||||
[2022/06/12 13:42:08] ppocr INFO: data_dir : ./train_data/ic15_data/
|
||||
[2022/06/12 13:42:08] ppocr INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']
|
||||
[2022/06/12 13:42:08] ppocr INFO: name : SimpleDataSet
|
||||
[2022/06/12 13:42:08] ppocr INFO: transforms :
|
||||
[2022/06/12 13:42:08] ppocr INFO: NRTRDecodeImage :
|
||||
[2022/06/12 13:42:08] ppocr INFO: channel_first : False
|
||||
[2022/06/12 13:42:08] ppocr INFO: img_mode : BGR
|
||||
[2022/06/12 13:42:08] ppocr INFO: SPINAttnLabelEncode : None
|
||||
[2022/06/12 13:42:08] ppocr INFO: SPINRecResizeImg :
|
||||
[2022/06/12 13:42:08] ppocr INFO: image_shape : [100, 32]
|
||||
[2022/06/12 13:42:08] ppocr INFO: interpolation : 2
|
||||
[2022/06/12 13:42:08] ppocr INFO: mean : [127.5]
|
||||
[2022/06/12 13:42:08] ppocr INFO: std : [127.5]
|
||||
[2022/06/12 13:42:08] ppocr INFO: KeepKeys :
|
||||
[2022/06/12 13:42:08] ppocr INFO: keep_keys : ['image', 'label', 'length']
|
||||
[2022/06/12 13:42:08] ppocr INFO: loader :
|
||||
[2022/06/12 13:42:08] ppocr INFO: batch_size_per_card : 8
|
||||
[2022/06/12 13:42:08] ppocr INFO: drop_last : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: num_workers : 4
|
||||
[2022/06/12 13:42:08] ppocr INFO: shuffle : True
|
||||
[2022/06/12 13:42:08] ppocr INFO: profiler_options : None
|
||||
[2022/06/12 13:42:08] ppocr INFO: train with paddle 2.2.2 and device CUDAPlace(0)
|
||||
[2022/06/12 13:42:08] ppocr INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']
|
||||
W0612 13:42:08.814790 17600 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.1, Runtime API Version: 10.2
|
||||
W0612 13:42:08.832805 17600 device_context.cc:465] device: 0, cuDNN Version: 7.6.
|
||||
[2022/06/12 13:42:12] ppocr INFO: resume from ./inference/rec_r32_gaspin_bilstm_att/best_accuracy
|
||||
[2022/06/12 13:42:12] ppocr INFO: metric in ckpt ***************
|
||||
[2022/06/12 13:42:12] ppocr INFO: acc:0.90589541082154
|
||||
[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:0.9627389225663741
|
||||
[2022/06/12 13:42:12] ppocr INFO: fps:1802.1068940938283
|
||||
[2022/06/12 13:42:12] ppocr INFO: best_epoch:6
|
||||
[2022/06/12 13:42:12] ppocr INFO: start_epoch:7
|
||||
eval model:: 0%| | 0/2 [00:00<?, ?it/s]
eval model:: 50%|¨€¨€¨€¨€¨€ | 1/2 [00:00<00:00, 4.67it/s]
|
||||
[2022/06/12 13:42:12] ppocr INFO: metric eval ***************
|
||||
[2022/06/12 13:42:12] ppocr INFO: acc:0.9999987500015626
|
||||
[2022/06/12 13:42:12] ppocr INFO: norm_edit_dis:1.0
|
||||
[2022/06/12 13:42:12] ppocr INFO: fps:57.50210270524082
|
|
@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
|||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
SPINRecResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -1044,3 +1044,52 @@ class MultiLabelEncode(BaseRecLabelEncode):
|
|||
data_out['label_sar'] = sar['label']
|
||||
data_out['length'] = ctc['length']
|
||||
return data_out
|
||||
|
||||
|
||||
class SPINAttnLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
lower=True,
|
||||
**kwargs):
|
||||
super(SPINAttnLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
self.lower = lower
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) > self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
target = [0] + text + [1]
|
||||
padded_text = [0 for _ in range(self.max_text_len + 2)]
|
||||
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
|
@ -267,6 +267,51 @@ class PRENResizeImg(object):
|
|||
data['image'] = resized_img.astype(np.float32)
|
||||
return data
|
||||
|
||||
class SPINRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
interpolation=2,
|
||||
mean=(127.5, 127.5, 127.5),
|
||||
std=(127.5, 127.5, 127.5),
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
self.mean = np.array(mean, dtype=np.float32)
|
||||
self.std = np.array(std, dtype=np.float32)
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
# different interpolation type corresponding the OpenCV
|
||||
if self.interpolation == 0:
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif self.interpolation == 1:
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
elif self.interpolation == 2:
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
elif self.interpolation == 3:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
raise Exception("Unsupported interpolation type !!!")
|
||||
# Deal with the image error during image loading
|
||||
if img is None:
|
||||
return None
|
||||
|
||||
img = cv2.resize(img, tuple(self.image_shape), interpolation)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
# normalize the image
|
||||
to_rgb = False
|
||||
img = img.copy().astype(np.float32)
|
||||
mean = np.float64(self.mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float64(self.std.reshape(1, -1))
|
||||
if to_rgb:
|
||||
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
|
|
|
@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
|
|||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
from .rec_multi_loss import MultiLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -61,7 +62,8 @@ def build_loss(config):
|
|||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'SPINAttentionLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SPINAttentionLoss(nn.Layer):
|
||||
def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
|
||||
super(SPINAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
targets = targets[:, 1:] # remove [eos] in label
|
||||
|
||||
label_lengths = batch[2].astype('int64')
|
||||
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
|
||||
1], predicts.shape[2]
|
||||
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||
targets = paddle.reshape(targets, [-1])
|
||||
|
||||
return {'loss': self.loss_func(inputs, targets)}
|
|
@ -32,10 +32,11 @@ def build_backbone(config, model_type):
|
|||
from .rec_micronet import MicroNet
|
||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
from .rec_svtrnet import SVTRNet
|
||||
from .rec_resnet_32 import ResNet32
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
|
||||
'SVTRNet'
|
||||
'SVTRNet', 'ResNet32'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["ResNet32"]
|
||||
|
||||
conv_weight_attr = nn.initializer.KaimingNormal()
|
||||
|
||||
class ResNet32(nn.Layer):
|
||||
"""
|
||||
Feature Extractor is proposed in FAN Ref [1]
|
||||
|
||||
Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels=512):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_channels (int): input channel
|
||||
output_channel (int): output channel
|
||||
"""
|
||||
super(ResNet32, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input feature
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output feature
|
||||
|
||||
"""
|
||||
return self.ConvNet(inputs)
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
"""Res-net Basic Block"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes,
|
||||
stride=1, downsample=None,
|
||||
norm_type='BN', **kwargs):
|
||||
"""
|
||||
Args:
|
||||
inplanes (int): input channel
|
||||
planes (int): channels of the middle feature
|
||||
stride (int): stride of the convolution
|
||||
downsample (int): type of the down_sample
|
||||
norm_type (str): type of the normalization
|
||||
**kwargs (None): backup parameter
|
||||
"""
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = self._conv3x3(inplanes, planes)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.conv2 = self._conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def _conv3x3(self, in_planes, out_planes, stride=1):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_planes (int): input channel
|
||||
out_planes (int): channels of the middle feature
|
||||
stride (int): stride of the convolution
|
||||
Returns:
|
||||
nn.Module: Conv2D with kernel = 3
|
||||
|
||||
"""
|
||||
|
||||
return nn.Conv2D(in_planes, out_planes,
|
||||
kernel_size=3, stride=stride,
|
||||
padding=1, weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): input feature
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output feature of the BasicBlock
|
||||
|
||||
"""
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
"""Res-Net network structure"""
|
||||
def __init__(self, input_channel,
|
||||
output_channel, block, layers):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_channel (int): input channel
|
||||
output_channel (int): output channel
|
||||
block (BasicBlock): convolution block
|
||||
layers (list): layers of the block
|
||||
"""
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.output_channel_block = [int(output_channel / 4),
|
||||
int(output_channel / 2),
|
||||
output_channel,
|
||||
output_channel]
|
||||
|
||||
self.inplanes = int(output_channel / 8)
|
||||
self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
|
||||
kernel_size=3, stride=1,
|
||||
padding=1,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
|
||||
self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
|
||||
kernel_size=3, stride=1,
|
||||
padding=1,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn0_2 = nn.BatchNorm2D(self.inplanes)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.layer1 = self._make_layer(block,
|
||||
self.output_channel_block[0],
|
||||
layers[0])
|
||||
self.conv1 = nn.Conv2D(self.output_channel_block[0],
|
||||
self.output_channel_block[0],
|
||||
kernel_size=3, stride=1,
|
||||
padding=1,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
|
||||
|
||||
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.layer2 = self._make_layer(block,
|
||||
self.output_channel_block[1],
|
||||
layers[1], stride=1)
|
||||
self.conv2 = nn.Conv2D(self.output_channel_block[1],
|
||||
self.output_channel_block[1],
|
||||
kernel_size=3, stride=1,
|
||||
padding=1,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False,)
|
||||
self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
|
||||
|
||||
self.maxpool3 = nn.MaxPool2D(kernel_size=2,
|
||||
stride=(2, 1),
|
||||
padding=(0, 1))
|
||||
self.layer3 = self._make_layer(block, self.output_channel_block[2],
|
||||
layers[2], stride=1)
|
||||
self.conv3 = nn.Conv2D(self.output_channel_block[2],
|
||||
self.output_channel_block[2],
|
||||
kernel_size=3, stride=1,
|
||||
padding=1,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
|
||||
|
||||
self.layer4 = self._make_layer(block, self.output_channel_block[3],
|
||||
layers[3], stride=1)
|
||||
self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
|
||||
self.output_channel_block[3],
|
||||
kernel_size=2, stride=(2, 1),
|
||||
padding=(0, 1),
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
|
||||
self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
|
||||
self.output_channel_block[3],
|
||||
kernel_size=2, stride=1,
|
||||
padding=0,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False)
|
||||
self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
"""
|
||||
|
||||
Args:
|
||||
block (block): convolution block
|
||||
planes (int): input channels
|
||||
blocks (list): layers of the block
|
||||
stride (int): stride of the convolution
|
||||
|
||||
Returns:
|
||||
nn.Sequential: the combination of the convolution block
|
||||
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride,
|
||||
weight_attr=conv_weight_attr,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = list()
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): input feature
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output feature of the Resnet
|
||||
|
||||
"""
|
||||
x = self.conv0_1(x)
|
||||
x = self.bn0_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv0_2(x)
|
||||
x = self.bn0_2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool2(x)
|
||||
x = self.layer2(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool3(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer4(x)
|
||||
x = self.conv4_1(x)
|
||||
x = self.bn4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.bn4_2(x)
|
||||
x = self.relu(x)
|
||||
return x
|
|
@ -33,6 +33,7 @@ def build_head(config):
|
|||
from .rec_aster_head import AsterHead
|
||||
from .rec_pren_head import PRENHead
|
||||
from .rec_multi_head import MultiHead
|
||||
from .rec_spin_att_head import SPINAttentionHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -46,7 +47,7 @@ def build_head(config):
|
|||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead'
|
||||
'MultiHead', 'SPINAttentionHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SPINAttentionHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||
super(SPINAttentionHead, self).__init__()
|
||||
self.input_size = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_classes = out_channels
|
||||
|
||||
self.attention_cell = AttentionLSTMCell(
|
||||
in_channels, hidden_size, out_channels, use_gru=False)
|
||||
self.generator = nn.Linear(hidden_size, out_channels)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = paddle.shape(inputs)[0]
|
||||
num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)),
|
||||
paddle.zeros((batch_size, self.hidden_size)))
|
||||
output_hiddens = []
|
||||
if self.training: # for train
|
||||
targets = targets[0]
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
char_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets, onehot_dim=self.num_classes)
|
||||
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
probs_step = self.generator(outputs)
|
||||
if probs is None:
|
||||
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||
else:
|
||||
probs = paddle.concat(
|
||||
[probs, paddle.unsqueeze(
|
||||
probs_step, axis=1)], axis=1)
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
targets = next_input
|
||||
if not self.training:
|
||||
probs = paddle.nn.functional.softmax(probs, axis=2)
|
||||
return probs
|
||||
|
||||
|
||||
class AttentionGRUCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionGRUCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
||||
|
||||
|
||||
class AttentionLSTM(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||
super(AttentionLSTM, self).__init__()
|
||||
self.input_size = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_classes = out_channels
|
||||
|
||||
self.attention_cell = AttentionLSTMCell(
|
||||
in_channels, hidden_size, out_channels, use_gru=False)
|
||||
self.generator = nn.Linear(hidden_size, out_channels)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||
(batch_size, self.hidden_size)))
|
||||
output_hiddens = []
|
||||
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
# one-hot vectors for a i-th char
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
|
||||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets, onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
probs_step = self.generator(hidden[0])
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
if probs is None:
|
||||
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||
else:
|
||||
probs = paddle.concat(
|
||||
[probs, paddle.unsqueeze(
|
||||
probs_step, axis=1)], axis=1)
|
||||
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
|
||||
targets = next_input
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class AttentionLSTMCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionLSTMCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
if not use_gru:
|
||||
self.rnn = nn.LSTMCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
else:
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
|
@ -47,6 +47,67 @@ class EncoderWithRNN(nn.Layer):
|
|||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
class BidirectionalLSTM(nn.Layer):
|
||||
def __init__(self, input_size,
|
||||
hidden_size,
|
||||
output_size=None,
|
||||
num_layers=1,
|
||||
dropout=0,
|
||||
direction=False,
|
||||
time_major=False,
|
||||
with_linear=False):
|
||||
super(BidirectionalLSTM, self).__init__()
|
||||
self.with_linear = with_linear
|
||||
self.rnn = nn.LSTM(input_size,
|
||||
hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
direction=direction,
|
||||
time_major=time_major)
|
||||
|
||||
# text recognition the specified structure LSTM with linear
|
||||
if self.with_linear:
|
||||
self.linear = nn.Linear(hidden_size * 2, output_size)
|
||||
|
||||
def forward(self, input_feature):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_feature (Torch.Tensor): visual feature [batch_size x T x input_size]
|
||||
|
||||
Returns:
|
||||
Torch.Tensor: LSTM output contextual feature [batch_size x T x output_size]
|
||||
|
||||
"""
|
||||
|
||||
# self.rnn.flatten_parameters() # error in export_model
|
||||
recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
||||
if self.with_linear:
|
||||
output = self.linear(recurrent) # batch_size x T x output_size
|
||||
return output
|
||||
return recurrent
|
||||
|
||||
class EncoderWithCascadeRNN(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
|
||||
super(EncoderWithCascadeRNN, self).__init__()
|
||||
self.out_channels = out_channels[-1]
|
||||
self.encoder = nn.LayerList(
|
||||
[BidirectionalLSTM(
|
||||
in_channels if i == 0 else out_channels[i - 1],
|
||||
hidden_size,
|
||||
output_size=out_channels[i],
|
||||
num_layers=1,
|
||||
direction='bidirectional',
|
||||
with_linear=with_linear)
|
||||
for i in range(num_layers)]
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for i, l in enumerate(self.encoder):
|
||||
x = l(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderWithFC(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
|
@ -166,13 +227,17 @@ class SequenceEncoder(nn.Layer):
|
|||
'reshape': Im2Seq,
|
||||
'fc': EncoderWithFC,
|
||||
'rnn': EncoderWithRNN,
|
||||
'svtr': EncoderWithSVTR
|
||||
'svtr': EncoderWithSVTR,
|
||||
'cascadernn': EncoderWithCascadeRNN
|
||||
}
|
||||
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
|
||||
encoder_type, support_encoder_dict.keys())
|
||||
if encoder_type == "svtr":
|
||||
self.encoder = support_encoder_dict[encoder_type](
|
||||
self.encoder_reshape.out_channels, **kwargs)
|
||||
elif encoder_type == 'cascadernn':
|
||||
self.encoder = support_encoder_dict[encoder_type](
|
||||
self.encoder_reshape.out_channels, hidden_size, **kwargs)
|
||||
else:
|
||||
self.encoder = support_encoder_dict[encoder_type](
|
||||
self.encoder_reshape.out_channels, hidden_size)
|
||||
|
|
|
@ -18,8 +18,10 @@ __all__ = ['build_transform']
|
|||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .stn import STN_ON
|
||||
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
|
||||
|
||||
support_dict = ['TPS', 'STN_ON']
|
||||
|
||||
support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,286 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
import itertools
|
||||
import functools
|
||||
from .tps import GridGenerator
|
||||
|
||||
'''This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
|
||||
'''
|
||||
|
||||
class SP_TransformerNetwork(nn.Layer):
|
||||
"""
|
||||
Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
|
||||
Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
|
||||
"""
|
||||
|
||||
def __init__(self, nc=1, default_type=5):
|
||||
""" Based on SPIN
|
||||
Args:
|
||||
nc (int): number of input channels (usually in 1 or 3)
|
||||
default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
|
||||
"""
|
||||
super(SP_TransformerNetwork, self).__init__()
|
||||
self.power_list = self.cal_K(default_type)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.bn = nn.InstanceNorm2D(nc)
|
||||
|
||||
def cal_K(self, k=5):
|
||||
"""
|
||||
|
||||
Args:
|
||||
k (int): the complexity of transformation intensities (by default set to 6 as the paper)
|
||||
|
||||
Returns:
|
||||
List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
|
||||
|
||||
"""
|
||||
from math import log
|
||||
x = []
|
||||
if k != 0:
|
||||
for i in range(1, k+1):
|
||||
lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
|
||||
upper = round(1/lower, 2)
|
||||
x.append(lower)
|
||||
x.append(upper)
|
||||
x.append(1.00)
|
||||
return x
|
||||
|
||||
def forward(self, batch_I, weights, offsets, lambda_color=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch_I (torch.Tensor): batch of input images [batch_size x nc x I_height x I_width]
|
||||
weights:
|
||||
offsets: the predicted offset by AIN, a scalar
|
||||
lambda_color: the learnable update gate \alpha in Equa. (5) as
|
||||
g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
|
||||
|
||||
Returns:
|
||||
torch.Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
|
||||
[batch_size x I_channel_num x I_r_height x I_r_width]
|
||||
|
||||
"""
|
||||
batch_I = (batch_I + 1) * 0.5
|
||||
if offsets is not None:
|
||||
batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
|
||||
batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
|
||||
batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
|
||||
|
||||
batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
|
||||
batch_weight_sum = self.bn(batch_weight_sum)
|
||||
batch_weight_sum = self.sigmoid(batch_weight_sum)
|
||||
batch_weight_sum = batch_weight_sum * 2 - 1
|
||||
return batch_weight_sum
|
||||
|
||||
class GA_SPIN_Transformer(nn.Layer):
|
||||
"""
|
||||
Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
|
||||
|
||||
|
||||
Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=1,
|
||||
I_r_size=(32, 100),
|
||||
offsets=False,
|
||||
norm_type='BN',
|
||||
default_type=6,
|
||||
loc_lr=1,
|
||||
stn=True):
|
||||
"""
|
||||
Args:
|
||||
in_channels (int): channel of input features,
|
||||
set it to 1 if the grayscale images and 3 if RGB input
|
||||
I_r_size (tuple): size of rectified images (used in STN transformations)
|
||||
inputDataType (str): the type of input data,
|
||||
only support 'torch.cuda.FloatTensor' this version
|
||||
offsets (bool): set it to False if use SPN w.o. AIN,
|
||||
and set it to True if use SPIN (both with SPN and AIN)
|
||||
norm_type (str): the normalization type of the module,
|
||||
set it to 'BN' by default, 'IN' optionally
|
||||
default_type (int): the K chromatic space,
|
||||
set it to 3/5/6 depend on the complexity of transformation intensities
|
||||
loc_lr (float): learning rate of location network
|
||||
|
||||
"""
|
||||
super(GA_SPIN_Transformer, self).__init__()
|
||||
self.nc = in_channels
|
||||
self.spt = True
|
||||
self.offsets = offsets
|
||||
self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
|
||||
self.I_r_size = I_r_size
|
||||
self.out_channels = in_channels
|
||||
if norm_type == 'BN':
|
||||
norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
|
||||
elif norm_type == 'IN':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
|
||||
use_global_stats=False)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
|
||||
if self.spt:
|
||||
self.sp_net = SP_TransformerNetwork(in_channels,
|
||||
default_type)
|
||||
self.spt_convnet = nn.Sequential(
|
||||
# 32*100
|
||||
nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(32), nn.ReLU(),
|
||||
nn.MaxPool2D(kernel_size=2, stride=2),
|
||||
# 16*50
|
||||
nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(64), nn.ReLU(),
|
||||
nn.MaxPool2D(kernel_size=2, stride=2),
|
||||
# 8*25
|
||||
nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(128), nn.ReLU(),
|
||||
nn.MaxPool2D(kernel_size=2, stride=2),
|
||||
# 4*12
|
||||
)
|
||||
self.stucture_fc1 = nn.Sequential(
|
||||
nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(256), nn.ReLU(),
|
||||
nn.MaxPool2D(kernel_size=2, stride=2),
|
||||
nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(256), nn.ReLU(), # 2*6
|
||||
nn.MaxPool2D(kernel_size=2, stride=2),
|
||||
nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
|
||||
norm_layer(512), nn.ReLU(), # 1*3
|
||||
nn.AdaptiveAvgPool2D(1),
|
||||
nn.Flatten(1, -1), # batch_size x 512
|
||||
nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
|
||||
nn.BatchNorm1D(256), nn.ReLU()
|
||||
)
|
||||
self.out_weight = 2*default_type+1
|
||||
self.spt_length = 2*default_type+1
|
||||
if offsets:
|
||||
self.out_weight += 1
|
||||
if self.stn:
|
||||
self.F = 20
|
||||
self.out_weight += self.F * 2
|
||||
self.GridGenerator = GridGenerator(self.F*2, self.F)
|
||||
|
||||
# self.out_weight*=nc
|
||||
# Init structure_fc2 in LocalizationNetwork
|
||||
initial_bias = self.init_spin(default_type*2)
|
||||
initial_bias = initial_bias.reshape(-1)
|
||||
param_attr = ParamAttr(
|
||||
learning_rate=loc_lr,
|
||||
initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
|
||||
bias_attr = ParamAttr(
|
||||
learning_rate=loc_lr,
|
||||
initializer=nn.initializer.Assign(initial_bias))
|
||||
self.stucture_fc2 = nn.Linear(256, self.out_weight,
|
||||
weight_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
if offsets:
|
||||
self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
|
||||
3, 1, 1,
|
||||
bias_attr=False),
|
||||
norm_layer(16),
|
||||
nn.ReLU(),)
|
||||
self.offset_fc2 = nn.Conv2D(16, in_channels,
|
||||
3, 1, 1)
|
||||
self.pool = nn.MaxPool2D(2, 2)
|
||||
|
||||
def init_spin(self, nz):
|
||||
"""
|
||||
Args:
|
||||
nz (int): number of paired \betas exponents, which means the value of K x 2
|
||||
|
||||
"""
|
||||
init_id = [0.00]*nz+[5.00]
|
||||
if self.offsets:
|
||||
init_id += [-5.00]
|
||||
# init_id *=3
|
||||
init = np.array(init_id)
|
||||
|
||||
if self.stn:
|
||||
F = self.F
|
||||
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
||||
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
|
||||
ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
initial_bias = initial_bias.reshape(-1)
|
||||
init = np.concatenate([init, initial_bias], axis=0)
|
||||
return init
|
||||
|
||||
def forward(self, x, return_weight=False):
|
||||
"""
|
||||
Args:
|
||||
x (torch.cuda.FloatTensor): input image batch
|
||||
return_weight (bool): set to False by default,
|
||||
if set to True return the predicted offsets of AIN, denoted as x_{offsets}
|
||||
|
||||
Returns:
|
||||
torch.Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
|
||||
"""
|
||||
|
||||
if self.spt:
|
||||
feat = self.spt_convnet(x)
|
||||
fc1 = self.stucture_fc1(feat)
|
||||
sp_weight_fusion = self.stucture_fc2(fc1)
|
||||
sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
|
||||
if self.offsets: # SPIN w. AIN
|
||||
lambda_color = sp_weight_fusion[:, self.spt_length, 0]
|
||||
lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
sp_weight = sp_weight_fusion[:, :self.spt_length, :]
|
||||
offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
|
||||
|
||||
assert offsets.shape[2] == 2 # 2
|
||||
assert offsets.shape[3] == 6 # 16
|
||||
offsets = self.sigmoid(offsets) # v12
|
||||
|
||||
if return_weight:
|
||||
return offsets
|
||||
offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
|
||||
|
||||
if self.stn:
|
||||
batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
|
||||
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
|
||||
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
|
||||
self.I_r_size[0],
|
||||
self.I_r_size[1],
|
||||
2])
|
||||
|
||||
else: # SPIN w.o. AIN
|
||||
sp_weight = sp_weight_fusion[:, :self.spt_length, :]
|
||||
lambda_color, offsets = None, None
|
||||
|
||||
if self.stn:
|
||||
batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
|
||||
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
|
||||
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
|
||||
self.I_r_size[0],
|
||||
self.I_r_size[1],
|
||||
2])
|
||||
|
||||
x = self.sp_net(x, sp_weight, offsets, lambda_color)
|
||||
if self.stn:
|
||||
x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
|
||||
return x
|
|
@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
|
|||
from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode
|
||||
SEEDLabelDecode, PRENLabelDecode, SPINAttnLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
|
@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode'
|
||||
'DistillationSARLabelDecode', 'SPINAttnLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -752,3 +752,82 @@ class PRENLabelDecode(BaseRecLabelDecode):
|
|||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
class SPINAttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(SPINAttnLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == int(beg_idx):
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
"""
|
||||
text = self.decode(text)
|
||||
if label is None:
|
||||
return text
|
||||
else:
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
|
@ -0,0 +1,68 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
:
|
||||
(
|
||||
'
|
||||
-
|
||||
,
|
||||
%
|
||||
>
|
||||
.
|
||||
[
|
||||
?
|
||||
)
|
||||
"
|
||||
=
|
||||
_
|
||||
*
|
||||
]
|
||||
;
|
||||
&
|
||||
+
|
||||
$
|
||||
@
|
||||
/
|
||||
|
|
||||
!
|
||||
<
|
||||
#
|
||||
`
|
||||
{
|
||||
~
|
||||
\
|
||||
}
|
||||
^
|
|
@ -0,0 +1,118 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 6
|
||||
log_smooth_window: 50
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [3, 4, 5]
|
||||
values: [0.001, 0.0003, 0.00009, 0.000027]
|
||||
|
||||
clip_norm: 5
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SPIN
|
||||
in_channels: 1
|
||||
Transform:
|
||||
name: GA_SPIN
|
||||
offsets: True
|
||||
default_type: 6
|
||||
loc_lr: 0.1
|
||||
stn: True
|
||||
Backbone:
|
||||
name: ResNet32
|
||||
out_channels: 512
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: cascadernn
|
||||
hidden_size: 256
|
||||
out_channels: [256, 512]
|
||||
with_linear: True
|
||||
Head:
|
||||
name: SPINAttentionHead
|
||||
hidden_size: 256
|
||||
|
||||
|
||||
Loss:
|
||||
name: SPINAttentionLoss
|
||||
ignore_index: 0
|
||||
|
||||
PostProcess:
|
||||
name: SPINAttnLabelDecode
|
||||
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
- SPINRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
interpolation : 2
|
||||
mean: [127.5]
|
||||
std: [127.5]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 128
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
- SPINRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
interpolation : 2
|
||||
mean: [127.5]
|
||||
std: [127.5]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1
|
||||
num_workers: 1
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_r32_gaspin_bilstm_att
|
||||
python:python
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./inference/rec_inference
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1|6
|
||||
--use_tensorrt:False|False
|
||||
--precision:fp32|int8
|
||||
--rec_model_dir:
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,32,100]}]
|
|
@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
|||
shape=[None, 3, 64, 512], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SPIN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
|
|
@ -69,6 +69,12 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "SPIN":
|
||||
postprocess_params = {
|
||||
'name': 'SPINAttnLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -250,6 +256,22 @@ class TextRecognizer(object):
|
|||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_spin(self, img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
mean = [127.5]
|
||||
std = [127.5]
|
||||
mean = np.array(mean, dtype=np.float32)
|
||||
std = np.array(std, dtype=np.float32)
|
||||
mean = np.float32(mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -300,6 +322,10 @@ class TextRecognizer(object):
|
|||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == 'SPIN':
|
||||
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
|
|
|
@ -207,7 +207,7 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
@ -564,7 +564,8 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
|
||||
'SPIN'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue