mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add flexible configuration for disable det model(C++)
This commit is contained in:
parent
6cf954e1c7
commit
b31d67ea32
@ -37,17 +37,14 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace cv;
|
using namespace cv;
|
||||||
|
|
||||||
DEFINE_string(config,
|
DEFINE_string(config, "", "Path of yaml file");
|
||||||
"", "Path of yaml file");
|
DEFINE_string(c, "", "Path of yaml file");
|
||||||
DEFINE_string(c,
|
|
||||||
"", "Path of yaml file");
|
|
||||||
|
|
||||||
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
|
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
|
||||||
const std::vector<std::string> &all_img_paths,
|
const std::vector<std::string> &all_img_paths,
|
||||||
const int batch_size, Detection::ObjectDetector *det,
|
const int batch_size, Detection::ObjectDetector *det,
|
||||||
std::vector<Detection::ObjectResult> &im_result,
|
std::vector<Detection::ObjectResult> &im_result,
|
||||||
std::vector<int> &im_bbox_num, std::vector<double> &det_t,
|
std::vector<double> &det_t, const bool visual_det = false,
|
||||||
const bool visual_det = false,
|
|
||||||
const bool run_benchmark = false,
|
const bool run_benchmark = false,
|
||||||
const std::string &output_dir = "output") {
|
const std::string &output_dir = "output") {
|
||||||
int steps = ceil(float(all_img_paths.size()) / batch_size);
|
int steps = ceil(float(all_img_paths.size()) / batch_size);
|
||||||
@ -104,7 +101,7 @@ void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
im_bbox_num.push_back(detect_num);
|
// im_bbox_num.push_back(detect_num);
|
||||||
item_start_idx = item_start_idx + bbox_num[i];
|
item_start_idx = item_start_idx + bbox_num[i];
|
||||||
|
|
||||||
// Visualization result
|
// Visualization result
|
||||||
@ -137,15 +134,16 @@ void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
|
|||||||
|
|
||||||
void PrintResult(std::string &img_path,
|
void PrintResult(std::string &img_path,
|
||||||
std::vector<Detection::ObjectResult> &det_result,
|
std::vector<Detection::ObjectResult> &det_result,
|
||||||
std::vector<int> &indeices, VectorSearch &vector_search,
|
std::vector<int> &indeices, VectorSearch *vector_search_ptr,
|
||||||
SearchResult &search_result) {
|
SearchResult &search_result) {
|
||||||
printf("%s:\n", img_path.c_str());
|
printf("%s:\n", img_path.c_str());
|
||||||
for (int i = 0; i < indeices.size(); ++i) {
|
for (int i = 0; i < indeices.size(); ++i) {
|
||||||
int t = indeices[i];
|
int t = indeices[i];
|
||||||
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
|
printf(
|
||||||
|
"\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
|
||||||
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
|
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
|
||||||
det_result[t].rect[3], det_result[t].confidence,
|
det_result[t].rect[3], det_result[t].confidence,
|
||||||
vector_search.GetLabel(search_result.I[search_result.return_k * t])
|
vector_search_ptr->GetLabel(search_result.I[search_result.return_k * t])
|
||||||
.c_str());
|
.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,10 +166,21 @@ int main(int argc, char **argv) {
|
|||||||
YamlConfig config(yaml_path);
|
YamlConfig config(yaml_path);
|
||||||
config.PrintConfigInfo();
|
config.PrintConfigInfo();
|
||||||
|
|
||||||
// initialize detector, rec_Model, vector_search
|
// initialize detector
|
||||||
Feature::FeatureExtracter feature_extracter(config.config_file);
|
Detection::ObjectDetector *detector_ptr = nullptr;
|
||||||
Detection::ObjectDetector detector(config.config_file);
|
if (config.config_file["Global"]["det_inference_model_dir"].Type() !=
|
||||||
VectorSearch searcher(config.config_file);
|
YAML::NodeType::Null &&
|
||||||
|
!config.config_file["Global"]["det_inference_model_dir"]
|
||||||
|
.as<std::string>()
|
||||||
|
.empty()) {
|
||||||
|
detector_ptr = new Detection::ObjectDetector(config.config_file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize feature_extractor
|
||||||
|
Feature::FeatureExtracter *feature_extracter_ptr =
|
||||||
|
new Feature::FeatureExtracter(config.config_file);
|
||||||
|
// initialize vector_searcher
|
||||||
|
VectorSearch *vector_searcher_ptr = new VectorSearch(config.config_file);
|
||||||
|
|
||||||
// config
|
// config
|
||||||
const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
|
const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
|
||||||
@ -217,7 +226,7 @@ int main(int argc, char **argv) {
|
|||||||
std::vector<std::string> img_paths;
|
std::vector<std::string> img_paths;
|
||||||
// for detection
|
// for detection
|
||||||
std::vector<Detection::ObjectResult> det_result;
|
std::vector<Detection::ObjectResult> det_result;
|
||||||
std::vector<int> det_bbox_num;
|
|
||||||
// for vector search
|
// for vector search
|
||||||
std::vector<float> features;
|
std::vector<float> features;
|
||||||
std::vector<float> feature;
|
std::vector<float> feature;
|
||||||
@ -243,9 +252,11 @@ int main(int argc, char **argv) {
|
|||||||
batch_imgs.push_back(srcimg);
|
batch_imgs.push_back(srcimg);
|
||||||
img_paths.push_back(img_path);
|
img_paths.push_back(img_path);
|
||||||
|
|
||||||
// step1: get all detection results
|
// step1: get all detection results if enable detector
|
||||||
DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
|
if (detector_ptr != nullptr) {
|
||||||
det_bbox_num, det_times, visual_det, false);
|
DetPredictImage(batch_imgs, img_paths, batch_size, detector_ptr,
|
||||||
|
det_result, det_times, visual_det, false);
|
||||||
|
}
|
||||||
|
|
||||||
// select max_det_results bbox
|
// select max_det_results bbox
|
||||||
if (det_result.size() > max_det_results) {
|
if (det_result.size() > max_det_results) {
|
||||||
@ -257,7 +268,6 @@ int main(int argc, char **argv) {
|
|||||||
Detection::ObjectResult result_whole_img = {
|
Detection::ObjectResult result_whole_img = {
|
||||||
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
|
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
|
||||||
det_result.push_back(result_whole_img);
|
det_result.push_back(result_whole_img);
|
||||||
det_bbox_num[0] = det_result.size() + 1;
|
|
||||||
|
|
||||||
// step3: extract feature for all boxes in an inmage
|
// step3: extract feature for all boxes in an inmage
|
||||||
SearchResult search_result;
|
SearchResult search_result;
|
||||||
@ -266,20 +276,22 @@ int main(int argc, char **argv) {
|
|||||||
int h = det_result[j].rect[3] - det_result[j].rect[1];
|
int h = det_result[j].rect[3] - det_result[j].rect[1];
|
||||||
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
|
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
|
||||||
cv::Mat crop_img = srcimg(rect);
|
cv::Mat crop_img = srcimg(rect);
|
||||||
feature_extracter.Run(crop_img, feature, cls_times);
|
feature_extracter_ptr->Run(crop_img, feature, cls_times);
|
||||||
features.insert(features.end(), feature.begin(), feature.end());
|
features.insert(features.end(), feature.begin(), feature.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// step4: get search result
|
// step4: get search result
|
||||||
auto search_start = std::chrono::steady_clock::now();
|
auto search_start = std::chrono::steady_clock::now();
|
||||||
search_result = searcher.Search(features.data(), det_result.size());
|
search_result =
|
||||||
|
vector_searcher_ptr->Search(features.data(), det_result.size());
|
||||||
auto search_end = std::chrono::steady_clock::now();
|
auto search_end = std::chrono::steady_clock::now();
|
||||||
|
|
||||||
// nms for search result
|
// nms for search result
|
||||||
for (int i = 0; i < det_result.size(); ++i) {
|
for (int i = 0; i < det_result.size(); ++i) {
|
||||||
det_result[i].confidence = search_result.D[search_result.return_k * i];
|
det_result[i].confidence = search_result.D[search_result.return_k * i];
|
||||||
}
|
}
|
||||||
NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
|
NMSBoxes(det_result, vector_searcher_ptr->GetThreshold(), rec_nms_thresold,
|
||||||
|
indeices);
|
||||||
auto nms_end = std::chrono::steady_clock::now();
|
auto nms_end = std::chrono::steady_clock::now();
|
||||||
std::chrono::duration<float> search_diff = search_end - search_start;
|
std::chrono::duration<float> search_diff = search_end - search_start;
|
||||||
search_times[1] += double(search_diff.count() * 1000);
|
search_times[1] += double(search_diff.count() * 1000);
|
||||||
@ -289,12 +301,12 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
// print result
|
// print result
|
||||||
if (not benchmark or (benchmark and idx >= warmup_iter))
|
if (not benchmark or (benchmark and idx >= warmup_iter))
|
||||||
PrintResult(img_path, det_result, indeices, searcher, search_result);
|
PrintResult(img_path, det_result, indeices, vector_searcher_ptr,
|
||||||
|
search_result);
|
||||||
|
|
||||||
// for postprocess
|
// for postprocess
|
||||||
batch_imgs.clear();
|
batch_imgs.clear();
|
||||||
img_paths.clear();
|
img_paths.clear();
|
||||||
det_bbox_num.clear();
|
|
||||||
det_result.clear();
|
det_result.clear();
|
||||||
feature.clear();
|
feature.clear();
|
||||||
features.clear();
|
features.clear();
|
||||||
@ -322,8 +334,10 @@ int main(int argc, char **argv) {
|
|||||||
std::vector<int> shape =
|
std::vector<int> shape =
|
||||||
config.config_file["Global"]["image_shape"].as<std::vector<int>>();
|
config.config_file["Global"]["image_shape"].as<std::vector<int>>();
|
||||||
std::string det_shape = std::to_string(shape[0]);
|
std::string det_shape = std::to_string(shape[0]);
|
||||||
for (int i = 1; i < shape.size(); ++i)
|
|
||||||
|
for (int i = 1; i < shape.size(); ++i) {
|
||||||
det_shape = det_shape + ", " + std::to_string(shape[i]);
|
det_shape = det_shape + ", " + std::to_string(shape[i]);
|
||||||
|
}
|
||||||
|
|
||||||
AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
|
AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
|
||||||
cpu_num_threads, batch_size, det_shape, presion,
|
cpu_num_threads, batch_size, det_shape, presion,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user