Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph
commit
9b8ffb6c21
|
@ -0,0 +1,107 @@
|
|||
# 高精度中文场景文本识别模型SVTR
|
||||
|
||||
## 1. 简介
|
||||
|
||||
PP-OCRv3是百度开源的超轻量级场景文本检测识别模型库,其中超轻量的场景中文识别模型SVTR_LCNet使用了SVTR算法结构。为了保证速度,SVTR_LCNet将SVTR模型的Local Blocks替换为LCNet,使用两层Global Blocks。在中文场景中,PP-OCRv3识别主要使用如下优化策略:
|
||||
- GTC:Attention指导CTC训练策略;
|
||||
- TextConAug:挖掘文字上下文信息的数据增广策略;
|
||||
- TextRotNet:自监督的预训练模型;
|
||||
- UDML:联合互学习策略;
|
||||
- UIM:无标注数据挖掘方案。
|
||||
|
||||
其中 *UIM:无标注数据挖掘方案* 使用了高精度的SVTR中文模型进行无标注文件的刷库,该模型在PP-OCRv3识别的数据集上训练,精度对比如下表。
|
||||
|
||||
|中文识别算法|模型|UIM|精度|
|
||||
| --- | --- | --- |--- |
|
||||
|PP-OCRv3|SVTR_LCNet| w/o |78.4%|
|
||||
|PP-OCRv3|SVTR_LCNet| w |79.4%|
|
||||
|SVTR|SVTR-Tiny|-|82.5%|
|
||||
|
||||
aistudio项目链接: [高精度中文场景文本识别模型SVTR](https://aistudio.baidu.com/aistudio/projectdetail/4263032)
|
||||
|
||||
## 2. SVTR中文模型使用
|
||||
|
||||
### 环境准备
|
||||
|
||||
|
||||
本任务基于Aistudio完成, 具体环境如下:
|
||||
|
||||
- 操作系统: Linux
|
||||
- PaddlePaddle: 2.3
|
||||
- PaddleOCR: dygraph
|
||||
|
||||
下载 PaddleOCR代码
|
||||
|
||||
```bash
|
||||
git clone -b dygraph https://github.com/PaddlePaddle/PaddleOCR
|
||||
```
|
||||
|
||||
安装依赖库
|
||||
|
||||
```bash
|
||||
pip install -r PaddleOCR/requirements.txt -i https://mirror.baidu.com/pypi/simple
|
||||
```
|
||||
|
||||
### 快速使用
|
||||
|
||||
获取SVTR中文模型文件,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
|
||||
<div align="center">
|
||||
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
|
||||
</div>
|
||||
|
||||
```bash
|
||||
# 解压模型文件
|
||||
tar xf svtr_ch_high_accuracy.tar
|
||||
```
|
||||
|
||||
预测中文文本,以下图为例:
|
||||

|
||||
|
||||
预测命令:
|
||||
|
||||
```bash
|
||||
# CPU预测
|
||||
python tools/infer_rec.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.infer_img=./doc/imgs_words/ch/word_1.jpg Global.use_gpu=False
|
||||
|
||||
# GPU预测
|
||||
#python tools/infer_rec.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.infer_img=./doc/imgs_words/ch/word_1.jpg Global.use_gpu=True
|
||||
```
|
||||
|
||||
可以看到最后打印结果为
|
||||
- result: 韩国小馆 0.9853458404541016
|
||||
|
||||
0.9853458404541016为预测置信度。
|
||||
|
||||
### 推理模型导出与预测
|
||||
|
||||
inference 模型(paddle.jit.save保存的模型) 一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
|
||||
|
||||
运行识别模型转inference模型命令,如下:
|
||||
|
||||
```bash
|
||||
python tools/export_model.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.save_inference_dir=./inference/svtr_ch
|
||||
```
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```shell
|
||||
inference/svtr_ch/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
inference模型预测,命令如下:
|
||||
|
||||
```bash
|
||||
# CPU预测
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_1.jpg" --rec_algorithm='SVTR' --rec_model_dir=./inference/svtr_ch/ --rec_image_shape='3, 32, 320' --rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt --use_gpu=False
|
||||
|
||||
# GPU预测
|
||||
#python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_1.jpg" --rec_algorithm='SVTR' --rec_model_dir=./inference/svtr_ch/ --rec_image_shape='3, 32, 320' --rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt --use_gpu=True
|
||||
```
|
||||
|
||||
**注意**
|
||||
|
||||
- 使用SVTR算法时,需要指定--rec_algorithm='SVTR'
|
||||
- 如果使用自定义字典训练的模型,需要将--rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt修改为自定义的字典
|
||||
- --rec_image_shape='3, 32, 320' 该参数不能去掉
|
|
@ -9,7 +9,7 @@ Global:
|
|||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
|
@ -49,7 +49,7 @@ Architecture:
|
|||
|
||||
|
||||
Loss:
|
||||
name: NRTRLoss
|
||||
name: CELoss
|
||||
smoothing: True
|
||||
|
||||
PostProcess:
|
||||
|
@ -68,8 +68,8 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [100, 32] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
|
@ -82,14 +82,14 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [100, 32] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
|
@ -97,5 +97,5 @@ Eval:
|
|||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 1
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 10
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r45_abinet/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_abinet.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
clip_norm: 20.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [6]
|
||||
values: [0.0001, 0.00001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ABINet
|
||||
in_channels: 3
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet45
|
||||
Head:
|
||||
name: ABINetHead
|
||||
use_lang: True
|
||||
iter_size: 3
|
||||
|
||||
|
||||
Loss:
|
||||
name: CELoss
|
||||
ignore_index: &ignore_index 100 # Must be greater than the number of character classes
|
||||
|
||||
PostProcess:
|
||||
name: ABINetLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetRecAug:
|
||||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 96
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
|
@ -26,7 +26,7 @@ Optimizer:
|
|||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
epsilon: 0.00000008
|
||||
epsilon: 8.e-8
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm pos_embed
|
||||
one_dim_param_no_weight_decay: true
|
||||
|
@ -77,14 +77,13 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
data_dir: ./train_data/data_lmdb_release/training
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_dict_path:
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
|
@ -98,14 +97,13 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
data_dir: ./train_data/data_lmdb_release/validation
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_dict_path:
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/svtr_ch_all/
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 2000
|
||||
cal_metric_during_train: true
|
||||
pretrained_model: null
|
||||
checkpoints: null
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
save_res_path: ./output/rec/predicts_svtr_tiny_ch_all.txt
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
epsilon: 8.0e-08
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm pos_embed
|
||||
one_dim_param_no_weight_decay: true
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
warmup_epoch: 2
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: SVTRNet
|
||||
img_size:
|
||||
- 32
|
||||
- 320
|
||||
out_char_num: 40
|
||||
out_channels: 96
|
||||
patch_merging: Conv
|
||||
embed_dim:
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
depth:
|
||||
- 3
|
||||
- 6
|
||||
- 3
|
||||
num_heads:
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
mixer:
|
||||
- Local
|
||||
- Local
|
||||
- Local
|
||||
- Local
|
||||
- Local
|
||||
- Local
|
||||
- Global
|
||||
- Global
|
||||
- Global
|
||||
- Global
|
||||
- Global
|
||||
- Global
|
||||
local_mixer:
|
||||
- - 7
|
||||
- 11
|
||||
- - 7
|
||||
- 11
|
||||
- - 7
|
||||
- 11
|
||||
last_stage: true
|
||||
prenorm: false
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
ext_op_transform_idx: 1
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecConAug:
|
||||
prob: 0.5
|
||||
ext_data_num: 2
|
||||
image_shape:
|
||||
- 32
|
||||
- 320
|
||||
- 3
|
||||
- RecAug: null
|
||||
- CTCLabelEncode: null
|
||||
- SVTRRecResizeImg:
|
||||
image_shape:
|
||||
- 3
|
||||
- 32
|
||||
- 320
|
||||
padding: true
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 256
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode: null
|
||||
- SVTRRecResizeImg:
|
||||
image_shape:
|
||||
- 3
|
||||
- 32
|
||||
- 320
|
||||
padding: true
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 256
|
||||
num_workers: 2
|
||||
profiler_options: null
|
|
@ -0,0 +1,102 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 20
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/vitstr_none_ce/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations after the 0th iteration#
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/EN_symbol_dict.txt
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_vitstr.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adadelta
|
||||
epsilon: 1.e-8
|
||||
rho: 0.95
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 1.0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ViTSTR
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ViTSTR
|
||||
scale: tiny
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
|
||||
Loss:
|
||||
name: CELoss
|
||||
with_all: True
|
||||
ignore_index: &ignore_index 0 # Must be zero or greater than the number of character classes
|
||||
|
||||
PostProcess:
|
||||
name: ViTSTRLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ViTSTRLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [224, 224] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
inter_type: 'Image.BICUBIC'
|
||||
scale: false
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 48
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ViTSTRLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [224, 224] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
inter_type: 'Image.BICUBIC'
|
||||
scale: false
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 2
|
|
@ -66,6 +66,8 @@
|
|||
- [x] [SAR](./algorithm_rec_sar.md)
|
||||
- [x] [SEED](./algorithm_rec_seed.md)
|
||||
- [x] [SVTR](./algorithm_rec_svtr.md)
|
||||
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
|
||||
- [x] [ABINet](./algorithm_rec_abinet.md)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -84,7 +86,8 @@
|
|||
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|
||||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|
||||
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|
||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# 场景文本识别算法-ABINet
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition](https://openaccess.thecvf.com/content/CVPR2021/papers/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.pdf)
|
||||
> Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang
|
||||
> CVPR, 2021
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`ABINet`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|ABINet|ResNet45|[rec_r45_abinet.yml](../../configs/rec/rec_r45_abinet.yml)|90.75%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`ABINet`识别模型时需要**更换配置文件**为`ABINet`的[配置文件](../../configs/rec/rec_r45_abinet.yml)。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
```shell
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_r45_abinet.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_abinet.yml
|
||||
```
|
||||
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r45_abinet.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_abinet/
|
||||
```
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/rec_r45_abinet/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_r45_abinet/' --rec_algorithm='ABINet' --rec_image_shape='3,32,128' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt'
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999995231628418)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
||||
- 训练上述模型采用的图像分辨率是[3,32,128],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中ABINet的预处理为您的预处理方法。
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持ABINet,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. MJSynth和SynthText两种数据集来自于[ABINet源repo](https://github.com/FangShancheng/ABINet) 。
|
||||
2. 我们使用ABINet作者提供的预训练模型进行finetune训练。
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{Fang2021ABINet,
|
||||
title = {ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
|
||||
author = {Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang},
|
||||
booktitle = {CVPR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2103.06495},
|
||||
pages = {7098-7107}
|
||||
}
|
||||
```
|
|
@ -12,6 +12,7 @@
|
|||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
- [6. 发行公告](#6)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
@ -110,7 +111,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png'
|
|||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9465042352676392)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
@ -140,12 +141,147 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
|
|||
|
||||
1. `NRTR`论文中使用Beam搜索进行解码字符,但是速度较慢,这里默认未使用Beam搜索,以贪婪搜索进行解码字符。
|
||||
|
||||
<a name="6"></a>
|
||||
## 6. 发行公告
|
||||
|
||||
1. release/2.6更新NRTR代码结构,新版NRTR可加载旧版(release/2.5及之前)模型参数,使用下面示例代码将旧版模型参数转换为新版模型参数:
|
||||
|
||||
```python
|
||||
|
||||
params = paddle.load('path/' + '.pdparams') # 旧版本参数
|
||||
state_dict = model.state_dict() # 新版模型参数
|
||||
new_state_dict = {}
|
||||
|
||||
for k1, v1 in state_dict.items():
|
||||
|
||||
k = k1
|
||||
if 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
|
||||
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
|
||||
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')]
|
||||
k = params[k_para.replace('qkv', 'conv2')]
|
||||
v = params[k_para.replace('qkv', 'conv3')]
|
||||
|
||||
new_state_dict[k1] = np.concatenate([q, k, v], -1)
|
||||
|
||||
elif 'encoder' in k and 'self_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
elif 'encoder' in k and 'norm3' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para.replace('norm3', 'norm2')]
|
||||
|
||||
elif 'encoder' in k and 'norm1' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
|
||||
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')]
|
||||
k = params[k_para.replace('qkv', 'conv2')]
|
||||
v = params[k_para.replace('qkv', 'conv3')]
|
||||
new_state_dict[k1] = np.concatenate([q, k, v], -1)
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
q = params[k_para.replace('q', 'conv1')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = q[:, :, 0, 0]
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
q = params[k_para.replace('q', 'conv1')]
|
||||
new_state_dict[k1] = q
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
k = params[k_para.replace('kv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('kv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = np.concatenate([k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
k = params[k_para.replace('kv', 'conv2')]
|
||||
v = params[k_para.replace('kv', 'conv3')]
|
||||
new_state_dict[k1] = np.concatenate([k, v], -1)
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
new_state_dict[k1] = params[k_para]
|
||||
elif 'decoder' in k and 'norm' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
elif 'mlp' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('fc', 'conv')
|
||||
k_para = k_para.replace('mlp.', '')
|
||||
w = params[k_para].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = w[:, :, 0, 0]
|
||||
elif 'mlp' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('fc', 'conv')
|
||||
k_para = k_para.replace('mlp.', '')
|
||||
w = params[k_para]
|
||||
new_state_dict[k1] = w
|
||||
|
||||
else:
|
||||
new_state_dict[k1] = params[k1]
|
||||
|
||||
if list(new_state_dict[k1].shape) != list(v1.shape):
|
||||
print(k1)
|
||||
|
||||
|
||||
for k, v1 in state_dict.items():
|
||||
if k not in new_state_dict.keys():
|
||||
print(1, k)
|
||||
elif list(new_state_dict[k].shape) != list(v1.shape):
|
||||
print(2, k)
|
||||
|
||||
|
||||
|
||||
model.set_state_dict(new_state_dict)
|
||||
paddle.save(model.state_dict(), 'nrtrnew_from_old_params.pdparams')
|
||||
|
||||
```
|
||||
|
||||
2. 新版相比与旧版,代码结构简洁,推理速度有所提高。
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{Sheng2019NRTR,
|
||||
title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
|
||||
author = {Fenfen Sheng and Zhineng Chen andBo Xu},
|
||||
author = {Fenfen Sheng and Zhineng Chen and Bo Xu},
|
||||
booktitle = {ICDAR},
|
||||
year = {2019},
|
||||
url = {http://arxiv.org/abs/1806.00926},
|
||||
|
|
|
@ -111,7 +111,6 @@ python3 tools/export_model.py -c ./rec_svtr_tiny_none_ctc_en_train/rec_svtr_tiny
|
|||
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
|
||||
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应SVTR的`infer_shape`。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# 场景文本识别算法-ViTSTR
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
|
||||
> Rowel Atienza
|
||||
> ICDAR, 2021
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`ViTSTR`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|
||||
|
||||
|模型|骨干网络|配置文件|Acc|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|ViTSTR|ViTSTR|[rec_vitstr_none_ce.yml](../../configs/rec/rec_vitstr_none_ce.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`ViTSTR`识别模型时需要**更换配置文件**为`ViTSTR`的[配置文件](../../configs/rec/rec_vitstr_none_ce.yml)。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
```shell
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_vitstr_none_ce.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr_none_ce.yml
|
||||
```
|
||||
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.pretrained_model=./rec_vitstr_none_ce_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_none_ce_train/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.pretrained_model=./rec_vitstr_none_ce_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr/
|
||||
```
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ViTSTR的`infer_shape`。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/rec_vitstr/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9998350143432617)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
||||
- 训练上述模型采用的图像分辨率是[1,224,224],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中ViTSTR的预处理为您的预处理方法。
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持ViTSTR,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. 在`ViTSTR`论文中,使用在ImageNet1k上的预训练权重进行初始化训练,我们在训练未采用预训练权重,最终精度没有变化甚至有所提高。
|
||||
2. 我们仅仅复现了`ViTSTR`中的tiny版本,如果需要使用small、base版本,可将[ViTSTR源repo](https://github.com/roatienza/deep-text-recognition-benchmark) 中的预训练权重转为Paddle权重使用。
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{Atienza2021ViTSTR,
|
||||
title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
|
||||
author = {Rowel Atienza},
|
||||
booktitle = {ICDAR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2105.08582}
|
||||
}
|
||||
```
|
|
@ -65,6 +65,8 @@ Supported text recognition algorithms (Click the link to get the tutorial):
|
|||
- [x] [SAR](./algorithm_rec_sar_en.md)
|
||||
- [x] [SEED](./algorithm_rec_seed_en.md)
|
||||
- [x] [SVTR](./algorithm_rec_svtr_en.md)
|
||||
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
|
||||
- [x] [ABINet](./algorithm_rec_abinet_en.md)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -83,7 +85,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|
||||
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|
||||
|
||||
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|
||||
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
# ABINet
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition](https://openaccess.thecvf.com/content/CVPR2021/papers/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.pdf)
|
||||
> Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang
|
||||
> CVPR, 2021
|
||||
|
||||
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|ABINet|ResNet45|[rec_r45_abinet.yml](../../configs/rec/rec_r45_abinet.yml)|90.75%|[pretrained & trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_r45_abinet.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_abinet.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_r45_abinet.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_r45_abinet_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the ABINet text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_r45_abinet.yml -o Global.pretrained_model=./rec_r45_abinet_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_abinet
|
||||
```
|
||||
|
||||
**Note:**
|
||||
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
|
||||
- If you modified the input size during training, please modify the `infer_shape` corresponding to ABINet in the `tools/export_model.py` file.
|
||||
|
||||
After the conversion is successful, there are three files in the directory:
|
||||
```
|
||||
/inference/rec_r45_abinet/
|
||||
├── inference.pdiparams
|
||||
├── inference.pdiparams.info
|
||||
└── inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
For ABINet text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_r45_abinet/' --rec_algorithm='ABINet' --rec_image_shape='3,32,128' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt'
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
|
||||
The result is as follows:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999995231628418)
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. Note that the MJSynth and SynthText datasets come from [ABINet repo](https://github.com/FangShancheng/ABINet).
|
||||
2. We use the pre-trained model provided by the ABINet authors for finetune training.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{Fang2021ABINet,
|
||||
title = {ABINet: Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
|
||||
author = {Shancheng Fang and Hongtao Xie and Yuxin Wang and Zhendong Mao and Yongdong Zhang},
|
||||
booktitle = {CVPR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2103.06495},
|
||||
pages = {7098-7107}
|
||||
}
|
||||
```
|
|
@ -12,6 +12,7 @@
|
|||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
- [6. Release Note](#6)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
@ -25,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|
|||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)|
|
||||
|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
|
@ -98,7 +99,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png'
|
|||
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
|
||||
The result is as follows:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9465042352676392)
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
|
@ -121,12 +122,146 @@ Not supported
|
|||
|
||||
1. In the `NRTR` paper, Beam search is used to decode characters, but the speed is slow. Beam search is not used by default here, and greedy search is used to decode characters.
|
||||
|
||||
<a name="6"></a>
|
||||
## 6. Release Note
|
||||
|
||||
1. The release/2.6 version updates the NRTR code structure. The new version of NRTR can load the model parameters of the old version (release/2.5 and before), and you may use the following code to convert the old version model parameters to the new version model parameters:
|
||||
|
||||
```python
|
||||
|
||||
params = paddle.load('path/' + '.pdparams') # the old version parameters
|
||||
state_dict = model.state_dict() # the new version model parameters
|
||||
new_state_dict = {}
|
||||
|
||||
for k1, v1 in state_dict.items():
|
||||
|
||||
k = k1
|
||||
if 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
|
||||
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
|
||||
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'encoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')]
|
||||
k = params[k_para.replace('qkv', 'conv2')]
|
||||
v = params[k_para.replace('qkv', 'conv3')]
|
||||
|
||||
new_state_dict[k1] = np.concatenate([q, k, v], -1)
|
||||
|
||||
elif 'encoder' in k and 'self_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
elif 'encoder' in k and 'norm3' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para.replace('norm3', 'norm2')]
|
||||
|
||||
elif 'encoder' in k and 'norm1' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')].transpose((1, 0, 2, 3))
|
||||
k = params[k_para.replace('qkv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('qkv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = np.concatenate([q[:, :, 0, 0], k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'qkv' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
q = params[k_para.replace('qkv', 'conv1')]
|
||||
k = params[k_para.replace('qkv', 'conv2')]
|
||||
v = params[k_para.replace('qkv', 'conv3')]
|
||||
new_state_dict[k1] = np.concatenate([q, k, v], -1)
|
||||
|
||||
elif 'decoder' in k and 'self_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
q = params[k_para.replace('q', 'conv1')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = q[:, :, 0, 0]
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'q' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
q = params[k_para.replace('q', 'conv1')]
|
||||
new_state_dict[k1] = q
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
k = params[k_para.replace('kv', 'conv2')].transpose((1, 0, 2, 3))
|
||||
v = params[k_para.replace('kv', 'conv3')].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = np.concatenate([k[:, :, 0, 0], v[:, :, 0, 0]], -1)
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'kv' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
k = params[k_para.replace('kv', 'conv2')]
|
||||
v = params[k_para.replace('kv', 'conv3')]
|
||||
new_state_dict[k1] = np.concatenate([k, v], -1)
|
||||
|
||||
elif 'decoder' in k and 'cross_attn' in k and 'out_proj' in k:
|
||||
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('cross_attn', 'multihead_attn')
|
||||
new_state_dict[k1] = params[k_para]
|
||||
elif 'decoder' in k and 'norm' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
new_state_dict[k1] = params[k_para]
|
||||
elif 'mlp' in k and 'weight' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('fc', 'conv')
|
||||
k_para = k_para.replace('mlp.', '')
|
||||
w = params[k_para].transpose((1, 0, 2, 3))
|
||||
new_state_dict[k1] = w[:, :, 0, 0]
|
||||
elif 'mlp' in k and 'bias' in k:
|
||||
k_para = k[:13] + 'layers.' + k[13:]
|
||||
k_para = k_para.replace('fc', 'conv')
|
||||
k_para = k_para.replace('mlp.', '')
|
||||
w = params[k_para]
|
||||
new_state_dict[k1] = w
|
||||
|
||||
else:
|
||||
new_state_dict[k1] = params[k1]
|
||||
|
||||
if list(new_state_dict[k1].shape) != list(v1.shape):
|
||||
print(k1)
|
||||
|
||||
|
||||
for k, v1 in state_dict.items():
|
||||
if k not in new_state_dict.keys():
|
||||
print(1, k)
|
||||
elif list(new_state_dict[k].shape) != list(v1.shape):
|
||||
print(2, k)
|
||||
|
||||
|
||||
|
||||
model.set_state_dict(new_state_dict)
|
||||
paddle.save(model.state_dict(), 'nrtrnew_from_old_params.pdparams')
|
||||
|
||||
```
|
||||
|
||||
2. The new version has a clean code structure and improved inference speed compared with the old version.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{Sheng2019NRTR,
|
||||
title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
|
||||
author = {Fenfen Sheng and Zhineng Chen andBo Xu},
|
||||
author = {Fenfen Sheng and Zhineng Chen and Bo Xu},
|
||||
booktitle = {ICDAR},
|
||||
year = {2019},
|
||||
url = {http://arxiv.org/abs/1806.00926},
|
||||
|
|
|
@ -88,7 +88,6 @@ python3 tools/export_model.py -c configs/rec/rec_svtrnet.yml -o Global.pretraine
|
|||
|
||||
**Note:**
|
||||
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
|
||||
- If you modified the input size during training, please modify the `infer_shape` corresponding to SVTR in the `tools/export_model.py` file.
|
||||
|
||||
After the conversion is successful, there are three files in the directory:
|
||||
```
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# ViTSTR
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
|
||||
> Rowel Atienza
|
||||
> ICDAR, 2021
|
||||
|
||||
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|ViTSTR|ViTSTR|[rec_vitstr_none_ce.yml](../../configs/rec/rec_vitstr_none_ce.yml)|79.82%|[trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_vitstr_none_ce.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr_none_ce.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_none_ce_train/best_accuracy
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the ViTSTR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_vitstr_none_ce.yml -o Global.pretrained_model=./rec_vitstr_none_ce_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr
|
||||
```
|
||||
|
||||
**Note:**
|
||||
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
|
||||
- If you modified the input size during training, please modify the `infer_shape` corresponding to ViTSTR in the `tools/export_model.py` file.
|
||||
|
||||
After the conversion is successful, there are three files in the directory:
|
||||
```
|
||||
/inference/rec_vitstr/
|
||||
├── inference.pdiparams
|
||||
├── inference.pdiparams.info
|
||||
└── inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
For ViTSTR text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
|
||||
The result is as follows:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9998350143432617)
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. In the `ViTSTR` paper, using pre-trained weights on ImageNet1k for initial training, we did not use pre-trained weights in training, and the final accuracy did not change or even improved.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{Atienza2021ViTSTR,
|
||||
title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
|
||||
author = {Rowel Atienza},
|
||||
booktitle = {ICDAR},
|
||||
year = {2021},
|
||||
url = {https://arxiv.org/abs/2105.08582}
|
||||
}
|
||||
```
|
|
@ -22,8 +22,10 @@ from .make_shrink_map import MakeShrinkMap
|
|||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
|
||||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -0,0 +1,407 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FangShancheng/ABINet/blob/main/transforms.py
|
||||
"""
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from paddle.vision.transforms import Compose, ColorJitter
|
||||
|
||||
|
||||
def sample_asym(magnitude, size=None):
|
||||
return np.random.beta(1, 4, size) * magnitude
|
||||
|
||||
|
||||
def sample_sym(magnitude, size=None):
|
||||
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
|
||||
|
||||
|
||||
def sample_uniform(low, high, size=None):
|
||||
return np.random.uniform(low, high, size=size)
|
||||
|
||||
|
||||
def get_interpolation(type='random'):
|
||||
if type == 'random':
|
||||
choice = [
|
||||
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
|
||||
]
|
||||
interpolation = choice[random.randint(0, len(choice) - 1)]
|
||||
elif type == 'nearest':
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif type == 'linear':
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
elif type == 'cubic':
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
elif type == 'area':
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
raise TypeError(
|
||||
'Interpolation types only nearest, linear, cubic, area are supported!'
|
||||
)
|
||||
return interpolation
|
||||
|
||||
|
||||
class CVRandomRotation(object):
|
||||
def __init__(self, degrees=15):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees):
|
||||
return sample_sym(degrees)
|
||||
|
||||
def __call__(self, img):
|
||||
angle = self.get_params(self.degrees)
|
||||
src_h, src_w = img.shape[:2]
|
||||
M = cv2.getRotationMatrix2D(
|
||||
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
|
||||
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
|
||||
dst_w = int(src_h * abs_sin + src_w * abs_cos)
|
||||
dst_h = int(src_h * abs_cos + src_w * abs_sin)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomAffine(object):
|
||||
def __init__(self, degrees, translate=None, scale=None, shear=None):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
if translate is not None:
|
||||
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
||||
"translate should be a list or tuple and it must be of length 2."
|
||||
for t in translate:
|
||||
if not (0.0 <= t <= 1.0):
|
||||
raise ValueError(
|
||||
"translation values should be between 0 and 1")
|
||||
self.translate = translate
|
||||
|
||||
if scale is not None:
|
||||
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
||||
"scale should be a list or tuple and it must be of length 2."
|
||||
for s in scale:
|
||||
if s <= 0:
|
||||
raise ValueError("scale values should be positive")
|
||||
self.scale = scale
|
||||
|
||||
if shear is not None:
|
||||
if isinstance(shear, numbers.Number):
|
||||
if shear < 0:
|
||||
raise ValueError(
|
||||
"If shear is a single number, it must be positive.")
|
||||
self.shear = [shear]
|
||||
else:
|
||||
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
|
||||
"shear should be a list or tuple and it must be of length 2."
|
||||
self.shear = shear
|
||||
else:
|
||||
self.shear = shear
|
||||
|
||||
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
|
||||
shear):
|
||||
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
|
||||
from numpy import sin, cos, tan
|
||||
|
||||
if isinstance(shear, numbers.Number):
|
||||
shear = [shear, 0]
|
||||
|
||||
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
|
||||
raise ValueError(
|
||||
"Shear should be a single value or a tuple/list containing " +
|
||||
"two values. Got {}".format(shear))
|
||||
|
||||
rot = math.radians(angle)
|
||||
sx, sy = [math.radians(s) for s in shear]
|
||||
|
||||
cx, cy = center
|
||||
tx, ty = translate
|
||||
|
||||
# RSS without scaling
|
||||
a = cos(rot - sy) / cos(sy)
|
||||
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
|
||||
c = sin(rot - sy) / cos(sy)
|
||||
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
|
||||
|
||||
# Inverted rotation matrix with scale and shear
|
||||
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
||||
M = [d, -b, 0, -c, a, 0]
|
||||
M = [x / scale for x in M]
|
||||
|
||||
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
||||
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
|
||||
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
|
||||
|
||||
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
||||
M[2] += cx
|
||||
M[5] += cy
|
||||
return M
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees, translate, scale_ranges, shears, height):
|
||||
angle = sample_sym(degrees)
|
||||
if translate is not None:
|
||||
max_dx = translate[0] * height
|
||||
max_dy = translate[1] * height
|
||||
translations = (np.round(sample_sym(max_dx)),
|
||||
np.round(sample_sym(max_dy)))
|
||||
else:
|
||||
translations = (0, 0)
|
||||
|
||||
if scale_ranges is not None:
|
||||
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if shears is not None:
|
||||
if len(shears) == 1:
|
||||
shear = [sample_sym(shears[0]), 0.]
|
||||
elif len(shears) == 2:
|
||||
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
|
||||
else:
|
||||
shear = 0.0
|
||||
|
||||
return angle, translations, scale, shear
|
||||
|
||||
def __call__(self, img):
|
||||
src_h, src_w = img.shape[:2]
|
||||
angle, translate, scale, shear = self.get_params(
|
||||
self.degrees, self.translate, self.scale, self.shear, src_h)
|
||||
|
||||
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
|
||||
(0, 0), scale, shear)
|
||||
M = np.array(M).reshape(2, 3)
|
||||
|
||||
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
|
||||
(0, src_h - 1)]
|
||||
project = lambda x, y, a, b, c: int(a * x + b * y + c)
|
||||
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
|
||||
for x, y in startpoints]
|
||||
|
||||
rect = cv2.minAreaRect(np.array(endpoints))
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
|
||||
dst_w = int(max_x - min_x)
|
||||
dst_h = int(max_y - min_y)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
# add translate
|
||||
dst_w += int(abs(translate[0]))
|
||||
dst_h += int(abs(translate[1]))
|
||||
if translate[0] < 0: M[0, 2] += abs(translate[0])
|
||||
if translate[1] < 0: M[1, 2] += abs(translate[1])
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomPerspective(object):
|
||||
def __init__(self, distortion=0.5):
|
||||
self.distortion = distortion
|
||||
|
||||
def get_params(self, width, height, distortion):
|
||||
offset_h = sample_asym(
|
||||
distortion * height / 2, size=4).astype(dtype=np.int)
|
||||
offset_w = sample_asym(
|
||||
distortion * width / 2, size=4).astype(dtype=np.int)
|
||||
topleft = (offset_w[0], offset_h[0])
|
||||
topright = (width - 1 - offset_w[1], offset_h[1])
|
||||
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
|
||||
botleft = (offset_w[3], height - 1 - offset_h[3])
|
||||
|
||||
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
|
||||
(0, height - 1)]
|
||||
endpoints = [topleft, topright, botright, botleft]
|
||||
return np.array(
|
||||
startpoints, dtype=np.float32), np.array(
|
||||
endpoints, dtype=np.float32)
|
||||
|
||||
def __call__(self, img):
|
||||
height, width = img.shape[:2]
|
||||
startpoints, endpoints = self.get_params(width, height, self.distortion)
|
||||
M = cv2.getPerspectiveTransform(startpoints, endpoints)
|
||||
|
||||
# TODO: more robust way to crop image
|
||||
rect = cv2.minAreaRect(endpoints)
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
min_x, min_y = max(min_x, 0), max(min_y, 0)
|
||||
|
||||
flags = get_interpolation()
|
||||
img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (max_x, max_y),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
img = img[min_y:, min_x:]
|
||||
return img
|
||||
|
||||
|
||||
class CVRescale(object):
|
||||
def __init__(self, factor=4, base_size=(128, 512)):
|
||||
""" Define image scales using gaussian pyramid and rescale image to target scale.
|
||||
|
||||
Args:
|
||||
factor: the decayed factor from base size, factor=4 keeps target scale by default.
|
||||
base_size: base size the build the bottom layer of pyramid
|
||||
"""
|
||||
if isinstance(factor, numbers.Number):
|
||||
self.factor = round(sample_uniform(0, factor))
|
||||
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
|
||||
self.factor = round(sample_uniform(factor[0], factor[1]))
|
||||
else:
|
||||
raise Exception('factor must be number or list with length 2')
|
||||
# assert factor is valid
|
||||
self.base_h, self.base_w = base_size[:2]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.factor == 0: return img
|
||||
src_h, src_w = img.shape[:2]
|
||||
cur_w, cur_h = self.base_w, self.base_h
|
||||
scale_img = cv2.resize(
|
||||
img, (cur_w, cur_h), interpolation=get_interpolation())
|
||||
for _ in range(self.factor):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = cv2.resize(
|
||||
scale_img, (src_w, src_h), interpolation=get_interpolation())
|
||||
return scale_img
|
||||
|
||||
|
||||
class CVGaussianNoise(object):
|
||||
def __init__(self, mean=0, var=20):
|
||||
self.mean = mean
|
||||
if isinstance(var, numbers.Number):
|
||||
self.var = max(int(sample_asym(var)), 1)
|
||||
elif isinstance(var, (tuple, list)) and len(var) == 2:
|
||||
self.var = int(sample_uniform(var[0], var[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
|
||||
def __call__(self, img):
|
||||
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
|
||||
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVMotionBlur(object):
|
||||
def __init__(self, degrees=12, angle=90):
|
||||
if isinstance(degrees, numbers.Number):
|
||||
self.degree = max(int(sample_asym(degrees)), 1)
|
||||
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
||||
self.degree = int(sample_uniform(degrees[0], degrees[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
self.angle = sample_uniform(-angle, angle)
|
||||
|
||||
def __call__(self, img):
|
||||
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
|
||||
self.angle, 1)
|
||||
motion_blur_kernel = np.zeros((self.degree, self.degree))
|
||||
motion_blur_kernel[self.degree // 2, :] = 1
|
||||
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
|
||||
(self.degree, self.degree))
|
||||
motion_blur_kernel = motion_blur_kernel / self.degree
|
||||
img = cv2.filter2D(img, -1, motion_blur_kernel)
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVGeometry(object):
|
||||
def __init__(self,
|
||||
degrees=15,
|
||||
translate=(0.3, 0.3),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
type_p = random.random()
|
||||
if type_p < 0.33:
|
||||
self.transforms = CVRandomRotation(degrees=degrees)
|
||||
elif type_p < 0.66:
|
||||
self.transforms = CVRandomAffine(
|
||||
degrees=degrees, translate=translate, scale=scale, shear=shear)
|
||||
else:
|
||||
self.transforms = CVRandomPerspective(distortion=distortion)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVDeterioration(object):
|
||||
def __init__(self, var, degrees, factor, p=0.5):
|
||||
self.p = p
|
||||
transforms = []
|
||||
if var is not None:
|
||||
transforms.append(CVGaussianNoise(var=var))
|
||||
if degrees is not None:
|
||||
transforms.append(CVMotionBlur(degrees=degrees))
|
||||
if factor is not None:
|
||||
transforms.append(CVRescale(factor=factor))
|
||||
|
||||
random.shuffle(transforms)
|
||||
transforms = Compose(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVColorJitter(object):
|
||||
def __init__(self,
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
self.transforms = ColorJitter(
|
||||
brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p: return self.transforms(img)
|
||||
else: return img
|
|
@ -157,37 +157,6 @@ class BaseRecLabelEncode(object):
|
|||
return text_list
|
||||
|
||||
|
||||
class NRTRLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
|
||||
super(NRTRLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len - 1:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text.insert(0, 2)
|
||||
text.append(3)
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class CTCLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -1046,3 +1015,99 @@ class MultiLabelEncode(BaseRecLabelEncode):
|
|||
data_out['label_sar'] = sar['label']
|
||||
data_out['length'] = ctc['length']
|
||||
return data_out
|
||||
|
||||
|
||||
class NRTRLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
|
||||
super(NRTRLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len - 1:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text.insert(0, 2)
|
||||
text.append(3)
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class ViTSTRLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
ignore_index=0,
|
||||
**kwargs):
|
||||
|
||||
super(ViTSTRLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text.insert(0, self.ignore_index)
|
||||
text.append(1)
|
||||
text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class ABINetLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
ignore_index=100,
|
||||
**kwargs):
|
||||
|
||||
super(ABINetLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text.append(0)
|
||||
text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['</s>'] + dict_character
|
||||
return dict_character
|
||||
|
|
|
@ -67,39 +67,6 @@ class DecodeImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class NRTRDecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
|
||||
img = cv2.imdecode(img, 1)
|
||||
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
|
|
@ -19,6 +19,8 @@ import random
|
|||
import copy
|
||||
from PIL import Image
|
||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter
|
||||
from paddle.vision.transforms import Compose
|
||||
|
||||
|
||||
class RecAug(object):
|
||||
|
@ -94,6 +96,36 @@ class BaseDataAugmentation(object):
|
|||
return data
|
||||
|
||||
|
||||
class ABINetRecAug(object):
|
||||
def __init__(self,
|
||||
geometry_p=0.5,
|
||||
deterioration_p=0.25,
|
||||
colorjitter_p=0.25,
|
||||
**kwargs):
|
||||
self.transforms = Compose([
|
||||
CVGeometry(
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=geometry_p), CVDeterioration(
|
||||
var=20, degrees=6, factor=4, p=deterioration_p),
|
||||
CVColorJitter(
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=colorjitter_p)
|
||||
])
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = self.transforms(img)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class RecConAug(object):
|
||||
def __init__(self,
|
||||
prob=0.5,
|
||||
|
@ -148,46 +180,6 @@ class ClsResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class NRTRRecResizeImg(object):
|
||||
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.resize_type = resize_type
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
image_shape = self.image_shape
|
||||
if self.padding:
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
norm_img = np.expand_dims(resized_image, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
resized_image = norm_img.astype(np.float32) / 128. - 1.
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
data['image'] = padding_im
|
||||
return data
|
||||
if self.resize_type == 'PIL':
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
img = np.array(img)
|
||||
if self.resize_type == 'OpenCV':
|
||||
img = cv2.resize(img, self.image_shape)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class RecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
|
@ -268,6 +260,84 @@ class PRENResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class GrayRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
resize_type,
|
||||
inter_type='Image.ANTIALIAS',
|
||||
scale=True,
|
||||
padding=False,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.resize_type = resize_type
|
||||
self.padding = padding
|
||||
self.inter_type = eval(inter_type)
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
image_shape = self.image_shape
|
||||
if self.padding:
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
norm_img = np.expand_dims(resized_image, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
resized_image = norm_img.astype(np.float32) / 128. - 1.
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
data['image'] = padding_im
|
||||
return data
|
||||
if self.resize_type == 'PIL':
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize(self.image_shape, self.inter_type)
|
||||
img = np.array(img)
|
||||
if self.resize_type == 'OpenCV':
|
||||
img = cv2.resize(img, self.image_shape)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
if self.scale:
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
else:
|
||||
data['image'] = norm_img.astype(np.float32) / 255.
|
||||
return data
|
||||
|
||||
|
||||
class ABINetRecResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class SVTRRecResizeImg(object):
|
||||
def __init__(self, image_shape, padding=True, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
|
||||
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
|
||||
self.padding)
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
|
@ -386,6 +456,26 @@ def resize_norm_img_srn(img, image_shape):
|
|||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
|
||||
def resize_norm_img_abinet(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255.
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (
|
||||
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype('float32')
|
||||
|
||||
valid_ratio = min(1.0, float(resized_w / imgW))
|
||||
return resized_image, valid_ratio
|
||||
|
||||
|
||||
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
|
|
@ -30,7 +30,7 @@ from .det_fce_loss import FCELoss
|
|||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_ce_loss import CELoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
|
@ -60,7 +60,7 @@ def build_loss(config):
|
|||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self,
|
||||
smoothing=False,
|
||||
with_all=False,
|
||||
ignore_index=-1,
|
||||
**kwargs):
|
||||
super(CELoss, self).__init__()
|
||||
if ignore_index >= 0:
|
||||
self.loss_func = nn.CrossEntropyLoss(
|
||||
reduction='mean', ignore_index=ignore_index)
|
||||
else:
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
self.smoothing = smoothing
|
||||
self.with_all = with_all
|
||||
|
||||
def forward(self, pred, batch):
|
||||
|
||||
if isinstance(pred, dict): # for ABINet
|
||||
loss = {}
|
||||
loss_sum = []
|
||||
for name, logits in pred.items():
|
||||
if isinstance(logits, list):
|
||||
logit_num = len(logits)
|
||||
all_tgt = paddle.concat([batch[1]] * logit_num, 0)
|
||||
all_logits = paddle.concat(logits, 0)
|
||||
flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
|
||||
flt_tgt = all_tgt.reshape([-1])
|
||||
else:
|
||||
flt_logtis = logits.reshape([-1, logits.shape[2]])
|
||||
flt_tgt = batch[1].reshape([-1])
|
||||
loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
|
||||
loss_sum.append(loss[name + '_loss'])
|
||||
loss['loss'] = sum(loss_sum)
|
||||
return loss
|
||||
else:
|
||||
if self.with_all: # for ViTSTR
|
||||
tgt = batch[1]
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
tgt = tgt.reshape([-1])
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
||||
else: # for NRTR
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:, 1:2 + max_len]
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
tgt = tgt.reshape([-1])
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
|
||||
n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
|
@ -1,30 +0,0 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class NRTRLoss(nn.Layer):
|
||||
def __init__(self, smoothing=True, **kwargs):
|
||||
super(NRTRLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, pred, batch):
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:, 1:2 + max_len]
|
||||
tgt = tgt.reshape([-1])
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
|
@ -28,35 +28,37 @@ def build_backbone(config, model_type):
|
|||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
from .rec_resnet_45 import ResNet45
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
from .rec_micronet import MicroNet
|
||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
from .rec_svtrnet import SVTRNet
|
||||
from .rec_vitstr import ViTSTR
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
|
||||
'SVTRNet'
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
elif model_type == 'kie':
|
||||
from .kie_unet_sdmgr import Kie_backbone
|
||||
support_dict = ['Kie_backbone']
|
||||
elif model_type == "table":
|
||||
elif model_type == 'table':
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
support_dict = ['ResNet', 'MobileNetV3']
|
||||
elif model_type == 'vqa':
|
||||
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
|
||||
support_dict = [
|
||||
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
|
||||
"LayoutXLMForSer", 'LayoutXLMForRe'
|
||||
'LayoutLMForSer', 'LayoutLMv2ForSer', 'LayoutLMv2ForRe',
|
||||
'LayoutXLMForSer', 'LayoutXLMForRe'
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
module_name = config.pop("name")
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
"when model typs is {}, backbone only support {}".format(model_type,
|
||||
support_dict))
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FangShancheng/ABINet/tree/main/modules
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
__all__ = ["ResNet45"]
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
def conv3x3(in_channel, out_channel, stride=1):
|
||||
return nn.Conv2D(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, channels, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = conv1x1(in_channels, channels)
|
||||
self.bn1 = nn.BatchNorm2D(channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(channels, channels, stride)
|
||||
self.bn2 = nn.BatchNorm2D(channels)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet45(nn.Layer):
|
||||
def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
|
||||
self.inplanes = 32
|
||||
super(ResNet45, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(32)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
|
||||
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
|
||||
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
|
||||
self.out_channels = 512
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2D):
|
||||
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
||||
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
# downsample = True
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(planes * block.expansion), )
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
# print(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
# print(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
return x
|
|
@ -147,7 +147,7 @@ class Attention(nn.Layer):
|
|||
dim,
|
||||
num_heads=8,
|
||||
mixer='Global',
|
||||
HW=[8, 25],
|
||||
HW=None,
|
||||
local_k=[7, 11],
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
|
@ -210,7 +210,7 @@ class Block(nn.Layer):
|
|||
num_heads,
|
||||
mixer='Global',
|
||||
local_mixer=[7, 11],
|
||||
HW=[8, 25],
|
||||
HW=None,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
|
@ -274,7 +274,9 @@ class PatchEmbed(nn.Layer):
|
|||
img_size=[32, 100],
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
sub_num=2):
|
||||
sub_num=2,
|
||||
patch_size=[4, 4],
|
||||
mode='pope'):
|
||||
super().__init__()
|
||||
num_patches = (img_size[1] // (2 ** sub_num)) * \
|
||||
(img_size[0] // (2 ** sub_num))
|
||||
|
@ -282,50 +284,56 @@ class PatchEmbed(nn.Layer):
|
|||
self.num_patches = num_patches
|
||||
self.embed_dim = embed_dim
|
||||
self.norm = None
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None))
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None))
|
||||
if mode == 'pope':
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None))
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None))
|
||||
elif mode == 'linear':
|
||||
self.proj = nn.Conv2D(
|
||||
1, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.num_patches = img_size[0] // patch_size[0] * img_size[
|
||||
1] // patch_size[1]
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/roatienza/deep-text-recognition-benchmark/blob/master/modules/vitstr.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from ppocr.modeling.backbones.rec_svtrnet import Block, PatchEmbed, zeros_, trunc_normal_, ones_
|
||||
|
||||
scale_dim_heads = {'tiny': [192, 3], 'small': [384, 6], 'base': [768, 12]}
|
||||
|
||||
|
||||
class ViTSTR(nn.Layer):
|
||||
def __init__(self,
|
||||
img_size=[224, 224],
|
||||
in_channels=1,
|
||||
scale='tiny',
|
||||
seqlen=27,
|
||||
patch_size=[16, 16],
|
||||
embed_dim=None,
|
||||
depth=12,
|
||||
num_heads=None,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
norm_layer='nn.LayerNorm',
|
||||
act_layer='nn.GELU',
|
||||
epsilon=1e-6,
|
||||
out_channels=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.seqlen = seqlen
|
||||
embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[
|
||||
scale][0]
|
||||
num_heads = num_heads if num_heads is not None else scale_dim_heads[
|
||||
scale][1]
|
||||
out_channels = out_channels if out_channels is not None else embed_dim
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
patch_size=patch_size,
|
||||
mode='linear')
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.pos_embed = self.create_parameter(
|
||||
shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_)
|
||||
self.add_parameter("pos_embed", self.pos_embed)
|
||||
self.cls_token = self.create_parameter(
|
||||
shape=[1, 1, embed_dim], default_initializer=zeros_)
|
||||
self.add_parameter("cls_token", self.cls_token)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, depth)
|
||||
self.blocks = nn.LayerList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=eval(act_layer),
|
||||
epsilon=epsilon,
|
||||
prenorm=False) for i in range(depth)
|
||||
])
|
||||
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
||||
|
||||
self.out_channels = out_channels
|
||||
|
||||
trunc_normal_(self.pos_embed)
|
||||
trunc_normal_(self.cls_token)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
cls_tokens = paddle.tile(self.cls_token, repeat_times=[B, 1, 1])
|
||||
x = paddle.concat((cls_tokens, x), axis=1)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = x[:, :self.seqlen]
|
||||
return x.transpose([0, 2, 1]).unsqueeze(2)
|
|
@ -33,6 +33,7 @@ def build_head(config):
|
|||
from .rec_aster_head import AsterHead
|
||||
from .rec_pren_head import PRENHead
|
||||
from .rec_multi_head import MultiHead
|
||||
from .rec_abinet_head import ABINetHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -46,7 +47,7 @@ def build_head(config):
|
|||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead'
|
||||
'MultiHead', 'ABINetHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -1,163 +0,0 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Linear
|
||||
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
zeros_ = constant_(value=0.)
|
||||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||
self._reset_parameters()
|
||||
self.conv1 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv2 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv3 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
|
||||
def _reset_parameters(self):
|
||||
xavier_uniform_(self.out_proj.weight)
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_padding_mask=None,
|
||||
incremental_state=None,
|
||||
attn_mask=None):
|
||||
"""
|
||||
Inputs of forward function
|
||||
query: [target length, batch size, embed dim]
|
||||
key: [sequence length, batch size, embed dim]
|
||||
value: [sequence length, batch size, embed dim]
|
||||
key_padding_mask: if True, mask padding based on batch size
|
||||
incremental_state: if provided, previous time steps are cashed
|
||||
need_weights: output attn_output_weights
|
||||
static_kv: key and value are static
|
||||
|
||||
Outputs of forward function
|
||||
attn_output: [target length, batch size, embed dim]
|
||||
attn_output_weights: [batch size, target length, sequence length]
|
||||
"""
|
||||
q_shape = paddle.shape(query)
|
||||
src_shape = paddle.shape(key)
|
||||
q = self._in_proj_q(query)
|
||||
k = self._in_proj_k(key)
|
||||
v = self._in_proj_v(value)
|
||||
q *= self.scaling
|
||||
q = paddle.transpose(
|
||||
paddle.reshape(
|
||||
q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
||||
[1, 2, 0, 3])
|
||||
k = paddle.transpose(
|
||||
paddle.reshape(
|
||||
k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
||||
[1, 2, 0, 3])
|
||||
v = paddle.transpose(
|
||||
paddle.reshape(
|
||||
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
||||
[1, 2, 0, 3])
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape[0] == q_shape[1]
|
||||
assert key_padding_mask.shape[1] == src_shape[0]
|
||||
attn_output_weights = paddle.matmul(q,
|
||||
paddle.transpose(k, [0, 1, 3, 2]))
|
||||
if attn_mask is not None:
|
||||
attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
|
||||
attn_output_weights += attn_mask
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = paddle.reshape(
|
||||
attn_output_weights,
|
||||
[q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
|
||||
key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
|
||||
key = paddle.cast(key, 'float32')
|
||||
y = paddle.full(
|
||||
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
|
||||
y = paddle.where(key == 0., key, y)
|
||||
attn_output_weights += y
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights.astype('float32'),
|
||||
axis=-1,
|
||||
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
|
||||
else attn_output_weights.dtype)
|
||||
attn_output_weights = F.dropout(
|
||||
attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = paddle.matmul(attn_output_weights, v)
|
||||
attn_output = paddle.reshape(
|
||||
paddle.transpose(attn_output, [2, 0, 1, 3]),
|
||||
[q_shape[0], q_shape[1], self.embed_dim])
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _in_proj_q(self, query):
|
||||
query = paddle.transpose(query, [1, 2, 0])
|
||||
query = paddle.unsqueeze(query, axis=2)
|
||||
res = self.conv1(query)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = paddle.transpose(res, [2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_k(self, key):
|
||||
key = paddle.transpose(key, [1, 2, 0])
|
||||
key = paddle.unsqueeze(key, axis=2)
|
||||
res = self.conv2(key)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = paddle.transpose(res, [2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_v(self, value):
|
||||
value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
|
||||
value = paddle.unsqueeze(value, axis=2)
|
||||
res = self.conv3(value)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = paddle.transpose(res, [2, 0, 1])
|
||||
return res
|
|
@ -0,0 +1,296 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FangShancheng/ABINet/tree/main/modules
|
||||
"""
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import LayerList
|
||||
from ppocr.modeling.heads.rec_nrtr_head import TransformerBlock, PositionalEncoding
|
||||
|
||||
|
||||
class BCNLanguage(nn.Layer):
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_layers=4,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.,
|
||||
max_length=25,
|
||||
detach=True,
|
||||
num_classes=37):
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.detach = detach
|
||||
self.max_length = max_length + 1 # additional stop token
|
||||
self.proj = nn.Linear(num_classes, d_model, bias_attr=False)
|
||||
self.token_encoder = PositionalEncoding(
|
||||
dropout=0.1, dim=d_model, max_len=self.max_length)
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
dropout=0, dim=d_model, max_len=self.max_length)
|
||||
|
||||
self.decoder = nn.LayerList([
|
||||
TransformerBlock(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
attention_dropout_rate=dropout,
|
||||
residual_dropout_rate=dropout,
|
||||
with_self_attn=False,
|
||||
with_cross_attn=True) for i in range(num_layers)
|
||||
])
|
||||
|
||||
self.cls = nn.Linear(d_model, num_classes)
|
||||
|
||||
def forward(self, tokens, lengths):
|
||||
"""
|
||||
Args:
|
||||
tokens: (B, N, C) where N is length, B is batch size and C is classes number
|
||||
lengths: (B,)
|
||||
"""
|
||||
if self.detach: tokens = tokens.detach()
|
||||
embed = self.proj(tokens) # (B, N, C)
|
||||
embed = self.token_encoder(embed) # (B, N, C)
|
||||
padding_mask = _get_mask(lengths, self.max_length)
|
||||
zeros = paddle.zeros_like(embed) # (B, N, C)
|
||||
qeury = self.pos_encoder(zeros)
|
||||
for decoder_layer in self.decoder:
|
||||
qeury = decoder_layer(qeury, embed, cross_mask=padding_mask)
|
||||
output = qeury # (B, N, C)
|
||||
|
||||
logits = self.cls(output) # (B, N, C)
|
||||
|
||||
return output, logits
|
||||
|
||||
|
||||
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2D(in_c, out_c, k, s, p), nn.BatchNorm2D(out_c), nn.ReLU())
|
||||
|
||||
|
||||
def decoder_layer(in_c,
|
||||
out_c,
|
||||
k=3,
|
||||
s=1,
|
||||
p=1,
|
||||
mode='nearest',
|
||||
scale_factor=None,
|
||||
size=None):
|
||||
align_corners = False if mode == 'nearest' else True
|
||||
return nn.Sequential(
|
||||
nn.Upsample(
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners),
|
||||
nn.Conv2D(in_c, out_c, k, s, p),
|
||||
nn.BatchNorm2D(out_c),
|
||||
nn.ReLU())
|
||||
|
||||
|
||||
class PositionAttention(nn.Layer):
|
||||
def __init__(self,
|
||||
max_length,
|
||||
in_channels=512,
|
||||
num_channels=64,
|
||||
h=8,
|
||||
w=32,
|
||||
mode='nearest',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.k_encoder = nn.Sequential(
|
||||
encoder_layer(
|
||||
in_channels, num_channels, s=(1, 2)),
|
||||
encoder_layer(
|
||||
num_channels, num_channels, s=(2, 2)),
|
||||
encoder_layer(
|
||||
num_channels, num_channels, s=(2, 2)),
|
||||
encoder_layer(
|
||||
num_channels, num_channels, s=(2, 2)))
|
||||
self.k_decoder = nn.Sequential(
|
||||
decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=mode),
|
||||
decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=mode),
|
||||
decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=mode),
|
||||
decoder_layer(
|
||||
num_channels, in_channels, size=(h, w), mode=mode))
|
||||
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
dropout=0, dim=in_channels, max_len=max_length)
|
||||
self.project = nn.Linear(in_channels, in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
k, v = x, x
|
||||
|
||||
# calculate key vector
|
||||
features = []
|
||||
for i in range(0, len(self.k_encoder)):
|
||||
k = self.k_encoder[i](k)
|
||||
features.append(k)
|
||||
for i in range(0, len(self.k_decoder) - 1):
|
||||
k = self.k_decoder[i](k)
|
||||
# print(k.shape, features[len(self.k_decoder) - 2 - i].shape)
|
||||
k = k + features[len(self.k_decoder) - 2 - i]
|
||||
k = self.k_decoder[-1](k)
|
||||
|
||||
# calculate query vector
|
||||
# TODO q=f(q,k)
|
||||
zeros = paddle.zeros(
|
||||
(B, self.max_length, C), dtype=x.dtype) # (T, N, C)
|
||||
q = self.pos_encoder(zeros) # (B, N, C)
|
||||
q = self.project(q) # (B, N, C)
|
||||
|
||||
# calculate attention
|
||||
attn_scores = q @k.flatten(2) # (B, N, (H*W))
|
||||
attn_scores = attn_scores / (C**0.5)
|
||||
attn_scores = F.softmax(attn_scores, axis=-1)
|
||||
|
||||
v = v.flatten(2).transpose([0, 2, 1]) # (B, (H*W), C)
|
||||
attn_vecs = attn_scores @v # (B, N, C)
|
||||
|
||||
return attn_vecs, attn_scores.reshape([0, self.max_length, H, W])
|
||||
|
||||
|
||||
class ABINetHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_layers=3,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
max_length=25,
|
||||
use_lang=False,
|
||||
iter_size=1):
|
||||
super().__init__()
|
||||
self.max_length = max_length + 1
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
dropout=0.1, dim=d_model, max_len=8 * 32)
|
||||
self.encoder = nn.LayerList([
|
||||
TransformerBlock(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
attention_dropout_rate=dropout,
|
||||
residual_dropout_rate=dropout,
|
||||
with_self_attn=True,
|
||||
with_cross_attn=False) for i in range(num_layers)
|
||||
])
|
||||
self.decoder = PositionAttention(
|
||||
max_length=max_length + 1, # additional stop token
|
||||
mode='nearest', )
|
||||
self.out_channels = out_channels
|
||||
self.cls = nn.Linear(d_model, self.out_channels)
|
||||
self.use_lang = use_lang
|
||||
if use_lang:
|
||||
self.iter_size = iter_size
|
||||
self.language = BCNLanguage(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_layers=4,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
max_length=max_length,
|
||||
num_classes=self.out_channels)
|
||||
# alignment
|
||||
self.w_att_align = nn.Linear(2 * d_model, d_model)
|
||||
self.cls_align = nn.Linear(d_model, self.out_channels)
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
x = x.transpose([0, 2, 3, 1])
|
||||
_, H, W, C = x.shape
|
||||
feature = x.flatten(1, 2)
|
||||
feature = self.pos_encoder(feature)
|
||||
for encoder_layer in self.encoder:
|
||||
feature = encoder_layer(feature)
|
||||
feature = feature.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
|
||||
v_feature, attn_scores = self.decoder(
|
||||
feature) # (B, N, C), (B, C, H, W)
|
||||
vis_logits = self.cls(v_feature) # (B, N, C)
|
||||
logits = vis_logits
|
||||
vis_lengths = _get_length(vis_logits)
|
||||
if self.use_lang:
|
||||
align_logits = vis_logits
|
||||
align_lengths = vis_lengths
|
||||
all_l_res, all_a_res = [], []
|
||||
for i in range(self.iter_size):
|
||||
tokens = F.softmax(align_logits, axis=-1)
|
||||
lengths = align_lengths
|
||||
lengths = paddle.clip(
|
||||
lengths, 2, self.max_length) # TODO:move to langauge model
|
||||
l_feature, l_logits = self.language(tokens, lengths)
|
||||
|
||||
# alignment
|
||||
all_l_res.append(l_logits)
|
||||
fuse = paddle.concat((l_feature, v_feature), -1)
|
||||
f_att = F.sigmoid(self.w_att_align(fuse))
|
||||
output = f_att * v_feature + (1 - f_att) * l_feature
|
||||
align_logits = self.cls_align(output) # (B, N, C)
|
||||
|
||||
align_lengths = _get_length(align_logits)
|
||||
all_a_res.append(align_logits)
|
||||
if self.training:
|
||||
return {
|
||||
'align': all_a_res,
|
||||
'lang': all_l_res,
|
||||
'vision': vis_logits
|
||||
}
|
||||
else:
|
||||
logits = align_logits
|
||||
if self.training:
|
||||
return logits
|
||||
else:
|
||||
return F.softmax(logits, -1)
|
||||
|
||||
|
||||
def _get_length(logit):
|
||||
""" Greed decoder to obtain length from logit"""
|
||||
out = (logit.argmax(-1) == 0)
|
||||
abn = out.any(-1)
|
||||
out_int = out.cast('int32')
|
||||
out = (out_int.cumsum(-1) == 1) & out
|
||||
out = out.cast('int32')
|
||||
out = out.argmax(-1)
|
||||
out = out + 1
|
||||
out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1]))
|
||||
return out
|
||||
|
||||
|
||||
def _get_mask(length, max_length):
|
||||
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
||||
Unmasked positions are filled with float(0.0).
|
||||
"""
|
||||
length = length.unsqueeze(-1)
|
||||
B = paddle.shape(length)[0]
|
||||
grid = paddle.arange(0, max_length).unsqueeze(0).tile([B, 1])
|
||||
zero_mask = paddle.zeros([B, max_length], dtype='float32')
|
||||
inf_mask = paddle.full([B, max_length], '-inf', dtype='float32')
|
||||
diag_mask = paddle.diag(
|
||||
paddle.full(
|
||||
[max_length], '-inf', dtype=paddle.float32),
|
||||
offset=0,
|
||||
name=None)
|
||||
mask = paddle.where(grid >= length, inf_mask, zero_mask)
|
||||
mask = mask.unsqueeze(1) + diag_mask
|
||||
return mask.unsqueeze(1)
|
|
@ -14,20 +14,15 @@
|
|||
|
||||
import math
|
||||
import paddle
|
||||
import copy
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import LayerList
|
||||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||
# from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm
|
||||
import numpy as np
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from ppocr.modeling.backbones.rec_svtrnet import Mlp, zeros_, ones_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
zeros_ = constant_(value=0.)
|
||||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class Transformer(nn.Layer):
|
||||
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
||||
|
@ -45,7 +40,6 @@ class Transformer(nn.Layer):
|
|||
dropout: the dropout value (default=0.1).
|
||||
custom_encoder: custom encoder (default=None).
|
||||
custom_decoder: custom decoder (default=None).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -54,45 +48,49 @@ class Transformer(nn.Layer):
|
|||
num_encoder_layers=6,
|
||||
beam_size=0,
|
||||
num_decoder_layers=6,
|
||||
max_len=25,
|
||||
dim_feedforward=1024,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1,
|
||||
custom_encoder=None,
|
||||
custom_decoder=None,
|
||||
in_channels=0,
|
||||
out_channels=0,
|
||||
scale_embedding=True):
|
||||
super(Transformer, self).__init__()
|
||||
self.out_channels = out_channels + 1
|
||||
self.max_len = max_len
|
||||
self.embedding = Embeddings(
|
||||
d_model=d_model,
|
||||
vocab=self.out_channels,
|
||||
padding_idx=0,
|
||||
scale_embedding=scale_embedding)
|
||||
self.positional_encoding = PositionalEncoding(
|
||||
dropout=residual_dropout_rate,
|
||||
dim=d_model, )
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
if num_encoder_layers > 0:
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||
residual_dropout_rate)
|
||||
self.encoder = TransformerEncoder(encoder_layer,
|
||||
num_encoder_layers)
|
||||
else:
|
||||
self.encoder = None
|
||||
dropout=residual_dropout_rate, dim=d_model)
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
if num_encoder_layers > 0:
|
||||
self.encoder = nn.LayerList([
|
||||
TransformerBlock(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
attention_dropout_rate,
|
||||
residual_dropout_rate,
|
||||
with_self_attn=True,
|
||||
with_cross_attn=False) for i in range(num_encoder_layers)
|
||||
])
|
||||
else:
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||
residual_dropout_rate)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
||||
self.encoder = None
|
||||
|
||||
self.decoder = nn.LayerList([
|
||||
TransformerBlock(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
attention_dropout_rate,
|
||||
residual_dropout_rate,
|
||||
with_self_attn=True,
|
||||
with_cross_attn=True) for i in range(num_decoder_layers)
|
||||
])
|
||||
|
||||
self._reset_parameters()
|
||||
self.beam_size = beam_size
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
@ -105,7 +103,7 @@ class Transformer(nn.Layer):
|
|||
|
||||
def _init_weights(self, m):
|
||||
|
||||
if isinstance(m, nn.Conv2D):
|
||||
if isinstance(m, nn.Linear):
|
||||
xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
|
@ -113,24 +111,20 @@ class Transformer(nn.Layer):
|
|||
def forward_train(self, src, tgt):
|
||||
tgt = tgt[:, :-1]
|
||||
|
||||
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
||||
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
||||
tgt = self.embedding(tgt)
|
||||
tgt = self.positional_encoding(tgt)
|
||||
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
||||
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[1])
|
||||
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||
memory = self.encoder(src)
|
||||
src = self.positional_encoding(src)
|
||||
for encoder_layer in self.encoder:
|
||||
src = encoder_layer(src)
|
||||
memory = src # B N C
|
||||
else:
|
||||
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||
output = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=None)
|
||||
output = output.transpose([1, 0, 2])
|
||||
memory = src # B N C
|
||||
for decoder_layer in self.decoder:
|
||||
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
|
||||
output = tgt
|
||||
logit = self.tgt_word_prj(output)
|
||||
return logit
|
||||
|
||||
|
@ -140,8 +134,8 @@ class Transformer(nn.Layer):
|
|||
src: the sequence to the encoder (required).
|
||||
tgt: the sequence to the decoder (required).
|
||||
Shape:
|
||||
- src: :math:`(S, N, E)`.
|
||||
- tgt: :math:`(T, N, E)`.
|
||||
- src: :math:`(B, sN, C)`.
|
||||
- tgt: :math:`(B, tN, C)`.
|
||||
Examples:
|
||||
>>> output = transformer_model(src, tgt)
|
||||
"""
|
||||
|
@ -157,36 +151,35 @@ class Transformer(nn.Layer):
|
|||
return self.forward_test(src)
|
||||
|
||||
def forward_test(self, src):
|
||||
|
||||
bs = paddle.shape(src)[0]
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
|
||||
memory = self.encoder(src)
|
||||
src = self.positional_encoding(src)
|
||||
for encoder_layer in self.encoder:
|
||||
src = encoder_layer(src)
|
||||
memory = src # B N C
|
||||
else:
|
||||
memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
|
||||
memory = src
|
||||
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
|
||||
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
|
||||
for len_dec_seq in range(1, 25):
|
||||
dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
|
||||
for len_dec_seq in range(1, self.max_len):
|
||||
dec_seq_embed = self.embedding(dec_seq)
|
||||
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
||||
tgt_mask = self.generate_square_subsequent_mask(
|
||||
paddle.shape(dec_seq_embed)[0])
|
||||
output = self.decoder(
|
||||
dec_seq_embed,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None)
|
||||
dec_output = paddle.transpose(output, [1, 0, 2])
|
||||
paddle.shape(dec_seq_embed)[1])
|
||||
tgt = dec_seq_embed
|
||||
for decoder_layer in self.decoder:
|
||||
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
|
||||
dec_output = tgt
|
||||
dec_output = dec_output[:, -1, :]
|
||||
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||
preds_idx = paddle.argmax(word_prob, axis=1)
|
||||
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=-1)
|
||||
preds_idx = paddle.argmax(word_prob, axis=-1)
|
||||
if paddle.equal_all(
|
||||
preds_idx,
|
||||
paddle.full(
|
||||
paddle.shape(preds_idx), 3, dtype='int64')):
|
||||
break
|
||||
preds_prob = paddle.max(word_prob, axis=1)
|
||||
preds_prob = paddle.max(word_prob, axis=-1)
|
||||
dec_seq = paddle.concat(
|
||||
[dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
|
||||
dec_prob = paddle.concat(
|
||||
|
@ -194,10 +187,10 @@ class Transformer(nn.Layer):
|
|||
return [dec_seq, dec_prob]
|
||||
|
||||
def forward_beam(self, images):
|
||||
''' Translation work in one batch '''
|
||||
""" Translation work in one batch """
|
||||
|
||||
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
||||
''' Indicate the position of an instance in a tensor. '''
|
||||
""" Indicate the position of an instance in a tensor. """
|
||||
return {
|
||||
inst_idx: tensor_position
|
||||
for tensor_position, inst_idx in enumerate(inst_idx_list)
|
||||
|
@ -205,7 +198,7 @@ class Transformer(nn.Layer):
|
|||
|
||||
def collect_active_part(beamed_tensor, curr_active_inst_idx,
|
||||
n_prev_active_inst, n_bm):
|
||||
''' Collect tensor parts associated to active instances. '''
|
||||
""" Collect tensor parts associated to active instances. """
|
||||
|
||||
beamed_tensor_shape = paddle.shape(beamed_tensor)
|
||||
n_curr_active_inst = len(curr_active_inst_idx)
|
||||
|
@ -237,9 +230,8 @@ class Transformer(nn.Layer):
|
|||
return active_src_enc, active_inst_idx_to_position_map
|
||||
|
||||
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
|
||||
inst_idx_to_position_map, n_bm,
|
||||
memory_key_padding_mask):
|
||||
''' Decode and update beam status, and then return active beam idx '''
|
||||
inst_idx_to_position_map, n_bm):
|
||||
""" Decode and update beam status, and then return active beam idx """
|
||||
|
||||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
||||
dec_partial_seq = [
|
||||
|
@ -249,19 +241,15 @@ class Transformer(nn.Layer):
|
|||
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
||||
return dec_partial_seq
|
||||
|
||||
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||
memory_key_padding_mask):
|
||||
dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
|
||||
def predict_word(dec_seq, enc_output, n_active_inst, n_bm):
|
||||
dec_seq = self.embedding(dec_seq)
|
||||
dec_seq = self.positional_encoding(dec_seq)
|
||||
tgt_mask = self.generate_square_subsequent_mask(
|
||||
paddle.shape(dec_seq)[0])
|
||||
dec_output = self.decoder(
|
||||
dec_seq,
|
||||
enc_output,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=memory_key_padding_mask, )
|
||||
dec_output = paddle.transpose(dec_output, [1, 0, 2])
|
||||
paddle.shape(dec_seq)[1])
|
||||
tgt = dec_seq
|
||||
for decoder_layer in self.decoder:
|
||||
tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
|
||||
dec_output = tgt
|
||||
dec_output = dec_output[:,
|
||||
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||
|
@ -281,8 +269,7 @@ class Transformer(nn.Layer):
|
|||
|
||||
n_active_inst = len(inst_idx_to_position_map)
|
||||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
||||
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||
None)
|
||||
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm)
|
||||
# Update the beam with predicted word prob information and collect incomplete instances
|
||||
active_inst_idx_list = collect_active_inst_idx_list(
|
||||
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
||||
|
@ -303,10 +290,10 @@ class Transformer(nn.Layer):
|
|||
with paddle.no_grad():
|
||||
#-- Encode
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
||||
src = self.positional_encoding(images)
|
||||
src_enc = self.encoder(src)
|
||||
else:
|
||||
src_enc = images.squeeze(2).transpose([0, 2, 1])
|
||||
src_enc = images
|
||||
|
||||
n_bm = self.beam_size
|
||||
src_shape = paddle.shape(src_enc)
|
||||
|
@ -317,11 +304,11 @@ class Transformer(nn.Layer):
|
|||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||
active_inst_idx_list)
|
||||
# Decode
|
||||
for len_dec_seq in range(1, 25):
|
||||
for len_dec_seq in range(1, self.max_len):
|
||||
src_enc_copy = src_enc.clone()
|
||||
active_inst_idx_list = beam_decode_step(
|
||||
inst_dec_beams, len_dec_seq, src_enc_copy,
|
||||
inst_idx_to_position_map, n_bm, None)
|
||||
inst_idx_to_position_map, n_bm)
|
||||
if not active_inst_idx_list:
|
||||
break # all instances have finished their path to <EOS>
|
||||
src_enc, inst_idx_to_position_map = collate_active_info(
|
||||
|
@ -354,261 +341,124 @@ class Transformer(nn.Layer):
|
|||
shape=[sz, sz], dtype='float32', fill_value='-inf'),
|
||||
diagonal=1)
|
||||
mask = mask + mask_inf
|
||||
return mask
|
||||
|
||||
def generate_padding_mask(self, x):
|
||||
padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
|
||||
return padding_mask
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Initiate parameters in the transformer model."""
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
xavier_uniform_(p)
|
||||
return mask.unsqueeze([0, 1])
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Layer):
|
||||
"""TransformerEncoder is a stack of N encoder layers
|
||||
Args:
|
||||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
"""
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
def __init__(self, encoder_layer, num_layers):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src):
|
||||
"""Pass the input through the endocder layers in turn.
|
||||
Args:
|
||||
src: the sequnce to the encoder (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
"""
|
||||
output = src
|
||||
|
||||
for i in range(self.num_layers):
|
||||
output = self.layers[i](output,
|
||||
src_mask=None,
|
||||
src_key_padding_mask=None)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Layer):
|
||||
"""TransformerDecoder is a stack of N decoder layers
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
||||
num_layers: the number of sub-decoder-layers in the decoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder_layer, num_layers):
|
||||
super(TransformerDecoder, self).__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
# self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.self_attn = self_attn
|
||||
if self_attn:
|
||||
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
|
||||
else:
|
||||
self.q = nn.Linear(embed_dim, embed_dim)
|
||||
self.kv = nn.Linear(embed_dim, embed_dim * 2)
|
||||
self.attn_drop = nn.Dropout(dropout)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None):
|
||||
"""Pass the inputs (and mask) through the decoder layer in turn.
|
||||
def forward(self, query, key=None, attn_mask=None):
|
||||
|
||||
Args:
|
||||
tgt: the sequence to the decoder (required).
|
||||
memory: the sequnce from the last layer of the encoder (required).
|
||||
tgt_mask: the mask for the tgt sequence (optional).
|
||||
memory_mask: the mask for the memory sequence (optional).
|
||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||
"""
|
||||
output = tgt
|
||||
for i in range(self.num_layers):
|
||||
output = self.layers[i](
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
qN = query.shape[1]
|
||||
|
||||
return output
|
||||
if self.self_attn:
|
||||
qkv = self.qkv(query).reshape(
|
||||
(0, qN, 3, self.num_heads, self.head_dim)).transpose(
|
||||
(2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
else:
|
||||
kN = key.shape[1]
|
||||
q = self.q(query).reshape(
|
||||
[0, qN, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
|
||||
kv = self.kv(key).reshape(
|
||||
(0, kN, 2, self.num_heads, self.head_dim)).transpose(
|
||||
(2, 0, 3, 1, 4))
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
|
||||
|
||||
if attn_mask is not None:
|
||||
attn += attn_mask
|
||||
|
||||
attn = F.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape(
|
||||
(0, qN, self.embed_dim))
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Layer):
|
||||
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||
in a different way during application.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
|
||||
"""
|
||||
|
||||
class TransformerBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
residual_dropout_rate=0.1,
|
||||
with_self_attn=True,
|
||||
with_cross_attn=False,
|
||||
epsilon=1e-5):
|
||||
super(TransformerBlock, self).__init__()
|
||||
self.with_self_attn = with_self_attn
|
||||
if with_self_attn:
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model,
|
||||
nhead,
|
||||
dropout=attention_dropout_rate,
|
||||
self_attn=with_self_attn)
|
||||
self.norm1 = LayerNorm(d_model, epsilon=epsilon)
|
||||
self.dropout1 = Dropout(residual_dropout_rate)
|
||||
self.with_cross_attn = with_cross_attn
|
||||
if with_cross_attn:
|
||||
self.cross_attn = MultiheadAttention( #for self_attn of encoder or cross_attn of decoder
|
||||
d_model,
|
||||
nhead,
|
||||
dropout=attention_dropout_rate)
|
||||
self.norm2 = LayerNorm(d_model, epsilon=epsilon)
|
||||
self.dropout2 = Dropout(residual_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=d_model,
|
||||
out_channels=dim_feedforward,
|
||||
kernel_size=(1, 1))
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=dim_feedforward,
|
||||
out_channels=d_model,
|
||||
kernel_size=(1, 1))
|
||||
self.mlp = Mlp(in_features=d_model,
|
||||
hidden_features=dim_feedforward,
|
||||
act_layer=nn.ReLU,
|
||||
drop=residual_dropout_rate)
|
||||
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.dropout1 = Dropout(residual_dropout_rate)
|
||||
self.dropout2 = Dropout(residual_dropout_rate)
|
||||
self.norm3 = LayerNorm(d_model, epsilon=epsilon)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""Pass the input through the endocder layer.
|
||||
Args:
|
||||
src: the sequnce to the encoder layer (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
"""
|
||||
src2 = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
|
||||
src = paddle.transpose(src, [1, 2, 0])
|
||||
src = paddle.unsqueeze(src, 2)
|
||||
src2 = self.conv2(F.relu(self.conv1(src)))
|
||||
src2 = paddle.squeeze(src2, 2)
|
||||
src2 = paddle.transpose(src2, [2, 0, 1])
|
||||
src = paddle.squeeze(src, 2)
|
||||
src = paddle.transpose(src, [2, 0, 1])
|
||||
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Layer):
|
||||
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
||||
This standard decoder layer is based on the paper "Attention Is All You Need".
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||
in a different way during application.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
self.multihead_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=d_model,
|
||||
out_channels=dim_feedforward,
|
||||
kernel_size=(1, 1))
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=dim_feedforward,
|
||||
out_channels=d_model,
|
||||
kernel_size=(1, 1))
|
||||
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.norm3 = LayerNorm(d_model)
|
||||
self.dropout1 = Dropout(residual_dropout_rate)
|
||||
self.dropout2 = Dropout(residual_dropout_rate)
|
||||
self.dropout3 = Dropout(residual_dropout_rate)
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None):
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
|
||||
if self.with_self_attn:
|
||||
tgt1 = self.self_attn(tgt, attn_mask=self_mask)
|
||||
tgt = self.norm1(tgt + self.dropout1(tgt1))
|
||||
|
||||
Args:
|
||||
tgt: the sequence to the decoder layer (required).
|
||||
memory: the sequnce from the last layer of the encoder (required).
|
||||
tgt_mask: the mask for the tgt sequence (optional).
|
||||
memory_mask: the mask for the memory sequence (optional).
|
||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||
|
||||
"""
|
||||
tgt2 = self.self_attn(
|
||||
tgt,
|
||||
tgt,
|
||||
tgt,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
tgt,
|
||||
memory,
|
||||
memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
# default
|
||||
tgt = paddle.transpose(tgt, [1, 2, 0])
|
||||
tgt = paddle.unsqueeze(tgt, 2)
|
||||
tgt2 = self.conv2(F.relu(self.conv1(tgt)))
|
||||
tgt2 = paddle.squeeze(tgt2, 2)
|
||||
tgt2 = paddle.transpose(tgt2, [2, 0, 1])
|
||||
tgt = paddle.squeeze(tgt, 2)
|
||||
tgt = paddle.transpose(tgt, [2, 0, 1])
|
||||
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
if self.with_cross_attn:
|
||||
tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
|
||||
tgt = self.norm2(tgt + self.dropout2(tgt2))
|
||||
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
|
||||
return tgt
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
"""Inject some information about the relative or absolute position of the tokens
|
||||
in the sequence. The positional encodings have the same dimension as
|
||||
|
@ -651,8 +501,9 @@ class PositionalEncoding(nn.Layer):
|
|||
Examples:
|
||||
>>> output = pos_encoder(x)
|
||||
"""
|
||||
x = x.transpose([1, 0, 2])
|
||||
x = x + self.pe[:paddle.shape(x)[0], :]
|
||||
return self.dropout(x)
|
||||
return self.dropout(x).transpose([1, 0, 2])
|
||||
|
||||
|
||||
class PositionalEncoding_2d(nn.Layer):
|
||||
|
@ -725,7 +576,7 @@ class PositionalEncoding_2d(nn.Layer):
|
|||
|
||||
|
||||
class Embeddings(nn.Layer):
|
||||
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
||||
def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
|
||||
super(Embeddings, self).__init__()
|
||||
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
||||
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||
|
@ -742,7 +593,7 @@ class Embeddings(nn.Layer):
|
|||
|
||||
|
||||
class Beam():
|
||||
''' Beam search '''
|
||||
""" Beam search """
|
||||
|
||||
def __init__(self, size, device=False):
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
|
|||
from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode
|
||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
|
@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode'
|
||||
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -140,70 +140,6 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
|
|||
return output
|
||||
|
||||
|
||||
class NRTRLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
|
||||
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
|
||||
if len(preds) == 2:
|
||||
preds_id = preds[0]
|
||||
preds_prob = preds[1]
|
||||
if isinstance(preds_id, paddle.Tensor):
|
||||
preds_id = preds_id.numpy()
|
||||
if isinstance(preds_prob, paddle.Tensor):
|
||||
preds_prob = preds_prob.numpy()
|
||||
if preds_id[0][0] == 2:
|
||||
preds_idx = preds_id[:, 1:]
|
||||
preds_prob = preds_prob[:, 1:]
|
||||
else:
|
||||
preds_idx = preds_id
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:, 1:])
|
||||
else:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:, 1:])
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == 3: # end
|
||||
break
|
||||
try:
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
except:
|
||||
continue
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -752,3 +688,122 @@ class PRENLabelDecode(BaseRecLabelDecode):
|
|||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
|
||||
class NRTRLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
|
||||
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
|
||||
if len(preds) == 2:
|
||||
preds_id = preds[0]
|
||||
preds_prob = preds[1]
|
||||
if isinstance(preds_id, paddle.Tensor):
|
||||
preds_id = preds_id.numpy()
|
||||
if isinstance(preds_prob, paddle.Tensor):
|
||||
preds_prob = preds_prob.numpy()
|
||||
if preds_id[0][0] == 2:
|
||||
preds_idx = preds_id[:, 1:]
|
||||
preds_prob = preds_prob[:, 1:]
|
||||
else:
|
||||
preds_idx = preds_id
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:, 1:])
|
||||
else:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:, 1:])
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
try:
|
||||
char_idx = self.character[int(text_index[batch_idx][idx])]
|
||||
except:
|
||||
continue
|
||||
if char_idx == '</s>': # end
|
||||
break
|
||||
char_list.append(char_idx)
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
|
||||
class ViTSTRLabelDecode(NRTRLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds[:, 1:].numpy()
|
||||
else:
|
||||
preds = preds[:, 1:]
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:, 1:])
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class ABINetLabelDecode(NRTRLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(ABINetLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, dict):
|
||||
preds = preds['align'][-1].numpy()
|
||||
elif isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
else:
|
||||
preds = preds
|
||||
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['</s>'] + dict_character
|
||||
return dict_character
|
||||
|
|
|
@ -49,7 +49,7 @@ Architecture:
|
|||
|
||||
|
||||
Loss:
|
||||
name: NRTRLoss
|
||||
name: CELoss
|
||||
smoothing: True
|
||||
|
||||
PostProcess:
|
||||
|
@ -69,7 +69,7 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
|
@ -90,7 +90,7 @@ Eval:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
|
@ -99,5 +99,5 @@ Eval:
|
|||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 1
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 10
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r45_abinet/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_abinet.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
clip_norm: 20.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [6]
|
||||
values: [0.0001, 0.00001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ABINet
|
||||
in_channels: 3
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet45
|
||||
|
||||
Head:
|
||||
name: ABINetHead
|
||||
use_lang: True
|
||||
iter_size: 3
|
||||
|
||||
|
||||
Loss:
|
||||
name: CELoss
|
||||
ignore_index: &ignore_index 100 # Must be greater than the number of character classes
|
||||
|
||||
PostProcess:
|
||||
name: ABINetLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetRecAug:
|
||||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 96
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- ABINetRecResizeImg:
|
||||
image_shape: [3, 32, 128]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_abinet
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./inference/rec_inference
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_r45_abinet_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_abinet/rec_r45_abinet.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,128" --rec_algorithm="ABINet"
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1|6
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--rec_model_dir:
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,32,128]}]
|
|
@ -0,0 +1,117 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 20
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/svtr/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_svtr_tiny.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
epsilon: 8.e-8
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm pos_embed
|
||||
one_dim_param_no_weight_decay: true
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
warmup_epoch: 2
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR
|
||||
Transform:
|
||||
name: STN_ON
|
||||
tps_inputsize: [32, 64]
|
||||
tps_outputsize: [32, 100]
|
||||
num_control_points: 20
|
||||
tps_margins: [0.05,0.05]
|
||||
stn_activation: none
|
||||
Backbone:
|
||||
name: SVTRNet
|
||||
img_size: [32, 100]
|
||||
out_char_num: 25
|
||||
out_channels: 192
|
||||
patch_merging: 'Conv'
|
||||
embed_dim: [64, 128, 256]
|
||||
depth: [3, 6, 3]
|
||||
num_heads: [2, 4, 8]
|
||||
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
|
||||
local_mixer: [[7, 11], [7, 11], [7, 11]]
|
||||
last_stage: True
|
||||
prenorm: false
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 512
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- SVTRRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 2
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_svtrnet
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./inference/rec_inference
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_svtrnet_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="SVTR"
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1|6
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--rec_model_dir:
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,64,256]}]
|
|
@ -0,0 +1,104 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 20
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/vitstr_none_ce/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations after the 0th iteration#
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/EN_symbol_dict.txt
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_vitstr.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adadelta
|
||||
epsilon: 1.e-8
|
||||
rho: 0.95
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 1.0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ViTSTR
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ViTSTR
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTCHead
|
||||
|
||||
Loss:
|
||||
name: CELoss
|
||||
smoothing: False
|
||||
with_all: True
|
||||
ignore_index: &ignore_index 0 # Must be zero or greater than the number of character classes
|
||||
|
||||
PostProcess:
|
||||
name: ViTSTRLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ViTSTRLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [224, 224] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
inter_type: 'Image.BICUBIC'
|
||||
scale: false
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 48
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ViTSTRLabelEncode: # Class handling label
|
||||
ignore_index: *ignore_index
|
||||
- GrayRecResizeImg:
|
||||
image_shape: [224, 224] # W H
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
inter_type: 'Image.BICUBIC'
|
||||
scale: false
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 2
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_vitstr
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./inference/rec_inference
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_vitstr_none_ce_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_vitstr_none_ce/rec_vitstr_none_ce.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="ViTSTR"
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1|6
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--rec_model_dir:
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[1,224,224]}]
|
|
@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
|
|||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
||||
def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
||||
def export_single_model(model,
|
||||
arch_config,
|
||||
save_path,
|
||||
logger,
|
||||
input_shape=None,
|
||||
quanter=None):
|
||||
if arch_config["algorithm"] == "SRN":
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
|
@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
|||
else:
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 64, 256], dtype="float32"),
|
||||
shape=[None] + input_shape, dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "PREN":
|
||||
|
@ -73,6 +78,25 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
|||
shape=[None, 3, 64, 512], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "ViTSTR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 224, 224], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "ABINet":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 32, 128], dtype="float32"),
|
||||
]
|
||||
# print([None, 3, 32, 128])
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "NRTR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
@ -84,8 +108,6 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
|||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
if arch_config["algorithm"] == "NRTR":
|
||||
infer_shape = [1, 32, 100]
|
||||
elif arch_config["model_type"] == "table":
|
||||
infer_shape = [3, 488, 488]
|
||||
model = to_static(
|
||||
|
@ -157,6 +179,13 @@ def main():
|
|||
|
||||
arch_config = config["Architecture"]
|
||||
|
||||
if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
|
||||
"name"] != 'MultiHead':
|
||||
input_shape = config["Eval"]["dataset"]["transforms"][-2][
|
||||
'SVTRRecResizeImg']['image_shape']
|
||||
else:
|
||||
input_shape = None
|
||||
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
archs = list(arch_config["Models"].values())
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
|
@ -165,7 +194,8 @@ def main():
|
|||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(model, arch_config, save_path, logger)
|
||||
export_single_model(
|
||||
model, arch_config, save_path, logger, input_shape=input_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -69,6 +69,18 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == 'ViTSTR':
|
||||
postprocess_params = {
|
||||
'name': 'ViTSTRLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == 'ABINet':
|
||||
postprocess_params = {
|
||||
'name': 'ABINetLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -96,15 +108,22 @@ class TextRecognizer(object):
|
|||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
if self.rec_algorithm == 'NRTR':
|
||||
if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize([100, 32], Image.ANTIALIAS)
|
||||
if self.rec_algorithm == 'ViTSTR':
|
||||
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
|
||||
else:
|
||||
img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
|
||||
img = np.array(img)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
return norm_img.astype(np.float32) / 128. - 1.
|
||||
if self.rec_algorithm == 'ViTSTR':
|
||||
norm_img = norm_img.astype(np.float32) / 255.
|
||||
else:
|
||||
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
||||
return norm_img
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
|
@ -132,17 +151,6 @@ class TextRecognizer(object):
|
|||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
|
@ -250,6 +258,35 @@ class TextRecognizer(object):
|
|||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_abinet(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255.
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (
|
||||
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype('float32')
|
||||
|
||||
return resized_image
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -300,6 +337,11 @@ class TextRecognizer(object):
|
|||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "ABINet":
|
||||
norm_img = self.resize_norm_img_abinet(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
|
|
|
@ -576,7 +576,8 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
|
||||
'ViTSTR', 'ABINet'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue