Merge pull request #7940 from dorren002/new_branch
add handwritten mathematical expression recognition algorithm CANpull/7956/head^2
commit
3907c72a08
|
@ -0,0 +1,122 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 240
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/can/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 1105 iterations (1 epoch)(batch_size = 8)
|
||||
eval_batch_step: [0, 1105]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/datasets/crohme_demo/hme_00.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
|
||||
max_text_length: 36
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_can.txt
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
clip_norm_global: 100.0
|
||||
lr:
|
||||
name: TwoStepCosine
|
||||
learning_rate: 0.01
|
||||
warmup_epoch: 1
|
||||
weight_decay: 0.0001
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CAN
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: DenseNet
|
||||
growthRate: 24
|
||||
reduction: 0.5
|
||||
bottleneck: True
|
||||
use_dropout: True
|
||||
input_channel: 1
|
||||
Head:
|
||||
name: CANHead
|
||||
in_channel: 684
|
||||
out_channel: 111
|
||||
max_text_length: 36
|
||||
ratio: 16
|
||||
attdecoder:
|
||||
is_train: True
|
||||
input_size: 256
|
||||
hidden_size: 256
|
||||
encoder_out_channel: 684
|
||||
dropout: True
|
||||
dropout_ratio: 0.5
|
||||
word_num: 111
|
||||
counting_decoder_out_channel: 111
|
||||
attention:
|
||||
attention_dim: 512
|
||||
word_conv_kernel: 1
|
||||
|
||||
Loss:
|
||||
name: CANLoss
|
||||
|
||||
PostProcess:
|
||||
name: CANLabelDecode
|
||||
|
||||
Metric:
|
||||
name: CANMetric
|
||||
main_indicator: exp_rate
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/CROHME/training/images/
|
||||
label_file_list: ["./train_data/CROHME/training/labels.txt"]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
channel_first: False
|
||||
- NormalizeImage:
|
||||
mean: [0,0,0]
|
||||
std: [1,1,1]
|
||||
order: 'hwc'
|
||||
- GrayImageChannelFormat:
|
||||
inverse: True
|
||||
- CANLabelEncode:
|
||||
lower: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 8
|
||||
drop_last: False
|
||||
num_workers: 4
|
||||
collate_fn: DyMaskCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/CROHME/evaluation/images/
|
||||
label_file_list: ["./train_data/CROHME/evaluation/labels.txt"]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
channel_first: False
|
||||
- NormalizeImage:
|
||||
mean: [0,0,0]
|
||||
std: [1,1,1]
|
||||
order: 'hwc'
|
||||
- GrayImageChannelFormat:
|
||||
inverse: True
|
||||
- CANLabelEncode:
|
||||
lower: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1
|
||||
num_workers: 4
|
||||
collate_fn: DyMaskCollator
|
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Binary file not shown.
After Width: | Height: | Size: 15 KiB |
Binary file not shown.
After Width: | Height: | Size: 4.8 KiB |
|
@ -0,0 +1,174 @@
|
|||
# 手写数学公式识别算法-CAN
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
论文信息:
|
||||
> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
|
||||
> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
|
||||
> ECCV, 2022
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`CAN`使用CROHME手写公式数据集进行训练,在对应测试集上的精度如下:
|
||||
|
||||
|模型 |骨干网络|配置文件|ExpRate|下载链接|
|
||||
| ----- | ----- | ----- | ----- | ----- |
|
||||
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`CAN`识别模型时需要**更换配置文件**为`CAN`的[配置文件](../../configs/rec/rec_d28_can.yml)。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
|
||||
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
|
||||
```shell
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
```
|
||||
|
||||
**注意:**
|
||||
- 我们提供的数据集,即[`CROHME数据集`](https://paddleocr.bj.bcebos.com/dataset/CROHME.tar)将手写公式存储为黑底白字的格式,若您自行准备的数据集与之相反,即以白底黑字模式存储,请在训练时做出如下修改
|
||||
```
|
||||
python3 tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False
|
||||
```
|
||||
- 默认每训练1个epoch(1105次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
|
||||
```
|
||||
python3 tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
-o Global.eval_batch_step=[0, {length_of_dataset//batch_size}]
|
||||
```
|
||||
|
||||
#
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/can_train.tar),使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/datasets/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
|
||||
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/crohme_demo/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/can_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
|
||||
|
||||
# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72
|
||||
```
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/rec_d28_can/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/datasets/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
|
||||
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/crohme_demo/'。
|
||||
|
||||
# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False
|
||||
```
|
||||
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_hme/hme_00.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
|
||||
```
|
||||
|
||||
|
||||
**注意**:
|
||||
|
||||
- 需要注意预测图像为**黑底白字**,即手写公式部分为白色,背景为黑色的图片。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中CAN的预处理为您的预处理方法。
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持CAN,所以暂未支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
1. CROHME数据集来自于[CAN源repo](https://github.com/LBH1024/CAN) 。
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@misc{https://doi.org/10.48550/arxiv.2207.11463,
|
||||
doi = {10.48550/ARXIV.2207.11463},
|
||||
url = {https://arxiv.org/abs/2207.11463},
|
||||
author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
|
||||
keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
||||
title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
|
||||
publisher = {arXiv},
|
||||
year = {2022},
|
||||
copyright = {arXiv.org perpetual, non-exclusive license}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,119 @@
|
|||
# CAN
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Environment](#2)
|
||||
- [3. Model Training / Evaluation / Prediction](#3)
|
||||
- [3.1 Training](#3-1)
|
||||
- [3.2 Evaluation](#3-2)
|
||||
- [3.3 Prediction](#3-3)
|
||||
- [4. Inference and Deployment](#4)
|
||||
- [4.1 Python Inference](#4-1)
|
||||
- [4.2 C++ Inference](#4-2)
|
||||
- [4.3 Serving](#4-3)
|
||||
- [4.4 More](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. Introduction
|
||||
|
||||
Paper:
|
||||
> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
|
||||
> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
|
||||
> ECCV, 2022
|
||||
|
||||
Using CROHME handwrittem mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
|
||||
|
||||
|Model|Backbone|config|exprate|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[trained model](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Environment
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Model Training / Evaluation / Prediction
|
||||
|
||||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
|
||||
|
||||
Training:
|
||||
|
||||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
|
||||
|
||||
```
|
||||
#Single GPU training (long training period, not recommended)
|
||||
python3 tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
|
||||
#Multi GPU training, specify the gpu number through the --gpus parameter
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```
|
||||
# GPU evaluation
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
|
||||
```
|
||||
|
||||
Prediction:
|
||||
|
||||
```
|
||||
# The configuration file used for prediction must match the training
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
|
||||
```
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Inference and Deployment
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
First, the model saved during the CAN handwritten mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
|
||||
|
||||
# The default output max length of the model is 36. If you need to predict a longer sequence, please specify its output sequence as an appropriate value when exporting the model, as: Architecture.Head.max_ text_ length=72
|
||||
```
|
||||
|
||||
For CAN handwritten mathematical expression recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
|
||||
|
||||
# If you need to predict on a picture with black characters on a white background, please set: -- rec_ image_ inverse=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 More
|
||||
|
||||
Not supported
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{https://doi.org/10.48550/arxiv.2207.11463,
|
||||
doi = {10.48550/ARXIV.2207.11463},
|
||||
url = {https://arxiv.org/abs/2207.11463},
|
||||
author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
|
||||
keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
||||
title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
|
||||
publisher = {arXiv},
|
||||
year = {2022},
|
||||
copyright = {arXiv.org perpetual, non-exclusive license}
|
||||
}
|
||||
```
|
|
@ -70,3 +70,49 @@ class SSLRotateCollate(object):
|
|||
def __call__(self, batch):
|
||||
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
||||
return output
|
||||
|
||||
|
||||
class DyMaskCollator(object):
|
||||
"""
|
||||
batch: [
|
||||
image [batch_size, channel, maxHinbatch, maxWinbatch]
|
||||
image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
|
||||
label [batch_size, maxLabelLen]
|
||||
label_mask [batch_size, maxLabelLen]
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
max_width, max_height, max_length = 0, 0, 0
|
||||
bs, channel = len(batch), batch[0][0].shape[0]
|
||||
proper_items = []
|
||||
for item in batch:
|
||||
if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
|
||||
2] * max_height > 1600 * 320:
|
||||
continue
|
||||
max_height = item[0].shape[1] if item[0].shape[
|
||||
1] > max_height else max_height
|
||||
max_width = item[0].shape[2] if item[0].shape[
|
||||
2] > max_width else max_width
|
||||
max_length = len(item[1]) if len(item[
|
||||
1]) > max_length else max_length
|
||||
proper_items.append(item)
|
||||
|
||||
images, image_masks = np.zeros(
|
||||
(len(proper_items), channel, max_height, max_width),
|
||||
dtype='float32'), np.zeros(
|
||||
(len(proper_items), 1, max_height, max_width), dtype='float32')
|
||||
labels, label_masks = np.zeros(
|
||||
(len(proper_items), max_length), dtype='int64'), np.zeros(
|
||||
(len(proper_items), max_length), dtype='int64')
|
||||
|
||||
for i in range(len(proper_items)):
|
||||
_, h, w = proper_items[i][0].shape
|
||||
images[i][:, :h, :w] = proper_items[i][0]
|
||||
image_masks[i][:, :h, :w] = 1
|
||||
l = len(proper_items[i][1])
|
||||
labels[i][:l] = proper_items[i][1]
|
||||
label_masks[i][:l] = 1
|
||||
|
||||
return images, image_masks, labels, label_masks
|
||||
|
|
|
@ -1474,4 +1474,33 @@ class CTLabelEncode(object):
|
|||
|
||||
data['polys'] = boxes
|
||||
data['texts'] = txts
|
||||
return data
|
||||
return data
|
||||
|
||||
|
||||
class CANLabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
max_text_length=100,
|
||||
use_space_char=False,
|
||||
lower=True,
|
||||
**kwargs):
|
||||
super(CANLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char, lower)
|
||||
|
||||
def encode(self, text_seq):
|
||||
text_seq_encoded = []
|
||||
for text in text_seq:
|
||||
if text not in self.character:
|
||||
continue
|
||||
text_seq_encoded.append(self.dict.get(text))
|
||||
if len(text_seq_encoded) == 0:
|
||||
return None
|
||||
return text_seq_encoded
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
if isinstance(label, str):
|
||||
label = label.strip().split()
|
||||
label.append(self.end_str)
|
||||
data['label'] = self.encode(label)
|
||||
return data
|
||||
|
|
|
@ -498,3 +498,27 @@ class ResizeNormalize(object):
|
|||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
|
||||
class GrayImageChannelFormat(object):
|
||||
"""
|
||||
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
||||
Args:
|
||||
inverse: inverse gray image
|
||||
"""
|
||||
|
||||
def __init__(self, inverse=False, **kwargs):
|
||||
self.inverse = inverse
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img_expanded = np.expand_dims(img_single_channel, 0)
|
||||
|
||||
if self.inverse:
|
||||
data['image'] = np.abs(img_expanded - 1)
|
||||
else:
|
||||
data['image'] = img_expanded
|
||||
|
||||
data['src_image'] = img
|
||||
return data
|
|
@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss
|
|||
from .rec_vl_loss import VLLoss
|
||||
from .rec_spin_att_loss import SPINAttentionLoss
|
||||
from .rec_rfl_loss import RFLLoss
|
||||
from .rec_can_loss import CANLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -71,7 +72,7 @@ def build_loss(config):
|
|||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
|
||||
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/LBH1024/CAN/models/can.py
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CANLoss(nn.Layer):
|
||||
'''
|
||||
CANLoss is consist of two part:
|
||||
word_average_loss: average accuracy of the symbol
|
||||
counting_loss: counting loss of every symbol
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
super(CANLoss, self).__init__()
|
||||
|
||||
self.use_label_mask = False
|
||||
self.out_channel = 111
|
||||
self.cross = nn.CrossEntropyLoss(
|
||||
reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
|
||||
self.counting_loss = nn.SmoothL1Loss(reduction='mean')
|
||||
self.ratio = 16
|
||||
|
||||
def forward(self, preds, batch):
|
||||
word_probs = preds[0]
|
||||
counting_preds = preds[1]
|
||||
counting_preds1 = preds[2]
|
||||
counting_preds2 = preds[3]
|
||||
labels = batch[2]
|
||||
labels_mask = batch[3]
|
||||
counting_labels = gen_counting_label(labels, self.out_channel, True)
|
||||
counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
|
||||
+ self.counting_loss(counting_preds, counting_labels)
|
||||
|
||||
word_loss = self.cross(
|
||||
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
|
||||
paddle.reshape(labels, [-1]))
|
||||
word_average_loss = paddle.sum(
|
||||
paddle.reshape(word_loss * labels_mask, [-1])) / (
|
||||
paddle.sum(labels_mask) + 1e-10
|
||||
) if self.use_label_mask else word_loss
|
||||
loss = word_average_loss + counting_loss
|
||||
return {'loss': loss}
|
||||
|
||||
|
||||
def gen_counting_label(labels, channel, tag):
|
||||
b, t = labels.shape
|
||||
counting_labels = np.zeros([b, channel])
|
||||
|
||||
if tag:
|
||||
ignore = [0, 1, 107, 108, 109, 110]
|
||||
else:
|
||||
ignore = []
|
||||
for i in range(b):
|
||||
for j in range(t):
|
||||
k = labels[i][j]
|
||||
if k in ignore:
|
||||
continue
|
||||
else:
|
||||
counting_labels[i][k] += 1
|
||||
counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
|
||||
return counting_labels
|
|
@ -22,7 +22,7 @@ import copy
|
|||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric, DetFCEMetric
|
||||
from .rec_metric import RecMetric, CNTMetric
|
||||
from .rec_metric import RecMetric, CNTMetric, CANMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
@ -38,7 +38,7 @@ def build_metric(config):
|
|||
support_dict = [
|
||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
|
||||
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -13,6 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import numpy as np
|
||||
import string
|
||||
|
||||
|
||||
|
@ -106,3 +109,71 @@ class CNTMetric(object):
|
|||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
|
||||
|
||||
class CANMetric(object):
|
||||
def __init__(self, main_indicator='exp_rate', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.word_right = []
|
||||
self.exp_right = []
|
||||
self.word_total_length = 0
|
||||
self.exp_total_num = 0
|
||||
self.word_rate = 0
|
||||
self.exp_rate = 0
|
||||
self.reset()
|
||||
self.epoch_reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
epoch_reset = v
|
||||
if epoch_reset:
|
||||
self.epoch_reset()
|
||||
word_probs = preds
|
||||
word_label, word_label_mask = batch
|
||||
line_right = 0
|
||||
if word_probs is not None:
|
||||
word_pred = word_probs.argmax(2)
|
||||
word_pred = word_pred.cpu().detach().numpy()
|
||||
word_scores = [
|
||||
SequenceMatcher(
|
||||
None,
|
||||
s1[:int(np.sum(s3))],
|
||||
s2[:int(np.sum(s3))],
|
||||
autojunk=False).ratio() * (
|
||||
len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
|
||||
len(s1[:int(np.sum(s3))]) / 2
|
||||
for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
|
||||
]
|
||||
batch_size = len(word_scores)
|
||||
for i in range(batch_size):
|
||||
if word_scores[i] == 1:
|
||||
line_right += 1
|
||||
self.word_rate = np.mean(word_scores) #float
|
||||
self.exp_rate = line_right / batch_size #float
|
||||
exp_length, word_length = word_label.shape[:2]
|
||||
self.word_right.append(self.word_rate * word_length)
|
||||
self.exp_right.append(self.exp_rate * exp_length)
|
||||
self.word_total_length = self.word_total_length + word_length
|
||||
self.exp_total_num = self.exp_total_num + exp_length
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return {
|
||||
'word_rate': 0,
|
||||
"exp_rate": 0,
|
||||
}
|
||||
"""
|
||||
cur_word_rate = sum(self.word_right) / self.word_total_length
|
||||
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
|
||||
self.reset()
|
||||
return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
|
||||
|
||||
def reset(self):
|
||||
self.word_rate = 0
|
||||
self.exp_rate = 0
|
||||
|
||||
def epoch_reset(self):
|
||||
self.word_right = []
|
||||
self.exp_right = []
|
||||
self.word_total_length = 0
|
||||
self.exp_total_num = 0
|
||||
|
|
|
@ -43,10 +43,12 @@ def build_backbone(config, model_type):
|
|||
from .rec_svtrnet import SVTRNet
|
||||
from .rec_vitstr import ViTSTR
|
||||
from .rec_resnet_rfl import ResNetRFL
|
||||
from .rec_densenet import DenseNet
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
|
||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
|
||||
'DenseNet'
|
||||
]
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/LBH1024/CAN/models/densenet.py
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class Bottleneck(nn.Layer):
|
||||
def __init__(self, nChannels, growthRate, use_dropout):
|
||||
super(Bottleneck, self).__init__()
|
||||
interChannels = 4 * growthRate
|
||||
self.bn1 = nn.BatchNorm2D(interChannels)
|
||||
self.conv1 = nn.Conv2D(
|
||||
nChannels, interChannels, kernel_size=1,
|
||||
bias_attr=None) # Xavier initialization
|
||||
self.bn2 = nn.BatchNorm2D(growthRate)
|
||||
self.conv2 = nn.Conv2D(
|
||||
interChannels, growthRate, kernel_size=3, padding=1,
|
||||
bias_attr=None) # Xavier initialization
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
if self.use_dropout:
|
||||
out = self.dropout(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
if self.use_dropout:
|
||||
out = self.dropout(out)
|
||||
out = paddle.concat([x, out], 1)
|
||||
return out
|
||||
|
||||
|
||||
class SingleLayer(nn.Layer):
|
||||
def __init__(self, nChannels, growthRate, use_dropout):
|
||||
super(SingleLayer, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2D(nChannels)
|
||||
self.conv1 = nn.Conv2D(
|
||||
nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)
|
||||
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(x))
|
||||
if self.use_dropout:
|
||||
out = self.dropout(out)
|
||||
|
||||
out = paddle.concat([x, out], 1)
|
||||
return out
|
||||
|
||||
|
||||
class Transition(nn.Layer):
|
||||
def __init__(self, nChannels, out_channels, use_dropout):
|
||||
super(Transition, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2D(out_channels)
|
||||
self.conv1 = nn.Conv2D(
|
||||
nChannels, out_channels, kernel_size=1, bias_attr=False)
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
if self.use_dropout:
|
||||
out = self.dropout(out)
|
||||
out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False)
|
||||
return out
|
||||
|
||||
|
||||
class DenseNet(nn.Layer):
|
||||
def __init__(self, growthRate, reduction, bottleneck, use_dropout,
|
||||
input_channel, **kwargs):
|
||||
super(DenseNet, self).__init__()
|
||||
|
||||
nDenseBlocks = 16
|
||||
nChannels = 2 * growthRate
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
input_channel,
|
||||
nChannels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
stride=2,
|
||||
bias_attr=False)
|
||||
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
|
||||
bottleneck, use_dropout)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
out_channels = int(math.floor(nChannels * reduction))
|
||||
self.trans1 = Transition(nChannels, out_channels, use_dropout)
|
||||
|
||||
nChannels = out_channels
|
||||
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
|
||||
bottleneck, use_dropout)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
out_channels = int(math.floor(nChannels * reduction))
|
||||
self.trans2 = Transition(nChannels, out_channels, use_dropout)
|
||||
|
||||
nChannels = out_channels
|
||||
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
|
||||
bottleneck, use_dropout)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
|
||||
use_dropout):
|
||||
layers = []
|
||||
for i in range(int(nDenseBlocks)):
|
||||
if bottleneck:
|
||||
layers.append(Bottleneck(nChannels, growthRate, use_dropout))
|
||||
else:
|
||||
layers.append(SingleLayer(nChannels, growthRate, use_dropout))
|
||||
nChannels += growthRate
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
x, x_m, y = inputs
|
||||
out = self.conv1(x)
|
||||
out = F.relu(out)
|
||||
out = F.max_pool2d(out, 2, ceil_mode=True)
|
||||
out = self.dense1(out)
|
||||
out = self.trans1(out)
|
||||
out = self.dense2(out)
|
||||
out = self.trans2(out)
|
||||
out = self.dense3(out)
|
||||
return out, x_m, y
|
|
@ -39,6 +39,7 @@ def build_head(config):
|
|||
from .rec_robustscanner_head import RobustScannerHead
|
||||
from .rec_visionlan_head import VLHead
|
||||
from .rec_rfl_head import RFLHead
|
||||
from .rec_can_head import CANHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -55,7 +56,7 @@ def build_head(config):
|
|||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
|
||||
'DRRGHead'
|
||||
'DRRGHead', 'CANHead'
|
||||
]
|
||||
|
||||
if config['name'] == 'DRRGHead':
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/LBH1024/CAN/models/can.py
|
||||
https://github.com/LBH1024/CAN/models/counting.py
|
||||
https://github.com/LBH1024/CAN/models/decoder.py
|
||||
https://github.com/LBH1024/CAN/models/attention.py
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle
|
||||
import math
|
||||
'''
|
||||
Counting Module
|
||||
'''
|
||||
|
||||
|
||||
class ChannelAtt(nn.Layer):
|
||||
def __init__(self, channel, reduction):
|
||||
super(ChannelAtt, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.shape
|
||||
y = paddle.reshape(self.avg_pool(x), [b, c])
|
||||
y = paddle.reshape(self.fc(y), [b, c, 1, 1])
|
||||
return x * y
|
||||
|
||||
|
||||
class CountingDecoder(nn.Layer):
|
||||
def __init__(self, in_channel, out_channel, kernel_size):
|
||||
super(CountingDecoder, self).__init__()
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
|
||||
self.trans_layer = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
self.in_channel,
|
||||
512,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(512))
|
||||
|
||||
self.channel_att = ChannelAtt(512, 16)
|
||||
|
||||
self.pred_layer = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
512, self.out_channel, kernel_size=1, bias_attr=False),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x, mask):
|
||||
b, _, h, w = x.shape
|
||||
x = self.trans_layer(x)
|
||||
x = self.channel_att(x)
|
||||
x = self.pred_layer(x)
|
||||
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
x = paddle.reshape(x, [b, self.out_channel, -1])
|
||||
x1 = paddle.sum(x, axis=-1)
|
||||
|
||||
return x1, paddle.reshape(x, [b, self.out_channel, h, w])
|
||||
|
||||
|
||||
'''
|
||||
Attention Decoder
|
||||
'''
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Layer):
|
||||
def __init__(self,
|
||||
num_pos_feats=64,
|
||||
temperature=10000,
|
||||
normalize=False,
|
||||
scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask):
|
||||
y_embed = paddle.cumsum(mask, 1, dtype='float32')
|
||||
x_embed = paddle.cumsum(mask, 2, dtype='float32')
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
dim_t = paddle.arange(self.num_pos_feats, dtype='float32')
|
||||
dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
|
||||
dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
|
||||
self.num_pos_feats)
|
||||
|
||||
pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
|
||||
pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
|
||||
|
||||
pos_x = paddle.flatten(
|
||||
paddle.stack(
|
||||
[
|
||||
paddle.sin(pos_x[:, :, :, 0::2]),
|
||||
paddle.cos(pos_x[:, :, :, 1::2])
|
||||
],
|
||||
axis=4),
|
||||
3)
|
||||
pos_y = paddle.flatten(
|
||||
paddle.stack(
|
||||
[
|
||||
paddle.sin(pos_y[:, :, :, 0::2]),
|
||||
paddle.cos(pos_y[:, :, :, 1::2])
|
||||
],
|
||||
axis=4),
|
||||
3)
|
||||
|
||||
pos = paddle.transpose(
|
||||
paddle.concat(
|
||||
[pos_y, pos_x], axis=3), [0, 3, 1, 2])
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class AttDecoder(nn.Layer):
|
||||
def __init__(self, ratio, is_train, input_size, hidden_size,
|
||||
encoder_out_channel, dropout, dropout_ratio, word_num,
|
||||
counting_decoder_out_channel, attention):
|
||||
super(AttDecoder, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.out_channel = encoder_out_channel
|
||||
self.attention_dim = attention['attention_dim']
|
||||
self.dropout_prob = dropout
|
||||
self.ratio = ratio
|
||||
self.word_num = word_num
|
||||
|
||||
self.counting_num = counting_decoder_out_channel
|
||||
self.is_train = is_train
|
||||
|
||||
self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
|
||||
self.embedding = nn.Embedding(self.word_num, self.input_size)
|
||||
self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
|
||||
self.word_attention = Attention(hidden_size, attention['attention_dim'])
|
||||
|
||||
self.encoder_feature_conv = nn.Conv2D(
|
||||
self.out_channel,
|
||||
self.attention_dim,
|
||||
kernel_size=attention['word_conv_kernel'],
|
||||
padding=attention['word_conv_kernel'] // 2)
|
||||
|
||||
self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.word_embedding_weight = nn.Linear(self.input_size,
|
||||
self.hidden_size)
|
||||
self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
|
||||
self.counting_context_weight = nn.Linear(self.counting_num,
|
||||
self.hidden_size)
|
||||
self.word_convert = nn.Linear(self.hidden_size, self.word_num)
|
||||
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout(dropout_ratio)
|
||||
|
||||
def forward(self, cnn_features, labels, counting_preds, images_mask):
|
||||
if self.is_train:
|
||||
_, num_steps = labels.shape
|
||||
else:
|
||||
num_steps = 36
|
||||
|
||||
batch_size, _, height, width = cnn_features.shape
|
||||
images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
|
||||
|
||||
word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
|
||||
word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
|
||||
|
||||
hidden = self.init_hidden(cnn_features, images_mask)
|
||||
counting_context_weighted = self.counting_context_weight(counting_preds)
|
||||
cnn_features_trans = self.encoder_feature_conv(cnn_features)
|
||||
|
||||
position_embedding = PositionEmbeddingSine(256, normalize=True)
|
||||
pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
|
||||
|
||||
cnn_features_trans = cnn_features_trans + pos
|
||||
|
||||
word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos
|
||||
word = word.squeeze(axis=1)
|
||||
for i in range(num_steps):
|
||||
word_embedding = self.embedding(word)
|
||||
_, hidden = self.word_input_gru(word_embedding, hidden)
|
||||
word_context_vec, _, word_alpha_sum = self.word_attention(
|
||||
cnn_features, cnn_features_trans, hidden, word_alpha_sum,
|
||||
images_mask)
|
||||
|
||||
current_state = self.word_state_weight(hidden)
|
||||
word_weighted_embedding = self.word_embedding_weight(word_embedding)
|
||||
word_context_weighted = self.word_context_weight(word_context_vec)
|
||||
|
||||
if self.dropout_prob:
|
||||
word_out_state = self.dropout(
|
||||
current_state + word_weighted_embedding +
|
||||
word_context_weighted + counting_context_weighted)
|
||||
else:
|
||||
word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
|
||||
|
||||
word_prob = self.word_convert(word_out_state)
|
||||
word_probs[:, i] = word_prob
|
||||
|
||||
if self.is_train:
|
||||
word = labels[:, i]
|
||||
else:
|
||||
word = word_prob.argmax(1)
|
||||
word = paddle.multiply(
|
||||
word, labels[:, i]
|
||||
) # labels are oneslike tensor in infer/predict mode
|
||||
|
||||
return word_probs
|
||||
|
||||
def init_hidden(self, features, feature_mask):
|
||||
average = paddle.sum(paddle.sum(features * feature_mask, axis=-1),
|
||||
axis=-1) / paddle.sum(
|
||||
(paddle.sum(feature_mask, axis=-1)), axis=-1)
|
||||
average = self.init_weight(average)
|
||||
return paddle.tanh(average)
|
||||
|
||||
|
||||
'''
|
||||
Attention Module
|
||||
'''
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(self, hidden_size, attention_dim):
|
||||
super(Attention, self).__init__()
|
||||
self.hidden = hidden_size
|
||||
self.attention_dim = attention_dim
|
||||
self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
|
||||
self.attention_conv = nn.Conv2D(
|
||||
1, 512, kernel_size=11, padding=5, bias_attr=False)
|
||||
self.attention_weight = nn.Linear(
|
||||
512, self.attention_dim, bias_attr=False)
|
||||
self.alpha_convert = nn.Linear(self.attention_dim, 1)
|
||||
|
||||
def forward(self,
|
||||
cnn_features,
|
||||
cnn_features_trans,
|
||||
hidden,
|
||||
alpha_sum,
|
||||
image_mask=None):
|
||||
query = self.hidden_weight(hidden)
|
||||
alpha_sum_trans = self.attention_conv(alpha_sum)
|
||||
coverage_alpha = self.attention_weight(
|
||||
paddle.transpose(alpha_sum_trans, [0, 2, 3, 1]))
|
||||
alpha_score = paddle.tanh(
|
||||
paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose(
|
||||
cnn_features_trans, [0, 2, 3, 1]))
|
||||
energy = self.alpha_convert(alpha_score)
|
||||
energy = energy - energy.max()
|
||||
energy_exp = paddle.exp(paddle.squeeze(energy, -1))
|
||||
|
||||
if image_mask is not None:
|
||||
energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
|
||||
alpha = energy_exp / (paddle.unsqueeze(
|
||||
paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10)
|
||||
alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
|
||||
context_vector = paddle.sum(
|
||||
paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1)
|
||||
|
||||
return context_vector, alpha, alpha_sum
|
||||
|
||||
|
||||
class CANHead(nn.Layer):
|
||||
def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
|
||||
super(CANHead, self).__init__()
|
||||
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
|
||||
self.counting_decoder1 = CountingDecoder(self.in_channel,
|
||||
self.out_channel, 3) # mscm
|
||||
self.counting_decoder2 = CountingDecoder(self.in_channel,
|
||||
self.out_channel, 5)
|
||||
|
||||
self.decoder = AttDecoder(ratio, **attdecoder)
|
||||
|
||||
self.ratio = ratio
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
cnn_features, images_mask, labels = inputs
|
||||
|
||||
counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
|
||||
counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
|
||||
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
|
||||
counting_preds = (counting_preds1 + counting_preds2) / 2
|
||||
|
||||
word_probs = self.decoder(cnn_features, labels, counting_preds,
|
||||
images_mask)
|
||||
return word_probs, counting_preds, counting_preds1, counting_preds2
|
|
@ -18,7 +18,7 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from paddle.optimizer import lr
|
||||
from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
|
||||
from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay, TwoStepCosineDecay
|
||||
|
||||
|
||||
class Linear(object):
|
||||
|
@ -386,3 +386,44 @@ class MultiStepDecay(object):
|
|||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class TwoStepCosine(object):
|
||||
"""
|
||||
Cosine learning rate decay
|
||||
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
epochs(int): total training epochs
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
step_each_epoch,
|
||||
epochs,
|
||||
warmup_epoch=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super(TwoStepCosine, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.T_max1 = step_each_epoch * 200
|
||||
self.T_max2 = step_each_epoch * epochs
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = TwoStepCosineDecay(
|
||||
learning_rate=self.learning_rate,
|
||||
T_max1=self.T_max1,
|
||||
T_max2=self.T_max2,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.learning_rate,
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
|
|
@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler):
|
|||
start_step = phase['end_step']
|
||||
|
||||
return computed_lr
|
||||
|
||||
|
||||
class TwoStepCosineDecay(LRScheduler):
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
T_max1,
|
||||
T_max2,
|
||||
eta_min=0,
|
||||
last_epoch=-1,
|
||||
verbose=False):
|
||||
if not isinstance(T_max1, int):
|
||||
raise TypeError(
|
||||
"The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
|
||||
% type(T_max1))
|
||||
if not isinstance(T_max2, int):
|
||||
raise TypeError(
|
||||
"The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
|
||||
% type(T_max2))
|
||||
if not isinstance(eta_min, (float, int)):
|
||||
raise TypeError(
|
||||
"The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
|
||||
% type(eta_min))
|
||||
assert T_max1 > 0 and isinstance(
|
||||
T_max1, int), " 'T_max1' must be a positive integer."
|
||||
assert T_max2 > 0 and isinstance(
|
||||
T_max2, int), " 'T_max1' must be a positive integer."
|
||||
self.T_max1 = T_max1
|
||||
self.T_max2 = T_max2
|
||||
self.eta_min = float(eta_min)
|
||||
super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
|
||||
verbose)
|
||||
|
||||
def get_lr(self):
|
||||
|
||||
if self.last_epoch <= self.T_max1:
|
||||
if self.last_epoch == 0:
|
||||
return self.base_lr
|
||||
elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
|
||||
return self.last_lr + (self.base_lr - self.eta_min) * (
|
||||
1 - math.cos(math.pi / self.T_max1)) / 2
|
||||
|
||||
return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
|
||||
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
|
||||
self.last_lr - self.eta_min) + self.eta_min
|
||||
else:
|
||||
if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
|
||||
return self.last_lr + (self.base_lr - self.eta_min) * (
|
||||
1 - math.cos(math.pi / self.T_max2)) / 2
|
||||
|
||||
return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
|
||||
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
|
||||
self.last_lr - self.eta_min) + self.eta_min
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
if self.last_epoch <= self.T_max1:
|
||||
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
|
||||
math.pi * self.last_epoch / self.T_max1)) / 2
|
||||
else:
|
||||
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
|
||||
math.pi * self.last_epoch / self.T_max2)) / 2
|
||||
|
|
|
@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
|
|||
from .picodet_postprocess import PicoDetPostProcess
|
||||
from .ct_postprocess import CTPostProcess
|
||||
from .drrg_postprocess import DRRGPostprocess
|
||||
from .rec_postprocess import CANLabelDecode
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
@ -51,7 +52,7 @@ def build_post_process(config, global_config=None):
|
|||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
|
||||
'RFLLabelDecode', 'DRRGPostprocess'
|
||||
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -896,3 +896,36 @@ class VLLabelDecode(BaseRecLabelDecode):
|
|||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
|
||||
class CANLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between latex-symbol and symbol-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CANLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def decode(self, text_index, preds_prob=None):
|
||||
result_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
seq_end = text_index[batch_idx].argmin(0)
|
||||
idx_list = text_index[batch_idx][:seq_end].tolist()
|
||||
symbol_list = [self.character[idx] for idx in idx_list]
|
||||
probs = []
|
||||
if preds_prob is not None:
|
||||
probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
|
||||
|
||||
result_list.append([' '.join(symbol_list), probs])
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
pred_prob, _, _, _ = preds
|
||||
preds_idx = pred_prob.argmax(axis=2)
|
||||
|
||||
text = self.decode(preds_idx)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
eos
|
||||
sos
|
||||
!
|
||||
'
|
||||
(
|
||||
)
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
<
|
||||
=
|
||||
>
|
||||
A
|
||||
B
|
||||
C
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
L
|
||||
M
|
||||
N
|
||||
P
|
||||
R
|
||||
S
|
||||
T
|
||||
V
|
||||
X
|
||||
Y
|
||||
[
|
||||
\Delta
|
||||
\alpha
|
||||
\beta
|
||||
\cdot
|
||||
\cdots
|
||||
\cos
|
||||
\div
|
||||
\exists
|
||||
\forall
|
||||
\frac
|
||||
\gamma
|
||||
\geq
|
||||
\in
|
||||
\infty
|
||||
\int
|
||||
\lambda
|
||||
\ldots
|
||||
\leq
|
||||
\lim
|
||||
\log
|
||||
\mu
|
||||
\neq
|
||||
\phi
|
||||
\pi
|
||||
\pm
|
||||
\prime
|
||||
\rightarrow
|
||||
\sigma
|
||||
\sin
|
||||
\sqrt
|
||||
\sum
|
||||
\tan
|
||||
\theta
|
||||
\times
|
||||
]
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
\{
|
||||
|
|
||||
\}
|
||||
{
|
||||
}
|
||||
^
|
||||
_
|
|
@ -0,0 +1,122 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 240
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/can/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 1105 iterations (1 epoch)(batch_size = 8)
|
||||
eval_batch_step: [0, 1105]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/datasets/crohme_demo/hme_00.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
|
||||
max_text_length: 36
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_can.txt
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
clip_norm_global: 100.0
|
||||
lr:
|
||||
name: TwoStepCosine
|
||||
learning_rate: 0.01
|
||||
warmup_epoch: 1
|
||||
weight_decay: 0.0001
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CAN
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: DenseNet
|
||||
growthRate: 24
|
||||
reduction: 0.5
|
||||
bottleneck: True
|
||||
use_dropout: True
|
||||
input_channel: 1
|
||||
Head:
|
||||
name: CANHead
|
||||
in_channel: 684
|
||||
out_channel: 111
|
||||
max_text_length: 36
|
||||
ratio: 16
|
||||
attdecoder:
|
||||
is_train: True
|
||||
input_size: 256
|
||||
hidden_size: 256
|
||||
encoder_out_channel: 684
|
||||
dropout: True
|
||||
dropout_ratio: 0.5
|
||||
word_num: 111
|
||||
counting_decoder_out_channel: 111
|
||||
attention:
|
||||
attention_dim: 512
|
||||
word_conv_kernel: 1
|
||||
|
||||
Loss:
|
||||
name: CANLoss
|
||||
|
||||
PostProcess:
|
||||
name: CANLabelDecode
|
||||
|
||||
Metric:
|
||||
name: CANMetric
|
||||
main_indicator: exp_rate
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/CROHME_lite/training/images/
|
||||
label_file_list: ["./train_data/CROHME_lite/training/labels.txt"]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
channel_first: False
|
||||
- NormalizeImage:
|
||||
mean: [0,0,0]
|
||||
std: [1,1,1]
|
||||
order: 'hwc'
|
||||
- GrayImageChannelFormat:
|
||||
inverse: True
|
||||
- CANLabelEncode:
|
||||
lower: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 8
|
||||
drop_last: False
|
||||
num_workers: 4
|
||||
collate_fn: DyMaskCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/CROHME_lite/evaluation/images/
|
||||
label_file_list: ["./train_data/CROHME_lite/evaluation/labels.txt"]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
channel_first: False
|
||||
- NormalizeImage:
|
||||
mean: [0,0,0]
|
||||
std: [1,1,1]
|
||||
order: 'hwc'
|
||||
- GrayImageChannelFormat:
|
||||
inverse: True
|
||||
- CANLabelEncode:
|
||||
lower: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1
|
||||
num_workers: 4
|
||||
collate_fn: DyMaskCollator
|
|
@ -0,0 +1,53 @@
|
|||
===========================train_params===========================
|
||||
model_name:rec_d28_can
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=240
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=8
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./doc/datasets/crohme_demo
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:./inference/rec_d28_can_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_algorithm="CAN"
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--rec_model_dir:
|
||||
--image_dir:./doc/datasets/crohme_demo
|
||||
--save_log_path:./test/output/
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[1,100,100]}]
|
|
@ -287,6 +287,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
|
|||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate
|
||||
cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../
|
||||
fi
|
||||
if [ ${model_name} == "rec_d28_can" ]; then
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/CROHME_lite.tar --no-check-certificate
|
||||
cd ./train_data/ && tar xf CROHME_lite.tar && cd ../
|
||||
fi
|
||||
|
||||
elif [ ${MODE} = "whole_train_whole_infer" ];then
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
| SAST |det_r50_vd_sast_totaltext_v2.0 | 检测 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| Rosetta|rec_mv3_none_none_ctc_v2.0 | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| Rosetta|rec_r34_vd_none_none_ctc_v2.0 | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| CAN |rec_d28_can | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| CRNN |rec_mv3_none_bilstm_ctc_v2.0 | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| CRNN |rec_r34_vd_none_bilstm_ctc_v2.0| 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
| StarNet|rec_mv3_tps_bilstm_ctc_v2.0 | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
|
||||
|
|
|
@ -74,7 +74,9 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
@ -83,7 +85,10 @@ def main():
|
|||
else:
|
||||
extra_input = config['Architecture']['algorithm'] in extra_input_models
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
if config['Architecture']['algorithm'] == 'CAN':
|
||||
model_type = 'can'
|
||||
else:
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
model_type = None
|
||||
|
||||
|
@ -92,7 +97,7 @@ def main():
|
|||
# amp
|
||||
use_amp = config["Global"].get("use_amp", False)
|
||||
amp_level = config["Global"].get("amp_level", 'O2')
|
||||
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
|
||||
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
|
||||
if use_amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
|
@ -120,7 +125,8 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list)
|
||||
eval_class, model_type, extra_input, scaler,
|
||||
amp_level, amp_custom_black_list)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -123,6 +123,17 @@ def export_single_model(model,
|
|||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "CAN":
|
||||
other_shape = [[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, None, None],
|
||||
dtype="float32"), paddle.static.InputSpec(
|
||||
shape=[None, 1, None, None], dtype="float32"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, arch_config['Head']['max_text_length']],
|
||||
dtype="int64")
|
||||
]]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(
|
||||
|
|
|
@ -108,6 +108,13 @@ class TextRecognizer(object):
|
|||
}
|
||||
elif self.rec_algorithm == "PREN":
|
||||
postprocess_params = {'name': 'PRENLabelDecode'}
|
||||
elif self.rec_algorithm == "CAN":
|
||||
self.inverse = args.rec_image_inverse
|
||||
postprocess_params = {
|
||||
'name': 'CANLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -351,6 +358,30 @@ class TextRecognizer(object):
|
|||
|
||||
return resized_image
|
||||
|
||||
def norm_img_can(self, img, image_shape):
|
||||
|
||||
img = cv2.cvtColor(
|
||||
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||
|
||||
if self.inverse:
|
||||
img = 255 - img
|
||||
|
||||
if self.rec_image_shape[0] == 1:
|
||||
h, w = img.shape
|
||||
_, imgH, imgW = self.rec_image_shape
|
||||
if h < imgH or w < imgW:
|
||||
padding_h = max(imgH - h, 0)
|
||||
padding_w = max(imgW - w, 0)
|
||||
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
||||
'constant',
|
||||
constant_values=(255))
|
||||
img = img_padded
|
||||
|
||||
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||
img = img.astype('float32')
|
||||
|
||||
return img
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -430,6 +461,17 @@ class TextRecognizer(object):
|
|||
word_positions = np.array(range(0, 40)).astype('int64')
|
||||
word_positions = np.expand_dims(word_positions, axis=0)
|
||||
word_positions_list.append(word_positions)
|
||||
elif self.rec_algorithm == "CAN":
|
||||
norm_img = self.norm_img_can(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
|
||||
word_label = np.ones([1, 36], dtype='int64')
|
||||
norm_img_mask_batch = []
|
||||
word_label_list = []
|
||||
norm_img_mask_batch.append(norm_image_mask)
|
||||
word_label_list.append(word_label)
|
||||
else:
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
|
@ -527,6 +569,33 @@ class TextRecognizer(object):
|
|||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
elif self.rec_algorithm == "CAN":
|
||||
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
|
||||
word_label_list = np.concatenate(word_label_list)
|
||||
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
|
||||
if self.use_onnx:
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = norm_img_batch
|
||||
outputs = self.predictor.run(self.output_tensors,
|
||||
input_dict)
|
||||
preds = outputs
|
||||
else:
|
||||
input_names = self.predictor.get_input_names()
|
||||
input_tensor = []
|
||||
for i in range(len(input_names)):
|
||||
input_tensor_i = self.predictor.get_input_handle(
|
||||
input_names[i])
|
||||
input_tensor_i.copy_from_cpu(inputs[i])
|
||||
input_tensor.append(input_tensor_i)
|
||||
self.input_tensor = input_tensor
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs
|
||||
else:
|
||||
if self.use_onnx:
|
||||
input_dict = {}
|
||||
|
|
|
@ -84,6 +84,7 @@ def init_args():
|
|||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
|
||||
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
|
|
|
@ -141,6 +141,11 @@ def main():
|
|||
paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
if config['Architecture']['algorithm'] == "CAN":
|
||||
image_mask = paddle.ones(
|
||||
(np.expand_dims(
|
||||
batch[0], axis=0).shape), dtype='float32')
|
||||
label = paddle.ones((1, 36), dtype='int64')
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
|
@ -149,6 +154,8 @@ def main():
|
|||
preds = model(images, img_metas)
|
||||
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||
preds = model(images, img_metas)
|
||||
elif config['Architecture']['algorithm'] == "CAN":
|
||||
preds = model([images, image_mask, label])
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
|
|
|
@ -273,6 +273,8 @@ def train(config,
|
|||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie"]:
|
||||
preds = model(batch)
|
||||
elif algorithm in ['CAN']:
|
||||
preds = model(batch[:3])
|
||||
else:
|
||||
preds = model(images)
|
||||
preds = to_float32(preds)
|
||||
|
@ -286,6 +288,8 @@ def train(config,
|
|||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie", 'sr']:
|
||||
preds = model(batch)
|
||||
elif algorithm in ['CAN']:
|
||||
preds = model(batch[:3])
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
|
@ -302,6 +306,9 @@ def train(config,
|
|||
elif model_type in ['table']:
|
||||
post_result = post_process_class(preds, batch)
|
||||
eval_class(post_result, batch)
|
||||
elif algorithm in ['CAN']:
|
||||
model_type = 'can'
|
||||
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
|
||||
else:
|
||||
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
|
||||
]: # for multi head loss
|
||||
|
@ -496,6 +503,8 @@ def eval(model,
|
|||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie"]:
|
||||
preds = model(batch)
|
||||
elif model_type in ['can']:
|
||||
preds = model(batch[:3])
|
||||
elif model_type in ['sr']:
|
||||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
|
@ -508,6 +517,8 @@ def eval(model,
|
|||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie"]:
|
||||
preds = model(batch)
|
||||
elif model_type in ['can']:
|
||||
preds = model(batch[:3])
|
||||
elif model_type in ['sr']:
|
||||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
|
@ -532,6 +543,8 @@ def eval(model,
|
|||
eval_class(post_result, batch_numpy)
|
||||
elif model_type in ['sr']:
|
||||
eval_class(preds, batch_numpy)
|
||||
elif model_type in ['can']:
|
||||
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
||||
else:
|
||||
post_result = post_process_class(preds, batch_numpy[1])
|
||||
eval_class(post_result, batch_numpy)
|
||||
|
@ -629,7 +642,7 @@ def preprocess(is_train=False):
|
|||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG'
|
||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
Loading…
Reference in New Issue