Merge pull request #7741 from zhiminzhang0830/rfl_branch
add text recognition algorithm rflearningpull/7832/head
commit
823a8391f1
|
@ -0,0 +1,112 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 6
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_resnet_rfl_att/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 5000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/rec_resnet_rfl.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
weight_decay: 0.0
|
||||
clip_norm_global: 5.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [3, 4, 5]
|
||||
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: RFL
|
||||
in_channels: 1
|
||||
Transform:
|
||||
name: TPS
|
||||
num_fiducial: 20
|
||||
loc_lr: 1.0
|
||||
model_name: large
|
||||
Backbone:
|
||||
name: ResNetRFL
|
||||
use_cnt: True
|
||||
use_seq: True
|
||||
Neck:
|
||||
name: RFAdaptor
|
||||
use_v2s: True
|
||||
use_s2v: True
|
||||
Head:
|
||||
name: RFLHead
|
||||
in_channels: 512
|
||||
hidden_size: 256
|
||||
batch_max_legnth: 25
|
||||
out_channels: 38
|
||||
use_cnt: True
|
||||
use_seq: True
|
||||
|
||||
Loss:
|
||||
name: RFLLoss
|
||||
# ignore_index: 0
|
||||
|
||||
PostProcess:
|
||||
name: RFLLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
padding: false
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
padding: false
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 8
|
|
@ -0,0 +1,110 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 6
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_resnet_rfl_visual/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 5000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/rec_resnet_rfl_visual.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
weight_decay: 0.0
|
||||
clip_norm_global: 5.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [3, 4, 5]
|
||||
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: RFL
|
||||
in_channels: 1
|
||||
Transform:
|
||||
name: TPS
|
||||
num_fiducial: 20
|
||||
loc_lr: 1.0
|
||||
model_name: large
|
||||
Backbone:
|
||||
name: ResNetRFL
|
||||
use_cnt: True
|
||||
use_seq: False
|
||||
Neck:
|
||||
name: RFAdaptor
|
||||
use_v2s: False
|
||||
use_s2v: False
|
||||
Head:
|
||||
name: RFLHead
|
||||
in_channels: 512
|
||||
hidden_size: 256
|
||||
batch_max_legnth: 25
|
||||
out_channels: 38
|
||||
use_cnt: True
|
||||
use_seq: False
|
||||
Loss:
|
||||
name: RFLLoss
|
||||
|
||||
PostProcess:
|
||||
name: RFLLabelDecode
|
||||
|
||||
Metric:
|
||||
name: CNTMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
padding: false
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
padding: false
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 8
|
|
@ -79,6 +79,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin.md)
|
||||
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
|
||||
- [x] [RFL](./algorithm_rec_rfl.md)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -102,7 +103,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|
|||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
# 场景文本识别算法-RFL
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf)
|
||||
> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan
|
||||
> ICDAR, 2021
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`RFL`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
```shell
|
||||
#step1:训练CNT分支
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
|
||||
|
||||
#step2:联合训练CNT和Att分支,注意将pretrained_model的路径设置为本地路径。
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar),使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att/
|
||||
```
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应RFL的`infer_shape`。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/rec_resnet_rfl_att/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100'
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
||||
- 训练上述模型采用的图像分辨率是[1,32,100],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中RFL的预处理为您的预处理方法。
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持RFL,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{2021Reciprocal,
|
||||
title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition},
|
||||
author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. },
|
||||
booktitle = {ICDAR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2105.06229}
|
||||
}
|
||||
```
|
|
@ -76,6 +76,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
||||
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
|
||||
- [x] [RFL](./algorithm_rec_rfl_en.md)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -99,7 +100,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|
||||
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|
||||
|
||||
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
# RFL
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf)
|
||||
> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan
|
||||
> ICDAR, 2021
|
||||
|
||||
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#step1:train the CNT branch
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
|
||||
|
||||
#step2:joint training of CNT and Att branches
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
|
||||
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the RFL text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att
|
||||
```
|
||||
|
||||
**Note:**
|
||||
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
|
||||
- If you modified the input size during training, please modify the `infer_shape` corresponding to NRTR in the `tools/export_model.py` file.
|
||||
|
||||
After the conversion is successful, there are three files in the directory:
|
||||
```
|
||||
/inference/rec_resnet_rfl_att/
|
||||
├── inference.pdiparams
|
||||
├── inference.pdiparams.info
|
||||
└── inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
For RFL text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100'
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
|
||||
The result is as follows:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374)
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{2021Reciprocal,
|
||||
title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition},
|
||||
author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. },
|
||||
booktitle = {ICDAR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2105.06229}
|
||||
}
|
||||
```
|
|
@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
|
|||
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
||||
RFLRecResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
return idx
|
||||
|
||||
|
||||
class RFLLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(RFLLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def encode_cnt(self, text):
|
||||
cnt_label = [0.0] * len(self.character)
|
||||
for char_ in text:
|
||||
cnt_label[char_] += 1
|
||||
return np.array(cnt_label)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
cnt_label = self.encode_cnt(text)
|
||||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
- len(text) - 2)
|
||||
if len(text) != self.max_text_len:
|
||||
return None
|
||||
data['label'] = np.array(text)
|
||||
data['cnt_label'] = cnt_label
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class SEEDLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -237,6 +237,33 @@ class VLRecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class RFLRecResizeImg(object):
|
||||
def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.padding = padding
|
||||
|
||||
self.interpolation = interpolation
|
||||
if self.interpolation == 0:
|
||||
self.interpolation = cv2.INTER_NEAREST
|
||||
elif self.interpolation == 1:
|
||||
self.interpolation = cv2.INTER_LINEAR
|
||||
elif self.interpolation == 2:
|
||||
self.interpolation = cv2.INTER_CUBIC
|
||||
elif self.interpolation == 3:
|
||||
self.interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
raise Exception("Unsupported interpolation type !!!")
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
norm_img, valid_ratio = resize_norm_img(
|
||||
img, self.image_shape, self.padding, self.interpolation)
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
|
|||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class RobustScannerRecResizeImg(object):
|
||||
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
max_text_length,
|
||||
width_downsample_ratio=0.25,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
self.max_text_length = max_text_length
|
||||
|
@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
|
|||
data['word_positons'] = word_positons
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
|
@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
|||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape, padding=True):
|
||||
def resize_norm_img(img,
|
||||
image_shape,
|
||||
padding=True,
|
||||
interpolation=cv2.INTER_LINEAR):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
if not padding:
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
img, (imgW, imgH), interpolation=interpolation)
|
||||
resized_w = imgW
|
||||
else:
|
||||
ratio = w / float(h)
|
||||
|
|
|
@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
|
|||
from .rec_multi_loss import MultiLoss
|
||||
from .rec_vl_loss import VLLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
from .rec_rfl_loss import RFLLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -69,7 +70,7 @@ def build_loss(config):
|
|||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss'
|
||||
'SLALoss', 'CTLoss', 'RFLLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_common/models/loss/cross_entropy_loss.py
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from .basic_loss import CELoss, DistanceLoss
|
||||
|
||||
|
||||
class RFLLoss(nn.Layer):
|
||||
def __init__(self, ignore_index=-100, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.cnt_loss = nn.MSELoss(**kwargs)
|
||||
self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
|
||||
self.total_loss = {}
|
||||
total_loss = 0.0
|
||||
if isinstance(predicts, tuple) or isinstance(predicts, list):
|
||||
cnt_outputs, seq_outputs = predicts
|
||||
else:
|
||||
cnt_outputs, seq_outputs = predicts, None
|
||||
# batch [image, label, length, cnt_label]
|
||||
if cnt_outputs is not None:
|
||||
cnt_loss = self.cnt_loss(cnt_outputs,
|
||||
paddle.cast(batch[3], paddle.float32))
|
||||
self.total_loss['cnt_loss'] = cnt_loss
|
||||
total_loss += cnt_loss
|
||||
|
||||
if seq_outputs is not None:
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
batch_size, num_steps, num_classes = seq_outputs.shape[
|
||||
0], seq_outputs.shape[1], seq_outputs.shape[2]
|
||||
assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = seq_outputs[:, :-1, :]
|
||||
targets = targets[:, 1:]
|
||||
|
||||
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
|
||||
targets = paddle.reshape(targets, [-1])
|
||||
seq_loss = self.seq_loss(inputs, targets)
|
||||
self.total_loss['seq_loss'] = seq_loss
|
||||
total_loss += seq_loss
|
||||
|
||||
self.total_loss['loss'] = total_loss
|
||||
return self.total_loss
|
|
@ -22,7 +22,7 @@ import copy
|
|||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric, DetFCEMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .rec_metric import RecMetric, CNTMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
@ -38,7 +38,7 @@ def build_metric(config):
|
|||
support_dict = [
|
||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric', 'SRMetric', 'CTMetric'
|
||||
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
|
|||
import string
|
||||
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
|
@ -74,3 +73,36 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
||||
|
||||
class CNTMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = 1e-5
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
for pred, target in zip(preds, labels):
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {'acc': correct_num / (all_num + self.eps), }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
|
|
|
@ -42,10 +42,11 @@ def build_backbone(config, model_type):
|
|||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
from .rec_svtrnet import SVTRNet
|
||||
from .rec_vitstr import ViTSTR
|
||||
from .rec_resnet_rfl import ResNetRFL
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
|
||||
]
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,348 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||
|
||||
kaiming_init_ = KaimingNormal()
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
"""Res-net Basic Block"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
norm_type='BN',
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
inplanes (int): input channel
|
||||
planes (int): channels of the middle feature
|
||||
stride (int): stride of the convolution
|
||||
downsample (int): type of the down_sample
|
||||
norm_type (str): type of the normalization
|
||||
**kwargs (None): backup parameter
|
||||
"""
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = self._conv3x3(inplanes, planes)
|
||||
self.bn1 = nn.BatchNorm(planes)
|
||||
self.conv2 = self._conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def _conv3x3(self, in_planes, out_planes, stride=1):
|
||||
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetRFL(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=512,
|
||||
use_cnt=True,
|
||||
use_seq=True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_channels (int): input channel
|
||||
out_channels (int): output channel
|
||||
"""
|
||||
super(ResNetRFL, self).__init__()
|
||||
assert use_cnt or use_seq
|
||||
self.use_cnt, self.use_seq = use_cnt, use_seq
|
||||
self.backbone = RFLBase(in_channels)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.out_channels_block = [
|
||||
int(self.out_channels / 4), int(self.out_channels / 2),
|
||||
self.out_channels, self.out_channels
|
||||
]
|
||||
block = BasicBlock
|
||||
layers = [1, 2, 5, 3]
|
||||
self.inplanes = int(self.out_channels // 2)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
if self.use_seq:
|
||||
self.maxpool3 = nn.MaxPool2D(
|
||||
kernel_size=2, stride=(2, 1), padding=(0, 1))
|
||||
self.layer3 = self._make_layer(
|
||||
block, self.out_channels_block[2], layers[2], stride=1)
|
||||
self.conv3 = nn.Conv2D(
|
||||
self.out_channels_block[2],
|
||||
self.out_channels_block[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn3 = nn.BatchNorm(self.out_channels_block[2])
|
||||
|
||||
self.layer4 = self._make_layer(
|
||||
block, self.out_channels_block[3], layers[3], stride=1)
|
||||
self.conv4_1 = nn.Conv2D(
|
||||
self.out_channels_block[3],
|
||||
self.out_channels_block[3],
|
||||
kernel_size=2,
|
||||
stride=(2, 1),
|
||||
padding=(0, 1),
|
||||
bias_attr=False)
|
||||
self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
|
||||
self.conv4_2 = nn.Conv2D(
|
||||
self.out_channels_block[3],
|
||||
self.out_channels_block[3],
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
|
||||
|
||||
if self.use_cnt:
|
||||
self.inplanes = int(self.out_channels // 2)
|
||||
self.v_maxpool3 = nn.MaxPool2D(
|
||||
kernel_size=2, stride=(2, 1), padding=(0, 1))
|
||||
self.v_layer3 = self._make_layer(
|
||||
block, self.out_channels_block[2], layers[2], stride=1)
|
||||
self.v_conv3 = nn.Conv2D(
|
||||
self.out_channels_block[2],
|
||||
self.out_channels_block[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
|
||||
|
||||
self.v_layer4 = self._make_layer(
|
||||
block, self.out_channels_block[3], layers[3], stride=1)
|
||||
self.v_conv4_1 = nn.Conv2D(
|
||||
self.out_channels_block[3],
|
||||
self.out_channels_block[3],
|
||||
kernel_size=2,
|
||||
stride=(2, 1),
|
||||
padding=(0, 1),
|
||||
bias_attr=False)
|
||||
self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
|
||||
self.v_conv4_2 = nn.Conv2D(
|
||||
self.out_channels_block[3],
|
||||
self.out_channels_block[3],
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm(planes * block.expansion), )
|
||||
|
||||
layers = list()
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
x_1 = self.backbone(inputs)
|
||||
|
||||
if self.use_cnt:
|
||||
v_x = self.v_maxpool3(x_1)
|
||||
v_x = self.v_layer3(v_x)
|
||||
v_x = self.v_conv3(v_x)
|
||||
v_x = self.v_bn3(v_x)
|
||||
visual_feature_2 = self.relu(v_x)
|
||||
|
||||
v_x = self.v_layer4(visual_feature_2)
|
||||
v_x = self.v_conv4_1(v_x)
|
||||
v_x = self.v_bn4_1(v_x)
|
||||
v_x = self.relu(v_x)
|
||||
v_x = self.v_conv4_2(v_x)
|
||||
v_x = self.v_bn4_2(v_x)
|
||||
visual_feature_3 = self.relu(v_x)
|
||||
else:
|
||||
visual_feature_3 = None
|
||||
if self.use_seq:
|
||||
x = self.maxpool3(x_1)
|
||||
x = self.layer3(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x_2 = self.relu(x)
|
||||
|
||||
x = self.layer4(x_2)
|
||||
x = self.conv4_1(x)
|
||||
x = self.bn4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.bn4_2(x)
|
||||
x_3 = self.relu(x)
|
||||
else:
|
||||
x_3 = None
|
||||
|
||||
return [visual_feature_3, x_3]
|
||||
|
||||
|
||||
class ResNetBase(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, block, layers):
|
||||
super(ResNetBase, self).__init__()
|
||||
|
||||
self.out_channels_block = [
|
||||
int(out_channels / 4), int(out_channels / 2), out_channels,
|
||||
out_channels
|
||||
]
|
||||
|
||||
self.inplanes = int(out_channels / 8)
|
||||
self.conv0_1 = nn.Conv2D(
|
||||
in_channels,
|
||||
int(out_channels / 16),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
|
||||
self.conv0_2 = nn.Conv2D(
|
||||
int(out_channels / 16),
|
||||
self.inplanes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn0_2 = nn.BatchNorm(self.inplanes)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.layer1 = self._make_layer(block, self.out_channels_block[0],
|
||||
layers[0])
|
||||
self.conv1 = nn.Conv2D(
|
||||
self.out_channels_block[0],
|
||||
self.out_channels_block[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(self.out_channels_block[0])
|
||||
|
||||
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.layer2 = self._make_layer(
|
||||
block, self.out_channels_block[1], layers[1], stride=1)
|
||||
self.conv2 = nn.Conv2D(
|
||||
self.out_channels_block[1],
|
||||
self.out_channels_block[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(self.out_channels_block[1])
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm(planes * block.expansion), )
|
||||
|
||||
layers = list()
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv0_1(x)
|
||||
x = self.bn0_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv0_2(x)
|
||||
x = self.bn0_2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool2(x)
|
||||
x = self.layer2(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class RFLBase(nn.Layer):
|
||||
""" Reciprocal feature learning share backbone network"""
|
||||
|
||||
def __init__(self, in_channels, out_channels=512):
|
||||
super(RFLBase, self).__init__()
|
||||
self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
|
||||
[1, 2, 5, 3])
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.ConvNet(inputs)
|
|
@ -38,6 +38,7 @@ def build_head(config):
|
|||
from .rec_abinet_head import ABINetHead
|
||||
from .rec_robustscanner_head import RobustScannerHead
|
||||
from .rec_visionlan_head import VLHead
|
||||
from .rec_rfl_head import RFLHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -53,7 +54,7 @@ def build_head(config):
|
|||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
|
|||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
char_onehots = None
|
||||
alpha = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
|
@ -167,7 +169,8 @@ class AttentionLSTM(nn.Layer):
|
|||
next_input = probs_step.argmax(axis=1)
|
||||
|
||||
targets = next_input
|
||||
|
||||
if not self.training:
|
||||
probs = paddle.nn.functional.softmax(probs, axis=2)
|
||||
return probs
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
|
||||
"""
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||
|
||||
from .rec_att_head import AttentionLSTM
|
||||
|
||||
kaiming_init_ = KaimingNormal()
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
class CNTHead(nn.Layer):
|
||||
def __init__(self,
|
||||
embed_size=512,
|
||||
encode_length=26,
|
||||
out_channels=38,
|
||||
**kwargs):
|
||||
super(CNTHead, self).__init__()
|
||||
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
|
||||
self.Prediction_visual = nn.Linear(encode_length * embed_size,
|
||||
self.out_channels)
|
||||
|
||||
def forward(self, visual_feature):
|
||||
|
||||
b, c, h, w = visual_feature.shape
|
||||
visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
|
||||
[0, 2, 1])
|
||||
visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512
|
||||
b, n, c = visual_feature_num.shape
|
||||
# using visual feature directly calculate the text length
|
||||
visual_feature_num = visual_feature_num.reshape([b, n * c])
|
||||
prediction_visual = self.Prediction_visual(visual_feature_num)
|
||||
|
||||
return prediction_visual
|
||||
|
||||
|
||||
class RFLHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=512,
|
||||
hidden_size=256,
|
||||
batch_max_legnth=25,
|
||||
out_channels=38,
|
||||
use_cnt=True,
|
||||
use_seq=True,
|
||||
**kwargs):
|
||||
|
||||
super(RFLHead, self).__init__()
|
||||
assert use_cnt or use_seq
|
||||
self.use_cnt = use_cnt
|
||||
self.use_seq = use_seq
|
||||
if self.use_cnt:
|
||||
self.cnt_head = CNTHead(
|
||||
embed_size=in_channels,
|
||||
encode_length=batch_max_legnth + 1,
|
||||
out_channels=out_channels,
|
||||
**kwargs)
|
||||
if self.use_seq:
|
||||
self.seq_head = AttentionLSTM(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_size=hidden_size,
|
||||
**kwargs)
|
||||
self.batch_max_legnth = batch_max_legnth
|
||||
self.num_class = out_channels
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
kaiming_init_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
cnt_inputs, seq_inputs = x
|
||||
if self.use_cnt:
|
||||
cnt_outputs = self.cnt_head(cnt_inputs)
|
||||
else:
|
||||
cnt_outputs = None
|
||||
if self.use_seq:
|
||||
if self.training:
|
||||
seq_outputs = self.seq_head(seq_inputs, targets[0],
|
||||
self.batch_max_legnth)
|
||||
else:
|
||||
seq_outputs = self.seq_head(seq_inputs, None,
|
||||
self.batch_max_legnth)
|
||||
return cnt_outputs, seq_outputs
|
||||
else:
|
||||
return cnt_outputs
|
|
@ -27,9 +27,11 @@ def build_neck(config):
|
|||
from .pren_fpn import PRENFPN
|
||||
from .csp_pan import CSPPAN
|
||||
from .ct_fpn import CTFPN
|
||||
from .rf_adaptor import RFAdaptor
|
||||
support_dict = [
|
||||
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
|
||||
'RFAdaptor'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||
|
||||
kaiming_init_ = KaimingNormal()
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
class S2VAdaptor(nn.Layer):
|
||||
""" Semantic to Visual adaptation module"""
|
||||
|
||||
def __init__(self, in_channels=512):
|
||||
super(S2VAdaptor, self).__init__()
|
||||
|
||||
self.in_channels = in_channels # 512
|
||||
|
||||
# feature strengthen module, channel attention
|
||||
self.channel_inter = nn.Linear(
|
||||
self.in_channels, self.in_channels, bias_attr=False)
|
||||
self.channel_bn = nn.BatchNorm1D(self.in_channels)
|
||||
self.channel_act = nn.ReLU()
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2D):
|
||||
kaiming_init_(m.weight)
|
||||
if isinstance(m, nn.Conv2D) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward(self, semantic):
|
||||
semantic_source = semantic # batch, channel, height, width
|
||||
|
||||
# feature transformation
|
||||
semantic = semantic.squeeze(2).transpose(
|
||||
[0, 2, 1]) # batch, width, channel
|
||||
channel_att = self.channel_inter(semantic) # batch, width, channel
|
||||
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
|
||||
channel_bn = self.channel_bn(channel_att) # batch, channel, width
|
||||
channel_att = self.channel_act(channel_bn) # batch, channel, width
|
||||
|
||||
# Feature enhancement
|
||||
channel_output = semantic_source * channel_att.unsqueeze(
|
||||
-2) # batch, channel, 1, width
|
||||
|
||||
return channel_output
|
||||
|
||||
|
||||
class V2SAdaptor(nn.Layer):
|
||||
""" Visual to Semantic adaptation module"""
|
||||
|
||||
def __init__(self, in_channels=512, return_mask=False):
|
||||
super(V2SAdaptor, self).__init__()
|
||||
|
||||
# parameter initialization
|
||||
self.in_channels = in_channels
|
||||
self.return_mask = return_mask
|
||||
|
||||
# output transformation
|
||||
self.channel_inter = nn.Linear(
|
||||
self.in_channels, self.in_channels, bias_attr=False)
|
||||
self.channel_bn = nn.BatchNorm1D(self.in_channels)
|
||||
self.channel_act = nn.ReLU()
|
||||
|
||||
def forward(self, visual):
|
||||
# Feature enhancement
|
||||
visual = visual.squeeze(2).transpose([0, 2, 1]) # batch, width, channel
|
||||
channel_att = self.channel_inter(visual) # batch, width, channel
|
||||
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
|
||||
channel_bn = self.channel_bn(channel_att) # batch, channel, width
|
||||
channel_att = self.channel_act(channel_bn) # batch, channel, width
|
||||
|
||||
# size alignment
|
||||
channel_output = channel_att.unsqueeze(-2) # batch, width, channel
|
||||
|
||||
if self.return_mask:
|
||||
return channel_output, channel_att
|
||||
return channel_output
|
||||
|
||||
|
||||
class RFAdaptor(nn.Layer):
|
||||
def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
|
||||
super(RFAdaptor, self).__init__()
|
||||
if use_v2s is True:
|
||||
self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs)
|
||||
else:
|
||||
self.neck_v2s = None
|
||||
if use_s2v is True:
|
||||
self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs)
|
||||
else:
|
||||
self.neck_s2v = None
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
visual_feature, rcg_feature = x
|
||||
if visual_feature is not None:
|
||||
batch, source_channels, v_source_height, v_source_width = visual_feature.shape
|
||||
visual_feature = visual_feature.reshape(
|
||||
[batch, source_channels, 1, v_source_height * v_source_width])
|
||||
|
||||
if self.neck_v2s is not None:
|
||||
v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
|
||||
else:
|
||||
v_rcg_feature = rcg_feature
|
||||
|
||||
if self.neck_s2v is not None:
|
||||
v_visual_feature = visual_feature + self.neck_s2v(rcg_feature)
|
||||
else:
|
||||
v_visual_feature = visual_feature
|
||||
if v_rcg_feature is not None:
|
||||
batch, source_channels, source_height, source_width = v_rcg_feature.shape
|
||||
v_rcg_feature = v_rcg_feature.reshape(
|
||||
[batch, source_channels, 1, source_height * source_width])
|
||||
|
||||
v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
|
||||
return v_visual_feature, v_rcg_feature
|
|
@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
|
|||
if 'clip_norm' in config:
|
||||
clip_norm = config.pop('clip_norm')
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||
elif 'clip_norm_global' in config:
|
||||
clip_norm = config.pop('clip_norm_global')
|
||||
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
|
||||
else:
|
||||
grad_clip = None
|
||||
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
||||
|
|
|
@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
|
|||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
||||
SPINLabelDecode, VLLabelDecode
|
||||
SPINLabelDecode, VLLabelDecode, RFLLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
||||
|
@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
|
||||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -242,6 +242,95 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
return idx
|
||||
|
||||
|
||||
class RFLLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(RFLLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
# if seq_outputs is not None:
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
cnt_outputs, seq_outputs = preds
|
||||
if isinstance(seq_outputs, paddle.Tensor):
|
||||
seq_outputs = seq_outputs.numpy()
|
||||
preds_idx = seq_outputs.argmax(axis=2)
|
||||
preds_prob = seq_outputs.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
else:
|
||||
cnt_outputs = preds
|
||||
if isinstance(cnt_outputs, paddle.Tensor):
|
||||
cnt_outputs = cnt_outputs.numpy()
|
||||
cnt_length = []
|
||||
for lens in cnt_outputs:
|
||||
length = round(np.sum(lens))
|
||||
cnt_length.append(length)
|
||||
if label is None:
|
||||
return cnt_length
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
length = [len(res[0]) for res in label]
|
||||
return cnt_length, length
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 6
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/rec/rec_resnet_rfl/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 5000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/rec_resnet_rfl.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
weight_decay: 0.0
|
||||
clip_norm_global: 5.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [3, 4, 5]
|
||||
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: RFL
|
||||
in_channels: 1
|
||||
Transform:
|
||||
name: TPS
|
||||
num_fiducial: 20
|
||||
loc_lr: 1.0
|
||||
model_name: large
|
||||
Backbone:
|
||||
name: ResNetRFL
|
||||
use_cnt: True
|
||||
use_seq: True
|
||||
Neck:
|
||||
name: RFAdaptor
|
||||
use_v2s: True
|
||||
use_s2v: True
|
||||
Head:
|
||||
name: RFLHead
|
||||
in_channels: 512
|
||||
hidden_size: 256
|
||||
batch_max_legnth: 25
|
||||
out_channels: 38
|
||||
use_cnt: True
|
||||
use_seq: True
|
||||
|
||||
Loss:
|
||||
name: RFLLoss
|
||||
|
||||
PostProcess:
|
||||
name: RFLLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RFLLabelEncode: # Class handling label
|
||||
- RFLRecResizeImg:
|
||||
image_shape: [1, 32, 100]
|
||||
interpolation: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 8
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_resnet_rfl
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./inference/rec_inference
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_resnet_rfl_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--rec_model_dir:
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[1,32,100]}]
|
|
@ -99,7 +99,7 @@ def export_single_model(model,
|
|||
]
|
||||
# print([None, 3, 32, 128])
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
|
||||
elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
|
|
|
@ -100,6 +100,12 @@ class TextRecognizer(object):
|
|||
"use_space_char": args.use_space_char,
|
||||
"rm_symbol": True
|
||||
}
|
||||
elif self.rec_algorithm == 'RFL':
|
||||
postprocess_params = {
|
||||
'name': 'RFLLabelDecode',
|
||||
"character_dict_path": None,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "PREN":
|
||||
postprocess_params = {'name': 'PRENLabelDecode'}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
|
@ -145,6 +151,16 @@ class TextRecognizer(object):
|
|||
else:
|
||||
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
||||
return norm_img
|
||||
elif self.rec_algorithm == 'RFL':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
|
|
|
@ -97,7 +97,8 @@ def main():
|
|||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
op[op_name][
|
||||
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
@ -136,9 +137,10 @@ def main():
|
|||
if config['Architecture']['algorithm'] == "RobustScanner":
|
||||
valid_ratio = np.expand_dims(batch[1], axis=0)
|
||||
word_positons = np.expand_dims(batch[2], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
img_metas = [
|
||||
paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
|
@ -160,6 +162,10 @@ def main():
|
|||
"score": float(post_result[key][0][1]),
|
||||
}
|
||||
info = json.dumps(rec_info, ensure_ascii=False)
|
||||
elif isinstance(post_result, list) and isinstance(post_result[0],
|
||||
int):
|
||||
# for RFLearning CNT branch
|
||||
info = str(post_result[0])
|
||||
else:
|
||||
if len(post_result[0]) >= 2:
|
||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||
|
|
|
@ -217,7 +217,7 @@ def train(config,
|
|||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
||||
"RobustScanner"
|
||||
"RobustScanner", "RFL"
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
|
@ -625,7 +625,7 @@ def preprocess(is_train=False):
|
|||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT'
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue