firsr verison for whole pipline
parent
c1a598f9ba
commit
ba65fa9ae2
|
@ -0,0 +1,85 @@
|
|||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// This code is adpated from opencv(https://github.com/opencv/opencv)
|
||||
|
||||
#include <algorithm>
|
||||
#include <include/object_detector.h>
|
||||
|
||||
template <typename T>
|
||||
static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
|
||||
const std::pair<float, T> &pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
}
|
||||
|
||||
float RectOverlap(const PaddleDetection::ObjectResult &a,
|
||||
const PaddleDetection::ObjectResult &b) {
|
||||
float Aa = (a.rect[2] - a.rect[0] + 1) * (a.rect[3] - a.rect[1] + 1);
|
||||
float Ab = (b.rect[2] - b.rect[0] + 1) * (b.rect[3] - b.rect[1] + 1);
|
||||
|
||||
int iou_w = max(min(a.rect[2], b.rect[2]) - max(a.rect[0], b.rect[0]) + 1, 0);
|
||||
int iou_h = max(min(a.rect[3], b.rect[3]) - max(a.rect[1], b.rect[1]) + 1, 0);
|
||||
float Aab = iou_w * iou_h;
|
||||
return Aab / (Aa + Ab - Aab);
|
||||
}
|
||||
|
||||
// Get max scores with corresponding indices.
|
||||
// scores: a set of scores.
|
||||
// threshold: only consider scores higher than the threshold.
|
||||
// top_k: if -1, keep all; otherwise, keep at most top_k.
|
||||
// score_index_vec: store the sorted (score, index) pair.
|
||||
inline void
|
||||
GetMaxScoreIndex(const std::vector<PaddleDetection::ObjectResult> &det_result,
|
||||
const float threshold,
|
||||
std::vector<std::pair<float, int>> &score_index_vec) {
|
||||
// Generate index score pairs.
|
||||
for (size_t i = 0; i < det_result.size(); ++i) {
|
||||
if (det_result[i].confidence > threshold) {
|
||||
score_index_vec.push_back(std::make_pair(det_result[i].confidence, i));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the score pair according to the scores in descending order
|
||||
std::stable_sort(score_index_vec.begin(), score_index_vec.end(),
|
||||
SortScorePairDescend<int>);
|
||||
|
||||
// // Keep top_k scores if needed.
|
||||
// if (top_k > 0 && top_k < (int)score_index_vec.size())
|
||||
// {
|
||||
// score_index_vec.resize(top_k);
|
||||
// }
|
||||
}
|
||||
|
||||
void NMSBoxes(const std::vector<PaddleDetection::ObjectResult> det_result,
|
||||
const float score_threshold, const float nms_threshold,
|
||||
std::vector<int> &indices) {
|
||||
int a = 1;
|
||||
// Get top_k scores (with corresponding indices).
|
||||
std::vector<std::pair<float, int>> score_index_vec;
|
||||
GetMaxScoreIndex(det_result, score_threshold, score_index_vec);
|
||||
|
||||
// Do nms
|
||||
indices.clear();
|
||||
for (size_t i = 0; i < score_index_vec.size(); ++i) {
|
||||
const int idx = score_index_vec[i].second;
|
||||
bool keep = true;
|
||||
for (int k = 0; k < (int)indices.size() && keep; ++k) {
|
||||
const int kept_idx = indices[k];
|
||||
float overlap = RectOverlap(det_result[idx], det_result[kept_idx]);
|
||||
keep = overlap <= nms_threshold;
|
||||
}
|
||||
if (keep)
|
||||
indices.push_back(idx);
|
||||
}
|
||||
}
|
|
@ -65,7 +65,6 @@ public:
|
|||
this->use_fp16_ = config_file["Global"]["use_fp16"].as<bool>();
|
||||
this->model_dir_ =
|
||||
config_file["Global"]["det_inference_model_dir"].as<std::string>();
|
||||
this->nms_thres_ = config_file["Global"]["rec_nms_thresold"].as<float>();
|
||||
this->threshold_ = config_file["Global"]["threshold"].as<float>();
|
||||
this->max_det_results_ = config_file["Global"]["max_det_results"].as<int>();
|
||||
this->image_shape_ =
|
||||
|
@ -105,7 +104,6 @@ private:
|
|||
bool batch_size_ = 1;
|
||||
bool use_fp16_ = false;
|
||||
std::string model_dir_;
|
||||
float nms_thres_ = 0.02;
|
||||
float threshold_ = 0.5;
|
||||
float max_det_results_ = 5;
|
||||
std::vector<int> image_shape_ = {3, 640, 640};
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include <auto_log/autolog.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <include/cls.h>
|
||||
#include <include/nms.h>
|
||||
#include <include/object_detector.h>
|
||||
#include <include/vector_search.h>
|
||||
#include <include/yaml_config.h>
|
||||
|
@ -132,6 +133,21 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
|
|||
}
|
||||
}
|
||||
|
||||
void PrintResult(std::string &img_path,
|
||||
std::vector<PaddleDetection::ObjectResult> &det_result,
|
||||
std::vector<int> &indeices, VectorSearch &vector_search,
|
||||
SearchResult &search_result) {
|
||||
printf("%s:\n", img_path.c_str());
|
||||
for (int i = 0; i < indeices.size(); ++i) {
|
||||
int t = indeices[i];
|
||||
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
|
||||
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
|
||||
det_result[t].rect[3], det_result[t].confidence,
|
||||
vector_search.GetLabel(search_result.I[search_result.return_k * t])
|
||||
.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
std::string yaml_path = "";
|
||||
|
@ -169,6 +185,11 @@ int main(int argc, char **argv) {
|
|||
if (config.config_file["Global"]["max_det_results"].IsDefined()) {
|
||||
max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
|
||||
}
|
||||
float rec_nms_thresold = 0.05;
|
||||
if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
|
||||
rec_nms_thresold =
|
||||
config.config_file["Global"]["rec_nms_thresold"].as<float>();
|
||||
}
|
||||
|
||||
// load image_file_path
|
||||
std::string path =
|
||||
|
@ -184,16 +205,20 @@ int main(int argc, char **argv) {
|
|||
img_files_list.push_back(path);
|
||||
}
|
||||
std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
|
||||
|
||||
double elapsed_time = 0.0;
|
||||
// for time log
|
||||
std::vector<double> cls_times = {0, 0, 0};
|
||||
std::vector<double> det_times = {0, 0, 0};
|
||||
// for read images
|
||||
std::vector<cv::Mat> batch_imgs;
|
||||
std::vector<std::string> img_paths;
|
||||
// for detection
|
||||
std::vector<PaddleDetection::ObjectResult> det_result;
|
||||
std::vector<int> det_bbox_num;
|
||||
// for vector search
|
||||
std::vector<float> features;
|
||||
std::vector<float> feature;
|
||||
// for nms
|
||||
std::vector<int> indeices;
|
||||
|
||||
int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
|
||||
for (int idx = 0; idx < img_files_list.size(); ++idx) {
|
||||
|
@ -214,8 +239,8 @@ int main(int argc, char **argv) {
|
|||
det_bbox_num, det_times, visual_det, run_benchmark);
|
||||
|
||||
// select max_det_results bbox
|
||||
while (det_result.size() > max_det_results) {
|
||||
det_result.pop_back();
|
||||
if (det_result.size() > max_det_results) {
|
||||
det_result.resize(max_det_results);
|
||||
}
|
||||
// step2: add the whole image for recognition to improve recall
|
||||
PaddleDetection::ObjectResult result_whole_img = {
|
||||
|
@ -238,6 +263,13 @@ int main(int argc, char **argv) {
|
|||
search_result = searcher.Search(features.data(), det_result.size());
|
||||
|
||||
// nms for search result
|
||||
for (int i = 0; i < det_result.size(); ++i) {
|
||||
det_result[i].confidence = search_result.D[search_result.return_k * i];
|
||||
}
|
||||
NMSBoxes(det_result, detector.GetThreshold(), rec_nms_thresold, indeices);
|
||||
|
||||
// print result
|
||||
PrintResult(img_path, det_result, indeices, searcher, search_result);
|
||||
|
||||
// for postprocess
|
||||
batch_imgs.clear();
|
||||
|
@ -246,6 +278,7 @@ int main(int argc, char **argv) {
|
|||
det_result.clear();
|
||||
feature.clear();
|
||||
features.clear();
|
||||
indeices.clear();
|
||||
}
|
||||
|
||||
std::string presion = "fp32";
|
||||
|
|
Loading…
Reference in New Issue