fix ppshitu_lite bugs and fix README.md

pull/1652/head
dongshuilong 2022-01-25 19:26:45 +08:00
parent 96e3a20877
commit 267f6508f0
11 changed files with 211 additions and 204 deletions

View File

@ -1,4 +1,4 @@
# Paddle-Lite端侧部署
# PP-ShiTu在Paddle-Lite端侧部署
本教程将介绍基于[Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在移动端部署PaddleDetection模型的详细步骤。
@ -125,30 +125,71 @@ Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括
#### 2.1.3 转换示例
下面以PaddleDetection中的 `PicoDet` 模型为例,介绍使用`paddle_lite_opt`完成预训练模型到inference模型再到Paddle-Lite优化模型的转换。
下面介绍使用`paddle_lite_opt`完成主体检测模型和识别模型的预训练模型转成inference模型最终转换成Paddle-Lite的优化模型的过程。
##### 2.1.3.1 转换主体检测模型
```shell
# 当前目录为 $PaddleClas/deploy/lite_shitu
# $code_path需替换成相应的运行目录,可以根据需要,将$code_path设置成需要的目录
export $code_path=~
cd $code_path
git clone https://github.com/PaddlePaddle/PaddleDetection.git
# 进入PaddleDetection根目录
cd PaddleDetection_root_path
cd PaddleDetection
# 将预训练模型导出为inference模型
python tools/export_model.py -c configs/picodet/picodet_s_320_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams --output_dir=output_inference
python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams --output_dir=inference
# 将inference模型转化为Paddle-Lite优化模型
# FP32
paddle_lite_opt --valid_targets=arm --model_file=output_inference/picodet_s_320_coco/model.pdmodel --param_file=output_inference/picodet_s_320_coco/model.pdiparams --optimize_out=output_inference/picodet_s_320_coco/model
# FP16
paddle_lite_opt --valid_targets=arm --model_file=output_inference/picodet_s_320_coco/model.pdmodel --param_file=output_inference/picodet_s_320_coco/model.pdiparams --optimize_out=output_inference/picodet_s_320_coco/model --enable_fp16=true
paddle_lite_opt --model_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdmodel --param_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdiparams --optimize_out=inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det
# 将inference模型配置转化为json格式
python deploy/lite/convert_yml_to_json.py output_inference/picodet_s_320_coco/infer_cfg.yml
# 将转好的模型复制到lite_shitu目录下
cd $PaddleClas/deploy/lite_shitu
mkdir models
cp $code_path/PaddleDetection/inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det.nb $PaddleClas/deploy/lite_shitu/models
```
最终在output_inference/picodet_s_320_coco/文件夹下生成`model.nb` 和 `infer_cfg.json`的文件。
##### 2.1.3.2 转换识别模型
```shell
# 转换inference model
待补充,生成的inference model存储在PaddleClas/inference下同时生成label.txt也存在此文件夹下
# 转换为Paddle-Lite模型
paddle_lite_opt --model_file=inference/inference.pdmodel --param_file=inference/inference.pdiparams --optimize_out=inference/rec
# 将模型、label文件拷贝到lite_shitu下
cp inference/rec.nb deploy/lite_shitu/models/
cp inference/label.txt deploy/lite_shitu/models/
cd deploy/lite_shitu
```
**注意**`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb``--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。
##### 2.1.3.3 准备测试图像
```shell
mkdir images
# 根据需要准备测试图像可以在images文件夹中存放多张图像
cp ../images/wangzai.jpg images/
```
##### 2.1.3.4 将yaml文件转换成json文件
```shell
# 如果测试单张图像
python generate_json_config.py --det_model_path models/mainbody_det.nb --rec_model_path models/rec.nb --rec_label_path models/label.txt --img_path images/wangzai.jpg
# or
# 如果测试多张图像
python generate_json_config.py --det_model_path models/mainbody_det.nb --rec_model_path models/rec.nb --rec_label_path models/label.txt --img_dir images
# 执行完成后会在lit_shitu下生成shitu_config.json配置文件
```
### 2.2 与手机联调
首先需要进行一些准备工作。
@ -183,41 +224,28 @@ List of devices attached
4. 编译lite部署代码生成移动端可执行文件
```shell
cd {PadddleDetection_Root}
cd deploy/lite/
cd $PaddleClas/deploy/lite_shitu
inference_lite_path=/{lite prediction library path}/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv/
mkdir $inference_lite_path/demo/cxx/lite
mkdir $inference_lite_path/demo/cxx/ppshitu_lite
cp -r Makefile src/ include/ *runtime_config.json $inference_lite_path/demo/cxx/lite
cp -r Makefile src/ include/ *.json models/ images/ $inference_lite_path/demo/cxx/ppshitu_lite
cd $inference_lite_path/demo/cxx/lite
cd $inference_lite_path/demo/cxx/ppshitu_lite
# 执行编译等待完成后得到可执行文件main
make ARM_ABI=arm8
#如果是arm7则执行 make ARM_ABI = arm7 (或者在Makefile中修改该项)
```
5. 准备优化后的模型、预测库文件、测试图像。
```shell
mkdir deploy
cp main *runtime_config.json deploy/
mv models deploy/
mv images deploy/
cp pp_shitu deploy/
cd deploy
mkdir model_det
mkdir model_keypoint
# 将优化后的模型、预测库文件、测试图像放置在预测库中的demo/cxx/detection文件夹下
cp {PadddleDetection_Root}/output_inference/picodet_s_320_coco/model.nb ./model_det/
cp {PadddleDetection_Root}/output_inference/picodet_s_320_coco/infer_cfg.json ./model_det/
# 如果需要关键点模型,则只需操作:
cp {PadddleDetection_Root}/output_inference/hrnet_w32_256x192/model.nb ./model_keypoint/
cp {PadddleDetection_Root}/output_inference/hrnet_w32_256x192/infer_cfg.json ./model_keypoint/
# 将测试图像复制到deploy文件夹中
cp [your_test_img].jpg ./demo.jpg
# 将C++预测动态库so文件复制到deploy文件夹中
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./
@ -227,45 +255,19 @@ cp ../../../cxx/lib/libpaddle_light_api_shared.so ./
```
deploy/
|-- model_det/
| |--model.nb 优化后的检测模型文件
| |--infer_cfg.json 检测器模型配置文件
|-- model_keypoint/
| |--model.nb 优化后的关键点模型文件
| |--infer_cfg.json 关键点模型配置文件
|-- main 生成的移动端执行文件
|-- det_runtime_config.json 目标检测执行时参数配置文件
|-- keypoint_runtime_config.json 关键点检测执行时参数配置文件
|-- models/
| |--mainbody_det.nb 优化后的主体检测模型文件
| |--rec.nb 优化后的识别模型文件
| |--label.txt 识别模型的label文件
|-- images/
| ... 图片文件
|-- pp_shitu 生成的移动端执行文件
|-- shitu_config.json 执行时参数配置文件
|-- libpaddle_light_api_shared.so Paddle-Lite库文件
```
**注意:**
* `det_runtime_config.json` 包含了目标检测的超参数,请按需进行修改:
```shell
{
"model_dir_det": "./model_det/", #检测器模型路径
"batch_size_det": 1, #检测预测时batchsize
"threshold_det": 0.5, #检测器输出阈值
"image_file": "demo.jpg", #测试图片
"image_dir": "", #测试图片文件夹
"run_benchmark": true, #性能测试开关
"cpu_threads": 4 #线程数
}
```
* `keypoint_runtime_config.json` 包含了关键点检测的超参数,请按需进行修改:
```shell
{
"model_dir_keypoint": "./model_keypoint/", #关键点模型路径(不使用需为空字符)
"batch_size_keypoint": 8, #关键点预测时batchsize
"threshold_keypoint": 0.5, #关键点输出阈值
"image_file": "demo.jpg", #测试图片
"image_dir": "", #测试图片文件夹
"run_benchmark": true, #性能测试开关
"cpu_threads": 4 #线程数
}
```
* `shitu_config.json` 包含了目标检测的超参数,请按需进行修改
6. 启动调试上述步骤完成后就可以使用ADB将文件夹 `deploy/` push到手机上运行步骤如下
@ -278,23 +280,20 @@ cd /data/local/tmp/deploy
export LD_LIBRARY_PATH=/data/local/tmp/deploy:$LD_LIBRARY_PATH
# 修改权限为可执行
chmod 777 main
# 以检测为例,执行程序
./main det_runtime_config.json
chmod 777 pp_shitu
# 执行程序
./pp_shitu shitu_config.json
```
如果对代码做了修改则需要重新编译并push到手机上。
运行效果如下:
<div align="center">
<img src="../../docs/images/lite_demo.jpg" width="600">
</div>
![](../../docs/images/ppshitu_lite_demo.png)
## FAQ
Q1如果想更换模型怎么办需要重新按照流程走一遍吗
Q1如果想更换模型怎么办需要重新按照流程走一遍吗
A1如果已经走通了上述步骤更换模型只需要替换 `.nb` 模型文件即可,同时要注意修改下配置文件中的 `.nb` 文件路径以及类别映射文件(如有必要)。
Q2换一个图测试怎么做
A2替换 deploy 下的测试图像为你想要测试的图像,使用 ADB 再次 push 到手机上即可。
Q2换一个图测试怎么做
A2替换 deploy 下的测试图像为你想要测试的图像,并重新生成json配置文件或者直接修改图像路径使用 ADB 再次 push 到手机上即可。

View File

@ -7,40 +7,63 @@ import yaml
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--yaml_path',
type=str,
default='../configs/inference_drink.yaml')
parser.add_argument('--img_dir', type=str, default=None,
help='The dir path for inference images')
parser.add_argument('--det_model_path',
type=str, default='./det.nb',
help="The model path for mainbody detection")
parser.add_argument('--rec_model_path', type=str, default='./rec.nb', help="The rec model path")
parser.add_argument('--rec-label-path', type=str, default='./label.txt', help='The rec model label')
parser.add_argument('--arch',
type=str,
default='GFL',
help='The model structure for detection model')
parser.add_argument('--fpn-stride',
type=list,
default=[8, 16, 32, 64],
help="The fpn strid for detection model")
parser.add_argument('--keep_top_k',
type=int,
default=100,
help='The params for nms(postprocess for detection)')
parser.add_argument('--nms-name',
type=str,
default='MultiClassNMS',
help='The nms name for postprocess of detection model')
parser.add_argument('--nms_threshold',
type=float,
default=0.5,
help='The nms nms_threshold for detection postprocess')
parser.add_argument('--nms_top_k',
type=int,
default=1000,
help='The nms_top_k in postprocess of detection model')
parser.add_argument(
'--yaml_path', type=str, default='../configs/inference_drink.yaml')
parser.add_argument(
'--img_dir',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--img_path',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--det_model_path',
type=str,
default='./det.nb',
help="The model path for mainbody detection")
parser.add_argument(
'--rec_model_path',
type=str,
default='./rec.nb',
help="The rec model path")
parser.add_argument(
'--rec_label_path',
type=str,
default='./label.txt',
help='The rec model label')
parser.add_argument(
'--arch',
type=str,
default='PicoDet',
help='The model structure for detection model')
parser.add_argument(
'--fpn-stride',
type=list,
default=[8, 16, 32, 64],
help="The fpn strid for detection model")
parser.add_argument(
'--keep_top_k',
type=int,
default=100,
help='The params for nms(postprocess for detection)')
parser.add_argument(
'--nms-name',
type=str,
default='MultiClassNMS',
help='The nms name for postprocess of detection model')
parser.add_argument(
'--nms_threshold',
type=float,
default=0.5,
help='The nms nms_threshold for detection postprocess')
parser.add_argument(
'--nms_top_k',
type=int,
default=1000,
help='The nms_top_k in postprocess of detection model')
parser.add_argument(
'--score_threshold',
type=float,
@ -55,18 +78,24 @@ def main():
config_yaml = yaml.safe_load(open(args.yaml_path))
config_json = {}
config_json["Global"] = {}
config_json["Global"]["infer_imgs"] = config_yaml["Global"]["infer_imgs"]
config_json["Global"]["infer_imgs_dir"] = args.img_dir
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
if args.img_dir is not None:
config_json["Global"]["infer_imgs_dir"] = args.img_dir
config_json["Global"]["infer_imgs"] = None
else:
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
config_json["Global"]["batch_size"] = config_yaml["Global"]["batch_size"]
config_json["Global"]["cpu_num_threads"] = config_yaml["Global"][
"cpu_num_threads"]
config_json["Global"]["cpu_num_threads"] = min(
config_yaml["Global"]["cpu_num_threads"], 4)
config_json["Global"]["image_shape"] = config_yaml["Global"]["image_shape"]
config_json["Global"][
"det_model_path"] = args.det_model_path
config_json["Global"][
"rec_model_path"] = args.rec_model_path
config_json["Global"]["det_model_path"] = args.det_model_path
config_json["Global"]["rec_model_path"] = args.rec_model_path
config_json["Global"]["rec_label_path"] = args.rec_label_path
config_json["Global"]["labe_list"] = config_yaml["Global"]["labe_list"]
config_json["Global"]["label_list"] = config_yaml["Global"]["labe_list"]
config_json["Global"]["rec_nms_thresold"] = config_yaml["Global"][
"rec_nms_thresold"]
config_json["Global"]["max_det_results"] = config_yaml["Global"][

View File

@ -60,7 +60,7 @@ class ConfigPaser {
// Get label_list for visualization
if (config["Global"].isMember("label_list")) {
label_list_.clear();
for (auto item : config["label_list"]) {
for (auto item : config["Global"]["label_list"]) {
label_list_.emplace_back(item.as<std::string>());
}
} else {

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle_api.h" // NOLINT
#include "json/json.h"
#include <arm_neon.h>
@ -54,8 +55,10 @@ public:
}
LoadLabel(config_file["Global"]["rec_label_path"].as<std::string>());
SetPreProcessParam(config_file["RecPreProcess"]["transform_ops"]);
if (!config_file["Global"].isMember("return_k"))
if (!config_file["Global"].isMember("return_k")){
this->topk = config_file["Global"]["return_k"].as<int>();
}
printf("rec model create!\n");
}
void LoadLabel(std::string path) {

View File

@ -14,13 +14,14 @@
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <ctime>
#include <numeric>
#include <algorithm>
#include <ctime>
#include <include/recognition.h>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
namespace PPShiTu {
@ -32,8 +33,11 @@ struct ObjectResult {
int class_id;
// Confidence of detected object
float confidence;
// RecModel result
std::vector<RESULT> rec_result;
};
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold);
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold, bool rec_nms=false);
} // namespace PPShiTu
} // namespace PPShiTu

View File

@ -73,17 +73,10 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
std::vector<double> det_t = {0, 0, 0};
int steps = ceil(float(batch_imgs.size()) / batch_size_det);
for (int idx = 0; idx < steps; idx++) {
std::vector<cv::Mat> batch_imgs;
int left_image_cnt = batch_imgs.size() - idx * batch_size_det;
if (left_image_cnt > batch_size_det) {
left_image_cnt = batch_size_det;
}
/* for (int bs = 0; bs < left_image_cnt; bs++) { */
/* std::string image_file_path = all_img_paths.at(idx * batch_size_det +
* bs); */
/* cv::Mat im = cv::imread(image_file_path, 1); */
/* batch_imgs.insert(batch_imgs.end(), im); */
/* } */
// Store all detected result
std::vector<PPShiTu::ObjectResult> result;
std::vector<int> bbox_num;
@ -108,32 +101,7 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
}
detect_num += 1;
im_result.push_back(item);
/* if (item.rect.size() > 6) { */
/* is_rbox = true; */
/* printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
*/
/* item.class_id, */
/* item.confidence, */
/* item.rect[0], */
/* item.rect[1], */
/* item.rect[2], */
/* item.rect[3], */
/* item.rect[4], */
/* item.rect[5], */
/* item.rect[6], */
/* item.rect[7]); */
/* } else { */
/* printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", */
/* item.class_id, */
/* item.confidence, */
/* item.rect[0], */
/* item.rect[1], */
/* item.rect[2], */
/* item.rect[3]); */
/* } */
}
/* std::cout << all_img_paths.at(idx * batch_size_det + i) */
/* << " The number of detected box: " << detect_num << std::endl; */
item_start_idx = item_start_idx + bbox_num[i];
}
@ -144,14 +112,13 @@ void DetPredictImage(const std::vector<cv::Mat> &batch_imgs,
}
void PrintResult(const std::string &image_path,
std::vector<PPShiTu::ObjectResult> &det_result,
std::vector<std::vector<PPShiTu::RESULT>> &rec_results) {
std::vector<PPShiTu::ObjectResult> &det_result) {
printf("%s:\n", image_path.c_str());
for (int i = 0; i < det_result.size(); ++i) {
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
det_result[i].rect[0], det_result[i].rect[1], det_result[i].rect[2],
det_result[i].rect[3], rec_results[i][0].score,
rec_results[i][0].class_name.c_str());
det_result[i].rect[3], det_result[i].rec_result[0].score,
det_result[i].rec_result[0].class_name.c_str());
}
}
@ -163,37 +130,33 @@ int main(int argc, char **argv) {
return -1;
}
std::string config_path = argv[1];
std::string img_path = "";
std::string img_dir = "";
if (argc >= 3) {
img_path = argv[2];
img_dir = argv[2];
}
// Parsing command-line
PPShiTu::load_jsonf(config_path, RT_Config);
if (RT_Config["Global"]["det_inference_model_dir"]
.as<std::string>()
.empty()) {
std::cout << "Please set [det_inference_model_dir] in " << config_path
<< std::endl;
if (RT_Config["Global"]["det_model_path"].as<std::string>().empty()) {
std::cout << "Please set [det_model_path] in " << config_path << std::endl;
return -1;
}
if (RT_Config["Global"]["infer_imgs"].as<std::string>().empty() &&
img_path.empty()) {
img_dir.empty()) {
std::cout << "Please set [infer_imgs] in " << config_path
<< " Or use command: <" << argv[0] << " [shitu_config]"
<< " [image_dir]>" << std::endl;
return -1;
}
if (!img_path.empty()) {
if (!img_dir.empty()) {
std::cout << "Use image_dir in command line overide the path in config file"
<< std::endl;
RT_Config["Global"]["infer_imgs_dir"] = img_path;
RT_Config["Global"]["infer_imgs_dir"] = img_dir;
RT_Config["Global"]["infer_imgs"] = "";
}
// Load model and create a object detector
PPShiTu::ObjectDetector det(
RT_Config,
RT_Config["Global"]["det_inference_model_dir"].as<std::string>(),
RT_Config, RT_Config["Global"]["det_model_path"].as<std::string>(),
RT_Config["Global"]["cpu_num_threads"].as<int>(),
RT_Config["Global"]["batch_size"].as<int>());
// create rec model
@ -202,7 +165,6 @@ int main(int argc, char **argv) {
std::vector<PPShiTu::ObjectResult> det_result;
std::vector<cv::Mat> batch_imgs;
std::vector<std::vector<PPShiTu::RESULT>> rec_results;
double rec_time;
if (!RT_Config["Global"]["infer_imgs"].as<std::string>().empty() ||
!RT_Config["Global"]["infer_imgs_dir"].as<std::string>().empty()) {
@ -239,7 +201,7 @@ int main(int argc, char **argv) {
// add the whole image for recognition to improve recall
PPShiTu::ObjectResult result_whole_img = {
{0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
{0, 0, srcimg.cols, srcimg.rows}, 0, 1.0};
det_result.push_back(result_whole_img);
// get rec result
@ -250,13 +212,14 @@ int main(int argc, char **argv) {
cv::Mat crop_img = srcimg(rect);
std::vector<PPShiTu::RESULT> result =
rec.RunRecModel(crop_img, rec_time);
rec_results.push_back(result);
det_result[j].rec_result.assign(result.begin(), result.end());
}
PrintResult(img_path, det_result, rec_results);
// rec nms
PPShiTu::nms(det_result,
RT_Config["Global"]["rec_nms_thresold"].as<float>(), true);
PrintResult(img_path, det_result);
batch_imgs.clear();
det_result.clear();
rec_results.clear();
}
}
return 0;

View File

@ -95,7 +95,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
// Clone the image : keep the original mat for postprocess
cv::Mat im = ori_im.clone();
cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
// cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
preprocessor_.Run(&im, &inputs_);
}
@ -235,7 +235,7 @@ void ObjectDetector::Predict(const std::vector<cv::Mat>& imgs,
auto postprocess_start = std::chrono::steady_clock::now();
// Get output tensor
output_data_list_.clear();
int num_class = 80;
int num_class = 1;
int reg_max = 7;
auto output_names = predictor_->GetOutputNames();
// TODO: Unified model output.
@ -281,6 +281,7 @@ void ObjectDetector::Predict(const std::vector<cv::Mat>& imgs,
out_bbox_num_data_.data());
}
// Postprocessing result
result->clear();
if (config_.arch_ == "PicoDet") {
PPShiTu::PicoDetPostProcess(

View File

@ -127,11 +127,11 @@ void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {"InitInfo",
"TopDownEvalAffine",
"Resize",
"NormalizeImage",
"PadStride",
"Permute"};
"DetTopDownEvalAffine",
"DetResize",
"DetNormalizeImage",
"DetPadStride",
"DetPermute"};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
for (const auto& name : RUN_ORDER) {

View File

@ -38,7 +38,7 @@ std::vector<RESULT> Recognition::RunRecModel(const cv::Mat &img,
// Get output and post process
std::unique_ptr<const Tensor> output_tensor(
std::move(this->predictor->GetOutput(0)));
std::move(this->predictor->GetOutput(1)));
auto *output_data = output_tensor->data<float>();
auto end = std::chrono::system_clock::now();
auto duration =

View File

@ -16,14 +16,23 @@
namespace PPShiTu {
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold) {
std::sort(input_boxes.begin(),
input_boxes.end(),
[](ObjectResult a, ObjectResult b) { return a.confidence > b.confidence; });
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold,
bool rec_nms) {
if (!rec_nms) {
std::sort(input_boxes.begin(), input_boxes.end(),
[](ObjectResult a, ObjectResult b) {
return a.confidence > b.confidence;
});
} else {
std::sort(input_boxes.begin(), input_boxes.end(),
[](ObjectResult a, ObjectResult b) {
return a.rec_result[0].score > b.rec_result[0].score;
});
}
std::vector<float> vArea(input_boxes.size());
for (int i = 0; i < int(input_boxes.size()); ++i) {
vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1)
* (input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1);
vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1) *
(input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1);
}
for (int i = 0; i < int(input_boxes.size()); ++i) {
for (int j = i + 1; j < int(input_boxes.size());) {
@ -36,14 +45,13 @@ void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold) {
float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter);
if (ovr >= nms_threshold) {
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
}
else {
j++;
input_boxes.erase(input_boxes.begin() + j);
vArea.erase(vArea.begin() + j);
} else {
j++;
}
}
}
}
} // namespace PPShiTu
} // namespace PPShiTu

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB