Merge branch 'dygraph' into robustscanner_branch

pull/6842/head
smilelite 2022-08-17 21:49:11 +08:00 committed by GitHub
commit 40c45e2ccc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
97 changed files with 4900 additions and 1012 deletions

2
.gitignore vendored
View File

@ -31,4 +31,4 @@ paddleocr.egg-info/
/deploy/android_demo/app/.cxx/
/deploy/android_demo/app/cache/
test_tipc/web/models/
test_tipc/web/node_modules/
test_tipc/web/node_modules/

View File

@ -0,0 +1,143 @@
Global:
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/SLANet
save_epoch_step: 400
# evaluation is run every 1000 iterations after the 0th iteration
eval_batch_step: [0, 1000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir: ./output/SLANet/infer
use_visualdl: False
infer_img: doc/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: &max_text_length 500
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
use_sync_bn: True
save_res_path: 'output/infer'
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
name: Piecewise
learning_rate: 0.001
decay_epochs : [40, 50]
values : [0.001, 0.0001, 0.00005]
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: table
algorithm: SLANet
Backbone:
name: PPLCNet
scale: 1.0
pretrained: true
use_ssld: true
Neck:
name: CSPPAN
out_channels: 96
Head:
name: SLAHead
hidden_size: 256
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: SLALoss
structure_weight: 1.0
loc_weight: 2.0
loc_loss: smooth_l1
PostProcess:
name: TableLabelDecode
merge_no_span_structure: &merge_no_span_structure True
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
loc_reg_num: *loc_reg_num
box_format: *box_format
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
in_box_format: *box_format
out_box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 48
drop_last: True
num_workers: 1
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
in_box_format: *box_format
out_box_format: *box_format
- ResizeTableImage:
max_len: 488
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 48
num_workers: 1

View File

@ -8,16 +8,15 @@ Global:
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
max_text_length: &max_text_length 500
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
Optimizer:
@ -52,7 +51,8 @@ Architecture:
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: TableMasterLoss
@ -61,11 +61,13 @@ Loss:
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
merge_no_span_structure: &merge_no_span_structure True
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
box_format: *box_format
Train:
dataset:
@ -78,15 +80,18 @@ Train:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
in_box_format: *box_format
out_box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
@ -112,15 +117,18 @@ Eval:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
in_box_format: *box_format
out_box_format: *box_format
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]

View File

@ -17,10 +17,9 @@ Global:
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 800
max_text_length: &max_text_length 800
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
@ -44,7 +43,8 @@ Architecture:
name: TableAttentionHead
hidden_size: 256
loc_type: 2
max_text_length: 800
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss:
name: TableAttentionLoss
@ -72,6 +72,8 @@ Train:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
@ -94,8 +96,8 @@ Train:
Eval:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
data_dir: train_data/table/pubtabnet/val/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
@ -104,6 +106,8 @@ Eval:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode:
- ResizeTableImage:
max_len: 488

View File

@ -30,7 +30,8 @@ DECLARE_string(image_dir);
DECLARE_string(type);
// detection related
DECLARE_string(det_model_dir);
DECLARE_int32(max_side_len);
DECLARE_string(limit_type);
DECLARE_int32(limit_side_len);
DECLARE_double(det_db_thresh);
DECLARE_double(det_db_box_thresh);
DECLARE_double(det_db_unclip_ratio);
@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// structure model related
DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len);
DECLARE_int32(table_batch_num);
DECLARE_string(table_char_dict_path);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
DECLARE_bool(table);

View File

@ -41,8 +41,8 @@ public:
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const int &max_side_len,
const double &det_db_thresh,
const bool &use_mkldnn, const string &limit_type,
const int &limit_side_len, const double &det_db_thresh,
const double &det_db_box_thresh,
const double &det_db_unclip_ratio,
const std::string &det_db_score_mode,
@ -54,7 +54,8 @@ public:
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->max_side_len_ = max_side_len;
this->limit_type_ = limit_type;
this->limit_side_len_ = limit_side_len;
this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh;
@ -84,7 +85,8 @@ private:
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
int max_side_len_ = 960;
string limit_type_ = "max";
int limit_side_len_ = 960;
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
@ -106,7 +108,7 @@ private:
Permute permute_op_;
// post-process
PostProcessor post_processor_;
DBPostProcessor post_processor_;
};
} // namespace PaddleOCR

View File

@ -47,11 +47,7 @@ public:
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void rec(std::vector<cv::Mat> img_list,
@ -62,6 +58,11 @@ private:
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
};
} // namespace PaddleOCR

View File

@ -0,0 +1,79 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/paddleocr.h>
#include <include/preprocess_op.h>
#include <include/structure_table.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class PaddleStructure : public PPOCR {
public:
explicit PaddleStructure();
~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
bool table = true);
private:
StructureTableRecognizer *recognizer_ = nullptr;
void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string
rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<std::vector<int>>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
float dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
static bool comparison_dis(const std::vector<float> &dis1,
const std::vector<float> &dis2) {
if (dis1[1] < dis2[1]) {
return true;
} else if (dis1[1] == dis2[1]) {
return dis1[0] < dis2[0];
} else {
return false;
}
}
};
} // namespace PaddleOCR

View File

@ -34,7 +34,7 @@ using namespace std;
namespace PaddleOCR {
class PostProcessor {
class DBPostProcessor {
public:
void GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance);
@ -90,4 +90,21 @@ private:
}
};
class TablePostProcessor {
public:
void init(std::string label_path);
void
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list);
private:
std::vector<std::string> label_list_;
std::string end = "eos";
std::string beg = "sos";
};
} // namespace PaddleOCR

View File

@ -48,11 +48,12 @@ class PermuteBatch {
public:
virtual void Run(const std::vector<cv::Mat> imgs, float *data);
};
class ResizeImgType0 {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w, bool use_tensorrt);
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type,
int limit_side_len, float &ratio_h, float &ratio_w,
bool use_tensorrt);
};
class CrnnResizeImg {
@ -69,4 +70,16 @@ public:
const std::vector<int> &rec_image_shape = {3, 48, 192});
};
class TableResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};
class TablePadImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};
} // namespace PaddleOCR

View File

@ -0,0 +1,100 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class StructureTableRecognizer {
public:
explicit StructureTableRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->table_batch_num_ = table_batch_num;
this->table_max_len_ = table_max_len;
this->post_processor_.init(label_path);
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &rec_html_tags,
std::vector<float> &rec_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes,
std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
int table_max_len_ = 488;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int table_batch_num_ = 1;
// pre-process
TableResizeImg resize_op_;
Normalize normalize_op_;
PermuteBatch permute_op_;
TablePadImg pad_op_;
// post-process
TablePostProcessor post_processor_;
}; // class StructureTableRecognizer
} // namespace PaddleOCR

View File

@ -40,6 +40,14 @@ struct OCRPredictResult {
int cls_label = -1;
};
struct StructurePredictResult {
std::vector<int> box;
std::string type;
std::vector<OCRPredictResult> text_res;
std::string html;
float html_score = -1;
};
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
@ -68,6 +76,22 @@ public:
static void CreateDir(const std::string &path);
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area);
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
private:
static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) {
if (result1.box[0][1] < result2.box[0][1]) {
return true;
} else if (result1.box[0][1] == result2.box[0][1]) {
return result1.box[0][0] < result2.box[0][0];
} else {
return false;
}
}
};
} // namespace PaddleOCR

View File

@ -171,6 +171,9 @@ inference/
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
|-- table
| |--inference.pdiparams
| |--inference.pdmodel
```
@ -275,6 +278,17 @@ Specifically,
--cls=true \
```
##### 7. table
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--type=structure \
--table=true
```
More parameters are as follows,
- Common parameters
@ -293,9 +307,9 @@ More parameters are as follows,
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|det|bool|true|前向是否执行文字检测|
|rec|bool|true|前向是否执行文字识别|
|cls|bool|false|前向是否执行文字方向分类|
|det|bool|true|Whether to perform text detection in the forward direction|
|rec|bool|true|Whether to perform text recognition in the forward direction|
|cls|bool|false|Whether to perform text direction classification in the forward direction|
- Detection related parameters
@ -329,6 +343,15 @@ More parameters are as follows,
|rec_img_h|int|48|image height of recognition|
|rec_img_w|int|320|image width of recognition|
- Table recognition related parameters
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|table_model_dir|string|-|Address of table recognition inference model|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file|
|table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network istable_max_lentable_max_len|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
@ -344,6 +367,12 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg
```
- table
```bash
predict img: ../../ppstructure/docs/table/table.jpg
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
```
<a name="3"></a>
## 3. FAQ

View File

@ -181,6 +181,9 @@ inference/
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
|-- table
| |--inference.pdiparams
| |--inference.pdmodel
```
<a name="22"></a>
@ -285,6 +288,16 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
--cls=true \
```
##### 7. 表格识别
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--type=structure \
--table=true
```
更多支持的可调节参数解释如下:
- 通用参数
@ -328,21 +341,32 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|cls_thresh|float|0.9|方向分类器的得分阈值|
|cls_batch_num|int|1|方向分类器batchsize|
- 识别模型相关
- 文字识别模型相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|rec_model_dir|string|-|识别模型inference model地址|
|rec_model_dir|string|-|文字识别模型inference model地址|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|rec_batch_num|int|6|识别模型batchsize|
|rec_img_h|int|48|识别模型输入图像高度|
|rec_img_w|int|320|识别模型输入图像宽度|
|rec_batch_num|int|6|文字识别模型batchsize|
|rec_img_h|int|48|文字识别模型输入图像高度|
|rec_img_w|int|320|文字识别模型输入图像宽度|
- 表格识别模型相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|table_model_dir|string|-|表格识别模型inference model地址|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件|
|table_max_len|int|488|表格识别模型输入图像长边大小最终网络输入图像大小为table_max_lentable_max_len|
* PaddleOCR也支持多语言的预测更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`inference模型路径字段即可。
最终屏幕上会输出检测结果如下。
- ocr
```bash
predict img: ../../doc/imgs/12.jpg
../../doc/imgs/12.jpg
@ -353,6 +377,13 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg
```
- table
```bash
predict img: ../../ppstructure/docs/table/table.jpg
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
```
<a name="3"></a>
## 3. FAQ

View File

@ -30,7 +30,8 @@ DEFINE_string(
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_string(limit_type, "max", "limit_type of input image.");
DEFINE_int32(limit_side_len, 960, "limit_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// structure model related
DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image.");
DEFINE_int32(table_batch_num, 1, "table_batch_num.");
DEFINE_string(table_char_dict_path,
"../../ppocr/utils/dict/table_structure_dict.txt",
"Path of dictionary.");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward.");

View File

@ -19,6 +19,7 @@
#include <include/args.h>
#include <include/paddleocr.h>
#include <include/paddlestructure.h>
using namespace PaddleOCR;
@ -32,6 +33,12 @@ void check_params() {
}
}
if (FLAGS_rec) {
std::cout
<< "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320',"
"if you are using recognition model with PP-OCRv2 or an older "
"version, "
"please set --rec_image_shape='3,32,320"
<< std::endl;
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[rec]: ./ppocr "
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
@ -47,6 +54,17 @@ void check_params() {
exit(1);
}
}
if (FLAGS_table) {
if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() ||
FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[table]: ./ppocr "
<< "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
}
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
FLAGS_precision != "int8") {
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
@ -54,21 +72,7 @@ void check_params() {
}
}
int main(int argc, char **argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params();
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
void ocr(std::vector<cv::String> &cv_all_img_names) {
PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
@ -109,3 +113,49 @@ int main(int argc, char **argv) {
}
}
}
void structure(std::vector<cv::String> &cv_all_img_names) {
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
std::vector<std::vector<StructurePredictResult>> structure_results =
engine.structure(cv_all_img_names, false, FLAGS_table);
for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << "predict img: " << cv_all_img_names[i] << endl;
for (int j = 0; j < structure_results[i].size(); j++) {
std::cout << j << "\ttype: " << structure_results[i][j].type
<< ", region: [";
std::cout << structure_results[i][j].box[0] << ","
<< structure_results[i][j].box[1] << ","
<< structure_results[i][j].box[2] << ","
<< structure_results[i][j].box[3] << "], res: ";
if (structure_results[i][j].type == "table") {
std::cout << structure_results[i][j].html << std::endl;
} else {
Utility::print_result(structure_results[i][j].text_res);
}
}
}
}
int main(int argc, char **argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params();
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
if (FLAGS_type == "ocr") {
ocr(cv_all_img_names);
} else if (FLAGS_type == "structure") {
structure(cv_all_img_names);
} else {
std::cout << "only value in ['ocr','structure'] is supported" << endl;
}
}

View File

@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
{"elementwise_add_7", {1, 56, 2, 2}},
{"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
{"x", {1, 3, 1536, 1536}},
{"conv2d_92.tmp_0", {1, 120, 400, 400}},
{"conv2d_91.tmp_0", {1, 24, 200, 200}},
{"conv2d_59.tmp_0", {1, 96, 400, 400}},
@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img,
img.copyTo(srcimg);
auto preprocess_start = std::chrono::steady_clock::now();
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->resize_op_.Run(img, resize_img, this->limit_type_,
this->limit_side_len_, ratio_h, ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,

View File

@ -23,10 +23,10 @@ PPOCR::PPOCR() {
if (FLAGS_det) {
this->detector_ = new DBDetector(
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
FLAGS_precision);
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type,
FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh,
FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
}
if (FLAGS_cls && FLAGS_use_angle_cls) {
@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
res.box = boxes[i];
ocr_results.push_back(res);
}
// sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results);
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];

View File

@ -0,0 +1,272 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/args.h>
#include <include/paddlestructure.h>
#include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR {
PaddleStructure::PaddleStructure() {
if (FLAGS_table) {
this->recognizer_ = new StructureTableRecognizer(
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
FLAGS_table_max_len);
}
};
std::vector<std::vector<StructurePredictResult>>
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
bool layout, bool table) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<double> time_info_table = {0, 0, 0};
std::vector<std::vector<StructurePredictResult>> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<StructurePredictResult> structure_result;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
if (layout) {
} else {
StructurePredictResult res;
res.type = "table";
res.box = std::vector<int>(4, 0);
res.box[2] = srcimg.cols;
res.box[3] = srcimg.rows;
structure_result.push_back(res);
}
cv::Mat roi_img;
for (int i = 0; i < structure_result.size(); i++) {
// crop image
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
if (structure_result[i].type == "table") {
this->table(roi_img, structure_result[i], time_info_table,
time_info_det, time_info_rec, time_info_cls);
}
}
structure_results.push_back(structure_result);
}
return structure_results;
};
void PaddleStructure::table(cv::Mat img,
StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls) {
// predict structure
std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<std::vector<int>>>> structure_boxes;
std::vector<double> structure_imes;
std::vector<cv::Mat> img_list;
img_list.push_back(img);
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_imes);
time_info_table[0] += structure_imes[0];
time_info_table[1] += structure_imes[1];
time_info_table[2] += structure_imes[2];
std::vector<OCRPredictResult> ocr_result;
std::string html;
int expand_pixel = 3;
for (int i = 0; i < img_list.size(); i++) {
// det
this->det(img_list[i], ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> rec_img_list;
for (int j = 0; j < ocr_result.size(); j++) {
int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0],
ocr_result[j].box[2][0], ocr_result[j].box[3][0]};
int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1],
ocr_result[j].box[2][1], ocr_result[j].box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<int> box{max(0, left - expand_pixel),
max(0, top - expand_pixel),
min(img_list[i].cols, right + expand_pixel),
min(img_list[i].rows, bottom + expand_pixel)};
cv::Mat crop_img = Utility::crop_image(img_list[i], box);
rec_img_list.push_back(crop_img);
}
// rec
this->rec(rec_img_list, ocr_result, time_info_rec);
// rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result);
structure_result.html = html;
structure_result.html_score = structure_scores[i];
}
};
std::string PaddleStructure::rebuild_table(
std::vector<std::string> structure_html_tags,
std::vector<std::vector<std::vector<int>>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) {
// match text in same cell
std::vector<std::vector<string>> matched(structure_boxes.size(),
std::vector<std::string>());
for (int i = 0; i < ocr_result.size(); i++) {
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
std::vector<float>(3, 100000.0));
for (int j = 0; j < structure_boxes.size(); j++) {
int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0],
ocr_result[i].box[2][0], ocr_result[i].box[3][0]};
int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1],
ocr_result[i].box[2][1], ocr_result[i].box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<std::vector<int>> box(2, std::vector<int>(2, 0));
box[0][0] = left - 1;
box[0][1] = top - 1;
box[1][0] = right + 1;
box[1][1] = bottom + 1;
dis_list[j][0] = this->dis(box, structure_boxes[j]);
dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]);
dis_list[j][2] = j;
}
// find min dis idx
std::sort(dis_list.begin(), dis_list.end(),
PaddleStructure::comparison_dis);
matched[dis_list[0][2]].push_back(ocr_result[i].text);
}
// get pred html
std::string html_str = "";
int td_tag_idx = 0;
for (int i = 0; i < structure_html_tags.size(); i++) {
if (structure_html_tags[i].find("</td>") != std::string::npos) {
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "<td>";
}
if (matched[td_tag_idx].size() > 0) {
bool b_with = false;
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
matched[td_tag_idx].size() > 1) {
b_with = true;
html_str += "<b>";
}
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
std::string content = matched[td_tag_idx][j];
if (matched[td_tag_idx].size() > 1) {
// remove blank, <b> and </b>
if (content.length() > 0 && content.at(0) == ' ') {
content = content.substr(0);
}
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
content = content.substr(3);
}
if (content.length() > 4 &&
content.substr(content.length() - 4) == "</b>") {
content = content.substr(0, content.length() - 4);
}
if (content.empty()) {
continue;
}
// add blank
if (j != matched[td_tag_idx].size() - 1 &&
content.at(content.length() - 1) != ' ') {
content += ' ';
}
}
html_str += content;
}
if (b_with) {
html_str += "</b>";
}
}
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "</td>";
} else {
html_str += structure_html_tags[i];
}
td_tag_idx += 1;
} else {
html_str += structure_html_tags[i];
}
}
return html_str;
}
float PaddleStructure::iou(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2) {
int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]);
int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = max(box1[0][0], box2[0][0]);
int y1 = max(box1[0][1], box2[0][1]);
int x2 = min(box1[1][0], box2[1][0]);
int y2 = min(box1[1][1], box2[1][1]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float PaddleStructure::dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2) {
int x1_1 = box1[0][0];
int y1_1 = box1[0][1];
int x2_1 = box1[1][0];
int y2_1 = box1[1][1];
int x1_2 = box2[0][0];
int y1_2 = box2[0][1];
int x2_2 = box2[1][0];
int y2_2 = box2[1][1];
float dis =
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
return dis + min(dis_2, dis_3);
}
PaddleStructure::~PaddleStructure() {
if (this->recognizer_ != nullptr) {
delete this->recognizer_;
}
};
} // namespace PaddleOCR

