Merge pull request #2337 from HydrogenSulfate/add_optional_det

Add flexible configuration of PP-ShiTu detection model
pull/2343/head
HydrogenSulfate 2022-09-22 20:39:10 +08:00 committed by GitHub
commit 1d397104be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 357 additions and 286 deletions

View File

@ -37,306 +37,324 @@
using namespace std;
using namespace cv;
DEFINE_string(config,
"", "Path of yaml file");
DEFINE_string(c,
"", "Path of yaml file");
DEFINE_string(config, "", "Path of yaml file");
DEFINE_string(c, "", "Path of yaml file");
void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
const std::vector <std::string> &all_img_paths,
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
const std::vector<std::string> &all_img_paths,
const int batch_size, Detection::ObjectDetector *det,
std::vector <Detection::ObjectResult> &im_result,
std::vector<int> &im_bbox_num, std::vector<double> &det_t,
const bool visual_det = false,
std::vector<Detection::ObjectResult> &im_result,
std::vector<double> &det_t, const bool visual_det = false,
const bool run_benchmark = false,
const std::string &output_dir = "output") {
int steps = ceil(float(all_img_paths.size()) / batch_size);
// printf("total images = %d, batch_size = %d, total steps = %d\n",
// all_img_paths.size(), batch_size, steps);
for (int idx = 0; idx < steps; idx++) {
int left_image_cnt = all_img_paths.size() - idx * batch_size;
if (left_image_cnt > batch_size) {
left_image_cnt = batch_size;
}
// for (int bs = 0; bs < left_image_cnt; bs++) {
// std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
// cv::Mat im = cv::imread(image_file_path, 1);
// batch_imgs.insert(batch_imgs.end(), im);
// }
// Store all detected result
std::vector <Detection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
bool is_rbox = false;
if (run_benchmark) {
det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
} else {
det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap
auto labels = det->GetLabelList();
auto colormap = Detection::GenerateColorMap(labels.size());
int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) {
cv::Mat im = batch_imgs[i];
int detect_num = 0;
for (int j = 0; j < bbox_num[i]; j++) {
Detection::ObjectResult item = result[item_start_idx + j];
if (item.confidence < det->GetThreshold() || item.class_id == -1) {
continue;
}
detect_num += 1;
im_result.push_back(item);
if (visual_det) {
if (item.rect.size() > 6) {
is_rbox = true;
printf(
"class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id, item.confidence, item.rect[0], item.rect[1],
item.rect[2], item.rect[3], item.rect[4], item.rect[5],
item.rect[6], item.rect[7]);
} else {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id, item.confidence, item.rect[0], item.rect[1],
item.rect[2], item.rect[3]);
}
}
}
im_bbox_num.push_back(detect_num);
item_start_idx = item_start_idx + bbox_num[i];
// Visualization result
if (visual_det) {
std::cout << all_img_paths.at(idx * batch_size + i)
<< " The number of detected box: " << detect_num
<< std::endl;
cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size + i);
output_path +=
image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
}
}
}
det_t[0] += det_times[0];
det_t[1] += det_times[1];
det_t[2] += det_times[2];
int steps = ceil(float(all_img_paths.size()) / batch_size);
// printf("total images = %d, batch_size = %d, total steps = %d\n",
// all_img_paths.size(), batch_size, steps);
for (int idx = 0; idx < steps; idx++) {
int left_image_cnt = all_img_paths.size() - idx * batch_size;
if (left_image_cnt > batch_size) {
left_image_cnt = batch_size;
}
// for (int bs = 0; bs < left_image_cnt; bs++) {
// std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
// cv::Mat im = cv::imread(image_file_path, 1);
// batch_imgs.insert(batch_imgs.end(), im);
// }
// Store all detected result
std::vector<Detection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
bool is_rbox = false;
if (run_benchmark) {
det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
} else {
det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap
auto labels = det->GetLabelList();
auto colormap = Detection::GenerateColorMap(labels.size());
int item_start_idx = 0;
for (int i = 0; i < left_image_cnt; i++) {
cv::Mat im = batch_imgs[i];
int detect_num = 0;
for (int j = 0; j < bbox_num[i]; j++) {
Detection::ObjectResult item = result[item_start_idx + j];
if (item.confidence < det->GetThreshold() || item.class_id == -1) {
continue;
}
detect_num += 1;
im_result.push_back(item);
if (visual_det) {
if (item.rect.size() > 6) {
is_rbox = true;
printf(
"class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id, item.confidence, item.rect[0], item.rect[1],
item.rect[2], item.rect[3], item.rect[4], item.rect[5],
item.rect[6], item.rect[7]);
} else {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id, item.confidence, item.rect[0], item.rect[1],
item.rect[2], item.rect[3]);
}
}
}
// im_bbox_num.push_back(detect_num);
item_start_idx = item_start_idx + bbox_num[i];
// Visualization result
if (visual_det) {
std::cout << all_img_paths.at(idx * batch_size + i)
<< " The number of detected box: " << detect_num
<< std::endl;
cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size + i);
output_path +=
image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
}
}
}
det_t[0] += det_times[0];
det_t[1] += det_times[1];
det_t[2] += det_times[2];
}
}
void PrintResult(std::string &img_path,
std::vector <Detection::ObjectResult> &det_result,
std::vector<int> &indeices, VectorSearch &vector_search,
std::vector<Detection::ObjectResult> &det_result,
std::vector<int> &indeices, VectorSearch *vector_search_ptr,
SearchResult &search_result) {
printf("%s:\n", img_path.c_str());
for (int i = 0; i < indeices.size(); ++i) {
int t = indeices[i];
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
det_result[t].rect[3], det_result[t].confidence,
vector_search.GetLabel(search_result.I[search_result.return_k * t])
.c_str());
}
printf("%s:\n", img_path.c_str());
for (int i = 0; i < indeices.size(); ++i) {
int t = indeices[i];
printf(
"\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
det_result[t].rect[3], det_result[t].confidence,
vector_search_ptr->GetLabel(search_result.I[search_result.return_k * t])
.c_str());
}
}
int main(int argc, char **argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
std::string yaml_path = "";
if (FLAGS_config == "" && FLAGS_c == "") {
std::cerr << "[ERROR] usage: " << std::endl
<< argv[0] << " -c $yaml_path" << std::endl
<< "or:" << std::endl
<< argv[0] << " -config $yaml_path" << std::endl;
exit(1);
} else if (FLAGS_config != "") {
yaml_path = FLAGS_config;
} else {
yaml_path = FLAGS_c;
google::ParseCommandLineFlags(&argc, &argv, true);
std::string yaml_path = "";
if (FLAGS_config == "" && FLAGS_c == "") {
std::cerr << "[ERROR] usage: " << std::endl
<< argv[0] << " -c $yaml_path" << std::endl
<< "or:" << std::endl
<< argv[0] << " -config $yaml_path" << std::endl;
exit(1);
} else if (FLAGS_config != "") {
yaml_path = FLAGS_config;
} else {
yaml_path = FLAGS_c;
}
YamlConfig config(yaml_path);
config.PrintConfigInfo();
// initialize detector
Detection::ObjectDetector *detector_ptr = nullptr;
if (config.config_file["Global"]["det_inference_model_dir"].Type() !=
YAML::NodeType::Null &&
!config.config_file["Global"]["det_inference_model_dir"]
.as<std::string>()
.empty()) {
detector_ptr = new Detection::ObjectDetector(config.config_file);
} else {
std::cout << "Found 'Global.det_inference_model_dir' empty, so "
"det_predictor is disabled"
<< std::endl;
}
// initialize feature_extractor
Feature::FeatureExtracter *feature_extracter_ptr =
new Feature::FeatureExtracter(config.config_file);
// initialize vector_searcher
VectorSearch *vector_searcher_ptr = new VectorSearch(config.config_file);
// config
const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
bool visual_det = false;
if (config.config_file["Global"]["visual_det"].IsDefined()) {
visual_det = config.config_file["Global"]["visual_det"].as<bool>();
}
bool benchmark = false;
if (config.config_file["Global"]["benchmark"].IsDefined()) {
benchmark = config.config_file["Global"]["benchmark"].as<bool>();
}
int max_det_results = 5;
if (config.config_file["Global"]["max_det_results"].IsDefined()) {
max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
}
float rec_nms_thresold = 0.05;
if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
rec_nms_thresold =
config.config_file["Global"]["rec_nms_thresold"].as<float>();
}
// load image_file_path
std::string path =
config.config_file["Global"]["infer_imgs"].as<std::string>();
std::vector<std::string> img_files_list;
if (cv::utils::fs::isDirectory(path)) {
std::vector<cv::String> filenames;
cv::glob(path, filenames);
for (auto f : filenames) {
img_files_list.push_back(f);
}
} else {
img_files_list.push_back(path);
}
std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
// for time log
std::vector<double> cls_times = {0, 0, 0};
std::vector<double> det_times = {0, 0, 0};
std::vector<double> search_times = {0, 0, 0};
int instance_num = 0;
// for read images
std::vector<cv::Mat> batch_imgs;
std::vector<std::string> img_paths;
// for detection
std::vector<Detection::ObjectResult> det_result;
// for vector search
std::vector<float> features;
std::vector<float> feature;
// for nms
std::vector<int> indeices;
int warmup_iter = img_files_list.size() > 5 ? 5 : img_files_list.size();
if (benchmark) {
img_files_list.insert(img_files_list.begin(), img_files_list.begin(),
img_files_list.begin() + warmup_iter);
}
for (int idx = 0; idx < img_files_list.size(); ++idx) {
std::string img_path = img_files_list[idx];
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path
<< "\n";
exit(-1);
}
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
batch_imgs.push_back(srcimg);
img_paths.push_back(img_path);
// step1: get all detection results if enable detector
if (detector_ptr != nullptr) {
DetPredictImage(batch_imgs, img_paths, batch_size, detector_ptr,
det_result, det_times, visual_det, false);
}
YamlConfig config(yaml_path);
config.PrintConfigInfo();
// select max_det_results bbox
if (det_result.size() > max_det_results) {
det_result.resize(max_det_results);
}
instance_num += det_result.size();
// initialize detector, rec_Model, vector_search
Feature::FeatureExtracter feature_extracter(config.config_file);
Detection::ObjectDetector detector(config.config_file);
VectorSearch searcher(config.config_file);
// step2: add the whole image for recognition to improve recall
Detection::ObjectResult result_whole_img = {
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
det_result.push_back(result_whole_img);
// config
const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
bool visual_det = false;
if (config.config_file["Global"]["visual_det"].IsDefined()) {
visual_det = config.config_file["Global"]["visual_det"].as<bool>();
}
bool benchmark = false;
if (config.config_file["Global"]["benchmark"].IsDefined()) {
benchmark = config.config_file["Global"]["benchmark"].as<bool>();
}
int max_det_results = 5;
if (config.config_file["Global"]["max_det_results"].IsDefined()) {
max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
}
float rec_nms_thresold = 0.05;
if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
rec_nms_thresold =
config.config_file["Global"]["rec_nms_thresold"].as<float>();
// step3: extract feature for all boxes in an inmage
SearchResult search_result;
for (int j = 0; j < det_result.size(); ++j) {
int w = det_result[j].rect[2] - det_result[j].rect[0];
int h = det_result[j].rect[3] - det_result[j].rect[1];
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
cv::Mat crop_img = srcimg(rect);
feature_extracter_ptr->Run(crop_img, feature, cls_times);
features.insert(features.end(), feature.begin(), feature.end());
}
// load image_file_path
std::string path =
config.config_file["Global"]["infer_imgs"].as<std::string>();
std::vector <std::string> img_files_list;
if (cv::utils::fs::isDirectory(path)) {
std::vector <cv::String> filenames;
cv::glob(path, filenames);
for (auto f : filenames) {
img_files_list.push_back(f);
}
} else {
img_files_list.push_back(path);
}
std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
// for time log
std::vector<double> cls_times = {0, 0, 0};
std::vector<double> det_times = {0, 0, 0};
std::vector<double> search_times = {0, 0, 0};
int instance_num = 0;
// for read images
std::vector <cv::Mat> batch_imgs;
std::vector <std::string> img_paths;
// for detection
std::vector <Detection::ObjectResult> det_result;
std::vector<int> det_bbox_num;
// for vector search
std::vector<float> features;
std::vector<float> feature;
// for nms
std::vector<int> indeices;
// step4: get search result
auto search_start = std::chrono::steady_clock::now();
search_result =
vector_searcher_ptr->Search(features.data(), det_result.size());
auto search_end = std::chrono::steady_clock::now();
int warmup_iter = img_files_list.size() > 5 ? 5 : img_files_list.size();
if (benchmark) {
img_files_list.insert(img_files_list.begin(), img_files_list.begin(),
img_files_list.begin() + warmup_iter);
// nms for search result
for (int i = 0; i < det_result.size(); ++i) {
det_result[i].confidence = search_result.D[search_result.return_k * i];
}
NMSBoxes(det_result, vector_searcher_ptr->GetThreshold(), rec_nms_thresold,
indeices);
auto nms_end = std::chrono::steady_clock::now();
std::chrono::duration<float> search_diff = search_end - search_start;
search_times[1] += double(search_diff.count() * 1000);
std::chrono::duration<float> nms_diff = nms_end - search_end;
search_times[2] += double(nms_diff.count() * 1000);
// print result
if (not benchmark or (benchmark and idx >= warmup_iter))
PrintResult(img_path, det_result, indeices, vector_searcher_ptr,
search_result);
// for postprocess
batch_imgs.clear();
img_paths.clear();
det_result.clear();
feature.clear();
features.clear();
indeices.clear();
if (benchmark and warmup_iter == idx + 1) {
det_times = {0, 0, 0};
cls_times = {0, 0, 0};
search_times = {0, 0, 0};
instance_num = 0;
}
}
if (benchmark) {
std::string presion = "fp32";
if (config.config_file["Global"]["use_fp16"].IsDefined() and
config.config_file["Global"]["use_fp16"].as<bool>())
presion = "fp16";
bool use_gpu = config.config_file["Global"]["use_gpu"].as<bool>();
bool use_tensorrt = config.config_file["Global"]["use_tensorrt"].as<bool>();
bool enable_mkldnn =
config.config_file["Global"]["enable_mkldnn"].as<bool>();
int cpu_num_threads =
config.config_file["Global"]["cpu_num_threads"].as<int>();
int batch_size = config.config_file["Global"]["batch_size"].as<int>();
std::vector<int> shape =
config.config_file["Global"]["image_shape"].as<std::vector<int>>();
std::string det_shape = std::to_string(shape[0]);
for (int i = 1; i < shape.size(); ++i) {
det_shape = det_shape + ", " + std::to_string(shape[i]);
}
for (int idx = 0; idx < img_files_list.size(); ++idx) {
std::string img_path = img_files_list[idx];
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path
<< "\n";
exit(-1);
}
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
batch_imgs.push_back(srcimg);
img_paths.push_back(img_path);
// step1: get all detection results
DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
det_bbox_num, det_times, visual_det, false);
// select max_det_results bbox
if (det_result.size() > max_det_results) {
det_result.resize(max_det_results);
}
instance_num += det_result.size();
// step2: add the whole image for recognition to improve recall
Detection::ObjectResult result_whole_img = {
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
det_result.push_back(result_whole_img);
det_bbox_num[0] = det_result.size() + 1;
// step3: extract feature for all boxes in an inmage
SearchResult search_result;
for (int j = 0; j < det_result.size(); ++j) {
int w = det_result[j].rect[2] - det_result[j].rect[0];
int h = det_result[j].rect[3] - det_result[j].rect[1];
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
cv::Mat crop_img = srcimg(rect);
feature_extracter.Run(crop_img, feature, cls_times);
features.insert(features.end(), feature.begin(), feature.end());
}
// step4: get search result
auto search_start = std::chrono::steady_clock::now();
search_result = searcher.Search(features.data(), det_result.size());
auto search_end = std::chrono::steady_clock::now();
// nms for search result
for (int i = 0; i < det_result.size(); ++i) {
det_result[i].confidence = search_result.D[search_result.return_k * i];
}
NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
auto nms_end = std::chrono::steady_clock::now();
std::chrono::duration<float> search_diff = search_end - search_start;
search_times[1] += double(search_diff.count() * 1000);
std::chrono::duration<float> nms_diff = nms_end - search_end;
search_times[2] += double(nms_diff.count() * 1000);
// print result
if (not benchmark or (benchmark and idx >= warmup_iter))
PrintResult(img_path, det_result, indeices, searcher, search_result);
// for postprocess
batch_imgs.clear();
img_paths.clear();
det_bbox_num.clear();
det_result.clear();
feature.clear();
features.clear();
indeices.clear();
if (benchmark and warmup_iter == idx + 1) {
det_times = {0, 0, 0};
cls_times = {0, 0, 0};
search_times = {0, 0, 0};
instance_num = 0;
}
}
if (benchmark) {
std::string presion = "fp32";
if (config.config_file["Global"]["use_fp16"].IsDefined() and
config.config_file["Global"]["use_fp16"].as<bool>())
presion = "fp16";
bool use_gpu = config.config_file["Global"]["use_gpu"].as<bool>();
bool use_tensorrt = config.config_file["Global"]["use_tensorrt"].as<bool>();
bool enable_mkldnn =
config.config_file["Global"]["enable_mkldnn"].as<bool>();
int cpu_num_threads =
config.config_file["Global"]["cpu_num_threads"].as<int>();
int batch_size = config.config_file["Global"]["batch_size"].as<int>();
std::vector<int> shape =
config.config_file["Global"]["image_shape"].as < std::vector < int >> ();
std::string det_shape = std::to_string(shape[0]);
for (int i = 1; i < shape.size(); ++i)
det_shape = det_shape + ", " + std::to_string(shape[i]);
AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, det_shape, presion,
det_times, img_files_list.size() - warmup_iter);
autolog_det.report();
AutoLogger autolog_rec("Rec", use_gpu, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, "3, 224, 224", presion,
cls_times, instance_num);
autolog_rec.report();
AutoLogger autolog_search("Search", false, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, "dynamic", presion,
search_times, instance_num);
autolog_search.report();
}
return 0;
AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, det_shape, presion,
det_times, img_files_list.size() - warmup_iter);
autolog_det.report();
AutoLogger autolog_rec("Rec", use_gpu, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, "3, 224, 224", presion,
cls_times, instance_num);
autolog_rec.report();
AutoLogger autolog_search("Search", false, use_tensorrt, enable_mkldnn,
cpu_num_threads, batch_size, "dynamic", presion,
search_times, instance_num);
autolog_search.report();
}
return 0;
}

