diff --git a/deploy/lite_shitu/README.md b/deploy/lite_shitu/README.md
index e775b134a..5d6dde45d 100644
--- a/deploy/lite_shitu/README.md
+++ b/deploy/lite_shitu/README.md
@@ -1,4 +1,4 @@
-# Paddle-Lite端侧部署
+# PP-ShiTu在Paddle-Lite端侧部署
本教程将介绍基于[Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在移动端部署PaddleDetection模型的详细步骤。
@@ -125,30 +125,71 @@ Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括
#### 2.1.3 转换示例
-下面以PaddleDetection中的 `PicoDet` 模型为例,介绍使用`paddle_lite_opt`完成预训练模型到inference模型,再到Paddle-Lite优化模型的转换。
+下面介绍使用`paddle_lite_opt`完成主体检测模型和识别模型的预训练模型,转成inference模型,最终转换成Paddle-Lite的优化模型的过程。
+
+##### 2.1.3.1 转换主体检测模型
```shell
+# 当前目录为 $PaddleClas/deploy/lite_shitu
+# $code_path需替换成相应的运行目录,可以根据需要,将$code_path设置成需要的目录
+export $code_path=~
+cd $code_path
+git clone https://github.com/PaddlePaddle/PaddleDetection.git
# 进入PaddleDetection根目录
-cd PaddleDetection_root_path
+cd PaddleDetection
# 将预训练模型导出为inference模型
-python tools/export_model.py -c configs/picodet/picodet_s_320_coco.yml \
- -o weights=https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams --output_dir=output_inference
+python tools/export_model.py -c configs/picodet/application/mainbody_detection/picodet_lcnet_x2_5_640_mainbody.yml -o weights=https://paddledet.bj.bcebos.com/models/picodet_lcnet_x2_5_640_mainbody.pdparams --output_dir=inference
# 将inference模型转化为Paddle-Lite优化模型
-# FP32
-paddle_lite_opt --valid_targets=arm --model_file=output_inference/picodet_s_320_coco/model.pdmodel --param_file=output_inference/picodet_s_320_coco/model.pdiparams --optimize_out=output_inference/picodet_s_320_coco/model
-# FP16
-paddle_lite_opt --valid_targets=arm --model_file=output_inference/picodet_s_320_coco/model.pdmodel --param_file=output_inference/picodet_s_320_coco/model.pdiparams --optimize_out=output_inference/picodet_s_320_coco/model --enable_fp16=true
+paddle_lite_opt --model_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdmodel --param_file=inference/picodet_lcnet_x2_5_640_mainbody/model.pdiparams --optimize_out=inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det
-# 将inference模型配置转化为json格式
-python deploy/lite/convert_yml_to_json.py output_inference/picodet_s_320_coco/infer_cfg.yml
+# 将转好的模型复制到lite_shitu目录下
+cd $PaddleClas/deploy/lite_shitu
+mkdir models
+cp $code_path/PaddleDetection/inference/picodet_lcnet_x2_5_640_mainbody/mainbody_det.nb $PaddleClas/deploy/lite_shitu/models
```
-最终在output_inference/picodet_s_320_coco/文件夹下生成`model.nb` 和 `infer_cfg.json`的文件。
+##### 2.1.3.2 转换识别模型
+
+```shell
+# 转换inference model
+待补充,生成的inference model存储在PaddleClas/inference下,同时生成label.txt,也存在此文件夹下
+
+# 转换为Paddle-Lite模型
+paddle_lite_opt --model_file=inference/inference.pdmodel --param_file=inference/inference.pdiparams --optimize_out=inference/rec
+
+# 将模型、label文件拷贝到lite_shitu下
+cp inference/rec.nb deploy/lite_shitu/models/
+cp inference/label.txt deploy/lite_shitu/models/
+cd deploy/lite_shitu
+```
**注意**:`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb`;`--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。
+##### 2.1.3.3 准备测试图像
+
+```shell
+mkdir images
+# 根据需要准备测试图像,可以在images文件夹中存放多张图像
+cp ../images/wangzai.jpg images/
+```
+
+
+
+##### 2.1.3.4 将yaml文件转换成json文件
+
+```shell
+# 如果测试单张图像
+python generate_json_config.py --det_model_path models/mainbody_det.nb --rec_model_path models/rec.nb --rec_label_path models/label.txt --img_path images/wangzai.jpg
+# or
+# 如果测试多张图像
+python generate_json_config.py --det_model_path models/mainbody_det.nb --rec_model_path models/rec.nb --rec_label_path models/label.txt --img_dir images
+
+# 执行完成后,会在lit_shitu下生成shitu_config.json配置文件
+
+```
+
### 2.2 与手机联调
首先需要进行一些准备工作。
@@ -183,41 +224,28 @@ List of devices attached
4. 编译lite部署代码生成移动端可执行文件
```shell
-cd {PadddleDetection_Root}
-cd deploy/lite/
+cd $PaddleClas/deploy/lite_shitu
inference_lite_path=/{lite prediction library path}/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv/
-mkdir $inference_lite_path/demo/cxx/lite
+mkdir $inference_lite_path/demo/cxx/ppshitu_lite
-cp -r Makefile src/ include/ *runtime_config.json $inference_lite_path/demo/cxx/lite
+cp -r Makefile src/ include/ *.json models/ images/ $inference_lite_path/demo/cxx/ppshitu_lite
-cd $inference_lite_path/demo/cxx/lite
+cd $inference_lite_path/demo/cxx/ppshitu_lite
# 执行编译,等待完成后得到可执行文件main
make ARM_ABI=arm8
#如果是arm7,则执行 make ARM_ABI = arm7 (或者在Makefile中修改该项)
-
```
5. 准备优化后的模型、预测库文件、测试图像。
```shell
mkdir deploy
-cp main *runtime_config.json deploy/
+mv models deploy/
+mv images deploy/
+cp pp_shitu deploy/
cd deploy
-mkdir model_det
-mkdir model_keypoint
-
-# 将优化后的模型、预测库文件、测试图像放置在预测库中的demo/cxx/detection文件夹下
-cp {PadddleDetection_Root}/output_inference/picodet_s_320_coco/model.nb ./model_det/
-cp {PadddleDetection_Root}/output_inference/picodet_s_320_coco/infer_cfg.json ./model_det/
-
-# 如果需要关键点模型,则只需操作:
-cp {PadddleDetection_Root}/output_inference/hrnet_w32_256x192/model.nb ./model_keypoint/
-cp {PadddleDetection_Root}/output_inference/hrnet_w32_256x192/infer_cfg.json ./model_keypoint/
-
-# 将测试图像复制到deploy文件夹中
-cp [your_test_img].jpg ./demo.jpg
# 将C++预测动态库so文件复制到deploy文件夹中
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./
@@ -227,45 +255,19 @@ cp ../../../cxx/lib/libpaddle_light_api_shared.so ./
```
deploy/
-|-- model_det/
-| |--model.nb 优化后的检测模型文件
-| |--infer_cfg.json 检测器模型配置文件
-|-- model_keypoint/
-| |--model.nb 优化后的关键点模型文件
-| |--infer_cfg.json 关键点模型配置文件
-|-- main 生成的移动端执行文件
-|-- det_runtime_config.json 目标检测执行时参数配置文件
-|-- keypoint_runtime_config.json 关键点检测执行时参数配置文件
+|-- models/
+| |--mainbody_det.nb 优化后的主体检测模型文件
+| |--rec.nb 优化后的识别模型文件
+| |--label.txt 识别模型的label文件
+|-- images/
+| ... 图片文件
+|-- pp_shitu 生成的移动端执行文件
+|-- shitu_config.json 执行时参数配置文件
|-- libpaddle_light_api_shared.so Paddle-Lite库文件
```
**注意:**
-* `det_runtime_config.json` 包含了目标检测的超参数,请按需进行修改:
-
-```shell
-{
- "model_dir_det": "./model_det/", #检测器模型路径
- "batch_size_det": 1, #检测预测时batchsize
- "threshold_det": 0.5, #检测器输出阈值
- "image_file": "demo.jpg", #测试图片
- "image_dir": "", #测试图片文件夹
- "run_benchmark": true, #性能测试开关
- "cpu_threads": 4 #线程数
-}
-```
-
-* `keypoint_runtime_config.json` 包含了关键点检测的超参数,请按需进行修改:
-```shell
-{
- "model_dir_keypoint": "./model_keypoint/", #关键点模型路径(不使用需为空字符)
- "batch_size_keypoint": 8, #关键点预测时batchsize
- "threshold_keypoint": 0.5, #关键点输出阈值
- "image_file": "demo.jpg", #测试图片
- "image_dir": "", #测试图片文件夹
- "run_benchmark": true, #性能测试开关
- "cpu_threads": 4 #线程数
-}
-```
+* `shitu_config.json` 包含了目标检测的超参数,请按需进行修改
6. 启动调试,上述步骤完成后就可以使用ADB将文件夹 `deploy/` push到手机上运行,步骤如下:
@@ -278,23 +280,20 @@ cd /data/local/tmp/deploy
export LD_LIBRARY_PATH=/data/local/tmp/deploy:$LD_LIBRARY_PATH
# 修改权限为可执行
-chmod 777 main
-# 以检测为例,执行程序
-./main det_runtime_config.json
+chmod 777 pp_shitu
+# 执行程序
+./pp_shitu shitu_config.json
```
如果对代码做了修改,则需要重新编译并push到手机上。
运行效果如下:
-
-

-
-
+
## FAQ
-Q1:如果想更换模型怎么办,需要重新按照流程走一遍吗?
+Q1:如果想更换模型怎么办,需要重新按照流程走一遍吗?
A1:如果已经走通了上述步骤,更换模型只需要替换 `.nb` 模型文件即可,同时要注意修改下配置文件中的 `.nb` 文件路径以及类别映射文件(如有必要)。
-Q2:换一个图测试怎么做?
-A2:替换 deploy 下的测试图像为你想要测试的图像,使用 ADB 再次 push 到手机上即可。
+Q2:换一个图测试怎么做?
+A2:替换 deploy 下的测试图像为你想要测试的图像,并重新生成json配置文件(或者直接修改图像路径),使用 ADB 再次 push 到手机上即可。
diff --git a/deploy/lite_shitu/generate_json_config.py b/deploy/lite_shitu/generate_json_config.py
index fb54685db..1525cdab9 100644
--- a/deploy/lite_shitu/generate_json_config.py
+++ b/deploy/lite_shitu/generate_json_config.py
@@ -7,40 +7,63 @@ import yaml
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--yaml_path',
- type=str,
- default='../configs/inference_drink.yaml')
- parser.add_argument('--img_dir', type=str, default=None,
- help='The dir path for inference images')
- parser.add_argument('--det_model_path',
- type=str, default='./det.nb',
- help="The model path for mainbody detection")
- parser.add_argument('--rec_model_path', type=str, default='./rec.nb', help="The rec model path")
- parser.add_argument('--rec-label-path', type=str, default='./label.txt', help='The rec model label')
- parser.add_argument('--arch',
- type=str,
- default='GFL',
- help='The model structure for detection model')
- parser.add_argument('--fpn-stride',
- type=list,
- default=[8, 16, 32, 64],
- help="The fpn strid for detection model")
- parser.add_argument('--keep_top_k',
- type=int,
- default=100,
- help='The params for nms(postprocess for detection)')
- parser.add_argument('--nms-name',
- type=str,
- default='MultiClassNMS',
- help='The nms name for postprocess of detection model')
- parser.add_argument('--nms_threshold',
- type=float,
- default=0.5,
- help='The nms nms_threshold for detection postprocess')
- parser.add_argument('--nms_top_k',
- type=int,
- default=1000,
- help='The nms_top_k in postprocess of detection model')
+ parser.add_argument(
+ '--yaml_path', type=str, default='../configs/inference_drink.yaml')
+ parser.add_argument(
+ '--img_dir',
+ type=str,
+ default=None,
+ help='The dir path for inference images')
+ parser.add_argument(
+ '--img_path',
+ type=str,
+ default=None,
+ help='The dir path for inference images')
+ parser.add_argument(
+ '--det_model_path',
+ type=str,
+ default='./det.nb',
+ help="The model path for mainbody detection")
+ parser.add_argument(
+ '--rec_model_path',
+ type=str,
+ default='./rec.nb',
+ help="The rec model path")
+ parser.add_argument(
+ '--rec_label_path',
+ type=str,
+ default='./label.txt',
+ help='The rec model label')
+ parser.add_argument(
+ '--arch',
+ type=str,
+ default='PicoDet',
+ help='The model structure for detection model')
+ parser.add_argument(
+ '--fpn-stride',
+ type=list,
+ default=[8, 16, 32, 64],
+ help="The fpn strid for detection model")
+ parser.add_argument(
+ '--keep_top_k',
+ type=int,
+ default=100,
+ help='The params for nms(postprocess for detection)')
+ parser.add_argument(
+ '--nms-name',
+ type=str,
+ default='MultiClassNMS',
+ help='The nms name for postprocess of detection model')
+ parser.add_argument(
+ '--nms_threshold',
+ type=float,
+ default=0.5,
+ help='The nms nms_threshold for detection postprocess')
+ parser.add_argument(
+ '--nms_top_k',
+ type=int,
+ default=1000,
+ help='The nms_top_k in postprocess of detection model')
parser.add_argument(
'--score_threshold',
type=float,
@@ -55,18 +78,24 @@ def main():
config_yaml = yaml.safe_load(open(args.yaml_path))
config_json = {}
config_json["Global"] = {}
- config_json["Global"]["infer_imgs"] = config_yaml["Global"]["infer_imgs"]
- config_json["Global"]["infer_imgs_dir"] = args.img_dir
+ config_json["Global"][
+ "infer_imgs"] = args.img_path if args.img_path else config_yaml[
+ "Global"]["infer_imgs"]
+ if args.img_dir is not None:
+ config_json["Global"]["infer_imgs_dir"] = args.img_dir
+ config_json["Global"]["infer_imgs"] = None
+ else:
+ config_json["Global"][
+ "infer_imgs"] = args.img_path if args.img_path else config_yaml[
+ "Global"]["infer_imgs"]
config_json["Global"]["batch_size"] = config_yaml["Global"]["batch_size"]
- config_json["Global"]["cpu_num_threads"] = config_yaml["Global"][
- "cpu_num_threads"]
+ config_json["Global"]["cpu_num_threads"] = min(
+ config_yaml["Global"]["cpu_num_threads"], 4)
config_json["Global"]["image_shape"] = config_yaml["Global"]["image_shape"]
- config_json["Global"][
- "det_model_path"] = args.det_model_path
- config_json["Global"][
- "rec_model_path"] = args.rec_model_path
+ config_json["Global"]["det_model_path"] = args.det_model_path
+ config_json["Global"]["rec_model_path"] = args.rec_model_path
config_json["Global"]["rec_label_path"] = args.rec_label_path
- config_json["Global"]["labe_list"] = config_yaml["Global"]["labe_list"]
+ config_json["Global"]["label_list"] = config_yaml["Global"]["labe_list"]
config_json["Global"]["rec_nms_thresold"] = config_yaml["Global"][
"rec_nms_thresold"]
config_json["Global"]["max_det_results"] = config_yaml["Global"][
diff --git a/deploy/lite_shitu/include/config_parser.h b/deploy/lite_shitu/include/config_parser.h
index 1228d1980..dca0e5a68 100644
--- a/deploy/lite_shitu/include/config_parser.h
+++ b/deploy/lite_shitu/include/config_parser.h
@@ -60,7 +60,7 @@ class ConfigPaser {
// Get label_list for visualization
if (config["Global"].isMember("label_list")) {
label_list_.clear();
- for (auto item : config["label_list"]) {
+ for (auto item : config["Global"]["label_list"]) {
label_list_.emplace_back(item.as());
}
} else {
diff --git a/deploy/lite_shitu/include/recognition.h b/deploy/lite_shitu/include/recognition.h
index b7d2a48c2..0c45e946e 100644
--- a/deploy/lite_shitu/include/recognition.h
+++ b/deploy/lite_shitu/include/recognition.h
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#pragma once
#include "paddle_api.h" // NOLINT
#include "json/json.h"
#include
@@ -54,8 +55,10 @@ public:
}
LoadLabel(config_file["Global"]["rec_label_path"].as());
SetPreProcessParam(config_file["RecPreProcess"]["transform_ops"]);
- if (!config_file["Global"].isMember("return_k"))
+ if (!config_file["Global"].isMember("return_k")){
this->topk = config_file["Global"]["return_k"].as();
+ }
+ printf("rec model create!\n");
}
void LoadLabel(std::string path) {
diff --git a/deploy/lite_shitu/include/utils.h b/deploy/lite_shitu/include/utils.h
index cbda311ec..18a04cf34 100644
--- a/deploy/lite_shitu/include/utils.h
+++ b/deploy/lite_shitu/include/utils.h
@@ -14,13 +14,14 @@
#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
namespace PPShiTu {
@@ -32,8 +33,11 @@ struct ObjectResult {
int class_id;
// Confidence of detected object
float confidence;
+
+ // RecModel result
+ std::vector rec_result;
};
-void nms(std::vector &input_boxes, float nms_threshold);
+void nms(std::vector &input_boxes, float nms_threshold, bool rec_nms=false);
-} // namespace PPShiTu
+} // namespace PPShiTu
diff --git a/deploy/lite_shitu/src/main.cc b/deploy/lite_shitu/src/main.cc
index ae23e7c55..e9c05c72a 100644
--- a/deploy/lite_shitu/src/main.cc
+++ b/deploy/lite_shitu/src/main.cc
@@ -73,17 +73,10 @@ void DetPredictImage(const std::vector &batch_imgs,
std::vector det_t = {0, 0, 0};
int steps = ceil(float(batch_imgs.size()) / batch_size_det);
for (int idx = 0; idx < steps; idx++) {
- std::vector batch_imgs;
int left_image_cnt = batch_imgs.size() - idx * batch_size_det;
if (left_image_cnt > batch_size_det) {
left_image_cnt = batch_size_det;
}
- /* for (int bs = 0; bs < left_image_cnt; bs++) { */
- /* std::string image_file_path = all_img_paths.at(idx * batch_size_det +
- * bs); */
- /* cv::Mat im = cv::imread(image_file_path, 1); */
- /* batch_imgs.insert(batch_imgs.end(), im); */
- /* } */
// Store all detected result
std::vector result;
std::vector bbox_num;
@@ -108,32 +101,7 @@ void DetPredictImage(const std::vector &batch_imgs,
}
detect_num += 1;
im_result.push_back(item);
- /* if (item.rect.size() > 6) { */
- /* is_rbox = true; */
- /* printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
- */
- /* item.class_id, */
- /* item.confidence, */
- /* item.rect[0], */
- /* item.rect[1], */
- /* item.rect[2], */
- /* item.rect[3], */
- /* item.rect[4], */
- /* item.rect[5], */
- /* item.rect[6], */
- /* item.rect[7]); */
- /* } else { */
- /* printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", */
- /* item.class_id, */
- /* item.confidence, */
- /* item.rect[0], */
- /* item.rect[1], */
- /* item.rect[2], */
- /* item.rect[3]); */
- /* } */
}
- /* std::cout << all_img_paths.at(idx * batch_size_det + i) */
- /* << " The number of detected box: " << detect_num << std::endl; */
item_start_idx = item_start_idx + bbox_num[i];
}
@@ -144,14 +112,13 @@ void DetPredictImage(const std::vector &batch_imgs,
}
void PrintResult(const std::string &image_path,
- std::vector &det_result,
- std::vector> &rec_results) {
+ std::vector &det_result) {
printf("%s:\n", image_path.c_str());
for (int i = 0; i < det_result.size(); ++i) {
printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
det_result[i].rect[0], det_result[i].rect[1], det_result[i].rect[2],
- det_result[i].rect[3], rec_results[i][0].score,
- rec_results[i][0].class_name.c_str());
+ det_result[i].rect[3], det_result[i].rec_result[0].score,
+ det_result[i].rec_result[0].class_name.c_str());
}
}
@@ -163,37 +130,33 @@ int main(int argc, char **argv) {
return -1;
}
std::string config_path = argv[1];
- std::string img_path = "";
+ std::string img_dir = "";
if (argc >= 3) {
- img_path = argv[2];
+ img_dir = argv[2];
}
// Parsing command-line
PPShiTu::load_jsonf(config_path, RT_Config);
- if (RT_Config["Global"]["det_inference_model_dir"]
- .as()
- .empty()) {
- std::cout << "Please set [det_inference_model_dir] in " << config_path
- << std::endl;
+ if (RT_Config["Global"]["det_model_path"].as().empty()) {
+ std::cout << "Please set [det_model_path] in " << config_path << std::endl;
return -1;
}
if (RT_Config["Global"]["infer_imgs"].as().empty() &&
- img_path.empty()) {
+ img_dir.empty()) {
std::cout << "Please set [infer_imgs] in " << config_path
<< " Or use command: <" << argv[0] << " [shitu_config]"
<< " [image_dir]>" << std::endl;
return -1;
}
- if (!img_path.empty()) {
+ if (!img_dir.empty()) {
std::cout << "Use image_dir in command line overide the path in config file"
<< std::endl;
- RT_Config["Global"]["infer_imgs_dir"] = img_path;
+ RT_Config["Global"]["infer_imgs_dir"] = img_dir;
RT_Config["Global"]["infer_imgs"] = "";
}
// Load model and create a object detector
PPShiTu::ObjectDetector det(
- RT_Config,
- RT_Config["Global"]["det_inference_model_dir"].as(),
+ RT_Config, RT_Config["Global"]["det_model_path"].as(),
RT_Config["Global"]["cpu_num_threads"].as(),
RT_Config["Global"]["batch_size"].as());
// create rec model
@@ -202,7 +165,6 @@ int main(int argc, char **argv) {
std::vector det_result;
std::vector batch_imgs;
- std::vector> rec_results;
double rec_time;
if (!RT_Config["Global"]["infer_imgs"].as().empty() ||
!RT_Config["Global"]["infer_imgs_dir"].as().empty()) {
@@ -239,7 +201,7 @@ int main(int argc, char **argv) {
// add the whole image for recognition to improve recall
PPShiTu::ObjectResult result_whole_img = {
- {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
+ {0, 0, srcimg.cols, srcimg.rows}, 0, 1.0};
det_result.push_back(result_whole_img);
// get rec result
@@ -250,13 +212,14 @@ int main(int argc, char **argv) {
cv::Mat crop_img = srcimg(rect);
std::vector result =
rec.RunRecModel(crop_img, rec_time);
- rec_results.push_back(result);
+ det_result[j].rec_result.assign(result.begin(), result.end());
}
- PrintResult(img_path, det_result, rec_results);
-
+ // rec nms
+ PPShiTu::nms(det_result,
+ RT_Config["Global"]["rec_nms_thresold"].as(), true);
+ PrintResult(img_path, det_result);
batch_imgs.clear();
det_result.clear();
- rec_results.clear();
}
}
return 0;
diff --git a/deploy/lite_shitu/src/object_detector.cc b/deploy/lite_shitu/src/object_detector.cc
index 59ccba3af..ffea31bb9 100644
--- a/deploy/lite_shitu/src/object_detector.cc
+++ b/deploy/lite_shitu/src/object_detector.cc
@@ -95,7 +95,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
// Clone the image : keep the original mat for postprocess
cv::Mat im = ori_im.clone();
- cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
+ // cv::cvtColor(im, im, cv::COLOR_BGR2RGB);
preprocessor_.Run(&im, &inputs_);
}
@@ -235,7 +235,7 @@ void ObjectDetector::Predict(const std::vector& imgs,
auto postprocess_start = std::chrono::steady_clock::now();
// Get output tensor
output_data_list_.clear();
- int num_class = 80;
+ int num_class = 1;
int reg_max = 7;
auto output_names = predictor_->GetOutputNames();
// TODO: Unified model output.
@@ -281,6 +281,7 @@ void ObjectDetector::Predict(const std::vector& imgs,
out_bbox_num_data_.data());
}
// Postprocessing result
+
result->clear();
if (config_.arch_ == "PicoDet") {
PPShiTu::PicoDetPostProcess(
diff --git a/deploy/lite_shitu/src/preprocess_op.cc b/deploy/lite_shitu/src/preprocess_op.cc
index 90a54a1c6..9c74d6ee7 100644
--- a/deploy/lite_shitu/src/preprocess_op.cc
+++ b/deploy/lite_shitu/src/preprocess_op.cc
@@ -127,11 +127,11 @@ void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
// Preprocessor op running order
const std::vector Preprocessor::RUN_ORDER = {"InitInfo",
- "TopDownEvalAffine",
- "Resize",
- "NormalizeImage",
- "PadStride",
- "Permute"};
+ "DetTopDownEvalAffine",
+ "DetResize",
+ "DetNormalizeImage",
+ "DetPadStride",
+ "DetPermute"};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
for (const auto& name : RUN_ORDER) {
diff --git a/deploy/lite_shitu/src/recognition.cc b/deploy/lite_shitu/src/recognition.cc
index 8d3a99476..0e711f386 100644
--- a/deploy/lite_shitu/src/recognition.cc
+++ b/deploy/lite_shitu/src/recognition.cc
@@ -38,7 +38,7 @@ std::vector Recognition::RunRecModel(const cv::Mat &img,
// Get output and post process
std::unique_ptr output_tensor(
- std::move(this->predictor->GetOutput(0)));
+ std::move(this->predictor->GetOutput(1)));
auto *output_data = output_tensor->data();
auto end = std::chrono::system_clock::now();
auto duration =
diff --git a/deploy/lite_shitu/src/utils.cc b/deploy/lite_shitu/src/utils.cc
index 88b2c563c..3bc461770 100644
--- a/deploy/lite_shitu/src/utils.cc
+++ b/deploy/lite_shitu/src/utils.cc
@@ -16,14 +16,23 @@
namespace PPShiTu {
-void nms(std::vector &input_boxes, float nms_threshold) {
- std::sort(input_boxes.begin(),
- input_boxes.end(),
- [](ObjectResult a, ObjectResult b) { return a.confidence > b.confidence; });
+void nms(std::vector &input_boxes, float nms_threshold,
+ bool rec_nms) {
+ if (!rec_nms) {
+ std::sort(input_boxes.begin(), input_boxes.end(),
+ [](ObjectResult a, ObjectResult b) {
+ return a.confidence > b.confidence;
+ });
+ } else {
+ std::sort(input_boxes.begin(), input_boxes.end(),
+ [](ObjectResult a, ObjectResult b) {
+ return a.rec_result[0].score > b.rec_result[0].score;
+ });
+ }
std::vector vArea(input_boxes.size());
for (int i = 0; i < int(input_boxes.size()); ++i) {
- vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1)
- * (input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1);
+ vArea[i] = (input_boxes.at(i).rect[2] - input_boxes.at(i).rect[0] + 1) *
+ (input_boxes.at(i).rect[3] - input_boxes.at(i).rect[1] + 1);
}
for (int i = 0; i < int(input_boxes.size()); ++i) {
for (int j = i + 1; j < int(input_boxes.size());) {
@@ -36,14 +45,13 @@ void nms(std::vector &input_boxes, float nms_threshold) {
float inter = w * h;
float ovr = inter / (vArea[i] + vArea[j] - inter);
if (ovr >= nms_threshold) {
- input_boxes.erase(input_boxes.begin() + j);
- vArea.erase(vArea.begin() + j);
- }
- else {
- j++;
+ input_boxes.erase(input_boxes.begin() + j);
+ vArea.erase(vArea.begin() + j);
+ } else {
+ j++;
}
}
}
}
-} // namespace PPShiTu
+} // namespace PPShiTu
diff --git a/docs/images/ppshitu_lite_demo.png b/docs/images/ppshitu_lite_demo.png
new file mode 100644
index 000000000..a9b48293d
Binary files /dev/null and b/docs/images/ppshitu_lite_demo.png differ