View File

@ -17,8 +17,8 @@
namespace PaddleOCR {
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance) {
void DBPostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance) {
int pts_num = 4;
float area = 0.0f;
float dist = 0.0f;
@ -35,8 +35,8 @@ void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
distance = area * unclip_ratio / dist;
}
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio) {
cv::RotatedRect DBPostProcessor::UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio) {
float distance = 1.0;
GetContourArea(box, unclip_ratio, distance);
@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
return res;
}
float **PostProcessor::Mat2Vec(cv::Mat mat) {
float **DBPostProcessor::Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
for (int i = 0; i < mat.rows; ++i)
array[i] = new float[mat.cols];
@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) {
}
std::vector<std::vector<int>>
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
DBPostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
std::sort(box.begin(), box.end(), XsortInt);
@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
return rect;
}
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> DBPostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> img_vec;
std::vector<float> tmp;
@ -113,20 +113,20 @@ std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
return img_vec;
}
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
bool DBPostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
bool DBPostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
float &ssid) {
std::vector<std::vector<float>>
DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) {
ssid = std::max(box.size.width, box.size.height);
cv::Mat points;
@ -160,8 +160,8 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
return array;
}
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred) {
float DBPostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
@ -206,8 +206,8 @@ float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
return score;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) {
float DBPostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
int height = pred.rows;
@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
return score;
}
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
const int min_size = 3;
@ -321,9 +321,9 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
return boxes;
}
std::vector<std::vector<std::vector<int>>>
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg) {
std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w, cv::Mat srcimg) {
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
return root_points;
}
void TablePostProcessor::init(std::string label_path) {
this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.insert(this->label_list_.begin(), this->beg);
this->label_list_.push_back(this->end);
}
void TablePostProcessor::Run(
std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list) {
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
// image tags and boxs
std::vector<std::string> rec_html_tags;
std::vector<std::vector<std::vector<int>>> rec_boxes;
float score = 0.f;
int count = 0;
float char_score = 0.f;
int char_idx = 0;
// step
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
std::string html_tag;
std::vector<std::vector<int>> rec_box;
// html tag
int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
structure_probs_shape[2];
char_idx = int(Utility::argmax(
&structure_probs[step_start_idx],
&structure_probs[step_start_idx + structure_probs_shape[2]]));
char_score = float(*std::max_element(
&structure_probs[step_start_idx],
&structure_probs[step_start_idx + structure_probs_shape[2]]));
html_tag = this->label_list_[char_idx];
if (step_idx > 0 && html_tag == this->end) {
break;
}
if (html_tag == this->beg) {
continue;
}
count += 1;
score += char_score;
rec_html_tags.push_back(html_tag);
// box
if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") {
for (int point_idx = 0; point_idx < loc_preds_shape[2];
point_idx += 2) {
std::vector<int> point(2, 0);
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
loc_preds_shape[2] +
point_idx;
point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]);
point[1] =
int(loc_preds[step_start_idx + 1] * height_list[batch_idx]);
rec_box.push_back(point);
}
rec_boxes.push_back(rec_box);
}
}
score /= count;
if (isnan(score) || rec_boxes.size() == 0) {
score = -1;
}
rec_scores.push_back(score);
rec_boxes_batch.push_back(rec_boxes);
rec_html_tag_batch.push_back(rec_html_tags);
}
}
} // namespace PaddleOCR

View File

@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
}
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int max_size_len, float &ratio_h, float &ratio_w,
bool use_tensorrt) {
string limit_type, int limit_side_len, float &ratio_h,
float &ratio_w, bool use_tensorrt) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
if (limit_type == "min") {
int min_wh = min(h, w);
if (min_wh < limit_side_len) {
if (h < w) {
ratio = float(limit_side_len) / float(h);
} else {
ratio = float(limit_side_len) / float(w);
}
}
} else {
int max_wh = max(h, w);
if (max_wh > limit_side_len) {
if (h > w) {
ratio = float(limit_side_len) / float(h);
} else {
ratio = float(limit_side_len) / float(w);
}
}
}
@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
}
}
void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len) {
int w = img.cols;
int h = img.rows;
int max_wh = w >= h ? w : h;
float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h);
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len) {
int w = img.cols;
int h = img.rows;
cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
} // namespace PaddleOCR

View File

@ -0,0 +1,158 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/structure_table.h>
namespace PaddleOCR {
void StructureTableRecognizer::Run(
std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &structure_html_tags,
std::vector<float> &structure_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &structure_boxes,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
int img_num = img_list.size();
for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->table_batch_num_) {
// preprocess
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_);
int batch_num = end_img_no - beg_img_no;
std::vector<cv::Mat> norm_img_batch;
std::vector<int> width_list;
std::vector<int> height_list;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg;
img_list[ino].copyTo(srcimg);
cv::Mat resize_img;
cv::Mat pad_img;
this->resize_op_.Run(srcimg, resize_img, this->table_max_len_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
this->pad_op_.Run(resize_img, pad_img, this->table_max_len_);
norm_img_batch.push_back(pad_img);
width_list.push_back(srcimg.cols);
height_list.push_back(srcimg.rows);
}
std::vector<float> input(
batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape(
{batch_num, 3, this->table_max_len_, this->table_max_len_});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
auto output_names = this->predictor_->GetOutputNames();
auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]);
auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]);
std::vector<int> predict_shape0 = output_tensor0->shape();
std::vector<int> predict_shape1 = output_tensor1->shape();
int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(),
1, std::multiplies<int>());
int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(),
1, std::multiplies<int>());
std::vector<float> loc_preds;
std::vector<float> structure_probs;
loc_preds.resize(out_num0);
structure_probs.resize(out_num1);
output_tensor0->CopyToCpu(loc_preds.data());
output_tensor1->CopyToCpu(structure_probs.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
std::vector<std::vector<std::string>> structure_html_tag_batch;
std::vector<float> structure_score_batch;
std::vector<std::vector<std::vector<std::vector<int>>>>
structure_boxes_batch;
this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
predict_shape0, predict_shape1,
structure_html_tag_batch, structure_boxes_batch,
width_list, height_list);
for (int m = 0; m < predict_shape0[0]; m++) {
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<table>");
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<body>");
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<html>");
structure_html_tag_batch[m].push_back("</table>");
structure_html_tag_batch[m].push_back("</body>");
structure_html_tag_batch[m].push_back("</html>");
structure_html_tags.push_back(structure_html_tag_batch[m]);
structure_scores.push_back(structure_score_batch[m]);
structure_boxes.push_back(structure_boxes_batch[m]);
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
}
void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config;
config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams");
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
} // namespace PaddleOCR

View File

@ -248,4 +248,33 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
std::cout << std::endl;
}
}
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
cv::Mat crop_im;
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2] - 1);
int crop_y2 = std::min(img.rows - 1, area[3] - 1);
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16);
cv::Mat crop_im_window =
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]),
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0]));
cv::Mat roi_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
crop_im_window += roi_img;
return crop_im;
}
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
for (int i = 0; i < ocr_result.size() - 1; i++) {
if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 &&
(ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) {
std::swap(ocr_result[i], ocr_result[i + 1]);
}
}
}
} // namespace PaddleOCR

View File

@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
all_results.append([])
continue
starttime = time.time()
dt_boxes, rec_res = self.text_sys(img)
dt_boxes, rec_res, _ = self.text_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))

View File

@ -59,7 +59,8 @@ pip3 install paddlehub==2.1.0 --upgrade -i https://mirror.baidu.com/pypi/simple
检测模型:./inference/ch_PP-OCRv3_det_infer/
识别模型:./inference/ch_PP-OCRv3_rec_infer/
方向分类器:./inference/ch_ppocr_mobile_v2.0_cls_infer/
表格结构识别模型:./inference/en_ppocr_mobile_v2.0_table_structure_infer/
版面分析模型:./inference/layout_infer/
表格结构识别模型:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
```
**模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的模型库[PP-OCR](../../doc/doc_ch/models_list.md)和[PP-Structure](../../ppstructure/docs/models_list.md)下载,也可以替换成自己训练转换好的模型。
@ -172,7 +173,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json
## 3. 发送预测请求
配置好服务端,可使用以下命令发送预测请求,获取预测结果:
```python tools/test_hubserving.py server_url image_path```
```python tools/test_hubserving.py --server_url=server_url --image_dir=image_path```
需要给脚本传递2个参数
- **server_url**:服务地址,格式为

View File

@ -61,7 +61,8 @@ Before installing the service module, you need to prepare the inference model an
text detection model: ./inference/ch_PP-OCRv3_det_infer/
text recognition model: ./inference/ch_PP-OCRv3_rec_infer/
text angle classifier: ./inference/ch_ppocr_mobile_v2.0_cls_infer/
tanle recognition: ./inference/en_ppocr_mobile_v2.0_table_structure_infer/
layout parse model: ./inference/layout_infer/
tanle recognition: ./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
```
**The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself.
@ -177,7 +178,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json
## 3. Send prediction requests
After the service starts, you can use the following command to send a prediction request to obtain the prediction result:
```shell
python tools/test_hubserving.py server_url image_path
python tools/test_hubserving.py --server_url=server_url --image_dir=image_path
```
Two parameters need to be passed to the script:

View File

@ -119,7 +119,7 @@ class StructureSystem(hub.Module):
all_results.append([])
continue
starttime = time.time()
res = self.table_sys(img)
res, _ = self.table_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
@ -144,6 +144,6 @@ class StructureSystem(hub.Module):
if __name__ == '__main__':
structure_system = StructureSystem()
structure_system._initialize()
image_path = ['./doc/table/1.png']
image_path = ['./ppstructure/docs/table/1.png']
res = structure_system.predict(paths=image_path)
print(res)

View File

@ -23,8 +23,10 @@ def read_params():
cfg = table_read_params()
# params for layout parser model
cfg.layout_path_model = 'lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config'
cfg.layout_label_map = None
cfg.layout_model_dir = ''
cfg.layout_dict_path = './ppocr/utils/dict/layout_publaynet_dict.txt'
cfg.layout_score_threshold = 0.5
cfg.layout_nms_threshold = 0.5
cfg.mode = 'structure'
cfg.output = './output'

View File

@ -118,11 +118,11 @@ class TableSystem(hub.Module):
all_results.append([])
continue
starttime = time.time()
pred_html = self.table_sys(img)
res, _ = self.table_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
all_results.append({'html': pred_html})
all_results.append({'html': res['html']})
return all_results
@serving
@ -138,6 +138,6 @@ class TableSystem(hub.Module):
if __name__ == '__main__':
table_system = TableSystem()
table_system._initialize()
image_path = ['./doc/table/table.jpg']
image_path = ['./ppstructure/docs/table/table.jpg']
res = table_system.predict(paths=image_path)
print(res)

View File

