add sr model Text Telescope
parent
8babfc86af
commit
0cdfc52507
|
@ -0,0 +1,84 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/sr/sr_telescope/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir: ./output/sr/sr_telescope/infer
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_52.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/sr/predicts_telescope.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.5
|
||||
beta2: 0.999
|
||||
clip_norm: 0.25
|
||||
lr:
|
||||
learning_rate: 0.0001
|
||||
|
||||
Architecture:
|
||||
model_type: sr
|
||||
algorithm: Telescope
|
||||
Transform:
|
||||
name: TBSRN
|
||||
STN: True
|
||||
infer_mode: False
|
||||
|
||||
Loss:
|
||||
name: TelescopeLoss
|
||||
confuse_dict_path: ./ppocr/utils/dict/confuse.pkl
|
||||
|
||||
|
||||
PostProcess:
|
||||
name: None
|
||||
|
||||
Metric:
|
||||
name: SRMetric
|
||||
main_indicator: all
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSetSR
|
||||
data_dir: ./train_data/TextZoom/train
|
||||
transforms:
|
||||
- SRResize:
|
||||
imgH: 32
|
||||
imgW: 128
|
||||
down_sample_scale: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 16
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSetSR
|
||||
data_dir: ./train_data/TextZoom/test
|
||||
transforms:
|
||||
- SRResize:
|
||||
imgH: 32
|
||||
imgW: 128
|
||||
down_sample_scale: 2
|
||||
- KeepKeys:
|
||||
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 0
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
# Text Telescope
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
|
||||
|
||||
> Chen, Jingye, Bin Li, and Xiangyang Xue
|
||||
|
||||
> CVPR, 2021
|
||||
|
||||
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
|
||||
|
||||
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
|
||||
|---|---|---|---|---|---|
|
||||
|Text Telescope|tbsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[训练模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
|
||||
|
||||
[TextZoom数据集](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) 来自两个超分数据集RealSR和SR-RAW,两个数据集都包含LR-HR对,TextZoom有17367对训数据和4373对测试数据。
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
|
||||
|
||||
- 训练
|
||||
|
||||
在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
|
||||
```
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/sr/sr_telescope.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
|
||||
|
||||
```
|
||||
|
||||
- 评估
|
||||
|
||||
```
|
||||
# GPU 评估, Global.pretrained_model 为待测权重
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
- 预测:
|
||||
|
||||
```
|
||||
# 预测使用的配置文件必须与训练一致
|
||||
python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的超分结果如下:
|
||||
|
||||

|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
|
||||
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Telescope 训练的[模型](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) 为例,可以使用如下命令进行转换:
|
||||
```shell
|
||||
python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
|
||||
```
|
||||
Text-Telescope 文本超分模型推理,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
|
||||
|
||||
```
|
||||
|
||||
执行命令后,图像的超分结果如下:
|
||||
|
||||

|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂未支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@INPROCEEDINGS{9578891,
|
||||
author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
|
||||
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
|
||||
year={2021},
|
||||
volume={},
|
||||
number={},
|
||||
pages={12021-12030},
|
||||
doi={10.1109/CVPR46437.2021.01185}}
|
||||
```
|
|
@ -0,0 +1,137 @@
|
|||
# Text Gestalt
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [Scene Text Telescope: Text-Focused Scene Image Super-Resolution](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Scene_Text_Telescope_Text-Focused_Scene_Image_Super-Resolution_CVPR_2021_paper.pdf)
|
||||
|
||||
> Chen, Jingye, Bin Li, and Xiangyang Xue
|
||||
|
||||
> CVPR, 2021
|
||||
|
||||
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
|
||||
|
||||
|Model|Backbone|config|Acc|Download link|
|
||||
|---|---|---|---|---|---|
|
||||
|Text Gestalt|tsrn|21.56|0.7411| [configs/sr/sr_telescope.yml](../../configs/sr/sr_telescope.yml)|[train model](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz)|
|
||||
|
||||
The [TextZoom dataset](https://paddleocr.bj.bcebos.com/dataset/TextZoom.tar) comes from two superfraction data sets, RealSR and SR-RAW, both of which contain LR-HR pairs. TextZoom has 17367 pairs of training data and 4373 pairs of test data.
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
|
||||
python3 tools/train.py -c configs/sr/sr_telescope.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_telescope.yml
|
||||
|
||||
```
|
||||
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
|
||||
python3 tools/infer_sr.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
|
||||
```
|
||||
|
||||

|
||||
|
||||
After executing the command, the super-resolution result of the above image is as follows:
|
||||
|
||||

|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
|
||||
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/Telescope_train.tar.gz) ), you can use the following command to convert:
|
||||
|
||||
```shell
|
||||
python3 tools/export_model.py -c configs/sr/sr_telescope.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
|
||||
```
|
||||
|
||||
For Text-Telescope super-resolution model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
|
||||
|
||||
```
|
||||
|
||||
After executing the command, the super-resolution result of the above image is as follows:
|
||||
|
||||

|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@INPROCEEDINGS{9578891,
|
||||
author={Chen, Jingye and Li, Bin and Xue, Xiangyang},
|
||||
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
title={Scene Text Telescope: Text-Focused Scene Image Super-Resolution},
|
||||
year={2021},
|
||||
volume={},
|
||||
number={},
|
||||
pages={12021-12030},
|
||||
doi={10.1109/CVPR46437.2021.01185}}
|
||||
```
|
|
@ -25,8 +25,6 @@ from .det_east_loss import EASTLoss
|
|||
from .det_sast_loss import SASTLoss
|
||||
from .det_pse_loss import PSELoss
|
||||
from .det_fce_loss import FCELoss
|
||||
from .det_ct_loss import CTLoss
|
||||
from .det_drrg_loss import DRRGLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -39,7 +37,6 @@ from .rec_pren_loss import PRENLoss
|
|||
from .rec_multi_loss import MultiLoss
|
||||
from .rec_vl_loss import VLLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
from .rec_rfl_loss import RFLLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -62,6 +59,7 @@ from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
|||
|
||||
# sr loss
|
||||
from .stroke_focus_loss import StrokeFocusLoss
|
||||
from .text_focus_loss import TelescopeLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
|
@ -71,7 +69,7 @@ def build_loss(config):
|
|||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
|
||||
'SLALoss', 'TelescopeLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/loss/text_focus_loss.py
|
||||
"""
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle
|
||||
import numpy as np
|
||||
import pickle as pkl
|
||||
|
||||
standard_alphebet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
standard_dict = {}
|
||||
for index in range(len(standard_alphebet)):
|
||||
standard_dict[standard_alphebet[index]] = index
|
||||
|
||||
|
||||
def load_confuse_matrix(confuse_dict_path):
|
||||
f = open(confuse_dict_path, 'rb')
|
||||
data = pkl.load(f)
|
||||
f.close()
|
||||
number = data[:10]
|
||||
upper = data[10:36]
|
||||
lower = data[36:]
|
||||
end = np.ones((1, 62))
|
||||
pad = np.ones((63, 1))
|
||||
rearrange_data = np.concatenate((end, number, lower, upper), axis=0)
|
||||
rearrange_data = np.concatenate((pad, rearrange_data), axis=1)
|
||||
rearrange_data = 1 / rearrange_data
|
||||
rearrange_data[rearrange_data == np.inf] = 1
|
||||
rearrange_data = paddle.to_tensor(rearrange_data)
|
||||
|
||||
lower_alpha = 'abcdefghijklmnopqrstuvwxyz'
|
||||
# upper_alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
for i in range(63):
|
||||
for j in range(63):
|
||||
if i != j and standard_alphebet[j] in lower_alpha:
|
||||
rearrange_data[i][j] = max(rearrange_data[i][j], rearrange_data[i][j + 26])
|
||||
rearrange_data = rearrange_data[:37, :37]
|
||||
|
||||
return rearrange_data
|
||||
|
||||
|
||||
def weight_cross_entropy(pred, gt, weight_table):
|
||||
batch = gt.shape[0]
|
||||
weight = weight_table[gt]
|
||||
pred_exp = paddle.exp(pred)
|
||||
pred_exp_weight = weight * pred_exp
|
||||
loss = 0
|
||||
for i in range(len(gt)):
|
||||
loss -= paddle.log(pred_exp_weight[i][gt[i]] / paddle.sum(pred_exp_weight, 1)[i])
|
||||
return loss / batch
|
||||
|
||||
|
||||
class TelescopeLoss(nn.Layer):
|
||||
def __init__(self, confuse_dict_path):
|
||||
super(TelescopeLoss, self).__init__()
|
||||
self.weight_table = load_confuse_matrix(confuse_dict_path)
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.ce_loss = nn.CrossEntropyLoss()
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, pred, data):
|
||||
sr_img = pred["sr_img"]
|
||||
hr_img = pred["hr_img"]
|
||||
sr_pred = pred["sr_pred"]
|
||||
text_gt = pred["text_gt"]
|
||||
|
||||
word_attention_map_gt = pred["word_attention_map_gt"]
|
||||
word_attention_map_pred = pred["word_attention_map_pred"]
|
||||
mse_loss = self.mse_loss(sr_img, hr_img)
|
||||
attention_loss = self.l1_loss(word_attention_map_gt, word_attention_map_pred)
|
||||
recognition_loss = weight_cross_entropy(sr_pred, text_gt, self.weight_table)
|
||||
loss = mse_loss + attention_loss * 10 + recognition_loss * 0.0005
|
||||
return {
|
||||
"mse_loss": mse_loss,
|
||||
"attention_loss": attention_loss,
|
||||
"loss": loss
|
||||
}
|
|
@ -15,18 +15,12 @@
|
|||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
|
||||
"""
|
||||
import copy
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import math, copy
|
||||
import numpy as np
|
||||
|
||||
# stroke-level alphabet
|
||||
alphabet = '0123456789'
|
||||
|
||||
|
||||
def get_alphabet_len():
|
||||
return len(alphabet)
|
||||
|
||||
|
||||
def subsequent_mask(size):
|
||||
|
@ -373,10 +367,10 @@ class Encoder(nn.Layer):
|
|||
|
||||
|
||||
class Transformer(nn.Layer):
|
||||
def __init__(self, in_channels=1):
|
||||
def __init__(self, in_channels=1, alphabet='0123456789'):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
word_n_class = get_alphabet_len()
|
||||
self.alphabet = alphabet
|
||||
word_n_class = self.get_alphabet_len()
|
||||
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
|
||||
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
|
||||
|
||||
|
@ -388,6 +382,9 @@ class Transformer(nn.Layer):
|
|||
if p.dim() > 1:
|
||||
nn.initializer.XavierNormal(p)
|
||||
|
||||
def get_alphabet_len(self):
|
||||
return len(self.alphabet)
|
||||
|
||||
def forward(self, image, text_length, text_input, attention_map=None):
|
||||
if image.shape[1] == 3:
|
||||
R = image[:, 0:1, :, :]
|
||||
|
@ -415,7 +412,7 @@ class Transformer(nn.Layer):
|
|||
|
||||
if self.training:
|
||||
total_length = paddle.sum(text_length)
|
||||
probs_res = paddle.zeros([total_length, get_alphabet_len()])
|
||||
probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
|
||||
start = 0
|
||||
|
||||
for index, length in enumerate(text_length):
|
||||
|
|
|
@ -19,9 +19,10 @@ def build_transform(config):
|
|||
from .tps import TPS
|
||||
from .stn import STN_ON
|
||||
from .tsrn import TSRN
|
||||
from .tbsrn import TBSRN
|
||||
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
|
||||
|
||||
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
|
||||
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN', 'TBSRN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/FudanVI/FudanOCR/blob/main/scene-text-telescope/model/tbsrn.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import string
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
from .stn import STN as STNHead
|
||||
from .tsrn import GruBlock, mish, UpsampleBLock
|
||||
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer, LayerNorm, \
|
||||
PositionwiseFeedForward, MultiHeadedAttention
|
||||
|
||||
|
||||
def positionalencoding2d(d_model, height, width):
|
||||
"""
|
||||
:param d_model: dimension of the model
|
||||
:param height: height of the positions
|
||||
:param width: width of the positions
|
||||
:return: d_model*height*width position matrix
|
||||
"""
|
||||
if d_model % 4 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd dimension (got dim={:d})".format(d_model))
|
||||
pe = paddle.zeros([d_model, height, width])
|
||||
# Each dimension use half of d_model
|
||||
d_model = int(d_model / 2)
|
||||
div_term = paddle.exp(paddle.arange(0., d_model, 2) *
|
||||
-(math.log(10000.0) / d_model))
|
||||
pos_w = paddle.arange(0., width, dtype='float32').unsqueeze(1)
|
||||
pos_h = paddle.arange(0., height, dtype='float32').unsqueeze(1)
|
||||
|
||||
pe[0:d_model:2, :, :] = paddle.sin(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
|
||||
pe[1:d_model:2, :, :] = paddle.cos(pos_w * div_term).transpose([1, 0]).unsqueeze(1).tile([1, height, 1])
|
||||
pe[d_model::2, :, :] = paddle.sin(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
|
||||
pe[d_model + 1::2, :, :] = paddle.cos(pos_h * div_term).transpose([1, 0]).unsqueeze(2).tile([1, 1, width])
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
class FeatureEnhancer(nn.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(FeatureEnhancer, self).__init__()
|
||||
|
||||
self.multihead = MultiHeadedAttention(h=4, d_model=128, dropout=0.1)
|
||||
self.mul_layernorm1 = LayerNorm(features=128)
|
||||
|
||||
self.pff = PositionwiseFeedForward(128, 128)
|
||||
self.mul_layernorm3 = LayerNorm(features=128)
|
||||
|
||||
self.linear = nn.Linear(128, 64)
|
||||
|
||||
def forward(self, conv_feature):
|
||||
'''
|
||||
text : (batch, seq_len, embedding_size)
|
||||
global_info: (batch, embedding_size, 1, 1)
|
||||
conv_feature: (batch, channel, H, W)
|
||||
'''
|
||||
batch = conv_feature.shape[0]
|
||||
position2d = positionalencoding2d(64, 16, 64).cast('float32').unsqueeze(0).reshape([1, 64, 1024])
|
||||
position2d = position2d.tile([batch, 1, 1])
|
||||
conv_feature = paddle.concat([conv_feature, position2d], 1) # batch, 128(64+64), 32, 128
|
||||
result = conv_feature.transpose([0, 2, 1])
|
||||
origin_result = result
|
||||
result = self.mul_layernorm1(origin_result + self.multihead(result, result, result, mask=None)[0])
|
||||
origin_result = result
|
||||
result = self.mul_layernorm3(origin_result + self.pff(result))
|
||||
result = self.linear(result)
|
||||
return result.transpose([0, 2, 1])
|
||||
|
||||
|
||||
def str_filt(str_, voc_type):
|
||||
alpha_dict = {
|
||||
'digit': string.digits,
|
||||
'lower': string.digits + string.ascii_lowercase,
|
||||
'upper': string.digits + string.ascii_letters,
|
||||
'all': string.digits + string.ascii_letters + string.punctuation
|
||||
}
|
||||
if voc_type == 'lower':
|
||||
str_ = str_.lower()
|
||||
for char in str_:
|
||||
if char not in alpha_dict[voc_type]:
|
||||
str_ = str_.replace(char, '')
|
||||
str_ = str_.lower()
|
||||
return str_
|
||||
|
||||
|
||||
class TBSRN(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
scale_factor=2,
|
||||
width=128,
|
||||
height=32,
|
||||
STN=True,
|
||||
srb_nums=5,
|
||||
mask=False,
|
||||
hidden_units=32,
|
||||
infer_mode=False):
|
||||
super(TBSRN, self).__init__()
|
||||
in_planes = 3
|
||||
if mask:
|
||||
in_planes = 4
|
||||
assert math.log(scale_factor, 2) % 1 == 0
|
||||
upsample_block_num = int(math.log(scale_factor, 2))
|
||||
self.block1 = nn.Sequential(
|
||||
nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4),
|
||||
nn.PReLU()
|
||||
# nn.ReLU()
|
||||
)
|
||||
self.srb_nums = srb_nums
|
||||
for i in range(srb_nums):
|
||||
setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2 * hidden_units))
|
||||
|
||||
setattr(self, 'block%d' % (srb_nums + 2),
|
||||
nn.Sequential(
|
||||
nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2D(2 * hidden_units)
|
||||
))
|
||||
|
||||
# self.non_local = NonLocalBlock2D(64, 64)
|
||||
block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
|
||||
block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
|
||||
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
|
||||
self.tps_inputsize = [height // scale_factor, width // scale_factor]
|
||||
tps_outputsize = [height // scale_factor, width // scale_factor]
|
||||
num_control_points = 20
|
||||
tps_margins = [0.05, 0.05]
|
||||
self.stn = STN
|
||||
self.out_channels = in_channels
|
||||
if self.stn:
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
|
||||
self.stn_head = STNHead(
|
||||
in_channels=in_planes,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation='none')
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
self.english_alphabet = '-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
self.english_dict = {}
|
||||
for index in range(len(self.english_alphabet)):
|
||||
self.english_dict[self.english_alphabet[index]] = index
|
||||
transformer = Transformer(alphabet='-0123456789abcdefghijklmnopqrstuvwxyz')
|
||||
self.transformer = transformer
|
||||
for param in self.transformer.parameters():
|
||||
param.trainable = False
|
||||
|
||||
def label_encoder(self, label):
|
||||
batch = len(label)
|
||||
|
||||
length = [len(i) for i in label]
|
||||
length_tensor = paddle.to_tensor(length, dtype='int64')
|
||||
|
||||
max_length = max(length)
|
||||
input_tensor = np.zeros((batch, max_length))
|
||||
for i in range(batch):
|
||||
for j in range(length[i] - 1):
|
||||
input_tensor[i][j + 1] = self.english_dict[label[i][j]]
|
||||
|
||||
text_gt = []
|
||||
for i in label:
|
||||
for j in i:
|
||||
text_gt.append(self.english_dict[j])
|
||||
text_gt = paddle.to_tensor(text_gt, dtype='int64')
|
||||
|
||||
input_tensor = paddle.to_tensor(input_tensor, dtype='int64')
|
||||
return length_tensor, input_tensor, text_gt
|
||||
|
||||
def forward(self, x):
|
||||
output = {}
|
||||
if self.infer_mode:
|
||||
output["lr_img"] = x
|
||||
y = x
|
||||
else:
|
||||
output["lr_img"] = x[0]
|
||||
output["hr_img"] = x[1]
|
||||
y = x[0]
|
||||
if self.stn and self.training:
|
||||
_, ctrl_points_x = self.stn_head(y)
|
||||
y, _ = self.tps(y, ctrl_points_x)
|
||||
block = {'1': self.block1(y)}
|
||||
for i in range(self.srb_nums + 1):
|
||||
block[str(i + 2)] = getattr(self,
|
||||
'block%d' % (i + 2))(block[str(i + 1)])
|
||||
|
||||
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
|
||||
((block['1'] + block[str(self.srb_nums + 2)]))
|
||||
|
||||
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
|
||||
output["sr_img"] = sr_img
|
||||
|
||||
if self.training:
|
||||
hr_img = x[1]
|
||||
|
||||
# add transformer
|
||||
label = [str_filt(i, 'lower') + '-' for i in x[2]]
|
||||
length_tensor, input_tensor, text_gt = self.label_encoder(label)
|
||||
hr_pred, word_attention_map_gt, hr_correct_list = self.transformer(hr_img, length_tensor,
|
||||
input_tensor)
|
||||
sr_pred, word_attention_map_pred, sr_correct_list = self.transformer(sr_img, length_tensor,
|
||||
input_tensor)
|
||||
output["hr_img"] = hr_img
|
||||
output["hr_pred"] = hr_pred
|
||||
output["text_gt"] = text_gt
|
||||
output["word_attention_map_gt"] = word_attention_map_gt
|
||||
output["sr_pred"] = sr_pred
|
||||
output["word_attention_map_pred"] = word_attention_map_pred
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RecurrentResidualBlock(nn.Layer):
|
||||
def __init__(self, channels):
|
||||
super(RecurrentResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(channels)
|
||||
self.gru1 = GruBlock(channels, channels)
|
||||
# self.prelu = nn.ReLU()
|
||||
self.prelu = mish()
|
||||
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(channels)
|
||||
self.gru2 = GruBlock(channels, channels)
|
||||
self.feature_enhancer = FeatureEnhancer()
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
paddle.nn.initializer.XavierUniform(p)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = self.bn1(residual)
|
||||
residual = self.prelu(residual)
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
size = residual.shape
|
||||
residual = residual.reshape([size[0], size[1], -1])
|
||||
residual = self.feature_enhancer(residual)
|
||||
residual = residual.reshape([size[0], size[1], size[2], size[3]])
|
||||
return x + residual
|
Binary file not shown.
Loading…
Reference in New Issue