Add rec algo VisionLAN (#6943)
* add vl * add vl * add vl * add ref * fix head out * add visionlan doc * fix vl infer * update dictpull/6837/head^2
parent
f5692c3f7e
commit
3f65b360ef
|
@ -0,0 +1,106 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 8
|
||||
log_smooth_window: 200
|
||||
print_batch_step: 200
|
||||
save_model_dir: ./output/rec/r45_visionlan
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: True
|
||||
infer_img: doc/imgs_words/en/word_2.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: &max_text_length 25
|
||||
training_step: &training_step LA
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_visionlan.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 20.0
|
||||
group_lr: true
|
||||
training_step: *training_step
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [6]
|
||||
values: [0.0001, 0.00001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: VisionLAN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet45
|
||||
strides: [2, 2, 2, 1, 1]
|
||||
Head:
|
||||
name: VLHead
|
||||
n_layers: 3
|
||||
n_position: 256
|
||||
n_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
training_step: *training_step
|
||||
|
||||
Loss:
|
||||
name: VLLoss
|
||||
mode: *training_step
|
||||
weight_res: 0.5
|
||||
weight_mas: 0.5
|
||||
|
||||
PostProcess:
|
||||
name: VLLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
is_filter: true
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetRecAug:
|
||||
- VLLabelEncode: # Class handling label
|
||||
- VLRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 220
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VLLabelEncode: # Class handling label
|
||||
- VLRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 64
|
||||
num_workers: 4
|
||||
|
|
@ -69,6 +69,7 @@
|
|||
- [x] [SVTR](./algorithm_rec_svtr.md)
|
||||
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
|
||||
- [x] [ABINet](./algorithm_rec_abinet.md)
|
||||
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin.md)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
@ -90,6 +91,7 @@
|
|||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|
||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# 场景文本识别算法-VisionLAN
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
|
||||
> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
|
||||
> ICCV, 2021
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`VisionLAN`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`VisionLAN`识别模型时需要**更换配置文件**为`VisionLAN`的[配置文件](../../configs/rec/rec_r45_visionlan.yml)。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
```shell
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
|
||||
```
|
||||
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
|
||||
```
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应VisionLAN的`infer_shape`。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
./inference/rec_r45_visionlan/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
||||
- 训练上述模型采用的图像分辨率是[3,64,256],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中VisionLAN的预处理为您的预处理方法。
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持VisionLAN,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 。
|
||||
2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@inproceedings{wang2021two,
|
||||
title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
|
||||
author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
|
||||
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
||||
pages={14194--14203},
|
||||
year={2021}
|
||||
}
|
||||
```
|
|
@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||
- [x] [SVTR](./algorithm_rec_svtr_en.md)
|
||||
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
|
||||
- [x] [ABINet](./algorithm_rec_abinet_en.md)
|
||||
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
|
||||
- [x] [SPIN](./algorithm_rec_spin_en.md)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
@ -89,6 +90,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|
||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|
||||
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# VisionLAN
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
|
||||
> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
|
||||
> ICCV, 2021
|
||||
|
||||
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
|
||||
```
|
||||
|
||||
**Note:**
|
||||
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
|
||||
- If you modified the input size during training, please modify the `infer_shape` corresponding to VisionLAN in the `tools/export_model.py` file.
|
||||
|
||||
After the conversion is successful, there are three files in the directory:
|
||||
```
|
||||
./inference/rec_r45_visionlan/
|
||||
├── inference.pdiparams
|
||||
├── inference.pdiparams.info
|
||||
└── inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
For VisionLAN text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
|
||||
The result is as follows:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
|
||||
2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{wang2021two,
|
||||
title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
|
||||
author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
|
||||
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
||||
pages={14194--14203},
|
||||
year={2021}
|
||||
}
|
||||
```
|
|
@ -25,8 +25,9 @@ from .make_pse_gt import MakePseGt
|
|||
|
||||
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg
|
||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg
|
||||
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -23,6 +23,8 @@ import string
|
|||
from shapely.geometry import LineString, Point, Polygon
|
||||
import json
|
||||
import copy
|
||||
from random import sample
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.data.imaug.vqa.augment import order_by_tbyx
|
||||
|
||||
|
@ -98,12 +100,13 @@ class BaseRecLabelEncode(object):
|
|||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False):
|
||||
use_space_char=False,
|
||||
lower=False):
|
||||
|
||||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.lower = False
|
||||
self.lower = lower
|
||||
|
||||
if character_dict_path is None:
|
||||
logger = get_logger()
|
||||
|
@ -1266,3 +1269,67 @@ class SPINLabelEncode(AttnLabelEncode):
|
|||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
||||
|
||||
class VLLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
lower=True,
|
||||
**kwargs):
|
||||
super(VLLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char, lower)
|
||||
self.character = self.character[10:] + self.character[
|
||||
1:10] + [self.character[0]]
|
||||
self.dict = {}
|
||||
for i, char in enumerate(self.character):
|
||||
self.dict[char] = i
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label'] # original string
|
||||
# generate occluded text
|
||||
len_str = len(text)
|
||||
if len_str <= 0:
|
||||
return None
|
||||
change_num = 1
|
||||
order = list(range(len_str))
|
||||
change_id = sample(order, change_num)[0]
|
||||
label_sub = text[change_id]
|
||||
if change_id == (len_str - 1):
|
||||
label_res = text[:change_id]
|
||||
elif change_id == 0:
|
||||
label_res = text[1:]
|
||||
else:
|
||||
label_res = text[:change_id] + text[change_id + 1:]
|
||||
|
||||
data['label_res'] = label_res # remaining string
|
||||
data['label_sub'] = label_sub # occluded character
|
||||
data['label_id'] = change_id # character index
|
||||
# encode label
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = [i + 1 for i in text]
|
||||
data['length'] = np.array(len(text))
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
label_res = self.encode(label_res)
|
||||
label_sub = self.encode(label_sub)
|
||||
if label_res is None:
|
||||
label_res = []
|
||||
else:
|
||||
label_res = [i + 1 for i in label_res]
|
||||
if label_sub is None:
|
||||
label_sub = []
|
||||
else:
|
||||
label_sub = [i + 1 for i in label_sub]
|
||||
data['length_res'] = np.array(len(label_res))
|
||||
data['length_sub'] = np.array(len(label_sub))
|
||||
label_res = label_res + [0] * (self.max_text_len - len(label_res))
|
||||
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
|
||||
data['label_res'] = np.array(label_res)
|
||||
data['label_sub'] = np.array(label_sub)
|
||||
return data
|
||||
|
|
|
@ -205,6 +205,38 @@ class RecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class VLRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
infer_mode=False,
|
||||
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_dict_path = character_dict_path
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
|
||||
imgC, imgH, imgW = self.image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
resized_image = resized_image.astype('float32')
|
||||
if self.image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
norm_img = resized_image[np.newaxis, :]
|
||||
else:
|
||||
norm_img = resized_image.transpose((2, 0, 1)) / 255
|
||||
valid_ratio = min(1.0, float(resized_w / imgW))
|
||||
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
@ -259,6 +291,7 @@ class PRENResizeImg(object):
|
|||
data['image'] = resized_img.astype(np.float32)
|
||||
return data
|
||||
|
||||
|
||||
class SPINRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
|
@ -303,6 +336,7 @@ class SPINRecResizeImg(object):
|
|||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class GrayRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
|
|
|
@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
|
|||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
from .rec_multi_loss import MultiLoss
|
||||
from .rec_vl_loss import VLLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
|
||||
# cls loss
|
||||
|
@ -63,7 +64,7 @@ def build_loss(config):
|
|||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss'
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/wangyuxin87/VisionLAN
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class VLLoss(nn.Layer):
|
||||
def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
|
||||
super(VLLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
|
||||
assert mode in ['LF_1', 'LF_2', 'LA']
|
||||
self.mode = mode
|
||||
self.weight_res = weight_res
|
||||
self.weight_mas = weight_mas
|
||||
|
||||
def flatten_label(self, target):
|
||||
label_flatten = []
|
||||
label_length = []
|
||||
for i in range(0, target.shape[0]):
|
||||
cur_label = target[i].tolist()
|
||||
label_flatten += cur_label[:cur_label.index(0) + 1]
|
||||
label_length.append(cur_label.index(0) + 1)
|
||||
label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
|
||||
label_length = paddle.to_tensor(label_length, dtype='int32')
|
||||
return (label_flatten, label_length)
|
||||
|
||||
def _flatten(self, sources, lengths):
|
||||
return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
text_pre = predicts[0]
|
||||
target = batch[1].astype('int64')
|
||||
label_flatten, length = self.flatten_label(target)
|
||||
text_pre = self._flatten(text_pre, length)
|
||||
if self.mode == 'LF_1':
|
||||
loss = self.loss_func(text_pre, label_flatten)
|
||||
else:
|
||||
text_rem = predicts[1]
|
||||
text_mas = predicts[2]
|
||||
target_res = batch[2].astype('int64')
|
||||
target_sub = batch[3].astype('int64')
|
||||
label_flatten_res, length_res = self.flatten_label(target_res)
|
||||
label_flatten_sub, length_sub = self.flatten_label(target_sub)
|
||||
text_rem = self._flatten(text_rem, length_res)
|
||||
text_mas = self._flatten(text_mas, length_sub)
|
||||
loss_ori = self.loss_func(text_pre, label_flatten)
|
||||
loss_res = self.loss_func(text_rem, label_flatten_res)
|
||||
loss_mas = self.loss_func(text_mas, label_flatten_sub)
|
||||
loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
|
||||
return {'loss': loss}
|
|
@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResNet45(nn.Layer):
|
||||
def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
block=BasicBlock,
|
||||
layers=[3, 4, 6, 6, 3],
|
||||
strides=[2, 1, 2, 1, 1]):
|
||||
self.inplanes = 32
|
||||
super(ResNet45, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
3,
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
|
@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
|
|||
self.bn1 = nn.BatchNorm2D(32)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
|
||||
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
|
||||
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
|
||||
self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
|
||||
self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
|
||||
self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
|
||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
|
||||
self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
|
||||
self.out_channels = 512
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2D):
|
||||
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
||||
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
|
@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
|
|||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
# print(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
# print(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
return x
|
||||
|
|
|
@ -35,6 +35,7 @@ def build_head(config):
|
|||
from .rec_multi_head import MultiHead
|
||||
from .rec_spin_att_head import SPINAttentionHead
|
||||
from .rec_abinet_head import ABINetHead
|
||||
from .rec_visionlan_head import VLHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -49,7 +50,8 @@ def build_head(config):
|
|||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,468 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/wangyuxin87/VisionLAN
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import Normal, XavierNormal
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.register_buffer(
|
||||
'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
|
||||
sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
|
||||
return sinusoid_table
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.shape[1]].clone().detach()
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Layer):
|
||||
"Scaled Dot-Product Attention"
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
self.softmax = nn.Softmax(axis=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
k = paddle.transpose(k, perm=[0, 2, 1])
|
||||
attn = paddle.bmm(q, k)
|
||||
attn = attn / self.temperature
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -1e9)
|
||||
if mask.dim() == 3:
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
elif mask.dim() == 2:
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
repeat_times = [
|
||||
attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
|
||||
]
|
||||
mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
|
||||
attn[mask == 0] = -1e9
|
||||
attn = self.softmax(attn)
|
||||
attn = self.dropout(attn)
|
||||
output = paddle.bmm(attn, v)
|
||||
return output
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Layer):
|
||||
" Multi-Head Attention module"
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
self.w_qs = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_k,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
|
||||
self.w_ks = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_k,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
|
||||
self.w_vs = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_v,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
|
||||
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
|
||||
0.5))
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self.fc = nn.Linear(
|
||||
n_head * d_v,
|
||||
d_model,
|
||||
weight_attr=ParamAttr(initializer=XavierNormal()))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.shape
|
||||
sz_b, len_k, _ = k.shape
|
||||
sz_b, len_v, _ = v.shape
|
||||
residual = q
|
||||
|
||||
q = self.w_qs(q)
|
||||
q = paddle.reshape(
|
||||
q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
|
||||
k = self.w_ks(k)
|
||||
k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
|
||||
v = self.w_vs(v)
|
||||
v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
|
||||
|
||||
q = paddle.transpose(q, perm=[2, 0, 1, 3])
|
||||
q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
|
||||
k = paddle.transpose(k, perm=[2, 0, 1, 3])
|
||||
k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
|
||||
v = paddle.transpose(v, perm=[2, 0, 1, 3])
|
||||
v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
|
||||
|
||||
mask = paddle.tile(
|
||||
mask,
|
||||
[n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
|
||||
output = self.attention(q, k, v, mask=mask)
|
||||
output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
|
||||
output = paddle.transpose(output, perm=[1, 2, 0, 3])
|
||||
output = paddle.reshape(
|
||||
output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
|
||||
output = self.dropout(self.fc(output))
|
||||
output = self.layer_norm(output + residual)
|
||||
return output
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Layer):
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
|
||||
self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.w_2(F.relu(self.w_1(x)))
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.dropout(x)
|
||||
x = self.layer_norm(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Layer):
|
||||
''' Compose with two layers '''
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.slf_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
enc_output = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class Transformer_Encoder(nn.Layer):
|
||||
def __init__(self,
|
||||
n_layers=2,
|
||||
n_head=8,
|
||||
d_word_vec=512,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
n_position=256):
|
||||
super(Transformer_Encoder, self).__init__()
|
||||
self.position_enc = PositionalEncoding(
|
||||
d_word_vec, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.layer_stack = nn.LayerList([
|
||||
EncoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
|
||||
|
||||
def forward(self, enc_output, src_mask, return_attns=False):
|
||||
enc_output = self.dropout(
|
||||
self.position_enc(enc_output)) # position embeding
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
|
||||
enc_output = self.layer_norm(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class PP_layer(nn.Layer):
|
||||
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
|
||||
|
||||
super(PP_layer, self).__init__()
|
||||
self.character_len = N_max_character
|
||||
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
|
||||
self.w0 = nn.Linear(N_max_character, n_position)
|
||||
self.wv = nn.Linear(n_dim, n_dim)
|
||||
self.we = nn.Linear(n_dim, N_max_character)
|
||||
self.active = nn.Tanh()
|
||||
self.softmax = nn.Softmax(axis=2)
|
||||
|
||||
def forward(self, enc_output):
|
||||
# enc_output: b,256,512
|
||||
reading_order = paddle.arange(self.character_len, dtype='int64')
|
||||
reading_order = reading_order.unsqueeze(0).expand(
|
||||
[enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
|
||||
reading_order = self.f0_embedding(reading_order) # b,25,512
|
||||
|
||||
# calculate attention
|
||||
reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
|
||||
t = self.w0(reading_order) # b,512,256
|
||||
t = self.active(
|
||||
paddle.transpose(
|
||||
t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
|
||||
t = self.we(t) # b,256,25
|
||||
t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
|
||||
g_output = paddle.bmm(t, enc_output) # b,25,512
|
||||
return g_output
|
||||
|
||||
|
||||
class Prediction(nn.Layer):
|
||||
def __init__(self,
|
||||
n_dim=512,
|
||||
n_position=256,
|
||||
N_max_character=25,
|
||||
n_class=37):
|
||||
super(Prediction, self).__init__()
|
||||
self.pp = PP_layer(
|
||||
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
|
||||
self.pp_share = PP_layer(
|
||||
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
|
||||
self.w_vrm = nn.Linear(n_dim, n_class) # output layer
|
||||
self.w_share = nn.Linear(n_dim, n_class) # output layer
|
||||
self.nclass = n_class
|
||||
|
||||
def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
|
||||
use_mlm=True):
|
||||
if train_mode:
|
||||
if not use_mlm:
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
g_output = self.w_vrm(g_output)
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
return g_output, f_res, f_sub
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
f_res = self.pp_share(f_res)
|
||||
f_sub = self.pp_share(f_sub)
|
||||
g_output = self.w_vrm(g_output)
|
||||
f_res = self.w_share(f_res)
|
||||
f_sub = self.w_share(f_sub)
|
||||
return g_output, f_res, f_sub
|
||||
else:
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
g_output = self.w_vrm(g_output)
|
||||
return g_output
|
||||
|
||||
|
||||
class MLM(nn.Layer):
|
||||
"Architecture of MLM"
|
||||
|
||||
def __init__(self, n_dim=512, n_position=256, max_text_length=25):
|
||||
super(MLM, self).__init__()
|
||||
self.MLM_SequenceModeling_mask = Transformer_Encoder(
|
||||
n_layers=2, n_position=n_position)
|
||||
self.MLM_SequenceModeling_WCL = Transformer_Encoder(
|
||||
n_layers=1, n_position=n_position)
|
||||
self.pos_embedding = nn.Embedding(max_text_length, n_dim)
|
||||
self.w0_linear = nn.Linear(1, n_position)
|
||||
self.wv = nn.Linear(n_dim, n_dim)
|
||||
self.active = nn.Tanh()
|
||||
self.we = nn.Linear(n_dim, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, label_pos):
|
||||
# transformer unit for generating mask_c
|
||||
feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
|
||||
# position embedding layer
|
||||
label_pos = paddle.to_tensor(label_pos, dtype='int64')
|
||||
pos_emb = self.pos_embedding(label_pos)
|
||||
pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
|
||||
pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
|
||||
# fusion position embedding with features V & generate mask_c
|
||||
att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
|
||||
att_map_sub = self.we(att_map_sub) # b,256,1
|
||||
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
|
||||
att_map_sub = self.sigmoid(att_map_sub) # b,1,256
|
||||
# WCL
|
||||
## generate inputs for WCL
|
||||
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
|
||||
f_res = x * (1 - att_map_sub) # second path with remaining string
|
||||
f_sub = x * att_map_sub # first path with occluded character
|
||||
## transformer units in WCL
|
||||
f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
|
||||
f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
|
||||
return f_res, f_sub, att_map_sub
|
||||
|
||||
|
||||
def trans_1d_2d(x):
|
||||
b, w_h, c = x.shape # b, 256, 512
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = paddle.reshape(x, [-1, c, 32, 8])
|
||||
x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
|
||||
return x
|
||||
|
||||
|
||||
class MLM_VRM(nn.Layer):
|
||||
"""
|
||||
MLM+VRM, MLM is only used in training.
|
||||
ratio controls the occluded number in a batch.
|
||||
The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
|
||||
x: input image
|
||||
label_pos: character index
|
||||
training_step: LF or LA process
|
||||
output
|
||||
text_pre: prediction of VRM
|
||||
test_rem: prediction of remaining string in MLM
|
||||
text_mas: prediction of occluded character in MLM
|
||||
mask_c_show: visualization of Mask_c
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=3,
|
||||
n_position=256,
|
||||
n_dim=512,
|
||||
max_text_length=25,
|
||||
nclass=37):
|
||||
super(MLM_VRM, self).__init__()
|
||||
self.MLM = MLM(n_dim=n_dim,
|
||||
n_position=n_position,
|
||||
max_text_length=max_text_length)
|
||||
self.SequenceModeling = Transformer_Encoder(
|
||||
n_layers=n_layers, n_position=n_position)
|
||||
self.Prediction = Prediction(
|
||||
n_dim=n_dim,
|
||||
n_position=n_position,
|
||||
N_max_character=max_text_length +
|
||||
1, # N_max_character = 1 eos + 25 characters
|
||||
n_class=nclass)
|
||||
self.nclass = nclass
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def forward(self, x, label_pos, training_step, train_mode=False):
|
||||
b, c, h, w = x.shape
|
||||
nT = self.max_text_length
|
||||
x = paddle.transpose(x, perm=[0, 1, 3, 2])
|
||||
x = paddle.reshape(x, [-1, c, h * w])
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
if train_mode:
|
||||
if training_step == 'LF_1':
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True, use_mlm=False)
|
||||
return text_pre, text_pre, text_pre, text_pre
|
||||
elif training_step == 'LF_2':
|
||||
# MLM
|
||||
f_res, f_sub, mask_c = self.MLM(x, label_pos)
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True)
|
||||
mask_c_show = trans_1d_2d(mask_c)
|
||||
return text_pre, test_rem, text_mas, mask_c_show
|
||||
elif training_step == 'LA':
|
||||
# MLM
|
||||
f_res, f_sub, mask_c = self.MLM(x, label_pos)
|
||||
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
|
||||
## ratio controls the occluded number in a batch
|
||||
character_mask = paddle.zeros_like(mask_c)
|
||||
|
||||
ratio = b // 2
|
||||
if ratio >= 1:
|
||||
with paddle.no_grad():
|
||||
character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
|
||||
else:
|
||||
character_mask = mask_c
|
||||
x = x * (1 - character_mask)
|
||||
# VRM
|
||||
## transformer unit for VRM
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
## prediction layer for MLM and VSR
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True)
|
||||
mask_c_show = trans_1d_2d(mask_c)
|
||||
return text_pre, test_rem, text_mas, mask_c_show
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else: # VRM is only used in the testing stage
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
contextual_feature = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre = self.Prediction(
|
||||
contextual_feature,
|
||||
f_res,
|
||||
f_sub,
|
||||
train_mode=False,
|
||||
use_mlm=False)
|
||||
text_pre = paddle.transpose(
|
||||
text_pre, perm=[1, 0, 2]) # (26, b, 37))
|
||||
return text_pre, x
|
||||
|
||||
|
||||
class VLHead(nn.Layer):
|
||||
"""
|
||||
Architecture of VisionLAN
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=36,
|
||||
n_layers=3,
|
||||
n_position=256,
|
||||
n_dim=512,
|
||||
max_text_length=25,
|
||||
training_step='LA'):
|
||||
super(VLHead, self).__init__()
|
||||
self.MLM_VRM = MLM_VRM(
|
||||
n_layers=n_layers,
|
||||
n_position=n_position,
|
||||
n_dim=n_dim,
|
||||
max_text_length=max_text_length,
|
||||
nclass=out_channels + 1)
|
||||
self.training_step = training_step
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
|
||||
if self.training:
|
||||
label_pos = targets[-2]
|
||||
text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
|
||||
feat, label_pos, self.training_step, train_mode=True)
|
||||
return text_pre, test_rem, text_mas, mask_map
|
||||
else:
|
||||
text_pre, x = self.MLM_VRM(
|
||||
feat, targets, self.training_step, train_mode=False)
|
||||
return text_pre, x
|
|
@ -77,11 +77,62 @@ class Adam(object):
|
|||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
self.lazy_mode = lazy_mode
|
||||
self.group_lr = kwargs.get('group_lr', False)
|
||||
self.training_step = kwargs.get('training_step', None)
|
||||
|
||||
def __call__(self, model):
|
||||
train_params = [
|
||||
param for param in model.parameters() if param.trainable is True
|
||||
]
|
||||
if self.group_lr:
|
||||
if self.training_step == 'LF_2':
|
||||
import paddle
|
||||
if isinstance(model, paddle.fluid.dygraph.parallel.
|
||||
DataParallel): # multi gpu
|
||||
mlm = model._layers.head.MLM_VRM.MLM.parameters()
|
||||
pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
|
||||
)
|
||||
pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
|
||||
)
|
||||
else: # single gpu
|
||||
mlm = model.head.MLM_VRM.MLM.parameters()
|
||||
pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
|
||||
)
|
||||
pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
|
||||
)
|
||||
|
||||
total = []
|
||||
for param in mlm:
|
||||
total.append(id(param))
|
||||
for param in pre_mlm_pp:
|
||||
total.append(id(param))
|
||||
for param in pre_mlm_w:
|
||||
total.append(id(param))
|
||||
|
||||
group_base_params = [
|
||||
param for param in model.parameters() if id(param) in total
|
||||
]
|
||||
group_small_params = [
|
||||
param for param in model.parameters()
|
||||
if id(param) not in total
|
||||
]
|
||||
train_params = [{
|
||||
'params': group_base_params
|
||||
}, {
|
||||
'params': group_small_params,
|
||||
'learning_rate': self.learning_rate.values[0] * 0.1
|
||||
}]
|
||||
|
||||
else:
|
||||
print(
|
||||
'group lr currently only support VisionLAN in LF_2 training step'
|
||||
)
|
||||
train_params = [
|
||||
param for param in model.parameters()
|
||||
if param.trainable is True
|
||||
]
|
||||
else:
|
||||
train_params = [
|
||||
param for param in model.parameters() if param.trainable is True
|
||||
]
|
||||
|
||||
opt = optim.Adam(
|
||||
learning_rate=self.learning_rate,
|
||||
beta1=self.beta1,
|
||||
|
|
|
@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
|
|||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
||||
SPINLabelDecode
|
||||
SPINLabelDecode, VLLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
||||
|
@ -38,31 +38,16 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
|
|||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess',
|
||||
'EASTPostProcess',
|
||||
'SASTPostProcess',
|
||||
'FCEPostProcess',
|
||||
'CTCLabelDecode',
|
||||
'AttnLabelDecode',
|
||||
'ClsPostProcess',
|
||||
'SRNLabelDecode',
|
||||
'PGPostProcess',
|
||||
'DistillationCTCLabelDecode',
|
||||
'TableLabelDecode',
|
||||
'DistillationDBPostProcess',
|
||||
'NRTRLabelDecode',
|
||||
'SARLabelDecode',
|
||||
'SEEDLabelDecode',
|
||||
'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess',
|
||||
'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode',
|
||||
'ViTSTRLabelDecode',
|
||||
'ABINetLabelDecode',
|
||||
'TableMasterLabelDecode',
|
||||
'SPINLabelDecode',
|
||||
'DistillationSerPostProcess',
|
||||
'DistillationRePostProcess',
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
|
||||
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
|
||||
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
|
||||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -668,6 +668,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
|
|||
dict_character = ['</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class SPINLabelDecode(AttnLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -682,3 +683,105 @@ class SPINLabelDecode(AttnLabelDecode):
|
|||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class VLLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
|
||||
self.max_text_length = kwargs.get('max_text_length', 25)
|
||||
self.nclass = len(self.character) + 1
|
||||
self.character = self.character[10:] + self.character[
|
||||
1:10] + [self.character[0]]
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id - 1]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, length=None, *args, **kwargs):
|
||||
if len(preds) == 2: # eval mode
|
||||
text_pre, x = preds
|
||||
b = text_pre.shape[1]
|
||||
lenText = self.max_text_length
|
||||
nsteps = self.max_text_length
|
||||
|
||||
if not isinstance(text_pre, paddle.Tensor):
|
||||
text_pre = paddle.to_tensor(text_pre, dtype='float32')
|
||||
|
||||
out_res = paddle.zeros(
|
||||
shape=[lenText, b, self.nclass], dtype=x.dtype)
|
||||
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
|
||||
now_step = 0
|
||||
for _ in range(nsteps):
|
||||
if 0 in out_length and now_step < nsteps:
|
||||
tmp_result = text_pre[now_step, :, :]
|
||||
out_res[now_step] = tmp_result
|
||||
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
||||
for j in range(b):
|
||||
if out_length[j] == 0 and tmp_result[j] == 0:
|
||||
out_length[j] = now_step + 1
|
||||
now_step += 1
|
||||
for j in range(0, b):
|
||||
if int(out_length[j]) == 0:
|
||||
out_length[j] = nsteps
|
||||
start = 0
|
||||
output = paddle.zeros(
|
||||
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
|
||||
for i in range(0, b):
|
||||
cur_length = int(out_length[i])
|
||||
output[start:start + cur_length] = out_res[0:cur_length, i, :]
|
||||
start += cur_length
|
||||
net_out = output
|
||||
length = out_length
|
||||
|
||||
else: # train mode
|
||||
net_out = preds[0]
|
||||
length = length
|
||||
net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
|
||||
text = []
|
||||
if not isinstance(net_out, paddle.Tensor):
|
||||
net_out = paddle.to_tensor(net_out, dtype='float32')
|
||||
net_out = F.softmax(net_out, axis=1)
|
||||
for i in range(0, length.shape[0]):
|
||||
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[1][:, 0].tolist()
|
||||
preds_text = ''.join([
|
||||
self.character[idx - 1]
|
||||
if idx > 0 and idx <= len(self.character) else ''
|
||||
for idx in preds_idx
|
||||
])
|
||||
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[0][:, 0]
|
||||
preds_prob = paddle.exp(
|
||||
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
|
||||
text.append((preds_text, preds_prob))
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
|
|
@ -73,7 +73,7 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
|
|
@ -97,6 +97,12 @@ def export_single_model(model,
|
|||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "VisionLAN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 64, 256], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(
|
||||
|
|
|
@ -69,6 +69,12 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "VisionLAN":
|
||||
postprocess_params = {
|
||||
'name': 'VLLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == 'ViTSTR':
|
||||
postprocess_params = {
|
||||
'name': 'ViTSTRLabelDecode',
|
||||
|
@ -157,6 +163,16 @@ class TextRecognizer(object):
|
|||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
img = img[:, :, ::-1] # bgr2rgb
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
|
@ -280,6 +296,7 @@ class TextRecognizer(object):
|
|||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
@ -359,6 +376,11 @@ class TextRecognizer(object):
|
|||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "VisionLAN":
|
||||
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == 'SPIN':
|
||||
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
|
|
@ -131,7 +131,6 @@ def main():
|
|||
if config['Architecture']['algorithm'] == "SAR":
|
||||
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio)]
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
|
|
|
@ -227,7 +227,9 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
@ -269,7 +271,6 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast(level='O2'):
|
||||
|
@ -310,6 +311,9 @@ def train(config,
|
|||
]: # for multi head loss
|
||||
post_result = post_process_class(
|
||||
preds['ctc'], batch[1]) # for CTC head out
|
||||
elif config['Loss']['name'] in ['VLLoss']:
|
||||
post_result = post_process_class(preds, batch[1],
|
||||
batch[-1])
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
|
@ -612,7 +616,7 @@ def preprocess(is_train=False):
|
|||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
@ -631,7 +635,7 @@ def preprocess(is_train=False):
|
|||
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
log_writer = VDLLogger(save_model_dir)
|
||||
log_writer = VDLLogger(vdl_writer_path)
|
||||
loggers.append(log_writer)
|
||||
if ('use_wandb' in config['Global'] and
|
||||
config['Global']['use_wandb']) or 'wandb' in config:
|
||||
|
|
Loading…
Reference in New Issue