@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
SAR文本识别模型推理可以执行如下命令
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
```
<a name="4-2"></a>

View File

@ -3,6 +3,7 @@
- [数据集汇总](#数据集汇总)
- [1. PubTabNet数据集](#1-pubtabnet数据集)
- [2. 好未来表格识别竞赛数据集](#2-好未来表格识别竞赛数据集)
- [3. 好未来表格识别竞赛数据集](#2-WTW中文场景表格数据集)
这里整理了常用表格识别数据集,持续更新中,欢迎各位小伙伴贡献数据集~
@ -12,6 +13,7 @@
|---|---|---|
| PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl格式可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
| 好未来表格识别竞赛数据集 |https://ai.100tal.com/dataset| jsonl格式可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
| WTW中文场景表格数据集 |https://github.com/wangwen-whu/WTW-Dataset| 需要进行转换后才能用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
## 1. PubTabNet数据集
- **数据简介**PubTabNet数据集的训练集合中包含50万张图像验证集合中包含0.9万张图像。部分图像可视化如下所示。
@ -31,3 +33,12 @@
<img src="../../datasets/table_tal_demo/1.jpg" width="500">
<img src="../../datasets/table_tal_demo/2.jpg" width="500">
</div>
## 3. WTW中文场景表格数据集
- **数据简介**WTW中文场景表格数据集包含表格检测和表格数据两部分数据数据集中同时包含扫描和拍照两张场景的图像。
https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif
<div align="center">
<img src="https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif" width="500">
</div>

View File

@ -0,0 +1,343 @@
# 表格识别
本文提供了PaddleOCR表格识别模型的全流程指南包括数据准备、模型训练、调优、评估、预测各个阶段的详细说明
- [1. 数据准备](#1-数据准备)
- [1.1. 数据集格式](#11-数据集格式)
- [1.2. 数据下载](#12-数据下载)
- [1.3. 数据集生成](#13-数据集生成)
- [2. 开始训练](#2-开始训练)
- [2.1. 启动训练](#21-启动训练)
- [2.2. 断点训练](#22-断点训练)
- [2.3. 更换Backbone 训练](#23-更换backbone-训练)
- [2.4. 混合精度训练](#24-混合精度训练)
- [2.5. 分布式训练](#25-分布式训练)
- [2.6. 其他训练环境](#26-其他训练环境)
- [2.7. 模型微调](#27-模型微调)
- [3. 模型评估与预测](#3-模型评估与预测)
- [3.1. 指标评估](#31-指标评估)
- [3.2. 测试表格结构识别效果](#32-测试表格结构识别效果)
- [4. 模型导出与预测](#4-模型导出与预测)
- [4.1 模型导出](#41-模型导出)
- [4.2 模型预测](#42-模型预测)
- [5. FAQ](#5-faq)
# 1. 数据准备
## 1.1. 数据集格式
PaddleOCR 表格识别模型数据集格式如下:
```txt
img_label # 每张图片标注经过json.dumps()之后的字符串
...
img_label
```
每一行的json格式为:
```txt
{
'filename': PMC5755158_010_01.png, # 图像名
'split': train, # 图像属于训练集还是验证集
'imgid': 0, # 图像的index
'html': {
'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]}, # 表格的HTML字符串
'cell': [
{
'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # 表格中的单个文本
'bbox': [x0, y0, x1, y1] # 表格中的单个文本的坐标
}
]
}
}
```
训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
```
# linux and mac os
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
# windows
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
```
## 1.2. 数据下载
公开数据集下载可参考 [table_datasets](dataset/table_datasets.md)。
## 1.3. 数据集生成
使用[TableGeneration](https://github.com/WenmuZhou/TableGeneration)可进行扫描表格图像的生成。
TableGeneration是一个开源表格数据集生成工具其通过浏览器渲染的方式对html字符串进行渲染后获得表格图像。部分样张如下
|类型|样例|
|---|---|
|简单表格|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/simple.jpg)|
|彩色表格|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/color.jpg)|
# 2. 开始训练
PaddleOCR提供了训练脚本、评估脚本和预测脚本本节将以 [SLANet](../../configs/table/SLANet.yml) 模型训练PubTabNet英文数据集为例
## 2.1. 启动训练
*如果您安装的是cpu版本请将配置文件中的 `use_gpu` 字段修改为false*
```
# GPU训练 支持单卡,多卡训练
# 训练日志会自动保存为 "{save_model_dir}" 下的train.log
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/table/SLANet.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml
```
正常启动训练后会看到以下log输出
```
[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27
[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43
[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48
```
log 中自动打印如下信息:
| 字段 | 含义 |
| :----: | :------: |
| epoch | 当前迭代轮次 |
| global_step | 当前迭代次数 |
| lr | 当前学习率 |
| acc | 当前batch的准确率 |
| loss | 当前损失函数 |
| structure_loss | 表格结构损失值 |
| loc_loss | 单元格坐标损失值 |
| avg_reader_cost | 当前 batch 数据处理耗时 |
| avg_batch_cost | 当前 batch 总耗时 |
| avg_samples | 当前 batch 内的样本数 |
| ips | 每秒处理图片的数量 |
PaddleOCR支持训练和评估交替进行, 可以在 `configs/table/SLANet.yml` 中修改 `eval_batch_step` 设置评估频率默认每1000个iter评估一次。评估过程中默认将最佳acc模型保存为 `output/SLANet/best_accuracy`
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
**提示:** 可通过 -c 参数选择 `configs/table/` 路径下的多种模型配置进行训练PaddleOCR支持的表格识别算法可以参考[前沿算法列表](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
**注意,预测/评估时的配置文件请务必与训练一致。**
## 2.2. 断点训练
如果训练程序中断如果希望加载训练中断的模型从而恢复训练可以通过指定Global.checkpoints指定要加载的模型路径
```shell
python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrained_model`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrained_model`指定的模型。
## 2.3. 更换Backbone 训练
PaddleOCR将网络划分为四部分分别在[ppocr/modeling](../../ppocr/modeling)下。 进入网络的数据将按照顺序(transforms->backbones->necks->heads)依次通过这四个部分。
```bash
├── architectures # 网络的组网代码
├── transforms # 网络的图像变换模块
├── backbones # 网络的特征提取模块
├── necks # 网络的特征增强模块
└── heads # 网络的输出模块
```
如果要更换的Backbone 在PaddleOCR中有对应实现直接修改配置yml文件中`Backbone`部分的参数即可。
如果要使用新的Backbone更换backbones的例子如下:
1. 在 [ppocr/modeling/backbones](../../ppocr/modeling/backbones) 文件夹下新建文件如my_backbone.py。
2. 在 my_backbone.py 文件内添加相关代码,示例代码如下:
```python
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class MyBackbone(nn.Layer):
def __init__(self, *args, **kwargs):
super(MyBackbone, self).__init__()
# your init code
self.conv = nn.xxxx
def forward(self, inputs):
# your network forward
y = self.conv(inputs)
return y
```
3. 在 [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py)文件内导入添加的`MyBackbone`模块然后修改配置文件中Backbone进行配置即可使用格式如下:
```yaml
Backbone:
name: MyBackbone
args1: args1
```
**注意**:如果要更换网络的其他模块,可以参考[文档](./add_new_algorithm.md)。
## 2.4. 混合精度训练
如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html) 以单机单卡为例,命令如下:
```shell
python3 tools/train.py -c configs/table/SLANet.yml \
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \
Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True
```
## 2.5. 分布式训练
多机多卡训练时,通过 `--ips` 参数设置使用的机器IP地址通过 `--gpus` 参数设置使用的GPU ID
```bash
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy
```
**注意:** 1采用多机多卡训练时需要替换上面命令中的ips值为您机器的地址机器之间需要能够相互ping通2训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`3更多关于分布式训练的性能优势等信息请参考[分布式训练教程](./distributed_training.md)。
## 2.6. 其他训练环境
- Windows GPU/CPU
在Windows平台上与Linux平台略有不同:
Windows平台只支持`单卡`的训练与预测指定GPU进行训练`set CUDA_VISIBLE_DEVICES=0`
在Windows平台DataLoader只支持单进程模式因此需要设置 `num_workers` 为0;
- macOS
不支持GPU模式需要在配置文件中设置`use_gpu`为False其余训练评估预测命令与Linux GPU完全相同。
- Linux DCU
DCU设备上运行需要设置环境变量 `export HIP_VISIBLE_DEVICES=0,1,2,3`其余训练评估预测命令与Linux GPU完全相同。
## 2.7. 模型微调
实际使用过程中,建议加载官方提供的预训练模型,在自己的数据集中进行微调,关于模型的微调方法,请参考:[模型微调教程](./finetune.md)。
# 3. 模型评估与预测
## 3.1. 指标评估
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。评估数据集可以通过 `configs/table/SLANet.yml` 修改Eval中的 `label_file_list` 设置。
```
# GPU 评估, Global.checkpoints 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
运行完成后会输出模型的acc指标如对英文表格识别模型进行评估会见到如下输出。
```bash
[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782
[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044
```
## 3.2. 测试表格结构识别效果
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件:
根据配置文件中设置的 `save_model_dir``save_epoch_step` 字段,会有以下几种参数被保存下来:
```
output/SLANet/
├── best_accuracy.pdopt
├── best_accuracy.pdparams
├── best_accuracy.states
├── config.yml
├── latest.pdopt
├── latest.pdparams
├── latest.states
└── train.log
```
其中 best_accuracy.* 是评估集上的最优模型latest.* 是最后一个epoch的模型。
```
# 预测表格图像
python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg
```
预测图片:
![](../../ppstructure/docs/table/table.jpg)
得到输入图像的预测结果:
```
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
```
单元格坐标可视化结果为
![](../../ppstructure/docs/imgs/slanet_result.jpg)
# 4. 模型导出与预测
## 4.1 模型导出
inference 模型(`paddle.jit.save`保存的模型)
一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。
训练过程中保存的模型是checkpoints模型保存的只有模型的参数多用于恢复训练等。
与checkpoints模型相比inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
表格识别模型转inference模型与文字检测识别的方式相同如下
```
# -c 后面设置训练算法的yml配置文件
# -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/
```
转换成功后,在目录下有三个文件:
```
inference/SLANet/
├── inference.pdiparams # inference模型的参数文件
├── inference.pdiparams.info # inference模型的参数信息可忽略
└── inference.pdmodel # inference模型的program文件
```
## 4.2 模型预测
模型导出后使用如下命令即可完成inference模型的预测
```python
python3.7 table/predict_structure.py \
--table_model_dir={path/to/inference model} \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
--image_dir=docs/table/table.jpg \
--output=../output/table
```
预测图片:
![](../../ppstructure/docs/table/table.jpg)
得到输入图像的预测结果:
```
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
```
单元格坐标可视化结果为
![](../../ppstructure/docs/imgs/slanet_result.jpg)
# 5. FAQ
Q1: 训练模型转inference 模型之后预测效果不一致?
**A**此类问题出现较多问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。可以对比训练使用的配置文件中的预处理、后处理和预测时是否存在差异。

View File

@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
For SAR text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
```
<a name="4-2"></a>

View File

@ -3,6 +3,7 @@
- [Dataset Summary](#dataset-summary)
- [1. PubTabNet](#1-pubtabnet)
- [2. TAL Table Recognition Competition Dataset](#2-tal-table-recognition-competition-dataset)
- [3. WTW Chinese scene table dataset](#3-wtw-chinese-scene-table-dataset)
Here are the commonly used table recognition datasets, which are being updated continuously. Welcome to contribute datasets~
@ -12,6 +13,7 @@ Here are the commonly used table recognition datasets, which are being updated c
|---|---|---|
| PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) |
| TAL Table Recognition Competition Dataset |https://ai.100tal.com/dataset| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) |
| WTW Chinese scene table dataset |https://github.com/wangwen-whu/WTW-Dataset| Conversion is required to load with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)|
## 1. PubTabNet
- **Data Introduction**The training set of the PubTabNet dataset contains 500,000 images and the validation set contains 9000 images. Part of the image visualization is shown below.
@ -30,3 +32,11 @@ Here are the commonly used table recognition datasets, which are being updated c
<img src="../../datasets/table_tal_demo/1.jpg" width="500">
<img src="../../datasets/table_tal_demo/2.jpg" width="500">
</div>
## 3. WTW Chinese scene table dataset
- **Data Introduction**The WTW Chinese scene table dataset consists of two parts: table detection and table data. The dataset contains images of two scenes, scanned and photographed.
https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif
<div align="center">
<img src="https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif" width="500">
</div>

View File

@ -3,14 +3,14 @@
**Note:** This tutorial mainly introduces the usage of PP-OCR series models, please refer to [PP-Structure Quick Start](../../ppstructure/docs/quickstart_en.md) for the quick use of document analysis related functions.
- [1. Installation](#1-installation)
- [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
- [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
- [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
- [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
- [2. Easy-to-Use](#2-easy-to-use)
- [2.1 Use by Command Line](#21-use-by-command-line)
- [2.1.1 Chinese and English Model](#211-chinese-and-english-model)
- [2.1.2 Multi-language Model](#212-multi-language-model)
- [2.2 Use by Code](#22-use-by-code)
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model)
- [2.1 Use by Command Line](#21-use-by-command-line)
- [2.1.1 Chinese and English Model](#211-chinese-and-english-model)
- [2.1.2 Multi-language Model](#212-multi-language-model)
- [2.2 Use by Code](#22-use-by-code)
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model)
- [3. Summary](#3-summary)
@ -51,12 +51,6 @@ pip install "paddleocr>=2.0.1" # Recommend to use version 2.0.1+
Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
- **For layout analysis users**, run the following command to install **Layout-Parser**
```bash
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
<a name="2-easy-to-use"></a>
## 2. Easy-to-Use

View File

@ -0,0 +1,354 @@
# Table Recognition
This article provides a full-process guide for the PaddleOCR table recognition model, including data preparation, model training, tuning, evaluation, prediction, and detailed descriptions of each stage:
- [1. Data Preparation](#1-data-preparation)
- [1.1. DataSet Format](#11-dataset-format)
- [1.2. Data Download](#12-data-download)
- [1.3. Dataset Generation](#13-dataset-generation)
- [2. Training](#2-training)
- [2.1. Start Training](#21-start-training)
- [2.2. Resume Training](#22-resume-training)
- [2.3. Training with New Backbone](#23-training-with-new-backbone)
- [2.4. Mixed Precision Training](#24-mixed-precision-training)
- [2.5. Distributed Training](#25-distributed-training)
- [2.6. Training on other platform(Windows/macOS/Linux DCU)](#26-training-on-other-platformwindowsmacoslinux-dcu)
- [2.7. Fine-tuning](#27-fine-tuning)
- [3. Evaluation and Test](#3-evaluation-and-test)
- [3.1. Evaluation](#31-evaluation)
- [3.2. Test table structure recognition effect](#32-test-table-structure-recognition-effect)
- [4. Model export and prediction](#4-model-export-and-prediction)
- [4.1 Model export](#41-model-export)
- [4.2 Prediction](#42-prediction)
- [5. FAQ](#5-faq)
# 1. Data Preparation
## 1.1. DataSet Format
The format of the PaddleOCR table recognition model dataset is as follows:
```txt
img_label # Each image is marked with a string after json.dumps()
...
img_label
```
The json format of each line is:
```txt
{
'filename': PMC5755158_010_01.png,# image name
'split': train, # whether the image belongs to the training set or the validation set
'imgid': 0,# index of image
'html': {
'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]}, # HTML string of the table
'cell': [
{
'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # text in cell
'bbox': [x0, y0, x1, y1] # bbox of cell
}
]
}
}
```
The default storage path for training data is `PaddleOCR/train_data`, if you already have a dataset on disk, just create a soft link to the dataset directory:
```
# linux and mac os
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
# windows
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
```
## 1.2. Data Download
Download the public dataset reference [table_datasets](dataset/table_datasets_en.md)。
## 1.3. Dataset Generation
Use [TableGeneration](https://github.com/WenmuZhou/TableGeneration) to generate scanned table images.
TableGeneration is an open source table dataset generation tool, which renders html strings through browser rendering to obtain table images.
Some samples are as follows:
|Type|Sample|
|---|---|
|Simple Table|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/simple.jpg)|
|Simple Color Table|![](https://raw.githubusercontent.com/WenmuZhou/TableGeneration/main/imgs/color.jpg)|
# 2. Training
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. In this section, the [SLANet](../../configs/table/SLANet.yml) model will be used as an example:
## 2.1. Start Training
*If you are installing the cpu version, please modify the `use_gpu` field in the configuration file to false*
```
# GPU training Support single card and multi-card training
# The training log will be automatically saved as train.log under "{save_model_dir}"
# specify the single card training(Long training time, not recommended)
python3 tools/train.py -c configs/table/SLANet.yml
# specify the card number through --gpus
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml
```
After starting training normally, you will see the following log output:
```
[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27
[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43
[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48
```
The following information is automatically printed in the log:
| Field | Meaning |
| :----: | :------: |
| epoch | current iteration round |
| global_step | current iteration count |
| lr | current learning rate |
| acc | The accuracy of the current batch |
| loss | current loss function |
| structure_loss | Table Structure Loss Values |
| loc_loss | Cell Coordinate Loss Value |
| avg_reader_cost | Current batch data processing time |
| avg_batch_cost | The total time spent in the current batch |
| avg_samples | The number of samples in the current batch |
| ips | Number of images processed per second |
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/table/SLANet.yml` to set the evaluation frequency. By default, it is evaluated once every 1000 iters. During the evaluation process, the best acc model is saved as `output/SLANet/best_accuracy` by default.
If the validation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or perform evaluation after training.
**Tips:** You can use the -c parameter to select various model configurations under the `configs/table/` path for training. For the table recognition algorithms supported by PaddleOCR, please refer to [Table Algorithms List](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/algorithm_overview_en.md#3):
**Note that the configuration file for prediction/evaluation must be the same as training. **
## 2.2. Resume Training
If the training program is interrupted, if you want to load the interrupted model to resume training, you can specify the path of the model to be loaded by specifying Global.checkpoints:
```shell
python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model
```
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If `Global.checkpoints` The specified model path is incorrect, and the model specified by `Global.pretrained_model` will be loaded.
## 2.3. Training with New Backbone
The network part completes the construction of the network, and PaddleOCR divides the network into four parts, which are under [ppocr/modeling](../../ppocr/modeling). The data entering the network will pass through these four parts in sequence(transforms->backbones->
necks->heads).
```bash
├── architectures # Code for building network
├── transforms # Image Transformation Module
├── backbones # Feature extraction module
├── necks # Feature enhancement module
└── heads # Output module
```
If the Backbone to be replaced has a corresponding implementation in PaddleOCR, you can directly modify the parameters in the `Backbone` part of the configuration yml file.
However, if you want to use a new Backbone, an example of replacing the backbones is as follows:
1. Create a new file under the [ppocr/modeling/backbones](../../ppocr/modeling/backbones) folder, such as my_backbone.py.
2. Add code in the my_backbone.py file, the sample code is as follows:
```python
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class MyBackbone(nn.Layer):
def __init__(self, *args, **kwargs):
super(MyBackbone, self).__init__()
# your init code
self.conv = nn.xxxx
def forward(self, inputs):
# your network forward
y = self.conv(inputs)
return y
```
3. Import the added module in the [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py) file.
After adding the four-part modules of the network, you only need to configure them in the configuration file to use, such as:
```yaml
Backbone:
name: MyBackbone
args1: args1
```
**NOTE**: More details about replace Backbone and other mudule can be found in [doc](add_new_algorithm_en.md).
## 2.4. Mixed Precision Training
If you want to speed up your training further, you can use [Auto Mixed Precision Training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), taking a single machine and a single gpu as an example, the commands are as follows:
```shell
python3 tools/train.py -c configs/table/SLANet.yml \
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \
Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True
```
## 2.5. Distributed Training
During multi-machine multi-gpu training, use the `--ips` parameter to set the used machine IP address, and the `--gpus` parameter to set the used GPU ID:
```bash
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy
```
**Note:** (1) When using multi-machine and multi-gpu training, you need to replace the ips value in the above command with the address of your machine, and the machines need to be able to ping each other. (2) Training needs to be launched separately on multiple machines. The command to view the ip address of the machine is `ifconfig`. (3) For more details about the distributed training speedup ratio, please refer to [Distributed Training Tutorial](./distributed_training_en.md).
## 2.6. Training on other platform(Windows/macOS/Linux DCU)
- Windows GPU/CPU
The Windows platform is slightly different from the Linux platform:
Windows platform only supports `single gpu` training and inference, specify GPU for training `set CUDA_VISIBLE_DEVICES=0`
On the Windows platform, DataLoader only supports single-process mode, so you need to set `num_workers` to 0;
- macOS
GPU mode is not supported, you need to set `use_gpu` to False in the configuration file, and the rest of the training evaluation prediction commands are exactly the same as Linux GPU.
- Linux DCU
Running on a DCU device requires setting the environment variable `export HIP_VISIBLE_DEVICES=0,1,2,3`, and the rest of the training and evaluation prediction commands are exactly the same as the Linux GPU.
## 2.7. Fine-tuning
In the actual use process, it is recommended to load the officially provided pre-training model and fine-tune it in your own data set. For the fine-tuning method of the table recognition model, please refer to: [Model fine-tuning tutorial](./finetune.md).
# 3. Evaluation and Test
## 3.1. Evaluation
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating metrics, you need to set `Global.checkpoints` to point to the saved parameter file. Evaluation datasets can be modified via the `label_file_list` setting in Eval via `configs/table/SLANet.yml`.
```
# GPU evaluation, Global.checkpoints is the weight to be tested
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
After the operation is completed, the acc indicator of the model will be output. If you evaluate the English table recognition model, you will see the following output.
```bash
[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782
[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044
```
## 3.2. Test table structure recognition effect
Using the model trained by PaddleOCR, you can quickly get prediction through the following script.
The default prediction picture is stored in `infer_img`, and the trained weight is specified via `-o Global.checkpoints`:
According to the `save_model_dir` and `save_epoch_step` fields set in the configuration file, the following parameters will be saved:
```
output/SLANet/
├── best_accuracy.pdopt
├── best_accuracy.pdparams
├── best_accuracy.states
├── config.yml
├── latest.pdopt
├── latest.pdparams
├── latest.states
└── train.log
```
Among them, best_accuracy.* is the best model on the evaluation set; latest.* is the model of the last epoch.
```
# Predict table image
python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg
```
Input image:
![](../../ppstructure/docs/table/table.jpg)
Get the prediction result of the input image:
```
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
```
The cell coordinates are visualized as
![](../../ppstructure/docs/imgs/slanet_result.jpg)
# 4. Model export and prediction
## 4.1 Model export
inference model (model saved by `paddle.jit.save`)
Generally, it is model training, a solidified model that saves the model structure and model parameters in a file, and is mostly used to predict deployment scenarios.
The model saved during the training process is the checkpoints model, and only the parameters of the model are saved, which are mostly used to resume training.
Compared with the checkpoints model, the inference model will additionally save the structural information of the model. It has superior performance in predicting deployment and accelerating reasoning, and is flexible and convenient, and is suitable for actual system integration.
The way to convert the form recognition model to the inference model is the same as the text detection and recognition, as follows:
```
# -c Set the training algorithm yml configuration file
# -o Set optional parameters
# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
# Global.save_inference_dir Set the address where the converted model will be saved.
python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/
```
After the conversion is successful, there are three files in the model save directory:
```
inference/SLANet/
├── inference.pdiparams # The parameter file of inference model
├── inference.pdiparams.info # The parameter information of inference model, which can be ignored
└── inference.pdmodel # The program file of model
```
## 4.2 Prediction
After the model is exported, use the following command to complete the prediction of the inference model
```python
python3.7 table/predict_structure.py \
--table_model_dir={path/to/inference model} \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
--image_dir=docs/table/table.jpg \
--output=../output/table
```
Input image:
![](../../ppstructure/docs/table/table.jpg)
Get the prediction result of the input image:
```
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
```
The cell coordinates are visualized as
![](../../ppstructure/docs/imgs/slanet_result.jpg)
# 5. FAQ
Q1: After the training model is transferred to the inference model, the prediction effect is inconsistent?
**A**: There are many such problems, and the problems are mostly caused by inconsistent preprocessing and postprocessing parameters when the trained model predicts and the preprocessing and postprocessing parameters when the inference model predicts. You can compare whether there are differences in preprocessing, postprocessing, and prediction in the configuration files used for training.

View File

@ -47,14 +47,14 @@ __all__ = [
]
SUPPORT_DET_MODEL = ['DB']
VERSION = '2.5.0.3'
VERSION = '2.6'
SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2', 'PP-OCRv3']
DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-STRUCTURE']
DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-Structurev2'
SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-Structure', 'PP-Structurev2']
MODEL_URLS = {
'OCR': {
'PP-OCRv3': {
@ -263,7 +263,7 @@ MODEL_URLS = {
}
},
'STRUCTURE': {
'PP-STRUCTURE': {
'PP-Structure': {
'table': {
'en': {
'url':
@ -271,6 +271,27 @@ MODEL_URLS = {
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
}
},
'PP-Structurev2': {
'table': {
'en': {
'url':
'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
},
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict_ch.txt'
}
},
'layout': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar',
'dict_path': 'ppocr/utils/dict/layout_publaynet_dict.txt'
}
}
}
}
}
@ -298,12 +319,15 @@ def parse_args(mMain=True):
"--structure_version",
type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='PP-STRUCTURE',
default='PP-Structurev2',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
' 1. PP-Structure Support en table structure model.'
' 2. PP-Structurev2 Support ch and en table structure model.')
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
if action.dest in [
'rec_char_dict_path', 'table_char_dict_path', 'layout_dict_path'
]:
action.default = None
if mMain:
return parser.parse_args()
@ -477,7 +501,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec:
dt_boxes, rec_res = self.__call__(img, cls)
dt_boxes, rec_res, _ = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec:
dt_boxes, elapse = self.text_detector(img)
@ -506,6 +530,12 @@ class PPStructure(StructureSystem):
if not params.show_log:
logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang)
if lang == 'ch':
table_lang = 'ch'
else:
table_lang = 'en'
if params.structure_version == 'PP-Structure':
params.merge_no_span_structure = False
# init model dir
det_model_config = get_model_config('OCR', params.ocr_version, 'det',
@ -520,14 +550,20 @@ class PPStructure(StructureSystem):
params.rec_model_dir,
os.path.join(BASE_DIR, 'whl', 'rec', lang), rec_model_config['url'])
table_model_config = get_model_config(
'STRUCTURE', params.structure_version, 'table', 'en')
'STRUCTURE', params.structure_version, 'table', table_lang)
params.table_model_dir, table_url = confirm_model_dir_url(
params.table_model_dir,
os.path.join(BASE_DIR, 'whl', 'table'), table_model_config['url'])
layout_model_config = get_model_config(
'STRUCTURE', params.structure_version, 'layout', 'ch')
params.layout_model_dir, layout_url = confirm_model_dir_url(
params.layout_model_dir,
os.path.join(BASE_DIR, 'whl', 'layout'), layout_model_config['url'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url)
maybe_download(params.layout_model_dir, layout_url)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(
@ -535,7 +571,9 @@ class PPStructure(StructureSystem):
if params.table_char_dict_path is None:
params.table_char_dict_path = str(
Path(__file__).parent / table_model_config['dict_path'])
if params.layout_dict_path is None:
params.layout_dict_path = str(
Path(__file__).parent / layout_model_config['dict_path'])
logger.debug(params)
super().__init__(params)
@ -557,7 +595,7 @@ class PPStructure(StructureSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img, return_ocr_result_in_table)
res, _ = super().__call__(img, return_ocr_result_in_table)
return res

View File

@ -575,7 +575,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
loc_reg_num=4,
**kwargs):
self.max_text_len = max_text_length
self.lower = False
@ -590,6 +590,12 @@ class TableLabelEncode(AttnLabelEncode):
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
if self.merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@ -597,7 +603,7 @@ class TableLabelEncode(AttnLabelEncode):
self.idx2char = {v: k for k, v in self.dict.items()}
self.character = dict_character
self.point_num = point_num
self.loc_reg_num = loc_reg_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
@ -653,7 +659,7 @@ class TableLabelEncode(AttnLabelEncode):
# encode box
bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
(self._max_text_len, self.loc_reg_num), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
@ -718,11 +724,11 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@ -743,27 +749,35 @@ class TableMasterLabelEncode(TableLabelEncode):
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
self.in_box_format = in_box_format
self.out_box_format = out_box_format
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
if self.in_box_format != self.out_box_format:
if self.out_box_format == 'xywh':
if self.in_box_format == 'xyxyxyxy':
bboxes = self.xyxyxyxy2xywh(bboxes)
elif self.in_box_format == 'xyxy':
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxyxyxy2xywh(self, boxes):
new_bboxes = np.zeros([len(bboxes), 4])
new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
return new_bboxes
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center

View File

@ -206,7 +206,7 @@ class ResizeTableImage(object):
data['bboxes'] = data['bboxes'] * ratio
data['image'] = resize_img
data['src_img'] = img
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
data['shape'] = np.array([height, width, ratio, ratio])
data['max_len'] = self.max_len
return data

View File

@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
from .table_att_loss import TableAttentionLoss, SLALoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
@ -67,7 +67,8 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss'
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')

View File

@ -22,65 +22,11 @@ from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer):
def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
def __init__(self, structure_weight, loc_weight, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
# overlap
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious
ious = inters / uni
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
# enclose erea
enclose = ew * eh + eps
giou = ious - (enclose - uni) / enclose
loss = 1 - giou
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
else:
raise NotImplementedError
return loss
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer):
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else:
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
class SLALoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
super(SLALoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.loc_loss = loc_loss
self.eps = 1e-12
def forward(self, predicts, batch):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
structure_loss = self.loss_func(structure_probs, structure_targets)
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.smooth_l1_loss(
loc_preds * loc_targets_mask,
loc_targets * loc_targets_mask,
reduction='sum') * self.loc_weight
loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}

View File

@ -16,9 +16,14 @@ from ppocr.metrics.det_metric import DetMetric
class TableStructureMetric(object):
def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
def __init__(self,
main_indicator='acc',
eps=1e-6,
del_thead_tbody=False,
**kwargs):
self.main_indicator = main_indicator
self.eps = eps
self.del_thead_tbody = del_thead_tbody
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
@ -31,6 +36,13 @@ class TableStructureMetric(object):
gt_structure_batch_list):
pred_str = ''.join(pred)
target_str = ''.join(target)
if self.del_thead_tbody:
pred_str = pred_str.replace('<thead>', '').replace(
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
'')
target_str = target_str.replace('<thead>', '').replace(
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
'')
if pred_str == target_str:
correct_num += 1
all_num += 1
@ -59,7 +71,8 @@ class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
box_format='xyxy',
del_thead_tbody=False,
**kwargs):
"""
@ -67,10 +80,11 @@ class TableMetric(object):
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.structure_metric = TableStructureMetric(
del_thead_tbody=del_thead_tbody)
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.box_format = box_format
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
@ -129,10 +143,14 @@ class TableMetric(object):
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
if self.box_format == 'xyxy':
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
elif self.box_format == 'xywh':
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.box_format == 'xyxyxyxy':
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box

View File

@ -21,7 +21,10 @@ def build_backbone(config, model_type):
from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
from .det_pp_lcnet import PPLCNet
support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')

View File

@ -0,0 +1,271 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from paddle.utils.download import get_path_from_url
MODEL_URLS = {
"PPLCNet_x0.25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
"PPLCNet_x0.35":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
"PPLCNet_x0.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
"PPLCNet_x0.75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
"PPLCNet_x1.0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
"PPLCNet_x1.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
"PPLCNet_x2.0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
"PPLCNet_x2.5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
}
MODEL_STAGES_PATTERN = {
"PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"]
}
__all__ = list(MODEL_URLS.keys())
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
# k: kernel_size
# in_c: input channel number in depthwise block
# out_c: output channel number in depthwise block
# s: stride in depthwise block
# use_se: whether to use SE block
NET_CONFIG = {
"blocks2":
# k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
"blocks5":
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
[5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
}
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
num_groups=1):
super().__init__()
self.conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.hardswish(x)
return x
class DepthwiseSeparable(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
dw_size=3,
use_se=False):
super().__init__()
self.use_se = use_se
self.dw_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_channels,
filter_size=dw_size,
stride=stride,
num_groups=num_channels)
if use_se:
self.se = SEModule(num_channels)
self.pw_conv = ConvBNLayer(
num_channels=num_channels,
filter_size=1,
num_filters=num_filters,
stride=1)
def forward(self, x):
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
x = self.pw_conv(x)
return x
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
self.hardsigmoid = nn.Hardsigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
class PPLCNet(nn.Layer):
def __init__(self,
in_channels=3,
scale=1.0,
pretrained=False,
use_ssld=False):
super().__init__()
self.out_channels = [
int(NET_CONFIG["blocks3"][-1][2] * scale),
int(NET_CONFIG["blocks4"][-1][2] * scale),
int(NET_CONFIG["blocks5"][-1][2] * scale),
int(NET_CONFIG["blocks6"][-1][2] * scale)
]
self.scale = scale
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
num_filters=make_divisible(16 * scale),
stride=2)
self.blocks2 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
])
self.blocks3 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
])
self.blocks4 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
])
self.blocks5 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
])
self.blocks6 = nn.Sequential(* [
DepthwiseSeparable(
num_channels=make_divisible(in_c * scale),
num_filters=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se)
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
])
if pretrained:
self._load_pretrained(
MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
def forward(self, x):
outs = []
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
outs.append(x)
x = self.blocks4(x)
outs.append(x)
x = self.blocks5(x)
outs.append(x)
x = self.blocks6(x)
outs.append(x)
return outs
def _load_pretrained(self, pretrained_url, use_ssld=False):
if use_ssld:
pretrained_url = pretrained_url.replace("_pretrained",
"_ssld_pretrained")
print(pretrained_url)
local_weight_path = get_path_from_url(
pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
param_state_dict = paddle.load(local_weight_path)
self.set_dict(param_state_dict)
return

View File

@ -44,7 +44,7 @@ def build_head(config):
#kie head
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead
from .table_att_head import TableAttentionHead, SLAHead
from .table_master_head import TableMasterHead
support_dict = [
@ -52,7 +52,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'RobustScannerHead'
'VLHead', 'SLAHead', 'RobustScannerHead'
]
#table head

View File

@ -18,12 +18,26 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle import ParamAttr
import paddle.nn.functional as F
import numpy as np
from .rec_att_head import AttentionGRUCell
def get_para_bias_attr(l2_decay, k):
if l2_decay > 0:
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
else:
regularizer = None
initializer = None
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
class TableAttentionHead(nn.Layer):
def __init__(self,
in_channels,
@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
in_max_len=488,
max_text_length=800,
out_channels=30,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class SLAHead(nn.Layer):
def __init__(self,
in_channels,
hidden_size,
out_channels=30,
max_text_length=500,
loc_reg_num=4,
fc_decay=0.0,
**kwargs):
"""
@param in_channels: input shape
@param hidden_size: hidden_size for RNN and Embedding
@param out_channels: num_classes to rec
@param max_text_length: max text pred
"""
super().__init__()
in_channels = in_channels[-1]
self.hidden_size = hidden_size
self.max_text_length = max_text_length
self.emb = self._char_to_onehot
self.num_embeddings = out_channels
# structure
self.structure_attention_cell = AttentionGRUCell(
in_channels, hidden_size, self.num_embeddings)
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
l2_decay=fc_decay, k=hidden_size)
self.structure_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1_2,
bias_attr=bias_attr1_2),
nn.Linear(
hidden_size,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr))
# loc
weight_attr1, bias_attr1 = get_para_bias_attr(
l2_decay=fc_decay, k=self.hidden_size)
weight_attr2, bias_attr2 = get_para_bias_attr(
l2_decay=fc_decay, k=self.hidden_size)
self.loc_generator = nn.Sequential(
nn.Linear(
self.hidden_size,
self.hidden_size,
weight_attr=weight_attr1,
bias_attr=bias_attr1),
nn.Linear(
self.hidden_size,
loc_reg_num,
weight_attr=weight_attr2,
bias_attr=bias_attr2),
nn.Sigmoid())
def forward(self, inputs, targets=None):
fea = inputs[-1]
batch_size = fea.shape[0]
# reshape
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
hidden = paddle.zeros((batch_size, self.hidden_size))
structure_preds = []
loc_preds = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
hidden, structure_step, loc_step = self._decode(structure[:, i],
fea, hidden)
structure_preds.append(structure_step)
loc_preds.append(loc_step)
else:
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length)
# for export
loc_step, structure_step = None, None
for i in range(max_text_length + 1):
hidden, structure_step, loc_step = self._decode(pre_chars, fea,
hidden)
pre_chars = structure_step.argmax(axis=1, dtype="int32")
structure_preds.append(structure_step)
loc_preds.append(loc_step)
structure_preds = paddle.stack(structure_preds, axis=1)
loc_preds = paddle.stack(loc_preds, axis=1)
if not self.training:
structure_preds = F.softmax(structure_preds)
return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
def _decode(self, pre_chars, features, hidden):
"""
Predict table label and coordinates for each step
@param pre_chars: Table label in previous step
@param features:
@param hidden: hidden status in previous step
@return:
"""
emb_feature = self.emb(pre_chars)
# output shape is b * self.hidden_size
(output, hidden), alpha = self.structure_attention_cell(
hidden, features, emb_feature)
# structure
structure_step = self.structure_generator(output)
# loc
loc_step = self.loc_generator(output)
return hidden, structure_step, loc_step
def _char_to_onehot(self, input_char):
input_ont_hot = F.one_hot(input_char, self.num_embeddings)
return input_ont_hot

View File

@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
loc_reg_num=4,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Linear(hidden_size, loc_reg_num),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.loc_reg_num = loc_reg_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
[input.shape[0], self.max_text_length + 1, self.loc_reg_num])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)

View File

@ -25,9 +25,10 @@ def build_neck(config):
from .fpn import FPN
from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
]
module_name = config.pop('name')

View File

@ -0,0 +1,324 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on:
# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
__all__ = ['CSPPAN']
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
groups=1,
act='leaky_relu'):
super(ConvBNLayer, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
assert self.act in ['leaky_relu', "hard_swish"]
self.conv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=groups,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn = nn.BatchNorm2D(out_channel)
def forward(self, x):
x = self.bn(self.conv(x))
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
class DPModule(nn.Layer):
"""
Depth-wise and point-wise module.
Args:
in_channel (int): The input channels of this Module.
out_channel (int): The output channels of this Module.
kernel_size (int): The conv2d kernel size of this Module.
stride (int): The conv2d's stride of this Module.
act (str): The activation function of this Module,
Now support `leaky_relu` and `hard_swish`.
"""
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
act='leaky_relu'):
super(DPModule, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
self.dwconv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=out_channel,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channel)
self.pwconv = nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=1,
groups=1,
padding=0,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn2 = nn.BatchNorm2D(out_channel)
def act_func(self, x):
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
def forward(self, x):
x = self.act_func(self.bn1(self.dwconv(x)))
x = self.act_func(self.bn2(self.pwconv(x)))
return x
class DarknetBottleneck(nn.Layer):
"""The basic bottleneck block used in Darknet.
Each Block consists of two ConvModules and the input is added to the
final output. Each ConvModule is composed of Conv, BN, and act.
The first convLayer has filter size of 1x1 and the second one has the
filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The kernel size of the convolution. Default: 0.5
add_identity (bool): Whether to add identity to the out.
Default: True
use_depthwise (bool): Whether to use depthwise separable convolution.
Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expansion=0.5,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super(DarknetBottleneck, self).__init__()
hidden_channels = int(out_channels * expansion)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv1 = ConvBNLayer(
in_channel=in_channels,
out_channel=hidden_channels,
kernel_size=1,
act=act)
self.conv2 = conv_func(
in_channel=hidden_channels,
out_channel=out_channels,
kernel_size=kernel_size,
stride=1,
act=act)
self.add_identity = \
add_identity and in_channels == out_channels
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.add_identity:
return out + identity
else:
return out
class CSPLayer(nn.Layer):
"""Cross Stage Partial Layer.
Args:
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
num_blocks (int): Number of blocks. Default: 1
add_identity (bool): Whether to add identity in blocks.
Default: True
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expand_ratio=0.5,
num_blocks=1,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super().__init__()
mid_channels = int(out_channels * expand_ratio)
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.final_conv = ConvBNLayer(
2 * mid_channels, out_channels, 1, act=act)
self.blocks = nn.Sequential(* [
DarknetBottleneck(
mid_channels,
mid_channels,
kernel_size,
1.0,
add_identity,
use_depthwise,
act=act) for _ in range(num_blocks)
])
def forward(self, x):
x_short = self.short_conv(x)
x_main = self.main_conv(x)
x_main = self.blocks(x_main)
x_final = paddle.concat((x_main, x_short), axis=1)
return self.final_conv(x_final)
class Channel_T(nn.Layer):
def __init__(self,
in_channels=[116, 232, 464],
out_channels=96,
act="leaky_relu"):
super(Channel_T, self).__init__()
self.convs = nn.LayerList()
for i in range(len(in_channels)):
self.convs.append(
ConvBNLayer(
in_channels[i], out_channels, 1, act=act))
def forward(self, x):
outs = [self.convs[i](x[i]) for i in range(len(x))]
return outs
class CSPPAN(nn.Layer):
"""Path Aggregation Network with CSP module.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
kernel_size (int): The conv2d kernel size of this Module.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: True
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=5,
num_csp_blocks=1,
use_depthwise=True,
act='hard_swish'):
super(CSPPAN, self).__init__()
self.in_channels = in_channels
self.out_channels = [out_channels] * len(in_channels)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv_t = Channel_T(in_channels, out_channels, act=act)
# build top-down blocks
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.top_down_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.top_down_blocks.append(
CSPLayer(
out_channels * 2,
out_channels,
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
# build bottom-up blocks
self.downsamples = nn.LayerList()
self.bottom_up_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsamples.append(
conv_func(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=2,
act=act))
self.bottom_up_blocks.append(
CSPLayer(
out_channels * 2,
out_channels,
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
def forward(self, inputs):
"""
Args:
inputs (tuple[Tensor]): input features.
Returns:
tuple[Tensor]: CSPPAN features.
"""
assert len(inputs) == len(self.in_channels)
inputs = self.conv_t(inputs)
# top-down path
inner_outs = [inputs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]
upsample_feat = F.upsample(
feat_heigh, size=paddle.shape(feat_low)[2:4], mode="nearest")
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
paddle.concat([upsample_feat, feat_low], 1))
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsamples[idx](feat_low)
out = self.bottom_up_blocks[idx](paddle.concat(
[downsample_feat, feat_height], 1))
outs.append(out)
return tuple(outs)

View File

@ -21,9 +21,29 @@ from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __init__(self,
character_dict_path,
merge_no_span_structure=False,
**kwargs):
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
if merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
self.td_token = ['<td>', '<td', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
@ -114,18 +134,21 @@ class TableLabelDecode(AttnLabelDecode):
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
bbox[0::2] *= w
bbox[1::2] *= h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
def __init__(self,
character_dict_path,
box_shape='ori',
merge_no_span_structure=True,
**kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path,
merge_no_span_structure)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
@ -157,4 +180,7 @@ class TableMasterLabelDecode(TableLabelDecode):
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
x, y, w, h = bbox
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
bbox = np.array([x1, y1, x2, y2])
return bbox

View File

@ -0,0 +1,48 @@
<thead>
</thead>
<tbody>
</tbody>
<tr>
</tr>
<td>
<td
>
</td>
colspan="2"
colspan="3"
colspan="4"
colspan="5"
colspan="6"
colspan="7"
colspan="8"
colspan="9"
colspan="10"
colspan="11"
colspan="12"
colspan="13"
colspan="14"
colspan="15"
colspan="16"
colspan="17"
colspan="18"
colspan="19"
colspan="20"
rowspan="2"
rowspan="3"
rowspan="4"
rowspan="5"
rowspan="6"
rowspan="7"
rowspan="8"
rowspan="9"
rowspan="10"
rowspan="11"
rowspan="12"
rowspan="13"
rowspan="14"
rowspan="15"
rowspan="16"
rowspan="17"
rowspan="18"
rowspan="19"
rowspan="20"

View File

@ -41,9 +41,7 @@ def download_with_progressbar(url, save_path):
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
tar_file_name_list = ['.pdiparams', '.pdiparams.info', '.pdmodel']
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
@ -57,8 +55,8 @@ def maybe_download(model_storage_directory, url):
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if member.name.endswith(tar_file_name):
filename = 'inference' + tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)

View File

@ -113,14 +113,11 @@ def draw_re_results(image,
return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
def draw_rectangle(img_path, boxes):
boxes = np.array(boxes)
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show

View File

@ -106,9 +106,9 @@ PP-Structure Series Model List (Updating)
|model name|description|model size|download|
| --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|[New] Slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_rec_slim|[New] Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text recognition| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset| 18.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|ch_PP-OCRv3_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 1.1M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)|
|ch_PP-OCRv3_rec_slim |[New] Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition| 4.9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) |
|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
### 7.3 DOC-VQA model

View File

@ -120,9 +120,10 @@ PP-Structure系列模型列表更新中
|模型名称|模型简介|模型大小|下载地址|
| --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_rec_slim|【最新】slim量化版超轻量模型支持中英文、数字识别| 9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|ch_PP-OCRv3_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 1.1M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)|
|ch_PP-OCRv3_rec_slim |【最新】slim量化版超轻量模型支持中英文、数字识别| 4.9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) |
|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
<a name="73"></a>
### 7.3 DocVQA 模型

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 369 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 772 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 467 KiB

View File

@ -1,8 +1,7 @@
- [快速安装](#快速安装)
- [1. PaddlePaddle 和 PaddleOCR](#1-paddlepaddle-和-paddleocr)
- [2. 安装其他依赖](#2-安装其他依赖)
- [2.1 版面分析所需 Layout-Parser](#21-版面分析所需--layout-parser)
- [2.2 VQA所需依赖](#22--vqa所需依赖)
- [2.1 VQA所需依赖](#21--vqa所需依赖)
# 快速安装
@ -12,14 +11,7 @@
## 2. 安装其他依赖
### 2.1 版面分析所需 Layout-Parser
Layout-Parser 可通过如下命令安装
```bash
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 VQA所需依赖
### 2.1 VQA所需依赖
* paddleocr
```bash

View File

@ -0,0 +1,30 @@
# Quick installation
- [1. PaddlePaddle 和 PaddleOCR](#1)
- [2. Install other dependencies](#2)
- [2.1 VQA](#21)
<a name="1"></a>
## 1. PaddlePaddle and PaddleOCR
Please refer to [PaddleOCR installation documentation](../../doc/doc_en/installation_en.md)
<a name="2"></a>
## 2. Install other dependencies
<a name="21"></a>
### 2.1 VQA
* paddleocr
```bash
pip3 install paddleocr
```
* PaddleNLP
```bash
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
cd PaddleNLP
pip3 install -e .
```

View File

@ -34,7 +34,9 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|基于TableRec-RARE在PubTabNet数据集上训练的英文表格识别模型|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的英文表格识别模型|9M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) |
|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
<a name="3"></a>

View File

@ -35,7 +35,9 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download|
| --- |-----------------------------------------------------------------------------| --- | --- |
|en_ppocr_mobile_v2.0_table_structure| Table structure model for English table scenes trained on PubTabNet dataset |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppocr_mobile_v2.0_table_structure| English table recognition model trained on PubTabNet dataset based on TableRec-RARE |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppstructure_mobile_v2.0_SLANet|English table recognition model trained on PubTabNet dataset based on SLANet|9M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) |
|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
<a name="3"></a>
## 3. VQA

View File

@ -1,21 +1,23 @@
# PP-Structure 快速开始
- [1. 安装依赖包](#1)
- [2. 便捷使用](#2)
- [2.1 命令行使用](#21)
- [2.1.1 版面分析+表格识别](#211)
- [2.1.2 版面分析](#212)
- [2.1.3 表格识别](#213)
- [2.1.4 DocVQA](#214)
- [2.2 代码使用](#22)
- [2.2.1 版面分析+表格识别](#221)
- [2.2.2 版面分析](#222)
- [2.2.3 表格识别](#223)
- [2.2.4 DocVQA](#224)
- [2.3 返回结果说明](#23)
- [2.3.1 版面分析+表格识别](#231)
- [2.3.2 DocVQA](#232)
- [2.4 参数说明](#24)
- [1. 安装依赖包](#1-安装依赖包)
- [2. 便捷使用](#2-便捷使用)
- [2.1 命令行使用](#21-命令行使用)
- [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别)
- [2.1.2 版面分析+表格识别](#212-版面分析表格识别)
- [2.1.3 版面分析](#213-版面分析)
- [2.1.4 表格识别](#214-表格识别)
- [2.1.5 DocVQA](#215-docvqa)
- [2.2 代码使用](#22-代码使用)
- [2.2.1 图像方向分类版面分析表格识别](#221-图像方向分类版面分析表格识别)
- [2.2.2 版面分析+表格识别](#222-版面分析表格识别)
- [2.2.3 版面分析](#223-版面分析)
- [2.2.4 表格识别](#224-表格识别)
- [2.2.5 DocVQA](#225-docvqa)
- [2.3 返回结果说明](#23-返回结果说明)
- [2.3.1 版面分析+表格识别](#231-版面分析表格识别)
- [2.3.2 DocVQA](#232-docvqa)
- [2.4 参数说明](#24-参数说明)
<a name="1"></a>
@ -24,8 +26,6 @@
```bash
# 安装 paddleocr推荐使用2.5+版本
pip3 install "paddleocr>=2.5"
# 安装 版面分析依赖包layoutparser如不需要版面分析功能可跳过
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# 安装 DocVQA依赖包paddlenlp如不需要DocVQA功能可跳过
pip install paddlenlp
@ -38,25 +38,31 @@ pip install paddlenlp
### 2.1 命令行使用
<a name="211"></a>
#### 2.1.1 版面分析+表格识别
#### 2.1.1 图像方向分类+版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
```
<a name="212"></a>
#### 2.1.2 版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="212"></a>
#### 2.1.2 版面分析
<a name="213"></a>
#### 2.1.3 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 表格识别
<a name="214"></a>
#### 2.1.4 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
<a name="215"></a>
#### 2.1.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)。
@ -64,7 +70,36 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structur
### 2.2 代码使用
<a name="221"></a>
#### 2.2.1 版面分析+表格识别
#### 2.2.1 图像方向分类版面分析表格识别
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
from PIL import Image
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 版面分析+表格识别
```python
import os
@ -92,8 +127,8 @@ im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 版面分析
<a name="223"></a>
#### 2.2.3 版面分析
```python
import os
@ -113,8 +148,8 @@ for line in result:
print(line)
```
<a name="223"></a>
#### 2.2.3 表格识别
<a name="224"></a>
#### 2.2.4 表格识别
```python
import os
@ -134,8 +169,8 @@ for line in result:
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
<a name="225"></a>
#### 2.2.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)。
@ -156,10 +191,10 @@ PP-Structure的返回结果为一个dict组成的list示例如下
```
dict 里各个字段说明如下
| 字段 | 说明 |
| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type| 图片区域的类型 |
|bbox| 图片区域的在原图的坐标,分别[左上角x左上角y右下角x右下角y] |
| 字段 | 说明|
| --- |---|
|type| 图片区域的类型 |
|bbox| 图片区域的在原图的坐标,分别[左上角x左上角y右下角x右下角y]|
|res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict字段说明如下<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: 表格的HTML字符串<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 在代码使用模式下前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果对应为如下字段: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: 文本检测坐标<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录图片里的每个表格会存储为一个excel图片区域会被裁剪之后保存下来excel文件和图片名为表格在图片里的坐标。
@ -180,20 +215,26 @@ dict 里各个字段说明如下
<a name="24"></a>
### 2.4 参数说明
| 字段 | 说明 | 默认值 |
|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | 版面分析模型模型label映射字典 | None |
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域是否执行ocr。当layout为False时会被自动设置为False | True |
| structure_version | 表格结构化模型版本,可选 PP-STRUCTURE。PP-STRUCTURE支持表格结构化模型 | PP-STRUCTURE |
| 字段 | 说明 | 默认值 |
|---|---|---|
| output | 结果保存地址 | ./output/table |
| table_max_len | 表格结构模型预测时图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址| None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| merge_no_span_structure | 表格识别模型中,是否对'\<td>'和'\</td>' 进行合并 | False |
| layout_model_dir | 版面分析模型 inference 模型地址 | None |
| layout_dict_path | 版面分析模型字典| ../ppocr/utils/dict/layout_publaynet_dict.txt |
| layout_score_threshold | 版面分析模型检测框阈值| 0.5|
| layout_nms_threshold | 版面分析模型nms阈值| 0.5|
| vqa_algorithm | vqa模型算法| LayoutXLM|
| ser_model_dir | ser模型 inference 模型地址| None|
| ser_dict_path | ser模型字典| ../train_data/XFUND/class_list_xfun.txt|
| mode | structure or vqa | structure |
| image_orientation | 前向中是否执行图像方向分类 | False |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域是否执行ocr。当layout为False时会被自动设置为False| True |
| recovery | 前向中是否执行版面恢复| False |
| structure_version | 模型版本,可选 PP-structure和PP-structurev2 | PP-structure |
大部分参数和PaddleOCR whl包保持一致见 [whl包文档](../../doc/doc_ch/whl.md)

View File

@ -1,21 +1,23 @@
# PP-Structure Quick Start
- [1. Install package](#1)
- [2. Use](#2)
- [2.1 Use by command line](#21)
- [2.1.1 layout analysis + table recognition](#211)
- [2.1.2 layout analysis](#212)
- [2.1.3 table recognition](#213)
- [2.1.4 DocVQA](#214)
- [2.2 Use by code](#22)
- [2.2.1 layout analysis + table recognition](#221)
- [2.2.2 layout analysis](#222)
- [2.2.3 table recognition](#223)
- [2.2.4 DocVQA](#224)
- [2.3 Result description](#23)
- [2.3.1 layout analysis + table recognition](#231)
- [2.3.2 DocVQA](#232)
- [2.4 Parameter Description](#24)
- [1. Install package](#1-install-package)
- [2. Use](#2-use)
- [2.1 Use by command line](#21-use-by-command-line)
- [2.1.1 image orientation + layout analysis + table recognition](#211-image-orientation--layout-analysis--table-recognition)
- [2.1.2 layout analysis + table recognition](#212-layout-analysis--table-recognition)
- [2.1.3 layout analysis](#213-layout-analysis)
- [2.1.4 table recognition](#214-table-recognition)
- [2.1.5 DocVQA](#215-docvqa)
- [2.2 Use by code](#22-use-by-code)
- [2.2.1 image orientation + layout analysis + table recognition](#221-image-orientation--layout-analysis--table-recognition)
- [2.2.2 layout analysis + table recognition](#222-layout-analysis--table-recognition)
- [2.2.3 layout analysis](#223-layout-analysis)
- [2.2.4 table recognition](#224-table-recognition)
- [2.2.5 DocVQA](#225-docvqa)
- [2.3 Result description](#23-result-description)
- [2.3.1 layout analysis + table recognition](#231-layout-analysis--table-recognition)
- [2.3.2 DocVQA](#232-docvqa)
- [2.4 Parameter Description](#24-parameter-description)
<a name="1"></a>
@ -24,8 +26,6 @@
```bash
# Install paddleocr, version 2.5+ is recommended
pip3 install "paddleocr>=2.5"
# Install layoutparser (if you do not use the layout analysis, you can skip it)
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# Install the DocVQA dependency package paddlenlp (if you do not use the DocVQA, you can skip it)
pip install paddlenlp
@ -38,25 +38,31 @@ pip install paddlenlp
### 2.1 Use by command line
<a name="211"></a>
#### 2.1.1 layout analysis + table recognition
#### 2.1.1 image orientation + layout analysis + table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
```
<a name="212"></a>
#### 2.1.2 layout analysis + table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="212"></a>
#### 2.1.2 layout analysis
<a name="213"></a>
#### 2.1.3 layout analysis
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 table recognition
<a name="214"></a>
#### 2.1.4 table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
<a name="215"></a>
#### 2.1.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
@ -64,7 +70,36 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
### 2.2 Use by code
<a name="221"></a>
#### 2.2.1 layout analysis + table recognition
#### 2.2.1 image orientation + layout analysis + table recognition
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
from PIL import Image
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 layout analysis + table recognition
```python
import os
@ -92,8 +127,8 @@ im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 layout analysis
<a name="223"></a>
#### 2.2.3 layout analysis
```python
import os
@ -113,8 +148,8 @@ for line in result:
print(line)
```
<a name="223"></a>
#### 2.2.3 table recognition
<a name="224"></a>
#### 2.2.4 table recognition
```python
import os
@ -134,8 +169,8 @@ for line in result:
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
<a name="225"></a>
#### 2.2.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
@ -157,8 +192,8 @@ The return of PP-Structure is a list of dicts, the example is as follows:
```
Each field in dict is described as follows:
| field | description |
| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| field | description |
| --- |---|
|type| Type of image area. |
|bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. |
|res| OCR or table recognition result of the image area. <br> table: a dict with field descriptions as follows: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: html str of table.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: text detection boxes.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: text recognition results.<br> OCR: A tuple containing the detection boxes and recognition results of each single text. |
@ -180,19 +215,26 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
<a name="24"></a>
### 2.4 Parameter Description
| field | description | default |
|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | The save path of result | ./output/table |
| table_max_len | When the table structure model predicts, the long side of the image | 488 |
| table_model_dir | the path of table structure model | None |
| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | Layout analysis model model label mapping dictionary path | None |
| model_name_or_path | the model path of VQA SER model | None |
| max_seq_length | the max token length of VQA SER model | 512 |
| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt |
| layout | Whether to perform layout analysis in forward | True |
| table | Whether to perform table recognition in forward | True |
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True |
| structure_version | table structure Model version number, the current model support list is as follows: PP-STRUCTURE support english table structure model | PP-STRUCTURE |
| field | description | default |
|---|---|---|
| output | result save path | ./output/table |
| table_max_len | long side of the image resize in table structure model | 488 |
| table_model_dir | Table structure model inference model path| None |
| table_char_dict_path | The dictionary path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| merge_no_span_structure | In the table recognition model, whether to merge '\<td>' and '\</td>' | False |
| layout_model_dir | Layout analysis model inference model path| None |
| layout_dict_path | The dictionary path of layout analysis model| ../ppocr/utils/dict/layout_publaynet_dict.txt |
| layout_score_threshold | The box threshold path of layout analysis model| 0.5|
| layout_nms_threshold | The nms threshold path of layout analysis model| 0.5|
| vqa_algorithm | vqa model algorithm| LayoutXLM|
| ser_model_dir | Ser model inference model path| None|
| ser_dict_path | The dictionary path of Ser model| ../train_data/XFUND/class_list_xfun.txt|
| mode | structure or vqa | structure |
| image_orientation | Whether to perform image orientation classification in forward | False |
| layout | Whether to perform layout analysis in forward | True |
| table | Whether to perform table recognition in forward | True |
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False| True |
| recovery | Whether to perform layout recovery in forward| False |
| structure_version | Structure version, optional PP-structure and PP-structurev2 | PP-structure |
Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md)

View File

@ -18,7 +18,7 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
@ -27,11 +27,11 @@ import numpy as np
import time
import logging
from copy import deepcopy
from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.recovery_to_doc import convert_info_docx
@ -42,6 +42,14 @@ logger = get_logger()
class StructureSystem(object):
def __init__(self, args):
self.mode = args.mode
self.recovery = args.recovery
self.image_orientation_predictor = None
if args.image_orientation:
import paddleclas
self.image_orientation_predictor = paddleclas.PaddleClas(
model_name="text_image_orientation")
if self.mode == 'structure':
if not args.show_log:
logger.setLevel(logging.INFO)
@ -51,28 +59,14 @@ class StructureSystem(object):
"When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
# init layout and ocr model
# init model
self.layout_predictor = None
self.text_system = None
self.table_system = None
if args.layout:
import layoutparser as lp
config_path = None
model_path = None
if os.path.isdir(args.layout_path_model):
model_path = args.layout_path_model
else:
config_path = args.layout_path_model
self.table_layout = lp.PaddleDetectionLayoutModel(
config_path=config_path,
model_path=model_path,
label_map=args.layout_label_map,
threshold=0.5,
enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,
thread_num=args.cpu_threads)
self.layout_predictor = LayoutPredictor(args)
if args.ocr:
self.text_system = TextSystem(args)
else:
self.table_layout = None
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
@ -80,39 +74,78 @@ class StructureSystem(object):
self.text_system.text_recognizer)
else:
self.table_system = TableSystem(args)
else:
self.table_system = None
elif self.mode == 'vqa':
raise NotImplementedError
def __call__(self, img, return_ocr_result_in_table=False):
time_dict = {
'image_orientation': 0,
'layout': 0,
'table': 0,
'table_match': 0,
'det': 0,
'rec': 0,
'vqa': 0,
'all': 0
}
start = time.time()
if self.image_orientation_predictor is not None:
tic = time.time()
cls_result = self.image_orientation_predictor.predict(
input_data=img)
cls_res = next(cls_result)
angle = cls_res[0]['label_names'][0]
cv_rotate_code = {
'90': cv2.ROTATE_90_COUNTERCLOCKWISE,
'180': cv2.ROTATE_180,
'270': cv2.ROTATE_90_CLOCKWISE
}
img = cv2.rotate(img, cv_rotate_code[angle])
toc = time.time()
time_dict['image_orientation'] = toc - tic
if self.mode == 'structure':
ori_im = img.copy()
if self.table_layout is not None:
layout_res = self.table_layout.detect(img[..., ::-1])
if self.layout_predictor is not None:
layout_res, elapse = self.layout_predictor(img)
time_dict['layout'] += elapse
else:
h, w = ori_im.shape[:2]
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
layout_res = [dict(bbox=None, label='table')]
res_list = []
for region in layout_res:
res = ''
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
if region['bbox'] is not None:
x1, y1, x2, y2 = region['bbox']
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
else:
x1, y1, x2, y2 = 0, 0, w, h
roi_img = ori_im
if region['label'] == 'table':
if self.table_system is not None:
res = self.table_system(roi_img,
return_ocr_result_in_table)
res, table_time_dict = self.table_system(
roi_img, return_ocr_result_in_table)
time_dict['table'] += table_time_dict['table']
time_dict['table_match'] += table_time_dict['match']
time_dict['det'] += table_time_dict['det']
time_dict['rec'] += table_time_dict['rec']
else:
if self.text_system is not None:
if args.recovery:
if self.recovery:
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
wht_im[y1:y2, x1:x2, :] = roi_img
filter_boxes, filter_rec_res = self.text_system(wht_im)
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
wht_im)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
# remove style char
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
roi_img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b>
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>',
@ -125,7 +158,7 @@ class StructureSystem(object):
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not args.recovery:
if not self.recovery:
box += [x1, y1]
res.append({
'text': rec_str,
@ -133,15 +166,17 @@ class StructureSystem(object):
'text_region': box.tolist()
})
res_list.append({
'type': region.type,
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'img': roi_img,
'res': res
})
return res_list
end = time.time()
time_dict['all'] = end - start
return res_list, time_dict
elif self.mode == 'vqa':
raise NotImplementedError
return None
return None, None
def save_structure_res(res, save_folder, img_name):
@ -156,12 +191,12 @@ def save_structure_res(res, save_folder, img_name):
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'Table' and len(region[
if region['type'] == 'table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res']['html'], excel_path)
elif region['type'] == 'Figure':
elif region['type'] == 'figure':
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
@ -187,8 +222,7 @@ def main(args):
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = structure_sys(img)
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
save_structure_res(res, save_folder, img_name)
@ -201,9 +235,8 @@ def main(args):
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
convert_info_docx(img, res, save_folder, img_name)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
convert_info_docx(img, res, save_folder, img_name)
logger.info("Predict time : {:.3f}s".format(time_dict['all']))
if __name__ == "__main__":

View File

@ -1,126 +1,124 @@
- [Table Recognition](#table-recognition)
- [1. pipeline](#1-pipeline)
- [2. Performance](#2-performance)
- [3. How to use](#3-how-to-use)
- [3.1 quick start](#31-quick-start)
- [3.2 Train](#32-train)
- [3.3 Eval](#33-eval)
- [3.4 Inference](#34-inference)
English | [简体中文](README_ch.md)
# Table Recognition
- [1. pipeline](#1-pipeline)
- [2. Performance](#2-performance)
- [3. Result](#3-result)
- [4. How to use](#4-how-to-use)
- [4.1 Quick start](#41-quick-start)
- [4.2 Train](#42-train)
- [4.3 Calculate TEDS](#43-calculate-teds)
- [5. Reference](#5-reference)
## 1. pipeline
The table recognition mainly contains three models
1. Single line text detection-DB
2. Single line text recognition-CRNN
3. Table structure and cell coordinate prediction-RARE
3. Table structure and cell coordinate prediction-SLANet
The table recognition flow chart is as follows
![tableocr_pipeline](../docs/table/tableocr_pipeline_en.jpg)
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
2. The table structure and cell coordinates is predicted by RARE model.
2. The table structure and cell coordinates is predicted by SLANet model.
3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
4. The cell recognition result and the table structure together construct the html string of the table.
## 2. Performance
We evaluated the algorithm on the PubTabNet<sup>[1]</sup> eval dataset, and the performance is as follows:
|Method|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
| --- | --- | --- | ---|
| EDD<sup>[2]</sup> |x| 88.3 |x|
| TableRec-RARE(ours) |73.8%| 93.32 |1550ms|
| SLANet(ours) | 76.2%| 94.98 |766ms|
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
| --- | --- |
| EDD<sup>[2]</sup> | 88.3 |
| Ours | 93.32 |
The performance indicators are explained as follows:
- Acc: The accuracy of the table structure in each image, a wrong token is considered an error.
- TEDS: The accuracy of the model's restoration of table information. This indicator evaluates not only the table structure, but also the text content in the table.
- Speed: The inference speed of a single image when the model runs on the CPU machine and MKL is enabled.
## 3. How to use
## 3. Result
### 3.1 quick start
![](../docs/imgs/table_ch_result1.jpg)
![](../docs/imgs/table_ch_result2.jpg)
![](../docs/imgs/table_ch_result3.jpg)
## 4. How to use
### 4.1 Quick start
Use the following commands to quickly complete the identification of a table.
```python
cd PaddleOCR/ppstructure
# download model
mkdir inference && cd inference
# Download the detection model of the ultra-lightweight table English OCR model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
# Download the recognition model of the ultra-lightweight table English OCR model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
# Download the ultra-lightweight English table inch model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
# Download the PP-OCRv3 text detection model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar
# Download the PP-OCRv3 text recognition model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar
# Download the PP-Structurev2 form recognition model and unzip it
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
cd ..
# run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
python3.7 table/predict_table.py \
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
--image_dir=docs/table/table.jpg \
--output=../output/table
After running, the excel sheet of each picture will be saved in the directory specified by the output field
### 3.2 Train
In this chapter, we only introduce the training of the table structure model, For model training of [text detection](../../doc/doc_en/detection_en.md) and [text recognition](../../doc/doc_en/recognition_en.md), please refer to the corresponding documents
* data preparation
The training data uses public data set [PubTabNet](https://arxiv.org/abs/1911.10683 ), Can be downloaded from the official [website](https://github.com/ibm-aur-nlp/PubTabNet) 。The PubTabNet data set contains about 500,000 images, as well as annotations in html format。
* Start training
*If you are installing the cpu version of paddle, please modify the `use_gpu` field in the configuration file to false*
```shell
# single GPU training
python3 tools/train.py -c configs/table/table_mv3.yml
# multi-GPU training
# Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
```
In the above instruction, use `-c` to select the training to use the `configs/table/table_mv3.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](../../doc/doc_en/config_en.md).
After the operation is completed, the excel table of each image will be saved to the directory specified by the output field, and an html file will be produced in the directory to visually view the cell coordinates and the recognized table.
* load trained model and continue training
### 4.2 Train
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
The training, evaluation and inference process of the text detection model can be referred to [detection](../../doc/doc_en/detection_en.md)
```shell
python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
```
The training, evaluation and inference process of the text recognition model can be referred to [recognition](../../doc/doc_en/recognition_en.md)
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
The training, evaluation and inference process of the table recognition model can be referred to [table_recognition](../../doc/doc_en/table_recognition_en.md)
### 3.3 Eval
### 4.3 Calculate TEDS
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
```json
{"PMC4289340_004_00.png": [
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
]}
```txt
PMC5755158_010_01.png <html><body><table><thead><tr><td></td><td><b>Weaning</b></td><td><b>Week 15</b></td><td><b>Off-test</b></td></tr></thead><tbody><tr><td>Weaning</td><td></td><td></td><td></td></tr><tr><td>Week 15</td><td></td><td>0.17 ± 0.08</td><td>0.16 ± 0.03</td></tr><tr><td>Off-test</td><td></td><td>0.80 ± 0.24</td><td>0.19 ± 0.09</td></tr></tbody></table></body></html>
```
Each line in gt consists of the file name and the html string of the table. The file name and the html string of the table are separated by `\t`.
You can also use the following command to generate an evaluation gt file from the annotation file:
```python
python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file
```
In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
1. HTML string list of table structure
2. The coordinates of each cell (not including the empty text in the cell)
3. The text information in each cell (not including the empty text in the cell)
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python
cd PaddleOCR/ppstructure
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
--image_dir=../doc/table/1.png \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
--gt_path=path/to/gt.txt
```
If the PubLatNet eval dataset is used, it will be output
```bash
teds: 93.32
teds: 94.98
```
### 3.4 Inference
```python
cd PaddleOCR/ppstructure
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, the excel sheet of each picture will be saved in the directory specified by the output field
Reference
## 5. Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
2. https://arxiv.org/pdf/1911.10683

View File

@ -2,22 +2,22 @@
# 表格识别
- [1. 表格识别 pipeline](#1)
- [2. 性能](#2)
- [3. 使用](#3)
- [3.1 快速开始](#31)
- [3.2 训练](#32)
- [3.3 评估](#33)
- [3.4 预测](#34)
- [1. 表格识别 pipeline](#1-表格识别-pipeline)
- [2. 性能](#2-性能)
- [3. 效果演示](#3-效果演示)
- [4. 使用](#4-使用)
- [4.1 快速开始](#41-快速开始)
- [4.2 训练](#42-训练)
- [4.3 计算TEDS](#43-计算teds)
- [5. Reference](#5-reference)
<a name="1"></a>
## 1. 表格识别 pipeline
表格识别主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
3. 表格结构和cell坐标预测-RARE
3. 表格结构和cell坐标预测-SLANet
具体流程图如下
@ -26,111 +26,102 @@
流程说明:
1. 图片由单行文字检测模型检测到单行文字的坐标,然后送入识别模型拿到识别结果。
2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
2. 图片由SLANet模型拿到表格的结构信息和单元格的坐标信息。
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
<a name="2"></a>
## 2. 性能
我们在 PubTabNet<sup>[1]</sup> 评估数据集上对算法进行了评估,性能如下
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
| --- | --- |
| EDD<sup>[2]</sup> | 88.3 |
| Ours | 93.32 |
|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
| --- | --- | --- | ---|
| EDD<sup>[2]</sup> |x| 88.3 |x|
| TableRec-RARE(ours) |73.8%| 93.32 |1550ms|
| SLANet(ours) | 76.2%| 94.98 |766ms|
<a name="3"></a>
## 3. 使用
性能指标解释如下:
- Acc: 模型对每张图像里表格结构的识别准确率错一个token就算错误。
- TEDS: 模型对表格信息还原的准确度,此指标评价内容不仅包含表格结构,还包含表格内的文字内容。
- Speed: 模型在CPU机器上开启MKL的情况下单张图片的推理速度。
<a name="31"></a>
### 3.1 快速开始
## 3. 效果演示
![](../docs/imgs/table_ch_result1.jpg)
![](../docs/imgs/table_ch_result2.jpg)
![](../docs/imgs/table_ch_result3.jpg)
## 4. 使用
### 4.1 快速开始
使用如下命令即可快速完成一张表格的识别。
```python
cd PaddleOCR/ppstructure
# 下载模型
mkdir inference && cd inference
# 下载超轻量级表格英文OCR模型的检测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
# 下载超轻量级表格英文OCR模型的识别模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
# 下载超轻量级英文表格英寸模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
# 下载PP-OCRv3文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar
# 下载PP-OCRv3文本识别模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar
# 下载PP-Structurev2表格识别模型并解压
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
cd ..
# 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
# 执行表格识别
python table/predict_table.py \
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
--image_dir=docs/table/table.jpg \
--output=../output/table
```
运行完成后每张图片的excel表格会保存到output字段指定的目录下
运行完成后每张图片的excel表格会保存到output字段指定的目录下同时在该目录下回生产一个html文件用于可视化查看单元格坐标和识别的表格。
note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
### 4.2 训练
<a name="32"></a>
### 3.2 训练
文本检测模型的训练、评估和推理流程可参考 [detection](../../doc/doc_ch/detection.md)
在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)和[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
文本识别模型的训练、评估和推理流程可参考 [recognition](../../doc/doc_ch/recognition.md)
* 数据准备
表格识别模型的训练、评估和推理流程可参考 [table_recognition](../../doc/doc_ch/table_recognition.md)
训练数据使用公开数据集PubTabNet ([论文](https://arxiv.org/abs/1911.10683)[下载地址](https://github.com/ibm-aur-nlp/PubTabNet))。PubTabNet数据集包含约50万张表格数据的图像以及图像对应的html格式的注释。
* 启动训练
*如果您安装的是cpu版本请将配置文件中的 `use_gpu` 字段修改为false*
```shell
# 单机单卡训练
python3 tools/train.py -c configs/table/table_mv3.yml
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
```
上述指令中,通过-c 选择训练使用configs/table/table_mv3.yml配置文件。有关配置文件的详细解释请参考[链接](../../doc/doc_ch/config.md)。
* 断点训练
如果训练程序中断如果希望加载训练中断的模型从而恢复训练可以通过指定Global.checkpoints指定要加载的模型路径
```shell
python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
<a name="33"></a>
### 3.3 评估
### 4.3 计算TEDS
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好)还需要准备评估的gt gt示例如下:
```json
{"PMC4289340_004_00.png": [
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
]}
```txt
PMC5755158_010_01.png <html><body><table><thead><tr><td></td><td><b>Weaning</b></td><td><b>Week 15</b></td><td><b>Off-test</b></td></tr></thead><tbody><tr><td>Weaning</td><td></td><td></td><td></td></tr><tr><td>Week 15</td><td></td><td>0.17 ± 0.08</td><td>0.16 ± 0.03</td></tr><tr><td>Off-test</td><td></td><td>0.80 ± 0.24</td><td>0.19 ± 0.09</td></tr></tbody></table></body></html>
```
gt每一行都由文件名和表格的html字符串组成文件名和表格的html字符串之间使用`\t`分隔。
也可使用如下命令由标注文件生成评估的gt文件
```python
python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file
```
json 中key为图片名value为对应的gtgt是一个由三个item组成的list每个item分别为
1. 表格结构的html字符串list
2. 每个cell的坐标 (不包括cell里文字为空的)
3. 每个cell里的文字信息 (不包括cell里文字为空的)
准备完成后使用如下命令进行评估评估完成后会输出teds指标。
```python
cd PaddleOCR/ppstructure
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
--image_dir=../doc/table/1.png \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
--gt_path=path/to/gt.txt
```
如使用PubLatNet评估数据集将会输出
```bash
teds: 93.32
teds: 94.98
```
<a name="34"></a>
### 3.4 预测
```python
cd PaddleOCR/ppstructure
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
# Reference
## 5. Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
2. https://arxiv.org/pdf/1911.10683

View File

@ -0,0 +1,102 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
conver table label to html
"""
import json
import argparse
from tqdm import tqdm
def save_pred_txt(key, val, tmp_file_path):
with open(tmp_file_path, 'a+', encoding='utf-8') as f:
f.write('{}\t{}\n'.format(key, val))
def skip_char(text, sp_char_list):
"""
skip empty cell
@param text: text in cell
@param sp_char_list: style char and special code
@return:
"""
for sp_char in sp_char_list:
text = text.replace(sp_char, '')
return text
def gen_html(img):
'''
Formats HTML code from tokenized annotation of img
'''
html_code = img['html']['structure']['tokens'].copy()
to_insert = [i for i, tag in enumerate(html_code) if tag in ('<td>', '>')]
for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]):
if cell['tokens']:
text = ''.join(cell['tokens'])
# skip empty text
sp_char_list = ['<b>', '</b>', '\u2028', ' ', '<i>', '</i>']
text_remove_style = skip_char(text, sp_char_list)
if len(text_remove_style) == 0:
continue
html_code.insert(i + 1, text)
html_code = ''.join(html_code)
html_code = '<html><body><table>{}</table></body></html>'.format(html_code)
return html_code
def load_gt_data(gt_path):
"""
load gt
@param gt_path:
@return:
"""
data_list = {}
with open(gt_path, 'rb') as f:
lines = f.readlines()
for line in tqdm(lines):
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
data_list[info['filename']] = info
return data_list
def convert(origin_gt_path, save_path):
"""
gen html from label file
@param origin_gt_path:
@param save_path:
@return:
"""
data_dict = load_gt_data(origin_gt_path)
for img_name, gt in tqdm(data_dict.items()):
html = gen_html(gt)
save_pred_txt(img_name, html, save_path)
print('conver finish')
def parse_args():
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
"--ori_gt_path", type=str, required=True, help="label gt path")
parser.add_argument(
"--save_path", type=str, required=True, help="path to save file")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
convert(args.ori_gt_path, args.save_path)

View File

@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
import pickle
import paddle
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem
@ -33,40 +36,74 @@ def parse_args():
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
def load_txt(txt_path):
pred_html_dict = {}
if not os.path.exists(txt_path):
return pred_html_dict
with open(txt_path, encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split('\t')
img_name, pred_html = line
pred_html_dict[img_name] = pred_html
return pred_html_dict
def load_result(path):
data = {}
if os.path.exists(path):
data = pickle.load(open(path, 'rb'))
return data
def save_result(path, data):
old_data = load_result(path)
old_data.update(data)
with open(path, 'wb') as f:
pickle.dump(old_data, f)
def main(gt_path, img_root, args):
os.makedirs(args.output, exist_ok=True)
# init TableSystem
text_sys = TableSystem(args)
jsons_gt = json.load(open(gt_path)) # gt
# load gt and preds html result
gt_html_dict = load_txt(gt_path)
ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
structure_result = load_result(
os.path.join(args.output, 'structure.pickle'))
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
for img_name, gt_html in tqdm(gt_html_dict.items()):
img = cv2.imread(os.path.join(img_root, img_name))
# run ocr and save result
if img_name not in ocr_result:
dt_boxes, rec_res, _, _ = text_sys._ocr(img)
ocr_result[img_name] = [dt_boxes, rec_res]
save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
# run structure and save result
if img_name not in structure_result:
structure_res, _ = text_sys._structure(img)
structure_result[img_name] = structure_res
save_result(
os.path.join(args.output, 'structure.pickle'), structure_result)
dt_boxes, rec_res = ocr_result[img_name]
structure_res = structure_result[img_name]
# match ocr and structure
pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, gt_contents)
gt_htmls.append(gt_html)
# compute teds
teds = TEDS(n_jobs=16)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
logger.info('teds:', sum(scores) / len(scores))
def get_gt_html(gt_structures, gt_contents):
end_html = []
td_index = 0
for tag in gt_structures:
if '</td>' in tag:
if gt_contents[td_index] != []:
end_html.extend(gt_contents[td_index])
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
logger.info('teds: {}'.format(sum(scores) / len(scores)))
if __name__ == '__main__':
args = parse_args()
main(args.gt_path,args.image_dir, args)
main(args.gt_path, args.image_dir, args)

View File

@ -1,11 +1,29 @@
import json
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppstructure.table.table_master_match import deal_eb_token, deal_bb
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4 - x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
@ -18,175 +36,157 @@ def compute_iou(rec1, rec2):
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
return (intersect / (sum_area - intersect)) * 1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
class TableMatch:
def __init__(self, filter_ocr_result=False, use_master=False):
self.filter_ocr_result = filter_ocr_result
self.use_master = use_master
def __call__(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
if self.filter_ocr_result:
dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
if self.use_master:
pred_html, pred = self.get_pred_html_master(pred_structures,
matched_index, rec_res)
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes:
distances = []
temp_ious = []
for pred_bbox in pred_bboxes:
if bbox != pred_bbox:
distances.append(distance(bbox, pred_bbox))
temp_ious.append(compute_iou(bbox, pred_bbox))
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
step = 0
for i in range(len(pred_bboxes)):
bbox = pred_bboxes[i]
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
break
else:
res.append(bbox)
step += 1
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
ys_1.append(box[1])
ys_2.append(box[3])
min_y_1 = sum(ys_1) / len(ys_1)
min_y_2 = sum(ys_2) / len(ys_2)
re_boxes = []
for box in pred_bboxes:
box[1] = min_y_1
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
while(len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(gt_bboxes):
distances = []
#temp_ious = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(distance(gt_box, pred_box))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
return matched#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
gt_box_index = 0
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
while(len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
if index not in matched.keys():
matched[index] = [gt_box_index]
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
if len(pred_box) == 8:
pred_box = [
np.min(pred_box[0::2]), np.min(pred_box[1::2]),
np.max(pred_box[0::2]), np.max(pred_box[1::2])
]
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)
)) # compute iou and l1 distance
sorted_distances = distances.copy()
# select det box by iou and l1 distance
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
matched[distances.index(sorted_distances[0])].append(i)
return matched
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox = gt_bboxes[0]
matched = {}
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
distances = []
gap_pre = gt_box[1] - pre_bbox[1]
gap_pre_1 = gt_box[0] - pre_bbox[2]
#print(gap_pre, len(pred_bboxes_rows))
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(pred_bboxes_rows) == 1:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
break
#print(match_bboxes_ready)
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if '<td></td>' == tag:
end_html.extend('<td>')
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
if '<td></td>' == tag:
end_html.append('</td>')
else:
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def get_pred_html_master(self, pred_structures, matched_index,
ocr_contents):
end_html = []
td_index = 0
for token in pred_structures:
if '</td>' in token:
txt = ''
b_with = False
if td_index in matched_index.keys():
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
txt += content
if b_with:
txt = '<b>{}</b>'.format(txt)
if '<td></td>' == token:
token = '<td>{}</td>'.format(txt)
else:
token = '{}</td>'.format(txt)
td_index += 1
token = deal_eb_token(token)
end_html.append(token)
html = ''.join(end_html)
html = deal_bb(html)
return html, end_html
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
y1 = pred_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
for box, rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res

View File

@ -16,7 +16,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@ -73,12 +73,14 @@ class TableStructurer(object):
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
'merge_no_span_structure': args.merge_no_span_structure
}
else:
postprocess_params = {
'name': 'TableMasterLabelDecode',
"character_dict_path": args.table_char_dict_path,
'box_shape': 'pad'
'box_shape': 'pad',
'merge_no_span_structure': args.merge_no_span_structure
}
self.preprocess_op = create_operators(pre_process_list)
@ -87,6 +89,7 @@ class TableStructurer(object):
utility.create_predictor(args, 'table', logger)
def __call__(self, img):
starttime = time.time()
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
@ -95,7 +98,6 @@ class TableStructurer(object):
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
@ -126,7 +128,6 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
@ -146,7 +147,10 @@ def main(args):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
img = draw_rectangle(image_file, pred_res['cell_bbox'])
else:
img = utility.draw_boxes(img, bbox_list)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)

View File

@ -18,20 +18,23 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import logging
import numpy as np
import time
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import tools.infer.utility as utility
from tools.infer.predict_system import sorted_boxes
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou
from ppstructure.table.matcher import TableMatch
from ppstructure.table.table_master_match import TableMasterMatcher
from ppstructure.utility import parse_args
import ppstructure.table.predict_structure as predict_strture
@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_detector = predict_det.TextDetector(
args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
if args.table_algorithm in ['TableMaster']:
self.match = TableMasterMatcher()
else:
self.match = TableMatch(filter_ocr_result=True)
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'table', logger)
@ -85,145 +97,72 @@ class TableSystem(object):
def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
ori_im = img.copy()
time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
start = time.time()
structure_res, elapse = self._structure(copy.deepcopy(img))
result['cell_bbox'] = structure_res[1].tolist()
time_dict['table'] = elapse
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
copy.deepcopy(img))
time_dict['det'] = det_elapse
time_dict['rec'] = rec_elapse
if return_ocr_result_in_table:
result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes]
result['rec_res'] = rec_res
tic = time.time()
pred_html = self.match(structure_res, dt_boxes, rec_res)
toc = time.time()
time_dict['match'] = toc - tic
result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
end = time.time()
time_dict['all'] = end - start
if self.benchmark:
self.autolog.times.stamp()
return result, time_dict
def _structure(self, img):
if self.benchmark:
self.autolog.times.start()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
return structure_res, elapse
def _ocr(self, img):
h, w = img.shape[:2]
if self.benchmark:
self.autolog.times.stamp()
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
if return_ocr_result_in_table:
result['boxes'] = [x.tolist() for x in dt_boxes]
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
x_max = box[:, 0].max() + 1
y_min = box[:, 1].min() - 1
y_max = box[:, 1].max() + 1
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
len(dt_boxes), det_elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
x0, y0, x1, y1 = expand(2, det_box, img.shape)
text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list)
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if self.benchmark:
self.autolog.times.stamp()
if return_ocr_result_in_table:
result['rec_res'] = rec_res
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
return result
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
y1 = pred_bboxes[:,1::2].min()
new_dt_boxes = []
new_rec_res = []
for box,rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)
)) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
len(rec_res), rec_elapse))
return dt_boxes, rec_res, det_elapse, rec_elapse
def to_excel(html_table, excel_path):
@ -236,8 +175,23 @@ def main(args):
image_file_list = image_file_list[args.process_id::args.total_process_num]
os.makedirs(args.output, exist_ok=True)
text_sys = TableSystem(args)
table_sys = TableSystem(args)
img_num = len(image_file_list)
f_html = open(
os.path.join(args.output, 'show.html'), mode='w', encoding='utf-8')
f_html.write('<html>\n<body>\n')
f_html.write('<table border="1">\n')
f_html.write(
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
)
f_html.write("<tr>\n")
f_html.write('<td>img name\n')
f_html.write('<td>ori image</td>')
f_html.write('<td>table html</td>')
f_html.write('<td>cell box</td>')
f_html.write("</tr>\n")
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
@ -249,13 +203,35 @@ def main(args):
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_res = text_sys(img)
pred_res, _ = table_sys(img)
pred_html = pred_res['html']
logger.info(pred_html)
to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if len(pred_res['cell_bbox']) > 0 and len(pred_res['cell_bbox'][
0]) == 4:
img = predict_strture.draw_rectangle(image_file,
pred_res['cell_bbox'])
else:
img = utility.draw_boxes(img, pred_res['cell_bbox'])
img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
f_html.write("<tr>\n")
f_html.write(f'<td> {os.path.basename(image_file)} <br/>\n')
f_html.write(f'<td><img src="{image_file}" width=640></td>\n')
f_html.write('<td><table border="1">' + pred_html.replace(
'<html><body><table>', '').replace('</table></body></html>', '') +
'</table></td>\n')
f_html.write(
f'<td><img src="{os.path.basename(image_file)}" width=640></td>\n')
f_html.write("</tr>\n")
f_html.write("</table>\n")
f_html.close()
if args.benchmark:
text_sys.autolog.report()

View File

@ -0,0 +1,953 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
"""
import os
import re
import cv2
import glob
import copy
import math
import pickle
import numpy as np
from shapely.geometry import Polygon, MultiPoint
"""
Useful function in matching.
"""
def remove_empty_bboxes(bboxes):
"""
remove [0., 0., 0., 0.] in structure master bboxes.
len(bboxes.shape) must be 2.
:param bboxes:
:return:
"""
new_bboxes = []
for bbox in bboxes:
if sum(bbox) == 0.:
continue
new_bboxes.append(bbox)
return np.array(new_bboxes)
def xywh2xyxy(bboxes):
if len(bboxes.shape) == 1:
new_bboxes = np.empty_like(bboxes)
new_bboxes[0] = bboxes[0] - bboxes[2] / 2
new_bboxes[1] = bboxes[1] - bboxes[3] / 2
new_bboxes[2] = bboxes[0] + bboxes[2] / 2
new_bboxes[3] = bboxes[1] + bboxes[3] / 2
return new_bboxes
elif len(bboxes.shape) == 2:
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
return new_bboxes
else:
raise ValueError
def xyxy2xywh(bboxes):
if len(bboxes.shape) == 1:
new_bboxes = np.empty_like(bboxes)
new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
new_bboxes[2] = bboxes[2] - bboxes[0]
new_bboxes[3] = bboxes[3] - bboxes[1]
return new_bboxes
elif len(bboxes.shape) == 2:
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
return new_bboxes
else:
raise ValueError
def pickle_load(path, prefix='end2end'):
if os.path.isfile(path):
data = pickle.load(open(path, 'rb'))
elif os.path.isdir(path):
data = dict()
search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
pkls = glob.glob(search_path)
for pkl in pkls:
this_data = pickle.load(open(pkl, 'rb'))
data.update(this_data)
else:
raise ValueError
return data
def convert_coord(xyxy):
"""
Convert two points format to four points format.
:param xyxy:
:return:
"""
new_bbox = np.zeros([4, 2], dtype=np.float32)
new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
return new_bbox
def cal_iou(bbox1, bbox2):
bbox1_poly = Polygon(bbox1).convex_hull
bbox2_poly = Polygon(bbox2).convex_hull
union_poly = np.concatenate((bbox1, bbox2))
if not bbox1_poly.intersects(bbox2_poly):
iou = 0
else:
inter_area = bbox1_poly.intersection(bbox2_poly).area
union_area = MultiPoint(union_poly).convex_hull.area
if union_area == 0:
iou = 0
else:
iou = float(inter_area) / union_area
return iou
def cal_distance(p1, p2):
delta_x = p1[0] - p2[0]
delta_y = p1[1] - p2[1]
d = math.sqrt((delta_x**2) + (delta_y**2))
return d
def is_inside(center_point, corner_point):
"""
Find if center_point inside the bbox(corner_point) or not.
:param center_point: center point (x, y)
:param corner_point: corner point ((x1,y1),(x2,y2))
:return:
"""
x_flag = False
y_flag = False
if (center_point[0] >= corner_point[0][0]) and (
center_point[0] <= corner_point[1][0]):
x_flag = True
if (center_point[1] >= corner_point[0][1]) and (
center_point[1] <= corner_point[1][1]):
y_flag = True
if x_flag and y_flag:
return True
else:
return False
def find_no_match(match_list, all_end2end_nums, type='end2end'):
"""
Find out no match end2end bbox in previous match list.
:param match_list: matching pairs.
:param all_end2end_nums: numbers of end2end_xywh
:param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
:return: no match pse bbox index list
"""
if type == 'end2end':
idx = 0
elif type == 'master':
idx = 1
else:
raise ValueError
no_match_indexs = []
# m[0] is end2end index m[1] is master index
matched_bbox_indexs = [m[idx] for m in match_list]
for n in range(all_end2end_nums):
if n not in matched_bbox_indexs:
no_match_indexs.append(n)
return no_match_indexs
def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
# only consider y axis, for grouping in row.
delta = abs(this_bbox[1] - target_bbox[1])
if delta < threshold:
return True
else:
return False
def sort_line_bbox(g, bg):
"""
Sorted the bbox in the same line(group)
compare coord 'x' value, where 'y' value is closed in the same group.
:param g: index in the same group
:param bg: bbox in the same group
:return:
"""
xs = [bg_item[0] for bg_item in bg]
xs_sorted = sorted(xs)
g_sorted = [None] * len(xs_sorted)
bg_sorted = [None] * len(xs_sorted)
for g_item, bg_item in zip(g, bg):
idx = xs_sorted.index(bg_item[0])
bg_sorted[idx] = bg_item
g_sorted[idx] = g_item
return g_sorted, bg_sorted
def flatten(sorted_groups, sorted_bbox_groups):
idxs = []
bboxes = []
for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
for g, bg in zip(group, bbox_group):
idxs.append(g)
bboxes.append(bg)
return idxs, bboxes
def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
"""
This function will group the render end2end bboxes in row.
:param end2end_xywh_bboxes:
:param no_match_end2end_indexes:
:return:
"""
groups = []
bbox_groups = []
for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
end2end_xywh_bboxes):
this_bbox = end2end_xywh_bbox
if len(groups) == 0:
groups.append([index])
bbox_groups.append([this_bbox])
else:
flag = False
for g, bg in zip(groups, bbox_groups):
# this_bbox is belong to bg's row or not
if is_abs_lower_than_threshold(this_bbox, bg[0]):
g.append(index)
bg.append(this_bbox)
flag = True
break
if not flag:
# this_bbox is not belong to bg's row, create a row.
groups.append([index])
bbox_groups.append([this_bbox])
# sorted bboxes in a group
tmp_groups, tmp_bbox_groups = [], []
for g, bg in zip(groups, bbox_groups):
g_sorted, bg_sorted = sort_line_bbox(g, bg)
tmp_groups.append(g_sorted)
tmp_bbox_groups.append(bg_sorted)
# sorted groups, sort by coord y's value.
sorted_groups = [None] * len(tmp_groups)
sorted_bbox_groups = [None] * len(tmp_bbox_groups)
ys = [bg[0][1] for bg in tmp_bbox_groups]
sorted_ys = sorted(ys)
for g, bg in zip(tmp_groups, tmp_bbox_groups):
idx = sorted_ys.index(bg[0][1])
sorted_groups[idx] = g
sorted_bbox_groups[idx] = bg
# flatten, get final result
end2end_sorted_idx_list, end2end_sorted_bbox_list \
= flatten(sorted_groups, sorted_bbox_groups)
return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
def get_bboxes_list(end2end_result, structure_master_result):
"""
This function is use to convert end2end results and structure master results to
List of xyxy bbox format and List of xywh bbox format
:param end2end_result: bbox's format is xyxy
:param structure_master_result: bbox's format is xywh
:return: 4 kind list of bbox ()
"""
# end2end
end2end_xyxy_list = []
end2end_xywh_list = []
for end2end_item in end2end_result:
src_bbox = end2end_item['bbox']
end2end_xyxy_list.append(src_bbox)
xywh_bbox = xyxy2xywh(src_bbox)
end2end_xywh_list.append(xywh_bbox)
end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
end2end_xywh_bboxes = np.array(end2end_xywh_list)
# structure master
src_bboxes = structure_master_result['bbox']
src_bboxes = remove_empty_bboxes(src_bboxes)
structure_master_xyxy_bboxes = src_bboxes
xywh_bbox = xyxy2xywh(src_bboxes)
structure_master_xywh_bboxes = xywh_bbox
return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
"""
Judge end2end Bbox's center point is inside structure master Bbox or not,
if end2end Bbox's center is in structure master Bbox, get matching pair.
:param end2end_xywh_bboxes:
:param structure_master_xyxy_bboxes:
:return: match pairs list, e.g. [[0,1], [1,2], ...]
"""
match_pairs_list = []
for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
x_master1, y_master1, x_master2, y_master2 \
= master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
center_point_end2end = (x_end2end, y_end2end)
corner_point_master = ((x_master1, y_master1),
(x_master2, y_master2))
if is_inside(center_point_end2end, corner_point_master):
match_pairs_list.append([i, j])
return match_pairs_list
def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
structure_master_xyxy_bboxes):
"""
Use iou to find matching list.
choose max iou value bbox as match pair.
:param end2end_xyxy_bboxes:
:param end2end_xyxy_indexes: original end2end indexes.
:param structure_master_xyxy_bboxes:
:return: match pairs list, e.g. [[0,1], [1,2], ...]
"""
match_pair_list = []
for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
end2end_xyxy_bboxes):
max_iou = 0
max_match = [None, None]
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
end2end_4xy = convert_coord(end2end_xyxy)
master_4xy = convert_coord(master_xyxy)
iou = cal_iou(end2end_4xy, master_4xy)
if iou > max_iou:
max_match[0], max_match[1] = end2end_xyxy_index, j
max_iou = iou
if max_match[0] is None:
# no match
continue
match_pair_list.append(max_match)
return match_pair_list
def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
master_bboxes):
"""
Get matching between no-match end2end bboxes and no-match master bboxes.
Use min distance to match.
This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
It will Return master_bboxes_nums match-pairs.
:param end2end_indexes:
:param end2end_bboxes:
:param master_indexes:
:param master_bboxes:
:return: match_pairs list, e.g. [[0,1], [1,2], ...]
"""
min_match_list = []
for j, master_bbox in zip(master_indexes, master_bboxes):
min_distance = np.inf
min_match = [0, 0] # i, j
for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
x_master, y_master = master_bbox[0], master_bbox[1]
end2end_point = (x_end2end, y_end2end)
master_point = (x_master, y_master)
dist = cal_distance(master_point, end2end_point)
if dist < min_distance:
min_match[0], min_match[1] = i, j
min_distance = dist
min_match_list.append(min_match)
return min_match_list
def extra_match(no_match_end2end_indexes, master_bbox_nums):
"""
This function will create some virtual master bboxes,
and get match with the no match end2end indexes.
:param no_match_end2end_indexes:
:param master_bbox_nums:
:return:
"""
end_nums = len(no_match_end2end_indexes) + master_bbox_nums
extra_match_list = []
for i in range(master_bbox_nums, end_nums):
end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
extra_match_list.append([end2end_index, i])
return extra_match_list
def get_match_dict(match_list):
"""
Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
:param match_list:
:return:
"""
match_dict = dict()
for match_pair in match_list:
end2end_index, master_index = match_pair[0], match_pair[1]
if master_index not in match_dict.keys():
match_dict[master_index] = [end2end_index]
else:
match_dict[master_index].append(end2end_index)
return match_dict
def deal_successive_space(text):
"""
deal successive space character for text
1. Replace ' '*3 with '<space>' which is real space is text
2. Remove ' ', which is split token, not true space
3. Replace '<space>' with ' ', to get real text
:param text:
:return:
"""
text = text.replace(' ' * 3, '<space>')
text = text.replace(' ', '')
text = text.replace('<space>', ' ')
return text
def reduce_repeat_bb(text_list, break_token):
"""
convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
:param text_list:
:param break_token:
:return:
"""
count = 0
for text in text_list:
if text.startswith('<b>'):
count += 1
if count == len(text_list):
new_text_list = []
for text in text_list:
text = text.replace('<b>', '').replace('</b>', '')
new_text_list.append(text)
return ['<b>' + break_token.join(new_text_list) + '</b>']
else:
return text_list
def get_match_text_dict(match_dict, end2end_info, break_token=' '):
match_text_dict = dict()
for master_index, end2end_index_list in match_dict.items():
text_list = [
end2end_info[end2end_index]['text']
for end2end_index in end2end_index_list
]
text_list = reduce_repeat_bb(text_list, break_token)
text = break_token.join(text_list)
match_text_dict[master_index] = text
return match_text_dict
def merge_span_token(master_token_list):
"""
Merge the span style token (row span or col span).
:param master_token_list:
:return:
"""
new_master_token_list = []
pointer = 0
if master_token_list[-1] != '</tbody>':
master_token_list.append('</tbody>')
while master_token_list[pointer] != '</tbody>':
try:
if master_token_list[pointer] == '<td':
if master_token_list[pointer + 1].startswith(
' colspan=') or master_token_list[
pointer + 1].startswith(' rowspan='):
"""
example:
pattern <td colspan="3">
'<td' + 'colspan=" "' + '>' + '</td>'
"""
tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
pointer += 4
new_master_token_list.append(tmp)
elif master_token_list[pointer + 2].startswith(
' colspan=') or master_token_list[
pointer + 2].startswith(' rowspan='):
"""
example:
pattern <td rowspan="2" colspan="3">
'<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
"""
tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
pointer += 5
new_master_token_list.append(tmp)
else:
new_master_token_list.append(master_token_list[pointer])
pointer += 1
else:
new_master_token_list.append(master_token_list[pointer])
pointer += 1
except:
print("Break in merge...")
break
new_master_token_list.append('</tbody>')
return new_master_token_list
def deal_eb_token(master_token):
"""
post process with <eb></eb>, <eb1></eb1>, ...
emptyBboxTokenDict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
}
:param master_token:
:return:
"""
master_token = master_token.replace('<eb></eb>', '<td></td>')
master_token = master_token.replace('<eb1></eb1>', '<td> </td>')
master_token = master_token.replace('<eb2></eb2>', '<td><b> </b></td>')
master_token = master_token.replace('<eb3></eb3>', '<td>\u2028\u2028</td>')
master_token = master_token.replace('<eb4></eb4>', '<td><sup> </sup></td>')
master_token = master_token.replace('<eb5></eb5>', '<td><b></b></td>')
master_token = master_token.replace('<eb6></eb6>', '<td><i> </i></td>')
master_token = master_token.replace('<eb7></eb7>',
'<td><b><i></i></b></td>')
master_token = master_token.replace('<eb8></eb8>',
'<td><b><i> </i></b></td>')
master_token = master_token.replace('<eb9></eb9>', '<td><i></i></td>')
master_token = master_token.replace('<eb10></eb10>',
'<td><b> \u2028 \u2028 </b></td>')
return master_token
def insert_text_to_token(master_token_list, match_text_dict):
"""
Insert OCR text result to structure token.
:param master_token_list:
:param match_text_dict:
:return:
"""
master_token_list = merge_span_token(master_token_list)
merged_result_list = []
text_count = 0
for master_token in master_token_list:
if master_token.startswith('<td'):
if text_count > len(match_text_dict) - 1:
text_count += 1
continue
elif text_count not in match_text_dict.keys():
text_count += 1
continue
else:
master_token = master_token.replace(
'><', '>{}<'.format(match_text_dict[text_count]))
text_count += 1
master_token = deal_eb_token(master_token)
merged_result_list.append(master_token)
return ''.join(merged_result_list)
def deal_isolate_span(thead_part):
"""
Deal with isolate span cases in this function.
It causes by wrong prediction in structure recognition model.
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
:param thead_part:
:return:
"""
# 1. find out isolate span tokens.
isolate_pattern = "<td></td> rowspan=\"(\d)+\" colspan=\"(\d)+\"></b></td>|" \
"<td></td> colspan=\"(\d)+\" rowspan=\"(\d)+\"></b></td>|" \
"<td></td> rowspan=\"(\d)+\"></b></td>|" \
"<td></td> colspan=\"(\d)+\"></b></td>"
isolate_iter = re.finditer(isolate_pattern, thead_part)
isolate_list = [i.group() for i in isolate_iter]
# 2. find out span number, by step 1 results.
span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
" colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
" rowspan=\"(\d)+\"|" \
" colspan=\"(\d)+\""
corrected_list = []
for isolate_item in isolate_list:
span_part = re.search(span_pattern, isolate_item)
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
corrected_item = '<td{}></td>'.format(spanStr_in_isolateItem)
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
# 4. replace original isolated token.
for corrected_item, isolate_item in zip(corrected_list, isolate_list):
if corrected_item is not None:
thead_part = thead_part.replace(isolate_item, corrected_item)
else:
pass
return thead_part
def deal_duplicate_bb(thead_part):
"""
Deal duplicate <b> or </b> after replace.
Keep one <b></b> in a <td></td> token.
:param thead_part:
:return:
"""
# 1. find out <td></td> in <thead></thead>.
td_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">(.+?)</td>|" \
"<td colspan=\"(\d)+\" rowspan=\"(\d)+\">(.+?)</td>|" \
"<td rowspan=\"(\d)+\">(.+?)</td>|" \
"<td colspan=\"(\d)+\">(.+?)</td>|" \
"<td>(.*?)</td>"
td_iter = re.finditer(td_pattern, thead_part)
td_list = [t.group() for t in td_iter]
# 2. is multiply <b></b> in <td></td> or not?
new_td_list = []
for td_item in td_list:
if td_item.count('<b>') > 1 or td_item.count('</b>') > 1:
# multiply <b></b> in <td></td> case.
# 1. remove all <b></b>
td_item = td_item.replace('<b>', '').replace('</b>', '')
# 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
td_item = td_item.replace('<td>', '<td><b>').replace('</td>',
'</b></td>')
new_td_list.append(td_item)
else:
new_td_list.append(td_item)
# 3. replace original thead part.
for td_item, new_td_item in zip(td_list, new_td_list):
thead_part = thead_part.replace(td_item, new_td_item)
return thead_part
def deal_bb(result_token):
"""
In our opinion, <b></b> always occurs in <thead></thead> text's context.
This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
:param result_token:
:return:
"""
# find out <thead></thead> parts.
thead_pattern = '<thead>(.*?)</thead>'
if re.search(thead_pattern, result_token) is None:
return result_token
thead_part = re.search(thead_pattern, result_token).group()
origin_thead_part = copy.deepcopy(thead_part)
# check "rowspan" or "colspan" occur in <thead></thead> parts or not .
span_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">|<td colspan=\"(\d)+\" rowspan=\"(\d)+\">|<td rowspan=\"(\d)+\">|<td colspan=\"(\d)+\">"
span_iter = re.finditer(span_pattern, thead_part)
span_list = [s.group() for s in span_iter]
has_span_in_head = True if len(span_list) > 0 else False
if not has_span_in_head:
# <thead></thead> not include "rowspan" or "colspan" branch 1.
# 1. replace <td> to <td><b>, and </td> to </b></td>
# 2. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b></b> to </b>
thead_part = thead_part.replace('<td>', '<td><b>')\
.replace('</td>', '</b></td>')\
.replace('<b><b>', '<b>')\
.replace('</b></b>', '</b>')
else:
# <thead></thead> include "rowspan" or "colspan" branch 2.
# Firstly, we deal rowspan or colspan cases.
# 1. replace > to ><b>
# 2. replace </td> to </b></td>
# 3. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b><b> to </b>
# Secondly, deal ordinary cases like branch 1
# replace ">" to "<b>"
replaced_span_list = []
for sp in span_list:
replaced_span_list.append(sp.replace('>', '><b>'))
for sp, rsp in zip(span_list, replaced_span_list):
thead_part = thead_part.replace(sp, rsp)
# replace "</td>" to "</b></td>"
thead_part = thead_part.replace('</td>', '</b></td>')
# remove duplicated <b> by re.sub
mb_pattern = "(<b>)+"
single_b_string = "<b>"
thead_part = re.sub(mb_pattern, single_b_string, thead_part)
mgb_pattern = "(</b>)+"
single_gb_string = "</b>"
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
# ordinary cases like branch 1
thead_part = thead_part.replace('<td>', '<td><b>').replace('<b><b>',
'<b>')
# convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
# but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
thead_part = thead_part.replace('<td><b></b></td>', '<td></td>')
# deal with duplicated <b></b>
thead_part = deal_duplicate_bb(thead_part)
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
# eg.PMC5994107_011_00.png
thead_part = deal_isolate_span(thead_part)
# replace original result with new thead part.
result_token = result_token.replace(origin_thead_part, thead_part)
return result_token
class Matcher:
def __init__(self, end2end_file, structure_master_file):
"""
This class process the end2end results and structure recognition results.
:param end2end_file: end2end results predict by end2end inference.
:param structure_master_file: structure recognition results predict by structure master inference.
"""
self.end2end_file = end2end_file
self.structure_master_file = structure_master_file
self.end2end_results = pickle_load(end2end_file, prefix='end2end')
self.structure_master_results = pickle_load(
structure_master_file, prefix='structure')
def match(self):
"""
Match process:
pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
1. Use pseBbox is inside masterBbox judge rule
2. Use iou between pseBbox and masterBbox rule
3. Use min distance of center point rule
:return:
"""
match_results = dict()
for idx, (file_name,
end2end_result) in enumerate(self.end2end_results.items()):
match_list = []
if file_name not in self.structure_master_results:
continue
structure_master_result = self.structure_master_results[file_name]
end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
get_bboxes_list(end2end_result, structure_master_result)
# rule 1: center rule
center_rule_match_list = \
center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
match_list.extend(center_rule_match_list)
# rule 2: iou rule
# firstly, find not match index in previous step.
center_no_match_end2end_indexs = \
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
if len(center_no_match_end2end_indexs) > 0:
center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
center_no_match_end2end_indexs]
# secondly, iou rule match
iou_rule_match_list = \
iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
match_list.extend(iou_rule_match_list)
# rule 3: distance rule
# match between no-match end2end bboxes and no-match master bboxes.
# it will return master_bboxes_nums match-pairs.
# firstly, find not match index in previous step.
centerIou_no_match_end2end_indexs = \
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
centerIou_no_match_master_indexs = \
find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
if len(centerIou_no_match_master_indexs) > 0 and len(
centerIou_no_match_end2end_indexs) > 0:
centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
centerIou_no_match_end2end_indexs]
centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
centerIou_no_match_master_indexs]
distance_match_list = distance_rule_match(
centerIou_no_match_end2end_indexs,
centerIou_no_match_end2end_xywh,
centerIou_no_match_master_indexs,
centerIou_no_match_master_xywh)
match_list.extend(distance_match_list)
# TODO:
# The render no-match pseBbox, insert the last
# After step3 distance rule, a master bbox at least match one end2end bbox.
# But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
# For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
# The above extra insert bboxes will be further processed in "formatOutput" function.
# After this operation, it will increase TEDS score.
no_match_end2end_indexes = \
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
if len(no_match_end2end_indexes) > 0:
no_match_end2end_xywh = end2end_xywh_bboxes[
no_match_end2end_indexes]
# sort the render no-match end2end bbox in row
end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
# make virtual master bboxes, and get matching with the no-match end2end bboxes.
extra_match_list = extra_match(
end2end_sorted_indexes_list,
len(structure_master_xywh_bboxes))
match_list_add_extra_match = copy.deepcopy(match_list)
match_list_add_extra_match.extend(extra_match_list)
else:
# no no-match end2end bboxes
match_list_add_extra_match = copy.deepcopy(match_list)
sorted_groups = []
sorted_bboxes_groups = []
match_result_dict = {
'match_list': match_list,
'match_list_add_extra_match': match_list_add_extra_match,
'sorted_groups': sorted_groups,
'sorted_bboxes_groups': sorted_bboxes_groups
}
# format output
match_result_dict = self._format(match_result_dict, file_name)
match_results[file_name] = match_result_dict
return match_results
def _format(self, match_result, file_name):
"""
Extend the master token(insert virtual master token), and format matching result.
:param match_result:
:param file_name:
:return:
"""
end2end_info = self.end2end_results[file_name]
master_info = self.structure_master_results[file_name]
master_token = master_info['text']
sorted_groups = match_result['sorted_groups']
# creat virtual master token
virtual_master_token_list = []
for line_group in sorted_groups:
tmp_list = ['<tr>']
item_nums = len(line_group)
for _ in range(item_nums):
tmp_list.append('<td></td>')
tmp_list.append('</tr>')
virtual_master_token_list.extend(tmp_list)
# insert virtual master token
master_token_list = master_token.split(',')
if master_token_list[-1] == '</tbody>':
# complete predict(no cut by max length)
# This situation insert virtual master token will drop TEDs score in val set.
# So we will not extend virtual token in this situation.
# fake extend virtual
master_token_list[:-1].extend(virtual_master_token_list)
# real extend virtual
# master_token_list = master_token_list[:-1]
# master_token_list.extend(virtual_master_token_list)
# master_token_list.append('</tbody>')
elif master_token_list[-1] == '<td></td>':
master_token_list.append('</tr>')
master_token_list.extend(virtual_master_token_list)
master_token_list.append('</tbody>')
else:
master_token_list.extend(virtual_master_token_list)
master_token_list.append('</tbody>')
# format output
match_result.setdefault('matched_master_token_list', master_token_list)
return match_result
def get_merge_result(self, match_results):
"""
Merge the OCR result into structure token to get final results.
:param match_results:
:return:
"""
merged_results = dict()
# break_token is linefeed token, when one master bbox has multiply end2end bboxes.
break_token = ' '
for idx, (file_name, match_info) in enumerate(match_results.items()):
end2end_info = self.end2end_results[file_name]
master_token_list = match_info['matched_master_token_list']
match_list = match_info['match_list_add_extra_match']
match_dict = get_match_dict(match_list)
match_text_dict = get_match_text_dict(match_dict, end2end_info,
break_token)
merged_result = insert_text_to_token(master_token_list,
match_text_dict)
merged_result = deal_bb(merged_result)
merged_results[file_name] = merged_result
return merged_results
class TableMasterMatcher(Matcher):
def __init__(self):
pass
def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
end2end_results = {img_name: []}
for dt_box, res in zip(dt_boxes, rec_res):
d = dict(
bbox=np.array(dt_box),
text=res[0], )
end2end_results[img_name].append(d)
self.end2end_results = end2end_results
structure_master_result_dict = {img_name: {}}
pred_structures, pred_bboxes = structure_res
pred_structures = ','.join(pred_structures[3:-3])
structure_master_result_dict[img_name]['text'] = pred_structures
structure_master_result_dict[img_name]['bbox'] = pred_bboxes
self.structure_master_results = structure_master_result_dict
# match
match_results = self.match()
merged_results = self.get_merge_result(match_results)
pred_html = merged_results[img_name]
pred_html = '<html><body><table>' + pred_html + '</table></body></html>'
return pred_html

View File

@ -27,6 +27,8 @@ def init_args():
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str)
parser.add_argument(
"--merge_no_span_structure", type=str2bool, default=True)
parser.add_argument(
"--table_char_dict_path",
type=str,
@ -36,14 +38,17 @@ def init_args():
parser.add_argument(
"--layout_dict_path",
type=str,
default="../ppocr/utils/dict/layout_pubalynet_dict.txt")
default="../ppocr/utils/dict/layout_publaynet_dict.txt")
parser.add_argument(
"--layout_score_threshold",
type=float,
default=0.5,
help="Threshold of score.")
parser.add_argument(
"--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms.")
"--layout_nms_threshold",
type=float,
default=0.5,
help="Threshold of nms.")
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
@ -59,6 +64,11 @@ def init_args():
type=str,
default='structure',
help='structure and vqa is supported')
parser.add_argument(
"--image_orientation",
type=bool,
default=False,
help='Whether to enable image orientation recognition')
parser.add_argument(
"--layout",
type=str2bool,

View File

@ -41,7 +41,11 @@ logger = get_logger()
class SerPredictor(object):
def __init__(self, args):
self.ocr_engine = PaddleOCR(
use_angle_cls=False, show_log=False, use_gpu=args.use_gpu)
use_angle_cls=args.use_angle_cls,
det_model_dir=args.det_model_dir,
rec_model_dir=args.rec_model_dir,
show_log=False,
use_gpu=args.use_gpu)
pre_process_list = [{
'VQATokenLabelEncode': {

View File

@ -58,10 +58,11 @@ function status_check(){
run_command=$2
run_log=$3
model_name=$4
log_path=$5
if [ $last_status -eq 0 ]; then
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
else
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
fi
}

View File

@ -1,58 +0,0 @@
===========================train_params===========================
model_name:det_r18_db_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o
quant_export:null
fpgm_export:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:null
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile

View File

@ -19,8 +19,6 @@ Global:
character_type: en
max_text_length: 800
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam

View File

@ -52,7 +52,7 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_benchmark_params==========================
batch_size:4
batch_size:8
fp_items:fp32|fp16
epoch:3
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile

View File

@ -16,8 +16,6 @@ Global:
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
@ -86,7 +84,7 @@ Train:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
@ -120,7 +118,7 @@ Eval:
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
box_format: 'xywh'
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]

View File

@ -0,0 +1,59 @@
===========================train_params===========================
model_name:vi_layoutxlm_ser
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--ser_model_dir:
--image_dir:./ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_benchmark_params==========================
batch_size:4
fp_items:fp32|fp16
epoch:3
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98

View File

@ -106,7 +106,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
@ -220,7 +220,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate

View File

@ -84,7 +84,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@ -117,7 +117,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done

View File

@ -88,7 +88,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@ -119,7 +119,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then
for infer_model in ${infer_model_dir_list[*]}; do
# run export
if [ ${infer_run_exports[Count]} != "null" ];then
_save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log"
save_infer_dir=$(dirname $infer_model)
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
save_infer_dir=${infer_model}
fi

View File

@ -66,7 +66,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
@ -78,7 +78,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
elif [[ ${model_name} =~ "det" ]]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
@ -91,7 +91,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
elif [[ ${model_name} =~ "rec" ]]; then
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
@ -104,7 +104,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
fi
# python inference
@ -127,7 +127,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
@ -146,7 +146,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
echo "Does not support hardware other than CPU and GPU Currently!"
fi
@ -158,4 +158,4 @@ echo "################### run test ###################"
export Count=0
IFS="|"
func_paddle2onnx
func_paddle2onnx

View File

@ -84,7 +84,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@ -109,7 +109,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi

View File

@ -83,7 +83,7 @@ function func_serving(){
trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
@ -95,14 +95,14 @@ function func_serving(){
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
sleep 5s
_save_log_path="${LOG_PATH}/cpp_client_cpu.log"
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else
server_log_path="${LOG_PATH}/cpp_server_gpu.log"
@ -114,7 +114,7 @@ function func_serving(){
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
fi
done

View File

@ -126,19 +126,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
@ -147,7 +147,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
@ -177,19 +177,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
@ -198,7 +198,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9

View File

@ -133,7 +133,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@ -164,7 +164,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi
@ -298,7 +298,7 @@ else
# run train
eval $cmd
eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}"
status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
@ -309,7 +309,7 @@ else
eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 "
eval $eval_cmd
status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}"
fi
# run export model
if [ ${run_export} != "null" ]; then
@ -320,7 +320,7 @@ else
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
eval $export_cmd
status_check $? "${export_cmd}" "${status_log}" "${model_name}"
status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
#run inference
eval $env

View File

@ -58,6 +58,8 @@ def export_single_model(model,
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
[paddle.static.InputSpec(
shape=[None], dtype="float32")]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
@ -144,7 +146,7 @@ def export_single_model(model,
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
infer_shape = [3, 48, -1] # for rec model, H must be 32
infer_shape = [3, 32, -1] # for rec model, H must be 32
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
@ -156,6 +158,8 @@ def export_single_model(model,
infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
if arch_config["algorithm"] == "SLANet":
infer_shape = [3, -1, -1]
model = to_static(
model,
input_spec=[
@ -248,4 +252,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -458,7 +458,8 @@ class TextRecognizer(object):
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
valid_ratios,
np.array(
[valid_ratios], dtype=np.float32),
]
if self.use_onnx:
input_dict = {}

View File

@ -65,9 +65,11 @@ class TextSystem(object):
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True):
time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
@ -83,10 +85,12 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if self.args.save_crop_res:
@ -98,7 +102,9 @@ class TextSystem(object):
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res
end = time.time()
time_dict['all'] = end - start
return filter_boxes, filter_rec_res, time_dict
def sorted_boxes(dt_boxes):
@ -133,9 +139,11 @@ def main(args):
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
logger.info(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
# warm up 10 times
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
@ -155,7 +163,7 @@ def main(args):
logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
dt_boxes, rec_res, time_dict = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
@ -198,7 +206,10 @@ def main(args):
text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report()
with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
with open(
os.path.join(draw_img_save_dir, "system_results.txt"),
'w',
encoding='utf-8') as f:
f.writelines(save_results)

View File

@ -163,6 +163,8 @@ def create_predictor(args, mode, logger):
model_dir = args.ser_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
elif mode == 'layout':
model_dir = args.layout_model_dir
else:
model_dir = args.e2e_model_dir

View File

@ -37,6 +37,7 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
from ppocr.utils.visual import draw_rectangle
from tools.infer.utility import draw_boxes
import tools.program as program
import cv2
@ -56,7 +57,6 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model)
@ -106,9 +106,13 @@ def main(config, device, logger, vdl_writer):
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(file, bbox_list, use_xywh)
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
img = draw_rectangle(file, bbox_list)
else:
img = draw_boxes(cv2.imread(file), bbox_list)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info('save result to {}'.format(save_res_path))
logger.info("success!")

View File

@ -653,7 +653,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'RobustScanner'
'Gestalt', 'SLANet', 'RobustScanner'
]
if use_xpu:

View File

@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info('convert_sync_batchnorm')
model = apply_to_static(model, config, logger)