commit
7dc5bc9b73
|
@ -1,4 +1,65 @@
|
|||
# 模型服务化部署
|
||||
|
||||
[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻易部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端等特点,详细使用请参考 [Paddle Serving 相关文档](https://github.com/PaddlePaddle/Serving)。
|
||||
## 一、简介
|
||||
[Paddle Serving](https://github.com/PaddlePaddle/Serving) 旨在帮助深度学习开发者轻易部署在线预测服务,支持一键部署工业级的服务能力、客户端和服务端之间高并发和高效通信、并支持多种编程语言开发客户端。
|
||||
|
||||
该部分以HTTP预测服务部署为例,介绍怎样在PaddleClas中使用PaddleServing部署模型服务。
|
||||
|
||||
|
||||
## 二、Serving安装
|
||||
|
||||
Serving官网推荐使用docker安装并部署Serving环境。首先需要拉取docker环境并创建基于Serving的docker。
|
||||
|
||||
```shell
|
||||
nvidia-docker pull hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu
|
||||
nvidia-docker run -p 9292:9292 --name test -dit hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu
|
||||
nvidia-docker exec -it test bash
|
||||
```
|
||||
|
||||
进入docker后,需要安装Serving相关的python包。
|
||||
|
||||
```shell
|
||||
pip install paddlepaddle-gpu
|
||||
pip install paddle-serving-client
|
||||
pip install paddle-serving-server-gpu
|
||||
```
|
||||
|
||||
* 如果安装速度太慢,可以通过`-i https://pypi.tuna.tsinghua.edu.cn/simple`更换源,加速安装过程。
|
||||
|
||||
* 如果希望部署CPU服务,可以安装serving-server的cpu版本,安装命令如下。
|
||||
|
||||
```shell
|
||||
pip install paddle-serving-server
|
||||
```
|
||||
|
||||
### 三、导出模型
|
||||
|
||||
使用`tools/export_serving_model.py`脚本导出Serving模型,以`ResNet50_vd`为例,使用方法如下。
|
||||
|
||||
```shell
|
||||
python tools/export_serving_model.py -m ResNet50_vd -p ./pretrained/ResNet50_vd_pretrained/ -o serving
|
||||
```
|
||||
|
||||
最终在serving文件夹下会生成`ppcls_client_conf`与`ppcls_model`两个文件夹,分别存储了client配置、模型参数与结构文件。
|
||||
|
||||
|
||||
### 四、服务部署与请求
|
||||
|
||||
* 使用下面的方式启动Serving服务。
|
||||
|
||||
```shell
|
||||
python tools/serving/image_service_gpu.py serving/ppcls_model workdir 9292
|
||||
```
|
||||
|
||||
其中`serving/ppcls_model`为刚才保存的Serving模型地址,`workdir`为为工作目录,`9292`为服务的端口号。
|
||||
|
||||
|
||||
* 使用下面的脚本向Serving服务发送识别请求,并返回结果。
|
||||
|
||||
```
|
||||
python tools/serving/image_http_client.py 9292 ./docs/images/logo.png
|
||||
```
|
||||
|
||||
`9292`为发送请求的端口号,需要与服务启动时的端口号保持一致,`./docs/images/logo.png`为待识别的图像文件。最终返回Top1识别结果的类别ID以及概率值。
|
||||
|
||||
* 更多的服务部署类型,如`RPC预测服务`等,可以参考Serving的github官网:[https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/imagenet](https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/imagenet)
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from ppcls.modeling import architectures
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle_serving_client.io as serving_io
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("-o", "--output_path", type=str, default="")
|
||||
parser.add_argument("--class_dim", type=int, default=1000)
|
||||
parser.add_argument("--img_size", type=int, default=224)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_input(img_size=224):
|
||||
image = fluid.data(
|
||||
name='image', shape=[None, 3, img_size, img_size], dtype='float32')
|
||||
return image
|
||||
|
||||
|
||||
def create_model(args, model, input, class_dim=1000):
|
||||
if args.model == "GoogLeNet":
|
||||
out, _, _ = model.net(input=input, class_dim=class_dim)
|
||||
else:
|
||||
out = model.net(input=input, class_dim=class_dim)
|
||||
out = fluid.layers.softmax(out)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model = architectures.__dict__[args.model]()
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
startup_prog = fluid.Program()
|
||||
infer_prog = fluid.Program()
|
||||
|
||||
with fluid.program_guard(infer_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
image = create_input(args.img_size)
|
||||
out = create_model(args, model, image, class_dim=args.class_dim)
|
||||
|
||||
infer_prog = infer_prog.clone(for_test=True)
|
||||
fluid.load(
|
||||
program=infer_prog, model_path=args.pretrained_model, executor=exe)
|
||||
|
||||
model_path = os.path.join(args.output_path, "ppcls_model")
|
||||
conf_path = os.path.join(args.output_path, "ppcls_client_conf")
|
||||
serving_io.save_model(model_path, conf_path, {"image": image},
|
||||
{"prediction": out}, infer_prog)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
py_version = sys.version_info[0]
|
||||
|
||||
|
||||
def predict(image_path, server):
|
||||
if py_version == 2:
|
||||
image = base64.b64encode(open(image_path).read())
|
||||
else:
|
||||
image = base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
|
||||
req = json.dumps({"feed": [{"image": image}], "fetch": ["prediction"]})
|
||||
r = requests.post(
|
||||
server, data=req, headers={"Content-Type": "application/json"})
|
||||
try:
|
||||
pred = r.json()["result"]["prediction"][0]
|
||||
cls_id = np.argmax(pred)
|
||||
score = pred[cls_id]
|
||||
pred = {"cls_id": cls_id, "score": score}
|
||||
return pred
|
||||
except ValueError:
|
||||
print(r.text)
|
||||
return r
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = "http://127.0.0.1:{}/image/prediction".format(sys.argv[1])
|
||||
image_file = sys.argv[2]
|
||||
res = predict(image_file, server)
|
||||
print("res:", res)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import base64
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import utils
|
||||
|
||||
|
||||
class ImageService(WebService):
|
||||
def __init__(self, name):
|
||||
super(ImageService, self).__init__(name=name)
|
||||
self.operators = self.create_operators()
|
||||
|
||||
def create_operators(self):
|
||||
size = 224
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
img_scale = 1.0 / 255.0
|
||||
decode_op = utils.DecodeImage()
|
||||
resize_op = utils.ResizeImage(resize_short=256)
|
||||
crop_op = utils.CropImage(size=(size, size))
|
||||
normalize_op = utils.NormalizeImage(
|
||||
scale=img_scale, mean=img_mean, std=img_std)
|
||||
totensor_op = utils.ToTensor()
|
||||
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
|
||||
|
||||
def _process_image(self, data, ops):
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
return data
|
||||
|
||||
def preprocess(self, feed={}, fetch=[]):
|
||||
feed_batch = []
|
||||
for ins in feed:
|
||||
if "image" not in ins:
|
||||
raise ("feed data error!")
|
||||
sample = base64.b64decode(ins["image"])
|
||||
img = self._process_image(sample, self.operators)
|
||||
feed_batch.append({"image": img})
|
||||
return feed_batch, fetch
|
||||
|
||||
|
||||
image_service = ImageService(name="image")
|
||||
image_service.load_model_config(sys.argv[1])
|
||||
image_service.prepare_server(
|
||||
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
|
||||
image_service.run_server()
|
||||
image_service.run_flask()
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import base64
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
class ImageService(WebService):
|
||||
def __init__(self, name):
|
||||
super(ImageService, self).__init__(name=name)
|
||||
self.operators = self.create_operators()
|
||||
|
||||
def create_operators(self):
|
||||
size = 224
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
img_scale = 1.0 / 255.0
|
||||
decode_op = utils.DecodeImage()
|
||||
resize_op = utils.ResizeImage(resize_short=256)
|
||||
crop_op = utils.CropImage(size=(size, size))
|
||||
normalize_op = utils.NormalizeImage(
|
||||
scale=img_scale, mean=img_mean, std=img_std)
|
||||
totensor_op = utils.ToTensor()
|
||||
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
|
||||
|
||||
def _process_image(self, data, ops):
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
return data
|
||||
|
||||
def preprocess(self, feed={}, fetch=[]):
|
||||
feed_batch = []
|
||||
for ins in feed:
|
||||
if "image" not in ins:
|
||||
raise ("feed data error!")
|
||||
sample = base64.b64decode(ins["image"])
|
||||
img = self._process_image(sample, self.operators)
|
||||
feed_batch.append({"image": img})
|
||||
return feed_batch, fetch
|
||||
|
||||
|
||||
image_service = ImageService(name="image")
|
||||
image_service.load_model_config(sys.argv[1])
|
||||
image_service.set_gpus("0")
|
||||
image_service.prepare_server(
|
||||
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
|
||||
image_service.run_server()
|
||||
image_service.run_flask()
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
def __init__(self, to_rgb=True):
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def __call__(self, img):
|
||||
data = np.frombuffer(img, dtype='uint8')
|
||||
img = cv2.imdecode(data, 1)
|
||||
if self.to_rgb:
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ResizeImage(object):
|
||||
def __init__(self, resize_short=None):
|
||||
self.resize_short = resize_short
|
||||
|
||||
def __call__(self, img):
|
||||
img_h, img_w = img.shape[:2]
|
||||
percent = float(self.resize_short) / min(img_w, img_h)
|
||||
w = int(round(img_w * percent))
|
||||
h = int(round(img_h * percent))
|
||||
return cv2.resize(img, (w, h))
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
def __init__(self, size):
|
||||
if type(size) is int:
|
||||
self.size = (size, size)
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = self.size
|
||||
img_h, img_w = img.shape[:2]
|
||||
w_start = (img_w - w) // 2
|
||||
h_start = (img_h - h) // 2
|
||||
|
||||
w_end = w_start + w
|
||||
h_end = h_start + h
|
||||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
def __init__(self, scale=None, mean=None, std=None):
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, img):
|
||||
return (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.transpose((2, 0, 1))
|
||||
return img
|
Loading…
Reference in New Issue