Merge pull request #2337 from HydrogenSulfate/add_optional_det
Add flexible configuration of PP-ShiTu detection modelpull/2343/head
commit
1d397104be
|
@ -37,17 +37,14 @@
|
|||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
DEFINE_string(config,
|
||||
"", "Path of yaml file");
|
||||
DEFINE_string(c,
|
||||
"", "Path of yaml file");
|
||||
DEFINE_string(config, "", "Path of yaml file");
|
||||
DEFINE_string(c, "", "Path of yaml file");
|
||||
|
||||
void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
|
||||
const std::vector<std::string> &all_img_paths,
|
||||
const int batch_size, Detection::ObjectDetector *det,
|
||||
std::vector<Detection::ObjectResult> &im_result,
|
||||
std::vector<int> &im_bbox_num, std::vector<double> &det_t,
|
||||
const bool visual_det = false,
|
||||
std::vector<double> &det_t, const bool visual_det = false,
|
||||
const bool run_benchmark = false,
|
||||
const std::string &output_dir = "output") {
|
||||
int steps = ceil(float(all_img_paths.size()) / batch_size);
|
||||
|
@ -104,7 +101,7 @@ void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
|
|||
}
|
||||
}
|
||||
}
|
||||
im_bbox_num.push_back(detect_num);
|
||||
// im_bbox_num.push_back(detect_num);
|
||||
item_start_idx = item_start_idx + bbox_num[i];
|
||||
|
||||
// Visualization result
|
||||
|
@ -137,15 +134,16 @@ void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
|
|||
|
||||
void PrintResult(std::string &img_path,
|
||||
std::vector<Detection::ObjectResult> &det_result,
|
||||
std::vector<int> &indeices, VectorSearch &vector_search,
|
||||
std::vector<int> &indeices, VectorSearch *vector_search_ptr,
|
||||
SearchResult &search_result) {
|
||||
printf("%s:\n", img_path.c_str());
|
||||
for (int i = 0; i < indeices.size(); ++i) {
|
||||
int t = indeices[i];
|
||||
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
|
||||
printf(
|
||||
"\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
|
||||
det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
|
||||
det_result[t].rect[3], det_result[t].confidence,
|
||||
vector_search.GetLabel(search_result.I[search_result.return_k * t])
|
||||
vector_search_ptr->GetLabel(search_result.I[search_result.return_k * t])
|
||||
.c_str());
|
||||
}
|
||||
}
|
||||
|
@ -168,10 +166,25 @@ int main(int argc, char **argv) {
|
|||
YamlConfig config(yaml_path);
|
||||
config.PrintConfigInfo();
|
||||
|
||||
// initialize detector, rec_Model, vector_search
|
||||
Feature::FeatureExtracter feature_extracter(config.config_file);
|
||||
Detection::ObjectDetector detector(config.config_file);
|
||||
VectorSearch searcher(config.config_file);
|
||||
// initialize detector
|
||||
Detection::ObjectDetector *detector_ptr = nullptr;
|
||||
if (config.config_file["Global"]["det_inference_model_dir"].Type() !=
|
||||
YAML::NodeType::Null &&
|
||||
!config.config_file["Global"]["det_inference_model_dir"]
|
||||
.as<std::string>()
|
||||
.empty()) {
|
||||
detector_ptr = new Detection::ObjectDetector(config.config_file);
|
||||
} else {
|
||||
std::cout << "Found 'Global.det_inference_model_dir' empty, so "
|
||||
"det_predictor is disabled"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// initialize feature_extractor
|
||||
Feature::FeatureExtracter *feature_extracter_ptr =
|
||||
new Feature::FeatureExtracter(config.config_file);
|
||||
// initialize vector_searcher
|
||||
VectorSearch *vector_searcher_ptr = new VectorSearch(config.config_file);
|
||||
|
||||
// config
|
||||
const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
|
||||
|
@ -217,7 +230,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<std::string> img_paths;
|
||||
// for detection
|
||||
std::vector<Detection::ObjectResult> det_result;
|
||||
std::vector<int> det_bbox_num;
|
||||
|
||||
// for vector search
|
||||
std::vector<float> features;
|
||||
std::vector<float> feature;
|
||||
|
@ -243,9 +256,11 @@ int main(int argc, char **argv) {
|
|||
batch_imgs.push_back(srcimg);
|
||||
img_paths.push_back(img_path);
|
||||
|
||||
// step1: get all detection results
|
||||
DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
|
||||
det_bbox_num, det_times, visual_det, false);
|
||||
// step1: get all detection results if enable detector
|
||||
if (detector_ptr != nullptr) {
|
||||
DetPredictImage(batch_imgs, img_paths, batch_size, detector_ptr,
|
||||
det_result, det_times, visual_det, false);
|
||||
}
|
||||
|
||||
// select max_det_results bbox
|
||||
if (det_result.size() > max_det_results) {
|
||||
|
@ -257,7 +272,6 @@ int main(int argc, char **argv) {
|
|||
Detection::ObjectResult result_whole_img = {
|
||||
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
|
||||
det_result.push_back(result_whole_img);
|
||||
det_bbox_num[0] = det_result.size() + 1;
|
||||
|
||||
// step3: extract feature for all boxes in an inmage
|
||||
SearchResult search_result;
|
||||
|
@ -266,20 +280,22 @@ int main(int argc, char **argv) {
|
|||
int h = det_result[j].rect[3] - det_result[j].rect[1];
|
||||
cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
|
||||
cv::Mat crop_img = srcimg(rect);
|
||||
feature_extracter.Run(crop_img, feature, cls_times);
|
||||
feature_extracter_ptr->Run(crop_img, feature, cls_times);
|
||||
features.insert(features.end(), feature.begin(), feature.end());
|
||||
}
|
||||
|
||||
// step4: get search result
|
||||
auto search_start = std::chrono::steady_clock::now();
|
||||
search_result = searcher.Search(features.data(), det_result.size());
|
||||
search_result =
|
||||
vector_searcher_ptr->Search(features.data(), det_result.size());
|
||||
auto search_end = std::chrono::steady_clock::now();
|
||||
|
||||
// nms for search result
|
||||
for (int i = 0; i < det_result.size(); ++i) {
|
||||
det_result[i].confidence = search_result.D[search_result.return_k * i];
|
||||
}
|
||||
NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
|
||||
NMSBoxes(det_result, vector_searcher_ptr->GetThreshold(), rec_nms_thresold,
|
||||
indeices);
|
||||
auto nms_end = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> search_diff = search_end - search_start;
|
||||
search_times[1] += double(search_diff.count() * 1000);
|
||||
|
@ -289,12 +305,12 @@ int main(int argc, char **argv) {
|
|||
|
||||
// print result
|
||||
if (not benchmark or (benchmark and idx >= warmup_iter))
|
||||
PrintResult(img_path, det_result, indeices, searcher, search_result);
|
||||
PrintResult(img_path, det_result, indeices, vector_searcher_ptr,
|
||||
search_result);
|
||||
|
||||
// for postprocess
|
||||
batch_imgs.clear();
|
||||
img_paths.clear();
|
||||
det_bbox_num.clear();
|
||||
det_result.clear();
|
||||
feature.clear();
|
||||
features.clear();
|
||||
|
@ -322,8 +338,10 @@ int main(int argc, char **argv) {
|
|||
std::vector<int> shape =
|
||||
config.config_file["Global"]["image_shape"].as<std::vector<int>>();
|
||||
std::string det_shape = std::to_string(shape[0]);
|
||||
for (int i = 1; i < shape.size(); ++i)
|
||||
|
||||
for (int i = 1; i < shape.size(); ++i) {
|
||||
det_shape = det_shape + ", " + std::to_string(shape[i]);
|
||||
}
|
||||
|
||||
AutoLogger autolog_det("Det", use_gpu, use_tensorrt, enable_mkldnn,
|
||||
cpu_num_threads, batch_size, det_shape, presion,
|
||||
|
|
|
@ -31,6 +31,13 @@ class SystemPredictor(object):
|
|||
|
||||
self.config = config
|
||||
self.rec_predictor = RecPredictor(config)
|
||||
|
||||
if not config["Global"]["det_inference_model_dir"]:
|
||||
logger.info(
|
||||
f"Found 'Global.det_inference_model_dir' empty({config['Global']['det_inference_model_dir']}), so det_predictor is disabled"
|
||||
)
|
||||
self.det_predictor = None
|
||||
else:
|
||||
self.det_predictor = DetPredictor(config)
|
||||
|
||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||
|
@ -92,7 +99,10 @@ class SystemPredictor(object):
|
|||
def predict(self, img):
|
||||
output = []
|
||||
# st1: get all detection results
|
||||
if self.det_predictor:
|
||||
results = self.det_predictor.predict(img)
|
||||
else:
|
||||
results = []
|
||||
|
||||
# st2: add the whole image for recognition to improve recall
|
||||
results = self.append_self(results, img.shape)
|
||||
|
|
|
@ -198,6 +198,21 @@ The final output is as follows.
|
|||
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}]
|
||||
```
|
||||
|
||||
The recognition process supports flexible configuration. Users can choose not to use the object detection model, but directly input a single whole image into the feature extraction model, and calculate the feature vector for subsequent retrieval, thereby reducing the time-consuming of the overall recognition process. It can be achieved by the script below
|
||||
```shell
|
||||
# Use the following command to use the GPU for whole-image prediction
|
||||
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
|
||||
|
||||
# Use the following command to use the CPU for whole-image prediction
|
||||
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
|
||||
```
|
||||
|
||||
The final output is as follows
|
||||
```log
|
||||
INFO: find 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
|
||||
[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
|
||||
```
|
||||
|
||||
#### 4.3.2 multi images prediction
|
||||
|
||||
If you want to predict the images in the folder, you can directly modify the `Global.infer_imgs` field in the configuration file, or you can modify the corresponding configuration through the following -o parameter.
|
||||
|
|
|
@ -348,11 +348,25 @@ cd ..
|
|||
./build/pp_shitu -c inference_drink.yaml
|
||||
```
|
||||
|
||||
以 `drink_dataset_v2.0/test_images/nongfu_spring.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
|
||||
默认以 `../drink_dataset_v2.0/test_images/100.jpeg` 作为输入图像,则执行上述推理命令可以得到如下结果
|
||||
|
||||
```log
|
||||
../../deploy/drink_dataset_v2.0/test_images/nongfu_spring.jpeg:
|
||||
result0: bbox[0, 0, 729, 1094], score: 0.688691, label: 农夫山泉-饮用天然水
|
||||
../drink_dataset_v2.0/test_images/100.jpeg:
|
||||
result0: bbox[437, 72, 660, 723], score: 0.769916, label: 元气森林
|
||||
result1: bbox[220, 71, 449, 685], score: 0.695485, label: 元气森林
|
||||
result2: bbox[795, 104, 979, 653], score: 0.626963, label: 元气森林
|
||||
```
|
||||
|
||||
识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。只需将`Global.det_inference_model_dir`后的字段改为`null`或者`""`,再运行以下推理命令即可
|
||||
```shell
|
||||
./build/pp_shitu -c inference_drink.yaml
|
||||
```
|
||||
|
||||
最终输出结果如下
|
||||
```log
|
||||
Found 'Global.det_inference_model_dir' empty, so det_predictor is disabled
|
||||
../drink_dataset_v2.0/test_images/100.jpeg:
|
||||
result0: bbox[0, 0, 1199, 801], score: 0.568903, label: 元气森林
|
||||
```
|
||||
|
||||
由于python和C++的opencv实现存在部分不同,可能导致python推理和C++推理结果有微小差异。但基本不影响最终的检索结果。
|
||||
|
|
|
@ -223,6 +223,20 @@ python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.u
|
|||
|
||||

|
||||
|
||||
识别流程支持灵活配置,用户可以选择不使用主体检测模型,而直接将单幅整图输入到特征提取模型,计算特征向量供后续检索使用,从而减少整体识别流程的耗时。可以按照以下命令直接进行整图识别
|
||||
```shell
|
||||
# 使用下面的命令使用 GPU 进行整图预测
|
||||
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.det_inference_model_dir=None
|
||||
|
||||
# 使用下面的命令使用 CPU 进行整图预测
|
||||
python3.7 python/predict_system.py -c configs/inference_general.yaml -o Global.use_gpu=False -o Global.det_inference_model_dir=None
|
||||
```
|
||||
|
||||
最终输出结果如下
|
||||
```log
|
||||
INFO: Found 'Global.det_inference_model_dir' empty(), so det_predictor is disabled
|
||||
[{'bbox': [0, 0, 1200, 802], 'rec_docs': '元气森林', 'rec_scores': 0.5696486}]
|
||||
```
|
||||
|
||||
<a name="基于文件夹的批量识别"></a>
|
||||
|
||||
|
|
Loading…
Reference in New Issue