add table cpp infer
parent
97f7f74808
commit
3867c8cc87
|
@ -30,7 +30,8 @@ DECLARE_string(image_dir);
|
|||
DECLARE_string(type);
|
||||
// detection related
|
||||
DECLARE_string(det_model_dir);
|
||||
DECLARE_int32(max_side_len);
|
||||
DECLARE_string(limit_type);
|
||||
DECLARE_int32(limit_side_len);
|
||||
DECLARE_double(det_db_thresh);
|
||||
DECLARE_double(det_db_box_thresh);
|
||||
DECLARE_double(det_db_unclip_ratio);
|
||||
|
@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num);
|
|||
DECLARE_string(rec_char_dict_path);
|
||||
DECLARE_int32(rec_img_h);
|
||||
DECLARE_int32(rec_img_w);
|
||||
// structure model related
|
||||
DECLARE_string(table_model_dir);
|
||||
DECLARE_int32(table_max_len);
|
||||
DECLARE_int32(table_batch_num);
|
||||
DECLARE_string(table_char_dict_path);
|
||||
// forward related
|
||||
DECLARE_bool(det);
|
||||
DECLARE_bool(rec);
|
||||
DECLARE_bool(cls);
|
||||
DECLARE_bool(table);
|
|
@ -41,8 +41,8 @@ public:
|
|||
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const int &max_side_len,
|
||||
const double &det_db_thresh,
|
||||
const bool &use_mkldnn, const string &limit_type,
|
||||
const int &limit_side_len, const double &det_db_thresh,
|
||||
const double &det_db_box_thresh,
|
||||
const double &det_db_unclip_ratio,
|
||||
const std::string &det_db_score_mode,
|
||||
|
@ -54,7 +54,8 @@ public:
|
|||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
|
||||
this->max_side_len_ = max_side_len;
|
||||
this->limit_type_ = limit_type;
|
||||
this->limit_side_len_ = limit_side_len;
|
||||
|
||||
this->det_db_thresh_ = det_db_thresh;
|
||||
this->det_db_box_thresh_ = det_db_box_thresh;
|
||||
|
@ -84,7 +85,8 @@ private:
|
|||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
|
||||
int max_side_len_ = 960;
|
||||
string limit_type_ = "max";
|
||||
int limit_side_len_ = 960;
|
||||
|
||||
double det_db_thresh_ = 0.3;
|
||||
double det_db_box_thresh_ = 0.5;
|
||||
|
@ -106,7 +108,7 @@ private:
|
|||
Permute permute_op_;
|
||||
|
||||
// post-process
|
||||
PostProcessor post_processor_;
|
||||
DBPostProcessor post_processor_;
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -47,11 +47,7 @@ public:
|
|||
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
|
||||
bool rec = true, bool cls = true);
|
||||
|
||||
private:
|
||||
DBDetector *detector_ = nullptr;
|
||||
Classifier *classifier_ = nullptr;
|
||||
CRNNRecognizer *recognizer_ = nullptr;
|
||||
|
||||
protected:
|
||||
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
|
||||
std::vector<double> ×);
|
||||
void rec(std::vector<cv::Mat> img_list,
|
||||
|
@ -62,6 +58,11 @@ private:
|
|||
std::vector<double> ×);
|
||||
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
|
||||
std::vector<double> &cls_times, int img_num);
|
||||
|
||||
private:
|
||||
DBDetector *detector_ = nullptr;
|
||||
Classifier *classifier_ = nullptr;
|
||||
CRNNRecognizer *recognizer_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/paddleocr.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/structure_table.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class PaddleStructure : public PPOCR {
|
||||
public:
|
||||
explicit PaddleStructure();
|
||||
~PaddleStructure();
|
||||
std::vector<std::vector<StructurePredictResult>>
|
||||
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
|
||||
bool table = true);
|
||||
|
||||
private:
|
||||
StructureTableRecognizer *recognizer_ = nullptr;
|
||||
|
||||
void table(cv::Mat img, StructurePredictResult &structure_result,
|
||||
std::vector<double> &time_info_table,
|
||||
std::vector<double> &time_info_det,
|
||||
std::vector<double> &time_info_rec,
|
||||
std::vector<double> &time_info_cls);
|
||||
std::string
|
||||
rebuild_table(std::vector<std::string> rec_html_tags,
|
||||
std::vector<std::vector<std::vector<int>>> rec_boxes,
|
||||
std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
float iou(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2);
|
||||
float dis(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2);
|
||||
|
||||
static bool comparison_dis(const std::vector<float> &dis1,
|
||||
const std::vector<float> &dis2) {
|
||||
if (dis1[1] < dis2[1]) {
|
||||
return true;
|
||||
} else if (dis1[1] == dis2[1]) {
|
||||
return dis1[0] < dis2[0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -34,7 +34,7 @@ using namespace std;
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class PostProcessor {
|
||||
class DBPostProcessor {
|
||||
public:
|
||||
void GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance);
|
||||
|
@ -90,4 +90,21 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
class TablePostProcessor {
|
||||
public:
|
||||
void init(std::string label_path);
|
||||
void
|
||||
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
|
||||
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
|
||||
std::vector<int> &structure_probs_shape,
|
||||
std::vector<std::vector<std::string>> &rec_html_tag_batch,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
|
||||
std::vector<int> &width_list, std::vector<int> &height_list);
|
||||
|
||||
private:
|
||||
std::vector<std::string> label_list_;
|
||||
std::string end = "eos";
|
||||
std::string beg = "sos";
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -51,8 +51,9 @@ public:
|
|||
|
||||
class ResizeImgType0 {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
|
||||
float &ratio_h, float &ratio_w, bool use_tensorrt);
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type,
|
||||
int limit_side_len, float &ratio_h, float &ratio_w,
|
||||
bool use_tensorrt);
|
||||
};
|
||||
|
||||
class CrnnResizeImg {
|
||||
|
@ -69,4 +70,16 @@ public:
|
|||
const std::vector<int> &rec_image_shape = {3, 48, 192});
|
||||
};
|
||||
|
||||
class TableResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len = 488);
|
||||
};
|
||||
|
||||
class TablePadImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len = 488);
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class StructureTableRecognizer {
|
||||
public:
|
||||
explicit StructureTableRecognizer(
|
||||
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
|
||||
const int &gpu_mem, const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const string &label_path,
|
||||
const bool &use_tensorrt, const std::string &precision,
|
||||
const int &table_batch_num, const int &table_max_len) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_tensorrt_ = use_tensorrt;
|
||||
this->precision_ = precision;
|
||||
this->table_batch_num_ = table_batch_num;
|
||||
this->table_max_len_ = table_max_len;
|
||||
|
||||
this->post_processor_.init(label_path);
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
||||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<cv::Mat> img_list,
|
||||
std::vector<std::vector<std::string>> &rec_html_tags,
|
||||
std::vector<float> &rec_scores,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes,
|
||||
std::vector<double> ×);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
int table_max_len_ = 488;
|
||||
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
bool use_tensorrt_ = false;
|
||||
std::string precision_ = "fp32";
|
||||
int table_batch_num_ = 1;
|
||||
|
||||
// pre-process
|
||||
TableResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
PermuteBatch permute_op_;
|
||||
TablePadImg pad_op_;
|
||||
|
||||
// post-process
|
||||
TablePostProcessor post_processor_;
|
||||
|
||||
}; // class StructureTableRecognizer
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -40,6 +40,14 @@ struct OCRPredictResult {
|
|||
int cls_label = -1;
|
||||
};
|
||||
|
||||
struct StructurePredictResult {
|
||||
std::vector<int> box;
|
||||
std::string type;
|
||||
std::vector<OCRPredictResult> text_res;
|
||||
std::string html;
|
||||
float html_score = -1;
|
||||
};
|
||||
|
||||
class Utility {
|
||||
public:
|
||||
static std::vector<std::string> ReadDict(const std::string &path);
|
||||
|
@ -68,6 +76,22 @@ public:
|
|||
static void CreateDir(const std::string &path);
|
||||
|
||||
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area);
|
||||
|
||||
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
|
||||
|
||||
private:
|
||||
static bool comparison_box(const OCRPredictResult &result1,
|
||||
const OCRPredictResult &result2) {
|
||||
if (result1.box[0][1] < result2.box[0][1]) {
|
||||
return true;
|
||||
} else if (result1.box[0][1] == result2.box[0][1]) {
|
||||
return result1.box[0][0] < result2.box[0][0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -30,7 +30,8 @@ DEFINE_string(
|
|||
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
|
||||
// detection related
|
||||
DEFINE_string(det_model_dir, "", "Path of det inference model.");
|
||||
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
|
||||
DEFINE_string(limit_type, "max", "limit_type of input image.");
|
||||
DEFINE_int32(limit_side_len, 960, "limit_side_len of input image.");
|
||||
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
|
||||
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
|
||||
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
|
||||
|
@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
|
|||
DEFINE_int32(rec_img_h, 48, "rec image height");
|
||||
DEFINE_int32(rec_img_w, 320, "rec image width");
|
||||
|
||||
// structure model related
|
||||
DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
|
||||
DEFINE_int32(table_max_len, 488, "max len size of input image.");
|
||||
DEFINE_int32(table_batch_num, 1, "table_batch_num.");
|
||||
DEFINE_string(table_char_dict_path,
|
||||
"../../ppocr/utils/dict/table_structure_dict.txt",
|
||||
"Path of dictionary.");
|
||||
|
||||
// ocr forward related
|
||||
DEFINE_bool(det, true, "Whether use det in forward.");
|
||||
DEFINE_bool(rec, true, "Whether use rec in forward.");
|
||||
DEFINE_bool(cls, false, "Whether use cls in forward.");
|
||||
DEFINE_bool(table, false, "Whether use table structure in forward.");
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <include/args.h>
|
||||
#include <include/paddleocr.h>
|
||||
#include <include/paddlestructure.h>
|
||||
|
||||
using namespace PaddleOCR;
|
||||
|
||||
|
@ -32,6 +33,12 @@ void check_params() {
|
|||
}
|
||||
}
|
||||
if (FLAGS_rec) {
|
||||
std::cout
|
||||
<< "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320',"
|
||||
"if you are using recognition model with PP-OCRv2 or an older "
|
||||
"version, "
|
||||
"please set --rec_image_shape='3,32,320"
|
||||
<< std::endl;
|
||||
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[rec]: ./ppocr "
|
||||
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
|
@ -47,6 +54,17 @@ void check_params() {
|
|||
exit(1);
|
||||
}
|
||||
}
|
||||
if (FLAGS_table) {
|
||||
if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() ||
|
||||
FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[table]: ./ppocr "
|
||||
<< "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
|
||||
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
<< "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
|
||||
FLAGS_precision != "int8") {
|
||||
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
|
||||
|
@ -54,21 +72,7 @@ void check_params() {
|
|||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Parsing command-line
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
check_params();
|
||||
|
||||
if (!Utility::PathExists(FLAGS_image_dir)) {
|
||||
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<cv::String> cv_all_img_names;
|
||||
cv::glob(FLAGS_image_dir, cv_all_img_names);
|
||||
std::cout << "total images num: " << cv_all_img_names.size() << endl;
|
||||
|
||||
void ocr(std::vector<cv::String> &cv_all_img_names) {
|
||||
PPOCR ocr = PPOCR();
|
||||
|
||||
std::vector<std::vector<OCRPredictResult>> ocr_results =
|
||||
|
@ -109,3 +113,49 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void structure(std::vector<cv::String> &cv_all_img_names) {
|
||||
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
|
||||
std::vector<std::vector<StructurePredictResult>> structure_results =
|
||||
engine.structure(cv_all_img_names, false, FLAGS_table);
|
||||
for (int i = 0; i < cv_all_img_names.size(); i++) {
|
||||
cout << cv_all_img_names[i] << "\n";
|
||||
for (int j = 0; j < structure_results[i].size(); j++) {
|
||||
std::cout << j << "\ttype: " << structure_results[i][j].type
|
||||
<< ", region: [";
|
||||
std::cout << structure_results[i][j].box[0] << ","
|
||||
<< structure_results[i][j].box[1] << ","
|
||||
<< structure_results[i][j].box[2] << ","
|
||||
<< structure_results[i][j].box[3] << "], res: ";
|
||||
if (structure_results[i][j].type == "Table") {
|
||||
std::cout << structure_results[i][j].html << std::endl;
|
||||
} else {
|
||||
Utility::print_result(structure_results[i][j].text_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Parsing command-line
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
check_params();
|
||||
|
||||
if (!Utility::PathExists(FLAGS_image_dir)) {
|
||||
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<cv::String> cv_all_img_names;
|
||||
cv::glob(FLAGS_image_dir, cv_all_img_names);
|
||||
std::cout << "total images num: " << cv_all_img_names.size() << endl;
|
||||
|
||||
if (FLAGS_type == "ocr") {
|
||||
ocr(cv_all_img_names);
|
||||
} else if (FLAGS_type == "structure") {
|
||||
structure(cv_all_img_names);
|
||||
} else {
|
||||
std::cout << "only value in ['ocr','structure'] is supported" << endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
{"elementwise_add_7", {1, 56, 2, 2}},
|
||||
{"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
|
||||
{"x", {1, 3, 1536, 1536}},
|
||||
{"conv2d_92.tmp_0", {1, 120, 400, 400}},
|
||||
{"conv2d_91.tmp_0", {1, 24, 200, 200}},
|
||||
{"conv2d_59.tmp_0", {1, 96, 400, 400}},
|
||||
|
@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img,
|
|||
img.copyTo(srcimg);
|
||||
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
|
||||
this->resize_op_.Run(img, resize_img, this->limit_type_,
|
||||
this->limit_side_len_, ratio_h, ratio_w,
|
||||
this->use_tensorrt_);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
|
|
|
@ -23,10 +23,10 @@ PPOCR::PPOCR() {
|
|||
if (FLAGS_det) {
|
||||
this->detector_ = new DBDetector(
|
||||
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
|
||||
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
|
||||
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
|
||||
FLAGS_precision);
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type,
|
||||
FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh,
|
||||
FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
}
|
||||
|
||||
if (FLAGS_cls && FLAGS_use_angle_cls) {
|
||||
|
@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
|
|||
res.box = boxes[i];
|
||||
ocr_results.push_back(res);
|
||||
}
|
||||
|
||||
// sort boex from top to bottom, from left to right
|
||||
Utility::sorted_boxes(ocr_results);
|
||||
times[0] += det_times[0];
|
||||
times[1] += det_times[1];
|
||||
times[2] += det_times[2];
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/args.h>
|
||||
#include <include/paddlestructure.h>
|
||||
|
||||
#include "auto_log/autolog.h"
|
||||
#include <numeric>
|
||||
#include <sys/stat.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
PaddleStructure::PaddleStructure() {
|
||||
if (FLAGS_table) {
|
||||
this->recognizer_ = new StructureTableRecognizer(
|
||||
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
|
||||
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
|
||||
FLAGS_table_max_len);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<StructurePredictResult>>
|
||||
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
|
||||
bool layout, bool table) {
|
||||
std::vector<double> time_info_det = {0, 0, 0};
|
||||
std::vector<double> time_info_rec = {0, 0, 0};
|
||||
std::vector<double> time_info_cls = {0, 0, 0};
|
||||
std::vector<double> time_info_table = {0, 0, 0};
|
||||
|
||||
std::vector<std::vector<StructurePredictResult>> structure_results;
|
||||
|
||||
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
|
||||
mkdir(FLAGS_output.c_str(), 0777);
|
||||
}
|
||||
for (int i = 0; i < cv_all_img_names.size(); ++i) {
|
||||
std::vector<StructurePredictResult> structure_result;
|
||||
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: "
|
||||
<< cv_all_img_names[i] << endl;
|
||||
exit(1);
|
||||
}
|
||||
if (layout) {
|
||||
} else {
|
||||
StructurePredictResult res;
|
||||
res.type = "Table";
|
||||
res.box = std::vector<int>(4, 0);
|
||||
res.box[2] = srcimg.cols;
|
||||
res.box[3] = srcimg.rows;
|
||||
structure_result.push_back(res);
|
||||
}
|
||||
cv::Mat roi_img;
|
||||
for (int i = 0; i < structure_result.size(); i++) {
|
||||
// crop image
|
||||
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
|
||||
if (structure_result[i].type == "Table") {
|
||||
this->table(roi_img, structure_result[i], time_info_table,
|
||||
time_info_det, time_info_rec, time_info_cls);
|
||||
}
|
||||
}
|
||||
structure_results.push_back(structure_result);
|
||||
}
|
||||
return structure_results;
|
||||
};
|
||||
|
||||
void PaddleStructure::table(cv::Mat img,
|
||||
StructurePredictResult &structure_result,
|
||||
std::vector<double> &time_info_table,
|
||||
std::vector<double> &time_info_det,
|
||||
std::vector<double> &time_info_rec,
|
||||
std::vector<double> &time_info_cls) {
|
||||
// predict structure
|
||||
std::vector<std::vector<std::string>> structure_html_tags;
|
||||
std::vector<float> structure_scores(1, 0);
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> structure_boxes;
|
||||
std::vector<double> structure_imes;
|
||||
std::vector<cv::Mat> img_list;
|
||||
img_list.push_back(img);
|
||||
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
|
||||
structure_boxes, structure_imes);
|
||||
time_info_table[0] += structure_imes[0];
|
||||
time_info_table[1] += structure_imes[1];
|
||||
time_info_table[2] += structure_imes[2];
|
||||
|
||||
std::vector<OCRPredictResult> ocr_result;
|
||||
std::string html;
|
||||
int expand_pixel = 3;
|
||||
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
// det
|
||||
this->det(img_list[i], ocr_result, time_info_det);
|
||||
// crop image
|
||||
std::vector<cv::Mat> rec_img_list;
|
||||
for (int j = 0; j < ocr_result.size(); j++) {
|
||||
int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0],
|
||||
ocr_result[j].box[2][0], ocr_result[j].box[3][0]};
|
||||
int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1],
|
||||
ocr_result[j].box[2][1], ocr_result[j].box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
std::vector<int> box{max(0, left - expand_pixel),
|
||||
max(0, top - expand_pixel),
|
||||
min(img_list[i].cols, right + expand_pixel),
|
||||
min(img_list[i].rows, bottom + expand_pixel)};
|
||||
cv::Mat crop_img = Utility::crop_image(img_list[i], box);
|
||||
rec_img_list.push_back(crop_img);
|
||||
}
|
||||
// rec
|
||||
this->rec(rec_img_list, ocr_result, time_info_rec);
|
||||
// rebuild table
|
||||
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
|
||||
ocr_result);
|
||||
structure_result.html = html;
|
||||
structure_result.html_score = structure_scores[i];
|
||||
}
|
||||
};
|
||||
|
||||
std::string PaddleStructure::rebuild_table(
|
||||
std::vector<std::string> structure_html_tags,
|
||||
std::vector<std::vector<std::vector<int>>> structure_boxes,
|
||||
std::vector<OCRPredictResult> &ocr_result) {
|
||||
// match text in same cell
|
||||
std::vector<std::vector<string>> matched(structure_boxes.size(),
|
||||
std::vector<std::string>());
|
||||
|
||||
for (int i = 0; i < ocr_result.size(); i++) {
|
||||
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
|
||||
std::vector<float>(3, 100000.0));
|
||||
for (int j = 0; j < structure_boxes.size(); j++) {
|
||||
int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0],
|
||||
ocr_result[i].box[2][0], ocr_result[i].box[3][0]};
|
||||
int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1],
|
||||
ocr_result[i].box[2][1], ocr_result[i].box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
std::vector<std::vector<int>> box(2, std::vector<int>(2, 0));
|
||||
box[0][0] = left - 1;
|
||||
box[0][1] = top - 1;
|
||||
box[1][0] = right + 1;
|
||||
box[1][1] = bottom + 1;
|
||||
|
||||
dis_list[j][0] = this->dis(box, structure_boxes[j]);
|
||||
dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]);
|
||||
dis_list[j][2] = j;
|
||||
}
|
||||
// find min dis idx
|
||||
std::sort(dis_list.begin(), dis_list.end(),
|
||||
PaddleStructure::comparison_dis);
|
||||
matched[dis_list[0][2]].push_back(ocr_result[i].text);
|
||||
}
|
||||
// get pred html
|
||||
std::string html_str = "";
|
||||
int td_tag_idx = 0;
|
||||
for (int i = 0; i < structure_html_tags.size(); i++) {
|
||||
if (structure_html_tags[i].find("</td>") != std::string::npos) {
|
||||
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
||||
html_str += "<td>";
|
||||
}
|
||||
if (matched[td_tag_idx].size() > 0) {
|
||||
bool b_with = false;
|
||||
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
|
||||
matched[td_tag_idx].size() > 1) {
|
||||
b_with = true;
|
||||
html_str += "<b>";
|
||||
}
|
||||
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
|
||||
std::string content = matched[td_tag_idx][j];
|
||||
if (matched[td_tag_idx].size() > 1) {
|
||||
// remove blank, <b> and </b>
|
||||
if (content.length() > 0 && content.at(0) == ' ') {
|
||||
content = content.substr(0);
|
||||
}
|
||||
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
|
||||
content = content.substr(3);
|
||||
}
|
||||
if (content.length() > 4 &&
|
||||
content.substr(content.length() - 4) == "</b>") {
|
||||
content = content.substr(0, content.length() - 4);
|
||||
}
|
||||
if (content.empty()) {
|
||||
continue;
|
||||
}
|
||||
// add blank
|
||||
if (j != matched[td_tag_idx].size() - 1 &&
|
||||
content.at(content.length() - 1) != ' ') {
|
||||
content += ' ';
|
||||
}
|
||||
}
|
||||
html_str += content;
|
||||
}
|
||||
if (b_with) {
|
||||
html_str += "</b>";
|
||||
}
|
||||
}
|
||||
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
||||
html_str += "</td>";
|
||||
} else {
|
||||
html_str += structure_html_tags[i];
|
||||
}
|
||||
td_tag_idx += 1;
|
||||
} else {
|
||||
html_str += structure_html_tags[i];
|
||||
}
|
||||
}
|
||||
return html_str;
|
||||
}
|
||||
|
||||
float PaddleStructure::iou(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2) {
|
||||
int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]);
|
||||
int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]);
|
||||
|
||||
// computing the sum_area
|
||||
int sum_area = area1 + area2;
|
||||
|
||||
// find the each point of intersect rectangle
|
||||
int x1 = max(box1[0][0], box2[0][0]);
|
||||
int y1 = max(box1[0][1], box2[0][1]);
|
||||
int x2 = min(box1[1][0], box2[1][0]);
|
||||
int y2 = min(box1[1][1], box2[1][1]);
|
||||
|
||||
// judge if there is an intersect
|
||||
if (y1 >= y2 || x1 >= x2) {
|
||||
return 0.0;
|
||||
} else {
|
||||
int intersect = (x2 - x1) * (y2 - y1);
|
||||
return intersect / (sum_area - intersect + 0.00000001);
|
||||
}
|
||||
}
|
||||
|
||||
float PaddleStructure::dis(std::vector<std::vector<int>> &box1,
|
||||
std::vector<std::vector<int>> &box2) {
|
||||
int x1_1 = box1[0][0];
|
||||
int y1_1 = box1[0][1];
|
||||
int x2_1 = box1[1][0];
|
||||
int y2_1 = box1[1][1];
|
||||
|
||||
int x1_2 = box2[0][0];
|
||||
int y1_2 = box2[0][1];
|
||||
int x2_2 = box2[1][0];
|
||||
int y2_2 = box2[1][1];
|
||||
|
||||
float dis =
|
||||
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
||||
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
|
||||
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
||||
return dis + min(dis_2, dis_3);
|
||||
}
|
||||
|
||||
PaddleStructure::~PaddleStructure() {
|
||||
if (this->recognizer_ != nullptr) {
|
||||
delete this->recognizer_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
void DBPostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance) {
|
||||
int pts_num = 4;
|
||||
float area = 0.0f;
|
||||
|
@ -35,7 +35,7 @@ void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
|||
distance = area * unclip_ratio / dist;
|
||||
}
|
||||
|
||||
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
|
||||
cv::RotatedRect DBPostProcessor::UnClip(std::vector<std::vector<float>> box,
|
||||
const float &unclip_ratio) {
|
||||
float distance = 1.0;
|
||||
|
||||
|
@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
|
|||
return res;
|
||||
}
|
||||
|
||||
float **PostProcessor::Mat2Vec(cv::Mat mat) {
|
||||
float **DBPostProcessor::Mat2Vec(cv::Mat mat) {
|
||||
auto **array = new float *[mat.rows];
|
||||
for (int i = 0; i < mat.rows; ++i)
|
||||
array[i] = new float[mat.cols];
|
||||
|
@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) {
|
|||
}
|
||||
|
||||
std::vector<std::vector<int>>
|
||||
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
||||
DBPostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
||||
std::vector<std::vector<int>> box = pts;
|
||||
std::sort(box.begin(), box.end(), XsortInt);
|
||||
|
||||
|
@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
|
|||
return rect;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
|
||||
std::vector<std::vector<float>> DBPostProcessor::Mat2Vector(cv::Mat mat) {
|
||||
std::vector<std::vector<float>> img_vec;
|
||||
std::vector<float> tmp;
|
||||
|
||||
|
@ -113,20 +113,20 @@ std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
|
|||
return img_vec;
|
||||
}
|
||||
|
||||
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
|
||||
bool DBPostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
|
||||
if (a[0] != b[0])
|
||||
return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
|
||||
bool DBPostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
|
||||
if (a[0] != b[0])
|
||||
return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
||||
float &ssid) {
|
||||
std::vector<std::vector<float>>
|
||||
DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) {
|
||||
ssid = std::max(box.size.width, box.size.height);
|
||||
|
||||
cv::Mat points;
|
||||
|
@ -160,7 +160,7 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
|||
return array;
|
||||
}
|
||||
|
||||
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
float DBPostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
cv::Mat pred) {
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
|
@ -206,7 +206,7 @@ float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
|||
return score;
|
||||
}
|
||||
|
||||
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
float DBPostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
cv::Mat pred) {
|
||||
auto array = box_array;
|
||||
int width = pred.cols;
|
||||
|
@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
|||
return score;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
|
||||
std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
|
||||
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
|
||||
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
|
||||
const int min_size = 3;
|
||||
|
@ -321,9 +321,9 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
|
|||
return boxes;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
float ratio_h, float ratio_w, cv::Mat srcimg) {
|
||||
std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
|
||||
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
|
||||
float ratio_w, cv::Mat srcimg) {
|
||||
int oriimg_h = srcimg.rows;
|
||||
int oriimg_w = srcimg.cols;
|
||||
|
||||
|
@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
return root_points;
|
||||
}
|
||||
|
||||
void TablePostProcessor::init(std::string label_path) {
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.insert(this->label_list_.begin(), this->beg);
|
||||
this->label_list_.push_back(this->end);
|
||||
}
|
||||
|
||||
void TablePostProcessor::Run(
|
||||
std::vector<float> &loc_preds, std::vector<float> &structure_probs,
|
||||
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
|
||||
std::vector<int> &structure_probs_shape,
|
||||
std::vector<std::vector<std::string>> &rec_html_tag_batch,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
|
||||
std::vector<int> &width_list, std::vector<int> &height_list) {
|
||||
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
|
||||
// image tags and boxs
|
||||
std::vector<std::string> rec_html_tags;
|
||||
std::vector<std::vector<std::vector<int>>> rec_boxes;
|
||||
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
float char_score = 0.f;
|
||||
int char_idx = 0;
|
||||
|
||||
// step
|
||||
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
|
||||
std::string html_tag;
|
||||
std::vector<std::vector<int>> rec_box;
|
||||
// html tag
|
||||
int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
|
||||
structure_probs_shape[2];
|
||||
char_idx = int(Utility::argmax(
|
||||
&structure_probs[step_start_idx],
|
||||
&structure_probs[step_start_idx + structure_probs_shape[2]]));
|
||||
char_score = float(*std::max_element(
|
||||
&structure_probs[step_start_idx],
|
||||
&structure_probs[step_start_idx + structure_probs_shape[2]]));
|
||||
html_tag = this->label_list_[char_idx];
|
||||
|
||||
if (step_idx > 0 && html_tag == this->end) {
|
||||
break;
|
||||
}
|
||||
if (html_tag == this->beg) {
|
||||
continue;
|
||||
}
|
||||
count += 1;
|
||||
score += char_score;
|
||||
rec_html_tags.push_back(html_tag);
|
||||
// box
|
||||
if (html_tag == "<td>" || html_tag == "<td") {
|
||||
for (int point_idx = 0; point_idx < loc_preds_shape[2];
|
||||
point_idx += 2) {
|
||||
std::vector<int> point(2, 0);
|
||||
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
|
||||
loc_preds_shape[2] +
|
||||
point_idx;
|
||||
point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]);
|
||||
point[1] =
|
||||
int(loc_preds[step_start_idx + 1] * height_list[batch_idx]);
|
||||
rec_box.push_back(point);
|
||||
}
|
||||
rec_boxes.push_back(rec_box);
|
||||
}
|
||||
}
|
||||
score /= count;
|
||||
if (isnan(score) || rec_boxes.size() == 0 || rec_html_tags.size() == 0) {
|
||||
score = -1;
|
||||
}
|
||||
rec_scores.push_back(score);
|
||||
rec_boxes_batch.push_back(rec_boxes);
|
||||
rec_html_tag_batch.push_back(rec_html_tags);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
int max_size_len, float &ratio_h, float &ratio_w,
|
||||
bool use_tensorrt) {
|
||||
string limit_type, int limit_side_len, float &ratio_h,
|
||||
float &ratio_w, bool use_tensorrt) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
float ratio = 1.f;
|
||||
int max_wh = w >= h ? w : h;
|
||||
if (max_wh > max_size_len) {
|
||||
if (h > w) {
|
||||
ratio = float(max_size_len) / float(h);
|
||||
if (limit_type == "min") {
|
||||
int min_wh = min(h, w);
|
||||
if (min_wh < limit_side_len) {
|
||||
if (h < w) {
|
||||
ratio = float(limit_side_len) / float(h);
|
||||
} else {
|
||||
ratio = float(max_size_len) / float(w);
|
||||
ratio = float(limit_side_len) / float(w);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int max_wh = max(h, w);
|
||||
if (max_wh > limit_side_len) {
|
||||
if (h > w) {
|
||||
ratio = float(limit_side_len) / float(h);
|
||||
} else {
|
||||
ratio = float(limit_side_len) / float(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
}
|
||||
}
|
||||
|
||||
void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
int max_wh = w >= h ? w : h;
|
||||
float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h);
|
||||
|
||||
int resize_h = int(float(h) * ratio);
|
||||
int resize_w = int(float(w) * ratio);
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
}
|
||||
|
||||
void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const int max_len) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/structure_table.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void StructureTableRecognizer::Run(
|
||||
std::vector<cv::Mat> img_list,
|
||||
std::vector<std::vector<std::string>> &structure_html_tags,
|
||||
std::vector<float> &structure_scores,
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>> &structure_boxes,
|
||||
std::vector<double> ×) {
|
||||
std::chrono::duration<float> preprocess_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> inference_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> postprocess_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
|
||||
int img_num = img_list.size();
|
||||
for (int beg_img_no = 0; beg_img_no < img_num;
|
||||
beg_img_no += this->table_batch_num_) {
|
||||
// preprocess
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_);
|
||||
int batch_num = end_img_no - beg_img_no;
|
||||
std::vector<cv::Mat> norm_img_batch;
|
||||
std::vector<int> width_list;
|
||||
std::vector<int> height_list;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
cv::Mat srcimg;
|
||||
img_list[ino].copyTo(srcimg);
|
||||
cv::Mat resize_img;
|
||||
cv::Mat pad_img;
|
||||
this->resize_op_.Run(srcimg, resize_img, this->table_max_len_);
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
this->pad_op_.Run(resize_img, pad_img, this->table_max_len_);
|
||||
norm_img_batch.push_back(pad_img);
|
||||
width_list.push_back(srcimg.cols);
|
||||
height_list.push_back(srcimg.rows);
|
||||
}
|
||||
|
||||
std::vector<float> input(
|
||||
batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f);
|
||||
this->permute_op_.Run(norm_img_batch, input.data());
|
||||
auto preprocess_end = std::chrono::steady_clock::now();
|
||||
preprocess_diff += preprocess_end - preprocess_start;
|
||||
// inference.
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape(
|
||||
{batch_num, 3, this->table_max_len_, this->table_max_len_});
|
||||
auto inference_start = std::chrono::steady_clock::now();
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]);
|
||||
std::vector<int> predict_shape0 = output_tensor0->shape();
|
||||
std::vector<int> predict_shape1 = output_tensor1->shape();
|
||||
|
||||
int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(),
|
||||
1, std::multiplies<int>());
|
||||
int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> loc_preds;
|
||||
std::vector<float> structure_probs;
|
||||
loc_preds.resize(out_num0);
|
||||
structure_probs.resize(out_num1);
|
||||
|
||||
output_tensor0->CopyToCpu(loc_preds.data());
|
||||
output_tensor1->CopyToCpu(structure_probs.data());
|
||||
auto inference_end = std::chrono::steady_clock::now();
|
||||
inference_diff += inference_end - inference_start;
|
||||
// postprocess
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
std::vector<std::vector<std::string>> structure_html_tag_batch;
|
||||
std::vector<float> structure_score_batch;
|
||||
std::vector<std::vector<std::vector<std::vector<int>>>>
|
||||
structure_boxes_batch;
|
||||
this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
|
||||
predict_shape0, predict_shape1,
|
||||
structure_html_tag_batch, structure_boxes_batch,
|
||||
width_list, height_list);
|
||||
for (int m = 0; m < predict_shape0[0]; m++) {
|
||||
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<table>");
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<body>");
|
||||
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
|
||||
"<html>");
|
||||
structure_html_tag_batch[m].push_back("</table>");
|
||||
structure_html_tag_batch[m].push_back("</body>");
|
||||
structure_html_tag_batch[m].push_back("</html>");
|
||||
structure_html_tags.push_back(structure_html_tag_batch[m]);
|
||||
structure_scores.push_back(structure_score_batch[m]);
|
||||
structure_boxes.push_back(structure_boxes_batch[m]);
|
||||
}
|
||||
auto postprocess_end = std::chrono::steady_clock::now();
|
||||
postprocess_diff += postprocess_end - postprocess_start;
|
||||
times.push_back(double(preprocess_diff.count() * 1000));
|
||||
times.push_back(double(inference_diff.count() * 1000));
|
||||
times.push_back(double(postprocess_diff.count() * 1000));
|
||||
}
|
||||
}
|
||||
|
||||
void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
config.SetModel(model_dir + "/inference.pdmodel",
|
||||
model_dir + "/inference.pdiparams");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
if (this->use_tensorrt_) {
|
||||
auto precision = paddle_infer::Config::Precision::kFloat32;
|
||||
if (this->precision_ == "fp16") {
|
||||
precision = paddle_infer::Config::Precision::kHalf;
|
||||
}
|
||||
if (this->precision_ == "int8") {
|
||||
precision = paddle_infer::Config::Precision::kInt8;
|
||||
}
|
||||
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
|
||||
}
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
|
@ -248,4 +248,33 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
|
|||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
|
||||
cv::Mat crop_im;
|
||||
int crop_x1 = std::max(0, area[0]);
|
||||
int crop_y1 = std::max(0, area[1]);
|
||||
int crop_x2 = std::min(img.cols - 1, area[2] - 1);
|
||||
int crop_y2 = std::min(img.rows - 1, area[3] - 1);
|
||||
|
||||
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16);
|
||||
cv::Mat crop_im_window =
|
||||
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]),
|
||||
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0]));
|
||||
cv::Mat roi_img =
|
||||
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
|
||||
crop_im_window += roi_img;
|
||||
return crop_im;
|
||||
}
|
||||
|
||||
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
|
||||
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
|
||||
|
||||
for (int i = 0; i < ocr_result.size() - 1; i++) {
|
||||
if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 &&
|
||||
(ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) {
|
||||
std::swap(ocr_result[i], ocr_result[i + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -62,7 +62,7 @@ class TableMatch:
|
|||
def __call__(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
if self.filter_ocr_result:
|
||||
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes,
|
||||
dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
|
||||
rec_res)
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
if self.use_master:
|
||||
|
@ -179,7 +179,7 @@ class TableMatch:
|
|||
html = deal_bb(html)
|
||||
return html, end_html
|
||||
|
||||
def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
|
||||
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
|
||||
y1 = pred_bboxes[:, 1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
|
|
@ -70,7 +70,7 @@ class TableSystem(object):
|
|||
if args.table_algorithm in ['TableMaster']:
|
||||
self.match = TableMasterMatcher()
|
||||
else:
|
||||
self.match = TableMatch()
|
||||
self.match = TableMatch(filter_ocr_result=True)
|
||||
|
||||
self.benchmark = args.benchmark
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||
|
|
Loading…
Reference in New Issue