upload cplusplus serving
parent
d69a6e8f24
commit
4e0ca2d003
|
@ -0,0 +1,7 @@
|
|||
nohup python3 -m paddle_serving_server.serve \
|
||||
--model picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving \
|
||||
--port 9293 >>log_mainbody_detection.txt 1&>2 &
|
||||
|
||||
nohup python3 -m paddle_serving_server.serve \
|
||||
--model general_PPLCNet_x2_5_lite_v1.0_serving \
|
||||
--port 9294 >>log_feature_extraction.txt 1&>2 &
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from paddle_serving_client import Client
|
||||
from paddle_serving_app.reader import *
|
||||
import cv2
|
||||
import faiss
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class MainbodyDetect():
|
||||
"""
|
||||
pp-shitu mainbody detect.
|
||||
include preprocess, process, postprocess
|
||||
return detect results
|
||||
Attention: Postprocess include num limit and box filter; no nms
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.preprocess = DetectionSequential([
|
||||
DetectionFile2Image(), DetectionNormalize(
|
||||
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
|
||||
DetectionResize(
|
||||
(640, 640), False, interpolation=2), DetectionTranspose(
|
||||
(2, 0, 1))
|
||||
])
|
||||
|
||||
self.client = Client()
|
||||
self.client.load_client_config(
|
||||
"picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/serving_client_conf.prototxt"
|
||||
)
|
||||
self.client.connect(['127.0.0.1:9293'])
|
||||
|
||||
self.max_det_result = 5
|
||||
self.conf_threshold = 0.2
|
||||
|
||||
def predict(self, imgpath):
|
||||
im, im_info = self.preprocess(sys.argv[1])
|
||||
im_shape = np.array(im.shape[1:]).reshape(-1)
|
||||
scale_factor = np.array(list(im_info['scale_factor'])).reshape(-1)
|
||||
fetch_map = self.client.predict(
|
||||
feed={
|
||||
"image": im,
|
||||
"im_shape": im_shape,
|
||||
"scale_factor": scale_factor,
|
||||
},
|
||||
fetch=["save_infer_model/scale_0.tmp_1"],
|
||||
batch=False)
|
||||
return self.postprocess(fetch_map, imgpath)
|
||||
|
||||
def postprocess(self, fetch_map, imgpath):
|
||||
#1. get top max_det_result
|
||||
det_results = fetch_map["save_infer_model/scale_0.tmp_1"]
|
||||
if len(det_results) > self.max_det_result:
|
||||
boxes_reserved = fetch_map[
|
||||
"save_infer_model/scale_0.tmp_1"][:self.max_det_result]
|
||||
else:
|
||||
boxes_reserved = det_results
|
||||
|
||||
#2. do conf threshold
|
||||
boxes_list = []
|
||||
for i in range(boxes_reserved.shape[0]):
|
||||
if (boxes_reserved[i, 1]) > self.conf_threshold:
|
||||
boxes_list.append(boxes_reserved[i, :])
|
||||
|
||||
#3. add origin image box
|
||||
origin_img = cv2.imread(imgpath)
|
||||
boxes_list.append(
|
||||
np.array([0, 1.0, 0, 0, origin_img.shape[1], origin_img.shape[0]]))
|
||||
return np.array(boxes_list)
|
||||
|
||||
|
||||
class ObjectRecognition():
|
||||
"""
|
||||
pp-shitu object recognion for all objects detected by MainbodyDetect.
|
||||
include preprocess, process, postprocess
|
||||
preprocess include preprocess for each image and batching.
|
||||
Batch process
|
||||
postprocess include retrieval and nms
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = Client()
|
||||
self.client.load_client_config(
|
||||
"general_PPLCNet_x2_5_lite_v1.0_client/serving_client_conf.prototxt"
|
||||
)
|
||||
self.client.connect(["127.0.0.1:9294"])
|
||||
|
||||
self.seq = Sequential([
|
||||
BGR2RGB(), Resize((224, 224)), Div(255),
|
||||
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
|
||||
False), Transpose((2, 0, 1))
|
||||
])
|
||||
|
||||
self.searcher, self.id_map = self.init_index()
|
||||
|
||||
self.rec_nms_thresold = 0.05
|
||||
self.rec_score_thres = 0.5
|
||||
self.feature_normalize = True
|
||||
self.return_k = 1
|
||||
|
||||
def init_index(self):
|
||||
index_dir = "../../drink_dataset_v1.0/index"
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "vector.index")), "vector.index not found ..."
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
|
||||
|
||||
searcher = faiss.read_index(os.path.join(index_dir, "vector.index"))
|
||||
|
||||
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
|
||||
id_map = pickle.load(fd)
|
||||
return searcher, id_map
|
||||
|
||||
def predict(self, det_boxes, imgpath):
|
||||
#1. preprocess
|
||||
batch_imgs = []
|
||||
origin_img = cv2.imread(imgpath)
|
||||
for i in range(det_boxes.shape[0]):
|
||||
box = det_boxes[i]
|
||||
x1, y1, x2, y2 = [int(x) for x in box[2:]]
|
||||
cropped_img = origin_img[y1:y2, x1:x2, :].copy()
|
||||
tmp = self.seq(cropped_img)
|
||||
batch_imgs.append(tmp)
|
||||
batch_imgs = np.array(batch_imgs)
|
||||
|
||||
#2. process
|
||||
fetch_map = self.client.predict(
|
||||
feed={"x": batch_imgs}, fetch=["features"], batch=True)
|
||||
batch_features = fetch_map["features"]
|
||||
|
||||
#3. postprocess
|
||||
if self.feature_normalize:
|
||||
feas_norm = np.sqrt(
|
||||
np.sum(np.square(batch_features), axis=1, keepdims=True))
|
||||
batch_features = np.divide(batch_features, feas_norm)
|
||||
scores, docs = self.searcher.search(batch_features, self.return_k)
|
||||
|
||||
results = []
|
||||
for i in range(scores.shape[0]):
|
||||
pred = {}
|
||||
if scores[i][0] >= self.rec_score_thres:
|
||||
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
|
||||
pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
|
||||
pred["rec_scores"] = scores[i][0]
|
||||
results.append(pred)
|
||||
return self.nms_to_rec_results(results)
|
||||
|
||||
def nms_to_rec_results(self, results):
|
||||
filtered_results = []
|
||||
x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
|
||||
y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
|
||||
x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
|
||||
y2 = np.array([r["bbox"][3] for r in results]).astype("float32")
|
||||
scores = np.array([r["rec_scores"] for r in results])
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
inds = np.where(ovr <= self.rec_nms_thresold)[0]
|
||||
order = order[inds + 1]
|
||||
filtered_results.append(results[i])
|
||||
return filtered_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
det = MainbodyDetect()
|
||||
rec = ObjectRecognition()
|
||||
|
||||
#1. get det_results
|
||||
det_results = det.predict(sys.argv[1])
|
||||
print(det_results)
|
||||
|
||||
#2. get rec_results
|
||||
rec_results = rec.predict(det_results, sys.argv[1])
|
||||
print(rec_results)
|
|
@ -0,0 +1,2 @@
|
|||
#run cls server:
|
||||
nohup python3 -m paddle_serving_server.serve --model ResNet50_vd_serving --port 9292 &
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from paddle_serving_client import Client
|
||||
|
||||
#app
|
||||
from paddle_serving_app.reader import Sequential, URL2Image, Resize
|
||||
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
|
||||
import time
|
||||
|
||||
client = Client()
|
||||
client.load_client_config("./ResNet50_vd_serving/serving_server_conf.prototxt")
|
||||
client.connect(["127.0.0.1:9292"])
|
||||
|
||||
label_dict = {}
|
||||
label_idx = 0
|
||||
with open("imagenet.label") as fin:
|
||||
for line in fin:
|
||||
label_dict[label_idx] = line.strip()
|
||||
label_idx += 1
|
||||
|
||||
#preprocess
|
||||
seq = Sequential([
|
||||
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
|
||||
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
|
||||
])
|
||||
|
||||
start = time.time()
|
||||
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
|
||||
for i in range(1):
|
||||
img = seq(image_file)
|
||||
fetch_map = client.predict(
|
||||
feed={"inputs": img}, fetch=["prediction"], batch=False)
|
||||
|
||||
prob = max(fetch_map["prediction"][0])
|
||||
label = label_dict[fetch_map["prediction"][0].tolist().index(prob)].strip(
|
||||
).replace(",", "")
|
||||
print("prediction: {}, probability: {}".format(label, prob))
|
||||
end = time.time()
|
||||
print(end - start)
|
|
@ -4,9 +4,13 @@
|
|||
- [3. 图像分类服务部署](#3)
|
||||
- [3.1 模型转换](#3.1)
|
||||
- [3.2 服务部署和请求](#3.2)
|
||||
- [3.2.1 Python Serving](#3.2.1)
|
||||
- [3.2.2 C++ Serving](#3.2.2)
|
||||
- [4. 图像识别服务部署](#4)
|
||||
- [4.1 模型转换](#4.1)
|
||||
- [4.2 服务部署和请求](#4.2)
|
||||
- [4.2.1 Python Serving](#4.2.1)
|
||||
- [4.2.2 C++ Serving](#4.2.2)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
|
@ -88,7 +92,7 @@ ResNet50_vd 推理模型转换完成后,会在当前文件夹多出 `ResNet50_
|
|||
|- serving_client_conf.prototxt
|
||||
|- serving_client_conf.stream.prototxt
|
||||
```
|
||||
得到模型文件之后,需要修改 `ResNet50_vd_server` 下文件 `serving_server_conf.prototxt` 中的 alias 名字:将 `fetch_var` 中的 `alias_name` 改为 `prediction`
|
||||
得到模型文件之后,需要分别修改 `ResNet50_vd_server` 和 `ResNet50_vd_client` 下文件 `serving_server_conf.prototxt` 中的 alias 名字:将 `fetch_var` 中的 `alias_name` 改为 `prediction`
|
||||
|
||||
**备注**: Serving 为了兼容不同模型的部署,提供了输入输出重命名的功能。这样,不同的模型在推理部署时,只需要修改配置文件的 alias_name 即可,无需修改代码即可完成推理部署。
|
||||
修改后的 serving_server_conf.prototxt 如下所示:
|
||||
|
@ -112,15 +116,18 @@ fetch_var {
|
|||
```
|
||||
<a name="3.2"></a>
|
||||
### 3.2 服务部署和请求
|
||||
paddleserving 目录包含了启动 pipeline 服务和发送预测请求的代码,包括:
|
||||
paddleserving 目录包含了启动 pipeline 服务、C++ serving服务和发送预测请求的代码,包括:
|
||||
```shell
|
||||
__init__.py
|
||||
config.yml # 启动服务的配置文件
|
||||
config.yml # 启动pipeline服务的配置文件
|
||||
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
|
||||
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
|
||||
classification_web_service.py # 启动pipeline服务端的脚本
|
||||
run_cpp_serving.sh # 启动C++ Serving部署的脚本
|
||||
test_cpp_serving_client.py # rpc方式发送C++ serving预测请求的脚本
|
||||
```
|
||||
|
||||
<a name="3.2.1"></a>
|
||||
#### 3.2.1 Python Serving
|
||||
- 启动服务:
|
||||
```shell
|
||||
# 启动服务,运行日志保存在 log.txt
|
||||
|
@ -137,6 +144,22 @@ python3 pipeline_http_client.py
|
|||
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
|
||||

|
||||
|
||||
<a name="3.2.2"></a>
|
||||
#### 3.2.2 C++ Serving
|
||||
- 启动服务:
|
||||
```shell
|
||||
# 启动服务, 服务在后台运行,运行日志保存在 nohup.txt
|
||||
sh run_cpp_serving.sh
|
||||
```
|
||||
|
||||
- 发送请求:
|
||||
```shell
|
||||
# 发送服务请求
|
||||
python3 test_cpp_serving_client.py
|
||||
```
|
||||
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
|
||||

|
||||
|
||||
<a name="4"></a>
|
||||
## 4.图像识别服务部署
|
||||
使用 PaddleServing 做服务化部署时,需要将保存的 inference 模型转换为 Serving 模型。 下面以 PP-ShiTu 中的超轻量图像识别模型为例,介绍图像识别服务的部署。
|
||||
|
@ -162,7 +185,11 @@ python3 -m paddle_serving_client.convert --dirname ./general_PPLCNet_x2_5_lite_v
|
|||
--serving_server ./general_PPLCNet_x2_5_lite_v1.0_serving/ \
|
||||
--serving_client ./general_PPLCNet_x2_5_lite_v1.0_client/
|
||||
```
|
||||
<<<<<<< HEAD
|
||||
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNet_x2_5_lite_v1.0_serving/` 和 `general_PPLCNet_x2_5_lite_v1.0_serving/` 的文件夹。分别修改 `general_PPLCNet_x2_5_lite_v1.0_serving/` 和 `general_PPLCNet_x2_5_lite_v1.0_client/` 目录下的 serving_server_conf.prototxt 中的 alias 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`。
|
||||
=======
|
||||
识别推理模型转换完成后,会在当前文件夹多出 `general_PPLCNet_x2_5_lite_v1.0_serving/` 和 `general_PPLCNet_x2_5_lite_v1.0_client/` 的文件夹。修改 `general_PPLCNet_x2_5_lite_v1.0_serving/` 目录下的 serving_server_conf.prototxt 中的 alias 名字: 将 `fetch_var` 中的 `alias_name` 改为 `features`。
|
||||
>>>>>>> d69a6e8f242cd894b41e9608bbca23172bcd3193
|
||||
修改后的 serving_server_conf.prototxt 内容如下:
|
||||
```
|
||||
feed_var {
|
||||
|
@ -207,14 +234,19 @@ wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_da
|
|||
```shell
|
||||
cd ./deploy/paddleserving/recognition
|
||||
```
|
||||
paddleserving 目录包含启动 pipeline 服务和发送预测请求的代码,包括:
|
||||
paddleserving 目录包含启动 Python Pipeline 服务、C++ Serving 服务和发送预测请求的代码,包括:
|
||||
```
|
||||
__init__.py
|
||||
config.yml # 启动服务的配置文件
|
||||
config.yml # 启动python pipeline服务的配置文件
|
||||
pipeline_http_client.py # http方式发送pipeline预测请求的脚本
|
||||
pipeline_rpc_client.py # rpc方式发送pipeline预测请求的脚本
|
||||
recognition_web_service.py # 启动pipeline服务端的脚本
|
||||
run_cpp_serving.sh # 启动C++ Serving部署的脚本
|
||||
test_cpp_serving_client.py # rpc方式发送C++ serving预测请求的脚本
|
||||
```
|
||||
|
||||
<a name="4.2.1"></a>
|
||||
#### 4.2.1 Python Serving
|
||||
- 启动服务:
|
||||
```
|
||||
# 启动服务,运行日志保存在 log.txt
|
||||
|
@ -230,6 +262,23 @@ python3 pipeline_http_client.py
|
|||
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
|
||||

|
||||
|
||||
<a name="4.2.2"></a>
|
||||
#### 4.2.2 C++ Serving
|
||||
- 启动服务:
|
||||
```shell
|
||||
# 启动服务: 此处会在后台同时启动主体检测和特征提取服务,端口号分别为9293和9294;
|
||||
# 运行日志分别保存在 log_mainbody_detection.txt 和 log_feature_extraction.txt中
|
||||
sh run_cpp_serving.sh
|
||||
```
|
||||
|
||||
- 发送请求:
|
||||
```shell
|
||||
# 发送服务请求
|
||||
python3 test_cpp_serving_client.py
|
||||
```
|
||||
成功运行后,模型预测的结果会打印在 cmd 窗口中,结果示例为:
|
||||

|
||||
|
||||
<a name="5"></a>
|
||||
## 5.FAQ
|
||||
**Q1**: 发送请求后没有结果返回或者提示输出解码报错
|
||||
|
|
Loading…
Reference in New Issue