Merge branch 'dygraph' into robustscanner_branch
commit
40c45e2ccc
configs/table
deploy
ppocr
data/imaug
metrics
modeling
postprocess
ppstructure
test_tipc
configs
det_r18_vd_db_v2_0
en_table_structure
layoutxlm_ser
table_master
vi_layoutxlm_ser
|
@ -31,4 +31,4 @@ paddleocr.egg-info/
|
|||
/deploy/android_demo/app/.cxx/
|
||||
/deploy/android_demo/app/cache/
|
||||
test_tipc/web/models/
|
||||
test_tipc/web/node_modules/
|
||||
test_tipc/web/node_modules/
|
|
@ -0,0 +1,143 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 20
|
||||
save_model_dir: ./output/SLANet
|
||||
save_epoch_step: 400
|
||||
# evaluation is run every 1000 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir: ./output/SLANet/infer
|
||||
use_visualdl: False
|
||||
infer_img: doc/table/table.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: &max_text_length 500
|
||||
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
infer_mode: False
|
||||
use_sync_bn: True
|
||||
save_res_path: 'output/infer'
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.001
|
||||
decay_epochs : [40, 50]
|
||||
values : [0.001, 0.0001, 0.00005]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: table
|
||||
algorithm: SLANet
|
||||
Backbone:
|
||||
name: PPLCNet
|
||||
scale: 1.0
|
||||
pretrained: true
|
||||
use_ssld: true
|
||||
Neck:
|
||||
name: CSPPAN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: SLAHead
|
||||
hidden_size: 256
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: SLALoss
|
||||
structure_weight: 1.0
|
||||
loc_weight: 2.0
|
||||
loc_loss: smooth_l1
|
||||
|
||||
PostProcess:
|
||||
name: TableLabelDecode
|
||||
merge_no_span_structure: &merge_no_span_structure True
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
compute_bbox_metric: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
box_format: *box_format
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/train/
|
||||
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- TableLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: *merge_no_span_structure
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
in_box_format: *box_format
|
||||
out_box_format: *box_format
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
size: [488, 488]
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 48
|
||||
drop_last: True
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/val/
|
||||
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- TableLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: *merge_no_span_structure
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
in_box_format: *box_format
|
||||
out_box_format: *box_format
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
size: [488, 488]
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 48
|
||||
num_workers: 1
|
|
@ -8,16 +8,15 @@ Global:
|
|||
eval_batch_step: [0, 6259]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model: null
|
||||
checkpoints:
|
||||
checkpoints:
|
||||
save_inference_dir: output/table_master/infer
|
||||
use_visualdl: false
|
||||
infer_img: ppstructure/docs/table/table.jpg
|
||||
save_res_path: ./output/table_master
|
||||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||
infer_mode: false
|
||||
max_text_length: 500
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
max_text_length: &max_text_length 500
|
||||
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -52,7 +51,8 @@ Architecture:
|
|||
headers: 8
|
||||
dropout: 0
|
||||
d_ff: 2024
|
||||
max_text_length: 500
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: TableMasterLoss
|
||||
|
@ -61,11 +61,13 @@ Loss:
|
|||
PostProcess:
|
||||
name: TableMasterLabelDecode
|
||||
box_shape: pad
|
||||
merge_no_span_structure: &merge_no_span_structure True
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
compute_bbox_metric: False
|
||||
box_format: *box_format
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
@ -78,15 +80,18 @@ Train:
|
|||
channel_first: False
|
||||
- TableMasterLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: True
|
||||
merge_no_span_structure: *merge_no_span_structure
|
||||
replace_empty_cell_token: True
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- ResizeTableImage:
|
||||
max_len: 480
|
||||
resize_bboxes: True
|
||||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
in_box_format: *box_format
|
||||
out_box_format: *box_format
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
@ -112,15 +117,18 @@ Eval:
|
|||
channel_first: False
|
||||
- TableMasterLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: True
|
||||
merge_no_span_structure: *merge_no_span_structure
|
||||
replace_empty_cell_token: True
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- ResizeTableImage:
|
||||
max_len: 480
|
||||
resize_bboxes: True
|
||||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
in_box_format: *box_format
|
||||
out_box_format: *box_format
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
|
|
@ -17,10 +17,9 @@ Global:
|
|||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 800
|
||||
max_text_length: &max_text_length 800
|
||||
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
@ -44,7 +43,8 @@ Architecture:
|
|||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
loc_type: 2
|
||||
max_text_length: 800
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: TableAttentionLoss
|
||||
|
@ -72,6 +72,8 @@ Train:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
|
@ -94,8 +96,8 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
|
||||
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
|
||||
data_dir: train_data/table/pubtabnet/val/
|
||||
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -104,6 +106,8 @@ Eval:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
|
|
|
@ -30,7 +30,8 @@ DECLARE_string(image_dir);
|
|||
DECLARE_string(type);
|
||||
// detection related
|
||||
DECLARE_string(det_model_dir);
|
||||
DECLARE_int32(max_side_len);
|
||||
DECLARE_string(limit_type);
|
||||
DECLARE_int32(limit_side_len);
|
||||
DECLARE_double(det_db_thresh);
|
||||
DECLARE_double(det_db_box_thresh);
|
||||
DECLARE_double(det_db_unclip_ratio);
|
||||
|
@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num);
|
|||
DECLARE_string(rec_char_dict_path);
|
||||
DECLARE_int32(rec_img_h);
|
||||
DECLARE_int32(rec_img_w);
|
||||
// structure model related
|
||||
DECLARE_string(table_model_dir);
|
||||
DECLARE_int32(table_max_len);
|
||||
DECLARE_int32(table_batch_num);
|
||||
DECLARE_string(table_char_dict_path);
|
||||
// forward related
|
||||
DECLARE_bool(det);
|
||||
DECLARE_bool(rec);
|
||||
DECLARE_bool(cls);
|
||||
DECLARE_bool(table);
|
|
@ -41,8 +41,8 @@ public:
|
|||
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const int &max_side_len,
|
||||
const double &det_db_thresh,
|
||||
const bool &use_mkldnn, const string &limit_type,
|
||||
const int &limit_side_len, const double &det_db_thresh,
|
||||
const double &det_db_box_thresh,
|
||||
const double &det_db_unclip_ratio,
|
||||
const std::string &det_db_score_mode,
|
||||
|
@ -54,7 +54,8 @@ public:
|
|||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
|
||||
this->max_side_len_ = max_side_len;
|
||||
this->limit_type_ = limit_type;
|
||||
this->limit_side_len_ = limit_side_len;
|
||||
|
||||
this->det_db_thresh_ = det_db_thresh;
|
||||
this->det_db_box_thresh_ = det_db_box_thresh;
|
||||
|
@ -84,7 +85,8 @@ private:
|
|||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
|
||||
int max_side_len_ = 960;
|
||||
string limit_type_ = "max";
|
||||
int limit_side_len_ = 960;
|
||||
|
||||
double det_db_thresh_ = 0.3;
|
||||
double det_db_box_thresh_ = 0.5;
|
||||
|
@ -106,7 +108,7 @@ private:
|
|||
Permute permute_op_;
|
||||
|
||||
// post-process
|
||||
PostProcessor post_processor_;
|
||||
DBPostProcessor post_processor_;
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -47,11 +47,7 @@ public:
|
|||
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
|
||||
bool rec = true, bool cls = true);
|
||||
|
||||
private:
|
||||
DBDetector *detector_ = nullptr;
|
||||
Classifier *classifier_ = nullptr;
|
||||
CRNNRecognizer *recognizer_ = nullptr;
|
||||
|
||||
protected:
|
||||
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
|
||||
std::vector<double> ×);
|
||||
void rec(std::vector<cv::Mat> img_list,
|
||||
|
@ -62,6 +58,11 @@ private:
|
|||
std::vector<double> ×);
|
||||
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
|
||||
std::vector<double> &cls_times, int img_num);
|
||||
|
||||
private:
|
||||
DBDetector *detector_ = nullptr;
|
||||
Classifier *classifier_ = nullptr;
|
||||
CRNNRecognizer *recognizer_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/paddleocr.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/structure_table.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class PaddleStructure : public PPOCR {
|
||||
public:
|
||||
explicit PaddleStructure();
|
||||
~PaddleStructure();
|
||||
std::vector<std::vector<StructurePredictResult>>
|
||||
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
|
||||
bool table = true);
|
||||
|
||||
private:
|
||||
StructureTableRecognizer *recognizer_ = nullptr;
|
||||
|
||||
void table(cv::Mat img, StructurePredictResult &structure_result,
|
||||
std::vector<double> &time_info_table,
|
||||
std::vector<double> &time_info_det,
|
||||
std::vector<double> &time_info_rec,
|
||||
std::vector<double> &time_info_cls);
|
||||
std::string
|
||||
rebuild_table(std::vector<std::string> rec_html_tags,
|
||||
std::vector<std::vector<std::vector<int>>> rec_boxes,
|
||||
std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
float iou(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2);
|
||||
float dis(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2);
|
||||
|
||||
static bool comparison_dis(const std::vector<float> &dis1,
|
||||
const std::vector<float> &dis2) {
|
||||
if (dis1[1] < dis2[1]) {
|
||||
return true;
|
||||
} else if (dis1[1] == dis2[1]) {
|
||||
return dis1[0] < dis2[0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -34,7 +34,7 @@ using namespace std;
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class PostProcessor {
|
||||
class DBPostProcessor {
|
||||
public:
|
||||
void GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance);
|
||||
|
@ -90,4 +90,21 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
class TablePostProcessor {
|
||||
public:
|
||||
void init(std::string label_path);
|
||||
void
|
||||
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
|
||||
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
|
||||
std::vector<int> &structure_probs_shape,
|
||||
std::vector<std::vector<std::string>> &rec_html_tag_batch,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
|
||||
std::vector<int> &width_list, std::vector<int> &height_list);
|
||||
|
||||
private:
|
||||
std::vector<std::string> label_list_;
|
||||
std::string end = "eos";
|
||||
std::string beg = "sos";
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -48,11 +48,12 @@ class PermuteBatch {
|
|||
public:
|
||||
virtual void Run(const std::vector<cv::Mat> imgs, float *data);
|
||||
};
|
||||
|
||||
|
||||
class ResizeImgType0 {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
|
||||
float &ratio_h, float &ratio_w, bool use_tensorrt);
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type,
|
||||
int limit_side_len, float &ratio_h, float &ratio_w,
|
||||
bool use_tensorrt);
|
||||
};
|
||||
|
||||
class CrnnResizeImg {
|
||||
|
@ -69,4 +70,16 @@ public:
|
|||
const std::vector<int> &rec_image_shape = {3, 48, 192});
|
||||
};
|
||||
|
||||
class TableResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len = 488);
|
||||
};
|
||||
|
||||
class TablePadImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len = 488);
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class StructureTableRecognizer {
|
||||
public:
|
||||
explicit StructureTableRecognizer(
|
||||
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
|
||||
const int &gpu_mem, const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const string &label_path,
|
||||
const bool &use_tensorrt, const std::string &precision,
|
||||
const int &table_batch_num, const int &table_max_len) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_tensorrt_ = use_tensorrt;
|
||||
this->precision_ = precision;
|
||||
this->table_batch_num_ = table_batch_num;
|
||||
this->table_max_len_ = table_max_len;
|
||||
|
||||
this->post_processor_.init(label_path);
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
||||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<cv::Mat> img_list,
|
||||
std::vector<std::vector<std::string>> &rec_html_tags,
|
||||
std::vector<float> &rec_scores,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes,
|
||||
std::vector<double> ×);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
int table_max_len_ = 488;
|
||||
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
bool use_tensorrt_ = false;
|
||||
std::string precision_ = "fp32";
|
||||
int table_batch_num_ = 1;
|
||||
|
||||
// pre-process
|
||||
TableResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
PermuteBatch permute_op_;
|
||||
TablePadImg pad_op_;
|
||||
|
||||
// post-process
|
||||
TablePostProcessor post_processor_;
|
||||
|
||||
}; // class StructureTableRecognizer
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -40,6 +40,14 @@ struct OCRPredictResult {
|
|||
int cls_label = -1;
|
||||
};
|
||||
|
||||
struct StructurePredictResult {
|
||||
std::vector<int> box;
|
||||
std::string type;
|
||||
std::vector<OCRPredictResult> text_res;
|
||||
std::string html;
|
||||
float html_score = -1;
|
||||
};
|
||||
|
||||
class Utility {
|
||||
public:
|
||||
static std::vector<std::string> ReadDict(const std::string &path);
|
||||
|
@ -68,6 +76,22 @@ public:
|
|||
static void CreateDir(const std::string &path);
|
||||
|
||||
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area);
|
||||
|
||||
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
private:
|
||||
static bool comparison_box(const OCRPredictResult &result1,
|
||||
const OCRPredictResult &result2) {
|
||||
if (result1.box[0][1] < result2.box[0][1]) {
|
||||
return true;
|
||||
} else if (result1.box[0][1] == result2.box[0][1]) {
|
||||
return result1.box[0][0] < result2.box[0][0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -171,6 +171,9 @@ inference/
|
|||
|-- cls
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
|-- table
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
|
@ -275,6 +278,17 @@ Specifically,
|
|||
--cls=true \
|
||||
```
|
||||
|
||||
|
||||
##### 7. table
|
||||
```shell
|
||||
./build/ppocr --det_model_dir=inference/det_db \
|
||||
--rec_model_dir=inference/rec_rcnn \
|
||||
--table_model_dir=inference/table \
|
||||
--image_dir=../../ppstructure/docs/table/table.jpg \
|
||||
--type=structure \
|
||||
--table=true
|
||||
```
|
||||
|
||||
More parameters are as follows,
|
||||
|
||||
- Common parameters
|
||||
|
@ -293,9 +307,9 @@ More parameters are as follows,
|
|||
|
||||
|parameter|data type|default|meaning|
|
||||
| :---: | :---: | :---: | :---: |
|
||||
|det|bool|true|前向是否执行文字检测|
|
||||
|rec|bool|true|前向是否执行文字识别|
|
||||
|cls|bool|false|前向是否执行文字方向分类|
|
||||
|det|bool|true|Whether to perform text detection in the forward direction|
|
||||
|rec|bool|true|Whether to perform text recognition in the forward direction|
|
||||
|cls|bool|false|Whether to perform text direction classification in the forward direction|
|
||||
|
||||
|
||||
- Detection related parameters
|
||||
|
@ -329,6 +343,15 @@ More parameters are as follows,
|
|||
|rec_img_h|int|48|image height of recognition|
|
||||
|rec_img_w|int|320|image width of recognition|
|
||||
|
||||
- Table recognition related parameters
|
||||
|
||||
|parameter|data type|default|meaning|
|
||||
| :---: | :---: | :---: | :---: |
|
||||
|table_model_dir|string|-|Address of table recognition inference model|
|
||||
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file|
|
||||
|table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network is(table_max_len,table_max_len)|
|
||||
|
||||
|
||||
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
|
||||
|
||||
|
||||
|
@ -344,6 +367,12 @@ predict img: ../../doc/imgs/12.jpg
|
|||
The detection visualized image saved in ./output//12.jpg
|
||||
```
|
||||
|
||||
- table
|
||||
|
||||
```bash
|
||||
predict img: ../../ppstructure/docs/table/table.jpg
|
||||
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
|
||||
```
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. FAQ
|
||||
|
|
|
@ -181,6 +181,9 @@ inference/
|
|||
|-- cls
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
|-- table
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
```
|
||||
|
||||
<a name="22"></a>
|
||||
|
@ -285,6 +288,16 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|
|||
--cls=true \
|
||||
```
|
||||
|
||||
##### 7. 表格识别
|
||||
```shell
|
||||
./build/ppocr --det_model_dir=inference/det_db \
|
||||
--rec_model_dir=inference/rec_rcnn \
|
||||
--table_model_dir=inference/table \
|
||||
--image_dir=../../ppstructure/docs/table/table.jpg \
|
||||
--type=structure \
|
||||
--table=true
|
||||
```
|
||||
|
||||
更多支持的可调节参数解释如下:
|
||||
|
||||
- 通用参数
|
||||
|
@ -328,21 +341,32 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|
|||
|cls_thresh|float|0.9|方向分类器的得分阈值|
|
||||
|cls_batch_num|int|1|方向分类器batchsize|
|
||||
|
||||
- 识别模型相关
|
||||
- 文字识别模型相关
|
||||
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| :---: | :---: | :---: | :---: |
|
||||
|rec_model_dir|string|-|识别模型inference model地址|
|
||||
|rec_model_dir|string|-|文字识别模型inference model地址|
|
||||
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|
||||
|rec_batch_num|int|6|识别模型batchsize|
|
||||
|rec_img_h|int|48|识别模型输入图像高度|
|
||||
|rec_img_w|int|320|识别模型输入图像宽度|
|
||||
|rec_batch_num|int|6|文字识别模型batchsize|
|
||||
|rec_img_h|int|48|文字识别模型输入图像高度|
|
||||
|rec_img_w|int|320|文字识别模型输入图像宽度|
|
||||
|
||||
|
||||
- 表格识别模型相关
|
||||
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| :---: | :---: | :---: | :---: |
|
||||
|table_model_dir|string|-|表格识别模型inference model地址|
|
||||
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件|
|
||||
|table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)|
|
||||
|
||||
|
||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||
|
||||
最终屏幕上会输出检测结果如下。
|
||||
|
||||
- ocr
|
||||
|
||||
```bash
|
||||
predict img: ../../doc/imgs/12.jpg
|
||||
../../doc/imgs/12.jpg
|
||||
|
@ -353,6 +377,13 @@ predict img: ../../doc/imgs/12.jpg
|
|||
The detection visualized image saved in ./output//12.jpg
|
||||
```
|
||||
|
||||
- table
|
||||
|
||||
```bash
|
||||
predict img: ../../ppstructure/docs/table/table.jpg
|
||||
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
|
||||
```
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. FAQ
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ DEFINE_string(
|
|||
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
|
||||
// detection related
|
||||
DEFINE_string(det_model_dir, "", "Path of det inference model.");
|
||||
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
|
||||
DEFINE_string(limit_type, "max", "limit_type of input image.");
|
||||
DEFINE_int32(limit_side_len, 960, "limit_side_len of input image.");
|
||||
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
|
||||
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
|
||||
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
|
||||
|
@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
|
|||
DEFINE_int32(rec_img_h, 48, "rec image height");
|
||||
DEFINE_int32(rec_img_w, 320, "rec image width");
|
||||
|
||||
// structure model related
|
||||
DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
|
||||
DEFINE_int32(table_max_len, 488, "max len size of input image.");
|
||||
DEFINE_int32(table_batch_num, 1, "table_batch_num.");
|
||||
DEFINE_string(table_char_dict_path,
|
||||
"../../ppocr/utils/dict/table_structure_dict.txt",
|
||||
"Path of dictionary.");
|
||||
|
||||
// ocr forward related
|
||||
DEFINE_bool(det, true, "Whether use det in forward.");
|
||||
DEFINE_bool(rec, true, "Whether use rec in forward.");
|
||||
DEFINE_bool(cls, false, "Whether use cls in forward.");
|
||||
DEFINE_bool(cls, false, "Whether use cls in forward.");
|
||||
DEFINE_bool(table, false, "Whether use table structure in forward.");
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <include/args.h>
|
||||
#include <include/paddleocr.h>
|
||||
#include <include/paddlestructure.h>
|
||||
|
||||
using namespace PaddleOCR;
|
||||
|
||||
|
@ -32,6 +33,12 @@ void check_params() {
|
|||
}
|
||||
}
|
||||
if (FLAGS_rec) {
|
||||
std::cout
|
||||
<< "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320',"
|
||||
"if you are using recognition model with PP-OCRv2 or an older "
|
||||
"version, "
|
||||
"please set --rec_image_shape='3,32,320"
|
||||
<< std::endl;
|
||||
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[rec]: ./ppocr "
|
||||
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
|
@ -47,6 +54,17 @@ void check_params() {
|
|||
exit(1);
|
||||
}
|
||||
}
|
||||
if (FLAGS_table) {
|
||||
if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() ||
|
||||
FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[table]: ./ppocr "
|
||||
<< "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
|
||||
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
<< "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
|
||||
FLAGS_precision != "int8") {
|
||||
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
|
||||
|
@ -54,21 +72,7 @@ void check_params() {
|
|||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Parsing command-line
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
check_params();
|
||||
|
||||
if (!Utility::PathExists(FLAGS_image_dir)) {
|
||||
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<cv::String> cv_all_img_names;
|
||||
cv::glob(FLAGS_image_dir, cv_all_img_names);
|
||||
std::cout << "total images num: " << cv_all_img_names.size() << endl;
|
||||
|
||||
void ocr(std::vector<cv::String> &cv_all_img_names) {
|
||||
PPOCR ocr = PPOCR();
|
||||
|
||||
std::vector<std::vector<OCRPredictResult>> ocr_results =
|
||||
|
@ -109,3 +113,49 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void structure(std::vector<cv::String> &cv_all_img_names) {
|
||||
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
|
||||
std::vector<std::vector<StructurePredictResult>> structure_results =
|
||||
engine.structure(cv_all_img_names, false, FLAGS_table);
|
||||
for (int i = 0; i < cv_all_img_names.size(); i++) {
|
||||
cout << "predict img: " << cv_all_img_names[i] << endl;
|
||||
for (int j = 0; j < structure_results[i].size(); j++) {
|
||||
std::cout << j << "\ttype: " << structure_results[i][j].type
|
||||
<< ", region: [";
|
||||
std::cout << structure_results[i][j].box[0] << ","
|
||||
<< structure_results[i][j].box[1] << ","
|
||||
<< structure_results[i][j].box[2] << ","
|
||||
<< structure_results[i][j].box[3] << "], res: ";
|
||||
if (structure_results[i][j].type == "table") {
|
||||
std::cout << structure_results[i][j].html << std::endl;
|
||||
} else {
|
||||
Utility::print_result(structure_results[i][j].text_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Parsing command-line
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
check_params();
|
||||
|
||||
if (!Utility::PathExists(FLAGS_image_dir)) {
|
||||
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<cv::String> cv_all_img_names;
|
||||
cv::glob(FLAGS_image_dir, cv_all_img_names);
|
||||
std::cout << "total images num: " << cv_all_img_names.size() << endl;
|
||||
|
||||
if (FLAGS_type == "ocr") {
|
||||
ocr(cv_all_img_names);
|
||||
} else if (FLAGS_type == "structure") {
|
||||
structure(cv_all_img_names);
|
||||
} else {
|
||||
std::cout << "only value in ['ocr','structure'] is supported" << endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
{"elementwise_add_7", {1, 56, 2, 2}},
|
||||
{"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
|
||||
{"x", {1, 3, 1536, 1536}},
|
||||
{"conv2d_92.tmp_0", {1, 120, 400, 400}},
|
||||
{"conv2d_91.tmp_0", {1, 24, 200, 200}},
|
||||
{"conv2d_59.tmp_0", {1, 96, 400, 400}},
|
||||
|
@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img,
|
|||
img.copyTo(srcimg);
|
||||
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
|
||||
this->resize_op_.Run(img, resize_img, this->limit_type_,
|
||||
this->limit_side_len_, ratio_h, ratio_w,
|
||||
this->use_tensorrt_);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
|
|
|
@ -23,10 +23,10 @@ PPOCR::PPOCR() {
|
|||
if (FLAGS_det) {
|
||||
this->detector_ = new DBDetector(
|
||||
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
|
||||
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
|
||||
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
|
||||
FLAGS_precision);
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type,
|
||||
FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh,
|
||||
FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
}
|
||||
|
||||
if (FLAGS_cls && FLAGS_use_angle_cls) {
|
||||
|
@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
|
|||
res.box = boxes[i];
|
||||
ocr_results.push_back(res);
|
||||
}
|
||||
|
||||
// sort boex from top to bottom, from left to right
|
||||
Utility::sorted_boxes(ocr_results);
|
||||
times[0] += det_times[0];
|
||||
times[1] += det_times[1];
|
||||
times[2] += det_times[2];
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/args.h>
|
||||
#include <include/paddlestructure.h>
|
||||
|
||||
#include "auto_log/autolog.h"
|
||||
#include <numeric>
|
||||
#include <sys/stat.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
PaddleStructure::PaddleStructure() {
|
||||
if (FLAGS_table) {
|
||||
this->recognizer_ = new StructureTableRecognizer(
|
||||
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
|
||||
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
|
||||
FLAGS_table_max_len);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<StructurePredictResult>>
|
||||
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
|
||||
bool layout, bool table) {
|
||||
std::vector<double> time_info_det = {0, 0, 0};
|
||||
std::vector<double> time_info_rec = {0, 0, 0};
|
||||
std::vector<double> time_info_cls = {0, 0, 0};
|
||||
std::vector<double> time_info_table = {0, 0, 0};
|
||||
|
||||
std::vector<std::vector<StructurePredictResult>> structure_results;
|
||||
|
||||
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
|
||||
mkdir(FLAGS_output.c_str(), 0777);
|
||||
}
|
||||
for (int i = 0; i < cv_all_img_names.size(); ++i) {
|
||||
std::vector<StructurePredictResult> structure_result;
|
||||
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: "
|
||||
<< cv_all_img_names[i] << endl;
|
||||
exit(1);
|
||||
}
|
||||
if (layout) {
|
||||
} else {
|
||||
StructurePredictResult res;
|
||||
res.type = "table";
|
||||
res.box = std::vector<int>(4, 0);
|
||||
res.box[2] = srcimg.cols;
|
||||
res.box[3] = srcimg.rows;
|
||||
structure_result.push_back(res);
|
||||
}
|
||||
cv::Mat roi_img;
|
||||
for (int i = 0; i < structure_result.size(); i++) {
|
||||
// crop image
|
||||
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
|
||||
if (structure_result[i].type == "table") {
|
||||
this->table(roi_img, structure_result[i], time_info_table,
|
||||
time_info_det, time_info_rec, time_info_cls);
|
||||
}
|
||||
}
|
||||
structure_results.push_back(structure_result);
|
||||
}
|
||||
return structure_results;
|
||||
};
|
||||
|
||||
void PaddleStructure::table(cv::Mat img,
|
||||
StructurePredictResult &structure_result,
|
||||
std::vector<double> &time_info_table,
|
||||
std::vector<double> &time_info_det,
|
||||
std::vector<double> &time_info_rec,
|
||||
std::vector<double> &time_info_cls) {
|
||||
// predict structure
|
||||
std::vector<std::vector<std::string>> structure_html_tags;
|
||||
std::vector<float> structure_scores(1, 0);
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> structure_boxes;
|
||||
std::vector<double> structure_imes;
|
||||
std::vector<cv::Mat> img_list;
|
||||
img_list.push_back(img);
|
||||
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
|
||||
structure_boxes, structure_imes);
|
||||
time_info_table[0] += structure_imes[0];
|
||||
time_info_table[1] += structure_imes[1];
|
||||
time_info_table[2] += structure_imes[2];
|
||||
|
||||
std::vector<OCRPredictResult> ocr_result;
|
||||
std::string html;
|
||||
int expand_pixel = 3;
|
||||
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
// det
|
||||
this->det(img_list[i], ocr_result, time_info_det);
|
||||
// crop image
|
||||
std::vector<cv::Mat> rec_img_list;
|
||||
for (int j = 0; j < ocr_result.size(); j++) {
|
||||
int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0],
|
||||
ocr_result[j].box[2][0], ocr_result[j].box[3][0]};
|
||||
int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1],
|
||||
ocr_result[j].box[2][1], ocr_result[j].box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
std::vector<int> box{max(0, left - expand_pixel),
|
||||
max(0, top - expand_pixel),
|
||||
min(img_list[i].cols, right + expand_pixel),
|
||||
min(img_list[i].rows, bottom + expand_pixel)};
|
||||
cv::Mat crop_img = Utility::crop_image(img_list[i], box);
|
||||
rec_img_list.push_back(crop_img);
|
||||
}
|
||||
// rec
|
||||
this->rec(rec_img_list, ocr_result, time_info_rec);
|
||||
// rebuild table
|
||||
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
|
||||
ocr_result);
|
||||
structure_result.html = html;
|
||||
structure_result.html_score = structure_scores[i];
|
||||
}
|
||||
};
|
||||
|
||||
std::string PaddleStructure::rebuild_table(
|
||||
std::vector<std::string> structure_html_tags,
|
||||
std::vector<std::vector<std::vector<int>>> structure_boxes,
|
||||
std::vector<OCRPredictResult> &ocr_result) {
|
||||
// match text in same cell
|
||||
std::vector<std::vector<string>> matched(structure_boxes.size(),
|
||||
std::vector<std::string>());
|
||||
|
||||
for (int i = 0; i < ocr_result.size(); i++) {
|
||||
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
|
||||
std::vector<float>(3, 100000.0));
|
||||
for (int j = 0; j < structure_boxes.size(); j++) {
|
||||
int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0],
|
||||
ocr_result[i].box[2][0], ocr_result[i].box[3][0]};
|
||||
int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1],
|
||||
ocr_result[i].box[2][1], ocr_result[i].box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
std::vector<std::vector<int>> box(2, std::vector<int>(2, 0));
|
||||
box[0][0] = left - 1;
|
||||
box[0][1] = top - 1;
|
||||
box[1][0] = right + 1;
|
||||
box[1][1] = bottom + 1;
|
||||
|
||||
dis_list[j][0] = this->dis(box, structure_boxes[j]);
|
||||
dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]);
|
||||
dis_list[j][2] = j;
|
||||
}
|
||||
// find min dis idx
|
||||
std::sort(dis_list.begin(), dis_list.end(),
|
||||
PaddleStructure::comparison_dis);
|
||||
matched[dis_list[0][2]].push_back(ocr_result[i].text);
|
||||
}
|
||||
// get pred html
|
||||
std::string html_str = "";
|
||||
int td_tag_idx = 0;
|
||||
for (int i = 0; i < structure_html_tags.size(); i++) {
|
||||
if (structure_html_tags[i].find("</td>") != std::string::npos) {
|
||||
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
||||
html_str += "<td>";
|
||||
}
|
||||
if (matched[td_tag_idx].size() > 0) {
|
||||
bool b_with = false;
|
||||
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
|
||||
matched[td_tag_idx].size() > 1) {
|
||||
b_with = true;
|
||||
html_str += "<b>";
|
||||
}
|
||||
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
|
||||
std::string content = matched[td_tag_idx][j];
|
||||
if (matched[td_tag_idx].size() > 1) {
|
||||
// remove blank, <b> and </b>
|
||||
if (content.length() > 0 && content.at(0) == ' ') {
|
||||
content = content.substr(0);
|
||||
}
|
||||
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
|
||||
content = content.substr(3);
|
||||
}
|
||||
if (content.length() > 4 &&
|
||||
content.substr(content.length() - 4) == "</b>") {
|
||||
content = content.substr(0, content.length() - 4);
|
||||
}
|
||||
if (content.empty()) {
|
||||
continue;
|
||||
}
|
||||
// add blank
|
||||
if (j != matched[td_tag_idx].size() - 1 &&
|
||||
content.at(content.length() - 1) != ' ') {
|
||||
content += ' ';
|
||||
}
|
||||
}
|
||||
html_str += content;
|
||||
}
|
||||
if (b_with) {
|
||||
html_str += "</b>";
|
||||
}
|
||||
}
|
||||
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
||||
html_str += "</td>";
|
||||
} else {
|
||||
html_str += structure_html_tags[i];
|
||||
}
|
||||
td_tag_idx += 1;
|
||||
} else {
|
||||
html_str += structure_html_tags[i];
|
||||
}
|
||||
}
|
||||
return html_str;
|
||||
}
|
||||
|
||||
float PaddleStructure::iou(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2) {
|
||||
int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]);
|
||||
int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]);
|
||||
|
||||
// computing the sum_area
|
||||
int sum_area = area1 + area2;
|
||||
|
||||
// find the each point of intersect rectangle
|
||||
int x1 = max(box1[0][0], box2[0][0]);
|
||||
int y1 = max(box1[0][1], box2[0][1]);
|
||||
int x2 = min(box1[1][0], box2[1][0]);
|
||||
int y2 = min(box1[1][1], box2[1][1]);
|
||||
|
||||
// judge if there is an intersect
|
||||
if (y1 >= y2 || x1 >= x2) {
|
||||
return 0.0;
|
||||
} else {
|
||||
int intersect = (x2 - x1) * (y2 - y1);
|
||||
return intersect / (sum_area - intersect + 0.00000001);
|
||||
}
|
||||
}
|
||||
|
||||
float PaddleStructure::dis(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2) {
|
||||
int x1_1 = box1[0][0];
|
||||
int y1_1 = box1[0][1];
|
||||
int x2_1 = box1[1][0];
|
||||
int y2_1 = box1[1][1];
|
||||
|
||||
int x1_2 = box2[0][0];
|
||||
int y1_2 = box2[0][1];
|
||||
int x2_2 = box2[1][0];
|
||||
int y2_2 = box2[1][1];
|
||||
|
||||
float dis =
|
||||
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
||||
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
|
||||
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
||||
return dis + min(dis_2, dis_3);
|
||||
}
|
||||
|
||||
PaddleStructure::~PaddleStructure() {
|
||||
if (this->recognizer_ != nullptr) {
|
||||
delete this->recognizer_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance) {
|
||||
void DBPostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance) {
|
||||
int pts_num = 4;
|
||||
float area = 0.0f;
|
||||
float dist = 0.0f;
|
||||
|
@ -35,8 +35,8 @@ void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
|||
distance = area * unclip_ratio / dist;
|
||||
}
|
||||
|
||||
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
|
||||
const float &unclip_ratio) {
|
||||
cv::RotatedRect DBPostProcessor::UnClip(std::vector<std::vector<float>> box,
|
||||
const float &unclip_ratio) {
|
||||
float distance = 1.0;
|
||||
|
||||
GetContourArea(box, unclip_ratio, distance);
|
||||
|
@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
|
|||
return res;
|
||||
}
|
||||
|
||||
float **PostProcessor::Mat2Vec(cv::Mat mat) {
|
||||
float **DBPostProcessor::Mat2Vec(cv::Mat mat) {
|
||||
auto **array = new float *[mat.rows];
|
||||
for (int i = 0; i < mat.rows; ++i)
|
||||
array[i] = new float[mat.cols];
|
||||
|
@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) {
|
|||
}
|
||||
|
||||
std::vector<std::vector<int>>
|
||||
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
||||
DBPostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
||||
std::vector<std::vector<int>> box = pts;
|
||||
std::sort(box.begin(), box.end(), XsortInt);
|
||||
|
||||
|
@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
|||
return rect;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
|
||||
std::vector<std::vector<float>> DBPostProcessor::Mat2Vector(cv::Mat mat) {
|
||||
std::vector<std::vector<float>> img_vec;
|
||||
std::vector<float> tmp;
|
||||
|
||||
|
@ -113,20 +113,20 @@ std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
|
|||
return img_vec;
|
||||
}
|
||||
|
||||
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
|
||||
bool DBPostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
|
||||
if (a[0] != b[0])
|
||||
return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
|
||||
bool DBPostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
|
||||
if (a[0] != b[0])
|
||||
return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
||||
float &ssid) {
|
||||
std::vector<std::vector<float>>
|
||||
DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) {
|
||||
ssid = std::max(box.size.width, box.size.height);
|
||||
|
||||
cv::Mat points;
|
||||
|
@ -160,8 +160,8 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
|||
return array;
|
||||
}
|
||||
|
||||
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
cv::Mat pred) {
|
||||
float DBPostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
cv::Mat pred) {
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
std::vector<float> box_x;
|
||||
|
@ -206,8 +206,8 @@ float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
|||
return score;
|
||||
}
|
||||
|
||||
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
cv::Mat pred) {
|
||||
float DBPostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
cv::Mat pred) {
|
||||
auto array = box_array;
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
|
@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
|||
return score;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
|
||||
std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
|
||||
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
|
||||
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
|
||||
const int min_size = 3;
|
||||
|
@ -321,9 +321,9 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
|
|||
return boxes;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
float ratio_h, float ratio_w, cv::Mat srcimg) {
|
||||
std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
|
||||
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
|
||||
float ratio_w, cv::Mat srcimg) {
|
||||
int oriimg_h = srcimg.rows;
|
||||
int oriimg_w = srcimg.cols;
|
||||
|
||||
|
@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
return root_points;
|
||||
}
|
||||
|
||||
void TablePostProcessor::init(std::string label_path) {
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.insert(this->label_list_.begin(), this->beg);
|
||||
this->label_list_.push_back(this->end);
|
||||
}
|
||||
|
||||
void TablePostProcessor::Run(
|
||||
std::vector<float> &loc_preds, std::vector<float> &structure_probs,
|
||||
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
|
||||
std::vector<int> &structure_probs_shape,
|
||||
std::vector<std::vector<std::string>> &rec_html_tag_batch,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
|
||||
std::vector<int> &width_list, std::vector<int> &height_list) {
|
||||
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
|
||||
// image tags and boxs
|
||||
std::vector<std::string> rec_html_tags;
|
||||
std::vector<std::vector<std::vector<int>>> rec_boxes;
|
||||
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
float char_score = 0.f;
|
||||
int char_idx = 0;
|
||||
|
||||
// step
|
||||
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
|
||||
std::string html_tag;
|
||||
std::vector<std::vector<int>> rec_box;
|
||||
// html tag
|
||||
int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
|
||||
structure_probs_shape[2];
|
||||
char_idx = int(Utility::argmax(
|
||||
&structure_probs[step_start_idx],
|
||||
&structure_probs[step_start_idx + structure_probs_shape[2]]));
|
||||
char_score = float(*std::max_element(
|
||||
&structure_probs[step_start_idx],
|
||||
&structure_probs[step_start_idx + structure_probs_shape[2]]));
|
||||
html_tag = this->label_list_[char_idx];
|
||||
|
||||
if (step_idx > 0 && html_tag == this->end) {
|
||||
break;
|
||||
}
|
||||
if (html_tag == this->beg) {
|
||||
continue;
|
||||
}
|
||||
count += 1;
|
||||
score += char_score;
|
||||
rec_html_tags.push_back(html_tag);
|
||||
// box
|
||||
if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") {
|
||||
for (int point_idx = 0; point_idx < loc_preds_shape[2];
|
||||
point_idx += 2) {
|
||||
std::vector<int> point(2, 0);
|
||||
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
|
||||
loc_preds_shape[2] +
|
||||
point_idx;
|
||||
point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]);
|
||||
point[1] =
|
||||
int(loc_preds[step_start_idx + 1] * height_list[batch_idx]);
|
||||
rec_box.push_back(point);
|
||||
}
|
||||
rec_boxes.push_back(rec_box);
|
||||
}
|
||||
}
|
||||
score /= count;
|
||||
if (isnan(score) || rec_boxes.size() == 0) {
|
||||
score = -1;
|
||||
}
|
||||
rec_scores.push_back(score);
|
||||
rec_boxes_batch.push_back(rec_boxes);
|
||||
rec_html_tag_batch.push_back(rec_html_tags);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
int max_size_len, float &ratio_h, float &ratio_w,
|
||||
bool use_tensorrt) {
|
||||
string limit_type, int limit_side_len, float &ratio_h,
|
||||
float &ratio_w, bool use_tensorrt) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
float ratio = 1.f;
|
||||
int max_wh = w >= h ? w : h;
|
||||
if (max_wh > max_size_len) {
|
||||
if (h > w) {
|
||||
ratio = float(max_size_len) / float(h);
|
||||
} else {
|
||||
ratio = float(max_size_len) / float(w);
|
||||
if (limit_type == "min") {
|
||||
int min_wh = min(h, w);
|
||||
if (min_wh < limit_side_len) {
|
||||
if (h < w) {
|
||||
ratio = float(limit_side_len) / float(h);
|
||||
} else {
|
||||
ratio = float(limit_side_len) / float(w);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int max_wh = max(h, w);
|
||||
if (max_wh > limit_side_len) {
|
||||
if (h > w) {
|
||||
ratio = float(limit_side_len) / float(h);
|
||||
} else {
|
||||
ratio = float(limit_side_len) / float(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
}
|
||||
}
|
||||
|
||||
void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
int max_wh = w >= h ? w : h;
|
||||
float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h);
|
||||
|
||||
int resize_h = int(float(h) * ratio);
|
||||
int resize_w = int(float(w) * ratio);
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
}
|
||||
|
||||
void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/structure_table.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void StructureTableRecognizer::Run(
|
||||
std::vector<cv::Mat> img_list,
|
||||
std::vector<std::vector<std::string>> &structure_html_tags,
|
||||
std::vector<float> &structure_scores,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &structure_boxes,
|
||||
std::vector<double> ×) {
|
||||
std::chrono::duration<float> preprocess_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> inference_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> postprocess_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
|
||||
int img_num = img_list.size();
|
||||
for (int beg_img_no = 0; beg_img_no < img_num;
|
||||
beg_img_no += this->table_batch_num_) {
|
||||
// preprocess
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_);
|
||||
int batch_num = end_img_no - beg_img_no;
|
||||
std::vector<cv::Mat> norm_img_batch;
|
||||
std::vector<int> width_list;
|
||||
std::vector<int> height_list;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
cv::Mat srcimg;
|
||||
img_list[ino].copyTo(srcimg);
|
||||
cv::Mat resize_img;
|
||||
cv::Mat pad_img;
|
||||
this->resize_op_.Run(srcimg, resize_img, this->table_max_len_);
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
this->pad_op_.Run(resize_img, pad_img, this->table_max_len_);
|
||||
norm_img_batch.push_back(pad_img);
|
||||
width_list.push_back(srcimg.cols);
|
||||
height_list.push_back(srcimg.rows);
|
||||
}
|
||||
|
||||
std::vector<float> input(
|
||||
batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f);
|
||||
this->permute_op_.Run(norm_img_batch, input.data());
|
||||
auto preprocess_end = std::chrono::steady_clock::now();
|
||||
preprocess_diff += preprocess_end - preprocess_start;
|
||||
// inference.
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape(
|
||||
{batch_num, 3, this->table_max_len_, this->table_max_len_});
|
||||
auto inference_start = std::chrono::steady_clock::now();
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]);
|
||||
std::vector<int> predict_shape0 = output_tensor0->shape();
|
||||
std::vector<int> predict_shape1 = output_tensor1->shape();
|
||||
|
||||
int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(),
|
||||
1, std::multiplies<int>());
|
||||
int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> loc_preds;
|
||||
std::vector<float> structure_probs;
|
||||
loc_preds.resize(out_num0);
|
||||
structure_probs.resize(out_num1);
|
||||
|
||||
output_tensor0->CopyToCpu(loc_preds.data());
|
||||
output_tensor1->CopyToCpu(structure_probs.data());
|
||||
auto inference_end = std::chrono::steady_clock::now();
|
||||
inference_diff += inference_end - inference_start;
|
||||
// postprocess
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
std::vector<std::vector<std::string>> structure_html_tag_batch;
|
||||
std::vector<float> structure_score_batch;
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>>
|
||||
structure_boxes_batch;
|
||||
this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
|
||||
predict_shape0, predict_shape1,
|
||||
structure_html_tag_batch, structure_boxes_batch,
|
||||
width_list, height_list);
|
||||
for (int m = 0; m < predict_shape0[0]; m++) {
|
||||
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<table>");
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<body>");
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<html>");
|
||||
structure_html_tag_batch[m].push_back("</table>");
|
||||
structure_html_tag_batch[m].push_back("</body>");
|
||||
structure_html_tag_batch[m].push_back("</html>");
|
||||
structure_html_tags.push_back(structure_html_tag_batch[m]);
|
||||
structure_scores.push_back(structure_score_batch[m]);
|
||||
structure_boxes.push_back(structure_boxes_batch[m]);
|
||||
}
|
||||
auto postprocess_end = std::chrono::steady_clock::now();
|
||||
postprocess_diff += postprocess_end - postprocess_start;
|
||||
times.push_back(double(preprocess_diff.count() * 1000));
|
||||
times.push_back(double(inference_diff.count() * 1000));
|
||||
times.push_back(double(postprocess_diff.count() * 1000));
|
||||
}
|
||||
}
|
||||
|
||||
void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
config.SetModel(model_dir + "/inference.pdmodel",
|
||||
model_dir + "/inference.pdiparams");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
if (this->use_tensorrt_) {
|
||||
auto precision = paddle_infer::Config::Precision::kFloat32;
|
||||
if (this->precision_ == "fp16") {
|
||||
precision = paddle_infer::Config::Precision::kHalf;
|
||||
}
|
||||
if (this->precision_ == "int8") {
|
||||
precision = paddle_infer::Config::Precision::kInt8;
|
||||
}
|
||||
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
|
||||
}
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
|
@ -248,4 +248,33 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
|
|||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
|
||||
cv::Mat crop_im;
|
||||
int crop_x1 = std::max(0, area[0]);
|
||||
int crop_y1 = std::max(0, area[1]);
|
||||
int crop_x2 = std::min(img.cols - 1, area[2] - 1);
|
||||
int crop_y2 = std::min(img.rows - 1, area[3] - 1);
|
||||
|
||||
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16);
|
||||
cv::Mat crop_im_window =
|
||||
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]),
|
||||
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0]));
|
||||
cv::Mat roi_img =
|
||||
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
|
||||
crop_im_window += roi_img;
|
||||
return crop_im;
|
||||
}
|
||||
|
||||
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
|
||||
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
|
||||
|
||||
for (int i = 0; i < ocr_result.size() - 1; i++) {
|
||||
if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 &&
|
||||
(ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) {
|
||||
std::swap(ocr_result[i], ocr_result[i + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
|
|||
all_results.append([])
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res = self.text_sys(img)
|
||||
dt_boxes, rec_res, _ = self.text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time: {}".format(elapse))
|
||||
|
||||
|
|
|
@ -59,7 +59,8 @@ pip3 install paddlehub==2.1.0 --upgrade -i https://mirror.baidu.com/pypi/simple
|
|||
检测模型:./inference/ch_PP-OCRv3_det_infer/
|
||||
识别模型:./inference/ch_PP-OCRv3_rec_infer/
|
||||
方向分类器:./inference/ch_ppocr_mobile_v2.0_cls_infer/
|
||||
表格结构识别模型:./inference/en_ppocr_mobile_v2.0_table_structure_infer/
|
||||
版面分析模型:./inference/layout_infer/
|
||||
表格结构识别模型:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
|
||||
```
|
||||
|
||||
**模型路径可在`params.py`中查看和修改。** 更多模型可以从PaddleOCR提供的模型库[PP-OCR](../../doc/doc_ch/models_list.md)和[PP-Structure](../../ppstructure/docs/models_list.md)下载,也可以替换成自己训练转换好的模型。
|
||||
|
@ -172,7 +173,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json
|
|||
## 3. 发送预测请求
|
||||
配置好服务端,可使用以下命令发送预测请求,获取预测结果:
|
||||
|
||||
```python tools/test_hubserving.py server_url image_path```
|
||||
```python tools/test_hubserving.py --server_url=server_url --image_dir=image_path```
|
||||
|
||||
需要给脚本传递2个参数:
|
||||
- **server_url**:服务地址,格式为
|
||||
|
|
|
@ -61,7 +61,8 @@ Before installing the service module, you need to prepare the inference model an
|
|||
text detection model: ./inference/ch_PP-OCRv3_det_infer/
|
||||
text recognition model: ./inference/ch_PP-OCRv3_rec_infer/
|
||||
text angle classifier: ./inference/ch_ppocr_mobile_v2.0_cls_infer/
|
||||
tanle recognition: ./inference/en_ppocr_mobile_v2.0_table_structure_infer/
|
||||
layout parse model: ./inference/layout_infer/
|
||||
tanle recognition: ./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
|
||||
```
|
||||
|
||||
**The model path can be found and modified in `params.py`.** More models provided by PaddleOCR can be obtained from the [model library](../../doc/doc_en/models_list_en.md). You can also use models trained by yourself.
|
||||
|
@ -177,7 +178,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json
|
|||
## 3. Send prediction requests
|
||||
After the service starts, you can use the following command to send a prediction request to obtain the prediction result:
|
||||
```shell
|
||||
python tools/test_hubserving.py server_url image_path
|
||||
python tools/test_hubserving.py --server_url=server_url --image_dir=image_path
|
||||
```
|
||||
|
||||
Two parameters need to be passed to the script:
|
||||
|
|
|
@ -119,7 +119,7 @@ class StructureSystem(hub.Module):
|
|||
all_results.append([])
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = self.table_sys(img)
|
||||
res, _ = self.table_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time: {}".format(elapse))
|
||||
|
||||
|
@ -144,6 +144,6 @@ class StructureSystem(hub.Module):
|
|||
if __name__ == '__main__':
|
||||
structure_system = StructureSystem()
|
||||
structure_system._initialize()
|
||||
image_path = ['./doc/table/1.png']
|
||||
image_path = ['./ppstructure/docs/table/1.png']
|
||||
res = structure_system.predict(paths=image_path)
|
||||
print(res)
|
||||
|
|
|
@ -23,8 +23,10 @@ def read_params():
|
|||
cfg = table_read_params()
|
||||
|
||||
# params for layout parser model
|
||||
cfg.layout_path_model = 'lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config'
|
||||
cfg.layout_label_map = None
|
||||
cfg.layout_model_dir = ''
|
||||
cfg.layout_dict_path = './ppocr/utils/dict/layout_publaynet_dict.txt'
|
||||
cfg.layout_score_threshold = 0.5
|
||||
cfg.layout_nms_threshold = 0.5
|
||||
|
||||
cfg.mode = 'structure'
|
||||
cfg.output = './output'
|
||||
|
|
|
@ -118,11 +118,11 @@ class TableSystem(hub.Module):
|
|||
all_results.append([])
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_html = self.table_sys(img)
|
||||
res, _ = self.table_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time: {}".format(elapse))
|
||||
|
||||
all_results.append({'html': pred_html})
|
||||
all_results.append({'html': res['html']})
|
||||
return all_results
|
||||
|
||||
@serving
|
||||
|
@ -138,6 +138,6 @@ class TableSystem(hub.Module):
|
|||
if __name__ == '__main__':
|
||||
table_system = TableSystem()
|
||||
table_system._initialize()
|
||||
image_path = ['./doc/table/table.jpg']
|
||||
image_path = ['./ppstructure/docs/table/table.jpg']
|
||||
res = table_system.predict(paths=image_path)
|
||||
print(res)
|
||||
|
|
|
@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
|
|||
SAR文本识别模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
- [数据集汇总](#数据集汇总)
|
||||
- [1. PubTabNet数据集](#1-pubtabnet数据集)
|
||||
- [2. 好未来表格识别竞赛数据集](#2-好未来表格识别竞赛数据集)
|
||||
- [3. 好未来表格识别竞赛数据集](#2-WTW中文场景表格数据集)
|
||||
|
||||
这里整理了常用表格识别数据集,持续更新中,欢迎各位小伙伴贡献数据集~
|
||||
|
||||
|
@ -12,6 +13,7 @@
|
|||
|---|---|---|
|
||||
| PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl格式,可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
|
||||
| 好未来表格识别竞赛数据集 |https://ai.100tal.com/dataset| jsonl格式,可直接用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
|
||||
| WTW中文场景表格数据集 |https://github.com/wangwen-whu/WTW-Dataset| 需要进行转换后才能用[pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)加载 |
|
||||
|
||||
## 1. PubTabNet数据集
|
||||
- **数据简介**:PubTabNet数据集的训练集合中包含50万张图像,验证集合中包含0.9万张图像。部分图像可视化如下所示。
|
||||
|
@ -31,3 +33,12 @@
|
|||
<img src="../../datasets/table_tal_demo/1.jpg" width="500">
|
||||
<img src="../../datasets/table_tal_demo/2.jpg" width="500">
|
||||
</div>
|
||||
|
||||
## 3. WTW中文场景表格数据集
|
||||
- **数据简介**:WTW中文场景表格数据集包含表格检测和表格数据两部分数据,数据集中同时包含扫描和拍照两张场景的图像。
|
||||
|
||||
https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif" width="500">
|
||||
</div>
|
||||
|
|
|
@ -0,0 +1,343 @@
|
|||
# 表格识别
|
||||
|
||||
本文提供了PaddleOCR表格识别模型的全流程指南,包括数据准备、模型训练、调优、评估、预测,各个阶段的详细说明:
|
||||
|
||||
- [1. 数据准备](#1-数据准备)
|
||||
- [1.1. 数据集格式](#11-数据集格式)
|
||||
- [1.2. 数据下载](#12-数据下载)
|
||||
- [1.3. 数据集生成](#13-数据集生成)
|
||||
- [2. 开始训练](#2-开始训练)
|
||||
- [2.1. 启动训练](#21-启动训练)
|
||||
- [2.2. 断点训练](#22-断点训练)
|
||||
- [2.3. 更换Backbone 训练](#23-更换backbone-训练)
|
||||
- [2.4. 混合精度训练](#24-混合精度训练)
|
||||
- [2.5. 分布式训练](#25-分布式训练)
|
||||
- [2.6. 其他训练环境](#26-其他训练环境)
|
||||
- [2.7. 模型微调](#27-模型微调)
|
||||
- [3. 模型评估与预测](#3-模型评估与预测)
|
||||
- [3.1. 指标评估](#31-指标评估)
|
||||
- [3.2. 测试表格结构识别效果](#32-测试表格结构识别效果)
|
||||
- [4. 模型导出与预测](#4-模型导出与预测)
|
||||
- [4.1 模型导出](#41-模型导出)
|
||||
- [4.2 模型预测](#42-模型预测)
|
||||
- [5. FAQ](#5-faq)
|
||||
|
||||
# 1. 数据准备
|
||||
|
||||
## 1.1. 数据集格式
|
||||
|
||||
PaddleOCR 表格识别模型数据集格式如下:
|
||||
```txt
|
||||
img_label # 每张图片标注经过json.dumps()之后的字符串
|
||||
...
|
||||
img_label
|
||||
```
|
||||
|
||||
每一行的json格式为:
|
||||
```txt
|
||||
{
|
||||
'filename': PMC5755158_010_01.png, # 图像名
|
||||
'split': ’train‘, # 图像属于训练集还是验证集
|
||||
'imgid': 0, # 图像的index
|
||||
'html': {
|
||||
'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]}, # 表格的HTML字符串
|
||||
'cell': [
|
||||
{
|
||||
'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # 表格中的单个文本
|
||||
'bbox': [x0, y0, x1, y1] # 表格中的单个文本的坐标
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
|
||||
|
||||
```
|
||||
# linux and mac os
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
||||
# windows
|
||||
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
|
||||
```
|
||||
|
||||
## 1.2. 数据下载
|
||||
|
||||
公开数据集下载可参考 [table_datasets](dataset/table_datasets.md)。
|
||||
|
||||
## 1.3. 数据集生成
|
||||
|
||||
使用[TableGeneration](https://github.com/WenmuZhou/TableGeneration)可进行扫描表格图像的生成。
|
||||
|
||||
TableGeneration是一个开源表格数据集生成工具,其通过浏览器渲染的方式对html字符串进行渲染后获得表格图像。部分样张如下:
|
||||
|
||||
|类型|样例|
|
||||
|---|---|
|
||||
|简单表格||
|
||||
|彩色表格||
|
||||
|
||||
# 2. 开始训练
|
||||
|
||||
PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 [SLANet](../../configs/table/SLANet.yml) 模型训练PubTabNet英文数据集为例:
|
||||
|
||||
## 2.1. 启动训练
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```
|
||||
# GPU训练 支持单卡,多卡训练
|
||||
# 训练日志会自动保存为 "{save_model_dir}" 下的train.log
|
||||
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/table/SLANet.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml
|
||||
```
|
||||
|
||||
正常启动训练后,会看到以下log输出:
|
||||
|
||||
```
|
||||
[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27
|
||||
[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43
|
||||
[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48
|
||||
```
|
||||
|
||||
log 中自动打印如下信息:
|
||||
|
||||
| 字段 | 含义 |
|
||||
| :----: | :------: |
|
||||
| epoch | 当前迭代轮次 |
|
||||
| global_step | 当前迭代次数 |
|
||||
| lr | 当前学习率 |
|
||||
| acc | 当前batch的准确率 |
|
||||
| loss | 当前损失函数 |
|
||||
| structure_loss | 表格结构损失值 |
|
||||
| loc_loss | 单元格坐标损失值 |
|
||||
| avg_reader_cost | 当前 batch 数据处理耗时 |
|
||||
| avg_batch_cost | 当前 batch 总耗时 |
|
||||
| avg_samples | 当前 batch 内的样本数 |
|
||||
| ips | 每秒处理图片的数量 |
|
||||
|
||||
|
||||
PaddleOCR支持训练和评估交替进行, 可以在 `configs/table/SLANet.yml` 中修改 `eval_batch_step` 设置评估频率,默认每1000个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/SLANet/best_accuracy` 。
|
||||
|
||||
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
|
||||
|
||||
**提示:** 可通过 -c 参数选择 `configs/table/` 路径下的多种模型配置进行训练,PaddleOCR支持的表格识别算法可以参考[前沿算法列表](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95):
|
||||
|
||||
**注意,预测/评估时的配置文件请务必与训练一致。**
|
||||
|
||||
## 2.2. 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrained_model`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrained_model`指定的模型。
|
||||
|
||||
## 2.3. 更换Backbone 训练
|
||||
|
||||
PaddleOCR将网络划分为四部分,分别在[ppocr/modeling](../../ppocr/modeling)下。 进入网络的数据将按照顺序(transforms->backbones->necks->heads)依次通过这四个部分。
|
||||
|
||||
```bash
|
||||
├── architectures # 网络的组网代码
|
||||
├── transforms # 网络的图像变换模块
|
||||
├── backbones # 网络的特征提取模块
|
||||
├── necks # 网络的特征增强模块
|
||||
└── heads # 网络的输出模块
|
||||
```
|
||||
如果要更换的Backbone 在PaddleOCR中有对应实现,直接修改配置yml文件中`Backbone`部分的参数即可。
|
||||
|
||||
如果要使用新的Backbone,更换backbones的例子如下:
|
||||
|
||||
1. 在 [ppocr/modeling/backbones](../../ppocr/modeling/backbones) 文件夹下新建文件,如my_backbone.py。
|
||||
2. 在 my_backbone.py 文件内添加相关代码,示例代码如下:
|
||||
|
||||
```python
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class MyBackbone(nn.Layer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyBackbone, self).__init__()
|
||||
# your init code
|
||||
self.conv = nn.xxxx
|
||||
|
||||
def forward(self, inputs):
|
||||
# your network forward
|
||||
y = self.conv(inputs)
|
||||
return y
|
||||
```
|
||||
|
||||
3. 在 [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py)文件内导入添加的`MyBackbone`模块,然后修改配置文件中Backbone进行配置即可使用,格式如下:
|
||||
|
||||
```yaml
|
||||
Backbone:
|
||||
name: MyBackbone
|
||||
args1: args1
|
||||
```
|
||||
|
||||
**注意**:如果要更换网络的其他模块,可以参考[文档](./add_new_algorithm.md)。
|
||||
|
||||
## 2.4. 混合精度训练
|
||||
|
||||
如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下:
|
||||
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/SLANet.yml \
|
||||
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \
|
||||
Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True
|
||||
```
|
||||
|
||||
## 2.5. 分布式训练
|
||||
|
||||
多机多卡训练时,通过 `--ips` 参数设置使用的机器IP地址,通过 `--gpus` 参数设置使用的GPU ID:
|
||||
|
||||
```bash
|
||||
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \
|
||||
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy
|
||||
```
|
||||
|
||||
**注意:** (1)采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通;(2)训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`;(3)更多关于分布式训练的性能优势等信息,请参考:[分布式训练教程](./distributed_training.md)。
|
||||
|
||||
|
||||
## 2.6. 其他训练环境
|
||||
|
||||
- Windows GPU/CPU
|
||||
在Windows平台上与Linux平台略有不同:
|
||||
Windows平台只支持`单卡`的训练与预测,指定GPU进行训练`set CUDA_VISIBLE_DEVICES=0`
|
||||
在Windows平台,DataLoader只支持单进程模式,因此需要设置 `num_workers` 为0;
|
||||
|
||||
- macOS
|
||||
不支持GPU模式,需要在配置文件中设置`use_gpu`为False,其余训练评估预测命令与Linux GPU完全相同。
|
||||
|
||||
- Linux DCU
|
||||
DCU设备上运行需要设置环境变量 `export HIP_VISIBLE_DEVICES=0,1,2,3`,其余训练评估预测命令与Linux GPU完全相同。
|
||||
|
||||
## 2.7. 模型微调
|
||||
|
||||
实际使用过程中,建议加载官方提供的预训练模型,在自己的数据集中进行微调,关于模型的微调方法,请参考:[模型微调教程](./finetune.md)。
|
||||
|
||||
|
||||
# 3. 模型评估与预测
|
||||
|
||||
## 3.1. 指标评估
|
||||
|
||||
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。评估数据集可以通过 `configs/table/SLANet.yml` 修改Eval中的 `label_file_list` 设置。
|
||||
|
||||
|
||||
```
|
||||
# GPU 评估, Global.checkpoints 为待测权重
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
运行完成后,会输出模型的acc指标,如对英文表格识别模型进行评估,会见到如下输出。
|
||||
```bash
|
||||
[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782
|
||||
[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044
|
||||
```
|
||||
|
||||
## 3.2. 测试表格结构识别效果
|
||||
|
||||
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
|
||||
|
||||
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件:
|
||||
|
||||
根据配置文件中设置的 `save_model_dir` 和 `save_epoch_step` 字段,会有以下几种参数被保存下来:
|
||||
|
||||
```
|
||||
output/SLANet/
|
||||
├── best_accuracy.pdopt
|
||||
├── best_accuracy.pdparams
|
||||
├── best_accuracy.states
|
||||
├── config.yml
|
||||
├── latest.pdopt
|
||||
├── latest.pdparams
|
||||
├── latest.states
|
||||
└── train.log
|
||||
```
|
||||
其中 best_accuracy.* 是评估集上的最优模型;latest.* 是最后一个epoch的模型。
|
||||
|
||||
```
|
||||
# 预测表格图像
|
||||
python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg
|
||||
```
|
||||
|
||||
预测图片:
|
||||
|
||||

|
||||
|
||||
得到输入图像的预测结果:
|
||||
|
||||
```
|
||||
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
|
||||
```
|
||||
|
||||
单元格坐标可视化结果为
|
||||
|
||||

|
||||
|
||||
# 4. 模型导出与预测
|
||||
|
||||
## 4.1 模型导出
|
||||
|
||||
inference 模型(`paddle.jit.save`保存的模型)
|
||||
一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。
|
||||
训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。
|
||||
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
|
||||
|
||||
表格识别模型转inference模型与文字检测识别的方式相同,如下:
|
||||
|
||||
```
|
||||
# -c 后面设置训练算法的yml配置文件
|
||||
# -o 配置可选参数
|
||||
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
|
||||
```
|
||||
inference/SLANet/
|
||||
├── inference.pdiparams # inference模型的参数文件
|
||||
├── inference.pdiparams.info # inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # inference模型的program文件
|
||||
```
|
||||
|
||||
## 4.2 模型预测
|
||||
|
||||
模型导出后,使用如下命令即可完成inference模型的预测
|
||||
|
||||
```python
|
||||
python3.7 table/predict_structure.py \
|
||||
--table_model_dir={path/to/inference model} \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
|
||||
--image_dir=docs/table/table.jpg \
|
||||
--output=../output/table
|
||||
```
|
||||
|
||||
预测图片:
|
||||
|
||||

|
||||
|
||||
得到输入图像的预测结果:
|
||||
|
||||
```
|
||||
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
|
||||
```
|
||||
|
||||
单元格坐标可视化结果为
|
||||
|
||||

|
||||
|
||||
|
||||
# 5. FAQ
|
||||
|
||||
Q1: 训练模型转inference 模型之后预测效果不一致?
|
||||
|
||||
**A**:此类问题出现较多,问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。可以对比训练使用的配置文件中的预处理、后处理和预测时是否存在差异。
|
|
@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
|
|||
For SAR text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
- [Dataset Summary](#dataset-summary)
|
||||
- [1. PubTabNet](#1-pubtabnet)
|
||||
- [2. TAL Table Recognition Competition Dataset](#2-tal-table-recognition-competition-dataset)
|
||||
- [3. WTW Chinese scene table dataset](#3-wtw-chinese-scene-table-dataset)
|
||||
|
||||
Here are the commonly used table recognition datasets, which are being updated continuously. Welcome to contribute datasets~
|
||||
|
||||
|
@ -12,6 +13,7 @@ Here are the commonly used table recognition datasets, which are being updated c
|
|||
|---|---|---|
|
||||
| PubTabNet |https://github.com/ibm-aur-nlp/PubTabNet| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) |
|
||||
| TAL Table Recognition Competition Dataset |https://ai.100tal.com/dataset| jsonl format, which can be loaded directly with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py) |
|
||||
| WTW Chinese scene table dataset |https://github.com/wangwen-whu/WTW-Dataset| Conversion is required to load with [pubtab_dataset.py](../../../ppocr/data/pubtab_dataset.py)|
|
||||
|
||||
## 1. PubTabNet
|
||||
- **Data Introduction**:The training set of the PubTabNet dataset contains 500,000 images and the validation set contains 9000 images. Part of the image visualization is shown below.
|
||||
|
@ -30,3 +32,11 @@ Here are the commonly used table recognition datasets, which are being updated c
|
|||
<img src="../../datasets/table_tal_demo/1.jpg" width="500">
|
||||
<img src="../../datasets/table_tal_demo/2.jpg" width="500">
|
||||
</div>
|
||||
|
||||
## 3. WTW Chinese scene table dataset
|
||||
- **Data Introduction**:The WTW Chinese scene table dataset consists of two parts: table detection and table data. The dataset contains images of two scenes, scanned and photographed.
|
||||
https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/wangwen-whu/WTW-Dataset/blob/main/demo/20210816_210413.gif" width="500">
|
||||
</div>
|
||||
|
|
|
@ -3,14 +3,14 @@
|
|||
**Note:** This tutorial mainly introduces the usage of PP-OCR series models, please refer to [PP-Structure Quick Start](../../ppstructure/docs/quickstart_en.md) for the quick use of document analysis related functions.
|
||||
|
||||
- [1. Installation](#1-installation)
|
||||
- [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
|
||||
- [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
|
||||
- [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
|
||||
- [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
|
||||
- [2. Easy-to-Use](#2-easy-to-use)
|
||||
- [2.1 Use by Command Line](#21-use-by-command-line)
|
||||
- [2.1.1 Chinese and English Model](#211-chinese-and-english-model)
|
||||
- [2.1.2 Multi-language Model](#212-multi-language-model)
|
||||
- [2.2 Use by Code](#22-use-by-code)
|
||||
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model)
|
||||
- [2.1 Use by Command Line](#21-use-by-command-line)
|
||||
- [2.1.1 Chinese and English Model](#211-chinese-and-english-model)
|
||||
- [2.1.2 Multi-language Model](#212-multi-language-model)
|
||||
- [2.2 Use by Code](#22-use-by-code)
|
||||
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model)
|
||||
- [3. Summary](#3-summary)
|
||||
|
||||
|
||||
|
@ -51,12 +51,6 @@ pip install "paddleocr>=2.0.1" # Recommend to use version 2.0.1+
|
|||
|
||||
Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
|
||||
|
||||
- **For layout analysis users**, run the following command to install **Layout-Parser**
|
||||
|
||||
```bash
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
<a name="2-easy-to-use"></a>
|
||||
|
||||
## 2. Easy-to-Use
|
||||
|
|
|
@ -0,0 +1,354 @@
|
|||
# Table Recognition
|
||||
|
||||
This article provides a full-process guide for the PaddleOCR table recognition model, including data preparation, model training, tuning, evaluation, prediction, and detailed descriptions of each stage:
|
||||
|
||||
- [1. Data Preparation](#1-data-preparation)
|
||||
- [1.1. DataSet Format](#11-dataset-format)
|
||||
- [1.2. Data Download](#12-data-download)
|
||||
- [1.3. Dataset Generation](#13-dataset-generation)
|
||||
- [2. Training](#2-training)
|
||||
- [2.1. Start Training](#21-start-training)
|
||||
- [2.2. Resume Training](#22-resume-training)
|
||||
- [2.3. Training with New Backbone](#23-training-with-new-backbone)
|
||||
- [2.4. Mixed Precision Training](#24-mixed-precision-training)
|
||||
- [2.5. Distributed Training](#25-distributed-training)
|
||||
- [2.6. Training on other platform(Windows/macOS/Linux DCU)](#26-training-on-other-platformwindowsmacoslinux-dcu)
|
||||
- [2.7. Fine-tuning](#27-fine-tuning)
|
||||
- [3. Evaluation and Test](#3-evaluation-and-test)
|
||||
- [3.1. Evaluation](#31-evaluation)
|
||||
- [3.2. Test table structure recognition effect](#32-test-table-structure-recognition-effect)
|
||||
- [4. Model export and prediction](#4-model-export-and-prediction)
|
||||
- [4.1 Model export](#41-model-export)
|
||||
- [4.2 Prediction](#42-prediction)
|
||||
- [5. FAQ](#5-faq)
|
||||
|
||||
# 1. Data Preparation
|
||||
|
||||
## 1.1. DataSet Format
|
||||
|
||||
The format of the PaddleOCR table recognition model dataset is as follows:
|
||||
```txt
|
||||
img_label # Each image is marked with a string after json.dumps()
|
||||
...
|
||||
img_label
|
||||
```
|
||||
|
||||
The json format of each line is:
|
||||
```txt
|
||||
{
|
||||
'filename': PMC5755158_010_01.png,# image name
|
||||
'split': ’train‘, # whether the image belongs to the training set or the validation set
|
||||
'imgid': 0,# index of image
|
||||
'html': {
|
||||
'structure': {'tokens': ['<thead>', '<tr>', '<td>', ...]}, # HTML string of the table
|
||||
'cell': [
|
||||
{
|
||||
'tokens': ['P', 'a', 'd', 'd', 'l', 'e', 'P', 'a', 'd', 'd', 'l', 'e'], # text in cell
|
||||
'bbox': [x0, y0, x1, y1] # bbox of cell
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The default storage path for training data is `PaddleOCR/train_data`, if you already have a dataset on disk, just create a soft link to the dataset directory:
|
||||
|
||||
```
|
||||
# linux and mac os
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
||||
# windows
|
||||
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
|
||||
```
|
||||
|
||||
## 1.2. Data Download
|
||||
|
||||
Download the public dataset reference [table_datasets](dataset/table_datasets_en.md)。
|
||||
|
||||
## 1.3. Dataset Generation
|
||||
|
||||
Use [TableGeneration](https://github.com/WenmuZhou/TableGeneration) to generate scanned table images.
|
||||
|
||||
TableGeneration is an open source table dataset generation tool, which renders html strings through browser rendering to obtain table images.
|
||||
|
||||
Some samples are as follows:
|
||||
|
||||
|Type|Sample|
|
||||
|---|---|
|
||||
|Simple Table||
|
||||
|Simple Color Table||
|
||||
|
||||
# 2. Training
|
||||
|
||||
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. In this section, the [SLANet](../../configs/table/SLANet.yml) model will be used as an example:
|
||||
|
||||
## 2.1. Start Training
|
||||
|
||||
*If you are installing the cpu version, please modify the `use_gpu` field in the configuration file to false*
|
||||
|
||||
```
|
||||
# GPU training Support single card and multi-card training
|
||||
# The training log will be automatically saved as train.log under "{save_model_dir}"
|
||||
|
||||
# specify the single card training(Long training time, not recommended)
|
||||
python3 tools/train.py -c configs/table/SLANet.yml
|
||||
|
||||
# specify the card number through --gpus
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml
|
||||
```
|
||||
|
||||
After starting training normally, you will see the following log output:
|
||||
|
||||
```
|
||||
[2022/08/16 03:07:33] ppocr INFO: epoch: [1/400], global_step: 20, lr: 0.000100, acc: 0.000000, loss: 3.915012, structure_loss: 3.229450, loc_loss: 0.670590, avg_reader_cost: 2.63382 s, avg_batch_cost: 6.32390 s, avg_samples: 48.0, ips: 7.59025 samples/s, eta: 9 days, 2:29:27
|
||||
[2022/08/16 03:08:41] ppocr INFO: epoch: [1/400], global_step: 40, lr: 0.000100, acc: 0.000000, loss: 1.750859, structure_loss: 1.082116, loc_loss: 0.652822, avg_reader_cost: 0.02533 s, avg_batch_cost: 3.37251 s, avg_samples: 48.0, ips: 14.23271 samples/s, eta: 6 days, 23:28:43
|
||||
[2022/08/16 03:09:46] ppocr INFO: epoch: [1/400], global_step: 60, lr: 0.000100, acc: 0.000000, loss: 1.395154, structure_loss: 0.776803, loc_loss: 0.625030, avg_reader_cost: 0.02550 s, avg_batch_cost: 3.26261 s, avg_samples: 48.0, ips: 14.71214 samples/s, eta: 6 days, 5:11:48
|
||||
```
|
||||
|
||||
The following information is automatically printed in the log:
|
||||
|
||||
| Field | Meaning |
|
||||
| :----: | :------: |
|
||||
| epoch | current iteration round |
|
||||
| global_step | current iteration count |
|
||||
| lr | current learning rate |
|
||||
| acc | The accuracy of the current batch |
|
||||
| loss | current loss function |
|
||||
| structure_loss | Table Structure Loss Values |
|
||||
| loc_loss | Cell Coordinate Loss Value |
|
||||
| avg_reader_cost | Current batch data processing time |
|
||||
| avg_batch_cost | The total time spent in the current batch |
|
||||
| avg_samples | The number of samples in the current batch |
|
||||
| ips | Number of images processed per second |
|
||||
|
||||
|
||||
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/table/SLANet.yml` to set the evaluation frequency. By default, it is evaluated once every 1000 iters. During the evaluation process, the best acc model is saved as `output/SLANet/best_accuracy` by default.
|
||||
|
||||
If the validation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or perform evaluation after training.
|
||||
|
||||
**Tips:** You can use the -c parameter to select various model configurations under the `configs/table/` path for training. For the table recognition algorithms supported by PaddleOCR, please refer to [Table Algorithms List](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/algorithm_overview_en.md#3):
|
||||
|
||||
**Note that the configuration file for prediction/evaluation must be the same as training. **
|
||||
|
||||
## 2.2. Resume Training
|
||||
|
||||
If the training program is interrupted, if you want to load the interrupted model to resume training, you can specify the path of the model to be loaded by specifying Global.checkpoints:
|
||||
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/SLANet.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If `Global.checkpoints` The specified model path is incorrect, and the model specified by `Global.pretrained_model` will be loaded.
|
||||
|
||||
## 2.3. Training with New Backbone
|
||||
|
||||
The network part completes the construction of the network, and PaddleOCR divides the network into four parts, which are under [ppocr/modeling](../../ppocr/modeling). The data entering the network will pass through these four parts in sequence(transforms->backbones->
|
||||
necks->heads).
|
||||
|
||||
```bash
|
||||
├── architectures # Code for building network
|
||||
├── transforms # Image Transformation Module
|
||||
├── backbones # Feature extraction module
|
||||
├── necks # Feature enhancement module
|
||||
└── heads # Output module
|
||||
```
|
||||
|
||||
If the Backbone to be replaced has a corresponding implementation in PaddleOCR, you can directly modify the parameters in the `Backbone` part of the configuration yml file.
|
||||
|
||||
However, if you want to use a new Backbone, an example of replacing the backbones is as follows:
|
||||
|
||||
1. Create a new file under the [ppocr/modeling/backbones](../../ppocr/modeling/backbones) folder, such as my_backbone.py.
|
||||
2. Add code in the my_backbone.py file, the sample code is as follows:
|
||||
|
||||
```python
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class MyBackbone(nn.Layer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyBackbone, self).__init__()
|
||||
# your init code
|
||||
self.conv = nn.xxxx
|
||||
|
||||
def forward(self, inputs):
|
||||
# your network forward
|
||||
y = self.conv(inputs)
|
||||
return y
|
||||
```
|
||||
|
||||
3. Import the added module in the [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py) file.
|
||||
|
||||
After adding the four-part modules of the network, you only need to configure them in the configuration file to use, such as:
|
||||
|
||||
```yaml
|
||||
Backbone:
|
||||
name: MyBackbone
|
||||
args1: args1
|
||||
```
|
||||
|
||||
**NOTE**: More details about replace Backbone and other mudule can be found in [doc](add_new_algorithm_en.md).
|
||||
|
||||
## 2.4. Mixed Precision Training
|
||||
|
||||
If you want to speed up your training further, you can use [Auto Mixed Precision Training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), taking a single machine and a single gpu as an example, the commands are as follows:
|
||||
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/SLANet.yml \
|
||||
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy \
|
||||
Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True
|
||||
```
|
||||
|
||||
## 2.5. Distributed Training
|
||||
|
||||
During multi-machine multi-gpu training, use the `--ips` parameter to set the used machine IP address, and the `--gpus` parameter to set the used GPU ID:
|
||||
|
||||
```bash
|
||||
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/table/SLANet.yml \
|
||||
-o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy
|
||||
```
|
||||
|
||||
|
||||
**Note:** (1) When using multi-machine and multi-gpu training, you need to replace the ips value in the above command with the address of your machine, and the machines need to be able to ping each other. (2) Training needs to be launched separately on multiple machines. The command to view the ip address of the machine is `ifconfig`. (3) For more details about the distributed training speedup ratio, please refer to [Distributed Training Tutorial](./distributed_training_en.md).
|
||||
|
||||
## 2.6. Training on other platform(Windows/macOS/Linux DCU)
|
||||
|
||||
- Windows GPU/CPU
|
||||
The Windows platform is slightly different from the Linux platform:
|
||||
Windows platform only supports `single gpu` training and inference, specify GPU for training `set CUDA_VISIBLE_DEVICES=0`
|
||||
On the Windows platform, DataLoader only supports single-process mode, so you need to set `num_workers` to 0;
|
||||
|
||||
- macOS
|
||||
GPU mode is not supported, you need to set `use_gpu` to False in the configuration file, and the rest of the training evaluation prediction commands are exactly the same as Linux GPU.
|
||||
|
||||
- Linux DCU
|
||||
Running on a DCU device requires setting the environment variable `export HIP_VISIBLE_DEVICES=0,1,2,3`, and the rest of the training and evaluation prediction commands are exactly the same as the Linux GPU.
|
||||
|
||||
|
||||
## 2.7. Fine-tuning
|
||||
|
||||
In the actual use process, it is recommended to load the officially provided pre-training model and fine-tune it in your own data set. For the fine-tuning method of the table recognition model, please refer to: [Model fine-tuning tutorial](./finetune.md).
|
||||
|
||||
|
||||
# 3. Evaluation and Test
|
||||
|
||||
## 3.1. Evaluation
|
||||
|
||||
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating metrics, you need to set `Global.checkpoints` to point to the saved parameter file. Evaluation datasets can be modified via the `label_file_list` setting in Eval via `configs/table/SLANet.yml`.
|
||||
|
||||
```
|
||||
# GPU evaluation, Global.checkpoints is the weight to be tested
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/table/SLANet.yml -o Global.checkpoints={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
After the operation is completed, the acc indicator of the model will be output. If you evaluate the English table recognition model, you will see the following output.
|
||||
|
||||
```bash
|
||||
[2022/08/16 07:59:55] ppocr INFO: acc:0.7622245132160782
|
||||
[2022/08/16 07:59:55] ppocr INFO: fps:30.991640622573044
|
||||
```
|
||||
|
||||
## 3.2. Test table structure recognition effect
|
||||
|
||||
Using the model trained by PaddleOCR, you can quickly get prediction through the following script.
|
||||
|
||||
The default prediction picture is stored in `infer_img`, and the trained weight is specified via `-o Global.checkpoints`:
|
||||
|
||||
|
||||
According to the `save_model_dir` and `save_epoch_step` fields set in the configuration file, the following parameters will be saved:
|
||||
|
||||
|
||||
```
|
||||
output/SLANet/
|
||||
├── best_accuracy.pdopt
|
||||
├── best_accuracy.pdparams
|
||||
├── best_accuracy.states
|
||||
├── config.yml
|
||||
├── latest.pdopt
|
||||
├── latest.pdparams
|
||||
├── latest.states
|
||||
└── train.log
|
||||
```
|
||||
Among them, best_accuracy.* is the best model on the evaluation set; latest.* is the model of the last epoch.
|
||||
|
||||
```
|
||||
# Predict table image
|
||||
python3 tools/infer_table.py -c configs/table/SLANet.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=ppstructure/docs/table/table.jpg
|
||||
```
|
||||
|
||||
Input image:
|
||||
|
||||

|
||||
|
||||
Get the prediction result of the input image:
|
||||
|
||||
```
|
||||
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
|
||||
```
|
||||
|
||||
The cell coordinates are visualized as
|
||||
|
||||

|
||||
|
||||
# 4. Model export and prediction
|
||||
|
||||
## 4.1 Model export
|
||||
|
||||
inference model (model saved by `paddle.jit.save`)
|
||||
Generally, it is model training, a solidified model that saves the model structure and model parameters in a file, and is mostly used to predict deployment scenarios.
|
||||
The model saved during the training process is the checkpoints model, and only the parameters of the model are saved, which are mostly used to resume training.
|
||||
Compared with the checkpoints model, the inference model will additionally save the structural information of the model. It has superior performance in predicting deployment and accelerating reasoning, and is flexible and convenient, and is suitable for actual system integration.
|
||||
|
||||
The way to convert the form recognition model to the inference model is the same as the text detection and recognition, as follows:
|
||||
|
||||
```
|
||||
# -c Set the training algorithm yml configuration file
|
||||
# -o Set optional parameters
|
||||
# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
|
||||
# Global.save_inference_dir Set the address where the converted model will be saved.
|
||||
|
||||
python3 tools/export_model.py -c configs/table/SLANet.yml -o Global.pretrained_model=./pretrain_models/SLANet/best_accuracy Global.save_inference_dir=./inference/SLANet/
|
||||
```
|
||||
|
||||
After the conversion is successful, there are three files in the model save directory:
|
||||
|
||||
|
||||
```
|
||||
inference/SLANet/
|
||||
├── inference.pdiparams # The parameter file of inference model
|
||||
├── inference.pdiparams.info # The parameter information of inference model, which can be ignored
|
||||
└── inference.pdmodel # The program file of model
|
||||
```
|
||||
|
||||
## 4.2 Prediction
|
||||
|
||||
After the model is exported, use the following command to complete the prediction of the inference model
|
||||
|
||||
```python
|
||||
python3.7 table/predict_structure.py \
|
||||
--table_model_dir={path/to/inference model} \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
|
||||
--image_dir=docs/table/table.jpg \
|
||||
--output=../output/table
|
||||
```
|
||||
|
||||
Input image:
|
||||
|
||||

|
||||
|
||||
Get the prediction result of the input image:
|
||||
|
||||
```
|
||||
['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'],[[320.0562438964844, 197.83375549316406, 350.0928955078125, 214.4309539794922], ... , [318.959228515625, 271.0166931152344, 353.7394104003906, 286.4538269042969]]
|
||||
```
|
||||
|
||||
The cell coordinates are visualized as
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
# 5. FAQ
|
||||
|
||||
Q1: After the training model is transferred to the inference model, the prediction effect is inconsistent?
|
||||
|
||||
**A**: There are many such problems, and the problems are mostly caused by inconsistent preprocessing and postprocessing parameters when the trained model predicts and the preprocessing and postprocessing parameters when the inference model predicts. You can compare whether there are differences in preprocessing, postprocessing, and prediction in the configuration files used for training.
|
60
paddleocr.py
60
paddleocr.py
|
@ -47,14 +47,14 @@ __all__ = [
|
|||
]
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
VERSION = '2.5.0.3'
|
||||
VERSION = '2.6'
|
||||
SUPPORT_REC_MODEL = ['CRNN', 'SVTR_LCNet']
|
||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3'
|
||||
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2', 'PP-OCRv3']
|
||||
DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-STRUCTURE'
|
||||
SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-STRUCTURE']
|
||||
DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-Structurev2'
|
||||
SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-Structure', 'PP-Structurev2']
|
||||
MODEL_URLS = {
|
||||
'OCR': {
|
||||
'PP-OCRv3': {
|
||||
|
@ -263,7 +263,7 @@ MODEL_URLS = {
|
|||
}
|
||||
},
|
||||
'STRUCTURE': {
|
||||
'PP-STRUCTURE': {
|
||||
'PP-Structure': {
|
||||
'table': {
|
||||
'en': {
|
||||
'url':
|
||||
|
@ -271,6 +271,27 @@ MODEL_URLS = {
|
|||
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
|
||||
}
|
||||
}
|
||||
},
|
||||
'PP-Structurev2': {
|
||||
'table': {
|
||||
'en': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
|
||||
},
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_structure_dict_ch.txt'
|
||||
}
|
||||
},
|
||||
'layout': {
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/layout_publaynet_dict.txt'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -298,12 +319,15 @@ def parse_args(mMain=True):
|
|||
"--structure_version",
|
||||
type=str,
|
||||
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
|
||||
default='PP-STRUCTURE',
|
||||
default='PP-Structurev2',
|
||||
help='Model version, the current model support list is as follows:'
|
||||
' 1. STRUCTURE Support en table structure model.')
|
||||
' 1. PP-Structure Support en table structure model.'
|
||||
' 2. PP-Structurev2 Support ch and en table structure model.')
|
||||
|
||||
for action in parser._actions:
|
||||
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
|
||||
if action.dest in [
|
||||
'rec_char_dict_path', 'table_char_dict_path', 'layout_dict_path'
|
||||
]:
|
||||
action.default = None
|
||||
if mMain:
|
||||
return parser.parse_args()
|
||||
|
@ -477,7 +501,7 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
if det and rec:
|
||||
dt_boxes, rec_res = self.__call__(img, cls)
|
||||
dt_boxes, rec_res, _ = self.__call__(img, cls)
|
||||
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||
elif det and not rec:
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
@ -506,6 +530,12 @@ class PPStructure(StructureSystem):
|
|||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
lang, det_lang = parse_lang(params.lang)
|
||||
if lang == 'ch':
|
||||
table_lang = 'ch'
|
||||
else:
|
||||
table_lang = 'en'
|
||||
if params.structure_version == 'PP-Structure':
|
||||
params.merge_no_span_structure = False
|
||||
|
||||
# init model dir
|
||||
det_model_config = get_model_config('OCR', params.ocr_version, 'det',
|
||||
|
@ -520,14 +550,20 @@ class PPStructure(StructureSystem):
|
|||
params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, 'whl', 'rec', lang), rec_model_config['url'])
|
||||
table_model_config = get_model_config(
|
||||
'STRUCTURE', params.structure_version, 'table', 'en')
|
||||
'STRUCTURE', params.structure_version, 'table', table_lang)
|
||||
params.table_model_dir, table_url = confirm_model_dir_url(
|
||||
params.table_model_dir,
|
||||
os.path.join(BASE_DIR, 'whl', 'table'), table_model_config['url'])
|
||||
layout_model_config = get_model_config(
|
||||
'STRUCTURE', params.structure_version, 'layout', 'ch')
|
||||
params.layout_model_dir, layout_url = confirm_model_dir_url(
|
||||
params.layout_model_dir,
|
||||
os.path.join(BASE_DIR, 'whl', 'layout'), layout_model_config['url'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.table_model_dir, table_url)
|
||||
maybe_download(params.layout_model_dir, layout_url)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_dict_path = str(
|
||||
|
@ -535,7 +571,9 @@ class PPStructure(StructureSystem):
|
|||
if params.table_char_dict_path is None:
|
||||
params.table_char_dict_path = str(
|
||||
Path(__file__).parent / table_model_config['dict_path'])
|
||||
|
||||
if params.layout_dict_path is None:
|
||||
params.layout_dict_path = str(
|
||||
Path(__file__).parent / layout_model_config['dict_path'])
|
||||
logger.debug(params)
|
||||
super().__init__(params)
|
||||
|
||||
|
@ -557,7 +595,7 @@ class PPStructure(StructureSystem):
|
|||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
res = super().__call__(img, return_ocr_result_in_table)
|
||||
res, _ = super().__call__(img, return_ocr_result_in_table)
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -575,7 +575,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
replace_empty_cell_token=False,
|
||||
merge_no_span_structure=False,
|
||||
learn_empty_box=False,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
self.max_text_len = max_text_length
|
||||
self.lower = False
|
||||
|
@ -590,6 +590,12 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
dict_character.append(line)
|
||||
|
||||
if self.merge_no_span_structure:
|
||||
if "<td></td>" not in dict_character:
|
||||
dict_character.append("<td></td>")
|
||||
if "<td>" in dict_character:
|
||||
dict_character.remove("<td>")
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
|
@ -597,7 +603,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
self.idx2char = {v: k for k, v in self.dict.items()}
|
||||
|
||||
self.character = dict_character
|
||||
self.point_num = point_num
|
||||
self.loc_reg_num = loc_reg_num
|
||||
self.pad_idx = self.dict[self.beg_str]
|
||||
self.start_idx = self.dict[self.beg_str]
|
||||
self.end_idx = self.dict[self.end_str]
|
||||
|
@ -653,7 +659,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
|
||||
# encode box
|
||||
bboxes = np.zeros(
|
||||
(self._max_text_len, self.point_num * 2), dtype=np.float32)
|
||||
(self._max_text_len, self.loc_reg_num), dtype=np.float32)
|
||||
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
|
||||
|
||||
bbox_idx = 0
|
||||
|
@ -718,11 +724,11 @@ class TableMasterLabelEncode(TableLabelEncode):
|
|||
replace_empty_cell_token=False,
|
||||
merge_no_span_structure=False,
|
||||
learn_empty_box=False,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableMasterLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, replace_empty_cell_token,
|
||||
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
|
||||
merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
|
||||
self.pad_idx = self.dict[self.pad_str]
|
||||
self.unknown_idx = self.dict[self.unknown_str]
|
||||
|
||||
|
@ -743,27 +749,35 @@ class TableMasterLabelEncode(TableLabelEncode):
|
|||
|
||||
|
||||
class TableBoxEncode(object):
|
||||
def __init__(self, use_xywh=False, **kwargs):
|
||||
self.use_xywh = use_xywh
|
||||
def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
|
||||
assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
|
||||
self.in_box_format = in_box_format
|
||||
self.out_box_format = out_box_format
|
||||
|
||||
def __call__(self, data):
|
||||
img_height, img_width = data['image'].shape[:2]
|
||||
bboxes = data['bboxes']
|
||||
if self.use_xywh and bboxes.shape[1] == 4:
|
||||
bboxes = self.xyxy2xywh(bboxes)
|
||||
if self.in_box_format != self.out_box_format:
|
||||
if self.out_box_format == 'xywh':
|
||||
if self.in_box_format == 'xyxyxyxy':
|
||||
bboxes = self.xyxyxyxy2xywh(bboxes)
|
||||
elif self.in_box_format == 'xyxy':
|
||||
bboxes = self.xyxy2xywh(bboxes)
|
||||
|
||||
bboxes[:, 0::2] /= img_width
|
||||
bboxes[:, 1::2] /= img_height
|
||||
data['bboxes'] = bboxes
|
||||
return data
|
||||
|
||||
def xyxyxyxy2xywh(self, boxes):
|
||||
new_bboxes = np.zeros([len(bboxes), 4])
|
||||
new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
|
||||
new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
|
||||
new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
|
||||
new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
|
||||
return new_bboxes
|
||||
|
||||
def xyxy2xywh(self, bboxes):
|
||||
"""
|
||||
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
|
||||
where (x1,y1) is top-left, (x2,y2) is bottom-right.
|
||||
(x,y) is bbox center and (w,h) is width and height.
|
||||
:param bboxes: (x1, y1, x2, y2)
|
||||
:return:
|
||||
"""
|
||||
new_bboxes = np.empty_like(bboxes)
|
||||
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
|
||||
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
|
||||
|
|
|
@ -206,7 +206,7 @@ class ResizeTableImage(object):
|
|||
data['bboxes'] = data['bboxes'] * ratio
|
||||
data['image'] = resize_img
|
||||
data['src_img'] = img
|
||||
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
|
||||
data['shape'] = np.array([height, width, ratio, ratio])
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
|
|||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
from .table_att_loss import TableAttentionLoss, SLALoss
|
||||
from .table_master_loss import TableMasterLoss
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
@ -67,7 +67,8 @@ def build_loss(config):
|
|||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss'
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||
'SLALoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -22,65 +22,11 @@ from paddle.nn import functional as F
|
|||
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
structure_weight,
|
||||
loc_weight,
|
||||
use_giou=False,
|
||||
giou_weight=1.0,
|
||||
**kwargs):
|
||||
def __init__(self, structure_weight, loc_weight, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
|
||||
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
|
||||
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
|
||||
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
|
||||
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
|
||||
) * (bbox[:, 3] - bbox[:, 1] +
|
||||
1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
|
||||
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
|
||||
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
|
||||
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
|
||||
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
|
@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer):
|
|||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
|
||||
loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
|
||||
loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss,
|
||||
"loc_loss_giou": loc_loss_giou
|
||||
}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
||||
|
||||
class SLALoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
|
||||
super(SLALoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.loc_loss = loc_loss
|
||||
self.eps = 1e-12
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[3].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
|
||||
loc_loss = F.smooth_l1_loss(
|
||||
loc_preds * loc_targets_mask,
|
||||
loc_targets * loc_targets_mask,
|
||||
reduction='sum') * self.loc_weight
|
||||
|
||||
loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
|
|
@ -16,9 +16,14 @@ from ppocr.metrics.det_metric import DetMetric
|
|||
|
||||
|
||||
class TableStructureMetric(object):
|
||||
def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
|
||||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
eps=1e-6,
|
||||
del_thead_tbody=False,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.eps = eps
|
||||
self.del_thead_tbody = del_thead_tbody
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, batch=None, *args, **kwargs):
|
||||
|
@ -31,6 +36,13 @@ class TableStructureMetric(object):
|
|||
gt_structure_batch_list):
|
||||
pred_str = ''.join(pred)
|
||||
target_str = ''.join(target)
|
||||
if self.del_thead_tbody:
|
||||
pred_str = pred_str.replace('<thead>', '').replace(
|
||||
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
|
||||
'')
|
||||
target_str = target_str.replace('<thead>', '').replace(
|
||||
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
|
||||
'')
|
||||
if pred_str == target_str:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
|
@ -59,7 +71,8 @@ class TableMetric(object):
|
|||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
compute_bbox_metric=False,
|
||||
point_num=2,
|
||||
box_format='xyxy',
|
||||
del_thead_tbody=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
|
@ -67,10 +80,11 @@ class TableMetric(object):
|
|||
@param main_matric: main_matric for save best_model
|
||||
@param kwargs:
|
||||
"""
|
||||
self.structure_metric = TableStructureMetric()
|
||||
self.structure_metric = TableStructureMetric(
|
||||
del_thead_tbody=del_thead_tbody)
|
||||
self.bbox_metric = DetMetric() if compute_bbox_metric else None
|
||||
self.main_indicator = main_indicator
|
||||
self.point_num = point_num
|
||||
self.box_format = box_format
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, batch=None, *args, **kwargs):
|
||||
|
@ -129,10 +143,14 @@ class TableMetric(object):
|
|||
self.bbox_metric.reset()
|
||||
|
||||
def format_box(self, box):
|
||||
if self.point_num == 2:
|
||||
if self.box_format == 'xyxy':
|
||||
x1, y1, x2, y2 = box
|
||||
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
elif self.point_num == 4:
|
||||
elif self.box_format == 'xywh':
|
||||
x, y, w, h = box
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
elif self.box_format == 'xyxyxyxy':
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = box
|
||||
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
|
||||
return box
|
||||
|
|
|
@ -21,7 +21,10 @@ def build_backbone(config, model_type):
|
|||
from .det_resnet import ResNet
|
||||
from .det_resnet_vd import ResNet_vd
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
|
||||
from .det_pp_lcnet import PPLCNet
|
||||
support_dict = [
|
||||
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
|
||||
]
|
||||
if model_type == "table":
|
||||
from .table_master_resnet import TableResNetExtra
|
||||
support_dict.append('TableResNetExtra')
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
|
||||
from paddle.regularizer import L2Decay
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
from paddle.utils.download import get_path_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"PPLCNet_x0.25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
|
||||
"PPLCNet_x0.35":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
|
||||
"PPLCNet_x0.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
|
||||
"PPLCNet_x0.75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
|
||||
"PPLCNet_x1.0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
|
||||
"PPLCNet_x1.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
|
||||
"PPLCNet_x2.0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
|
||||
"PPLCNet_x2.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
|
||||
}
|
||||
|
||||
MODEL_STAGES_PATTERN = {
|
||||
"PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"]
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
|
||||
# k: kernel_size
|
||||
# in_c: input channel number in depthwise block
|
||||
# out_c: output channel number in depthwise block
|
||||
# s: stride in depthwise block
|
||||
# use_se: whether to use SE block
|
||||
|
||||
NET_CONFIG = {
|
||||
"blocks2":
|
||||
# k, in_c, out_c, s, use_se
|
||||
[[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5":
|
||||
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
num_groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = BatchNorm(
|
||||
num_filters,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.hardswish = nn.Hardswish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.hardswish(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
dw_size=3,
|
||||
use_se=False):
|
||||
super().__init__()
|
||||
self.use_se = use_se
|
||||
self.dw_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_channels,
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
num_groups=num_channels)
|
||||
if use_se:
|
||||
self.se = SEModule(num_channels)
|
||||
self.pw_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
filter_size=1,
|
||||
num_filters=num_filters,
|
||||
stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_conv(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
x = self.pw_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super().__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.hardsigmoid = nn.Hardsigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.hardsigmoid(x)
|
||||
x = paddle.multiply(x=identity, y=x)
|
||||
return x
|
||||
|
||||
|
||||
class PPLCNet(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
scale=1.0,
|
||||
pretrained=False,
|
||||
use_ssld=False):
|
||||
super().__init__()
|
||||
self.out_channels = [
|
||||
int(NET_CONFIG["blocks3"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks4"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks5"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks6"][-1][2] * scale)
|
||||
]
|
||||
self.scale = scale
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels,
|
||||
filter_size=3,
|
||||
num_filters=make_divisible(16 * scale),
|
||||
stride=2)
|
||||
|
||||
self.blocks2 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
|
||||
self.blocks5 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
if pretrained:
|
||||
self._load_pretrained(
|
||||
MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
x = self.conv1(x)
|
||||
x = self.blocks2(x)
|
||||
x = self.blocks3(x)
|
||||
outs.append(x)
|
||||
x = self.blocks4(x)
|
||||
outs.append(x)
|
||||
x = self.blocks5(x)
|
||||
outs.append(x)
|
||||
x = self.blocks6(x)
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _load_pretrained(self, pretrained_url, use_ssld=False):
|
||||
if use_ssld:
|
||||
pretrained_url = pretrained_url.replace("_pretrained",
|
||||
"_ssld_pretrained")
|
||||
print(pretrained_url)
|
||||
local_weight_path = get_path_from_url(
|
||||
pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
|
||||
param_state_dict = paddle.load(local_weight_path)
|
||||
self.set_dict(param_state_dict)
|
||||
return
|
|
@ -44,7 +44,7 @@ def build_head(config):
|
|||
#kie head
|
||||
from .kie_sdmgr_head import SDMGRHead
|
||||
|
||||
from .table_att_head import TableAttentionHead
|
||||
from .table_att_head import TableAttentionHead, SLAHead
|
||||
from .table_master_head import TableMasterHead
|
||||
|
||||
support_dict = [
|
||||
|
@ -52,7 +52,7 @@ def build_head(config):
|
|||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'VLHead', 'RobustScannerHead'
|
||||
'VLHead', 'SLAHead', 'RobustScannerHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -18,12 +18,26 @@ from __future__ import print_function
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .rec_att_head import AttentionGRUCell
|
||||
|
||||
|
||||
def get_para_bias_attr(l2_decay, k):
|
||||
if l2_decay > 0:
|
||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||
else:
|
||||
regularizer = None
|
||||
initializer = None
|
||||
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
return [weight_attr, bias_attr]
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
|
@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
|
|||
in_max_len=488,
|
||||
max_text_length=800,
|
||||
out_channels=30,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
|
@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
|
|||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size,
|
||||
point_num * 2)
|
||||
loc_reg_num)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
|
@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
|
|||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
|
||||
|
||||
|
||||
class SLAHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
out_channels=30,
|
||||
max_text_length=500,
|
||||
loc_reg_num=4,
|
||||
fc_decay=0.0,
|
||||
**kwargs):
|
||||
"""
|
||||
@param in_channels: input shape
|
||||
@param hidden_size: hidden_size for RNN and Embedding
|
||||
@param out_channels: num_classes to rec
|
||||
@param max_text_length: max text pred
|
||||
"""
|
||||
super().__init__()
|
||||
in_channels = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.max_text_length = max_text_length
|
||||
self.emb = self._char_to_onehot
|
||||
self.num_embeddings = out_channels
|
||||
|
||||
# structure
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
in_channels, hidden_size, self.num_embeddings)
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
self.structure_generator = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
weight_attr=weight_attr1_2,
|
||||
bias_attr=bias_attr1_2),
|
||||
nn.Linear(
|
||||
hidden_size,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr))
|
||||
# loc
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=self.hidden_size)
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=self.hidden_size)
|
||||
self.loc_generator = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1),
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
loc_reg_num,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
fea = inputs[-1]
|
||||
batch_size = fea.shape[0]
|
||||
# reshape
|
||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
structure_preds = []
|
||||
loc_preds = []
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_text_length + 1):
|
||||
hidden, structure_step, loc_step = self._decode(structure[:, i],
|
||||
fea, hidden)
|
||||
structure_preds.append(structure_step)
|
||||
loc_preds.append(loc_step)
|
||||
else:
|
||||
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
max_text_length = paddle.to_tensor(self.max_text_length)
|
||||
# for export
|
||||
loc_step, structure_step = None, None
|
||||
for i in range(max_text_length + 1):
|
||||
hidden, structure_step, loc_step = self._decode(pre_chars, fea,
|
||||
hidden)
|
||||
pre_chars = structure_step.argmax(axis=1, dtype="int32")
|
||||
structure_preds.append(structure_step)
|
||||
loc_preds.append(loc_step)
|
||||
structure_preds = paddle.stack(structure_preds, axis=1)
|
||||
loc_preds = paddle.stack(loc_preds, axis=1)
|
||||
if not self.training:
|
||||
structure_preds = F.softmax(structure_preds)
|
||||
return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
|
||||
|
||||
def _decode(self, pre_chars, features, hidden):
|
||||
"""
|
||||
Predict table label and coordinates for each step
|
||||
@param pre_chars: Table label in previous step
|
||||
@param features:
|
||||
@param hidden: hidden status in previous step
|
||||
@return:
|
||||
"""
|
||||
emb_feature = self.emb(pre_chars)
|
||||
# output shape is b * self.hidden_size
|
||||
(output, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, features, emb_feature)
|
||||
|
||||
# structure
|
||||
structure_step = self.structure_generator(output)
|
||||
# loc
|
||||
loc_step = self.loc_generator(output)
|
||||
return hidden, structure_step, loc_step
|
||||
|
||||
def _char_to_onehot(self, input_char):
|
||||
input_ont_hot = F.one_hot(input_char, self.num_embeddings)
|
||||
return input_ont_hot
|
||||
|
|
|
@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
|
|||
d_ff=2048,
|
||||
dropout=0,
|
||||
max_text_length=500,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableMasterHead, self).__init__()
|
||||
hidden_size = in_channels[-1]
|
||||
|
@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
|
|||
self.cls_fc = nn.Linear(hidden_size, out_channels)
|
||||
self.bbox_fc = nn.Sequential(
|
||||
# nn.Linear(hidden_size, hidden_size),
|
||||
nn.Linear(hidden_size, point_num * 2),
|
||||
nn.Linear(hidden_size, loc_reg_num),
|
||||
nn.Sigmoid())
|
||||
self.norm = nn.LayerNorm(hidden_size)
|
||||
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
|
||||
|
@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
|
|||
self.SOS = out_channels - 3
|
||||
self.PAD = out_channels - 1
|
||||
self.out_channels = out_channels
|
||||
self.point_num = point_num
|
||||
self.loc_reg_num = loc_reg_num
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def make_mask(self, tgt):
|
||||
|
@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
|
|||
output = paddle.zeros(
|
||||
[input.shape[0], self.max_text_length + 1, self.out_channels])
|
||||
bbox_output = paddle.zeros(
|
||||
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
|
||||
[input.shape[0], self.max_text_length + 1, self.loc_reg_num])
|
||||
max_text_length = paddle.to_tensor(self.max_text_length)
|
||||
for i in range(max_text_length + 1):
|
||||
target_mask = self.make_mask(input)
|
||||
|
|
|
@ -25,9 +25,10 @@ def build_neck(config):
|
|||
from .fpn import FPN
|
||||
from .fce_fpn import FCEFPN
|
||||
from .pren_fpn import PRENFPN
|
||||
from .csp_pan import CSPPAN
|
||||
support_dict = [
|
||||
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The code is based on:
|
||||
# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['CSPPAN']
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channel=96,
|
||||
out_channel=96,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act='leaky_relu'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
initializer = nn.initializer.KaimingUniform()
|
||||
self.act = act
|
||||
assert self.act in ['leaky_relu', "hard_swish"]
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=kernel_size,
|
||||
groups=groups,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
stride=stride,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm2D(out_channel)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(self.conv(x))
|
||||
if self.act == "leaky_relu":
|
||||
x = F.leaky_relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = F.hardswish(x)
|
||||
return x
|
||||
|
||||
|
||||
class DPModule(nn.Layer):
|
||||
"""
|
||||
Depth-wise and point-wise module.
|
||||
Args:
|
||||
in_channel (int): The input channels of this Module.
|
||||
out_channel (int): The output channels of this Module.
|
||||
kernel_size (int): The conv2d kernel size of this Module.
|
||||
stride (int): The conv2d's stride of this Module.
|
||||
act (str): The activation function of this Module,
|
||||
Now support `leaky_relu` and `hard_swish`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channel=96,
|
||||
out_channel=96,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='leaky_relu'):
|
||||
super(DPModule, self).__init__()
|
||||
initializer = nn.initializer.KaimingUniform()
|
||||
self.act = act
|
||||
self.dwconv = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=kernel_size,
|
||||
groups=out_channel,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
stride=stride,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(out_channel)
|
||||
self.pwconv = nn.Conv2D(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
groups=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm2D(out_channel)
|
||||
|
||||
def act_func(self, x):
|
||||
if self.act == "leaky_relu":
|
||||
x = F.leaky_relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = F.hardswish(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act_func(self.bn1(self.dwconv(x)))
|
||||
x = self.act_func(self.bn2(self.pwconv(x)))
|
||||
return x
|
||||
|
||||
|
||||
class DarknetBottleneck(nn.Layer):
|
||||
"""The basic bottleneck block used in Darknet.
|
||||
Each Block consists of two ConvModules and the input is added to the
|
||||
final output. Each ConvModule is composed of Conv, BN, and act.
|
||||
The first convLayer has filter size of 1x1 and the second one has the
|
||||
filter size of 3x3.
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
expansion (int): The kernel size of the convolution. Default: 0.5
|
||||
add_identity (bool): Whether to add identity to the out.
|
||||
Default: True
|
||||
use_depthwise (bool): Whether to use depthwise separable convolution.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
expansion=0.5,
|
||||
add_identity=True,
|
||||
use_depthwise=False,
|
||||
act="leaky_relu"):
|
||||
super(DarknetBottleneck, self).__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
conv_func = DPModule if use_depthwise else ConvBNLayer
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channel=in_channels,
|
||||
out_channel=hidden_channels,
|
||||
kernel_size=1,
|
||||
act=act)
|
||||
self.conv2 = conv_func(
|
||||
in_channel=hidden_channels,
|
||||
out_channel=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
act=act)
|
||||
self.add_identity = \
|
||||
add_identity and in_channels == out_channels
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.add_identity:
|
||||
return out + identity
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class CSPLayer(nn.Layer):
|
||||
"""Cross Stage Partial Layer.
|
||||
Args:
|
||||
in_channels (int): The input channels of the CSP layer.
|
||||
out_channels (int): The output channels of the CSP layer.
|
||||
expand_ratio (float): Ratio to adjust the number of channels of the
|
||||
hidden layer. Default: 0.5
|
||||
num_blocks (int): Number of blocks. Default: 1
|
||||
add_identity (bool): Whether to add identity in blocks.
|
||||
Default: True
|
||||
use_depthwise (bool): Whether to depthwise separable convolution in
|
||||
blocks. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
expand_ratio=0.5,
|
||||
num_blocks=1,
|
||||
add_identity=True,
|
||||
use_depthwise=False,
|
||||
act="leaky_relu"):
|
||||
super().__init__()
|
||||
mid_channels = int(out_channels * expand_ratio)
|
||||
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
|
||||
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
|
||||
self.final_conv = ConvBNLayer(
|
||||
2 * mid_channels, out_channels, 1, act=act)
|
||||
|
||||
self.blocks = nn.Sequential(* [
|
||||
DarknetBottleneck(
|
||||
mid_channels,
|
||||
mid_channels,
|
||||
kernel_size,
|
||||
1.0,
|
||||
add_identity,
|
||||
use_depthwise,
|
||||
act=act) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x_short = self.short_conv(x)
|
||||
|
||||
x_main = self.main_conv(x)
|
||||
x_main = self.blocks(x_main)
|
||||
|
||||
x_final = paddle.concat((x_main, x_short), axis=1)
|
||||
return self.final_conv(x_final)
|
||||
|
||||
|
||||
class Channel_T(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=[116, 232, 464],
|
||||
out_channels=96,
|
||||
act="leaky_relu"):
|
||||
super(Channel_T, self).__init__()
|
||||
self.convs = nn.LayerList()
|
||||
for i in range(len(in_channels)):
|
||||
self.convs.append(
|
||||
ConvBNLayer(
|
||||
in_channels[i], out_channels, 1, act=act))
|
||||
|
||||
def forward(self, x):
|
||||
outs = [self.convs[i](x[i]) for i in range(len(x))]
|
||||
return outs
|
||||
|
||||
|
||||
class CSPPAN(nn.Layer):
|
||||
"""Path Aggregation Network with CSP module.
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
kernel_size (int): The conv2d kernel size of this Module.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
|
||||
use_depthwise (bool): Whether to depthwise separable convolution in
|
||||
blocks. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=5,
|
||||
num_csp_blocks=1,
|
||||
use_depthwise=True,
|
||||
act='hard_swish'):
|
||||
super(CSPPAN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = [out_channels] * len(in_channels)
|
||||
conv_func = DPModule if use_depthwise else ConvBNLayer
|
||||
|
||||
self.conv_t = Channel_T(in_channels, out_channels, act=act)
|
||||
|
||||
# build top-down blocks
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.top_down_blocks = nn.LayerList()
|
||||
for idx in range(len(in_channels) - 1, 0, -1):
|
||||
self.top_down_blocks.append(
|
||||
CSPLayer(
|
||||
out_channels * 2,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
num_blocks=num_csp_blocks,
|
||||
add_identity=False,
|
||||
use_depthwise=use_depthwise,
|
||||
act=act))
|
||||
|
||||
# build bottom-up blocks
|
||||
self.downsamples = nn.LayerList()
|
||||
self.bottom_up_blocks = nn.LayerList()
|
||||
for idx in range(len(in_channels) - 1):
|
||||
self.downsamples.append(
|
||||
conv_func(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=2,
|
||||
act=act))
|
||||
self.bottom_up_blocks.append(
|
||||
CSPLayer(
|
||||
out_channels * 2,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
num_blocks=num_csp_blocks,
|
||||
add_identity=False,
|
||||
use_depthwise=use_depthwise,
|
||||
act=act))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Args:
|
||||
inputs (tuple[Tensor]): input features.
|
||||
Returns:
|
||||
tuple[Tensor]: CSPPAN features.
|
||||
"""
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
inputs = self.conv_t(inputs)
|
||||
|
||||
# top-down path
|
||||
inner_outs = [inputs[-1]]
|
||||
for idx in range(len(self.in_channels) - 1, 0, -1):
|
||||
feat_heigh = inner_outs[0]
|
||||
feat_low = inputs[idx - 1]
|
||||
upsample_feat = F.upsample(
|
||||
feat_heigh, size=paddle.shape(feat_low)[2:4], mode="nearest")
|
||||
|
||||
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
|
||||
paddle.concat([upsample_feat, feat_low], 1))
|
||||
inner_outs.insert(0, inner_out)
|
||||
|
||||
# bottom-up path
|
||||
outs = [inner_outs[0]]
|
||||
for idx in range(len(self.in_channels) - 1):
|
||||
feat_low = outs[-1]
|
||||
feat_height = inner_outs[idx + 1]
|
||||
downsample_feat = self.downsamples[idx](feat_low)
|
||||
out = self.bottom_up_blocks[idx](paddle.concat(
|
||||
[downsample_feat, feat_height], 1))
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
|
@ -21,9 +21,29 @@ from .rec_postprocess import AttnLabelDecode
|
|||
class TableLabelDecode(AttnLabelDecode):
|
||||
""" """
|
||||
|
||||
def __init__(self, character_dict_path, **kwargs):
|
||||
super(TableLabelDecode, self).__init__(character_dict_path)
|
||||
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
merge_no_span_structure=False,
|
||||
**kwargs):
|
||||
dict_character = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
dict_character.append(line)
|
||||
|
||||
if merge_no_span_structure:
|
||||
if "<td></td>" not in dict_character:
|
||||
dict_character.append("<td></td>")
|
||||
if "<td>" in dict_character:
|
||||
dict_character.remove("<td>")
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
self.td_token = ['<td>', '<td', '<td></td>']
|
||||
|
||||
def __call__(self, preds, batch=None):
|
||||
structure_probs = preds['structure_probs']
|
||||
|
@ -114,18 +134,21 @@ class TableLabelDecode(AttnLabelDecode):
|
|||
|
||||
def _bbox_decode(self, bbox, shape):
|
||||
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
|
||||
src_h = h / ratio_h
|
||||
src_w = w / ratio_w
|
||||
bbox[0::2] *= src_w
|
||||
bbox[1::2] *= src_h
|
||||
bbox[0::2] *= w
|
||||
bbox[1::2] *= h
|
||||
return bbox
|
||||
|
||||
|
||||
class TableMasterLabelDecode(TableLabelDecode):
|
||||
""" """
|
||||
|
||||
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
|
||||
super(TableMasterLabelDecode, self).__init__(character_dict_path)
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
box_shape='ori',
|
||||
merge_no_span_structure=True,
|
||||
**kwargs):
|
||||
super(TableMasterLabelDecode, self).__init__(character_dict_path,
|
||||
merge_no_span_structure)
|
||||
self.box_shape = box_shape
|
||||
assert box_shape in [
|
||||
'ori', 'pad'
|
||||
|
@ -157,4 +180,7 @@ class TableMasterLabelDecode(TableLabelDecode):
|
|||
bbox[1::2] *= h
|
||||
bbox[0::2] /= ratio_w
|
||||
bbox[1::2] /= ratio_h
|
||||
x, y, w, h = bbox
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
bbox = np.array([x1, y1, x2, y2])
|
||||
return bbox
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
<thead>
|
||||
</thead>
|
||||
<tbody>
|
||||
</tbody>
|
||||
<tr>
|
||||
</tr>
|
||||
<td>
|
||||
<td
|
||||
>
|
||||
</td>
|
||||
colspan="2"
|
||||
colspan="3"
|
||||
colspan="4"
|
||||
colspan="5"
|
||||
colspan="6"
|
||||
colspan="7"
|
||||
colspan="8"
|
||||
colspan="9"
|
||||
colspan="10"
|
||||
colspan="11"
|
||||
colspan="12"
|
||||
colspan="13"
|
||||
colspan="14"
|
||||
colspan="15"
|
||||
colspan="16"
|
||||
colspan="17"
|
||||
colspan="18"
|
||||
colspan="19"
|
||||
colspan="20"
|
||||
rowspan="2"
|
||||
rowspan="3"
|
||||
rowspan="4"
|
||||
rowspan="5"
|
||||
rowspan="6"
|
||||
rowspan="7"
|
||||
rowspan="8"
|
||||
rowspan="9"
|
||||
rowspan="10"
|
||||
rowspan="11"
|
||||
rowspan="12"
|
||||
rowspan="13"
|
||||
rowspan="14"
|
||||
rowspan="15"
|
||||
rowspan="16"
|
||||
rowspan="17"
|
||||
rowspan="18"
|
||||
rowspan="19"
|
||||
rowspan="20"
|
|
@ -41,9 +41,7 @@ def download_with_progressbar(url, save_path):
|
|||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
tar_file_name_list = ['.pdiparams', '.pdiparams.info', '.pdmodel']
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
|
@ -57,8 +55,8 @@ def maybe_download(model_storage_directory, url):
|
|||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if member.name.endswith(tar_file_name):
|
||||
filename = 'inference' + tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
|
|
|
@ -113,14 +113,11 @@ def draw_re_results(image,
|
|||
return np.array(img_new)
|
||||
|
||||
|
||||
def draw_rectangle(img_path, boxes, use_xywh=False):
|
||||
def draw_rectangle(img_path, boxes):
|
||||
boxes = np.array(boxes)
|
||||
img = cv2.imread(img_path)
|
||||
img_show = img.copy()
|
||||
for box in boxes.astype(int):
|
||||
if use_xywh:
|
||||
x, y, w, h = box
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
else:
|
||||
x1, y1, x2, y2 = box
|
||||
x1, y1, x2, y2 = box
|
||||
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||
return img_show
|
|
@ -106,9 +106,9 @@ PP-Structure Series Model List (Updating)
|
|||
|
||||
|model name|description|model size|download|
|
||||
| --- | --- | --- | --- |
|
||||
|ch_PP-OCRv2_det_slim|[New] Slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|
||||
|ch_PP-OCRv2_rec_slim|[New] Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text recognition| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset| 18.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|ch_PP-OCRv3_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection| 1.1M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)|
|
||||
|ch_PP-OCRv3_rec_slim |[New] Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition| 4.9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) |
|
||||
|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|
||||
### 7.3 DOC-VQA model
|
||||
|
||||
|
|
|
@ -120,9 +120,10 @@ PP-Structure系列模型列表(更新中)
|
|||
|
||||
|模型名称|模型简介|模型大小|下载地址|
|
||||
| --- | --- | --- | --- |
|
||||
|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|
||||
|ch_PP-OCRv2_rec_slim|【最新】slim量化版超轻量模型,支持中英文、数字识别| 9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|ch_PP-OCRv3_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测| 1.1M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar)|
|
||||
|ch_PP-OCRv3_rec_slim |【最新】slim量化版超轻量模型,支持中英文、数字识别| 4.9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) |
|
||||
|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|
||||
|
||||
<a name="73"></a>
|
||||
### 7.3 DocVQA 模型
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 78 KiB |
Binary file not shown.
After Width: | Height: | Size: 369 KiB |
Binary file not shown.
After Width: | Height: | Size: 772 KiB |
Binary file not shown.
After Width: | Height: | Size: 467 KiB |
|
@ -1,8 +1,7 @@
|
|||
- [快速安装](#快速安装)
|
||||
- [1. PaddlePaddle 和 PaddleOCR](#1-paddlepaddle-和-paddleocr)
|
||||
- [2. 安装其他依赖](#2-安装其他依赖)
|
||||
- [2.1 版面分析所需 Layout-Parser](#21-版面分析所需--layout-parser)
|
||||
- [2.2 VQA所需依赖](#22--vqa所需依赖)
|
||||
- [2.1 VQA所需依赖](#21--vqa所需依赖)
|
||||
|
||||
# 快速安装
|
||||
|
||||
|
@ -12,14 +11,7 @@
|
|||
|
||||
## 2. 安装其他依赖
|
||||
|
||||
### 2.1 版面分析所需 Layout-Parser
|
||||
|
||||
Layout-Parser 可通过如下命令安装
|
||||
|
||||
```bash
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
### 2.2 VQA所需依赖
|
||||
### 2.1 VQA所需依赖
|
||||
* paddleocr
|
||||
|
||||
```bash
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Quick installation
|
||||
|
||||
- [1. PaddlePaddle 和 PaddleOCR](#1)
|
||||
- [2. Install other dependencies](#2)
|
||||
- [2.1 VQA](#21)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. PaddlePaddle and PaddleOCR
|
||||
|
||||
Please refer to [PaddleOCR installation documentation](../../doc/doc_en/installation_en.md)
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. Install other dependencies
|
||||
|
||||
<a name="21"></a>
|
||||
### 2.1 VQA
|
||||
|
||||
* paddleocr
|
||||
|
||||
```bash
|
||||
pip3 install paddleocr
|
||||
```
|
||||
|
||||
* PaddleNLP
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
|
||||
cd PaddleNLP
|
||||
pip3 install -e .
|
||||
```
|
|
@ -34,7 +34,9 @@
|
|||
|
||||
|模型名称|模型简介|推理模型大小|下载地址|
|
||||
| --- | --- | --- | --- |
|
||||
|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|基于TableRec-RARE在PubTabNet数据集上训练的英文表格识别模型|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|en_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的英文表格识别模型|9M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|ch_ppstructure_mobile_v2.0_SLANet|基于SLANet在PubTabNet数据集上训练的中文表格识别模型|9.3M|[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|
||||
<a name="3"></a>
|
||||
|
||||
|
|
|
@ -35,7 +35,9 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|
|||
|
||||
|model| description |inference model size|download|
|
||||
| --- |-----------------------------------------------------------------------------| --- | --- |
|
||||
|en_ppocr_mobile_v2.0_table_structure| Table structure model for English table scenes trained on PubTabNet dataset |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure| English table recognition model trained on PubTabNet dataset based on TableRec-RARE |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|en_ppstructure_mobile_v2.0_SLANet|English table recognition model trained on PubTabNet dataset based on SLANet|9M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|ch_ppstructure_mobile_v2.0_SLANet|Chinese table recognition model trained on PubTabNet dataset based on SLANet|9.3M|[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_train.tar) |
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. VQA
|
||||
|
|
|
@ -1,21 +1,23 @@
|
|||
# PP-Structure 快速开始
|
||||
|
||||
- [1. 安装依赖包](#1)
|
||||
- [2. 便捷使用](#2)
|
||||
- [2.1 命令行使用](#21)
|
||||
- [2.1.1 版面分析+表格识别](#211)
|
||||
- [2.1.2 版面分析](#212)
|
||||
- [2.1.3 表格识别](#213)
|
||||
- [2.1.4 DocVQA](#214)
|
||||
- [2.2 代码使用](#22)
|
||||
- [2.2.1 版面分析+表格识别](#221)
|
||||
- [2.2.2 版面分析](#222)
|
||||
- [2.2.3 表格识别](#223)
|
||||
- [2.2.4 DocVQA](#224)
|
||||
- [2.3 返回结果说明](#23)
|
||||
- [2.3.1 版面分析+表格识别](#231)
|
||||
- [2.3.2 DocVQA](#232)
|
||||
- [2.4 参数说明](#24)
|
||||
- [1. 安装依赖包](#1-安装依赖包)
|
||||
- [2. 便捷使用](#2-便捷使用)
|
||||
- [2.1 命令行使用](#21-命令行使用)
|
||||
- [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别)
|
||||
- [2.1.2 版面分析+表格识别](#212-版面分析表格识别)
|
||||
- [2.1.3 版面分析](#213-版面分析)
|
||||
- [2.1.4 表格识别](#214-表格识别)
|
||||
- [2.1.5 DocVQA](#215-docvqa)
|
||||
- [2.2 代码使用](#22-代码使用)
|
||||
- [2.2.1 图像方向分类版面分析表格识别](#221-图像方向分类版面分析表格识别)
|
||||
- [2.2.2 版面分析+表格识别](#222-版面分析表格识别)
|
||||
- [2.2.3 版面分析](#223-版面分析)
|
||||
- [2.2.4 表格识别](#224-表格识别)
|
||||
- [2.2.5 DocVQA](#225-docvqa)
|
||||
- [2.3 返回结果说明](#23-返回结果说明)
|
||||
- [2.3.1 版面分析+表格识别](#231-版面分析表格识别)
|
||||
- [2.3.2 DocVQA](#232-docvqa)
|
||||
- [2.4 参数说明](#24-参数说明)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
|
@ -24,8 +26,6 @@
|
|||
```bash
|
||||
# 安装 paddleocr,推荐使用2.5+版本
|
||||
pip3 install "paddleocr>=2.5"
|
||||
# 安装 版面分析依赖包layoutparser(如不需要版面分析功能,可跳过)
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
# 安装 DocVQA依赖包paddlenlp(如不需要DocVQA功能,可跳过)
|
||||
pip install paddlenlp
|
||||
|
||||
|
@ -38,25 +38,31 @@ pip install paddlenlp
|
|||
### 2.1 命令行使用
|
||||
|
||||
<a name="211"></a>
|
||||
#### 2.1.1 版面分析+表格识别
|
||||
#### 2.1.1 图像方向分类+版面分析+表格识别
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
|
||||
```
|
||||
|
||||
<a name="212"></a>
|
||||
#### 2.1.2 版面分析+表格识别
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
|
||||
```
|
||||
|
||||
<a name="212"></a>
|
||||
#### 2.1.2 版面分析
|
||||
<a name="213"></a>
|
||||
#### 2.1.3 版面分析
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
|
||||
```
|
||||
|
||||
<a name="213"></a>
|
||||
#### 2.1.3 表格识别
|
||||
<a name="214"></a>
|
||||
#### 2.1.4 表格识别
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
|
||||
```
|
||||
|
||||
<a name="214"></a>
|
||||
#### 2.1.4 DocVQA
|
||||
<a name="215"></a>
|
||||
#### 2.1.5 DocVQA
|
||||
|
||||
请参考:[文档视觉问答](../vqa/README.md)。
|
||||
|
||||
|
@ -64,7 +70,36 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structur
|
|||
### 2.2 代码使用
|
||||
|
||||
<a name="221"></a>
|
||||
#### 2.2.1 版面分析+表格识别
|
||||
#### 2.2.1 图像方向分类版面分析表格识别
|
||||
|
||||
```python
|
||||
import os
|
||||
import cv2
|
||||
from paddleocr import PPStructure,draw_structure_result,save_structure_res
|
||||
|
||||
table_engine = PPStructure(show_log=True, image_orientation=True)
|
||||
|
||||
save_folder = './output'
|
||||
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
line.pop('img')
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_structure_result(image, result,font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
<a name="222"></a>
|
||||
#### 2.2.2 版面分析+表格识别
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -92,8 +127,8 @@ im_show = Image.fromarray(im_show)
|
|||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
<a name="222"></a>
|
||||
#### 2.2.2 版面分析
|
||||
<a name="223"></a>
|
||||
#### 2.2.3 版面分析
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -113,8 +148,8 @@ for line in result:
|
|||
print(line)
|
||||
```
|
||||
|
||||
<a name="223"></a>
|
||||
#### 2.2.3 表格识别
|
||||
<a name="224"></a>
|
||||
#### 2.2.4 表格识别
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -134,8 +169,8 @@ for line in result:
|
|||
print(line)
|
||||
```
|
||||
|
||||
<a name="224"></a>
|
||||
#### 2.2.4 DocVQA
|
||||
<a name="225"></a>
|
||||
#### 2.2.5 DocVQA
|
||||
|
||||
请参考:[文档视觉问答](../vqa/README.md)。
|
||||
|
||||
|
@ -156,10 +191,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
|
|||
```
|
||||
dict 里各个字段说明如下
|
||||
|
||||
| 字段 | 说明 |
|
||||
| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
|type| 图片区域的类型 |
|
||||
|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
|
||||
| 字段 | 说明|
|
||||
| --- |---|
|
||||
|type| 图片区域的类型 |
|
||||
|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
|
||||
|res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict,字段说明如下<br>        `html`: 表格的HTML字符串<br>        在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段: <br>        `boxes`: 文本检测坐标<br>        `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
|
||||
|
||||
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
|
||||
|
@ -180,20 +215,26 @@ dict 里各个字段说明如下
|
|||
<a name="24"></a>
|
||||
### 2.4 参数说明
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
|
||||
| output | excel和识别结果保存的地址 | ./output/table |
|
||||
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
|
||||
| table_model_dir | 表格结构模型 inference 模型地址 | None |
|
||||
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
|
||||
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
|
||||
| layout_label_map | 版面分析模型模型label映射字典 | None |
|
||||
| model_name_or_path | VQA SER模型地址 | None |
|
||||
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
|
||||
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
|
||||
| layout | 前向中是否执行版面分析 | True |
|
||||
| table | 前向中是否执行表格识别 | True |
|
||||
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
|
||||
| structure_version | 表格结构化模型版本,可选 PP-STRUCTURE。PP-STRUCTURE支持表格结构化模型 | PP-STRUCTURE |
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|---|---|---|
|
||||
| output | 结果保存地址 | ./output/table |
|
||||
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
|
||||
| table_model_dir | 表格结构模型 inference 模型地址| None |
|
||||
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
|
||||
| merge_no_span_structure | 表格识别模型中,是否对'\<td>'和'\</td>' 进行合并 | False |
|
||||
| layout_model_dir | 版面分析模型 inference 模型地址 | None |
|
||||
| layout_dict_path | 版面分析模型字典| ../ppocr/utils/dict/layout_publaynet_dict.txt |
|
||||
| layout_score_threshold | 版面分析模型检测框阈值| 0.5|
|
||||
| layout_nms_threshold | 版面分析模型nms阈值| 0.5|
|
||||
| vqa_algorithm | vqa模型算法| LayoutXLM|
|
||||
| ser_model_dir | ser模型 inference 模型地址| None|
|
||||
| ser_dict_path | ser模型字典| ../train_data/XFUND/class_list_xfun.txt|
|
||||
| mode | structure or vqa | structure |
|
||||
| image_orientation | 前向中是否执行图像方向分类 | False |
|
||||
| layout | 前向中是否执行版面分析 | True |
|
||||
| table | 前向中是否执行表格识别 | True |
|
||||
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False| True |
|
||||
| recovery | 前向中是否执行版面恢复| False |
|
||||
| structure_version | 模型版本,可选 PP-structure和PP-structurev2 | PP-structure |
|
||||
|
||||
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
|
||||
|
|
|
@ -1,21 +1,23 @@
|
|||
# PP-Structure Quick Start
|
||||
|
||||
- [1. Install package](#1)
|
||||
- [2. Use](#2)
|
||||
- [2.1 Use by command line](#21)
|
||||
- [2.1.1 layout analysis + table recognition](#211)
|
||||
- [2.1.2 layout analysis](#212)
|
||||
- [2.1.3 table recognition](#213)
|
||||
- [2.1.4 DocVQA](#214)
|
||||
- [2.2 Use by code](#22)
|
||||
- [2.2.1 layout analysis + table recognition](#221)
|
||||
- [2.2.2 layout analysis](#222)
|
||||
- [2.2.3 table recognition](#223)
|
||||
- [2.2.4 DocVQA](#224)
|
||||
- [2.3 Result description](#23)
|
||||
- [2.3.1 layout analysis + table recognition](#231)
|
||||
- [2.3.2 DocVQA](#232)
|
||||
- [2.4 Parameter Description](#24)
|
||||
- [1. Install package](#1-install-package)
|
||||
- [2. Use](#2-use)
|
||||
- [2.1 Use by command line](#21-use-by-command-line)
|
||||
- [2.1.1 image orientation + layout analysis + table recognition](#211-image-orientation--layout-analysis--table-recognition)
|
||||
- [2.1.2 layout analysis + table recognition](#212-layout-analysis--table-recognition)
|
||||
- [2.1.3 layout analysis](#213-layout-analysis)
|
||||
- [2.1.4 table recognition](#214-table-recognition)
|
||||
- [2.1.5 DocVQA](#215-docvqa)
|
||||
- [2.2 Use by code](#22-use-by-code)
|
||||
- [2.2.1 image orientation + layout analysis + table recognition](#221-image-orientation--layout-analysis--table-recognition)
|
||||
- [2.2.2 layout analysis + table recognition](#222-layout-analysis--table-recognition)
|
||||
- [2.2.3 layout analysis](#223-layout-analysis)
|
||||
- [2.2.4 table recognition](#224-table-recognition)
|
||||
- [2.2.5 DocVQA](#225-docvqa)
|
||||
- [2.3 Result description](#23-result-description)
|
||||
- [2.3.1 layout analysis + table recognition](#231-layout-analysis--table-recognition)
|
||||
- [2.3.2 DocVQA](#232-docvqa)
|
||||
- [2.4 Parameter Description](#24-parameter-description)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
|
@ -24,8 +26,6 @@
|
|||
```bash
|
||||
# Install paddleocr, version 2.5+ is recommended
|
||||
pip3 install "paddleocr>=2.5"
|
||||
# Install layoutparser (if you do not use the layout analysis, you can skip it)
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
# Install the DocVQA dependency package paddlenlp (if you do not use the DocVQA, you can skip it)
|
||||
pip install paddlenlp
|
||||
|
||||
|
@ -38,25 +38,31 @@ pip install paddlenlp
|
|||
### 2.1 Use by command line
|
||||
|
||||
<a name="211"></a>
|
||||
#### 2.1.1 layout analysis + table recognition
|
||||
#### 2.1.1 image orientation + layout analysis + table recognition
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
|
||||
```
|
||||
|
||||
<a name="212"></a>
|
||||
#### 2.1.2 layout analysis + table recognition
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
|
||||
```
|
||||
|
||||
<a name="212"></a>
|
||||
#### 2.1.2 layout analysis
|
||||
<a name="213"></a>
|
||||
#### 2.1.3 layout analysis
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
|
||||
```
|
||||
|
||||
<a name="213"></a>
|
||||
#### 2.1.3 table recognition
|
||||
<a name="214"></a>
|
||||
#### 2.1.4 table recognition
|
||||
```bash
|
||||
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
|
||||
```
|
||||
|
||||
<a name="214"></a>
|
||||
#### 2.1.4 DocVQA
|
||||
<a name="215"></a>
|
||||
#### 2.1.5 DocVQA
|
||||
|
||||
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
|
||||
|
||||
|
@ -64,7 +70,36 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
|
|||
### 2.2 Use by code
|
||||
|
||||
<a name="221"></a>
|
||||
#### 2.2.1 layout analysis + table recognition
|
||||
#### 2.2.1 image orientation + layout analysis + table recognition
|
||||
|
||||
```python
|
||||
import os
|
||||
import cv2
|
||||
from paddleocr import PPStructure,draw_structure_result,save_structure_res
|
||||
|
||||
table_engine = PPStructure(show_log=True, image_orientation=True)
|
||||
|
||||
save_folder = './output'
|
||||
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
line.pop('img')
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_structure_result(image, result,font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
<a name="222"></a>
|
||||
#### 2.2.2 layout analysis + table recognition
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -92,8 +127,8 @@ im_show = Image.fromarray(im_show)
|
|||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
<a name="222"></a>
|
||||
#### 2.2.2 layout analysis
|
||||
<a name="223"></a>
|
||||
#### 2.2.3 layout analysis
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -113,8 +148,8 @@ for line in result:
|
|||
print(line)
|
||||
```
|
||||
|
||||
<a name="223"></a>
|
||||
#### 2.2.3 table recognition
|
||||
<a name="224"></a>
|
||||
#### 2.2.4 table recognition
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -134,8 +169,8 @@ for line in result:
|
|||
print(line)
|
||||
```
|
||||
|
||||
<a name="224"></a>
|
||||
#### 2.2.4 DocVQA
|
||||
<a name="225"></a>
|
||||
#### 2.2.5 DocVQA
|
||||
|
||||
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
|
||||
|
||||
|
@ -157,8 +192,8 @@ The return of PP-Structure is a list of dicts, the example is as follows:
|
|||
```
|
||||
Each field in dict is described as follows:
|
||||
|
||||
| field | description |
|
||||
| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| field | description |
|
||||
| --- |---|
|
||||
|type| Type of image area. |
|
||||
|bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. |
|
||||
|res| OCR or table recognition result of the image area. <br> table: a dict with field descriptions as follows: <br>        `html`: html str of table.<br>        In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields: <br>        `boxes`: text detection boxes.<br>        `rec_res`: text recognition results.<br> OCR: A tuple containing the detection boxes and recognition results of each single text. |
|
||||
|
@ -180,19 +215,26 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
|
|||
<a name="24"></a>
|
||||
### 2.4 Parameter Description
|
||||
|
||||
| field | description | default |
|
||||
|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
|
||||
| output | The save path of result | ./output/table |
|
||||
| table_max_len | When the table structure model predicts, the long side of the image | 488 |
|
||||
| table_model_dir | the path of table structure model | None |
|
||||
| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
|
||||
| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
|
||||
| layout_label_map | Layout analysis model model label mapping dictionary path | None |
|
||||
| model_name_or_path | the model path of VQA SER model | None |
|
||||
| max_seq_length | the max token length of VQA SER model | 512 |
|
||||
| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt |
|
||||
| layout | Whether to perform layout analysis in forward | True |
|
||||
| table | Whether to perform table recognition in forward | True |
|
||||
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True |
|
||||
| structure_version | table structure Model version number, the current model support list is as follows: PP-STRUCTURE support english table structure model | PP-STRUCTURE |
|
||||
| field | description | default |
|
||||
|---|---|---|
|
||||
| output | result save path | ./output/table |
|
||||
| table_max_len | long side of the image resize in table structure model | 488 |
|
||||
| table_model_dir | Table structure model inference model path| None |
|
||||
| table_char_dict_path | The dictionary path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
|
||||
| merge_no_span_structure | In the table recognition model, whether to merge '\<td>' and '\</td>' | False |
|
||||
| layout_model_dir | Layout analysis model inference model path| None |
|
||||
| layout_dict_path | The dictionary path of layout analysis model| ../ppocr/utils/dict/layout_publaynet_dict.txt |
|
||||
| layout_score_threshold | The box threshold path of layout analysis model| 0.5|
|
||||
| layout_nms_threshold | The nms threshold path of layout analysis model| 0.5|
|
||||
| vqa_algorithm | vqa model algorithm| LayoutXLM|
|
||||
| ser_model_dir | Ser model inference model path| None|
|
||||
| ser_dict_path | The dictionary path of Ser model| ../train_data/XFUND/class_list_xfun.txt|
|
||||
| mode | structure or vqa | structure |
|
||||
| image_orientation | Whether to perform image orientation classification in forward | False |
|
||||
| layout | Whether to perform layout analysis in forward | True |
|
||||
| table | Whether to perform table recognition in forward | True |
|
||||
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False| True |
|
||||
| recovery | Whether to perform layout recovery in forward| False |
|
||||
| structure_version | Structure version, optional PP-structure and PP-structurev2 | PP-structure |
|
||||
|
||||
Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md)
|
||||
|
|
|
@ -18,7 +18,7 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
|
@ -27,11 +27,11 @@ import numpy as np
|
|||
import time
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from attrdict import AttrDict
|
||||
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.layout.predict_layout import LayoutPredictor
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
from ppstructure.utility import parse_args, draw_structure_result
|
||||
from ppstructure.recovery.recovery_to_doc import convert_info_docx
|
||||
|
@ -42,6 +42,14 @@ logger = get_logger()
|
|||
class StructureSystem(object):
|
||||
def __init__(self, args):
|
||||
self.mode = args.mode
|
||||
self.recovery = args.recovery
|
||||
|
||||
self.image_orientation_predictor = None
|
||||
if args.image_orientation:
|
||||
import paddleclas
|
||||
self.image_orientation_predictor = paddleclas.PaddleClas(
|
||||
model_name="text_image_orientation")
|
||||
|
||||
if self.mode == 'structure':
|
||||
if not args.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -51,28 +59,14 @@ class StructureSystem(object):
|
|||
"When args.layout is false, args.ocr is automatically set to false"
|
||||
)
|
||||
args.drop_score = 0
|
||||
# init layout and ocr model
|
||||
# init model
|
||||
self.layout_predictor = None
|
||||
self.text_system = None
|
||||
self.table_system = None
|
||||
if args.layout:
|
||||
import layoutparser as lp
|
||||
config_path = None
|
||||
model_path = None
|
||||
if os.path.isdir(args.layout_path_model):
|
||||
model_path = args.layout_path_model
|
||||
else:
|
||||
config_path = args.layout_path_model
|
||||
self.table_layout = lp.PaddleDetectionLayoutModel(
|
||||
config_path=config_path,
|
||||
model_path=model_path,
|
||||
label_map=args.layout_label_map,
|
||||
threshold=0.5,
|
||||
enable_mkldnn=args.enable_mkldnn,
|
||||
enforce_cpu=not args.use_gpu,
|
||||
thread_num=args.cpu_threads)
|
||||
self.layout_predictor = LayoutPredictor(args)
|
||||
if args.ocr:
|
||||
self.text_system = TextSystem(args)
|
||||
else:
|
||||
self.table_layout = None
|
||||
if args.table:
|
||||
if self.text_system is not None:
|
||||
self.table_system = TableSystem(
|
||||
|
@ -80,39 +74,78 @@ class StructureSystem(object):
|
|||
self.text_system.text_recognizer)
|
||||
else:
|
||||
self.table_system = TableSystem(args)
|
||||
else:
|
||||
self.table_system = None
|
||||
|
||||
elif self.mode == 'vqa':
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, img, return_ocr_result_in_table=False):
|
||||
time_dict = {
|
||||
'image_orientation': 0,
|
||||
'layout': 0,
|
||||
'table': 0,
|
||||
'table_match': 0,
|
||||
'det': 0,
|
||||
'rec': 0,
|
||||
'vqa': 0,
|
||||
'all': 0
|
||||
}
|
||||
start = time.time()
|
||||
if self.image_orientation_predictor is not None:
|
||||
tic = time.time()
|
||||
cls_result = self.image_orientation_predictor.predict(
|
||||
input_data=img)
|
||||
cls_res = next(cls_result)
|
||||
angle = cls_res[0]['label_names'][0]
|
||||
cv_rotate_code = {
|
||||
'90': cv2.ROTATE_90_COUNTERCLOCKWISE,
|
||||
'180': cv2.ROTATE_180,
|
||||
'270': cv2.ROTATE_90_CLOCKWISE
|
||||
}
|
||||
img = cv2.rotate(img, cv_rotate_code[angle])
|
||||
toc = time.time()
|
||||
time_dict['image_orientation'] = toc - tic
|
||||
if self.mode == 'structure':
|
||||
ori_im = img.copy()
|
||||
if self.table_layout is not None:
|
||||
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||
if self.layout_predictor is not None:
|
||||
layout_res, elapse = self.layout_predictor(img)
|
||||
time_dict['layout'] += elapse
|
||||
else:
|
||||
h, w = ori_im.shape[:2]
|
||||
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
|
||||
layout_res = [dict(bbox=None, label='table')]
|
||||
res_list = []
|
||||
for region in layout_res:
|
||||
res = ''
|
||||
x1, y1, x2, y2 = region.coordinates
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
if region.type == 'Table':
|
||||
if region['bbox'] is not None:
|
||||
x1, y1, x2, y2 = region['bbox']
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
else:
|
||||
x1, y1, x2, y2 = 0, 0, w, h
|
||||
roi_img = ori_im
|
||||
if region['label'] == 'table':
|
||||
if self.table_system is not None:
|
||||
res = self.table_system(roi_img,
|
||||
return_ocr_result_in_table)
|
||||
res, table_time_dict = self.table_system(
|
||||
roi_img, return_ocr_result_in_table)
|
||||
time_dict['table'] += table_time_dict['table']
|
||||
time_dict['table_match'] += table_time_dict['match']
|
||||
time_dict['det'] += table_time_dict['det']
|
||||
time_dict['rec'] += table_time_dict['rec']
|
||||
else:
|
||||
if self.text_system is not None:
|
||||
if args.recovery:
|
||||
if self.recovery:
|
||||
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
|
||||
wht_im[y1:y2, x1:x2, :] = roi_img
|
||||
filter_boxes, filter_rec_res = self.text_system(wht_im)
|
||||
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
|
||||
wht_im)
|
||||
else:
|
||||
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||
# remove style char
|
||||
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
|
||||
roi_img)
|
||||
time_dict['det'] += ocr_time_dict['det']
|
||||
time_dict['rec'] += ocr_time_dict['rec']
|
||||
|
||||
# remove style char,
|
||||
# when using the recognition model trained on the PubtabNet dataset,
|
||||
# it will recognize the text format in the table, such as <b>
|
||||
style_token = [
|
||||
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
|
||||
'</b>', '<sub>', '</sup>', '<overline>',
|
||||
|
@ -125,7 +158,7 @@ class StructureSystem(object):
|
|||
for token in style_token:
|
||||
if token in rec_str:
|
||||
rec_str = rec_str.replace(token, '')
|
||||
if not args.recovery:
|
||||
if not self.recovery:
|
||||
box += [x1, y1]
|
||||
res.append({
|
||||
'text': rec_str,
|
||||
|
@ -133,15 +166,17 @@ class StructureSystem(object):
|
|||
'text_region': box.tolist()
|
||||
})
|
||||
res_list.append({
|
||||
'type': region.type,
|
||||
'type': region['label'].lower(),
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'img': roi_img,
|
||||
'res': res
|
||||
})
|
||||
return res_list
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return res_list, time_dict
|
||||
elif self.mode == 'vqa':
|
||||
raise NotImplementedError
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
def save_structure_res(res, save_folder, img_name):
|
||||
|
@ -156,12 +191,12 @@ def save_structure_res(res, save_folder, img_name):
|
|||
roi_img = region.pop('img')
|
||||
f.write('{}\n'.format(json.dumps(region)))
|
||||
|
||||
if region['type'] == 'Table' and len(region[
|
||||
if region['type'] == 'table' and len(region[
|
||||
'res']) > 0 and 'html' in region['res']:
|
||||
excel_path = os.path.join(excel_save_folder,
|
||||
'{}.xlsx'.format(region['bbox']))
|
||||
to_excel(region['res']['html'], excel_path)
|
||||
elif region['type'] == 'Figure':
|
||||
elif region['type'] == 'figure':
|
||||
img_path = os.path.join(excel_save_folder,
|
||||
'{}.jpg'.format(region['bbox']))
|
||||
cv2.imwrite(img_path, roi_img)
|
||||
|
@ -187,8 +222,7 @@ def main(args):
|
|||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = structure_sys(img)
|
||||
res, time_dict = structure_sys(img)
|
||||
|
||||
if structure_sys.mode == 'structure':
|
||||
save_structure_res(res, save_folder, img_name)
|
||||
|
@ -201,9 +235,8 @@ def main(args):
|
|||
cv2.imwrite(img_save_path, draw_img)
|
||||
logger.info('result save to {}'.format(img_save_path))
|
||||
if args.recovery:
|
||||
convert_info_docx(img, res, save_folder, img_name)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
convert_info_docx(img, res, save_folder, img_name)
|
||||
logger.info("Predict time : {:.3f}s".format(time_dict['all']))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,126 +1,124 @@
|
|||
- [Table Recognition](#table-recognition)
|
||||
- [1. pipeline](#1-pipeline)
|
||||
- [2. Performance](#2-performance)
|
||||
- [3. How to use](#3-how-to-use)
|
||||
- [3.1 quick start](#31-quick-start)
|
||||
- [3.2 Train](#32-train)
|
||||
- [3.3 Eval](#33-eval)
|
||||
- [3.4 Inference](#34-inference)
|
||||
|
||||
English | [简体中文](README_ch.md)
|
||||
|
||||
# Table Recognition
|
||||
|
||||
- [1. pipeline](#1-pipeline)
|
||||
- [2. Performance](#2-performance)
|
||||
- [3. Result](#3-result)
|
||||
- [4. How to use](#4-how-to-use)
|
||||
- [4.1 Quick start](#41-quick-start)
|
||||
- [4.2 Train](#42-train)
|
||||
- [4.3 Calculate TEDS](#43-calculate-teds)
|
||||
- [5. Reference](#5-reference)
|
||||
|
||||
|
||||
## 1. pipeline
|
||||
The table recognition mainly contains three models
|
||||
1. Single line text detection-DB
|
||||
2. Single line text recognition-CRNN
|
||||
3. Table structure and cell coordinate prediction-RARE
|
||||
3. Table structure and cell coordinate prediction-SLANet
|
||||
|
||||
The table recognition flow chart is as follows
|
||||
|
||||

|
||||
|
||||
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
|
||||
2. The table structure and cell coordinates is predicted by RARE model.
|
||||
2. The table structure and cell coordinates is predicted by SLANet model.
|
||||
3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
|
||||
4. The cell recognition result and the table structure together construct the html string of the table.
|
||||
|
||||
## 2. Performance
|
||||
We evaluated the algorithm on the PubTabNet<sup>[1]</sup> eval dataset, and the performance is as follows:
|
||||
|
||||
|Method|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
|
||||
| --- | --- | --- | ---|
|
||||
| EDD<sup>[2]</sup> |x| 88.3 |x|
|
||||
| TableRec-RARE(ours) |73.8%| 93.32 |1550ms|
|
||||
| SLANet(ours) | 76.2%| 94.98 |766ms|
|
||||
|
||||
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
||||
| --- | --- |
|
||||
| EDD<sup>[2]</sup> | 88.3 |
|
||||
| Ours | 93.32 |
|
||||
The performance indicators are explained as follows:
|
||||
- Acc: The accuracy of the table structure in each image, a wrong token is considered an error.
|
||||
- TEDS: The accuracy of the model's restoration of table information. This indicator evaluates not only the table structure, but also the text content in the table.
|
||||
- Speed: The inference speed of a single image when the model runs on the CPU machine and MKL is enabled.
|
||||
|
||||
## 3. How to use
|
||||
## 3. Result
|
||||
|
||||
### 3.1 quick start
|
||||

|
||||

|
||||

|
||||
|
||||
## 4. How to use
|
||||
|
||||
### 4.1 Quick start
|
||||
|
||||
Use the following commands to quickly complete the identification of a table.
|
||||
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
|
||||
# download model
|
||||
mkdir inference && cd inference
|
||||
# Download the detection model of the ultra-lightweight table English OCR model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
|
||||
# Download the recognition model of the ultra-lightweight table English OCR model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
|
||||
# Download the ultra-lightweight English table inch model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
# Download the PP-OCRv3 text detection model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar
|
||||
# Download the PP-OCRv3 text recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar
|
||||
# Download the PP-Structurev2 form recognition model and unzip it
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
# run
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
|
||||
```
|
||||
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
|
||||
python3.7 table/predict_table.py \
|
||||
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
|
||||
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
|
||||
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
|
||||
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
|
||||
--image_dir=docs/table/table.jpg \
|
||||
--output=../output/table
|
||||
|
||||
After running, the excel sheet of each picture will be saved in the directory specified by the output field
|
||||
|
||||
### 3.2 Train
|
||||
|
||||
In this chapter, we only introduce the training of the table structure model, For model training of [text detection](../../doc/doc_en/detection_en.md) and [text recognition](../../doc/doc_en/recognition_en.md), please refer to the corresponding documents
|
||||
|
||||
* data preparation
|
||||
The training data uses public data set [PubTabNet](https://arxiv.org/abs/1911.10683 ), Can be downloaded from the official [website](https://github.com/ibm-aur-nlp/PubTabNet) 。The PubTabNet data set contains about 500,000 images, as well as annotations in html format。
|
||||
|
||||
* Start training
|
||||
*If you are installing the cpu version of paddle, please modify the `use_gpu` field in the configuration file to false*
|
||||
```shell
|
||||
# single GPU training
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml
|
||||
# multi-GPU training
|
||||
# Set the GPU ID used by the '--gpus' parameter.
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
|
||||
```
|
||||
|
||||
In the above instruction, use `-c` to select the training to use the `configs/table/table_mv3.yml` configuration file.
|
||||
For a detailed explanation of the configuration file, please refer to [config](../../doc/doc_en/config_en.md).
|
||||
After the operation is completed, the excel table of each image will be saved to the directory specified by the output field, and an html file will be produced in the directory to visually view the cell coordinates and the recognized table.
|
||||
|
||||
* load trained model and continue training
|
||||
### 4.2 Train
|
||||
|
||||
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
|
||||
The training, evaluation and inference process of the text detection model can be referred to [detection](../../doc/doc_en/detection_en.md)
|
||||
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
The training, evaluation and inference process of the text recognition model can be referred to [recognition](../../doc/doc_en/recognition_en.md)
|
||||
|
||||
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
|
||||
The training, evaluation and inference process of the table recognition model can be referred to [table_recognition](../../doc/doc_en/table_recognition_en.md)
|
||||
|
||||
### 3.3 Eval
|
||||
### 4.3 Calculate TEDS
|
||||
|
||||
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
|
||||
```json
|
||||
{"PMC4289340_004_00.png": [
|
||||
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
||||
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
||||
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
||||
]}
|
||||
```txt
|
||||
PMC5755158_010_01.png <html><body><table><thead><tr><td></td><td><b>Weaning</b></td><td><b>Week 15</b></td><td><b>Off-test</b></td></tr></thead><tbody><tr><td>Weaning</td><td>–</td><td>–</td><td>–</td></tr><tr><td>Week 15</td><td>–</td><td>0.17 ± 0.08</td><td>0.16 ± 0.03</td></tr><tr><td>Off-test</td><td>–</td><td>0.80 ± 0.24</td><td>0.19 ± 0.09</td></tr></tbody></table></body></html>
|
||||
```
|
||||
Each line in gt consists of the file name and the html string of the table. The file name and the html string of the table are separated by `\t`.
|
||||
|
||||
You can also use the following command to generate an evaluation gt file from the annotation file:
|
||||
```python
|
||||
python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file
|
||||
```
|
||||
In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
|
||||
1. HTML string list of table structure
|
||||
2. The coordinates of each cell (not including the empty text in the cell)
|
||||
3. The text information in each cell (not including the empty text in the cell)
|
||||
|
||||
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||
python3 table/eval_table.py \
|
||||
--det_model_dir=path/to/det_model_dir \
|
||||
--rec_model_dir=path/to/rec_model_dir \
|
||||
--table_model_dir=path/to/table_model_dir \
|
||||
--image_dir=../doc/table/1.png \
|
||||
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
|
||||
--det_limit_side_len=736 \
|
||||
--det_limit_type=min \
|
||||
--gt_path=path/to/gt.txt
|
||||
```
|
||||
|
||||
If the PubLatNet eval dataset is used, it will be output
|
||||
```bash
|
||||
teds: 93.32
|
||||
teds: 94.98
|
||||
```
|
||||
|
||||
### 3.4 Inference
|
||||
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
After running, the excel sheet of each picture will be saved in the directory specified by the output field
|
||||
|
||||
Reference
|
||||
## 5. Reference
|
||||
1. https://github.com/ibm-aur-nlp/PubTabNet
|
||||
2. https://arxiv.org/pdf/1911.10683
|
||||
|
|
|
@ -2,22 +2,22 @@
|
|||
|
||||
# 表格识别
|
||||
|
||||
- [1. 表格识别 pipeline](#1)
|
||||
- [2. 性能](#2)
|
||||
- [3. 使用](#3)
|
||||
- [3.1 快速开始](#31)
|
||||
- [3.2 训练](#32)
|
||||
- [3.3 评估](#33)
|
||||
- [3.4 预测](#34)
|
||||
- [1. 表格识别 pipeline](#1-表格识别-pipeline)
|
||||
- [2. 性能](#2-性能)
|
||||
- [3. 效果演示](#3-效果演示)
|
||||
- [4. 使用](#4-使用)
|
||||
- [4.1 快速开始](#41-快速开始)
|
||||
- [4.2 训练](#42-训练)
|
||||
- [4.3 计算TEDS](#43-计算teds)
|
||||
- [5. Reference](#5-reference)
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 表格识别 pipeline
|
||||
|
||||
表格识别主要包含三个模型
|
||||
1. 单行文本检测-DB
|
||||
2. 单行文本识别-CRNN
|
||||
3. 表格结构和cell坐标预测-RARE
|
||||
3. 表格结构和cell坐标预测-SLANet
|
||||
|
||||
具体流程图如下
|
||||
|
||||
|
@ -26,111 +26,102 @@
|
|||
流程说明:
|
||||
|
||||
1. 图片由单行文字检测模型检测到单行文字的坐标,然后送入识别模型拿到识别结果。
|
||||
2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
|
||||
2. 图片由SLANet模型拿到表格的结构信息和单元格的坐标信息。
|
||||
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
|
||||
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 性能
|
||||
|
||||
我们在 PubTabNet<sup>[1]</sup> 评估数据集上对算法进行了评估,性能如下
|
||||
|
||||
|
||||
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
|
||||
| --- | --- |
|
||||
| EDD<sup>[2]</sup> | 88.3 |
|
||||
| Ours | 93.32 |
|
||||
|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
|
||||
| --- | --- | --- | ---|
|
||||
| EDD<sup>[2]</sup> |x| 88.3 |x|
|
||||
| TableRec-RARE(ours) |73.8%| 93.32 |1550ms|
|
||||
| SLANet(ours) | 76.2%| 94.98 |766ms|
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 使用
|
||||
性能指标解释如下:
|
||||
- Acc: 模型对每张图像里表格结构的识别准确率,错一个token就算错误。
|
||||
- TEDS: 模型对表格信息还原的准确度,此指标评价内容不仅包含表格结构,还包含表格内的文字内容。
|
||||
- Speed: 模型在CPU机器上,开启MKL的情况下,单张图片的推理速度。
|
||||
|
||||
<a name="31"></a>
|
||||
### 3.1 快速开始
|
||||
## 3. 效果演示
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||
## 4. 使用
|
||||
|
||||
### 4.1 快速开始
|
||||
|
||||
使用如下命令即可快速完成一张表格的识别。
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
|
||||
# 下载模型
|
||||
mkdir inference && cd inference
|
||||
# 下载超轻量级表格英文OCR模型的检测模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
|
||||
# 下载超轻量级表格英文OCR模型的识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
|
||||
# 下载超轻量级英文表格英寸模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
# 下载PP-OCRv3文本检测模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar
|
||||
# 下载PP-OCRv3文本识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar
|
||||
# 下载PP-Structurev2表格识别模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
|
||||
cd ..
|
||||
# 执行预测
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
|
||||
# 执行表格识别
|
||||
python table/predict_table.py \
|
||||
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
|
||||
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
|
||||
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
|
||||
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict_ch.txt \
|
||||
--image_dir=docs/table/table.jpg \
|
||||
--output=../output/table
|
||||
```
|
||||
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
|
||||
运行完成后,每张图片的excel表格会保存到output字段指定的目录下,同时在该目录下回生产一个html文件,用于可视化查看单元格坐标和识别的表格。
|
||||
|
||||
note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
|
||||
### 4.2 训练
|
||||
|
||||
<a name="32"></a>
|
||||
### 3.2 训练
|
||||
文本检测模型的训练、评估和推理流程可参考 [detection](../../doc/doc_ch/detection.md)
|
||||
|
||||
在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)和[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
|
||||
文本识别模型的训练、评估和推理流程可参考 [recognition](../../doc/doc_ch/recognition.md)
|
||||
|
||||
* 数据准备
|
||||
表格识别模型的训练、评估和推理流程可参考 [table_recognition](../../doc/doc_ch/table_recognition.md)
|
||||
|
||||
训练数据使用公开数据集PubTabNet ([论文](https://arxiv.org/abs/1911.10683),[下载地址](https://github.com/ibm-aur-nlp/PubTabNet))。PubTabNet数据集包含约50万张表格数据的图像,以及图像对应的html格式的注释。
|
||||
|
||||
* 启动训练
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
```shell
|
||||
# 单机单卡训练
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml
|
||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
|
||||
```
|
||||
|
||||
上述指令中,通过-c 选择训练使用configs/table/table_mv3.yml配置文件。有关配置文件的详细解释,请参考[链接](../../doc/doc_ch/config.md)。
|
||||
|
||||
* 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
|
||||
|
||||
<a name="33"></a>
|
||||
### 3.3 评估
|
||||
### 4.3 计算TEDS
|
||||
|
||||
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
|
||||
```json
|
||||
{"PMC4289340_004_00.png": [
|
||||
["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"],
|
||||
[[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
|
||||
[["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]
|
||||
]}
|
||||
```txt
|
||||
PMC5755158_010_01.png <html><body><table><thead><tr><td></td><td><b>Weaning</b></td><td><b>Week 15</b></td><td><b>Off-test</b></td></tr></thead><tbody><tr><td>Weaning</td><td>–</td><td>–</td><td>–</td></tr><tr><td>Week 15</td><td>–</td><td>0.17 ± 0.08</td><td>0.16 ± 0.03</td></tr><tr><td>Off-test</td><td>–</td><td>0.80 ± 0.24</td><td>0.19 ± 0.09</td></tr></tbody></table></body></html>
|
||||
```
|
||||
gt每一行都由文件名和表格的html字符串组成,文件名和表格的html字符串之间使用`\t`分隔。
|
||||
|
||||
也可使用如下命令,由标注文件生成评估的gt文件:
|
||||
```python
|
||||
python3 ppstructure/table/convert_label2html.py --ori_gt_path /path/to/your_label_file --save_path /path/to/save_file
|
||||
```
|
||||
json 中,key为图片名,value为对应的gt,gt是一个由三个item组成的list,每个item分别为
|
||||
1. 表格结构的html字符串list
|
||||
2. 每个cell的坐标 (不包括cell里文字为空的)
|
||||
3. 每个cell里的文字信息 (不包括cell里文字为空的)
|
||||
|
||||
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||
python3 table/eval_table.py \
|
||||
--det_model_dir=path/to/det_model_dir \
|
||||
--rec_model_dir=path/to/rec_model_dir \
|
||||
--table_model_dir=path/to/table_model_dir \
|
||||
--image_dir=../doc/table/1.png \
|
||||
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
|
||||
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
|
||||
--det_limit_side_len=736 \
|
||||
--det_limit_type=min \
|
||||
--gt_path=path/to/gt.txt
|
||||
```
|
||||
如使用PubLatNet评估数据集,将会输出
|
||||
```bash
|
||||
teds: 93.32
|
||||
teds: 94.98
|
||||
```
|
||||
|
||||
<a name="34"></a>
|
||||
### 3.4 预测
|
||||
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
|
||||
# Reference
|
||||
## 5. Reference
|
||||
1. https://github.com/ibm-aur-nlp/PubTabNet
|
||||
2. https://arxiv.org/pdf/1911.10683
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
conver table label to html
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def save_pred_txt(key, val, tmp_file_path):
|
||||
with open(tmp_file_path, 'a+', encoding='utf-8') as f:
|
||||
f.write('{}\t{}\n'.format(key, val))
|
||||
|
||||
|
||||
def skip_char(text, sp_char_list):
|
||||
"""
|
||||
skip empty cell
|
||||
@param text: text in cell
|
||||
@param sp_char_list: style char and special code
|
||||
@return:
|
||||
"""
|
||||
for sp_char in sp_char_list:
|
||||
text = text.replace(sp_char, '')
|
||||
return text
|
||||
|
||||
|
||||
def gen_html(img):
|
||||
'''
|
||||
Formats HTML code from tokenized annotation of img
|
||||
'''
|
||||
html_code = img['html']['structure']['tokens'].copy()
|
||||
to_insert = [i for i, tag in enumerate(html_code) if tag in ('<td>', '>')]
|
||||
for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]):
|
||||
if cell['tokens']:
|
||||
text = ''.join(cell['tokens'])
|
||||
# skip empty text
|
||||
sp_char_list = ['<b>', '</b>', '\u2028', ' ', '<i>', '</i>']
|
||||
text_remove_style = skip_char(text, sp_char_list)
|
||||
if len(text_remove_style) == 0:
|
||||
continue
|
||||
html_code.insert(i + 1, text)
|
||||
html_code = ''.join(html_code)
|
||||
html_code = '<html><body><table>{}</table></body></html>'.format(html_code)
|
||||
return html_code
|
||||
|
||||
|
||||
def load_gt_data(gt_path):
|
||||
"""
|
||||
load gt
|
||||
@param gt_path:
|
||||
@return:
|
||||
"""
|
||||
data_list = {}
|
||||
with open(gt_path, 'rb') as f:
|
||||
lines = f.readlines()
|
||||
for line in tqdm(lines):
|
||||
data_line = line.decode('utf-8').strip("\n")
|
||||
info = json.loads(data_line)
|
||||
data_list[info['filename']] = info
|
||||
return data_list
|
||||
|
||||
|
||||
def convert(origin_gt_path, save_path):
|
||||
"""
|
||||
gen html from label file
|
||||
@param origin_gt_path:
|
||||
@param save_path:
|
||||
@return:
|
||||
"""
|
||||
data_dict = load_gt_data(origin_gt_path)
|
||||
for img_name, gt in tqdm(data_dict.items()):
|
||||
html = gen_html(gt)
|
||||
save_pred_txt(img_name, html, save_path)
|
||||
print('conver finish')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="args for paddleserving")
|
||||
parser.add_argument(
|
||||
"--ori_gt_path", type=str, required=True, help="label gt path")
|
||||
parser.add_argument(
|
||||
"--save_path", type=str, required=True, help="path to save file")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
convert(args.ori_gt_path, args.save_path)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,14 +11,17 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import pickle
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
from ppstructure.table.table_metric import TEDS
|
||||
from ppstructure.table.predict_table import TableSystem
|
||||
|
@ -33,40 +36,74 @@ def parse_args():
|
|||
parser.add_argument("--gt_path", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
teds = TEDS(n_jobs=16)
|
||||
|
||||
def load_txt(txt_path):
|
||||
pred_html_dict = {}
|
||||
if not os.path.exists(txt_path):
|
||||
return pred_html_dict
|
||||
with open(txt_path, encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split('\t')
|
||||
img_name, pred_html = line
|
||||
pred_html_dict[img_name] = pred_html
|
||||
return pred_html_dict
|
||||
|
||||
|
||||
def load_result(path):
|
||||
data = {}
|
||||
if os.path.exists(path):
|
||||
data = pickle.load(open(path, 'rb'))
|
||||
return data
|
||||
|
||||
|
||||
def save_result(path, data):
|
||||
old_data = load_result(path)
|
||||
old_data.update(data)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(old_data, f)
|
||||
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
# init TableSystem
|
||||
text_sys = TableSystem(args)
|
||||
jsons_gt = json.load(open(gt_path)) # gt
|
||||
# load gt and preds html result
|
||||
gt_html_dict = load_txt(gt_path)
|
||||
|
||||
ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
|
||||
structure_result = load_result(
|
||||
os.path.join(args.output, 'structure.pickle'))
|
||||
|
||||
pred_htmls = []
|
||||
gt_htmls = []
|
||||
for img_name in tqdm(jsons_gt):
|
||||
# read image
|
||||
img = cv2.imread(os.path.join(img_root,img_name))
|
||||
pred_html = text_sys(img)
|
||||
for img_name, gt_html in tqdm(gt_html_dict.items()):
|
||||
img = cv2.imread(os.path.join(img_root, img_name))
|
||||
# run ocr and save result
|
||||
if img_name not in ocr_result:
|
||||
dt_boxes, rec_res, _, _ = text_sys._ocr(img)
|
||||
ocr_result[img_name] = [dt_boxes, rec_res]
|
||||
save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
|
||||
# run structure and save result
|
||||
if img_name not in structure_result:
|
||||
structure_res, _ = text_sys._structure(img)
|
||||
structure_result[img_name] = structure_res
|
||||
save_result(
|
||||
os.path.join(args.output, 'structure.pickle'), structure_result)
|
||||
dt_boxes, rec_res = ocr_result[img_name]
|
||||
structure_res = structure_result[img_name]
|
||||
# match ocr and structure
|
||||
pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
|
||||
|
||||
pred_htmls.append(pred_html)
|
||||
|
||||
gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
|
||||
gt_html, gt = get_gt_html(gt_structures, gt_contents)
|
||||
gt_htmls.append(gt_html)
|
||||
|
||||
# compute teds
|
||||
teds = TEDS(n_jobs=16)
|
||||
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
|
||||
logger.info('teds:', sum(scores) / len(scores))
|
||||
|
||||
|
||||
def get_gt_html(gt_structures, gt_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in gt_structures:
|
||||
if '</td>' in tag:
|
||||
if gt_contents[td_index] != []:
|
||||
end_html.extend(gt_contents[td_index])
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
logger.info('teds: {}'.format(sum(scores) / len(scores)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args.gt_path,args.image_dir, args)
|
||||
main(args.gt_path, args.image_dir, args)
|
||||
|
|
|
@ -1,11 +1,29 @@
|
|||
import json
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from ppstructure.table.table_master_match import deal_eb_token, deal_bb
|
||||
|
||||
|
||||
def distance(box_1, box_2):
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4- x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4 - x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
|
||||
|
||||
def compute_iou(rec1, rec2):
|
||||
"""
|
||||
|
@ -18,175 +36,157 @@ def compute_iou(rec1, rec2):
|
|||
# computing area of each rectangles
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
left_line = max(rec1[1], rec2[1])
|
||||
right_line = min(rec1[3], rec2[3])
|
||||
top_line = max(rec1[0], rec2[0])
|
||||
bottom_line = min(rec1[2], rec2[2])
|
||||
|
||||
|
||||
# judge if there is an intersect
|
||||
if left_line >= right_line or top_line >= bottom_line:
|
||||
return 0.0
|
||||
else:
|
||||
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||
return (intersect / (sum_area - intersect))*1.0
|
||||
|
||||
return (intersect / (sum_area - intersect)) * 1.0
|
||||
|
||||
|
||||
def matcher_merge(ocr_bboxes, pred_bboxes):
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(ocr_bboxes):
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
# compute l1 distence and IOU between two boxes
|
||||
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
|
||||
sorted_distances = distances.copy()
|
||||
# select nearest cell
|
||||
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
class TableMatch:
|
||||
def __init__(self, filter_ocr_result=False, use_master=False):
|
||||
self.filter_ocr_result = filter_ocr_result
|
||||
self.use_master = use_master
|
||||
|
||||
def __call__(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
if self.filter_ocr_result:
|
||||
dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
|
||||
rec_res)
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
if self.use_master:
|
||||
pred_html, pred = self.get_pred_html_master(pred_structures,
|
||||
matched_index, rec_res)
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
|
||||
rec_res)
|
||||
return pred_html
|
||||
|
||||
def complex_num(pred_bboxes):
|
||||
complex_nums = []
|
||||
for bbox in pred_bboxes:
|
||||
distances = []
|
||||
temp_ious = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
if bbox != pred_bbox:
|
||||
distances.append(distance(bbox, pred_bbox))
|
||||
temp_ious.append(compute_iou(bbox, pred_bbox))
|
||||
complex_nums.append(temp_ious[distances.index(min(distances))])
|
||||
return sum(complex_nums) / len(complex_nums)
|
||||
|
||||
def get_rows(pred_bboxes):
|
||||
pre_bbox = pred_bboxes[0]
|
||||
res = []
|
||||
step = 0
|
||||
for i in range(len(pred_bboxes)):
|
||||
bbox = pred_bboxes[i]
|
||||
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
|
||||
break
|
||||
else:
|
||||
res.append(bbox)
|
||||
step += 1
|
||||
for i in range(step):
|
||||
pred_bboxes.pop(0)
|
||||
return res, pred_bboxes
|
||||
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||
ys_1 = []
|
||||
ys_2 = []
|
||||
for box in pred_bboxes:
|
||||
ys_1.append(box[1])
|
||||
ys_2.append(box[3])
|
||||
min_y_1 = sum(ys_1) / len(ys_1)
|
||||
min_y_2 = sum(ys_2) / len(ys_2)
|
||||
re_boxes = []
|
||||
for box in pred_bboxes:
|
||||
box[1] = min_y_1
|
||||
box[3] = min_y_2
|
||||
re_boxes.append(box)
|
||||
return re_boxes
|
||||
|
||||
def matcher_refine_row(gt_bboxes, pred_bboxes):
|
||||
before_refine_pred_bboxes = pred_bboxes.copy()
|
||||
pred_bboxes = []
|
||||
while(len(before_refine_pred_bboxes) != 0):
|
||||
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
|
||||
print(row_bboxes)
|
||||
pred_bboxes.extend(refine_rows(row_bboxes))
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
distances = []
|
||||
#temp_ious = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
#temp_ious.append(compute_iou(gt_box, pred_box))
|
||||
#all_dis.append(min(distances))
|
||||
#ious.append(temp_ious[distances.index(min(distances))])
|
||||
if distances.index(min(distances)) not in matched.keys():
|
||||
matched[distances.index(min(distances))] = [i]
|
||||
else:
|
||||
matched[distances.index(min(distances))].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
|
||||
|
||||
|
||||
#先挑选出一行,再进行匹配
|
||||
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
gt_box_index = 0
|
||||
delete_gt_bboxes = gt_bboxes.copy()
|
||||
match_bboxes_ready = []
|
||||
matched = {}
|
||||
while(len(delete_gt_bboxes) != 0):
|
||||
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
|
||||
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
|
||||
if len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
print(row_bboxes)
|
||||
for i, gt_box in enumerate(row_bboxes):
|
||||
#print(gt_box)
|
||||
pred_distances = []
|
||||
distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print('index', index)
|
||||
if index not in matched.keys():
|
||||
matched[index] = [gt_box_index]
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
if len(pred_box) == 8:
|
||||
pred_box = [
|
||||
np.min(pred_box[0::2]), np.min(pred_box[1::2]),
|
||||
np.max(pred_box[0::2]), np.max(pred_box[1::2])
|
||||
]
|
||||
distances.append((distance(gt_box, pred_box),
|
||||
1. - compute_iou(gt_box, pred_box)
|
||||
)) # compute iou and l1 distance
|
||||
sorted_distances = distances.copy()
|
||||
# select det box by iou and l1 distance
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[index].append(gt_box_index)
|
||||
gt_box_index += 1
|
||||
return matched
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
'''
|
||||
gt_bboxes: 排序后
|
||||
pred_bboxes:
|
||||
'''
|
||||
pre_bbox = gt_bboxes[0]
|
||||
matched = {}
|
||||
match_bboxes_ready = []
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
|
||||
pred_distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
distances = []
|
||||
gap_pre = gt_box[1] - pre_bbox[1]
|
||||
gap_pre_1 = gt_box[0] - pre_bbox[2]
|
||||
#print(gap_pre, len(pred_bboxes_rows))
|
||||
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(pred_bboxes_rows) == 1:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
|
||||
break
|
||||
#print(match_bboxes_ready)
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print(gt_box, index)
|
||||
#match_bboxes_ready.pop(distances.index(min(distances)))
|
||||
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
|
||||
if index not in matched.keys():
|
||||
matched[index] = [i]
|
||||
else:
|
||||
matched[index].append(i)
|
||||
pre_bbox = gt_box
|
||||
return matched
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if '<td></td>' == tag:
|
||||
end_html.extend('<td>')
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
if '<td></td>' == tag:
|
||||
end_html.append('</td>')
|
||||
else:
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
def get_pred_html_master(self, pred_structures, matched_index,
|
||||
ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for token in pred_structures:
|
||||
if '</td>' in token:
|
||||
txt = ''
|
||||
b_with = False
|
||||
if td_index in matched_index.keys():
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
txt += content
|
||||
if b_with:
|
||||
txt = '<b>{}</b>'.format(txt)
|
||||
if '<td></td>' == token:
|
||||
token = '<td>{}</td>'.format(txt)
|
||||
else:
|
||||
token = '{}</td>'.format(txt)
|
||||
td_index += 1
|
||||
token = deal_eb_token(token)
|
||||
end_html.append(token)
|
||||
html = ''.join(end_html)
|
||||
html = deal_bb(html)
|
||||
return html, end_html
|
||||
|
||||
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
|
||||
y1 = pred_bboxes[:, 1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
||||
for box, rec in zip(dt_boxes, rec_res):
|
||||
if np.max(box[1::2]) < y1:
|
||||
continue
|
||||
new_dt_boxes.append(box)
|
||||
new_rec_res.append(rec)
|
||||
return new_dt_boxes, new_rec_res
|
||||
|
|
|
@ -16,7 +16,7 @@ import sys
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
|
@ -73,12 +73,14 @@ class TableStructurer(object):
|
|||
postprocess_params = {
|
||||
'name': 'TableLabelDecode',
|
||||
"character_dict_path": args.table_char_dict_path,
|
||||
'merge_no_span_structure': args.merge_no_span_structure
|
||||
}
|
||||
else:
|
||||
postprocess_params = {
|
||||
'name': 'TableMasterLabelDecode',
|
||||
"character_dict_path": args.table_char_dict_path,
|
||||
'box_shape': 'pad'
|
||||
'box_shape': 'pad',
|
||||
'merge_no_span_structure': args.merge_no_span_structure
|
||||
}
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
|
@ -87,6 +89,7 @@ class TableStructurer(object):
|
|||
utility.create_predictor(args, 'table', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
starttime = time.time()
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
|
@ -95,7 +98,6 @@ class TableStructurer(object):
|
|||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
|
@ -126,7 +128,6 @@ def main(args):
|
|||
table_structurer = TableStructurer(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
use_xywh = args.table_algorithm in ['TableMaster']
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
with open(
|
||||
os.path.join(args.output, 'infer.txt'), mode='w',
|
||||
|
@ -146,7 +147,10 @@ def main(args):
|
|||
f_w.write("result: {}, {}\n".format(structure_str_list,
|
||||
bbox_list_str))
|
||||
|
||||
img = draw_rectangle(image_file, bbox_list, use_xywh)
|
||||
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
|
||||
img = draw_rectangle(image_file, pred_res['cell_bbox'])
|
||||
else:
|
||||
img = utility.draw_boxes(img, bbox_list)
|
||||
img_save_path = os.path.join(args.output,
|
||||
os.path.basename(image_file))
|
||||
cv2.imwrite(img_save_path, img)
|
||||
|
|
|
@ -18,20 +18,23 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import sorted_boxes
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppstructure.table.matcher import distance, compute_iou
|
||||
from ppstructure.table.matcher import TableMatch
|
||||
from ppstructure.table.table_master_match import TableMasterMatcher
|
||||
from ppstructure.utility import parse_args
|
||||
import ppstructure.table.predict_structure as predict_strture
|
||||
|
||||
|
@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
|
|||
|
||||
class TableSystem(object):
|
||||
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||
if not args.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
self.text_detector = predict_det.TextDetector(
|
||||
args) if text_detector is None else text_detector
|
||||
self.text_recognizer = predict_rec.TextRecognizer(
|
||||
args) if text_recognizer is None else text_recognizer
|
||||
|
||||
self.table_structurer = predict_strture.TableStructurer(args)
|
||||
if args.table_algorithm in ['TableMaster']:
|
||||
self.match = TableMasterMatcher()
|
||||
else:
|
||||
self.match = TableMatch(filter_ocr_result=True)
|
||||
|
||||
self.benchmark = args.benchmark
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||
args, 'table', logger)
|
||||
|
@ -85,145 +97,72 @@ class TableSystem(object):
|
|||
|
||||
def __call__(self, img, return_ocr_result_in_table=False):
|
||||
result = dict()
|
||||
ori_im = img.copy()
|
||||
time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
|
||||
start = time.time()
|
||||
|
||||
structure_res, elapse = self._structure(copy.deepcopy(img))
|
||||
result['cell_bbox'] = structure_res[1].tolist()
|
||||
time_dict['table'] = elapse
|
||||
|
||||
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
|
||||
copy.deepcopy(img))
|
||||
time_dict['det'] = det_elapse
|
||||
time_dict['rec'] = rec_elapse
|
||||
|
||||
if return_ocr_result_in_table:
|
||||
result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes]
|
||||
result['rec_res'] = rec_res
|
||||
|
||||
tic = time.time()
|
||||
pred_html = self.match(structure_res, dt_boxes, rec_res)
|
||||
toc = time.time()
|
||||
time_dict['match'] = toc - tic
|
||||
result['html'] = pred_html
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
return result, time_dict
|
||||
|
||||
def _structure(self, img):
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
|
||||
return structure_res, elapse
|
||||
|
||||
def _ocr(self, img):
|
||||
h, w = img.shape[:2]
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
if return_ocr_result_in_table:
|
||||
result['boxes'] = [x.tolist() for x in dt_boxes]
|
||||
|
||||
r_boxes = []
|
||||
for box in dt_boxes:
|
||||
x_min = box[:, 0].min() - 1
|
||||
x_max = box[:, 0].max() + 1
|
||||
y_min = box[:, 1].min() - 1
|
||||
y_max = box[:, 1].max() + 1
|
||||
x_min = max(0, box[:, 0].min() - 1)
|
||||
x_max = min(w, box[:, 0].max() + 1)
|
||||
y_min = max(0, box[:, 1].min() - 1)
|
||||
y_max = min(h, box[:, 1].max() + 1)
|
||||
box = [x_min, y_min, x_max, y_max]
|
||||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
len(dt_boxes), det_elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
|
||||
img_crop_list = []
|
||||
for i in range(len(dt_boxes)):
|
||||
det_box = dt_boxes[i]
|
||||
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
|
||||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||
x0, y0, x1, y1 = expand(2, det_box, img.shape)
|
||||
text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
|
||||
img_crop_list.append(text_rect)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
if return_ocr_result_in_table:
|
||||
result['rec_res'] = rec_res
|
||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||
result['html'] = pred_html
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return result
|
||||
|
||||
def rebuild_table(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
|
||||
rec_res)
|
||||
return pred_html, pred
|
||||
|
||||
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
|
||||
y1 = pred_bboxes[:,1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
||||
for box,rec in zip(dt_boxes, rec_res):
|
||||
if np.max(box[1::2]) < y1:
|
||||
continue
|
||||
new_dt_boxes.append(box)
|
||||
new_rec_res.append(rec)
|
||||
return new_dt_boxes, new_rec_res
|
||||
|
||||
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append((distance(gt_box, pred_box),
|
||||
1. - compute_iou(gt_box, pred_box)
|
||||
)) # 获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
return _boxes
|
||||
len(rec_res), rec_elapse))
|
||||
return dt_boxes, rec_res, det_elapse, rec_elapse
|
||||
|
||||
|
||||
def to_excel(html_table, excel_path):
|
||||
|
@ -236,8 +175,23 @@ def main(args):
|
|||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
table_sys = TableSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
|
||||
f_html = open(
|
||||
os.path.join(args.output, 'show.html'), mode='w', encoding='utf-8')
|
||||
f_html.write('<html>\n<body>\n')
|
||||
f_html.write('<table border="1">\n')
|
||||
f_html.write(
|
||||
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
|
||||
)
|
||||
f_html.write("<tr>\n")
|
||||
f_html.write('<td>img name\n')
|
||||
f_html.write('<td>ori image</td>')
|
||||
f_html.write('<td>table html</td>')
|
||||
f_html.write('<td>cell box</td>')
|
||||
f_html.write("</tr>\n")
|
||||
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
|
@ -249,13 +203,35 @@ def main(args):
|
|||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_res = text_sys(img)
|
||||
pred_res, _ = table_sys(img)
|
||||
pred_html = pred_res['html']
|
||||
logger.info(pred_html)
|
||||
to_excel(pred_html, excel_path)
|
||||
logger.info('excel saved to {}'.format(excel_path))
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
if len(pred_res['cell_bbox']) > 0 and len(pred_res['cell_bbox'][
|
||||
0]) == 4:
|
||||
img = predict_strture.draw_rectangle(image_file,
|
||||
pred_res['cell_bbox'])
|
||||
else:
|
||||
img = utility.draw_boxes(img, pred_res['cell_bbox'])
|
||||
img_save_path = os.path.join(args.output, os.path.basename(image_file))
|
||||
cv2.imwrite(img_save_path, img)
|
||||
|
||||
f_html.write("<tr>\n")
|
||||
f_html.write(f'<td> {os.path.basename(image_file)} <br/>\n')
|
||||
f_html.write(f'<td><img src="{image_file}" width=640></td>\n')
|
||||
f_html.write('<td><table border="1">' + pred_html.replace(
|
||||
'<html><body><table>', '').replace('</table></body></html>', '') +
|
||||
'</table></td>\n')
|
||||
f_html.write(
|
||||
f'<td><img src="{os.path.basename(image_file)}" width=640></td>\n')
|
||||
f_html.write("</tr>\n")
|
||||
f_html.write("</table>\n")
|
||||
f_html.close()
|
||||
|
||||
if args.benchmark:
|
||||
text_sys.autolog.report()
|
||||
|
||||
|
|
|
@ -0,0 +1,953 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import cv2
|
||||
import glob
|
||||
import copy
|
||||
import math
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from shapely.geometry import Polygon, MultiPoint
|
||||
"""
|
||||
Useful function in matching.
|
||||
"""
|
||||
|
||||
|
||||
def remove_empty_bboxes(bboxes):
|
||||
"""
|
||||
remove [0., 0., 0., 0.] in structure master bboxes.
|
||||
len(bboxes.shape) must be 2.
|
||||
:param bboxes:
|
||||
:return:
|
||||
"""
|
||||
new_bboxes = []
|
||||
for bbox in bboxes:
|
||||
if sum(bbox) == 0.:
|
||||
continue
|
||||
new_bboxes.append(bbox)
|
||||
return np.array(new_bboxes)
|
||||
|
||||
|
||||
def xywh2xyxy(bboxes):
|
||||
if len(bboxes.shape) == 1:
|
||||
new_bboxes = np.empty_like(bboxes)
|
||||
new_bboxes[0] = bboxes[0] - bboxes[2] / 2
|
||||
new_bboxes[1] = bboxes[1] - bboxes[3] / 2
|
||||
new_bboxes[2] = bboxes[0] + bboxes[2] / 2
|
||||
new_bboxes[3] = bboxes[1] + bboxes[3] / 2
|
||||
return new_bboxes
|
||||
elif len(bboxes.shape) == 2:
|
||||
new_bboxes = np.empty_like(bboxes)
|
||||
new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
|
||||
new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
|
||||
new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
|
||||
new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
|
||||
return new_bboxes
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def xyxy2xywh(bboxes):
|
||||
if len(bboxes.shape) == 1:
|
||||
new_bboxes = np.empty_like(bboxes)
|
||||
new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
|
||||
new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
|
||||
new_bboxes[2] = bboxes[2] - bboxes[0]
|
||||
new_bboxes[3] = bboxes[3] - bboxes[1]
|
||||
return new_bboxes
|
||||
elif len(bboxes.shape) == 2:
|
||||
new_bboxes = np.empty_like(bboxes)
|
||||
new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
|
||||
new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
|
||||
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
|
||||
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
|
||||
return new_bboxes
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def pickle_load(path, prefix='end2end'):
|
||||
if os.path.isfile(path):
|
||||
data = pickle.load(open(path, 'rb'))
|
||||
elif os.path.isdir(path):
|
||||
data = dict()
|
||||
search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
|
||||
pkls = glob.glob(search_path)
|
||||
for pkl in pkls:
|
||||
this_data = pickle.load(open(pkl, 'rb'))
|
||||
data.update(this_data)
|
||||
else:
|
||||
raise ValueError
|
||||
return data
|
||||
|
||||
|
||||
def convert_coord(xyxy):
|
||||
"""
|
||||
Convert two points format to four points format.
|
||||
:param xyxy:
|
||||
:return:
|
||||
"""
|
||||
new_bbox = np.zeros([4, 2], dtype=np.float32)
|
||||
new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
|
||||
new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
|
||||
new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
|
||||
new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
|
||||
return new_bbox
|
||||
|
||||
|
||||
def cal_iou(bbox1, bbox2):
|
||||
bbox1_poly = Polygon(bbox1).convex_hull
|
||||
bbox2_poly = Polygon(bbox2).convex_hull
|
||||
union_poly = np.concatenate((bbox1, bbox2))
|
||||
|
||||
if not bbox1_poly.intersects(bbox2_poly):
|
||||
iou = 0
|
||||
else:
|
||||
inter_area = bbox1_poly.intersection(bbox2_poly).area
|
||||
union_area = MultiPoint(union_poly).convex_hull.area
|
||||
if union_area == 0:
|
||||
iou = 0
|
||||
else:
|
||||
iou = float(inter_area) / union_area
|
||||
return iou
|
||||
|
||||
|
||||
def cal_distance(p1, p2):
|
||||
delta_x = p1[0] - p2[0]
|
||||
delta_y = p1[1] - p2[1]
|
||||
d = math.sqrt((delta_x**2) + (delta_y**2))
|
||||
return d
|
||||
|
||||
|
||||
def is_inside(center_point, corner_point):
|
||||
"""
|
||||
Find if center_point inside the bbox(corner_point) or not.
|
||||
:param center_point: center point (x, y)
|
||||
:param corner_point: corner point ((x1,y1),(x2,y2))
|
||||
:return:
|
||||
"""
|
||||
x_flag = False
|
||||
y_flag = False
|
||||
if (center_point[0] >= corner_point[0][0]) and (
|
||||
center_point[0] <= corner_point[1][0]):
|
||||
x_flag = True
|
||||
if (center_point[1] >= corner_point[0][1]) and (
|
||||
center_point[1] <= corner_point[1][1]):
|
||||
y_flag = True
|
||||
if x_flag and y_flag:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def find_no_match(match_list, all_end2end_nums, type='end2end'):
|
||||
"""
|
||||
Find out no match end2end bbox in previous match list.
|
||||
:param match_list: matching pairs.
|
||||
:param all_end2end_nums: numbers of end2end_xywh
|
||||
:param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
|
||||
:return: no match pse bbox index list
|
||||
"""
|
||||
if type == 'end2end':
|
||||
idx = 0
|
||||
elif type == 'master':
|
||||
idx = 1
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
no_match_indexs = []
|
||||
# m[0] is end2end index m[1] is master index
|
||||
matched_bbox_indexs = [m[idx] for m in match_list]
|
||||
for n in range(all_end2end_nums):
|
||||
if n not in matched_bbox_indexs:
|
||||
no_match_indexs.append(n)
|
||||
return no_match_indexs
|
||||
|
||||
|
||||
def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
|
||||
# only consider y axis, for grouping in row.
|
||||
delta = abs(this_bbox[1] - target_bbox[1])
|
||||
if delta < threshold:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def sort_line_bbox(g, bg):
|
||||
"""
|
||||
Sorted the bbox in the same line(group)
|
||||
compare coord 'x' value, where 'y' value is closed in the same group.
|
||||
:param g: index in the same group
|
||||
:param bg: bbox in the same group
|
||||
:return:
|
||||
"""
|
||||
|
||||
xs = [bg_item[0] for bg_item in bg]
|
||||
xs_sorted = sorted(xs)
|
||||
|
||||
g_sorted = [None] * len(xs_sorted)
|
||||
bg_sorted = [None] * len(xs_sorted)
|
||||
for g_item, bg_item in zip(g, bg):
|
||||
idx = xs_sorted.index(bg_item[0])
|
||||
bg_sorted[idx] = bg_item
|
||||
g_sorted[idx] = g_item
|
||||
|
||||
return g_sorted, bg_sorted
|
||||
|
||||
|
||||
def flatten(sorted_groups, sorted_bbox_groups):
|
||||
idxs = []
|
||||
bboxes = []
|
||||
for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
|
||||
for g, bg in zip(group, bbox_group):
|
||||
idxs.append(g)
|
||||
bboxes.append(bg)
|
||||
return idxs, bboxes
|
||||
|
||||
|
||||
def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
|
||||
"""
|
||||
This function will group the render end2end bboxes in row.
|
||||
:param end2end_xywh_bboxes:
|
||||
:param no_match_end2end_indexes:
|
||||
:return:
|
||||
"""
|
||||
groups = []
|
||||
bbox_groups = []
|
||||
for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
|
||||
end2end_xywh_bboxes):
|
||||
this_bbox = end2end_xywh_bbox
|
||||
if len(groups) == 0:
|
||||
groups.append([index])
|
||||
bbox_groups.append([this_bbox])
|
||||
else:
|
||||
flag = False
|
||||
for g, bg in zip(groups, bbox_groups):
|
||||
# this_bbox is belong to bg's row or not
|
||||
if is_abs_lower_than_threshold(this_bbox, bg[0]):
|
||||
g.append(index)
|
||||
bg.append(this_bbox)
|
||||
flag = True
|
||||
break
|
||||
if not flag:
|
||||
# this_bbox is not belong to bg's row, create a row.
|
||||
groups.append([index])
|
||||
bbox_groups.append([this_bbox])
|
||||
|
||||
# sorted bboxes in a group
|
||||
tmp_groups, tmp_bbox_groups = [], []
|
||||
for g, bg in zip(groups, bbox_groups):
|
||||
g_sorted, bg_sorted = sort_line_bbox(g, bg)
|
||||
tmp_groups.append(g_sorted)
|
||||
tmp_bbox_groups.append(bg_sorted)
|
||||
|
||||
# sorted groups, sort by coord y's value.
|
||||
sorted_groups = [None] * len(tmp_groups)
|
||||
sorted_bbox_groups = [None] * len(tmp_bbox_groups)
|
||||
ys = [bg[0][1] for bg in tmp_bbox_groups]
|
||||
sorted_ys = sorted(ys)
|
||||
for g, bg in zip(tmp_groups, tmp_bbox_groups):
|
||||
idx = sorted_ys.index(bg[0][1])
|
||||
sorted_groups[idx] = g
|
||||
sorted_bbox_groups[idx] = bg
|
||||
|
||||
# flatten, get final result
|
||||
end2end_sorted_idx_list, end2end_sorted_bbox_list \
|
||||
= flatten(sorted_groups, sorted_bbox_groups)
|
||||
|
||||
return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
|
||||
|
||||
|
||||
def get_bboxes_list(end2end_result, structure_master_result):
|
||||
"""
|
||||
This function is use to convert end2end results and structure master results to
|
||||
List of xyxy bbox format and List of xywh bbox format
|
||||
:param end2end_result: bbox's format is xyxy
|
||||
:param structure_master_result: bbox's format is xywh
|
||||
:return: 4 kind list of bbox ()
|
||||
"""
|
||||
# end2end
|
||||
end2end_xyxy_list = []
|
||||
end2end_xywh_list = []
|
||||
for end2end_item in end2end_result:
|
||||
src_bbox = end2end_item['bbox']
|
||||
end2end_xyxy_list.append(src_bbox)
|
||||
xywh_bbox = xyxy2xywh(src_bbox)
|
||||
end2end_xywh_list.append(xywh_bbox)
|
||||
end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
|
||||
end2end_xywh_bboxes = np.array(end2end_xywh_list)
|
||||
|
||||
# structure master
|
||||
src_bboxes = structure_master_result['bbox']
|
||||
src_bboxes = remove_empty_bboxes(src_bboxes)
|
||||
structure_master_xyxy_bboxes = src_bboxes
|
||||
xywh_bbox = xyxy2xywh(src_bboxes)
|
||||
structure_master_xywh_bboxes = xywh_bbox
|
||||
|
||||
return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
|
||||
|
||||
|
||||
def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
|
||||
"""
|
||||
Judge end2end Bbox's center point is inside structure master Bbox or not,
|
||||
if end2end Bbox's center is in structure master Bbox, get matching pair.
|
||||
:param end2end_xywh_bboxes:
|
||||
:param structure_master_xyxy_bboxes:
|
||||
:return: match pairs list, e.g. [[0,1], [1,2], ...]
|
||||
"""
|
||||
match_pairs_list = []
|
||||
for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
|
||||
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
|
||||
x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
|
||||
x_master1, y_master1, x_master2, y_master2 \
|
||||
= master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
|
||||
center_point_end2end = (x_end2end, y_end2end)
|
||||
corner_point_master = ((x_master1, y_master1),
|
||||
(x_master2, y_master2))
|
||||
if is_inside(center_point_end2end, corner_point_master):
|
||||
match_pairs_list.append([i, j])
|
||||
return match_pairs_list
|
||||
|
||||
|
||||
def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
|
||||
structure_master_xyxy_bboxes):
|
||||
"""
|
||||
Use iou to find matching list.
|
||||
choose max iou value bbox as match pair.
|
||||
:param end2end_xyxy_bboxes:
|
||||
:param end2end_xyxy_indexes: original end2end indexes.
|
||||
:param structure_master_xyxy_bboxes:
|
||||
:return: match pairs list, e.g. [[0,1], [1,2], ...]
|
||||
"""
|
||||
match_pair_list = []
|
||||
for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
|
||||
end2end_xyxy_bboxes):
|
||||
max_iou = 0
|
||||
max_match = [None, None]
|
||||
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
|
||||
end2end_4xy = convert_coord(end2end_xyxy)
|
||||
master_4xy = convert_coord(master_xyxy)
|
||||
iou = cal_iou(end2end_4xy, master_4xy)
|
||||
if iou > max_iou:
|
||||
max_match[0], max_match[1] = end2end_xyxy_index, j
|
||||
max_iou = iou
|
||||
|
||||
if max_match[0] is None:
|
||||
# no match
|
||||
continue
|
||||
match_pair_list.append(max_match)
|
||||
return match_pair_list
|
||||
|
||||
|
||||
def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
|
||||
master_bboxes):
|
||||
"""
|
||||
Get matching between no-match end2end bboxes and no-match master bboxes.
|
||||
Use min distance to match.
|
||||
This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
|
||||
It will Return master_bboxes_nums match-pairs.
|
||||
:param end2end_indexes:
|
||||
:param end2end_bboxes:
|
||||
:param master_indexes:
|
||||
:param master_bboxes:
|
||||
:return: match_pairs list, e.g. [[0,1], [1,2], ...]
|
||||
"""
|
||||
min_match_list = []
|
||||
for j, master_bbox in zip(master_indexes, master_bboxes):
|
||||
min_distance = np.inf
|
||||
min_match = [0, 0] # i, j
|
||||
for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
|
||||
x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
|
||||
x_master, y_master = master_bbox[0], master_bbox[1]
|
||||
end2end_point = (x_end2end, y_end2end)
|
||||
master_point = (x_master, y_master)
|
||||
dist = cal_distance(master_point, end2end_point)
|
||||
if dist < min_distance:
|
||||
min_match[0], min_match[1] = i, j
|
||||
min_distance = dist
|
||||
min_match_list.append(min_match)
|
||||
return min_match_list
|
||||
|
||||
|
||||
def extra_match(no_match_end2end_indexes, master_bbox_nums):
|
||||
"""
|
||||
This function will create some virtual master bboxes,
|
||||
and get match with the no match end2end indexes.
|
||||
:param no_match_end2end_indexes:
|
||||
:param master_bbox_nums:
|
||||
:return:
|
||||
"""
|
||||
end_nums = len(no_match_end2end_indexes) + master_bbox_nums
|
||||
extra_match_list = []
|
||||
for i in range(master_bbox_nums, end_nums):
|
||||
end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
|
||||
extra_match_list.append([end2end_index, i])
|
||||
return extra_match_list
|
||||
|
||||
|
||||
def get_match_dict(match_list):
|
||||
"""
|
||||
Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
|
||||
:param match_list:
|
||||
:return:
|
||||
"""
|
||||
match_dict = dict()
|
||||
for match_pair in match_list:
|
||||
end2end_index, master_index = match_pair[0], match_pair[1]
|
||||
if master_index not in match_dict.keys():
|
||||
match_dict[master_index] = [end2end_index]
|
||||
else:
|
||||
match_dict[master_index].append(end2end_index)
|
||||
return match_dict
|
||||
|
||||
|
||||
def deal_successive_space(text):
|
||||
"""
|
||||
deal successive space character for text
|
||||
1. Replace ' '*3 with '<space>' which is real space is text
|
||||
2. Remove ' ', which is split token, not true space
|
||||
3. Replace '<space>' with ' ', to get real text
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
text = text.replace(' ' * 3, '<space>')
|
||||
text = text.replace(' ', '')
|
||||
text = text.replace('<space>', ' ')
|
||||
return text
|
||||
|
||||
|
||||
def reduce_repeat_bb(text_list, break_token):
|
||||
"""
|
||||
convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
|
||||
PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
|
||||
:param text_list:
|
||||
:param break_token:
|
||||
:return:
|
||||
"""
|
||||
count = 0
|
||||
for text in text_list:
|
||||
if text.startswith('<b>'):
|
||||
count += 1
|
||||
if count == len(text_list):
|
||||
new_text_list = []
|
||||
for text in text_list:
|
||||
text = text.replace('<b>', '').replace('</b>', '')
|
||||
new_text_list.append(text)
|
||||
return ['<b>' + break_token.join(new_text_list) + '</b>']
|
||||
else:
|
||||
return text_list
|
||||
|
||||
|
||||
def get_match_text_dict(match_dict, end2end_info, break_token=' '):
|
||||
match_text_dict = dict()
|
||||
for master_index, end2end_index_list in match_dict.items():
|
||||
text_list = [
|
||||
end2end_info[end2end_index]['text']
|
||||
for end2end_index in end2end_index_list
|
||||
]
|
||||
text_list = reduce_repeat_bb(text_list, break_token)
|
||||
text = break_token.join(text_list)
|
||||
match_text_dict[master_index] = text
|
||||
return match_text_dict
|
||||
|
||||
|
||||
def merge_span_token(master_token_list):
|
||||
"""
|
||||
Merge the span style token (row span or col span).
|
||||
:param master_token_list:
|
||||
:return:
|
||||
"""
|
||||
new_master_token_list = []
|
||||
pointer = 0
|
||||
if master_token_list[-1] != '</tbody>':
|
||||
master_token_list.append('</tbody>')
|
||||
while master_token_list[pointer] != '</tbody>':
|
||||
try:
|
||||
if master_token_list[pointer] == '<td':
|
||||
if master_token_list[pointer + 1].startswith(
|
||||
' colspan=') or master_token_list[
|
||||
pointer + 1].startswith(' rowspan='):
|
||||
"""
|
||||
example:
|
||||
pattern <td colspan="3">
|
||||
'<td' + 'colspan=" "' + '>' + '</td>'
|
||||
"""
|
||||
tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
|
||||
pointer += 4
|
||||
new_master_token_list.append(tmp)
|
||||
|
||||
elif master_token_list[pointer + 2].startswith(
|
||||
' colspan=') or master_token_list[
|
||||
pointer + 2].startswith(' rowspan='):
|
||||
"""
|
||||
example:
|
||||
pattern <td rowspan="2" colspan="3">
|
||||
'<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
|
||||
"""
|
||||
tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
|
||||
pointer += 5
|
||||
new_master_token_list.append(tmp)
|
||||
|
||||
else:
|
||||
new_master_token_list.append(master_token_list[pointer])
|
||||
pointer += 1
|
||||
else:
|
||||
new_master_token_list.append(master_token_list[pointer])
|
||||
pointer += 1
|
||||
except:
|
||||
print("Break in merge...")
|
||||
break
|
||||
new_master_token_list.append('</tbody>')
|
||||
|
||||
return new_master_token_list
|
||||
|
||||
|
||||
def deal_eb_token(master_token):
|
||||
"""
|
||||
post process with <eb></eb>, <eb1></eb1>, ...
|
||||
emptyBboxTokenDict = {
|
||||
"[]": '<eb></eb>',
|
||||
"[' ']": '<eb1></eb1>',
|
||||
"['<b>', ' ', '</b>']": '<eb2></eb2>',
|
||||
"['\\u2028', '\\u2028']": '<eb3></eb3>',
|
||||
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
|
||||
"['<b>', '</b>']": '<eb5></eb5>',
|
||||
"['<i>', ' ', '</i>']": '<eb6></eb6>',
|
||||
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
|
||||
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
|
||||
"['<i>', '</i>']": '<eb9></eb9>',
|
||||
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
|
||||
}
|
||||
:param master_token:
|
||||
:return:
|
||||
"""
|
||||
master_token = master_token.replace('<eb></eb>', '<td></td>')
|
||||
master_token = master_token.replace('<eb1></eb1>', '<td> </td>')
|
||||
master_token = master_token.replace('<eb2></eb2>', '<td><b> </b></td>')
|
||||
master_token = master_token.replace('<eb3></eb3>', '<td>\u2028\u2028</td>')
|
||||
master_token = master_token.replace('<eb4></eb4>', '<td><sup> </sup></td>')
|
||||
master_token = master_token.replace('<eb5></eb5>', '<td><b></b></td>')
|
||||
master_token = master_token.replace('<eb6></eb6>', '<td><i> </i></td>')
|
||||
master_token = master_token.replace('<eb7></eb7>',
|
||||
'<td><b><i></i></b></td>')
|
||||
master_token = master_token.replace('<eb8></eb8>',
|
||||
'<td><b><i> </i></b></td>')
|
||||
master_token = master_token.replace('<eb9></eb9>', '<td><i></i></td>')
|
||||
master_token = master_token.replace('<eb10></eb10>',
|
||||
'<td><b> \u2028 \u2028 </b></td>')
|
||||
return master_token
|
||||
|
||||
|
||||
def insert_text_to_token(master_token_list, match_text_dict):
|
||||
"""
|
||||
Insert OCR text result to structure token.
|
||||
:param master_token_list:
|
||||
:param match_text_dict:
|
||||
:return:
|
||||
"""
|
||||
master_token_list = merge_span_token(master_token_list)
|
||||
merged_result_list = []
|
||||
text_count = 0
|
||||
for master_token in master_token_list:
|
||||
if master_token.startswith('<td'):
|
||||
if text_count > len(match_text_dict) - 1:
|
||||
text_count += 1
|
||||
continue
|
||||
elif text_count not in match_text_dict.keys():
|
||||
text_count += 1
|
||||
continue
|
||||
else:
|
||||
master_token = master_token.replace(
|
||||
'><', '>{}<'.format(match_text_dict[text_count]))
|
||||
text_count += 1
|
||||
master_token = deal_eb_token(master_token)
|
||||
merged_result_list.append(master_token)
|
||||
|
||||
return ''.join(merged_result_list)
|
||||
|
||||
|
||||
def deal_isolate_span(thead_part):
|
||||
"""
|
||||
Deal with isolate span cases in this function.
|
||||
It causes by wrong prediction in structure recognition model.
|
||||
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
|
||||
:param thead_part:
|
||||
:return:
|
||||
"""
|
||||
# 1. find out isolate span tokens.
|
||||
isolate_pattern = "<td></td> rowspan=\"(\d)+\" colspan=\"(\d)+\"></b></td>|" \
|
||||
"<td></td> colspan=\"(\d)+\" rowspan=\"(\d)+\"></b></td>|" \
|
||||
"<td></td> rowspan=\"(\d)+\"></b></td>|" \
|
||||
"<td></td> colspan=\"(\d)+\"></b></td>"
|
||||
isolate_iter = re.finditer(isolate_pattern, thead_part)
|
||||
isolate_list = [i.group() for i in isolate_iter]
|
||||
|
||||
# 2. find out span number, by step 1 results.
|
||||
span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
|
||||
" colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
|
||||
" rowspan=\"(\d)+\"|" \
|
||||
" colspan=\"(\d)+\""
|
||||
corrected_list = []
|
||||
for isolate_item in isolate_list:
|
||||
span_part = re.search(span_pattern, isolate_item)
|
||||
spanStr_in_isolateItem = span_part.group()
|
||||
# 3. merge the span number into the span token format string.
|
||||
if spanStr_in_isolateItem is not None:
|
||||
corrected_item = '<td{}></td>'.format(spanStr_in_isolateItem)
|
||||
corrected_list.append(corrected_item)
|
||||
else:
|
||||
corrected_list.append(None)
|
||||
|
||||
# 4. replace original isolated token.
|
||||
for corrected_item, isolate_item in zip(corrected_list, isolate_list):
|
||||
if corrected_item is not None:
|
||||
thead_part = thead_part.replace(isolate_item, corrected_item)
|
||||
else:
|
||||
pass
|
||||
return thead_part
|
||||
|
||||
|
||||
def deal_duplicate_bb(thead_part):
|
||||
"""
|
||||
Deal duplicate <b> or </b> after replace.
|
||||
Keep one <b></b> in a <td></td> token.
|
||||
:param thead_part:
|
||||
:return:
|
||||
"""
|
||||
# 1. find out <td></td> in <thead></thead>.
|
||||
td_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">(.+?)</td>|" \
|
||||
"<td colspan=\"(\d)+\" rowspan=\"(\d)+\">(.+?)</td>|" \
|
||||
"<td rowspan=\"(\d)+\">(.+?)</td>|" \
|
||||
"<td colspan=\"(\d)+\">(.+?)</td>|" \
|
||||
"<td>(.*?)</td>"
|
||||
td_iter = re.finditer(td_pattern, thead_part)
|
||||
td_list = [t.group() for t in td_iter]
|
||||
|
||||
# 2. is multiply <b></b> in <td></td> or not?
|
||||
new_td_list = []
|
||||
for td_item in td_list:
|
||||
if td_item.count('<b>') > 1 or td_item.count('</b>') > 1:
|
||||
# multiply <b></b> in <td></td> case.
|
||||
# 1. remove all <b></b>
|
||||
td_item = td_item.replace('<b>', '').replace('</b>', '')
|
||||
# 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
|
||||
td_item = td_item.replace('<td>', '<td><b>').replace('</td>',
|
||||
'</b></td>')
|
||||
new_td_list.append(td_item)
|
||||
else:
|
||||
new_td_list.append(td_item)
|
||||
|
||||
# 3. replace original thead part.
|
||||
for td_item, new_td_item in zip(td_list, new_td_list):
|
||||
thead_part = thead_part.replace(td_item, new_td_item)
|
||||
return thead_part
|
||||
|
||||
|
||||
def deal_bb(result_token):
|
||||
"""
|
||||
In our opinion, <b></b> always occurs in <thead></thead> text's context.
|
||||
This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
|
||||
:param result_token:
|
||||
:return:
|
||||
"""
|
||||
# find out <thead></thead> parts.
|
||||
thead_pattern = '<thead>(.*?)</thead>'
|
||||
if re.search(thead_pattern, result_token) is None:
|
||||
return result_token
|
||||
thead_part = re.search(thead_pattern, result_token).group()
|
||||
origin_thead_part = copy.deepcopy(thead_part)
|
||||
|
||||
# check "rowspan" or "colspan" occur in <thead></thead> parts or not .
|
||||
span_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">|<td colspan=\"(\d)+\" rowspan=\"(\d)+\">|<td rowspan=\"(\d)+\">|<td colspan=\"(\d)+\">"
|
||||
span_iter = re.finditer(span_pattern, thead_part)
|
||||
span_list = [s.group() for s in span_iter]
|
||||
has_span_in_head = True if len(span_list) > 0 else False
|
||||
|
||||
if not has_span_in_head:
|
||||
# <thead></thead> not include "rowspan" or "colspan" branch 1.
|
||||
# 1. replace <td> to <td><b>, and </td> to </b></td>
|
||||
# 2. it is possible to predict text include <b> or </b> by Text-line recognition,
|
||||
# so we replace <b><b> to <b>, and </b></b> to </b>
|
||||
thead_part = thead_part.replace('<td>', '<td><b>')\
|
||||
.replace('</td>', '</b></td>')\
|
||||
.replace('<b><b>', '<b>')\
|
||||
.replace('</b></b>', '</b>')
|
||||
else:
|
||||
# <thead></thead> include "rowspan" or "colspan" branch 2.
|
||||
# Firstly, we deal rowspan or colspan cases.
|
||||
# 1. replace > to ><b>
|
||||
# 2. replace </td> to </b></td>
|
||||
# 3. it is possible to predict text include <b> or </b> by Text-line recognition,
|
||||
# so we replace <b><b> to <b>, and </b><b> to </b>
|
||||
|
||||
# Secondly, deal ordinary cases like branch 1
|
||||
|
||||
# replace ">" to "<b>"
|
||||
replaced_span_list = []
|
||||
for sp in span_list:
|
||||
replaced_span_list.append(sp.replace('>', '><b>'))
|
||||
for sp, rsp in zip(span_list, replaced_span_list):
|
||||
thead_part = thead_part.replace(sp, rsp)
|
||||
|
||||
# replace "</td>" to "</b></td>"
|
||||
thead_part = thead_part.replace('</td>', '</b></td>')
|
||||
|
||||
# remove duplicated <b> by re.sub
|
||||
mb_pattern = "(<b>)+"
|
||||
single_b_string = "<b>"
|
||||
thead_part = re.sub(mb_pattern, single_b_string, thead_part)
|
||||
|
||||
mgb_pattern = "(</b>)+"
|
||||
single_gb_string = "</b>"
|
||||
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
|
||||
|
||||
# ordinary cases like branch 1
|
||||
thead_part = thead_part.replace('<td>', '<td><b>').replace('<b><b>',
|
||||
'<b>')
|
||||
|
||||
# convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
|
||||
# but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
|
||||
thead_part = thead_part.replace('<td><b></b></td>', '<td></td>')
|
||||
# deal with duplicated <b></b>
|
||||
thead_part = deal_duplicate_bb(thead_part)
|
||||
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
|
||||
# eg.PMC5994107_011_00.png
|
||||
thead_part = deal_isolate_span(thead_part)
|
||||
# replace original result with new thead part.
|
||||
result_token = result_token.replace(origin_thead_part, thead_part)
|
||||
return result_token
|
||||
|
||||
|
||||
class Matcher:
|
||||
def __init__(self, end2end_file, structure_master_file):
|
||||
"""
|
||||
This class process the end2end results and structure recognition results.
|
||||
:param end2end_file: end2end results predict by end2end inference.
|
||||
:param structure_master_file: structure recognition results predict by structure master inference.
|
||||
"""
|
||||
self.end2end_file = end2end_file
|
||||
self.structure_master_file = structure_master_file
|
||||
self.end2end_results = pickle_load(end2end_file, prefix='end2end')
|
||||
self.structure_master_results = pickle_load(
|
||||
structure_master_file, prefix='structure')
|
||||
|
||||
def match(self):
|
||||
"""
|
||||
Match process:
|
||||
pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
|
||||
1. Use pseBbox is inside masterBbox judge rule
|
||||
2. Use iou between pseBbox and masterBbox rule
|
||||
3. Use min distance of center point rule
|
||||
:return:
|
||||
"""
|
||||
match_results = dict()
|
||||
for idx, (file_name,
|
||||
end2end_result) in enumerate(self.end2end_results.items()):
|
||||
match_list = []
|
||||
if file_name not in self.structure_master_results:
|
||||
continue
|
||||
structure_master_result = self.structure_master_results[file_name]
|
||||
end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
|
||||
get_bboxes_list(end2end_result, structure_master_result)
|
||||
|
||||
# rule 1: center rule
|
||||
center_rule_match_list = \
|
||||
center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
|
||||
match_list.extend(center_rule_match_list)
|
||||
|
||||
# rule 2: iou rule
|
||||
# firstly, find not match index in previous step.
|
||||
center_no_match_end2end_indexs = \
|
||||
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
|
||||
if len(center_no_match_end2end_indexs) > 0:
|
||||
center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
|
||||
center_no_match_end2end_indexs]
|
||||
# secondly, iou rule match
|
||||
iou_rule_match_list = \
|
||||
iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
|
||||
match_list.extend(iou_rule_match_list)
|
||||
|
||||
# rule 3: distance rule
|
||||
# match between no-match end2end bboxes and no-match master bboxes.
|
||||
# it will return master_bboxes_nums match-pairs.
|
||||
# firstly, find not match index in previous step.
|
||||
centerIou_no_match_end2end_indexs = \
|
||||
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
|
||||
centerIou_no_match_master_indexs = \
|
||||
find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
|
||||
if len(centerIou_no_match_master_indexs) > 0 and len(
|
||||
centerIou_no_match_end2end_indexs) > 0:
|
||||
centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
|
||||
centerIou_no_match_end2end_indexs]
|
||||
centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
|
||||
centerIou_no_match_master_indexs]
|
||||
distance_match_list = distance_rule_match(
|
||||
centerIou_no_match_end2end_indexs,
|
||||
centerIou_no_match_end2end_xywh,
|
||||
centerIou_no_match_master_indexs,
|
||||
centerIou_no_match_master_xywh)
|
||||
match_list.extend(distance_match_list)
|
||||
|
||||
# TODO:
|
||||
# The render no-match pseBbox, insert the last
|
||||
# After step3 distance rule, a master bbox at least match one end2end bbox.
|
||||
# But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
|
||||
# For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
|
||||
# The above extra insert bboxes will be further processed in "formatOutput" function.
|
||||
# After this operation, it will increase TEDS score.
|
||||
no_match_end2end_indexes = \
|
||||
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
|
||||
if len(no_match_end2end_indexes) > 0:
|
||||
no_match_end2end_xywh = end2end_xywh_bboxes[
|
||||
no_match_end2end_indexes]
|
||||
# sort the render no-match end2end bbox in row
|
||||
end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
|
||||
sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
|
||||
# make virtual master bboxes, and get matching with the no-match end2end bboxes.
|
||||
extra_match_list = extra_match(
|
||||
end2end_sorted_indexes_list,
|
||||
len(structure_master_xywh_bboxes))
|
||||
match_list_add_extra_match = copy.deepcopy(match_list)
|
||||
match_list_add_extra_match.extend(extra_match_list)
|
||||
else:
|
||||
# no no-match end2end bboxes
|
||||
match_list_add_extra_match = copy.deepcopy(match_list)
|
||||
sorted_groups = []
|
||||
sorted_bboxes_groups = []
|
||||
|
||||
match_result_dict = {
|
||||
'match_list': match_list,
|
||||
'match_list_add_extra_match': match_list_add_extra_match,
|
||||
'sorted_groups': sorted_groups,
|
||||
'sorted_bboxes_groups': sorted_bboxes_groups
|
||||
}
|
||||
|
||||
# format output
|
||||
match_result_dict = self._format(match_result_dict, file_name)
|
||||
|
||||
match_results[file_name] = match_result_dict
|
||||
|
||||
return match_results
|
||||
|
||||
def _format(self, match_result, file_name):
|
||||
"""
|
||||
Extend the master token(insert virtual master token), and format matching result.
|
||||
:param match_result:
|
||||
:param file_name:
|
||||
:return:
|
||||
"""
|
||||
end2end_info = self.end2end_results[file_name]
|
||||
master_info = self.structure_master_results[file_name]
|
||||
master_token = master_info['text']
|
||||
sorted_groups = match_result['sorted_groups']
|
||||
|
||||
# creat virtual master token
|
||||
virtual_master_token_list = []
|
||||
for line_group in sorted_groups:
|
||||
tmp_list = ['<tr>']
|
||||
item_nums = len(line_group)
|
||||
for _ in range(item_nums):
|
||||
tmp_list.append('<td></td>')
|
||||
tmp_list.append('</tr>')
|
||||
virtual_master_token_list.extend(tmp_list)
|
||||
|
||||
# insert virtual master token
|
||||
master_token_list = master_token.split(',')
|
||||
if master_token_list[-1] == '</tbody>':
|
||||
# complete predict(no cut by max length)
|
||||
# This situation insert virtual master token will drop TEDs score in val set.
|
||||
# So we will not extend virtual token in this situation.
|
||||
|
||||
# fake extend virtual
|
||||
master_token_list[:-1].extend(virtual_master_token_list)
|
||||
|
||||
# real extend virtual
|
||||
# master_token_list = master_token_list[:-1]
|
||||
# master_token_list.extend(virtual_master_token_list)
|
||||
# master_token_list.append('</tbody>')
|
||||
|
||||
elif master_token_list[-1] == '<td></td>':
|
||||
master_token_list.append('</tr>')
|
||||
master_token_list.extend(virtual_master_token_list)
|
||||
master_token_list.append('</tbody>')
|
||||
else:
|
||||
master_token_list.extend(virtual_master_token_list)
|
||||
master_token_list.append('</tbody>')
|
||||
|
||||
# format output
|
||||
match_result.setdefault('matched_master_token_list', master_token_list)
|
||||
return match_result
|
||||
|
||||
def get_merge_result(self, match_results):
|
||||
"""
|
||||
Merge the OCR result into structure token to get final results.
|
||||
:param match_results:
|
||||
:return:
|
||||
"""
|
||||
merged_results = dict()
|
||||
|
||||
# break_token is linefeed token, when one master bbox has multiply end2end bboxes.
|
||||
break_token = ' '
|
||||
|
||||
for idx, (file_name, match_info) in enumerate(match_results.items()):
|
||||
end2end_info = self.end2end_results[file_name]
|
||||
master_token_list = match_info['matched_master_token_list']
|
||||
match_list = match_info['match_list_add_extra_match']
|
||||
|
||||
match_dict = get_match_dict(match_list)
|
||||
match_text_dict = get_match_text_dict(match_dict, end2end_info,
|
||||
break_token)
|
||||
merged_result = insert_text_to_token(master_token_list,
|
||||
match_text_dict)
|
||||
merged_result = deal_bb(merged_result)
|
||||
|
||||
merged_results[file_name] = merged_result
|
||||
|
||||
return merged_results
|
||||
|
||||
|
||||
class TableMasterMatcher(Matcher):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
|
||||
end2end_results = {img_name: []}
|
||||
for dt_box, res in zip(dt_boxes, rec_res):
|
||||
d = dict(
|
||||
bbox=np.array(dt_box),
|
||||
text=res[0], )
|
||||
end2end_results[img_name].append(d)
|
||||
|
||||
self.end2end_results = end2end_results
|
||||
|
||||
structure_master_result_dict = {img_name: {}}
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
pred_structures = ','.join(pred_structures[3:-3])
|
||||
structure_master_result_dict[img_name]['text'] = pred_structures
|
||||
structure_master_result_dict[img_name]['bbox'] = pred_bboxes
|
||||
self.structure_master_results = structure_master_result_dict
|
||||
|
||||
# match
|
||||
match_results = self.match()
|
||||
merged_results = self.get_merge_result(match_results)
|
||||
pred_html = merged_results[img_name]
|
||||
pred_html = '<html><body><table>' + pred_html + '</table></body></html>'
|
||||
return pred_html
|
|
@ -27,6 +27,8 @@ def init_args():
|
|||
parser.add_argument("--table_max_len", type=int, default=488)
|
||||
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
|
||||
parser.add_argument("--table_model_dir", type=str)
|
||||
parser.add_argument(
|
||||
"--merge_no_span_structure", type=str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--table_char_dict_path",
|
||||
type=str,
|
||||
|
@ -36,14 +38,17 @@ def init_args():
|
|||
parser.add_argument(
|
||||
"--layout_dict_path",
|
||||
type=str,
|
||||
default="../ppocr/utils/dict/layout_pubalynet_dict.txt")
|
||||
default="../ppocr/utils/dict/layout_publaynet_dict.txt")
|
||||
parser.add_argument(
|
||||
"--layout_score_threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Threshold of score.")
|
||||
parser.add_argument(
|
||||
"--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms.")
|
||||
"--layout_nms_threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Threshold of nms.")
|
||||
# params for vqa
|
||||
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
|
||||
parser.add_argument("--ser_model_dir", type=str)
|
||||
|
@ -59,6 +64,11 @@ def init_args():
|
|||
type=str,
|
||||
default='structure',
|
||||
help='structure and vqa is supported')
|
||||
parser.add_argument(
|
||||
"--image_orientation",
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Whether to enable image orientation recognition')
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str2bool,
|
||||
|
|
|
@ -41,7 +41,11 @@ logger = get_logger()
|
|||
class SerPredictor(object):
|
||||
def __init__(self, args):
|
||||
self.ocr_engine = PaddleOCR(
|
||||
use_angle_cls=False, show_log=False, use_gpu=args.use_gpu)
|
||||
use_angle_cls=args.use_angle_cls,
|
||||
det_model_dir=args.det_model_dir,
|
||||
rec_model_dir=args.rec_model_dir,
|
||||
show_log=False,
|
||||
use_gpu=args.use_gpu)
|
||||
|
||||
pre_process_list = [{
|
||||
'VQATokenLabelEncode': {
|
||||
|
|
|
@ -58,10 +58,11 @@ function status_check(){
|
|||
run_command=$2
|
||||
run_log=$3
|
||||
model_name=$4
|
||||
log_path=$5
|
||||
if [ $last_status -eq 0 ]; then
|
||||
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
|
||||
else
|
||||
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
===========================train_params===========================
|
||||
model_name:det_r18_db_v2_0
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.checkpoints:
|
||||
norm_export:null
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
train_model:null
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--det_model_dir:
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
--save_log_path:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
|
||||
===========================train_benchmark_params==========================
|
||||
batch_size:8|16
|
||||
fp_items:fp32|fp16
|
||||
epoch:15
|
||||
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
|
|
@ -19,8 +19,6 @@ Global:
|
|||
character_type: en
|
||||
max_text_length: 800
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -52,7 +52,7 @@ null:null
|
|||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,224,224]}]
|
||||
===========================train_benchmark_params==========================
|
||||
batch_size:4
|
||||
batch_size:8
|
||||
fp_items:fp32|fp16
|
||||
epoch:3
|
||||
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
|
||||
|
|
|
@ -16,8 +16,6 @@ Global:
|
|||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||
infer_mode: false
|
||||
max_text_length: 500
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -86,7 +84,7 @@ Train:
|
|||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: 'xywh'
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
@ -120,7 +118,7 @@ Eval:
|
|||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: 'xywh'
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
===========================train_params===========================
|
||||
model_name:vi_layoutxlm_ser
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:fp32
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
|
||||
Architecture.Backbone.checkpoints:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Architecture.Backbone.checkpoints:
|
||||
norm_export:tools/export_model.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o
|
||||
quant_export:
|
||||
fpgm_export:
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:null
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:ppstructure/vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output --ocr_order_method=tb-yx
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:False
|
||||
--cpu_threads:6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False
|
||||
--precision:fp32
|
||||
--ser_model_dir:
|
||||
--image_dir:./ppstructure/docs/vqa/input/zh_val_42.jpg
|
||||
null:null
|
||||
--benchmark:False
|
||||
null:null
|
||||
===========================infer_benchmark_params==========================
|
||||
random_infer_input:[{float32,[3,224,224]}]
|
||||
===========================train_benchmark_params==========================
|
||||
batch_size:4
|
||||
fp_items:fp32|fp16
|
||||
epoch:3
|
||||
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
|
||||
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98
|
|
@ -106,7 +106,7 @@ if [ ${MODE} = "benchmark_train" ];then
|
|||
ln -s ./icdar2015_benckmark ./icdar2015
|
||||
cd ../
|
||||
fi
|
||||
if [ ${model_name} == "layoutxlm_ser" ]; then
|
||||
if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then
|
||||
pip install -r ppstructure/vqa/requirements.txt
|
||||
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
|
||||
|
@ -220,7 +220,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
|
|||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
|
||||
cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
|
||||
fi
|
||||
if [ ${model_name} == "layoutxlm_ser" ]; then
|
||||
if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then
|
||||
pip install -r ppstructure/vqa/requirements.txt
|
||||
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
|
||||
|
|
|
@ -84,7 +84,7 @@ function func_cpp_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -117,7 +117,7 @@ function func_cpp_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
|
||||
done
|
||||
done
|
||||
|
|
|
@ -88,7 +88,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -119,7 +119,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
|
||||
done
|
||||
done
|
||||
|
@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then
|
|||
for infer_model in ${infer_model_dir_list[*]}; do
|
||||
# run export
|
||||
if [ ${infer_run_exports[Count]} != "null" ];then
|
||||
_save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log"
|
||||
save_infer_dir=$(dirname $infer_model)
|
||||
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
|
||||
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
|
||||
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
|
||||
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 "
|
||||
echo ${infer_run_exports[Count]}
|
||||
eval $export_cmd
|
||||
status_export=$?
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
else
|
||||
save_infer_dir=${infer_model}
|
||||
fi
|
||||
|
|
|
@ -66,7 +66,7 @@ function func_paddle2onnx(){
|
|||
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
|
||||
eval $trans_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
|
||||
# trans rec
|
||||
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
|
@ -78,7 +78,7 @@ function func_paddle2onnx(){
|
|||
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
|
||||
eval $trans_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
|
||||
elif [[ ${model_name} =~ "det" ]]; then
|
||||
# trans det
|
||||
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
|
||||
|
@ -91,7 +91,7 @@ function func_paddle2onnx(){
|
|||
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
|
||||
eval $trans_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
|
||||
elif [[ ${model_name} =~ "rec" ]]; then
|
||||
# trans rec
|
||||
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
|
||||
|
@ -104,7 +104,7 @@ function func_paddle2onnx(){
|
|||
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
|
||||
eval $trans_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
|
||||
fi
|
||||
|
||||
# python inference
|
||||
|
@ -127,7 +127,7 @@ function func_paddle2onnx(){
|
|||
eval $infer_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
|
||||
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
|
||||
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
|
||||
|
@ -146,7 +146,7 @@ function func_paddle2onnx(){
|
|||
eval $infer_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
else
|
||||
echo "Does not support hardware other than CPU and GPU Currently!"
|
||||
fi
|
||||
|
@ -158,4 +158,4 @@ echo "################### run test ###################"
|
|||
|
||||
export Count=0
|
||||
IFS="|"
|
||||
func_paddle2onnx
|
||||
func_paddle2onnx
|
||||
|
|
|
@ -84,7 +84,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -109,7 +109,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
|
||||
done
|
||||
done
|
||||
|
@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then
|
|||
echo $export_cmd
|
||||
eval $export_cmd
|
||||
status_export=$?
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
|
||||
else
|
||||
save_infer_dir=${infer_model}
|
||||
fi
|
||||
|
|
|
@ -83,7 +83,7 @@ function func_serving(){
|
|||
trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
|
||||
eval $trans_model_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
|
||||
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
|
||||
python_list=(${python_list})
|
||||
cd ${serving_dir_value}
|
||||
|
@ -95,14 +95,14 @@ function func_serving(){
|
|||
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cpp_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/cpp_client_cpu.log"
|
||||
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
|
||||
eval $cpp_client_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
|
||||
else
|
||||
server_log_path="${LOG_PATH}/cpp_server_gpu.log"
|
||||
|
@ -114,7 +114,7 @@ function func_serving(){
|
|||
eval $cpp_client_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -126,19 +126,19 @@ function func_serving(){
|
|||
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
elif [[ ${model_name} =~ "det" ]]; then
|
||||
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
|
||||
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
elif [[ ${model_name} =~ "rec" ]]; then
|
||||
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
|
||||
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
fi
|
||||
sleep 2s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
|
@ -147,7 +147,7 @@ function func_serving(){
|
|||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
sleep 2s
|
||||
done
|
||||
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
|
||||
|
@ -177,19 +177,19 @@ function func_serving(){
|
|||
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
elif [[ ${model_name} =~ "det" ]]; then
|
||||
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
|
||||
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
elif [[ ${model_name} =~ "rec" ]]; then
|
||||
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
|
||||
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
|
||||
eval $web_service_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
|
||||
fi
|
||||
sleep 2s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
|
@ -198,7 +198,7 @@ function func_serving(){
|
|||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
sleep 2s
|
||||
done
|
||||
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
|
||||
|
|
|
@ -133,7 +133,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -164,7 +164,7 @@ function func_inference(){
|
|||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}"
|
||||
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
|
||||
|
||||
done
|
||||
done
|
||||
|
@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then
|
|||
echo $export_cmd
|
||||
eval $export_cmd
|
||||
status_export=$?
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
|
||||
else
|
||||
save_infer_dir=${infer_model}
|
||||
fi
|
||||
|
@ -298,7 +298,7 @@ else
|
|||
# run train
|
||||
eval $cmd
|
||||
eval "cat ${save_log}/train.log >> ${save_log}.log"
|
||||
status_check $? "${cmd}" "${status_log}" "${model_name}"
|
||||
status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log"
|
||||
|
||||
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
|
||||
|
||||
|
@ -309,7 +309,7 @@ else
|
|||
eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
|
||||
eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 "
|
||||
eval $eval_cmd
|
||||
status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}"
|
||||
fi
|
||||
# run export model
|
||||
if [ ${run_export} != "null" ]; then
|
||||
|
@ -320,7 +320,7 @@ else
|
|||
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
|
||||
export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
|
||||
eval $export_cmd
|
||||
status_check $? "${export_cmd}" "${status_log}" "${model_name}"
|
||||
status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
|
||||
|
||||
#run inference
|
||||
eval $env
|
||||
|
|
|
@ -58,6 +58,8 @@ def export_single_model(model,
|
|||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
[paddle.static.InputSpec(
|
||||
shape=[None], dtype="float32")]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SVTR":
|
||||
|
@ -144,7 +146,7 @@ def export_single_model(model,
|
|||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
infer_shape = [3, 48, -1] # for rec model, H must be 32
|
||||
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||
if "Transform" in arch_config and arch_config[
|
||||
"Transform"] is not None and arch_config["Transform"][
|
||||
"name"] == "TPS":
|
||||
|
@ -156,6 +158,8 @@ def export_single_model(model,
|
|||
infer_shape = [3, 488, 488]
|
||||
if arch_config["algorithm"] == "TableMaster":
|
||||
infer_shape = [3, 480, 480]
|
||||
if arch_config["algorithm"] == "SLANet":
|
||||
infer_shape = [3, -1, -1]
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
|
@ -248,4 +252,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -458,7 +458,8 @@ class TextRecognizer(object):
|
|||
valid_ratios = np.concatenate(valid_ratios)
|
||||
inputs = [
|
||||
norm_img_batch,
|
||||
valid_ratios,
|
||||
np.array(
|
||||
[valid_ratios], dtype=np.float32),
|
||||
]
|
||||
if self.use_onnx:
|
||||
input_dict = {}
|
||||
|
|
|
@ -65,9 +65,11 @@ class TextSystem(object):
|
|||
self.crop_image_res_index += bbox_num
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
||||
time_dict['det'] = elapse
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
|
@ -83,10 +85,12 @@ class TextSystem(object):
|
|||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
time_dict['cls'] = elapse
|
||||
logger.debug("cls num : {}, elapse : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict['rec'] = elapse
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.args.save_crop_res:
|
||||
|
@ -98,7 +102,9 @@ class TextSystem(object):
|
|||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
return filter_boxes, filter_rec_res
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return filter_boxes, filter_rec_res, time_dict
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
|
@ -133,9 +139,11 @@ def main(args):
|
|||
os.makedirs(draw_img_save_dir, exist_ok=True)
|
||||
save_results = []
|
||||
|
||||
logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
||||
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
|
||||
|
||||
logger.info(
|
||||
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
||||
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
|
||||
)
|
||||
|
||||
# warm up 10 times
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
|
@ -155,7 +163,7 @@ def main(args):
|
|||
logger.debug("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res = text_sys(img)
|
||||
dt_boxes, rec_res, time_dict = text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
total_time += elapse
|
||||
|
||||
|
@ -198,7 +206,10 @@ def main(args):
|
|||
text_sys.text_detector.autolog.report()
|
||||
text_sys.text_recognizer.autolog.report()
|
||||
|
||||
with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
|
||||
with open(
|
||||
os.path.join(draw_img_save_dir, "system_results.txt"),
|
||||
'w',
|
||||
encoding='utf-8') as f:
|
||||
f.writelines(save_results)
|
||||
|
||||
|
||||
|
|
|
@ -163,6 +163,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.ser_model_dir
|
||||
elif mode == "sr":
|
||||
model_dir = args.sr_model_dir
|
||||
elif mode == 'layout':
|
||||
model_dir = args.layout_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ from ppocr.postprocess import build_post_process
|
|||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
from ppocr.utils.visual import draw_rectangle
|
||||
from tools.infer.utility import draw_boxes
|
||||
import tools.program as program
|
||||
import cv2
|
||||
|
||||
|
@ -56,7 +57,6 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
use_xywh = algorithm in ['TableMaster']
|
||||
|
||||
load_model(config, model)
|
||||
|
||||
|
@ -106,9 +106,13 @@ def main(config, device, logger, vdl_writer):
|
|||
f_w.write("result: {}, {}\n".format(structure_str_list,
|
||||
bbox_list_str))
|
||||
|
||||
img = draw_rectangle(file, bbox_list, use_xywh)
|
||||
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
|
||||
img = draw_rectangle(file, bbox_list)
|
||||
else:
|
||||
img = draw_boxes(cv2.imread(file), bbox_list)
|
||||
cv2.imwrite(
|
||||
os.path.join(save_res_path, os.path.basename(file)), img)
|
||||
logger.info('save result to {}'.format(save_res_path))
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
|
@ -653,7 +653,7 @@ def preprocess(is_train=False):
|
|||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||
'Gestalt', 'RobustScanner'
|
||||
'Gestalt', 'SLANet', 'RobustScanner'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
|
@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
|
|||
config['Loss']['ignore_index'] = char_num - 1
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
use_sync_bn = config["Global"].get("use_sync_bn", False)
|
||||
if use_sync_bn:
|
||||
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
logger.info('convert_sync_batchnorm')
|
||||
|
||||
model = apply_to_static(model, config, logger)
|
||||
|
||||
|
|
Loading…
Reference in New Issue