View File

@ -31,7 +31,14 @@ class SystemPredictor(object):
self.config = config
self.rec_predictor = RecPredictor(config)
self.det_predictor = DetPredictor(config)
if not config["Global"]["det_inference_model_dir"]:
logger.info(
f"Found 'Global.det_inference_model_dir' empty({config['Global']['det_inference_model_dir']}), so det_predictor is disabled"
)
self.det_predictor = None
else:
self.det_predictor = DetPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.return_k = self.config['IndexProcess']['return_k']
@ -92,7 +99,10 @@ class SystemPredictor(object):
def predict(self, img):
output = []
# st1: get all detection results
results = self.det_predictor.predict(img)
if self.det_predictor:
results = self.det_predictor.predict(img)
else:
results = []
# st2: add the whole image for recognition to improve recall
results = self.append_self(results, img.shape)

View File

@ -198,6 +198,21 @@ The final output is as follows.
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}]
```
The recognition process supports flexible configuration. Users can choose not to use the object detection model, but directly input a single whole image into the feature extraction model, and calculate the feature vector for subsequent retrieval, thereby reducing the time-consuming of the overall recognition process. It can be achieved by the script below
```shell
# Use the following command to use the GPU for whole-image prediction
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
# Use the following command to use the CPU for whole-image prediction
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
```
The final output is as follows
```log
INFO: find 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
```
#### 4.3.2 multi images prediction
If you want to predict the images in the folder, you can directly modify the `Global.infer_imgs` field in the configuration file, or you can modify the corresponding configuration through the following -o parameter.

View File

@ -348,11 +348,25 @@ cd ..
./build/pp_shitu -c inference_drink.yaml
```
`drink_dataset_v2.0/test_images/nongfu_spring.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
默认以 `../drink_dataset_v2.0/test_images/100.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
```log
../../deploy/drink_dataset_v2.0/test_images/nongfu_spring.jpeg:
result0: bbox[0, 0, 729, 1094], score: 0.688691, label: 农夫山泉-饮用天然水
../drink_dataset_v2.0/test_images/100.jpeg:
result0: bbox[437, 72, 660, 723], score: 0.769916, label: 元气森林
result1: bbox[220, 71, 449, 685], score: 0.695485, label: 元气森林
result2: bbox[795, 104, 979, 653], score: 0.626963, label: 元气森林
```
识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。只需将`Global.det_inference_model_dir`后的字段改为`null`或者`""`,再运行以下推理命令即可
```shell
./build/pp_shitu -c inference_drink.yaml
```
最终输出结果如下
```log
Found 'Global.det_inference_model_dir' empty, so det_predictor is disabled
../drink_dataset_v2.0/test_images/100.jpeg:
result0: bbox[0, 0, 1199, 801], score: 0.568903, label: 元气森林
```
由于python和C++的opencv实现存在部分不同可能导致python推理和C++推理结果有微小差异。但基本不影响最终的检索结果。

View File

@ -223,6 +223,20 @@ python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.u
![](../../images/recognition/drink_data_demo/output/100.jpeg)
识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。可以按照以下命令直接进行整图识别
```shell
# 使用下面的命令使用 GPU 进行整图预测
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
# 使用下面的命令使用 CPU 进行整图预测
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
```
最终输出结果如下
```log
INFO: Found 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
```
<a name="基于文件夹的批量识别"></a>