From 6cf954e1c749053309a4c8502faf51e83e1bcbfa Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Wed, 21 Sep 2022 13:37:20 +0000
Subject: [PATCH 1/3] add flexible configuration for disable det model
---
deploy/python/predict_system.py | 14 ++++++++++++--
docs/en/PPShiTu/PPShiTuV2_introduction.md | 15 +++++++++++++++
docs/zh_CN/quick_start/quick_start_recognition.md | 14 ++++++++++++++
3 files changed, 41 insertions(+), 2 deletions(-)
diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py
index 44ff8a2e1..c029636b1 100644
--- a/deploy/python/predict_system.py
+++ b/deploy/python/predict_system.py
@@ -31,7 +31,14 @@ class SystemPredictor(object):
self.config = config
self.rec_predictor = RecPredictor(config)
- self.det_predictor = DetPredictor(config)
+
+ if not config["Global"]["det_inference_model_dir"]:
+ logger.info(
+ f"find 'Global.det_inference_model_dir' empty({config['Global']['det_inference_model_dir']}), so det_predictor is disabled"
+ )
+ self.det_predictor = None
+ else:
+ self.det_predictor = DetPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.return_k = self.config['IndexProcess']['return_k']
@@ -92,7 +99,10 @@ class SystemPredictor(object):
def predict(self, img):
output = []
# st1: get all detection results
- results = self.det_predictor.predict(img)
+ if self.det_predictor:
+ results = self.det_predictor.predict(img)
+ else:
+ results = []
# st2: add the whole image for recognition to improve recall
results = self.append_self(results, img.shape)
diff --git a/docs/en/PPShiTu/PPShiTuV2_introduction.md b/docs/en/PPShiTu/PPShiTuV2_introduction.md
index bae44aea1..2ad62acdb 100644
--- a/docs/en/PPShiTu/PPShiTuV2_introduction.md
+++ b/docs/en/PPShiTu/PPShiTuV2_introduction.md
@@ -198,6 +198,21 @@ The final output is as follows.
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}]
```
+The recognition process supports flexible configuration. Users can choose not to use the object detection model, but directly input a single whole image into the feature extraction model, and calculate the feature vector for subsequent retrieval, thereby reducing the time-consuming of the overall recognition process. It can be achieved by the script below
+```shell
+# Use the following command to use the GPU for whole-image prediction
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
+
+# Use the following command to use the CPU for whole-image prediction
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
+```
+
+The final output is as follows
+```log
+INFO: find 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
+[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
+```
+
#### 4.3.2 multi images prediction
If you want to predict the images in the folder, you can directly modify the `Global.infer_imgs` field in the configuration file, or you can modify the corresponding configuration through the following -o parameter.
diff --git a/docs/zh_CN/quick_start/quick_start_recognition.md b/docs/zh_CN/quick_start/quick_start_recognition.md
index a8a8c4fd2..b8912daad 100644
--- a/docs/zh_CN/quick_start/quick_start_recognition.md
+++ b/docs/zh_CN/quick_start/quick_start_recognition.md
@@ -223,6 +223,20 @@ python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.u

+识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。可以按照以下命令直接进行整图识别
+```shell
+# 使用下面的命令使用 GPU 进行整图预测
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
+
+# 使用下面的命令使用 CPU 进行整图预测
+python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
+```
+
+最终输出结果如下
+```log
+INFO: find 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
+[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
+```
From b31d67ea3219d2bd0cbcc81a7d7659736e7b5275 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 22 Sep 2022 02:42:48 +0000
Subject: [PATCH 2/3] add flexible configuration for disable det model(C++)
---
deploy/cpp_shitu/src/main.cpp | 576 +++++++++++++++++-----------------
1 file changed, 295 insertions(+), 281 deletions(-)
diff --git a/deploy/cpp_shitu/src/main.cpp b/deploy/cpp_shitu/src/main.cpp
index be37d3afd..6bdbf6574 100644
--- a/deploy/cpp_shitu/src/main.cpp
+++ b/deploy/cpp_shitu/src/main.cpp
@@ -37,306 +37,320 @@
using namespace std;
using namespace cv;
-DEFINE_string(config,
-"", "Path of yaml file");
-DEFINE_string(c,
-"", "Path of yaml file");
+DEFINE_string(config, "", "Path of yaml file");
+DEFINE_string(c, "", "Path of yaml file");
-void DetPredictImage(const std::vector &batch_imgs,
- const std::vector &all_img_paths,
+void DetPredictImage(const std::vector &batch_imgs,
+ const std::vector &all_img_paths,
const int batch_size, Detection::ObjectDetector *det,
- std::vector &im_result,
- std::vector &im_bbox_num, std::vector &det_t,
- const bool visual_det = false,
+ std::vector &im_result,
+ std::vector &det_t, const bool visual_det = false,
const bool run_benchmark = false,
const std::string &output_dir = "output") {
- int steps = ceil(float(all_img_paths.size()) / batch_size);
- // printf("total images = %d, batch_size = %d, total steps = %d\n",
- // all_img_paths.size(), batch_size, steps);
- for (int idx = 0; idx < steps; idx++) {
- int left_image_cnt = all_img_paths.size() - idx * batch_size;
- if (left_image_cnt > batch_size) {
- left_image_cnt = batch_size;
- }
- // for (int bs = 0; bs < left_image_cnt; bs++) {
- // std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
- // cv::Mat im = cv::imread(image_file_path, 1);
- // batch_imgs.insert(batch_imgs.end(), im);
- // }
-
- // Store all detected result
- std::vector result;
- std::vector bbox_num;
- std::vector det_times;
- bool is_rbox = false;
- if (run_benchmark) {
- det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
- } else {
- det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
- // get labels and colormap
- auto labels = det->GetLabelList();
- auto colormap = Detection::GenerateColorMap(labels.size());
-
- int item_start_idx = 0;
- for (int i = 0; i < left_image_cnt; i++) {
- cv::Mat im = batch_imgs[i];
- int detect_num = 0;
-
- for (int j = 0; j < bbox_num[i]; j++) {
- Detection::ObjectResult item = result[item_start_idx + j];
- if (item.confidence < det->GetThreshold() || item.class_id == -1) {
- continue;
- }
- detect_num += 1;
- im_result.push_back(item);
- if (visual_det) {
- if (item.rect.size() > 6) {
- is_rbox = true;
- printf(
- "class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
- item.class_id, item.confidence, item.rect[0], item.rect[1],
- item.rect[2], item.rect[3], item.rect[4], item.rect[5],
- item.rect[6], item.rect[7]);
- } else {
- printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
- item.class_id, item.confidence, item.rect[0], item.rect[1],
- item.rect[2], item.rect[3]);
- }
- }
- }
- im_bbox_num.push_back(detect_num);
- item_start_idx = item_start_idx + bbox_num[i];
-
- // Visualization result
- if (visual_det) {
- std::cout << all_img_paths.at(idx * batch_size + i)
- << " The number of detected box: " << detect_num
- << std::endl;
- cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
- colormap, is_rbox);
- std::vector compression_params;
- compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
- compression_params.push_back(95);
- std::string output_path(output_dir);
- if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
- output_path += OS_PATH_SEP;
- }
- std::string image_file_path = all_img_paths.at(idx * batch_size + i);
- output_path +=
- image_file_path.substr(image_file_path.find_last_of('/') + 1);
- cv::imwrite(output_path, vis_img, compression_params);
- printf("Visualized output saved as %s\n", output_path.c_str());
- }
- }
- }
- det_t[0] += det_times[0];
- det_t[1] += det_times[1];
- det_t[2] += det_times[2];
+ int steps = ceil(float(all_img_paths.size()) / batch_size);
+ // printf("total images = %d, batch_size = %d, total steps = %d\n",
+ // all_img_paths.size(), batch_size, steps);
+ for (int idx = 0; idx < steps; idx++) {
+ int left_image_cnt = all_img_paths.size() - idx * batch_size;
+ if (left_image_cnt > batch_size) {
+ left_image_cnt = batch_size;
}
+ // for (int bs = 0; bs < left_image_cnt; bs++) {
+ // std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
+ // cv::Mat im = cv::imread(image_file_path, 1);
+ // batch_imgs.insert(batch_imgs.end(), im);
+ // }
+
+ // Store all detected result
+ std::vector result;
+ std::vector bbox_num;
+ std::vector det_times;
+ bool is_rbox = false;
+ if (run_benchmark) {
+ det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
+ } else {
+ det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
+ // get labels and colormap
+ auto labels = det->GetLabelList();
+ auto colormap = Detection::GenerateColorMap(labels.size());
+
+ int item_start_idx = 0;
+ for (int i = 0; i < left_image_cnt; i++) {
+ cv::Mat im = batch_imgs[i];
+ int detect_num = 0;
+
+ for (int j = 0; j < bbox_num[i]; j++) {
+ Detection::ObjectResult item = result[item_start_idx + j];
+ if (item.confidence < det->GetThreshold() || item.class_id == -1) {
+ continue;
+ }
+ detect_num += 1;
+ im_result.push_back(item);
+ if (visual_det) {
+ if (item.rect.size() > 6) {
+ is_rbox = true;
+ printf(
+ "class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
+ item.class_id, item.confidence, item.rect[0], item.rect[1],
+ item.rect[2], item.rect[3], item.rect[4], item.rect[5],
+ item.rect[6], item.rect[7]);
+ } else {
+ printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
+ item.class_id, item.confidence, item.rect[0], item.rect[1],
+ item.rect[2], item.rect[3]);
+ }
+ }
+ }
+ // im_bbox_num.push_back(detect_num);
+ item_start_idx = item_start_idx + bbox_num[i];
+
+ // Visualization result
+ if (visual_det) {
+ std::cout << all_img_paths.at(idx * batch_size + i)
+ << " The number of detected box: " << detect_num
+ << std::endl;
+ cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
+ colormap, is_rbox);
+ std::vector compression_params;
+ compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
+ compression_params.push_back(95);
+ std::string output_path(output_dir);
+ if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
+ output_path += OS_PATH_SEP;
+ }
+ std::string image_file_path = all_img_paths.at(idx * batch_size + i);
+ output_path +=
+ image_file_path.substr(image_file_path.find_last_of('/') + 1);
+ cv::imwrite(output_path, vis_img, compression_params);
+ printf("Visualized output saved as %s\n", output_path.c_str());
+ }
+ }
+ }
+ det_t[0] += det_times[0];
+ det_t[1] += det_times[1];
+ det_t[2] += det_times[2];
+ }
}
void PrintResult(std::string &img_path,
- std::vector &det_result,
- std::vector &indeices, VectorSearch &vector_search,
+ std::vector &det_result,
+ std::vector &indeices, VectorSearch *vector_search_ptr,
SearchResult &search_result) {
- printf("%s:\n", img_path.c_str());
- for (int i = 0; i < indeices.size(); ++i) {
- int t = indeices[i];
- printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
- det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
- det_result[t].rect[3], det_result[t].confidence,
- vector_search.GetLabel(search_result.I[search_result.return_k * t])
- .c_str());
- }
+ printf("%s:\n", img_path.c_str());
+ for (int i = 0; i < indeices.size(); ++i) {
+ int t = indeices[i];
+ printf(
+ "\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
+ det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
+ det_result[t].rect[3], det_result[t].confidence,
+ vector_search_ptr->GetLabel(search_result.I[search_result.return_k * t])
+ .c_str());
+ }
}
int main(int argc, char **argv) {
- google::ParseCommandLineFlags(&argc, &argv, true);
- std::string yaml_path = "";
- if (FLAGS_config == "" && FLAGS_c == "") {
- std::cerr << "[ERROR] usage: " << std::endl
- << argv[0] << " -c $yaml_path" << std::endl
- << "or:" << std::endl
- << argv[0] << " -config $yaml_path" << std::endl;
- exit(1);
- } else if (FLAGS_config != "") {
- yaml_path = FLAGS_config;
- } else {
- yaml_path = FLAGS_c;
+ google::ParseCommandLineFlags(&argc, &argv, true);
+ std::string yaml_path = "";
+ if (FLAGS_config == "" && FLAGS_c == "") {
+ std::cerr << "[ERROR] usage: " << std::endl
+ << argv[0] << " -c $yaml_path" << std::endl
+ << "or:" << std::endl
+ << argv[0] << " -config $yaml_path" << std::endl;
+ exit(1);
+ } else if (FLAGS_config != "") {
+ yaml_path = FLAGS_config;
+ } else {
+ yaml_path = FLAGS_c;
+ }
+
+ YamlConfig config(yaml_path);
+ config.PrintConfigInfo();
+
+ // initialize detector
+ Detection::ObjectDetector *detector_ptr = nullptr;
+ if (config.config_file["Global"]["det_inference_model_dir"].Type() !=
+ YAML::NodeType::Null &&
+ !config.config_file["Global"]["det_inference_model_dir"]
+ .as()
+ .empty()) {
+ detector_ptr = new Detection::ObjectDetector(config.config_file);
+ }
+
+ // initialize feature_extractor
+ Feature::FeatureExtracter *feature_extracter_ptr =
+ new Feature::FeatureExtracter(config.config_file);
+ // initialize vector_searcher
+ VectorSearch *vector_searcher_ptr = new VectorSearch(config.config_file);
+
+ // config
+ const int batch_size = config.config_file["Global"]["batch_size"].as();
+ bool visual_det = false;
+ if (config.config_file["Global"]["visual_det"].IsDefined()) {
+ visual_det = config.config_file["Global"]["visual_det"].as();
+ }
+ bool benchmark = false;
+ if (config.config_file["Global"]["benchmark"].IsDefined()) {
+ benchmark = config.config_file["Global"]["benchmark"].as();
+ }
+ int max_det_results = 5;
+ if (config.config_file["Global"]["max_det_results"].IsDefined()) {
+ max_det_results = config.config_file["Global"]["max_det_results"].as();
+ }
+ float rec_nms_thresold = 0.05;
+ if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
+ rec_nms_thresold =
+ config.config_file["Global"]["rec_nms_thresold"].as();
+ }
+
+ // load image_file_path
+ std::string path =
+ config.config_file["Global"]["infer_imgs"].as();
+ std::vector img_files_list;
+ if (cv::utils::fs::isDirectory(path)) {
+ std::vector filenames;
+ cv::glob(path, filenames);
+ for (auto f : filenames) {
+ img_files_list.push_back(f);
+ }
+ } else {
+ img_files_list.push_back(path);
+ }
+ std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
+ // for time log
+ std::vector cls_times = {0, 0, 0};
+ std::vector det_times = {0, 0, 0};
+ std::vector search_times = {0, 0, 0};
+ int instance_num = 0;
+ // for read images
+ std::vector batch_imgs;
+ std::vector img_paths;
+ // for detection
+ std::vector det_result;
+
+ // for vector search
+ std::vector features;
+ std::vector feature;
+ // for nms
+ std::vector indeices;
+
+ int warmup_iter = img_files_list.size() > 5 ? 5 : img_files_list.size();
+ if (benchmark) {
+ img_files_list.insert(img_files_list.begin(), img_files_list.begin(),
+ img_files_list.begin() + warmup_iter);
+ }
+
+ for (int idx = 0; idx < img_files_list.size(); ++idx) {
+ std::string img_path = img_files_list[idx];
+ cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
+ if (!srcimg.data) {
+ std::cerr << "[ERROR] image read failed! image path: " << img_path
+ << "\n";
+ exit(-1);
+ }
+ cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
+
+ batch_imgs.push_back(srcimg);
+ img_paths.push_back(img_path);
+
+ // step1: get all detection results if enable detector
+ if (detector_ptr != nullptr) {
+ DetPredictImage(batch_imgs, img_paths, batch_size, detector_ptr,
+ det_result, det_times, visual_det, false);
}
- YamlConfig config(yaml_path);
- config.PrintConfigInfo();
+ // select max_det_results bbox
+ if (det_result.size() > max_det_results) {
+ det_result.resize(max_det_results);
+ }
+ instance_num += det_result.size();
- // initialize detector, rec_Model, vector_search
- Feature::FeatureExtracter feature_extracter(config.config_file);
- Detection::ObjectDetector detector(config.config_file);
- VectorSearch searcher(config.config_file);
+ // step2: add the whole image for recognition to improve recall
+ Detection::ObjectResult result_whole_img = {
+ {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
+ det_result.push_back(result_whole_img);
- // config
- const int batch_size = config.config_file["Global"]["batch_size"].as();
- bool visual_det = false;
- if (config.config_file["Global"]["visual_det"].IsDefined()) {
- visual_det = config.config_file["Global"]["visual_det"].as();
- }
- bool benchmark = false;
- if (config.config_file["Global"]["benchmark"].IsDefined()) {
- benchmark = config.config_file["Global"]["benchmark"].as();
- }
- int max_det_results = 5;
- if (config.config_file["Global"]["max_det_results"].IsDefined()) {
- max_det_results = config.config_file["Global"]["max_det_results"].as();
- }
- float rec_nms_thresold = 0.05;
- if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
- rec_nms_thresold =
- config.config_file["Global"]["rec_nms_thresold"].as();
+ // step3: extract feature for all boxes in an inmage
+ SearchResult search_result;
+ for (int j = 0; j < det_result.size(); ++j) {
+ int w = det_result[j].rect[2] - det_result[j].rect[0];
+ int h = det_result[j].rect[3] - det_result[j].rect[1];
+ cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
+ cv::Mat crop_img = srcimg(rect);
+ feature_extracter_ptr->Run(crop_img, feature, cls_times);
+ features.insert(features.end(), feature.begin(), feature.end());
}
- // load image_file_path
- std::string path =
- config.config_file["Global"]["infer_imgs"].as();
- std::vector img_files_list;
- if (cv::utils::fs::isDirectory(path)) {
- std::vector filenames;
- cv::glob(path, filenames);
- for (auto f : filenames) {
- img_files_list.push_back(f);
- }
- } else {
- img_files_list.push_back(path);
- }
- std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
- // for time log
- std::vector cls_times = {0, 0, 0};
- std::vector det_times = {0, 0, 0};
- std::vector search_times = {0, 0, 0};
- int instance_num = 0;
- // for read images
- std::vector batch_imgs;
- std::vector img_paths;
- // for detection
- std::vector det_result;
- std::vector det_bbox_num;
- // for vector search
- std::vector features;
- std::vector feature;
- // for nms
- std::vector indeices;
+ // step4: get search result
+ auto search_start = std::chrono::steady_clock::now();
+ search_result =
+ vector_searcher_ptr->Search(features.data(), det_result.size());
+ auto search_end = std::chrono::steady_clock::now();
- int warmup_iter = img_files_list.size() > 5 ? 5 : img_files_list.size();
- if (benchmark) {
- img_files_list.insert(img_files_list.begin(), img_files_list.begin(),
- img_files_list.begin() + warmup_iter);
+ // nms for search result
+ for (int i = 0; i < det_result.size(); ++i) {
+ det_result[i].confidence = search_result.D[search_result.return_k * i];
+ }
+ NMSBoxes(det_result, vector_searcher_ptr->GetThreshold(), rec_nms_thresold,
+ indeices);
+ auto nms_end = std::chrono::steady_clock::now();
+ std::chrono::duration search_diff = search_end - search_start;
+ search_times[1] += double(search_diff.count() * 1000);
+
+ std::chrono::duration nms_diff = nms_end - search_end;
+ search_times[2] += double(nms_diff.count() * 1000);
+
+ // print result
+ if (not benchmark or (benchmark and idx >= warmup_iter))
+ PrintResult(img_path, det_result, indeices, vector_searcher_ptr,
+ search_result);
+
+ // for postprocess
+ batch_imgs.clear();
+ img_paths.clear();
+ det_result.clear();
+ feature.clear();
+ features.clear();
+ indeices.clear();
+ if (benchmark and warmup_iter == idx + 1) {
+ det_times = {0, 0, 0};
+ cls_times = {0, 0, 0};
+ search_times = {0, 0, 0};
+ instance_num = 0;
+ }
+ }
+
+ if (benchmark) {
+ std::string presion = "fp32";
+ if (config.config_file["Global"]["use_fp16"].IsDefined() and
+ config.config_file["Global"]["use_fp16"].as())
+ presion = "fp16";
+ bool use_gpu = config.config_file["Global"]["use_gpu"].as();
+ bool use_tensorrt = config.config_file["Global"]["use_tensorrt"].as();
+ bool enable_mkldnn =
+ config.config_file["Global"]["enable_mkldnn"].as();
+ int cpu_num_threads =
+ config.config_file["Global"]["cpu_num_threads"].as();
+ int batch_size = config.config_file["Global"]["batch_size"].as();
+ std::vector shape =
+ config.config_file["Global"]["image_shape"].as>();
+ std::string det_shape = std::to_string(shape[0]);
+
+ for (int i = 1; i < shape.size(); ++i) {
+ det_shape = det_shape + ", " + std::to_string(shape[i]);
}
- for (int idx = 0; idx < img_files_list.size(); ++idx) {
- std::string img_path = img_files_list[idx];
- cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
- if (!srcimg.data) {
- std::cerr << "[ERROR] image read failed! image path: " << img_path
- << "\n";
- exit(-1);
- }
- cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
-
- batch_imgs.push_back(srcimg);
- img_paths.push_back(img_path);
-
- // step1: get all detection results
- DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
- det_bbox_num, det_times, visual_det, false);
-
- // select max_det_results bbox
- if (det_result.size() > max_det_results) {
- det_result.resize(max_det_results);
- }
- instance_num += det_result.size();
-
- // step2: add the whole image for recognition to improve recall
- Detection::ObjectResult result_whole_img = {
- {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
- det_result.push_back(result_whole_img);
- det_bbox_num[0] = det_result.size() + 1;
-
- // step3: extract feature for all boxes in an inmage
- SearchResult search_result;
- for (int j = 0; j < det_result.size(); ++j) {
- int w = det_result[j].rect[2] - det_result[j].rect[0];
- int h = det_result[j].rect[3] - det_result[j].rect[1];
- cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
- cv::Mat crop_img = srcimg(rect);
- feature_extracter.Run(crop_img, feature, cls_times);
- features.insert(features.end(), feature.begin(), feature.end());
- }
-
- // step4: get search result
- auto search_start = std::chrono::steady_clock::now();
- search_result = searcher.Search(features.data(), det_result.size());
- auto search_end = std::chrono::steady_clock::now();
-
- // nms for search result
- for (int i = 0; i < det_result.size(); ++i) {
- det_result[i].confidence = search_result.D[search_result.return_k * i];
- }
- NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
- auto nms_end = std::chrono::steady_clock::now();
- std::chrono::duration search_diff = search_end - search_start;
- search_times[1] += double(search_diff.count() * 1000);
-
- std::chrono::duration nms_diff = nms_end - search_end;
- search_times[2] += double(nms_diff.count() * 1000);
-
- // print result
- if (not benchmark or (benchmark and idx >= warmup_iter))
- PrintResult(img_path, det_result, indeices, searcher, search_result);
-
- // for postprocess
- batch_imgs.clear();
- img_paths.clear();
- det_bbox_num.clear();
- det_result.clear();
- feature.clear();
- features.clear();
- indeices.clear();
- if (benchmark and warmup_iter == idx + 1) {
- det_times = {0, 0, 0};
- cls_times = {0, 0, 0};
- search_times = {0, 0, 0};
- instance_num = 0;
- }
- }
-
- if (benchmark) {
- std::string presion = "fp32";
- if (config.config_file["Global"]["use_fp16"].IsDefined() and
- config.config_file["Global"]["use_fp16"].as())
- presion = "fp16";
- bool use_gpu = config.config_file["Global"]["use_gpu"].as();
- bool use_tensorrt = config.config_file["Global"]["use_tensorrt"].as();
- bool enable_mkldnn =
- config.config_file["Global"]["enable_mkldnn"].as();
- int cpu_num_threads =
- config.config_file["Global"]["cpu_num_threads"].as();
- int batch_size = config.config_file["Global"]["batch_size"].as();
- std::vector shape =
- config.config_file["Global"]["image_shape"].as < std::vector < int >> ();
- std::string det_shape = std::to_string(shape[0]);
- for (int i = 1; i < shape.size(); ++i)
- det_shape = det_shape + ", " + std::to_string(shape[i]);
-
- AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
- cpu_num_threads, batch_size, det_shape, presion,
- det_times, img_files_list.size() - warmup_iter);
- autolog_det.report();
- AutoLogger autolog_rec("Rec", use_gpu, use_tensorrt, enable_mkldnn,
- cpu_num_threads, batch_size, "3, 224, 224", presion,
- cls_times, instance_num);
- autolog_rec.report();
- AutoLogger autolog_search("Search", false, use_tensorrt, enable_mkldnn,
- cpu_num_threads, batch_size, "dynamic", presion,
- search_times, instance_num);
- autolog_search.report();
- }
- return 0;
+ AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
+ cpu_num_threads, batch_size, det_shape, presion,
+ det_times, img_files_list.size() - warmup_iter);
+ autolog_det.report();
+ AutoLogger autolog_rec("Rec", use_gpu, use_tensorrt, enable_mkldnn,
+ cpu_num_threads, batch_size, "3, 224, 224", presion,
+ cls_times, instance_num);
+ autolog_rec.report();
+ AutoLogger autolog_search("Search", false, use_tensorrt, enable_mkldnn,
+ cpu_num_threads, batch_size, "dynamic", presion,
+ search_times, instance_num);
+ autolog_search.report();
+ }
+ return 0;
}
From e44ecc198700db9fda7449dcfab6bf363f73ad57 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 22 Sep 2022 03:03:30 +0000
Subject: [PATCH 3/3] update PP-ShiTu inference doc
---
deploy/cpp_shitu/src/main.cpp | 4 ++++
deploy/python/predict_system.py | 2 +-
docs/zh_CN/deployment/PP-ShiTu/cpp.md | 20 ++++++++++++++++---
.../quick_start/quick_start_recognition.md | 2 +-
4 files changed, 23 insertions(+), 5 deletions(-)
diff --git a/deploy/cpp_shitu/src/main.cpp b/deploy/cpp_shitu/src/main.cpp
index 6bdbf6574..b4c384d50 100644
--- a/deploy/cpp_shitu/src/main.cpp
+++ b/deploy/cpp_shitu/src/main.cpp
@@ -174,6 +174,10 @@ int main(int argc, char **argv) {
.as()
.empty()) {
detector_ptr = new Detection::ObjectDetector(config.config_file);
+ } else {
+ std::cout << "Found 'Global.det_inference_model_dir' empty, so "
+ "det_predictor is disabled"
+ << std::endl;
}
// initialize feature_extractor
diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py
index c029636b1..19bf5c325 100644
--- a/deploy/python/predict_system.py
+++ b/deploy/python/predict_system.py
@@ -34,7 +34,7 @@ class SystemPredictor(object):
if not config["Global"]["det_inference_model_dir"]:
logger.info(
- f"find 'Global.det_inference_model_dir' empty({config['Global']['det_inference_model_dir']}), so det_predictor is disabled"
+ f"Found 'Global.det_inference_model_dir' empty({config['Global']['det_inference_model_dir']}), so det_predictor is disabled"
)
self.det_predictor = None
else:
diff --git a/docs/zh_CN/deployment/PP-ShiTu/cpp.md b/docs/zh_CN/deployment/PP-ShiTu/cpp.md
index 3f45f5036..6a802a610 100644
--- a/docs/zh_CN/deployment/PP-ShiTu/cpp.md
+++ b/docs/zh_CN/deployment/PP-ShiTu/cpp.md
@@ -348,11 +348,25 @@ cd ..
./build/pp_shitu -c inference_drink.yaml
```
- 以 `drink_dataset_v2.0/test_images/nongfu_spring.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
+ 默认以 `../drink_dataset_v2.0/test_images/100.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
```log
- ../../deploy/drink_dataset_v2.0/test_images/nongfu_spring.jpeg:
- result0: bbox[0, 0, 729, 1094], score: 0.688691, label: 农夫山泉-饮用天然水
+ ../drink_dataset_v2.0/test_images/100.jpeg:
+ result0: bbox[437, 72, 660, 723], score: 0.769916, label: 元气森林
+ result1: bbox[220, 71, 449, 685], score: 0.695485, label: 元气森林
+ result2: bbox[795, 104, 979, 653], score: 0.626963, label: 元气森林
+ ```
+
+ 识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。只需将`Global.det_inference_model_dir`后的字段改为`null`或者`""`,再运行以下推理命令即可
+ ```shell
+ ./build/pp_shitu -c inference_drink.yaml
+ ```
+
+ 最终输出结果如下
+ ```log
+ Found 'Global.det_inference_model_dir' empty, so det_predictor is disabled
+ ../drink_dataset_v2.0/test_images/100.jpeg:
+ result0: bbox[0, 0, 1199, 801], score: 0.568903, label: 元气森林
```
由于python和C++的opencv实现存在部分不同,可能导致python推理和C++推理结果有微小差异。但基本不影响最终的检索结果。
diff --git a/docs/zh_CN/quick_start/quick_start_recognition.md b/docs/zh_CN/quick_start/quick_start_recognition.md
index b8912daad..511a3c607 100644
--- a/docs/zh_CN/quick_start/quick_start_recognition.md
+++ b/docs/zh_CN/quick_start/quick_start_recognition.md
@@ -234,7 +234,7 @@ python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.u
最终输出结果如下
```log
-INFO: find 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
+INFO: Found 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
```