diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 43c046c5c..638d869a3 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -69,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
+- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
@@ -91,6 +92,7 @@
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
diff --git a/doc/doc_ch/algorithm_rec_visionlan.md b/doc/doc_ch/algorithm_rec_visionlan.md
new file mode 100644
index 000000000..0c4fe86e5
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_visionlan.md
@@ -0,0 +1,154 @@
+# 场景文本识别算法-VisionLAN
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
+> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
+> ICCV, 2021
+
+
+
+`VisionLAN`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`VisionLAN`识别模型时需要**更换配置文件**为`VisionLAN`的[配置文件](../../configs/rec/rec_r45_visionlan.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
+```
+
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应VisionLAN的`infer_shape`。
+
+转换成功后,在目录下有三个文件:
+```
+./inference/rec_r45_visionlan/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
+```
+
+
+
+执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
+结果如下:
+```shell
+Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
+```
+
+**注意**:
+
+- 训练上述模型采用的图像分辨率是[3,64,256],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中VisionLAN的预处理为您的预处理方法。
+
+
+
+## 5. FAQ
+
+1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 。
+2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
+
+## 引用
+
+```bibtex
+@inproceedings{wang2021two,
+ title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
+ author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={14194--14203},
+ year={2021}
+}
+```
diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index 2cf073221..eba521350 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -65,7 +65,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
```
-上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
+上述指令中,通过-c 选择训练使用configs/det/det_mv3_db.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)。
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
diff --git a/doc/doc_en/PP-OCRv3_introduction_en.md b/doc/doc_en/PP-OCRv3_introduction_en.md
index 481e0b817..815ad9b0e 100644
--- a/doc/doc_en/PP-OCRv3_introduction_en.md
+++ b/doc/doc_en/PP-OCRv3_introduction_en.md
@@ -55,10 +55,11 @@ The ablation experiments are as follows:
|ID|Strategy|Model Size|Hmean|The Inference Time(cpu + mkldnn)|
|-|-|-|-|-|
-|baseline teacher|DB-R50|99M|83.5%|260ms|
+|baseline teacher|PP-OCR server|49M|83.2%|171ms|
|teacher1|DB-R50-LK-PAN|124M|85.0%|396ms|
|teacher2|DB-R50-LK-PAN-DML|124M|86.0%|396ms|
|baseline student|PP-OCRv2|3M|83.2%|117ms|
+|student0|DB-MV3-RSE-FPN|3.6M|84.5%|124ms|
|student1|DB-MV3-CML(teacher2)|3M|84.3%|117ms|
|student2|DB-MV3-RSE-FPN-CML(teacher2)|3.6M|85.4%|124ms|
@@ -199,7 +200,7 @@ UDML (Unified-Deep Mutual Learning) is a strategy proposed in PP-OCRv2 which is
**(6)UIM:Unlabeled Images Mining**
-UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%).
+UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%). In practice, we use the full data set to train the high-precision SVTR_Tiny model (acc=82.5%) for data mining. [SVTR_Tiny model download and tutorial](../../applications/高精度中文识别模型.md).

diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index dd819790d..3412ccbf7 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
+- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
@@ -90,6 +91,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
diff --git a/doc/doc_en/algorithm_rec_visionlan_en.md b/doc/doc_en/algorithm_rec_visionlan_en.md
new file mode 100644
index 000000000..ebd02d52f
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_visionlan_en.md
@@ -0,0 +1,135 @@
+# VisionLAN
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
+> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
+> ICCV, 2021
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
+```
+
+**Note:**
+- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
+- If you modified the input size during training, please modify the `infer_shape` corresponding to VisionLAN in the `tools/export_model.py` file.
+
+After the conversion is successful, there are three files in the directory:
+```
+./inference/rec_r45_visionlan/
+ ├── inference.pdiparams
+ ├── inference.pdiparams.info
+ └── inference.pdmodel
+```
+
+
+For VisionLAN text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
+```
+
+
+
+After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
+The result is as follows:
+```shell
+Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
+2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
+
+## Citation
+
+```bibtex
+@inproceedings{wang2021two,
+ title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
+ author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={14194--14203},
+ year={2021}
+}
+```
diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index f85bf585c..c215e1a46 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -51,7 +51,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
-In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
+In the above instruction, use `-c` to select the training to use the `configs/det/det_mv3_db.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 30d422b61..0da940a3a 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -25,8 +25,9 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg, RobustScannerRecResizeImg
+ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
+
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 97539faf2..1656c6952 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -23,7 +23,10 @@ import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
+from random import sample
+
from ppocr.utils.logging import get_logger
+from ppocr.data.imaug.vqa.augment import order_by_tbyx
class ClsLabelEncode(object):
@@ -97,12 +100,13 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
- use_space_char=False):
+ use_space_char=False,
+ lower=False):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
- self.lower = False
+ self.lower = lower
if character_dict_path is None:
logger = get_logger()
@@ -870,6 +874,7 @@ class VQATokenLabelEncode(object):
add_special_ids=False,
algorithm='LayoutXLM',
use_textline_bbox_info=True,
+ order_method=None,
infer_mode=False,
ocr_engine=None,
**kwargs):
@@ -899,6 +904,8 @@ class VQATokenLabelEncode(object):
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
+ self.order_method = order_method
+ assert self.order_method in [None, "tb-yx"]
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
@@ -938,6 +945,14 @@ class VQATokenLabelEncode(object):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
+ for idx in range(len(ocr_info)):
+ if "bbox" not in ocr_info[idx]:
+ ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
+ "points"])
+
+ if self.order_method == "tb-yx":
+ ocr_info = order_by_tbyx(ocr_info)
+
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
@@ -977,7 +992,10 @@ class VQATokenLabelEncode(object):
info["bbox"] = self.trans_poly_to_bbox(info["points"])
encode_res = self.tokenizer.encode(
- text, pad_to_max_seq_len=False, return_attention_mask=True)
+ text,
+ pad_to_max_seq_len=False,
+ return_attention_mask=True,
+ return_token_type_ids=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
@@ -1049,10 +1067,10 @@ class VQATokenLabelEncode(object):
return data
def trans_poly_to_bbox(self, poly):
- x1 = np.min([p[0] for p in poly])
- x2 = np.max([p[0] for p in poly])
- y1 = np.min([p[1] for p in poly])
- y2 = np.max([p[1] for p in poly])
+ x1 = int(np.min([p[0] for p in poly]))
+ x2 = int(np.max([p[0] for p in poly]))
+ y1 = int(np.min([p[1] for p in poly]))
+ y2 = int(np.max([p[1] for p in poly]))
return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
@@ -1217,6 +1235,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character = [''] + dict_character
return dict_character
+
class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
@@ -1229,6 +1248,7 @@ class SPINLabelEncode(AttnLabelEncode):
super(SPINLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower
+
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
@@ -1248,4 +1268,68 @@ class SPINLabelEncode(AttnLabelEncode):
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
- return data
\ No newline at end of file
+ return data
+
+
+class VLLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(VLLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char, lower)
+ self.character = self.character[10:] + self.character[
+ 1:10] + [self.character[0]]
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+
+ def __call__(self, data):
+ text = data['label'] # original string
+ # generate occluded text
+ len_str = len(text)
+ if len_str <= 0:
+ return None
+ change_num = 1
+ order = list(range(len_str))
+ change_id = sample(order, change_num)[0]
+ label_sub = text[change_id]
+ if change_id == (len_str - 1):
+ label_res = text[:change_id]
+ elif change_id == 0:
+ label_res = text[1:]
+ else:
+ label_res = text[:change_id] + text[change_id + 1:]
+
+ data['label_res'] = label_res # remaining string
+ data['label_sub'] = label_sub # occluded character
+ data['label_id'] = change_id # character index
+ # encode label
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = [i + 1 for i in text]
+ data['length'] = np.array(len(text))
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ label_res = self.encode(label_res)
+ label_sub = self.encode(label_sub)
+ if label_res is None:
+ label_res = []
+ else:
+ label_res = [i + 1 for i in label_res]
+ if label_sub is None:
+ label_sub = []
+ else:
+ label_sub = [i + 1 for i in label_sub]
+ data['length_res'] = np.array(len(label_res))
+ data['length_sub'] = np.array(len(label_sub))
+ label_res = label_res + [0] * (self.max_text_len - len(label_res))
+ label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
+ data['label_res'] = np.array(label_res)
+ data['label_sub'] = np.array(label_sub)
+ return data
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 8b2309f44..a5e0de849 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -205,6 +205,38 @@ class RecResizeImg(object):
return data
+class VLRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ infer_mode=False,
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+ padding=True,
+ **kwargs):
+ self.image_shape = image_shape
+ self.infer_mode = infer_mode
+ self.character_dict_path = character_dict_path
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+
+ imgC, imgH, imgW = self.image_shape
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ resized_image = resized_image.astype('float32')
+ if self.image_shape[0] == 1:
+ resized_image = resized_image / 255
+ norm_img = resized_image[np.newaxis, :]
+ else:
+ norm_img = resized_image.transpose((2, 0, 1)) / 255
+ valid_ratio = min(1.0, float(resized_w / imgW))
+
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
@@ -259,6 +291,7 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32)
return data
+
class SPINRecResizeImg(object):
def __init__(self,
image_shape,
@@ -267,7 +300,7 @@ class SPINRecResizeImg(object):
std=(127.5, 127.5, 127.5),
**kwargs):
self.image_shape = image_shape
-
+
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.interpolation = interpolation
@@ -303,6 +336,7 @@ class SPINRecResizeImg(object):
data['image'] = img
return data
+
class GrayRecResizeImg(object):
def __init__(self,
image_shape,
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index bde175115..34189bcef 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -13,12 +13,10 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
-from .augment import DistortBBox
__all__ = [
'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
- 'DistortBBox',
]
diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py
index fcdc9685e..b95fcdf0f 100644
--- a/ppocr/data/imaug/vqa/augment.py
+++ b/ppocr/data/imaug/vqa/augment.py
@@ -16,22 +16,18 @@ import os
import sys
import numpy as np
import random
+from copy import deepcopy
-class DistortBBox:
- def __init__(self, prob=0.5, max_scale=1, **kwargs):
- """Random distort bbox
- """
- self.prob = prob
- self.max_scale = max_scale
-
- def __call__(self, data):
- if random.random() > self.prob:
- return data
- bbox = np.array(data['bbox'])
- rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
- bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
- data['bbox'] = np.clip(data['bbox'], 0, 1000)
- data['bbox'] = bbox.tolist()
- sys.stdout.flush()
- return data
+def order_by_tbyx(ocr_info):
+ res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
+ for i in range(len(res) - 1):
+ for j in range(i, 0, -1):
+ if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
+ (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
+ tmp = deepcopy(res[j])
+ res[j] = deepcopy(res[j + 1])
+ res[j + 1] = deepcopy(tmp)
+ else:
+ break
+ return res
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 30120ac56..bb82c7e00 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
+from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
# cls loss
@@ -63,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 74490791c..da9faa08b 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -63,18 +63,21 @@ class KLJSLoss(object):
def __call__(self, p1, p2, reduction="mean"):
if self.mode.lower() == 'kl':
- loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+ loss = paddle.multiply(p2,
+ paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
- p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
- loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
+ loss = paddle.multiply(
+ p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
- p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
+ p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
else:
- raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
-
+ raise ValueError(
+ "The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
+
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
@@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer):
self.reduction = reduction
def forward(self, predicts, batch):
- loss = predicts[self.key]
+ loss = predicts
+ if self.key is not None and isinstance(predicts, dict):
+ loss = loss[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index f4cdee8f9..8d697d544 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
+from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
+from .distillation_loss import DistillationLossFromOutput
+from .distillation_loss import DistillationVQADistanceLoss
class CombinedLoss(nn.Layer):
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 565b066d1..87fed6235 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
+from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
+from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def _sum_loss(loss_dict):
@@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
+
+
+class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
+ def __init__(self,
+ num_classes,
+ model_name_list=[],
+ key=None,
+ name="loss_ser"):
+ super().__init__(num_classes=num_classes)
+ self.model_name_list = model_name_list
+ self.key = key
+ self.name = name
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ loss = super().forward(out, batch)
+ loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
+ return loss_dict
+
+
+class DistillationLossFromOutput(LossFromOutput):
+ def __init__(self,
+ reduction="none",
+ model_name_list=[],
+ dist_key=None,
+ key="loss",
+ name="loss_re"):
+ super().__init__(key=key, reduction=reduction)
+ self.model_name_list = model_name_list
+ self.name = name
+ self.dist_key = dist_key
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.dist_key is not None:
+ out = out[self.dist_key]
+ loss = super().forward(out, batch)
+ loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
+ return loss_dict
+
+
+class DistillationSERDMLLoss(DMLLoss):
+ """
+ """
+
+ def __init__(self,
+ act="softmax",
+ use_log=True,
+ num_classes=7,
+ model_name_pairs=[],
+ key=None,
+ name="loss_dml_ser"):
+ super().__init__(act=act, use_log=use_log)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.name = name
+ self.num_classes = num_classes
+ self.model_name_pairs = model_name_pairs
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ out1 = out1.reshape([-1, out1.shape[-1]])
+ out2 = out2.reshape([-1, out2.shape[-1]])
+
+ attention_mask = batch[2]
+ if attention_mask is not None:
+ active_output = attention_mask.reshape([-1, ]) == 1
+ out1 = out1[active_output]
+ out2 = out2[active_output]
+
+ loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
+ out2)
+
+ return loss_dict
+
+
+class DistillationVQADistanceLoss(DistanceLoss):
+ def __init__(self,
+ mode="l2",
+ model_name_pairs=[],
+ key=None,
+ name="loss_distance",
+ **kargs):
+ super().__init__(mode=mode, **kargs)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.model_name_pairs = model_name_pairs
+ self.name = name + "_l2"
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ attention_mask = batch[2]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ if attention_mask is not None:
+ max_len = attention_mask.shape[-1]
+ out1 = out1[:, :max_len]
+ out2 = out2[:, :max_len]
+ out1 = out1.reshape([-1, out1.shape[-1]])
+ out2 = out2.reshape([-1, out2.shape[-1]])
+ if attention_mask is not None:
+ active_output = attention_mask.reshape([-1, ]) == 1
+ out1 = out1[active_output]
+ out2 = out2[active_output]
+
+ loss = super().forward(out1, out2)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}nohu_{}".format(self.name, key,
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
+ idx)] = loss
+ return loss_dict
diff --git a/ppocr/losses/rec_vl_loss.py b/ppocr/losses/rec_vl_loss.py
new file mode 100644
index 000000000..5cd87c709
--- /dev/null
+++ b/ppocr/losses/rec_vl_loss.py
@@ -0,0 +1,70 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/wangyuxin87/VisionLAN
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class VLLoss(nn.Layer):
+ def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
+ super(VLLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
+ assert mode in ['LF_1', 'LF_2', 'LA']
+ self.mode = mode
+ self.weight_res = weight_res
+ self.weight_mas = weight_mas
+
+ def flatten_label(self, target):
+ label_flatten = []
+ label_length = []
+ for i in range(0, target.shape[0]):
+ cur_label = target[i].tolist()
+ label_flatten += cur_label[:cur_label.index(0) + 1]
+ label_length.append(cur_label.index(0) + 1)
+ label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
+ label_length = paddle.to_tensor(label_length, dtype='int32')
+ return (label_flatten, label_length)
+
+ def _flatten(self, sources, lengths):
+ return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
+
+ def forward(self, predicts, batch):
+ text_pre = predicts[0]
+ target = batch[1].astype('int64')
+ label_flatten, length = self.flatten_label(target)
+ text_pre = self._flatten(text_pre, length)
+ if self.mode == 'LF_1':
+ loss = self.loss_func(text_pre, label_flatten)
+ else:
+ text_rem = predicts[1]
+ text_mas = predicts[2]
+ target_res = batch[2].astype('int64')
+ target_sub = batch[3].astype('int64')
+ label_flatten_res, length_res = self.flatten_label(target_res)
+ label_flatten_sub, length_sub = self.flatten_label(target_sub)
+ text_rem = self._flatten(text_rem, length_res)
+ text_mas = self._flatten(text_mas, length_sub)
+ loss_ori = self.loss_func(text_pre, label_flatten)
+ loss_res = self.loss_func(text_rem, label_flatten_res)
+ loss_mas = self.loss_func(text_mas, label_flatten_sub)
+ loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
+ return {'loss': loss}
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index f9cd46347..5d564c0e2 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -17,26 +17,30 @@ from __future__ import division
from __future__ import print_function
from paddle import nn
+from ppocr.losses.basic_loss import DMLLoss
class VQASerTokenLayoutLMLoss(nn.Layer):
- def __init__(self, num_classes):
+ def __init__(self, num_classes, key=None):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
+ self.key = key
def forward(self, predicts, batch):
+ if isinstance(predicts, dict) and self.key is not None:
+ predicts = predicts[self.key]
labels = batch[5]
attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
- active_outputs = predicts.reshape(
+ active_output = predicts.reshape(
[-1, self.num_classes])[active_loss]
- active_labels = labels.reshape([-1, ])[active_loss]
- loss = self.loss_class(active_outputs, active_labels)
+ active_label = labels.reshape([-1, ])[active_loss]
+ loss = self.loss_class(active_output, active_label)
else:
loss = self.loss_class(
predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ]))
- return {'loss': loss}
+ return {'loss': loss}
\ No newline at end of file
diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py
index c440cebdd..e2cbc4dc0 100644
--- a/ppocr/metrics/distillation_metric.py
+++ b/ppocr/metrics/distillation_metric.py
@@ -19,6 +19,8 @@ from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
+from .vqa_token_ser_metric import VQASerTokenMetric
+from .vqa_token_re_metric import VQAReTokenMetric
class DistillationMetric(object):
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index c6b50d488..ed2a909cb 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -73,28 +73,40 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
+
y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
- y["backbone_out"] = x
- if self.use_neck:
- x = self.neck(x)
- y["neck_out"] = x
- if self.use_head:
- x = self.head(x, targets=data)
- # for multi head, save ctc neck out for udml
- if isinstance(x, dict) and 'ctc_neck' in x.keys():
- y["neck_out"] = x["ctc_neck"]
- y["head_out"] = x
- elif isinstance(x, dict):
+ if isinstance(x, dict):
y.update(x)
else:
- y["head_out"] = x
+ y["backbone_out"] = x
+ final_name = "backbone_out"
+ if self.use_neck:
+ x = self.neck(x)
+ if isinstance(x, dict):
+ y.update(x)
+ else:
+ y["neck_out"] = x
+ final_name = "neck_out"
+ if self.use_head:
+ x = self.head(x, targets=data)
+ # for multi head, save ctc neck out for udml
+ if isinstance(x, dict) and 'ctc_neck' in x.keys():
+ y["neck_out"] = x["ctc_neck"]
+ y["head_out"] = x
+ elif isinstance(x, dict):
+ y.update(x)
+ else:
+ y["head_out"] = x
+ final_name = "head_out"
if self.return_all_feats:
if self.training:
return y
+ elif isinstance(x, dict):
+ return x
else:
- return {"head_out": y["head_out"]}
+ return {final_name: x}
else:
return x
diff --git a/ppocr/modeling/backbones/rec_resnet_45.py b/ppocr/modeling/backbones/rec_resnet_45.py
index 9093d0bc9..083eb7f48 100644
--- a/ppocr/modeling/backbones/rec_resnet_45.py
+++ b/ppocr/modeling/backbones/rec_resnet_45.py
@@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
class ResNet45(nn.Layer):
- def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
+ def __init__(self,
+ in_channels=3,
+ block=BasicBlock,
+ layers=[3, 4, 6, 6, 3],
+ strides=[2, 1, 2, 1, 1]):
self.inplanes = 32
super(ResNet45, self).__init__()
self.conv1 = nn.Conv2D(
- 3,
+ in_channels,
32,
kernel_size=3,
stride=1,
@@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
- self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
- self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
- self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
- self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
- self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
self.out_channels = 512
- # for m in self.modules():
- # if isinstance(m, nn.Conv2D):
- # n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- # m.weight.data.normal_(0, math.sqrt(2. / n))
-
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
@@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
- # print(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
- # print(x)
x = self.layer4(x)
x = self.layer5(x)
return x
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
index 6a2710dfa..782dc393e 100644
--- a/ppocr/modeling/backbones/rec_resnet_aster.py
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -140,4 +140,4 @@ class ResNet_ASTER(nn.Layer):
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
- return cnn_feat
+ return cnn_feat
\ No newline at end of file
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index 34dd9d10e..d4ced3508 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -22,13 +22,22 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
+from paddlenlp.transformers import AutoModel
-__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
+__all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
pretrained_model_dict = {
- LayoutXLMModel: 'layoutxlm-base-uncased',
- LayoutLMModel: 'layoutlm-base-uncased',
- LayoutLMv2Model: 'layoutlmv2-base-uncased'
+ LayoutXLMModel: {
+ "base": "layoutxlm-base-uncased",
+ "vi": "layoutxlm-wo-backbone-base-uncased",
+ },
+ LayoutLMModel: {
+ "base": "layoutlm-base-uncased",
+ },
+ LayoutLMv2Model: {
+ "base": "layoutlmv2-base-uncased",
+ "vi": "layoutlmv2-wo-backbone-base-uncased",
+ },
}
@@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
- type='ser',
+ mode="base",
+ type="ser",
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
- if checkpoints is not None:
+ if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints)
- elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
- self.model = model_class.from_pretrained(pretrained)
- else:
- pretrained_model_name = pretrained_model_dict[base_model_class]
+ else: # load the pretrained-model
+ pretrained_model_name = pretrained_model_dict[base_model_class][
+ mode]
if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
- base_model = base_model_class(
- **base_model_class.pretrained_init_configuration[
- pretrained_model_name])
- if type == 'ser':
+ base_model = base_model_class.from_pretrained(pretrained)
+ if type == "ser":
self.model = model_class(
- base_model, num_classes=kwargs['num_classes'], dropout=None)
+ base_model, num_classes=kwargs["num_classes"], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
+ self.use_visual_backbone = True
class LayoutLMForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
- num_classes=num_classes)
+ num_classes=num_classes, )
+ self.use_visual_backbone = False
def forward(self, x):
x = self.model(
@@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutLMv2ForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
num_classes=num_classes)
+ self.use_visual_backbone = True
+ if hasattr(self.model.layoutlmv2, "use_visual_backbone"
+ ) and self.model.layoutlmv2.use_visual_backbone is False:
+ self.use_visual_backbone = False
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None)
- if not self.training:
+ if self.training:
+ res = {"backbone_out": x[0]}
+ res.update(x[1])
+ return res
+ else:
return x
- return x[0]
class LayoutXLMForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
num_classes=num_classes)
+ self.use_visual_backbone = True
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None)
- if not self.training:
+ if self.training:
+ res = {"backbone_out": x[0]}
+ res.update(x[1])
+ return res
+ else:
return x
- return x[0]
class LayoutLMv2ForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, **kwargs):
- super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
- LayoutLMv2ForRelationExtraction,
- 're', pretrained, checkpoints)
+ def __init__(self, pretrained=True, checkpoints=None, mode="base",
+ **kwargs):
+ super(LayoutLMv2ForRe, self).__init__(
+ LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
+ pretrained, checkpoints)
def forward(self, x):
x = self.model(
@@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, **kwargs):
- super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
- LayoutXLMForRelationExtraction,
- 're', pretrained, checkpoints)
+ def __init__(self, pretrained=True, checkpoints=None, mode="base",
+ **kwargs):
+ super(LayoutXLMForRe, self).__init__(
+ LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
+ pretrained, checkpoints)
+ self.use_visual_backbone = True
+ if hasattr(self.model.layoutxlm, "use_visual_backbone"
+ ) and self.model.layoutxlm.use_visual_backbone is False:
+ self.use_visual_backbone = False
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None,
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index ca861c352..190622329 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -36,6 +36,7 @@ def build_head(config):
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
+ from .rec_visionlan_head import VLHead
# cls head
from .cls_head import ClsHead
@@ -50,7 +51,8 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'RobustScannerHead'
+ 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
+ 'VLHead', 'RobustScannerHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py
new file mode 100644
index 000000000..86054d9bb
--- /dev/null
+++ b/ppocr/modeling/heads/rec_visionlan_head.py
@@ -0,0 +1,468 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/wangyuxin87/VisionLAN
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.initializer import Normal, XavierNormal
+import numpy as np
+
+
+class PositionalEncoding(nn.Layer):
+ def __init__(self, d_hid, n_position=200):
+ super(PositionalEncoding, self).__init__()
+ self.register_buffer(
+ 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
+ ''' Sinusoid position encoding table '''
+
+ def get_position_angle_vec(position):
+ return [
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ]
+
+ sinusoid_table = np.array(
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
+ sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
+ return sinusoid_table
+
+ def forward(self, x):
+ return x + self.pos_table[:, :x.shape[1]].clone().detach()
+
+
+class ScaledDotProductAttention(nn.Layer):
+ "Scaled Dot-Product Attention"
+
+ def __init__(self, temperature, attn_dropout=0.1):
+ super(ScaledDotProductAttention, self).__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout)
+ self.softmax = nn.Softmax(axis=2)
+
+ def forward(self, q, k, v, mask=None):
+ k = paddle.transpose(k, perm=[0, 2, 1])
+ attn = paddle.bmm(q, k)
+ attn = attn / self.temperature
+ if mask is not None:
+ attn = attn.masked_fill(mask, -1e9)
+ if mask.dim() == 3:
+ mask = paddle.unsqueeze(mask, axis=1)
+ elif mask.dim() == 2:
+ mask = paddle.unsqueeze(mask, axis=1)
+ mask = paddle.unsqueeze(mask, axis=1)
+ repeat_times = [
+ attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
+ ]
+ mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
+ attn[mask == 0] = -1e9
+ attn = self.softmax(attn)
+ attn = self.dropout(attn)
+ output = paddle.bmm(attn, v)
+ return output
+
+
+class MultiHeadAttention(nn.Layer):
+ " Multi-Head Attention module"
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+ self.w_qs = nn.Linear(
+ d_model,
+ n_head * d_k,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ self.w_ks = nn.Linear(
+ d_model,
+ n_head * d_k,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ self.w_vs = nn.Linear(
+ d_model,
+ n_head * d_v,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
+
+ self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
+ 0.5))
+ self.layer_norm = nn.LayerNorm(d_model)
+ self.fc = nn.Linear(
+ n_head * d_v,
+ d_model,
+ weight_attr=ParamAttr(initializer=XavierNormal()))
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, _ = q.shape
+ sz_b, len_k, _ = k.shape
+ sz_b, len_v, _ = v.shape
+ residual = q
+
+ q = self.w_qs(q)
+ q = paddle.reshape(
+ q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
+ k = self.w_ks(k)
+ k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
+ v = self.w_vs(v)
+ v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
+
+ q = paddle.transpose(q, perm=[2, 0, 1, 3])
+ q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
+ k = paddle.transpose(k, perm=[2, 0, 1, 3])
+ k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
+ v = paddle.transpose(v, perm=[2, 0, 1, 3])
+ v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
+
+ mask = paddle.tile(
+ mask,
+ [n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
+ output = self.attention(q, k, v, mask=mask)
+ output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
+ output = paddle.transpose(output, perm=[1, 2, 0, 3])
+ output = paddle.reshape(
+ output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
+ output = self.dropout(self.fc(output))
+ output = self.layer_norm(output + residual)
+ return output
+
+
+class PositionwiseFeedForward(nn.Layer):
+ def __init__(self, d_in, d_hid, dropout=0.1):
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
+ self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
+ self.layer_norm = nn.LayerNorm(d_in)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ residual = x
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.w_2(F.relu(self.w_1(x)))
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.dropout(x)
+ x = self.layer_norm(x + residual)
+ return x
+
+
+class EncoderLayer(nn.Layer):
+ ''' Compose with two layers '''
+
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
+ super(EncoderLayer, self).__init__()
+ self.slf_attn = MultiHeadAttention(
+ n_head, d_model, d_k, d_v, dropout=dropout)
+ self.pos_ffn = PositionwiseFeedForward(
+ d_model, d_inner, dropout=dropout)
+
+ def forward(self, enc_input, slf_attn_mask=None):
+ enc_output = self.slf_attn(
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
+ enc_output = self.pos_ffn(enc_output)
+ return enc_output
+
+
+class Transformer_Encoder(nn.Layer):
+ def __init__(self,
+ n_layers=2,
+ n_head=8,
+ d_word_vec=512,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ d_inner=2048,
+ dropout=0.1,
+ n_position=256):
+ super(Transformer_Encoder, self).__init__()
+ self.position_enc = PositionalEncoding(
+ d_word_vec, n_position=n_position)
+ self.dropout = nn.Dropout(p=dropout)
+ self.layer_stack = nn.LayerList([
+ EncoderLayer(
+ d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
+
+ def forward(self, enc_output, src_mask, return_attns=False):
+ enc_output = self.dropout(
+ self.position_enc(enc_output)) # position embeding
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
+ enc_output = self.layer_norm(enc_output)
+ return enc_output
+
+
+class PP_layer(nn.Layer):
+ def __init__(self, n_dim=512, N_max_character=25, n_position=256):
+
+ super(PP_layer, self).__init__()
+ self.character_len = N_max_character
+ self.f0_embedding = nn.Embedding(N_max_character, n_dim)
+ self.w0 = nn.Linear(N_max_character, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.we = nn.Linear(n_dim, N_max_character)
+ self.active = nn.Tanh()
+ self.softmax = nn.Softmax(axis=2)
+
+ def forward(self, enc_output):
+ # enc_output: b,256,512
+ reading_order = paddle.arange(self.character_len, dtype='int64')
+ reading_order = reading_order.unsqueeze(0).expand(
+ [enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
+ reading_order = self.f0_embedding(reading_order) # b,25,512
+
+ # calculate attention
+ reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
+ t = self.w0(reading_order) # b,512,256
+ t = self.active(
+ paddle.transpose(
+ t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
+ t = self.we(t) # b,256,25
+ t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
+ g_output = paddle.bmm(t, enc_output) # b,25,512
+ return g_output
+
+
+class Prediction(nn.Layer):
+ def __init__(self,
+ n_dim=512,
+ n_position=256,
+ N_max_character=25,
+ n_class=37):
+ super(Prediction, self).__init__()
+ self.pp = PP_layer(
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ self.pp_share = PP_layer(
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ self.w_vrm = nn.Linear(n_dim, n_class) # output layer
+ self.w_share = nn.Linear(n_dim, n_class) # output layer
+ self.nclass = n_class
+
+ def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
+ use_mlm=True):
+ if train_mode:
+ if not use_mlm:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ f_res = 0
+ f_sub = 0
+ return g_output, f_res, f_sub
+ g_output = self.pp(cnn_feature) # b,25,512
+ f_res = self.pp_share(f_res)
+ f_sub = self.pp_share(f_sub)
+ g_output = self.w_vrm(g_output)
+ f_res = self.w_share(f_res)
+ f_sub = self.w_share(f_sub)
+ return g_output, f_res, f_sub
+ else:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ return g_output
+
+
+class MLM(nn.Layer):
+ "Architecture of MLM"
+
+ def __init__(self, n_dim=512, n_position=256, max_text_length=25):
+ super(MLM, self).__init__()
+ self.MLM_SequenceModeling_mask = Transformer_Encoder(
+ n_layers=2, n_position=n_position)
+ self.MLM_SequenceModeling_WCL = Transformer_Encoder(
+ n_layers=1, n_position=n_position)
+ self.pos_embedding = nn.Embedding(max_text_length, n_dim)
+ self.w0_linear = nn.Linear(1, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.active = nn.Tanh()
+ self.we = nn.Linear(n_dim, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, label_pos):
+ # transformer unit for generating mask_c
+ feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
+ # position embedding layer
+ label_pos = paddle.to_tensor(label_pos, dtype='int64')
+ pos_emb = self.pos_embedding(label_pos)
+ pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
+ pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
+ # fusion position embedding with features V & generate mask_c
+ att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
+ att_map_sub = self.we(att_map_sub) # b,256,1
+ att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+ att_map_sub = self.sigmoid(att_map_sub) # b,1,256
+ # WCL
+ ## generate inputs for WCL
+ att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+ f_res = x * (1 - att_map_sub) # second path with remaining string
+ f_sub = x * att_map_sub # first path with occluded character
+ ## transformer units in WCL
+ f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
+ f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
+ return f_res, f_sub, att_map_sub
+
+
+def trans_1d_2d(x):
+ b, w_h, c = x.shape # b, 256, 512
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = paddle.reshape(x, [-1, c, 32, 8])
+ x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
+ return x
+
+
+class MLM_VRM(nn.Layer):
+ """
+ MLM+VRM, MLM is only used in training.
+ ratio controls the occluded number in a batch.
+ The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
+ x: input image
+ label_pos: character index
+ training_step: LF or LA process
+ output
+ text_pre: prediction of VRM
+ test_rem: prediction of remaining string in MLM
+ text_mas: prediction of occluded character in MLM
+ mask_c_show: visualization of Mask_c
+ """
+
+ def __init__(self,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ max_text_length=25,
+ nclass=37):
+ super(MLM_VRM, self).__init__()
+ self.MLM = MLM(n_dim=n_dim,
+ n_position=n_position,
+ max_text_length=max_text_length)
+ self.SequenceModeling = Transformer_Encoder(
+ n_layers=n_layers, n_position=n_position)
+ self.Prediction = Prediction(
+ n_dim=n_dim,
+ n_position=n_position,
+ N_max_character=max_text_length +
+ 1, # N_max_character = 1 eos + 25 characters
+ n_class=nclass)
+ self.nclass = nclass
+ self.max_text_length = max_text_length
+
+ def forward(self, x, label_pos, training_step, train_mode=False):
+ b, c, h, w = x.shape
+ nT = self.max_text_length
+ x = paddle.transpose(x, perm=[0, 1, 3, 2])
+ x = paddle.reshape(x, [-1, c, h * w])
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ if train_mode:
+ if training_step == 'LF_1':
+ f_res = 0
+ f_sub = 0
+ x = self.SequenceModeling(x, src_mask=None)
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True, use_mlm=False)
+ return text_pre, text_pre, text_pre, text_pre
+ elif training_step == 'LF_2':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(x, label_pos)
+ x = self.SequenceModeling(x, src_mask=None)
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True)
+ mask_c_show = trans_1d_2d(mask_c)
+ return text_pre, test_rem, text_mas, mask_c_show
+ elif training_step == 'LA':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(x, label_pos)
+ ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
+ ## ratio controls the occluded number in a batch
+ character_mask = paddle.zeros_like(mask_c)
+
+ ratio = b // 2
+ if ratio >= 1:
+ with paddle.no_grad():
+ character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
+ else:
+ character_mask = mask_c
+ x = x * (1 - character_mask)
+ # VRM
+ ## transformer unit for VRM
+ x = self.SequenceModeling(x, src_mask=None)
+ ## prediction layer for MLM and VSR
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True)
+ mask_c_show = trans_1d_2d(mask_c)
+ return text_pre, test_rem, text_mas, mask_c_show
+ else:
+ raise NotImplementedError
+ else: # VRM is only used in the testing stage
+ f_res = 0
+ f_sub = 0
+ contextual_feature = self.SequenceModeling(x, src_mask=None)
+ text_pre = self.Prediction(
+ contextual_feature,
+ f_res,
+ f_sub,
+ train_mode=False,
+ use_mlm=False)
+ text_pre = paddle.transpose(
+ text_pre, perm=[1, 0, 2]) # (26, b, 37))
+ return text_pre, x
+
+
+class VLHead(nn.Layer):
+ """
+ Architecture of VisionLAN
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels=36,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ max_text_length=25,
+ training_step='LA'):
+ super(VLHead, self).__init__()
+ self.MLM_VRM = MLM_VRM(
+ n_layers=n_layers,
+ n_position=n_position,
+ n_dim=n_dim,
+ max_text_length=max_text_length,
+ nclass=out_channels + 1)
+ self.training_step = training_step
+
+ def forward(self, feat, targets=None):
+
+ if self.training:
+ label_pos = targets[-2]
+ text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
+ feat, label_pos, self.training_step, train_mode=True)
+ return text_pre, test_rem, text_mas, mask_map
+ else:
+ text_pre, x = self.MLM_VRM(
+ feat, targets, self.training_step, train_mode=False)
+ return text_pre, x
diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py
index dd8544e2e..144f011c7 100644
--- a/ppocr/optimizer/optimizer.py
+++ b/ppocr/optimizer/optimizer.py
@@ -77,11 +77,62 @@ class Adam(object):
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
+ self.group_lr = kwargs.get('group_lr', False)
+ self.training_step = kwargs.get('training_step', None)
def __call__(self, model):
- train_params = [
- param for param in model.parameters() if param.trainable is True
- ]
+ if self.group_lr:
+ if self.training_step == 'LF_2':
+ import paddle
+ if isinstance(model, paddle.fluid.dygraph.parallel.
+ DataParallel): # multi gpu
+ mlm = model._layers.head.MLM_VRM.MLM.parameters()
+ pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
+ )
+ pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
+ )
+ else: # single gpu
+ mlm = model.head.MLM_VRM.MLM.parameters()
+ pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
+ )
+ pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
+ )
+
+ total = []
+ for param in mlm:
+ total.append(id(param))
+ for param in pre_mlm_pp:
+ total.append(id(param))
+ for param in pre_mlm_w:
+ total.append(id(param))
+
+ group_base_params = [
+ param for param in model.parameters() if id(param) in total
+ ]
+ group_small_params = [
+ param for param in model.parameters()
+ if id(param) not in total
+ ]
+ train_params = [{
+ 'params': group_base_params
+ }, {
+ 'params': group_small_params,
+ 'learning_rate': self.learning_rate.values[0] * 0.1
+ }]
+
+ else:
+ print(
+ 'group lr currently only support VisionLAN in LF_2 training step'
+ )
+ train_params = [
+ param for param in model.parameters()
+ if param.trainable is True
+ ]
+ else:
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
+
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index eeebc5803..8f41a005f 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,12 +28,13 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
- SPINLabelDecode
+ SPINLabelDecode, VLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
-from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
-from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
+from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
+from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
+from .picodet_postprocess import PicoDetPostProcess
def build_post_process(config, global_config=None):
@@ -45,7 +46,9 @@ def build_post_process(config, global_config=None):
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
- 'TableMasterLabelDecode', 'SPINLabelDecode'
+ 'TableMasterLabelDecode', 'SPINLabelDecode',
+ 'DistillationSerPostProcess', 'DistillationRePostProcess',
+ 'VLLabelDecode', 'PicoDetPostProcess'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/picodet_postprocess.py b/ppocr/postprocess/picodet_postprocess.py
new file mode 100644
index 000000000..1a0aeb438
--- /dev/null
+++ b/ppocr/postprocess/picodet_postprocess.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from scipy.special import softmax
+
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ """
+ Args:
+ box_scores (N, 5): boxes in corner-form and probabilities.
+ iou_threshold: intersection over union threshold.
+ top_k: keep top_k results. If k <= 0, keep all the results.
+ candidate_size: only consider the candidates with the highest scores.
+ Returns:
+ picked: a list of indexes of the kept boxes
+ """
+ scores = box_scores[:, -1]
+ boxes = box_scores[:, :-1]
+ picked = []
+ indexes = np.argsort(scores)
+ indexes = indexes[-candidate_size:]
+ while len(indexes) > 0:
+ current = indexes[-1]
+ picked.append(current)
+ if 0 < top_k == len(picked) or len(indexes) == 1:
+ break
+ current_box = boxes[current, :]
+ indexes = indexes[:-1]
+ rest_boxes = boxes[indexes, :]
+ iou = iou_of(
+ rest_boxes,
+ np.expand_dims(
+ current_box, axis=0), )
+ indexes = indexes[iou <= iou_threshold]
+
+ return box_scores[picked, :]
+
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+ """Return intersection-over-union (Jaccard index) of boxes.
+ Args:
+ boxes0 (N, 4): ground truth boxes.
+ boxes1 (N or 1, 4): predicted boxes.
+ eps: a small number to avoid 0 as denominator.
+ Returns:
+ iou (N): IoU values.
+ """
+ overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
+ overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
+
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+ return overlap_area / (area0 + area1 - overlap_area + eps)
+
+
+def area_of(left_top, right_bottom):
+ """Compute the areas of rectangles given two corners.
+ Args:
+ left_top (N, 2): left top corner.
+ right_bottom (N, 2): right bottom corner.
+ Returns:
+ area (N): return the area.
+ """
+ hw = np.clip(right_bottom - left_top, 0.0, None)
+ return hw[..., 0] * hw[..., 1]
+
+
+class PicoDetPostProcess(object):
+ """
+ Args:
+ input_shape (int): network input image size
+ ori_shape (int): ori image shape of before padding
+ scale_factor (float): scale factor of ori image
+ enable_mkldnn (bool): whether to open MKLDNN
+ """
+
+ def __init__(self,
+ layout_dict_path,
+ strides=[8, 16, 32, 64],
+ score_threshold=0.4,
+ nms_threshold=0.5,
+ nms_top_k=1000,
+ keep_top_k=100):
+ self.labels = self.load_layout_dict(layout_dict_path)
+ self.strides = strides
+ self.score_threshold = score_threshold
+ self.nms_threshold = nms_threshold
+ self.nms_top_k = nms_top_k
+ self.keep_top_k = keep_top_k
+
+ def load_layout_dict(self, layout_dict_path):
+ with open(layout_dict_path, 'r', encoding='utf-8') as fp:
+ labels = fp.readlines()
+ return [label.strip('\n') for label in labels]
+
+ def warp_boxes(self, boxes, ori_shape):
+ """Apply transform to boxes
+ """
+ width, height = ori_shape[1], ori_shape[0]
+ n = len(boxes)
+ if n:
+ # warp points
+ xy = np.ones((n * 4, 3))
+ xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
+ n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ # xy = xy @ M.T # transform
+ xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
+ # create new boxes
+ x = xy[:, [0, 2, 4, 6]]
+ y = xy[:, [1, 3, 5, 7]]
+ xy = np.concatenate(
+ (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+ # clip boxes
+ xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
+ xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
+ return xy.astype(np.float32)
+ else:
+ return boxes
+
+ def img_info(self, ori_img, img):
+ origin_shape = ori_img.shape
+ resize_shape = img.shape
+ im_scale_y = resize_shape[2] / float(origin_shape[0])
+ im_scale_x = resize_shape[3] / float(origin_shape[1])
+ scale_factor = np.array([im_scale_y, im_scale_x], dtype=np.float32)
+ img_shape = np.array(img.shape[2:], dtype=np.float32)
+
+ input_shape = np.array(img).astype('float32').shape[2:]
+ ori_shape = np.array((img_shape, )).astype('float32')
+ scale_factor = np.array((scale_factor, )).astype('float32')
+ return ori_shape, input_shape, scale_factor
+
+ def __call__(self, ori_img, img, preds):
+ scores, raw_boxes = preds['boxes'], preds['boxes_num']
+ batch_size = raw_boxes[0].shape[0]
+ reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
+ out_boxes_num = []
+ out_boxes_list = []
+ results = []
+ ori_shape, input_shape, scale_factor = self.img_info(ori_img, img)
+
+ for batch_id in range(batch_size):
+ # generate centers
+ decode_boxes = []
+ select_scores = []
+ for stride, box_distribute, score in zip(self.strides, raw_boxes,
+ scores):
+ box_distribute = box_distribute[batch_id]
+ score = score[batch_id]
+ # centers
+ fm_h = input_shape[0] / stride
+ fm_w = input_shape[1] / stride
+ h_range = np.arange(fm_h)
+ w_range = np.arange(fm_w)
+ ww, hh = np.meshgrid(w_range, h_range)
+ ct_row = (hh.flatten() + 0.5) * stride
+ ct_col = (ww.flatten() + 0.5) * stride
+ center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
+
+ # box distribution to distance
+ reg_range = np.arange(reg_max + 1)
+ box_distance = box_distribute.reshape((-1, reg_max + 1))
+ box_distance = softmax(box_distance, axis=1)
+ box_distance = box_distance * np.expand_dims(reg_range, axis=0)
+ box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
+ box_distance = box_distance * stride
+
+ # top K candidate
+ topk_idx = np.argsort(score.max(axis=1))[::-1]
+ topk_idx = topk_idx[:self.nms_top_k]
+ center = center[topk_idx]
+ score = score[topk_idx]
+ box_distance = box_distance[topk_idx]
+
+ # decode box
+ decode_box = center + [-1, -1, 1, 1] * box_distance
+
+ select_scores.append(score)
+ decode_boxes.append(decode_box)
+
+ # nms
+ bboxes = np.concatenate(decode_boxes, axis=0)
+ confidences = np.concatenate(select_scores, axis=0)
+ picked_box_probs = []
+ picked_labels = []
+ for class_index in range(0, confidences.shape[1]):
+ probs = confidences[:, class_index]
+ mask = probs > self.score_threshold
+ probs = probs[mask]
+ if probs.shape[0] == 0:
+ continue
+ subset_boxes = bboxes[mask, :]
+ box_probs = np.concatenate(
+ [subset_boxes, probs.reshape(-1, 1)], axis=1)
+ box_probs = hard_nms(
+ box_probs,
+ iou_threshold=self.nms_threshold,
+ top_k=self.keep_top_k, )
+ picked_box_probs.append(box_probs)
+ picked_labels.extend([class_index] * box_probs.shape[0])
+
+ if len(picked_box_probs) == 0:
+ out_boxes_list.append(np.empty((0, 4)))
+ out_boxes_num.append(0)
+
+ else:
+ picked_box_probs = np.concatenate(picked_box_probs)
+
+ # resize output boxes
+ picked_box_probs[:, :4] = self.warp_boxes(
+ picked_box_probs[:, :4], ori_shape[batch_id])
+ im_scale = np.concatenate([
+ scale_factor[batch_id][::-1], scale_factor[batch_id][::-1]
+ ])
+ picked_box_probs[:, :4] /= im_scale
+ # clas score box
+ out_boxes_list.append(
+ np.concatenate(
+ [
+ np.expand_dims(
+ np.array(picked_labels),
+ axis=-1), np.expand_dims(
+ picked_box_probs[:, 4], axis=-1),
+ picked_box_probs[:, :4]
+ ],
+ axis=1))
+ out_boxes_num.append(len(picked_labels))
+
+ out_boxes_list = np.concatenate(out_boxes_list, axis=0)
+ out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
+
+ for dt in out_boxes_list:
+ clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+ label = self.labels[clsid]
+ result = {'bbox': bbox, 'label': label}
+ results.append(result)
+ return results
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 3fe29aabe..7b994f810 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -668,6 +668,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
dict_character = [''] + dict_character
return dict_character
+
class SPINLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """
@@ -681,4 +682,106 @@ class SPINLabelDecode(AttnLabelDecode):
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + [self.end_str] + dict_character
- return dict_character
\ No newline at end of file
+ return dict_character
+
+
+class VLLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
+ self.max_text_length = kwargs.get('max_text_length', 25)
+ self.nclass = len(self.character) + 1
+ self.character = self.character[10:] + self.character[
+ 1:10] + [self.character[0]]
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id - 1]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, label=None, length=None, *args, **kwargs):
+ if len(preds) == 2: # eval mode
+ text_pre, x = preds
+ b = text_pre.shape[1]
+ lenText = self.max_text_length
+ nsteps = self.max_text_length
+
+ if not isinstance(text_pre, paddle.Tensor):
+ text_pre = paddle.to_tensor(text_pre, dtype='float32')
+
+ out_res = paddle.zeros(
+ shape=[lenText, b, self.nclass], dtype=x.dtype)
+ out_length = paddle.zeros(shape=[b], dtype=x.dtype)
+ now_step = 0
+ for _ in range(nsteps):
+ if 0 in out_length and now_step < nsteps:
+ tmp_result = text_pre[now_step, :, :]
+ out_res[now_step] = tmp_result
+ tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
+ for j in range(b):
+ if out_length[j] == 0 and tmp_result[j] == 0:
+ out_length[j] = now_step + 1
+ now_step += 1
+ for j in range(0, b):
+ if int(out_length[j]) == 0:
+ out_length[j] = nsteps
+ start = 0
+ output = paddle.zeros(
+ shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
+ for i in range(0, b):
+ cur_length = int(out_length[i])
+ output[start:start + cur_length] = out_res[0:cur_length, i, :]
+ start += cur_length
+ net_out = output
+ length = out_length
+
+ else: # train mode
+ net_out = preds[0]
+ length = length
+ net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
+ text = []
+ if not isinstance(net_out, paddle.Tensor):
+ net_out = paddle.to_tensor(net_out, dtype='float32')
+ net_out = F.softmax(net_out, axis=1)
+ for i in range(0, length.shape[0]):
+ preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
+ ) + length[i])].topk(1)[1][:, 0].tolist()
+ preds_text = ''.join([
+ self.character[idx - 1]
+ if idx > 0 and idx <= len(self.character) else ''
+ for idx in preds_idx
+ ])
+ preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
+ ) + length[i])].topk(1)[0][:, 0]
+ preds_prob = paddle.exp(
+ paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
+ text.append((preds_text, preds_prob))
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
index 1d55d13d7..96c25d9aa 100644
--- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object):
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
+
+
+class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
+ """
+ DistillationRePostProcess
+ """
+
+ def __init__(self, model_name=["Student"], key=None, **kwargs):
+ super().__init__(**kwargs)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+ self.key = key
+
+ def __call__(self, preds, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ output[name] = super().__call__(pred, *args, **kwargs)
+ return output
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 8a6669f71..5541da90a 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object):
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
+
+
+class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess):
+ """
+ DistillationSerPostProcess
+ """
+
+ def __init__(self, class_path, model_name=["Student"], key=None, **kwargs):
+ super().__init__(class_path, **kwargs)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+ self.key = key
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ output[name] = super().__call__(pred, batch=batch, *args, **kwargs)
+ return output
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 3647111fd..e77a6ce01 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -53,8 +53,12 @@ def load_model(config, model, optimizer=None, model_type='det'):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
+ is_float16 = False
if model_type == 'vqa':
+ # NOTE: for vqa model, resume training is not supported now
+ if config["Architecture"]["algorithm"] in ["Distillation"]:
+ return best_model_dict
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
@@ -78,6 +82,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
+
return best_model_dict
if checkpoints:
@@ -96,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
key, params.keys()))
continue
pre_value = params[key]
+ if pre_value.dtype == paddle.float16:
+ pre_value = pre_value.astype(paddle.float32)
+ is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
@@ -103,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
-
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
@@ -122,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
- load_pretrained_params(model, pretrained_model)
+ is_float16 = load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
+ best_model_dict['is_float16'] = is_float16
return best_model_dict
@@ -138,19 +150,28 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
+ is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
+ if params[k1].dtype == paddle.float16:
+ params[k1] = params[k1].astype(paddle.float32)
+ is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
+
model.set_state_dict(new_state_dict)
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
logger.info("load pretrain successful from {}".format(path))
- return model
+ return is_float16
def save_model(model,
@@ -166,15 +187,19 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
+ if config['Architecture']["model_type"] != 'vqa':
+ paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
- else:
+ else: # for vqa system, we follow the save/load rules in NLP
if config['Global']['distributed']:
- model._layers.backbone.model.save_pretrained(model_prefix)
+ arch = model._layers
else:
- model.backbone.model.save_pretrained(model_prefix)
+ arch = model
+ if config["Architecture"]["algorithm"] in ["Distillation"]:
+ arch = arch.Student
+ arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
diff --git a/ppstructure/layout/predict_layout.py b/ppstructure/layout/predict_layout.py
new file mode 100755
index 000000000..a58a63f49
--- /dev/null
+++ b/ppstructure/layout/predict_layout.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+from picodet_postprocess import PicoDetPostProcess
+
+logger = get_logger()
+
+class LayoutPredictor(object):
+ def __init__(self, args):
+ pre_process_list = [{
+ 'Resize': {
+ 'size': [800, 608]
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image']
+ }
+ }]
+ postprocess_params = {
+ 'name': 'PicoDetPostProcess',
+ "layout_dict_path": args.layout_dict_path,
+ "score_threshold": args.layout_score_threshold,
+ "nms_threshold": args.layout_nms_threshold,
+ }
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'layout', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+
+ if img is None:
+ return None, 0
+
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+
+ preds, elapse = 0, 1
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+
+ np_score_list, np_boxes_list = [], []
+ output_names = self.predictor.get_output_names()
+ num_outs = int(len(output_names) / 2)
+ for out_idx in range(num_outs):
+ np_score_list.append(
+ self.predictor.get_output_handle(output_names[out_idx])
+ .copy_to_cpu())
+ np_boxes_list.append(
+ self.predictor.get_output_handle(output_names[
+ out_idx + num_outs]).copy_to_cpu())
+ preds = dict(boxes=np_score_list, boxes_num=np_boxes_list)
+
+ post_preds = self.postprocess_op(ori_im, img, preds)
+ elapse = time.time() - starttime
+ return post_preds, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ layout_predictor = LayoutPredictor(args)
+ count = 0
+ total_time = 0
+
+ repeats = 50
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+
+ layout_res, elapse = layout_predictor(img)
+
+ logger.info("result: {}".format(layout_res))
+
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index af0616239..d79658f13 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -32,15 +32,18 @@ def init_args():
type=str,
default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
+ parser.add_argument("--layout_model_dir", type=str)
parser.add_argument(
- "--layout_path_model",
+ "--layout_dict_path",
type=str,
- default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+ default="../ppocr/utils/dict/layout_pubalynet_dict.txt")
parser.add_argument(
- "--layout_label_map",
- type=ast.literal_eval,
- default=None,
- help='label map according to ppstructure/layout/README_ch.md')
+ "--layout_score_threshold",
+ type=float,
+ default=0.5,
+ help="Threshold of score.")
+ parser.add_argument(
+ "--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms.")
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
@@ -87,7 +90,7 @@ def draw_structure_result(image, result, font_path):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
- if region['type'] == 'Table':
+ if region['type'] == 'table':
pass
else:
for text_result in region['res']:
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 05635265b..28b794383 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -216,7 +216,7 @@ Use the following command to complete the tandem prediction of `OCR + SER` based
```shell
cd ppstructure
-CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md
index b421a82d3..f168110ed 100644
--- a/ppstructure/vqa/README_ch.md
+++ b/ppstructure/vqa/README_ch.md
@@ -215,7 +215,7 @@ python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture
```shell
cd ppstructure
-CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py
index de0bbfe72..3097ebcf1 100644
--- a/ppstructure/vqa/predict_vqa_token_ser.py
+++ b/ppstructure/vqa/predict_vqa_token_ser.py
@@ -153,7 +153,7 @@ def main(args):
img_res = draw_ser_results(
image_file,
ser_res,
- font_path="../doc/fonts/simfang.ttf", )
+ font_path=args.vis_font_path, )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
index a1497ba8f..3eb82d42b 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
@@ -114,7 +114,7 @@ Train:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list:
- - ./train_data/ic15_data/rec_gt_train4w.txt
+ - ./train_data/ic15_data/rec_gt_train.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
index ee884f668..4c8ba0a6f 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
@@ -153,7 +153,7 @@ Train:
data_dir: ./train_data/ic15_data/
ext_op_transform_idx: 1
label_file_list:
- - ./train_data/ic15_data/rec_gt_train4w.txt
+ - ./train_data/ic15_data/rec_gt_train.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
index 420c6592d..59fc1bd41 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
@@ -52,8 +52,9 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,320]}]
===========================train_benchmark_params==========================
-batch_size:128
+batch_size:64
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
+
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 91%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index b42ab9db3..73f1d4985 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
similarity index 94%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
index becad991e..00373b61e 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================ch_ppocr_mobile_v2.0===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 17c2fbbae..3e01ae573 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index d18e9f11f..305882aa3 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 842c93401..0c366b03d 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 1d1c2ae28..ded332e67 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
similarity index 92%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
index 24bb8746a..5f9dfa5f5 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
@@ -1,5 +1,5 @@
===========================infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 84%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 00473d106..8f36ad4b8 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index c9dd5ad92..6dfd7e7bd 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
similarity index 98%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
index 3db816cc0..f3aa9d0f8 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 5271f78bb..bf81d0baa 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 6b3352f74..df71e9070 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
index 04c8d0e19..ba880d1f9 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
similarity index 93%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
index 2bdec8488..45c4fd1ae 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
@@ -1,5 +1,5 @@
===========================kl_quant_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
Global.pretrained_model:null
Global.save_inference_dir:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
index dae3f8053..0f6df1ac5 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_FPGM
+model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 150a8a031..2014c6dbc 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_FPGM
+model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index eb2fd0a00..f0e58dd56 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index ab518de55..c5dc52583 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 049ec7845..82d4db32a 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 17723f41a..513233059 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_pact_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index 229f70cf3..3be53952f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 909d73891..63e7f8f73 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 480fb16cd..332e632bd 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 5bab0c9e4..78b76edae 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
2onnx: paddle2onnx
--det_model_dir:
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index c0c5291cc..5c60903f6 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
similarity index 98%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
index 36fdb1b91..40f397948 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 631118c0a..2f919d102 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index bd9c4a8df..f60e2790e 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
index 77472fbdf..9c1223f41 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -14,7 +14,7 @@ null:null
##
trainer:pact_train
norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_train:null
distill_train:null
null:null
@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null
distill_export:null
export1:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
similarity index 84%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
index f63fe4c2b..df47f328b 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
@@ -1,10 +1,10 @@
===========================kl_quant_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
Global.pretrained_model:null
Global.save_inference_dir:null
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
-infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml -o
+infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml -o
infer_quant:True
inference:tools/infer/predict_rec.py --rec_image_shape="3,32,320"
--use_gpu:False|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
index 89daceeb5..94c950310 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_FPGM
+model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -15,7 +15,7 @@ null:null
trainer:fpgm_train
norm_train:null
pact_train:null
-fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
+fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
distill_train:null
null:null
null:null
@@ -29,7 +29,7 @@ Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
distill_export:null
export1:null
export2:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 7abc3e934..71555865a 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_FPGM
+model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -15,7 +15,7 @@ null:null
trainer:fpgm_train
norm_train:null
pact_train:null
-fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
+fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
distill_train:null
null:null
null:null
@@ -29,7 +29,7 @@ Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
distill_export:null
export1:null
export2:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index adf06257a..ef4c93fcd 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_klquant_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index d9de1cc19..d904e22a7 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 948e3dceb..de4f7ed2c 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index ba2df90f7..74ca7b50b 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_pact_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index 1a49a10f9..5a3047448 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index f123f3654..5871199bc 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 91%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 7c980b2ba..ba8646fd9 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
similarity index 94%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
index b20596f7a..53f8ab0e7 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================ch_ppocr_server_v2.0===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 85%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index e478896a5..9e2cf191f 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index bbfec44db..55b27e04a 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 8853e709d..21b8c9a08 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml b/test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 69ae939e2..4a30affd0 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 82%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index c8bebf54f..b7dd6e22b 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 018dd1a22..4d4f0679b 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
similarity index 92%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
index 7b90a4078..90ed29f43 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 12388d967..f398078fc 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 93ed14cb6..7a2d0a53c 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index cbec272cc..3f3905516 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_rec_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 462f6090d..89b966100 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
2onnx: paddle2onnx
--det_model_dir:
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 7f456320b..4133e961c 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml b/test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
index 9fc117d67..b9a1ae498 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:True|False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 9884ab247..d5f57aef9 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 63ddaa4a8..20eb10b8e 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml b/test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml
similarity index 100%
rename from test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml
rename to test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml
diff --git a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
similarity index 91%
rename from test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
index 1ec1597a4..9c6d9660d 100644
--- a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_mv3_east_v2.0
+model_name:det_mv3_east_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_mv3_east_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml b/test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml
similarity index 100%
rename from test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml
rename to test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml
diff --git a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
similarity index 91%
rename from test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
index daeec69f8..525fdc7d4 100644
--- a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_mv3_pse_v2.0
+model_name:det_mv3_pse_v2_0
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_mv3_pse_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
similarity index 96%
rename from test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
index 11af0ad18..1d0d9693a 100644
--- a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_db_v2.0
+model_name:det_r50_db_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
similarity index 96%
rename from test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml
rename to test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
index 3a513b8f3..29f6f32a5 100644
--- a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
- pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
+ pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
index 2d294fd30..92ded19d6 100644
--- a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_dcn_fce_ctw_v2.0
+model_name:det_r50_dcn_fce_ctw_v2_0
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml
similarity index 100%
rename from test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
rename to test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
index b70ef46b4..b01f1925b 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_vd_sast_icdar15_v2.0
+model_name:det_r50_vd_sast_icdar15_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_vd_sast_icdar15_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml
similarity index 100%
rename from test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml
rename to test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
index 7be5af7dd..a47ad6803 100644
--- a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_vd_sast_totaltext_v2.0
+model_name:det_r50_vd_sast_totaltext_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_vd_sast_totaltext_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
new file mode 100644
index 000000000..34082bc19
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
@@ -0,0 +1,59 @@
+===========================train_params===========================
+model_name:layoutxlm_ser
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
+Architecture.Backbone.checkpoints:null
+train_model_name:latest
+train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Architecture.Backbone.checkpoints:
+norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:ppstructure/vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--ser_model_dir:
+--image_dir:./ppstructure/docs/vqa/input/zh_val_42.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,224,224]}]
+===========================train_benchmark_params==========================
+batch_size:4
+fp_items:fp32|fp16
+epoch:3
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
index 4e34a6a52..db89b4c78 100644
--- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_none_bilstm_ctc_v2.0
+model_name:rec_mv3_none_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
index 593de3ff2..003e91ff3 100644
--- a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_none_none_ctc_v2.0
+model_name:rec_mv3_none_none_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_none_none_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml
rename to test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
index 1b2d9abb0..c7b416c83 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_tps_bilstm_att_v2.0
+model_name:rec_mv3_tps_bilstm_att_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_tps_bilstm_att_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
index 1367c7abd..0c6e2d1da 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_tps_bilstm_ctc_v2.0
+model_name:rec_mv3_tps_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_tps_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
index d0cb20481..21d56b685 100644
--- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
- pretrained_model:
+ pretrained_model: pretrain_models/rec_r32_gaspin_bilstm_att_train/best_accuracy
checkpoints:
save_inference_dir:
use_visualdl: False
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
index 4915055a5..115dfd661 100644
--- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
@@ -1,6 +1,6 @@
===========================train_params===========================
model_name:rec_r32_gaspin_bilstm_att
-python:python
+python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
@@ -39,11 +39,11 @@ infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_at
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
--use_gpu:True|False
---enable_mkldnn:True|False
---cpu_threads:1|6
+--enable_mkldnn:False
+--cpu_threads:6
--rec_batch_num:1|6
---use_tensorrt:False|False
---precision:fp32|int8
+--use_tensorrt:False
+--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 86%
rename from test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
index 46aa3d719..07a6190b0 100644
--- a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_none_bilstm_ctc_v2.0
+model_name:rec_r34_vd_none_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
similarity index 86%
rename from test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
index 3e066d7b7..145793aa4 100644
--- a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_none_none_ctc_v2.0
+model_name:rec_r34_vd_none_none_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_none_none_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
index 1e4f46633..759518a4a 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_tps_bilstm_att_v2.0
+model_name:rec_r34_vd_tps_bilstm_att_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
index 9e795b664..ecc898341 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_tps_bilstm_ctc_v2.0
+model_name:rec_r34_vd_tps_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_tps_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md
index a7f95eb6c..50cc13b92 100644
--- a/test_tipc/docs/benchmark_train.md
+++ b/test_tipc/docs/benchmark_train.md
@@ -69,7 +69,8 @@ train_log/
| det_r50_vd_east_v2.0 |[config](../configs/det_r50_vd_east_v2.0/train_infer_python.txt) | 42.485 | 42.624 / 42.663 / 42.561 |0.00239083 | 67.61 |67.825/ 68.299/ 68.51| 0.00999854 | 10,000| 2,000|
| det_r50_vd_pse_v2.0 |[config](../configs/det_r50_vd_pse_v2.0/train_infer_python.txt) | 16.455 | 16.517 / 16.555 / 16.353 |0.012201752 | 27.02 |27.288 / 27.152 / 27.408| 0.009340339 | 10,000| 2,000|
| rec_mv3_none_bilstm_ctc_v2.0 |[config](../configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt) | 2288.358 | 2291.906 / 2293.725 / 2290.05 |0.001602197 | 2336.17 |2327.042 / 2328.093 / 2344.915| 0.007622025 | 600,000| 160,000|
+| layoutxlm_ser |[config](../configs/layoutxlm/train_infer_python.txt) | 18.001 | 18.114 / 18.107 / 18.307 |0.010924783 | 21.982 | 21.507 / 21.116 / 21.406| 0.018180127 | 1490 | 1490|
| PP-Structure-table |[config](../configs/en_table_structure/train_infer_python.txt) | 14.151 | 14.077 / 14.23 / 14.25 |0.012140351 | 16.285 | 16.595 / 16.878 / 16.531 | 0.020559308 | 20,000| 5,000|
| det_r50_dcn_fce_ctw_v2.0 |[config](../configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt) | 14.057 | 14.029 / 14.02 / 14.014 |0.001069214 | 18.298 |18.411 / 18.376 / 18.331| 0.004345228 | 10,000| 2,000|
| ch_PP-OCRv3_det |[config](../configs/ch_PP-OCRv3_det/train_infer_python.txt) | 8.622 | 8.431 / 8.423 / 8.479|0.006604552 | 14.203 |14.346 14.468 14.23| 0.016450097 | 10,000| 2,000|
-| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 73.627 | 72.46 / 73.575 / 73.704|0.016878324 | | | | 160,000| 40,000|
\ No newline at end of file
+| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 90.239 | 90.077 / 91.513 / 91.325|0.01569176 | | | | 160,000| 40,000|
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index cb3fa2440..76543f39e 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -22,7 +22,7 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
@@ -30,7 +30,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "ch_ppocr_server_v2.0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
+ if [[ ${model_name} =~ "ch_ppocr_server_v2_0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
@@ -55,7 +55,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "det_r50_db_v2.0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
+ if [[ ${model_name} =~ "det_r50_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
@@ -71,13 +71,23 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_rec" || ${model_name} =~ "ch_ppocr_server_v2.0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2.0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
- rm -rf ./train_data/ic15_data_benckmark
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_rec" || ${model_name} =~ "ch_ppocr_server_v2_0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2_0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
+ rm -rf ./train_data/ic15_data
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data_benckmark.tar
ln -s ./ic15_data_benckmark ./ic15_data
cd ../
fi
+ if [[ ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
+ rm -rf ./train_data/ic15_data
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf ic15_data_benckmark.tar
+ ln -s ./ic15_data_benckmark ./ic15_data
+ cd ic15_data
+ mv rec_gt_train4w.txt rec_gt_train.txt
+ cd ../
+ cd ../
+ fi
if [[ ${model_name} == "en_table_structure" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
@@ -87,7 +97,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./pubtabnet_benckmark ./pubtabnet
cd ../
fi
- if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]]; then
+ if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
@@ -96,6 +106,19 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
+ if [ ${model_name} == "layoutxlm_ser" ]; then
+ pip install -r ppstructure/vqa/requirements.txt
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
+ cd ./train_data/ && tar xf XFUND.tar
+ # expand gt.txt 10 times
+ cd XFUND/zh_train
+ for i in `seq 10`;do cp train.json dup$i.txt;done
+ cat dup* > train.json && rm -rf dup*
+ cd ../../
+
+ cd ../
+ fi
fi
if [ ${MODE} = "lite_train_lite_infer" ];then
@@ -161,7 +184,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
@@ -172,16 +195,16 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_db_v2.0" ]; then
+ if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
- if [ ${model_name} == "det_mv3_east_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
@@ -189,10 +212,21 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
+ if [ ${model_name} == "rec_r32_gaspin_bilstm_att" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
+ fi
+ if [ ${model_name} == "layoutxlm_ser" ]; then
+ pip install -r ppstructure/vqa/requirements.txt
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
+ cd ./train_data/ && tar xf XFUND.tar
+ cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -220,7 +254,7 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
@@ -264,32 +298,32 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_prune_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_slim_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
eval_model_name="ch_PP-OCRv2_rec_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
@@ -334,39 +368,39 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_none_none_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_none_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_none_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_server_v2.0_rec" ]; then
+ if [ ${model_name} == "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_server_v2.0_rec_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec" ]; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
@@ -374,11 +408,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mtb_nrtr_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_att_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_att_v2.0_train.tar && cd ../
fi
@@ -391,7 +425,7 @@ elif [ ${MODE} = "whole_infer" ];then
cd ./inference/ && tar xf rec_r50_vd_srn_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_totaltext_v2.0_train.tar && cd ../
fi
@@ -399,11 +433,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_r50_db_v2.0" ]; then
+ if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_mv3_pse_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_pse_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_pse_v2.0_train.tar & cd ../
fi
@@ -411,7 +445,7 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_pse_v2.0_train.tar & cd ../
fi
- if [ ${model_name} == "det_mv3_east_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_east_v2.0_train.tar & cd ../
fi
@@ -419,7 +453,7 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
fi
- if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
@@ -434,7 +468,7 @@ fi
if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
@@ -466,7 +500,7 @@ if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
@@ -484,35 +518,35 @@ if [[ ${model_name} =~ "KL" ]]; then
fi
if [ ${MODE} = "cpp_infer" ];then
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
@@ -564,12 +598,12 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
@@ -597,7 +631,7 @@ if [ ${MODE} = "serving_infer" ];then
${python_name} -m pip install paddle_serving_client
${python_name} -m pip install paddle-serving-app
# wget model
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && cd ../
@@ -609,7 +643,7 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_klquant_infer.tar && tar xf ch_PP-OCRv3_rec_klquant_infer.tar && cd ../
- elif [ ${model_name} == "ch_ppocr_mobile_v2.0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then
+ elif [ ${model_name} == "ch_ppocr_mobile_v2_0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && cd ../
@@ -621,11 +655,11 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_pact_infer.tar && tar xf ch_PP-OCRv3_rec_pact_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
@@ -650,11 +684,11 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
${python_name} -m pip install paddle2onnx
${python_name} -m pip install onnxruntime
# wget model
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
diff --git a/test_tipc/prepare_lite_cpp.sh b/test_tipc/prepare_lite_cpp.sh
index 9148cb5dd..0d3a5ca45 100644
--- a/test_tipc/prepare_lite_cpp.sh
+++ b/test_tipc/prepare_lite_cpp.sh
@@ -49,7 +49,7 @@ model_path=./inference_models
for model in ${lite_model_list[*]}; do
if [[ $model =~ "PP-OCRv2" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/${model}.tar
- elif [[ $model =~ "v2.0" ]]; then
+ elif [[ $model =~ "v2_0" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/${model}.tar
elif [[ $model =~ "PP-OCRv3" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/${model}.tar
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 356bc9804..78d79d0b8 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -54,7 +54,7 @@ function func_paddle2onnx(){
_script=$1
# paddle2onnx
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
@@ -113,7 +113,7 @@ function func_paddle2onnx(){
_save_log_path="${LOG_PATH}/paddle2onnx_infer_cpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
@@ -132,7 +132,7 @@ function func_paddle2onnx(){
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh
index 4ccccc06e..4b7dfcf78 100644
--- a/test_tipc/test_serving_infer_python.sh
+++ b/test_tipc/test_serving_infer_python.sh
@@ -71,7 +71,7 @@ function func_serving(){
# pdserving
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
# trans det
set_dirname=$(func_set_params "--dirname" "${det_infer_model_dir_value}")
set_serving_server=$(func_set_params "--serving_server" "${det_serving_server_value}")
@@ -120,7 +120,7 @@ function func_serving(){
for threads in ${web_cpu_threads_list[*]}; do
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
server_log_path="${LOG_PATH}/python_server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}.log"
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
@@ -171,7 +171,7 @@ function func_serving(){
device_type=2
fi
set_precision=$(func_set_params "${web_precision_key}" "${precision}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 402f636b1..545cdbba2 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -101,6 +101,7 @@ function func_inference(){
_log_path=$4
_img_dir=$5
_flag_quant=$6
+ _gpu=$7
# inference
for use_gpu in ${use_gpu_list[*]}; do
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
@@ -119,7 +120,7 @@ function func_inference(){
fi # skip when quant model inference but precision is not int8
set_precision=$(func_set_params "${precision_key}" "${precision}")
- _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+ _save_log_path="${_log_path}/python_infer_cpu_gpus_${_gpu}_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
@@ -150,7 +151,7 @@ function func_inference(){
continue
fi
for batch_size in ${batch_size_list[*]}; do
- _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+ _save_log_path="${_log_path}/python_infer_gpu_gpus_${_gpu}_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
@@ -184,6 +185,7 @@ if [ ${MODE} = "whole_infer" ]; then
# set CUDA_VISIBLE_DEVICES
eval $env
export Count=0
+ gpu=0
IFS="|"
infer_run_exports=(${infer_export_list})
infer_quant_flag=(${infer_is_quant})
@@ -205,7 +207,7 @@ if [ ${MODE} = "whole_infer" ]; then
fi
#run inference
is_quant=${infer_quant_flag[Count]}
- func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
+ func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} "${gpu}"
Count=$(($Count + 1))
done
else
@@ -328,7 +330,7 @@ else
else
infer_model_dir=${save_infer_path}
fi
- func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}"
+ func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" "${gpu}"
eval "unset CUDA_VISIBLE_DEVICES"
fi
diff --git a/tools/eval.py b/tools/eval.py
index 6f5189fd6..38d72d178 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"]
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
diff --git a/tools/export_model.py b/tools/export_model.py
index 1cd273b20..cea02e4c5 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -113,6 +113,12 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "VisionLAN":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 64, 256], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
@@ -233,4 +239,4 @@ def main():
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 428b47a40..34d9497fe 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -68,7 +68,13 @@ class TextRecognizer(object):
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
- }
+ }
+ elif self.rec_algorithm == "VisionLAN":
+ postprocess_params = {
+ 'name': 'VLLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
elif self.rec_algorithm == 'ViTSTR':
postprocess_params = {
'name': 'ViTSTRLabelDecode',
@@ -164,6 +170,16 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_vl(self, img, image_shape):
+
+ imgC, imgH, imgW = image_shape
+ img = img[:, :, ::-1] # bgr2rgb
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ return resized_image
+
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -287,6 +303,7 @@ class TextRecognizer(object):
img -= mean
img *= stdinv
return img
+
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -367,6 +384,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "VisionLAN":
+ norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
+ self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 7eb77dec7..9345106e7 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -38,6 +38,7 @@ def init_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
+ parser.add_argument("--shape_info_filename", type=str, default=None)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
@@ -204,9 +205,18 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
- min_subgraph_size=args.min_subgraph_size,
+ min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
use_calib_mode=False)
- # skip the minmum trt subgraph
+
+ # collect shape
+ if args.shape_info_filename is not None:
+ if not os.path.exists(args.shape_info_filename):
+ config.collect_shape_range_info(args.shape_info_filename)
+ logger.info(f"collect dynamic shape info into : {args.shape_info_filename}")
+ else:
+ logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.")
+ config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True)
+
use_dynamic_shape = True
if mode == "det":
min_input_shape = {
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 670733cb9..14b14544e 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -139,7 +139,6 @@ def main():
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
-
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
index 20ab1fe17..51378bdae 100755
--- a/tools/infer_vqa_token_ser_re.py
+++ b/tools/infer_vqa_token_ser_re.py
@@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results):
class SerRePredictor(object):
def __init__(self, config, ser_config):
+ global_config = config['Global']
+ if "infer_mode" in global_config:
+ ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
+
self.ser_engine = SerPredictor(ser_config)
# init re model
- global_config = config['Global']
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
@@ -130,8 +133,8 @@ class SerRePredictor(object):
self.model.eval()
- def __call__(self, img_path):
- ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
+ def __call__(self, data):
+ ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
@@ -173,18 +176,33 @@ if __name__ == '__main__':
ser_re_engine = SerRePredictor(config, ser_config)
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ if config["Global"].get("infer_mode", None) is False:
+ data_dir = config['Eval']['dataset']['data_dir']
+ with open(config['Global']['infer_img'], "rb") as f:
+ infer_imgs = f.readlines()
+ else:
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
+ for idx, info in enumerate(infer_imgs):
+ if config["Global"].get("infer_mode", None) is False:
+ data_line = info.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path = os.path.join(data_dir, substr[0])
+ data = {'img_path': img_path, 'label': substr[1]}
+ else:
+ img_path = info
+ data = {'img_path': img_path}
+
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
- result = ser_re_engine(img_path)
+ result = ser_re_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
diff --git a/tools/program.py b/tools/program.py
index 374d88072..3b49d1a4d 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except Exception as e:
pass
+
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
@@ -168,11 +169,12 @@ def to_float32(preds):
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
- preds = preds.astype(paddle.float32)
+ preds = paddle.to_tensor(preds, dtype='float32')
return preds
+
def train(config,
train_dataloader,
valid_dataloader,
@@ -225,7 +227,9 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "RobustScanner"]
+ extra_input_models = [
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
+ ]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
@@ -267,7 +271,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
-
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
@@ -308,6 +311,9 @@ def train(config,
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
+ elif config['Loss']['name'] in ['VLLoss']:
+ post_result = post_process_class(preds, batch[1],
+ batch[-1])
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
@@ -370,7 +376,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
- extra_input=extra_input)
+ extra_input=extra_input,
+ scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
@@ -460,7 +467,8 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
- extra_input=False):
+ extra_input=False,
+ scaler=None):
model.eval()
with paddle.no_grad():
total_frame = 0.0
@@ -477,12 +485,24 @@ def eval(model,
break
images = batch[0]
start = time.time()
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie", 'vqa']:
- preds = model(batch)
+
+ # use amp
+ if scaler:
+ with paddle.amp.auto_cast(level='O2'):
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
+ else:
+ preds = model(images)
else:
- preds = model(images)
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
+ else:
+ preds = model(images)
+
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
@@ -596,7 +616,8 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'RobustScanner'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
+ 'RobustScanner'
]
if use_xpu:
@@ -615,7 +636,7 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
- log_writer = VDLLogger(save_model_dir)
+ log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
diff --git a/tools/train.py b/tools/train.py
index 309d4bb9e..dc8cae8a6 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
- if config['Global']['distributed']:
- model = paddle.DataParallel(model)
-
model = apply_to_static(model, config, logger)
# build loss
@@ -157,10 +154,13 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
- model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
+ model, optimizer = paddle.amp.decorate(
+ models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None
+ if config['Global']['distributed']:
+ model = paddle.DataParallel(model)
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,