commit
e10128ac4c
|
@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
|||
|
||||
set(DEMO_NAME "ocr_system")
|
||||
|
||||
|
||||
macro(safe_set_static_flag)
|
||||
foreach(flag_var
|
||||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
||||
|
|
|
@ -44,6 +44,9 @@ public:
|
|||
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
|
||||
return std::distance(first, std::max_element(first, last));
|
||||
}
|
||||
|
||||
static void GetAllFiles(const char *dir_name,
|
||||
std::vector<std::string> &all_inputs);
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -0,0 +1 @@
|
|||
/paddle/test/PaddleOCR/deploy/cpp_infer/inference
|
|
@ -228,10 +228,7 @@ char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # 字典文件
|
|||
visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹下保存文件名为`ocr_vis.png`的预测结果。
|
||||
```
|
||||
|
||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md) 中的多语言字典与模型部分,
|
||||
如果希望进行多语言预测,只需将修改`tools/config.txt`中的`char_list_file`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||
|
||||
多语言模型和字典文件对应关系可以参考[文档](../../doc/doc_ch/multi_languages.md#预测部署)
|
||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`tools/config.txt`中的`char_list_file`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||
|
||||
最终屏幕上会输出检测结果如下。
|
||||
|
||||
|
|
|
@ -78,7 +78,6 @@ opencv3/
|
|||
|
||||
#### 1.2.1 Direct download and installation
|
||||
|
||||
* Different cuda versions of the Linux inference library (based on GCC 4.8.2) are provided on the
|
||||
[Paddle inference library official website](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html). You can view and select the appropriate version of the inference library on the official website.
|
||||
|
||||
|
||||
|
@ -92,7 +91,7 @@ Finally you can see the following files in the folder of `paddle_inference/`.
|
|||
|
||||
#### 1.2.2 Compile from the source code
|
||||
* If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code. It is recommended to download the inference library with paddle version greater than or equal to 2.0.1.
|
||||
* You can refer to [Paddle inference library](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#congyuanmabianyi) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
|
||||
* You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
|
||||
|
||||
|
||||
```shell
|
||||
|
@ -100,7 +99,7 @@ git clone https://github.com/PaddlePaddle/Paddle.git
|
|||
git checkout release/2.1
|
||||
```
|
||||
|
||||
* After entering the Paddle directory, the compilation method is as follows.
|
||||
* After entering the Paddle directory, the commands to compile the paddle inference library are as follows.
|
||||
|
||||
```shell
|
||||
rm -rf build
|
||||
|
@ -235,7 +234,6 @@ visualize 1 # Whether to visualize the results,when it is set as 1, The predic
|
|||
```
|
||||
|
||||
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `char_list_file` and `rec_model_dir` in file `tools/config.txt`.
|
||||
The corresponding relationship between the multi-language model and the dictionary file can be found in [document](../../doc/doc_en/multi_languages_en.md#inference)
|
||||
|
||||
|
||||
The detection results will be shown on the screen, which is as follows.
|
||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||
std::memset(e, 0, sizeof(TEdge));
|
||||
std::memset(e, int(0), sizeof(TEdge));
|
||||
e->Next = eNext;
|
||||
e->Prev = ePrev;
|
||||
e->Curr = Pt;
|
||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
|||
TEdge *rb = lm->RightBound;
|
||||
|
||||
OutPt *Op1 = 0;
|
||||
if (!lb) {
|
||||
if (!lb || !rb) {
|
||||
// nb: don't insert LB into either AEL or SEL
|
||||
InsertEdgeIntoAEL(rb, 0);
|
||||
SetWindingCount(*rb);
|
||||
if (IsContributing(*rb))
|
||||
Op1 = AddOutPt(rb, rb->Bot);
|
||||
} else if (!rb) {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
SetWindingCount(*lb);
|
||||
if (IsContributing(*lb))
|
||||
Op1 = AddOutPt(lb, lb->Bot);
|
||||
//} else if (!rb) {
|
||||
// InsertEdgeIntoAEL(lb, 0);
|
||||
// SetWindingCount(*lb);
|
||||
// if (IsContributing(*lb))
|
||||
// Op1 = AddOutPt(lb, lb->Bot);
|
||||
InsertScanbeam(lb->Top.Y);
|
||||
} else {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
if (dir == dLeftToRight) {
|
||||
maxIt = m_Maxima.begin();
|
||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||
maxIt = m_Maxima.end();
|
||||
} else {
|
||||
maxRit = m_Maxima.rbegin();
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||
maxRit = m_Maxima.rend();
|
||||
}
|
||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
}
|
||||
} else {
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str,
|
|||
std::vector<std::string> res;
|
||||
if ("" == str)
|
||||
return res;
|
||||
char *strs = new char[str.length() + 1];
|
||||
char strs[str.length() + 1];
|
||||
std::strcpy(strs, str.c_str());
|
||||
|
||||
char *d = new char[delim.length() + 1];
|
||||
char d[delim.length() + 1];
|
||||
std::strcpy(d, delim.c_str());
|
||||
|
||||
char *p = std::strtok(strs, d);
|
||||
|
@ -61,4 +61,4 @@ void OCRConfig::PrintConfigInfo() {
|
|||
std::cout << "=======End of Paddle OCR inference config======" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -27,9 +27,12 @@
|
|||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <include/config.h>
|
||||
#include <include/ocr_det.h>
|
||||
#include <include/ocr_rec.h>
|
||||
#include <include/utility.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
@ -47,13 +50,8 @@ int main(int argc, char **argv) {
|
|||
config.PrintConfigInfo();
|
||||
|
||||
std::string img_path(argv[2]);
|
||||
|
||||
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
|
||||
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << img_path << "\n";
|
||||
exit(1);
|
||||
}
|
||||
std::vector<std::string> all_img_names;
|
||||
Utility::GetAllFiles((char *)img_path.c_str(), all_img_names);
|
||||
|
||||
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
|
@ -76,18 +74,30 @@ int main(int argc, char **argv) {
|
|||
config.use_tensorrt, config.use_fp16);
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
det.Run(srcimg, boxes);
|
||||
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
std::cout << "Cost "
|
||||
<< double(duration.count()) *
|
||||
std::chrono::microseconds::period::num /
|
||||
std::chrono::microseconds::period::den
|
||||
<< "s" << std::endl;
|
||||
for (auto img_dir : all_img_names) {
|
||||
LOG(INFO) << "The predict img: " << img_dir;
|
||||
|
||||
cv::Mat srcimg = cv::imread(img_dir, cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << img_path
|
||||
<< "\n";
|
||||
exit(1);
|
||||
}
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
||||
det.Run(srcimg, boxes);
|
||||
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
std::cout << "Cost "
|
||||
<< double(duration.count()) *
|
||||
std::chrono::microseconds::period::num /
|
||||
std::chrono::microseconds::period::den
|
||||
<< "s" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,42 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
|
||||
: paddle_infer::Config::Precision::kFloat32,
|
||||
false, false);
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"x", {1, 3, 50, 50}},
|
||||
{"conv2d_92.tmp_0", {1, 96, 20, 20}},
|
||||
{"conv2d_91.tmp_0", {1, 96, 10, 10}},
|
||||
{"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}},
|
||||
{"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}},
|
||||
{"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}},
|
||||
{"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}},
|
||||
{"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}},
|
||||
{"elementwise_add_7", {1, 56, 2, 2}},
|
||||
{"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
|
||||
{"conv2d_92.tmp_0", {1, 96, 400, 400}},
|
||||
{"conv2d_91.tmp_0", {1, 96, 200, 200}},
|
||||
{"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}},
|
||||
{"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}},
|
||||
{"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}},
|
||||
{"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}},
|
||||
{"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}},
|
||||
{"elementwise_add_7", {1, 56, 400, 400}},
|
||||
{"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}};
|
||||
std::map<std::string, std::vector<int>> opt_input_shape = {
|
||||
{"x", {1, 3, 640, 640}},
|
||||
{"conv2d_92.tmp_0", {1, 96, 160, 160}},
|
||||
{"conv2d_91.tmp_0", {1, 96, 80, 80}},
|
||||
{"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}},
|
||||
{"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}},
|
||||
{"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}},
|
||||
{"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}},
|
||||
{"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}},
|
||||
{"elementwise_add_7", {1, 56, 40, 40}},
|
||||
{"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}};
|
||||
|
||||
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
||||
opt_input_shape);
|
||||
}
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
|
@ -48,7 +84,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
// config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
|
|
|
@ -106,6 +106,15 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
|
||||
: paddle_infer::Config::Precision::kFloat32,
|
||||
false, false);
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"x", {1, 3, 32, 10}}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"x", {1, 3, 32, 2000}}};
|
||||
std::map<std::string, std::vector<int>> opt_input_shape = {
|
||||
{"x", {1, 3, 32, 320}}};
|
||||
|
||||
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
||||
opt_input_shape);
|
||||
}
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
|
|
|
@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
e /= 255.0;
|
||||
}
|
||||
(*im).convertTo(*im, CV_32FC3, e);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
|
||||
}
|
||||
std::vector<cv::Mat> bgr_channels(3);
|
||||
cv::split(*im, bgr_channels);
|
||||
for (auto i = 0; i < bgr_channels.size(); i++) {
|
||||
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
|
||||
(0.0 - mean[i]) * scale[i]);
|
||||
}
|
||||
cv::merge(bgr_channels, *im);
|
||||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
|
@ -77,19 +74,13 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
|
||||
int resize_h = int(float(h) * ratio);
|
||||
int resize_w = int(float(w) * ratio);
|
||||
|
||||
|
||||
resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
|
||||
resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
|
||||
|
||||
if (!use_tensorrt) {
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
ratio_h = float(resize_h) / float(h);
|
||||
ratio_w = float(resize_w) / float(w);
|
||||
} else {
|
||||
cv::resize(img, resize_img, cv::Size(640, 640));
|
||||
ratio_h = float(640) / float(h);
|
||||
ratio_w = float(640) / float(w);
|
||||
}
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
ratio_h = float(resize_h) / float(h);
|
||||
ratio_w = float(resize_w) / float(w);
|
||||
}
|
||||
|
||||
void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
||||
|
@ -108,23 +99,12 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
if (!use_tensorrt) {
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
|
||||
int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
|
||||
{127, 127, 127});
|
||||
} else {
|
||||
int k = int(img.cols * 32 / img.rows);
|
||||
if (k >= 100) {
|
||||
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
} else {
|
||||
cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR);
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k),
|
||||
cv::BORDER_CONSTANT, {127, 127, 127});
|
||||
}
|
||||
}
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
|
||||
int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
|
||||
{127, 127, 127});
|
||||
}
|
||||
|
||||
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
|
@ -142,15 +122,11 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
if (!use_tensorrt) {
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
} else {
|
||||
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR);
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <dirent.h>
|
||||
#include <include/utility.h>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <vector>
|
||||
|
||||
#include <include/utility.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
std::vector<std::string> Utility::ReadDict(const std::string &path) {
|
||||
|
@ -57,4 +59,37 @@ void Utility::VisualizeBboxes(
|
|||
<< std::endl;
|
||||
}
|
||||
|
||||
// list all files under a directory
|
||||
void Utility::GetAllFiles(const char *dir_name,
|
||||
std::vector<std::string> &all_inputs) {
|
||||
if (NULL == dir_name) {
|
||||
std::cout << " dir_name is null ! " << std::endl;
|
||||
return;
|
||||
}
|
||||
struct stat s;
|
||||
lstat(dir_name, &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dir_name is not a valid directory !" << std::endl;
|
||||
all_inputs.push_back(dir_name);
|
||||
return;
|
||||
} else {
|
||||
struct dirent *filename; // return value for readdir()
|
||||
DIR *dir; // return value for opendir()
|
||||
dir = opendir(dir_name);
|
||||
if (NULL == dir) {
|
||||
std::cout << "Can not open dir " << dir_name << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << "Successfully opened the dir !" << std::endl;
|
||||
while ((filename = readdir(dir)) != NULL) {
|
||||
if (strcmp(filename->d_name, ".") == 0 ||
|
||||
strcmp(filename->d_name, "..") == 0)
|
||||
continue;
|
||||
// img_dir + std::string("/") + all_inputs[0];
|
||||
all_inputs.push_back(dir_name + std::string("/") +
|
||||
std::string(filename->d_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -1,7 +1,7 @@
|
|||
OPENCV_DIR=your_opencv_dir
|
||||
LIB_DIR=your_paddle_inference_dir
|
||||
CUDA_LIB_DIR=your_cuda_lib_dir
|
||||
CUDNN_LIB_DIR=your_cudnn_lib_dir
|
||||
OPENCV_DIR=/paddle/test/opencv-3.4.7/opencv3
|
||||
LIB_DIR=/paddle/test/PaddleOCR/deploy/paddle_inference
|
||||
CUDA_LIB_DIR=/usr/local/cuda/lib64
|
||||
CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
|
||||
|
||||
BUILD_DIR=build
|
||||
rm -rf ${BUILD_DIR}
|
||||
|
@ -18,3 +18,5 @@ cmake .. \
|
|||
-DCUDA_LIB=${CUDA_LIB_DIR} \
|
||||
|
||||
make -j
|
||||
|
||||
|
||||
|
|
|
@ -20,10 +20,10 @@ cls_thresh 0.9
|
|||
|
||||
# rec config
|
||||
rec_model_dir ./inference/ch_ppocr_mobile_v2.0_rec_infer/
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
# show the detection results
|
||||
visualize 1
|
||||
visualize 0
|
||||
|
||||
# use_tensorrt
|
||||
use_tensorrt 0
|
||||
|
|
Loading…
Reference in New Issue