add re predict
parent
0a59848d3a
commit
06194524ca
|
@ -68,6 +68,7 @@ Train:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
@ -83,7 +84,6 @@ Train:
|
|||
drop_last: False
|
||||
batch_size_per_card: 2
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
|
@ -105,6 +105,7 @@ Eval:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
@ -120,4 +121,3 @@ Eval:
|
|||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
|
|
@ -73,6 +73,7 @@ Train:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
@ -88,7 +89,6 @@ Train:
|
|||
drop_last: False
|
||||
batch_size_per_card: 2
|
||||
num_workers: 4
|
||||
collate_fn: ListCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
|
@ -112,6 +112,7 @@ Eval:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
@ -127,5 +128,3 @@ Eval:
|
|||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
||||
|
|
|
@ -116,6 +116,7 @@ Train:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
@ -155,6 +156,7 @@ Eval:
|
|||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- TensorizeEntitiesRelations:
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
|模型|骨干网络|任务|配置文件|hmean|下载链接|
|
||||
| --- | --- |--|--- | --- | --- |
|
||||
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|
||||
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()|
|
||||
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
@ -52,14 +52,14 @@
|
|||
|
||||
### 4.1 Python推理
|
||||
|
||||
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。
|
||||
- SER
|
||||
|
||||
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
|
||||
|
||||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
|
||||
tar -xf ser_LayoutXLM_xfun_zh.tar
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
|
||||
```
|
||||
|
||||
LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
|
||||
|
@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
|
|||
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
|
||||
</div>
|
||||
|
||||
- RE
|
||||
|
||||
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
|
||||
|
||||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
|
||||
tar -xf re_LayoutXLM_xfun_zh.tar
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
|
||||
```
|
||||
|
||||
LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_layoutxlm_infer \
|
||||
--ser_model_dir=../inference/ser_layoutxlm_infer \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf
|
||||
```
|
||||
|
||||
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
|
||||
|
||||
<div align="center">
|
||||
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
|
||||
</div>
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
|
|
@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|
|||
|模型|骨干网络|任务|配置文件|hmean|下载链接|
|
||||
| --- | --- |---| --- | --- | --- |
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()|
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
|
@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|
|||
|
||||
### 4.1 Python推理
|
||||
|
||||
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。
|
||||
-SER
|
||||
|
||||
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
|
||||
|
||||
|
@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
|
|||
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
|
||||
</div>
|
||||
|
||||
-RE
|
||||
|
||||
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
|
||||
|
||||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
|
||||
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
|
||||
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
|
||||
```
|
||||
|
||||
VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_vi_layoutxlm_infer \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
|
||||
--use_visual_backbone=False \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
```
|
||||
|
||||
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
|
||||
|
||||
<div align="center">
|
||||
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
|
||||
</div>
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
|
|
@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|
|||
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
|
||||
| --- | --- |--|--- | --- | --- |
|
||||
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|
||||
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()|
|
||||
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
|
||||
|
||||
|
||||
## 2. Environment
|
||||
|
@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
|
|||
|
||||
### 4.1 Python Inference
|
||||
|
||||
**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model.
|
||||
- SER
|
||||
|
||||
First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export.
|
||||
|
||||
|
@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM
|
|||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
|
||||
tar -xf ser_LayoutXLM_xfun_zh.tar
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
|
||||
```
|
||||
|
||||
Use the following command to infer using LayoutXLM SER model.
|
||||
|
@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default.
|
|||
</div>
|
||||
|
||||
|
||||
- RE
|
||||
|
||||
First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export.
|
||||
|
||||
|
||||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
|
||||
tar -xf re_LayoutXLM_xfun_zh.tar
|
||||
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer
|
||||
```
|
||||
|
||||
Use the following command to infer using LayoutXLM RE model.
|
||||
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_layoutxlm_infer \
|
||||
--ser_model_dir=../inference/ser_layoutxlm_infer \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf
|
||||
```
|
||||
The RE visualization results are saved in the `./output` directory by default. The results are as follows.
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
|
|
@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|
|||
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
|
||||
| --- | --- |---| --- | --- | --- |
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()|
|
||||
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
|
||||
|
||||
|
||||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
|
||||
|
@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
|
|||
|
||||
### 4.1 Python Inference
|
||||
|
||||
**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model.
|
||||
-SER
|
||||
|
||||
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
|
||||
|
||||
|
@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The
|
|||
</div>
|
||||
|
||||
|
||||
-RE
|
||||
|
||||
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
|
||||
|
||||
|
||||
``` bash
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
|
||||
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
|
||||
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
|
||||
```
|
||||
|
||||
Use the following command to infer using VI-LayoutXLM RE model.
|
||||
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_vi_layoutxlm_infer \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
|
||||
--use_visual_backbone=False \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
```
|
||||
|
||||
The RE visualization results are saved in the `./output` folder by default. The results are as follows.
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
### 4.2 C++ Inference
|
||||
|
||||
Not supported
|
||||
|
|
|
@ -12,11 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
|
||||
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
|
||||
|
||||
__all__ = [
|
||||
'VQATokenPad',
|
||||
'VQASerTokenChunk',
|
||||
'VQAReTokenChunk',
|
||||
'VQAReTokenRelation',
|
||||
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
|
||||
'TensorizeEntitiesRelations'
|
||||
]
|
||||
|
|
|
@ -15,3 +15,4 @@
|
|||
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
|
||||
from .vqa_token_pad import VQATokenPad
|
||||
from .vqa_token_relation import VQAReTokenRelation
|
||||
from .vqa_re_convert import TensorizeEntitiesRelations
|
|
@ -0,0 +1,51 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TensorizeEntitiesRelations(object):
|
||||
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
entities = data['entities']
|
||||
relations = data['relations']
|
||||
|
||||
entities_new = np.full(
|
||||
shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
|
||||
entities_new[0, 0] = len(entities['start'])
|
||||
entities_new[0, 1] = len(entities['end'])
|
||||
entities_new[0, 2] = len(entities['label'])
|
||||
entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
|
||||
'start'])
|
||||
entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
|
||||
entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
|
||||
'label'])
|
||||
|
||||
relations_new = np.full(
|
||||
shape=[self.max_seq_len * self.max_seq_len + 1, 2],
|
||||
fill_value=-1,
|
||||
dtype='int64')
|
||||
relations_new[0, 0] = len(relations['head'])
|
||||
relations_new[0, 1] = len(relations['tail'])
|
||||
relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
|
||||
'head'])
|
||||
relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
|
||||
'tail'])
|
||||
|
||||
data['entities'] = entities_new
|
||||
data['relations'] = relations_new
|
||||
return data
|
|
@ -37,23 +37,25 @@ class VQAReTokenMetric(object):
|
|||
gt_relations = []
|
||||
for b in range(len(self.relations_list)):
|
||||
rel_sent = []
|
||||
if "head" in self.relations_list[b]:
|
||||
for head, tail in zip(self.relations_list[b]["head"],
|
||||
self.relations_list[b]["tail"]):
|
||||
relation_list = self.relations_list[b]
|
||||
entitie_list = self.entities_list[b]
|
||||
head_len = relation_list[0, 0]
|
||||
if head_len > 0:
|
||||
entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
|
||||
entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
|
||||
entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
|
||||
for head, tail in zip(relation_list[1:head_len + 1, 0],
|
||||
relation_list[1:head_len + 1, 1]):
|
||||
rel = {}
|
||||
rel["head_id"] = head
|
||||
rel["head"] = (
|
||||
self.entities_list[b]["start"][rel["head_id"]],
|
||||
self.entities_list[b]["end"][rel["head_id"]])
|
||||
rel["head_type"] = self.entities_list[b]["label"][rel[
|
||||
"head_id"]]
|
||||
rel["head"] = (entitie_start_list[head],
|
||||
entitie_end_list[head])
|
||||
rel["head_type"] = entitie_label_list[head]
|
||||
|
||||
rel["tail_id"] = tail
|
||||
rel["tail"] = (
|
||||
self.entities_list[b]["start"][rel["tail_id"]],
|
||||
self.entities_list[b]["end"][rel["tail_id"]])
|
||||
rel["tail_type"] = self.entities_list[b]["label"][rel[
|
||||
"tail_id"]]
|
||||
rel["tail"] = (entitie_start_list[tail],
|
||||
entitie_end_list[tail])
|
||||
rel["tail_type"] = entitie_label_list[tail]
|
||||
|
||||
rel["type"] = 1
|
||||
rel_sent.append(rel)
|
||||
|
|
|
@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel):
|
|||
def forward(self, x):
|
||||
if self.use_visual_backbone is True:
|
||||
image = x[4]
|
||||
entities = x[5]
|
||||
relations = x[6]
|
||||
else:
|
||||
image = None
|
||||
entities = x[4]
|
||||
relations = x[5]
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
|
@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel):
|
|||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None,
|
||||
entities=x[5],
|
||||
relations=x[6])
|
||||
entities=entities,
|
||||
relations=relations)
|
||||
return x
|
||||
|
|
|
@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object):
|
|||
super(VQAReTokenLayoutLMPostProcess, self).__init__()
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
pred_relations = preds['pred_relations']
|
||||
if isinstance(preds['pred_relations'], paddle.Tensor):
|
||||
pred_relations = pred_relations.numpy()
|
||||
pred_relations = self.decode_pred(pred_relations)
|
||||
|
||||
if label is not None:
|
||||
return self._metric(preds, label)
|
||||
return self._metric(pred_relations, label)
|
||||
else:
|
||||
return self._infer(preds, *args, **kwargs)
|
||||
return self._infer(pred_relations, *args, **kwargs)
|
||||
|
||||
def _metric(self, preds, label):
|
||||
return preds['pred_relations'], label[6], label[5]
|
||||
def _metric(self, pred_relations, label):
|
||||
return pred_relations, label[6], label[5]
|
||||
|
||||
def _infer(self, preds, *args, **kwargs):
|
||||
def _infer(self, pred_relations, *args, **kwargs):
|
||||
ser_results = kwargs['ser_results']
|
||||
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
|
||||
pred_relations = preds['pred_relations']
|
||||
|
||||
# merge relations and ocr info
|
||||
results = []
|
||||
|
@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object):
|
|||
results.append(result)
|
||||
return results
|
||||
|
||||
def decode_pred(self, pred_relations):
|
||||
pred_relations_new = []
|
||||
for pred_relation in pred_relations:
|
||||
pred_relation_new = []
|
||||
pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1]
|
||||
for relation in pred_relation:
|
||||
relation_new = dict()
|
||||
relation_new['head_id'] = relation[0, 0]
|
||||
relation_new['head'] = tuple(relation[1])
|
||||
relation_new['head_type'] = relation[2, 0]
|
||||
relation_new['tail_id'] = relation[3, 0]
|
||||
relation_new['tail'] = tuple(relation[4])
|
||||
relation_new['tail_type'] = relation[5, 0]
|
||||
relation_new['type'] = relation[6, 0]
|
||||
pred_relation_new.append(relation_new)
|
||||
pred_relations_new.append(pred_relation_new)
|
||||
return pred_relations_new
|
||||
|
||||
|
||||
class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
|
||||
"""
|
||||
|
|
|
@ -51,9 +51,9 @@
|
|||
|模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址|
|
||||
| --- | --- | --- |--- |--- | --- |
|
||||
|ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) |
|
||||
|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
|
||||
|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
|
||||
|ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|
||||
|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|
||||
|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|
||||
|ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|
||||
|re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|
||||
|ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|
||||
|
|
|
@ -209,17 +209,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
|
|||
|
||||
#### 4.2.3 Inference using PaddleInference
|
||||
|
||||
At present, only SER model supports inference using PaddleInference.
|
||||
|
||||
Firstly, download the inference SER inference model.
|
||||
|
||||
|
||||
```bash
|
||||
mkdir inference
|
||||
cd inference
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
|
||||
cd ..
|
||||
```
|
||||
|
||||
-SER
|
||||
|
||||
Use the following command for inference.
|
||||
|
||||
|
||||
|
@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \
|
|||
|
||||
The visual results and text file will be saved in directory `output`.
|
||||
|
||||
-RE
|
||||
|
||||
Use the following command for inference.
|
||||
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--use_visual_backbone=False \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
```
|
||||
|
||||
The visual results and text file will be saved in directory `output`.
|
||||
|
||||
|
||||
### 4.3 More
|
||||
|
||||
|
|
|
@ -193,17 +193,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
|
|||
|
||||
#### 4.2.3 基于PaddleInference的预测
|
||||
|
||||
目前仅SER模型支持PaddleInference推理。
|
||||
|
||||
首先下载SER的推理模型。
|
||||
|
||||
首先下载SER和RE的推理模型。
|
||||
|
||||
```bash
|
||||
mkdir inference
|
||||
cd inference
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
|
||||
cd ..
|
||||
```
|
||||
|
||||
-SER
|
||||
|
||||
执行下面的命令进行预测。
|
||||
|
||||
```bash
|
||||
|
@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \
|
|||
|
||||
可视化结果保存在`output`目录下。
|
||||
|
||||
-RE
|
||||
|
||||
执行下面的命令进行预测。
|
||||
|
||||
```bash
|
||||
cd ppstructure
|
||||
python3 kie/predict_kie_token_ser_re.py \
|
||||
--kie_algorithm=LayoutXLM \
|
||||
--re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
|
||||
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
|
||||
--use_visual_backbone=False \
|
||||
--image_dir=./docs/kie/input/zh_val_42.jpg \
|
||||
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
|
||||
--vis_font_path=../doc/fonts/simfang.ttf \
|
||||
--ocr_order_method="tb-yx"
|
||||
```
|
||||
|
||||
可视化结果保存在`output`目录下。
|
||||
|
||||
|
||||
### 4.3 更多
|
||||
|
||||
关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md)。
|
||||
|
|
|
@ -102,16 +102,18 @@ class SerPredictor(object):
|
|||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img = data[0]
|
||||
if img is None:
|
||||
if data[0] is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
for idx in range(len(data)):
|
||||
if isinstance(data[idx], np.ndarray):
|
||||
data[idx] = np.expand_dims(data[idx], axis=0)
|
||||
else:
|
||||
data[idx] = [data[idx]]
|
||||
|
||||
for idx in range(len(self.input_tensor)):
|
||||
expand_input = np.expand_dims(data[idx], axis=0)
|
||||
self.input_tensor[idx].copy_from_cpu(expand_input)
|
||||
self.input_tensor[idx].copy_from_cpu(data[idx])
|
||||
|
||||
self.predictor.run()
|
||||
|
||||
|
@ -122,9 +124,9 @@ class SerPredictor(object):
|
|||
preds = outputs[0]
|
||||
|
||||
post_result = self.postprocess_op(
|
||||
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
|
||||
preds, segment_offset_ids=data[6], ocr_infos=data[7])
|
||||
elapse = time.time() - starttime
|
||||
return post_result, elapse
|
||||
return post_result, data, elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -145,7 +147,7 @@ def main(args):
|
|||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
ser_res, elapse = ser_predictor(img)
|
||||
ser_res, _, elapse = ser_predictor(img)
|
||||
ser_res = ser_res[0]
|
||||
|
||||
res_str = '{}\t{}\n'.format(
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer_kie_token_ser_re import make_input
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.visual import draw_re_results
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read
|
||||
from ppstructure.utility import parse_args
|
||||
from ppstructure.kie.predict_kie_token_ser import SerPredictor
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class SerRePredictor(object):
|
||||
def __init__(self, args):
|
||||
self.use_visual_backbone = args.use_visual_backbone
|
||||
self.ser_engine = SerPredictor(args)
|
||||
|
||||
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 're', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
starttime = time.time()
|
||||
ser_results, ser_inputs, _ = self.ser_engine(img)
|
||||
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
|
||||
if self.use_visual_backbone == False:
|
||||
re_input.pop(4)
|
||||
for idx in range(len(self.input_tensor)):
|
||||
self.input_tensor[idx].copy_from_cpu(re_input[idx])
|
||||
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = dict(loss=outputs[0], pred_relations=outputs[1])
|
||||
|
||||
post_result = self.postprocess_op(
|
||||
preds,
|
||||
ser_results=ser_results,
|
||||
entity_idx_dict_batch=entity_idx_dict_batch)
|
||||
|
||||
elapse = time.time() - starttime
|
||||
return post_result, elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
ser_predictor = SerRePredictor(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
with open(
|
||||
os.path.join(args.output, 'infer.txt'), mode='w',
|
||||
encoding='utf-8') as f_w:
|
||||
for image_file in image_file_list:
|
||||
img, flag, _ = check_and_read(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
img = img[:, :, ::-1]
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
re_res, elapse = ser_predictor(img)
|
||||
re_res = re_res[0]
|
||||
|
||||
res_str = '{}\t{}\n'.format(
|
||||
image_file,
|
||||
json.dumps(
|
||||
{
|
||||
"ocr_info": re_res,
|
||||
}, ensure_ascii=False))
|
||||
f_w.write(res_str)
|
||||
|
||||
img_res = draw_re_results(
|
||||
image_file, re_res, font_path=args.vis_font_path)
|
||||
|
||||
img_save_path = os.path.join(
|
||||
args.output,
|
||||
os.path.splitext(os.path.basename(image_file))[0] +
|
||||
"_ser_re.jpg")
|
||||
|
||||
cv2.imwrite(img_save_path, img_res)
|
||||
logger.info("save vis result to {}".format(img_save_path))
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
|
@ -52,6 +52,8 @@ def init_args():
|
|||
# params for kie
|
||||
parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM')
|
||||
parser.add_argument("--ser_model_dir", type=str)
|
||||
parser.add_argument("--re_model_dir", type=str)
|
||||
parser.add_argument("--use_visual_backbone", type=str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--ser_dict_path",
|
||||
type=str,
|
||||
|
|
|
@ -115,16 +115,12 @@ def export_single_model(model,
|
|||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
|
||||
[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, ],
|
||||
dtype="float32"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, max_text_length],
|
||||
dtype="int64")
|
||||
]
|
||||
shape=[None, 3, 48, 160], dtype="float32"), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, ], dtype="float32"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, max_text_length], dtype="int64")
|
||||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
|
@ -140,6 +136,13 @@ def export_single_model(model,
|
|||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 224, 224], dtype="int64"), # image
|
||||
]
|
||||
if 'Re' in arch_config['Backbone']['name']:
|
||||
input_spec.extend([
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 512, 3], dtype="int64"), # entities
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, None, 2], dtype="int64"), # relations
|
||||
])
|
||||
if model.backbone.use_visual_backbone is False:
|
||||
input_spec.pop(4)
|
||||
model = to_static(model, input_spec=[input_spec])
|
||||
|
|
|
@ -162,6 +162,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.table_model_dir
|
||||
elif mode == 'ser':
|
||||
model_dir = args.ser_model_dir
|
||||
elif mode == 're':
|
||||
model_dir = args.re_model_dir
|
||||
elif mode == "sr":
|
||||
model_dir = args.sr_model_dir
|
||||
elif mode == 'layout':
|
||||
|
@ -227,7 +229,8 @@ def create_predictor(args, mode, logger):
|
|||
use_calib_mode=False)
|
||||
|
||||
# collect shape
|
||||
trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt")
|
||||
trt_shape_f = os.path.join(model_dir,
|
||||
f"{mode}_trt_dynamic_shape.txt")
|
||||
|
||||
if not os.path.exists(trt_shape_f):
|
||||
config.collect_shape_range_info(trt_shape_f)
|
||||
|
@ -262,6 +265,8 @@ def create_predictor(args, mode, logger):
|
|||
config.disable_glog_info()
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.delete_pass("matmul_transpose_reshape_fuse_pass")
|
||||
if mode == 're':
|
||||
config.delete_pass("simplify_with_basic_ops_pass")
|
||||
if mode == 'table':
|
||||
config.delete_pass("fc_fuse_pass") # not supported for table
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
|
|
|
@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser):
|
|||
|
||||
def make_input(ser_inputs, ser_results):
|
||||
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
||||
|
||||
batch_size, max_seq_len = ser_inputs[0].shape[:2]
|
||||
entities = ser_inputs[8][0]
|
||||
ser_results = ser_results[0]
|
||||
assert len(entities) == len(ser_results)
|
||||
|
@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results):
|
|||
start.append(entity['start'])
|
||||
end.append(entity['end'])
|
||||
label.append(entities_labels[res['pred']])
|
||||
entities = dict(start=start, end=end, label=label)
|
||||
|
||||
entities = np.full([max_seq_len + 1, 3], fill_value=-1)
|
||||
entities[0, 0] = len(start)
|
||||
entities[1:len(start) + 1, 0] = start
|
||||
entities[0, 1] = len(end)
|
||||
entities[1:len(end) + 1, 1] = end
|
||||
entities[0, 2] = len(label)
|
||||
entities[1:len(label) + 1, 2] = label
|
||||
|
||||
# relations
|
||||
head = []
|
||||
tail = []
|
||||
for i in range(len(entities["label"])):
|
||||
for j in range(len(entities["label"])):
|
||||
if entities["label"][i] == 1 and entities["label"][j] == 2:
|
||||
for i in range(len(label)):
|
||||
for j in range(len(label)):
|
||||
if label[i] == 1 and label[j] == 2:
|
||||
head.append(i)
|
||||
tail.append(j)
|
||||
|
||||
relations = dict(head=head, tail=tail)
|
||||
relations = np.full([len(head) + 1, 2], fill_value=-1)
|
||||
relations[0, 0] = len(head)
|
||||
relations[1:len(head) + 1, 0] = head
|
||||
relations[0, 1] = len(tail)
|
||||
relations[1:len(tail) + 1, 1] = tail
|
||||
|
||||
entities = np.expand_dims(entities, axis=0)
|
||||
entities = np.repeat(entities, batch_size, axis=0)
|
||||
relations = np.expand_dims(relations, axis=0)
|
||||
relations = np.repeat(relations, batch_size, axis=0)
|
||||
|
||||
# remove ocr_info segment_offset_id and label in ser input
|
||||
if isinstance(ser_inputs[0], paddle.Tensor):
|
||||
entities = paddle.to_tensor(entities)
|
||||
relations = paddle.to_tensor(relations)
|
||||
ser_inputs = ser_inputs[:5] + [entities, relations]
|
||||
|
||||
batch_size = ser_inputs[0].shape[0]
|
||||
entities_batch = []
|
||||
relations_batch = []
|
||||
entity_idx_dict_batch = []
|
||||
for b in range(batch_size):
|
||||
entities_batch.append(entities)
|
||||
relations_batch.append(relations)
|
||||
entity_idx_dict_batch.append(entity_idx_dict)
|
||||
|
||||
ser_inputs[8] = entities_batch
|
||||
ser_inputs.append(relations_batch)
|
||||
# remove ocr_info segment_offset_id and label in ser input
|
||||
ser_inputs.pop(7)
|
||||
ser_inputs.pop(6)
|
||||
ser_inputs.pop(5)
|
||||
return ser_inputs, entity_idx_dict_batch
|
||||
|
||||
|
||||
|
@ -136,6 +146,8 @@ class SerRePredictor(object):
|
|||
def __call__(self, data):
|
||||
ser_results, ser_inputs = self.ser_engine(data)
|
||||
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
|
||||
if self.model.backbone.use_visual_backbone is False:
|
||||
re_input.pop(4)
|
||||
preds = self.model(re_input)
|
||||
post_result = self.post_process_class(
|
||||
preds,
|
||||
|
|
Loading…
Reference in New Issue