Merge pull request #2045 from HydrogenSulfate/add_cpp_serving

Add cpp serving chain
pull/2057/head
Walter 2022-06-14 11:41:56 +08:00 committed by GitHub
commit 05c393d938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1734 additions and 554 deletions

View File

@ -0,0 +1,88 @@
# 使用镜像:
# registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82
# 编译Serving Server
# client和app可以直接使用release版本
# server因为加入了自定义OP需要重新编译
# 默认编译时的${PWD}=PaddleClas/deploy/paddleserving/
python_name=${1:-'python'}
apt-get update
apt install -y libcurl4-openssl-dev libbz2-dev
wget -nc https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar
tar xf centos_ssl.tar
rm -rf centos_ssl.tar
mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k
mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k
ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10
ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10
ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so
ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so
# 安装go依赖
rm -rf /usr/local/go
wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local
export GOROOT=/usr/local/go
export GOPATH=/root/gopath
export PATH=$PATH:$GOPATH/bin:$GOROOT/bin
go env -w GO111MODULE=on
go env -w GOPROXY=https://goproxy.cn,direct
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
go install github.com/golang/protobuf/protoc-gen-go@v1.4.3
go install google.golang.org/grpc@v1.33.0
go env -w GO111MODULE=auto
# 下载opencv库
wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz
tar -xvf opencv3.tar.gz
rm -rf opencv3.tar.gz
export OPENCV_DIR=$PWD/opencv3
# clone Serving
git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1
cd Serving # PaddleClas/deploy/paddleserving/Serving
export Serving_repo_path=$PWD
git submodule update --init --recursive
${python_name} -m pip install -r python/requirements.txt
# set env
export PYTHON_INCLUDE_DIR=$(${python_name} -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")
export PYTHON_LIBRARIES=$(${python_name} -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")
export PYTHON_EXECUTABLE=`which ${python_name}`
export CUDA_PATH='/usr/local/cuda'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/'
export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/'
# cp 自定义OP代码
\cp ../preprocess/general_clas_op.* ${Serving_repo_path}/core/general-server/op
\cp ../preprocess/preprocess_op.* ${Serving_repo_path}/core/predictor/tools/pp_shitu_tools
# 编译Server
mkdir server-build-gpu-opencv
cd server-build-gpu-opencv
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
-DOPENCV_DIR=${OPENCV_DIR} \
-DWITH_OPENCV=ON \
-DSERVER=ON \
-DWITH_GPU=ON ..
make -j32
${python_name} -m pip install python/dist/paddle*
# export SERVING_BIN
export SERVING_BIN=$PWD/core/general-server/serving
cd ../../

View File

@ -0,0 +1,206 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_clas_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralClasOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
// only support string type
char *total_input_ptr = static_cast<char *>(in->at(0).data.data());
std::string base64str = total_input_ptr;
cv::Mat img = Base2Mat(base64str);
// RGB2BGR
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
// Resize
cv::Mat resize_img;
resize_op_.Run(img, resize_img, resize_short_size_);
// CenterCrop
crop_op_.Run(resize_img, crop_size_);
// Normalize
normalize_op_.Run(&resize_img, mean_, scale_, is_scale_);
// Permute
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
permute_op_.Run(&resize_img, input.data());
float maxValue = *max_element(input.begin(), input.end());
float minValue = *min_element(input.begin(), input.end());
TensorVector *real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
std::vector<int> input_shape;
int in_num = 0;
void *databuf_data = NULL;
char *databuf_char = NULL;
size_t databuf_size = 0;
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int>());
databuf_size = in_num * sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data, input.data(), databuf_size);
databuf_char = reinterpret_cast<char *>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(0).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(0).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
if (InferManager::instance().infer(engine_name().c_str(), real_in, out,
batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralClasOp::Base2Mat(std::string &base64_data) {
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralClasOp::base64Decode(const char *Data, int DataByte) {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte) {
if (*Data != '\r' && *Data != '\n') {
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=') {
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=') {
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
} else // 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
DEFINE_OP(GeneralClasOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu

View File

@ -0,0 +1,70 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h"
#include "paddle_inference_api.h" // NOLINT
#include <string>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralClasOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralClasOp);
int inference();
private:
// clas preprocess
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
bool is_scale_ = true;
int resize_short_size_ = 256;
int crop_size_ = 224;
PaddleClas::ResizeImg resize_op_;
PaddleClas::Normalize normalize_op_;
PaddleClas::Permute permute_op_;
PaddleClas::CenterCropImg crop_op_;
// read pics
cv::Mat Base2Mat(std::string &base64_data);
std::string base64Decode(const char *Data, int DataByte);
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu

View File

@ -0,0 +1,149 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <math.h>
#include <numeric>
#include "preprocess_op.h"
namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale) {
(*im).convertTo(*im, CV_32FC3, scale);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size, int size) {
int resize_h = 0;
int resize_w = 0;
if (size > 0) {
resize_h = size;
resize_w = size;
} else {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace Feature
namespace PaddleClas {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
if (is_scale) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / scale[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / scale[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace PaddleClas

View File

@ -0,0 +1,81 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
namespace Feature {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
int size = 0);
};
} // namespace Feature
namespace PaddleClas {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale = true);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace PaddleClas

View File

@ -0,0 +1,32 @@
feed_var {
name: "x"
alias_name: "x"
is_lod_tensor: false
feed_type: 1
shape: 3
shape: 224
shape: 224
}
feed_var {
name: "boxes"
alias_name: "boxes"
is_lod_tensor: false
feed_type: 1
shape: 6
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "features"
is_lod_tensor: false
fetch_type: 1
shape: 512
}
fetch_var {
name: "boxes"
alias_name: "boxes"
is_lod_tensor: false
fetch_type: 1
shape: 6
}

View File

@ -0,0 +1,30 @@
feed_var {
name: "x"
alias_name: "x"
is_lod_tensor: false
feed_type: 1
shape: 3
shape: 224
shape: 224
}
feed_var {
name: "boxes"
alias_name: "boxes"
is_lod_tensor: false
feed_type: 1
shape: 6
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "features"
is_lod_tensor: false
fetch_type: 1
shape: 512
}
fetch_var {
name: "boxes"
alias_name: "boxes"
is_lod_tensor: false
fetch_type: 1
shape: 6
}

View File

@ -0,0 +1,29 @@
feed_var {
name: "im_shape"
alias_name: "im_shape"
is_lod_tensor: false
feed_type: 1
shape: 2
}
feed_var {
name: "image"
alias_name: "image"
is_lod_tensor: false
feed_type: 7
shape: -1
shape: -1
shape: 3
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "save_infer_model/scale_0.tmp_1"
is_lod_tensor: true
fetch_type: 1
shape: -1
}
fetch_var {
name: "save_infer_model/scale_1.tmp_1"
alias_name: "save_infer_model/scale_1.tmp_1"
is_lod_tensor: false
fetch_type: 2
}

View File

@ -0,0 +1,29 @@
feed_var {
name: "im_shape"
alias_name: "im_shape"
is_lod_tensor: false
feed_type: 1
shape: 2
}
feed_var {
name: "image"
alias_name: "image"
is_lod_tensor: false
feed_type: 7
shape: -1
shape: -1
shape: 3
}
fetch_var {
name: "save_infer_model/scale_0.tmp_1"
alias_name: "save_infer_model/scale_0.tmp_1"
is_lod_tensor: true
fetch_type: 1
shape: -1
}
fetch_var {
name: "save_infer_model/scale_1.tmp_1"
alias_name: "save_infer_model/scale_1.tmp_1"
is_lod_tensor: false
fetch_type: 2
}

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle_serving_client import Client
@ -22,181 +21,101 @@ import faiss
import os
import pickle
class MainbodyDetect():
"""
pp-shitu mainbody detect.
include preprocess, process, postprocess
return detect results
Attention: Postprocess include num limit and box filter; no nms
"""
def __init__(self):
self.preprocess = DetectionSequential([
DetectionFile2Image(), DetectionNormalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionResize(
(640, 640), False, interpolation=2), DetectionTranspose(
(2, 0, 1))
])
self.client = Client()
self.client.load_client_config(
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/serving_client_conf.prototxt"
)
self.client.connect(['127.0.0.1:9293'])
self.max_det_result = 5
self.conf_threshold = 0.2
def predict(self, imgpath):
im, im_info = self.preprocess(imgpath)
im_shape = np.array(im.shape[1:]).reshape(-1)
scale_factor = np.array(list(im_info['scale_factor'])).reshape(-1)
fetch_map = self.client.predict(
feed={
"image": im,
"im_shape": im_shape,
"scale_factor": scale_factor,
},
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
return self.postprocess(fetch_map, imgpath)
def postprocess(self, fetch_map, imgpath):
#1. get top max_det_result
det_results = fetch_map["save_infer_model/scale_0.tmp_1"]
if len(det_results) > self.max_det_result:
boxes_reserved = fetch_map[
"save_infer_model/scale_0.tmp_1"][:self.max_det_result]
else:
boxes_reserved = det_results
#2. do conf threshold
boxes_list = []
for i in range(boxes_reserved.shape[0]):
if (boxes_reserved[i, 1]) > self.conf_threshold:
boxes_list.append(boxes_reserved[i, :])
#3. add origin image box
origin_img = cv2.imread(imgpath)
boxes_list.append(
np.array([0, 1.0, 0, 0, origin_img.shape[1], origin_img.shape[0]]))
return np.array(boxes_list)
rec_nms_thresold = 0.05
rec_score_thres = 0.5
feature_normalize = True
return_k = 1
index_dir = "../../drink_dataset_v1.0/index"
class ObjectRecognition():
"""
pp-shitu object recognion for all objects detected by MainbodyDetect.
include preprocess, process, postprocess
preprocess include preprocess for each image and batching.
Batch process
postprocess include retrieval and nms
"""
def init_index(index_dir):
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
def __init__(self):
self.client = Client()
self.client.load_client_config(
"../../models/general_PPLCNet_x2_5_lite_v1.0_client/serving_client_conf.prototxt"
)
self.client.connect(["127.0.0.1:9294"])
searcher = faiss.read_index(os.path.join(index_dir, "vector.index"))
self.seq = Sequential([
BGR2RGB(), Resize((224, 224)), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
False), Transpose((2, 0, 1))
])
self.searcher, self.id_map = self.init_index()
self.rec_nms_thresold = 0.05
self.rec_score_thres = 0.5
self.feature_normalize = True
self.return_k = 1
def init_index(self):
index_dir = "../../drink_dataset_v1.0/index"
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
searcher = faiss.read_index(os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
id_map = pickle.load(fd)
return searcher, id_map
def predict(self, det_boxes, imgpath):
#1. preprocess
batch_imgs = []
origin_img = cv2.imread(imgpath)
for i in range(det_boxes.shape[0]):
box = det_boxes[i]
x1, y1, x2, y2 = [int(x) for x in box[2:]]
cropped_img = origin_img[y1:y2, x1:x2, :].copy()
tmp = self.seq(cropped_img)
batch_imgs.append(tmp)
batch_imgs = np.array(batch_imgs)
#2. process
fetch_map = self.client.predict(
feed={"x": batch_imgs}, fetch=["features"], batch=True)
batch_features = fetch_map["features"]
#3. postprocess
if self.feature_normalize:
feas_norm = np.sqrt(
np.sum(np.square(batch_features), axis=1, keepdims=True))
batch_features = np.divide(batch_features, feas_norm)
scores, docs = self.searcher.search(batch_features, self.return_k)
results = []
for i in range(scores.shape[0]):
pred = {}
if scores[i][0] >= self.rec_score_thres:
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
pred["rec_scores"] = scores[i][0]
results.append(pred)
return self.nms_to_rec_results(results)
def nms_to_rec_results(self, results):
filtered_results = []
x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
y2 = np.array([r["bbox"][3] for r in results]).astype("float32")
scores = np.array([r["rec_scores"] for r in results])
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
while order.size > 0:
i = order[0]
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= self.rec_nms_thresold)[0]
order = order[inds + 1]
filtered_results.append(results[i])
return filtered_results
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
id_map = pickle.load(fd)
return searcher, id_map
#get box
def nms_to_rec_results(results, thresh=0.1):
filtered_results = []
x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
y2 = np.array([r["bbox"][3] for r in results]).astype("float32")
scores = np.array([r["rec_scores"] for r in results])
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
while order.size > 0:
i = order[0]
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
filtered_results.append(results[i])
return filtered_results
def postprocess(fetch_dict, feature_normalize, det_boxes, searcher, id_map,
return_k, rec_score_thres, rec_nms_thresold):
batch_features = fetch_dict["features"]
#do feature norm
if feature_normalize:
feas_norm = np.sqrt(
np.sum(np.square(batch_features), axis=1, keepdims=True))
batch_features = np.divide(batch_features, feas_norm)
scores, docs = searcher.search(batch_features, return_k)
results = []
for i in range(scores.shape[0]):
pred = {}
if scores[i][0] >= rec_score_thres:
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
pred["rec_docs"] = id_map[docs[i][0]].split()[1]
pred["rec_scores"] = scores[i][0]
results.append(pred)
#do nms
results = nms_to_rec_results(results, rec_nms_thresold)
return results
#do client
if __name__ == "__main__":
det = MainbodyDetect()
rec = ObjectRecognition()
client = Client()
client.load_client_config([
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client",
"../../models/general_PPLCNet_x2_5_lite_v1.0_client"
])
client.connect(['127.0.0.1:9400'])
#1. get det_results
imgpath = "../../drink_dataset_v1.0/test_images/001.jpeg"
det_results = det.predict(imgpath)
im = cv2.imread("../../drink_dataset_v1.0/test_images/001.jpeg")
im_shape = np.array(im.shape[:2]).reshape(-1)
fetch_map = client.predict(
feed={"image": im,
"im_shape": im_shape},
fetch=["features", "boxes"],
batch=False)
#2. get rec_results
rec_results = rec.predict(det_results, imgpath)
print(rec_results)
#add retrieval procedure
det_boxes = fetch_map["boxes"]
searcher, id_map = init_index(index_dir)
results = postprocess(fetch_map, feature_normalize, det_boxes, searcher,
id_map, return_k, rec_score_thres, rec_nms_thresold)
print(results)

View File

@ -12,16 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
#app
from paddle_serving_app.reader import Sequential, URL2Image, Resize
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
import base64
import time
from paddle_serving_client import Client
def bytes_to_base64(image: bytes) -> str:
"""encode bytes into base64 string
"""
return base64.b64encode(image).decode('utf8')
client = Client()
client.load_client_config("./ResNet50_vd_serving/serving_server_conf.prototxt")
client.load_client_config("./ResNet50_client/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9292"])
label_dict = {}
@ -31,22 +35,17 @@ with open("imagenet.label") as fin:
label_dict[label_idx] = line.strip()
label_idx += 1
#preprocess
seq = Sequential([
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
])
start = time.time()
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
image_file = "./daisy.jpg"
for i in range(1):
img = seq(image_file)
fetch_map = client.predict(
feed={"inputs": img}, fetch=["prediction"], batch=False)
prob = max(fetch_map["prediction"][0])
label = label_dict[fetch_map["prediction"][0].tolist().index(prob)].strip(
).replace(",", "")
print("prediction: {}, probability: {}".format(label, prob))
end = time.time()
print(end - start)
start = time.time()
with open(image_file, 'rb') as img_file:
image_data = img_file.read()
image = bytes_to_base64(image_data)
fetch_dict = client.predict(
feed={"inputs": image}, fetch=["prediction"], batch=False)
prob = max(fetch_dict["prediction"][0])
label = label_dict[fetch_dict["prediction"][0].tolist().index(
prob)].strip().replace(",", "")
print("prediction: {}, probability: {}".format(label, prob))
end = time.time()
print(end - start)

View File

@ -112,4 +112,5 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/MobileNetV3/Mo
- [test_lite_arm_cpu_cpp 使用](docs/test_lite_arm_cpu_cpp.md): 测试基于Paddle-Lite的ARM CPU端c++预测部署功能.
- [test_paddle2onnx 使用](docs/test_paddle2onnx.md)测试Paddle2ONNX的模型转化功能并验证正确性。
- [test_serving_infer_python 使用](docs/test_serving_infer_python.md)测试python serving功能。
- [test_serving_infer_cpp 使用](docs/test_serving_infer_cpp.md)测试cpp serving功能。
- [test_train_fleet_inference_python 使用](./docs/test_train_fleet_inference_python.md)测试基于Python的多机多卡训练与推理等基本功能。

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:MobileNetV3_large_x1_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/MobileNetV3_large_x1_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/MobileNetV3_large_x1_0_serving/
--serving_client:./deploy/paddleserving/MobileNetV3_large_x1_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,18 @@
===========================serving_params===========================
model_name:PPShiTu
python:python3.7
cls_inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
det_inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./models/general_PPLCNet_x2_5_lite_v1.0_infer/
--dirname:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./models/general_PPLCNet_x2_5_lite_v1.0_serving/
--serving_client:./models/general_PPLCNet_x2_5_lite_v1.0_client/
--serving_server:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/
--serving_client:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/
serving_dir:./paddleserving/recognition
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPHGNet_small
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPHGNet_small_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPHGNet_small_serving/
--serving_client:./deploy/paddleserving/PPHGNet_small_client/
serving_dir:./deploy/paddleserving
web_service:classification_web_service.py
--use_gpu:0|null
pipline:pipeline_http_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPHGNet_tiny
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPHGNet_tiny_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPHGNet_tiny_serving/
--serving_client:./deploy/paddleserving/PPHGNet_tiny_client/
serving_dir:./deploy/paddleserving
web_service:classification_web_service.py
--use_gpu:0|null
pipline:pipeline_http_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x0_25
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_25_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_25_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_25_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_25_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x0_35
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_35_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_35_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_35_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_35_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x0_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x0_75
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_75_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_75_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_75_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_75_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x1_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x1_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x1_0_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x1_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x1_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x1_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x1_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x1_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x2_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x2_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x2_0_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x2_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNet_x2_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x2_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x2_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x2_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:PPLCNetV2_base
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNetV2_base_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNetV2_base_serving/
--serving_client:./deploy/paddleserving/PPLCNetV2_base_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:ResNet50
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/ResNet50_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/ResNet50_serving/
--serving_client:./deploy/paddleserving/ResNet50_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:ResNet50_vd
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/ResNet50_vd_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/ResNet50_vd_serving/
--serving_client:./deploy/paddleserving/ResNet50_vd_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,14 @@
===========================serving_params===========================
model_name:SwinTransformer_tiny_patch4_window7_224
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/SwinTransformer_tiny_patch4_window7_224_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_serving/
--serving_client:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py

View File

@ -0,0 +1,91 @@
# Linux GPU/CPU PYTHON 服务化部署测试
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer_cpp.sh`可以测试基于Python的模型服务化部署功能。
## 1. 测试结论汇总
- 推理相关:
| 算法名称 | 模型名称 | device_CPU | device_GPU |
| :----: | :----: | :----: | :----: |
| MobileNetV3 | MobileNetV3_large_x1_0 | 支持 | 支持 |
| PP-ShiTu | PPShiTu_general_rec、PPShiTu_mainbody_det | 支持 | 支持 |
| PPHGNet | PPHGNet_small | 支持 | 支持 |
| PPHGNet | PPHGNet_tiny | 支持 | 支持 |
| PPLCNet | PPLCNet_x0_25 | 支持 | 支持 |
| PPLCNet | PPLCNet_x0_35 | 支持 | 支持 |
| PPLCNet | PPLCNet_x0_5 | 支持 | 支持 |
| PPLCNet | PPLCNet_x0_75 | 支持 | 支持 |
| PPLCNet | PPLCNet_x1_0 | 支持 | 支持 |
| PPLCNet | PPLCNet_x1_5 | 支持 | 支持 |
| PPLCNet | PPLCNet_x2_0 | 支持 | 支持 |
| PPLCNet | PPLCNet_x2_5 | 支持 | 支持 |
| PPLCNetV2 | PPLCNetV2_base | 支持 | 支持 |
| ResNet | ResNet50 | 支持 | 支持 |
| ResNet | ResNet50_vd | 支持 | 支持 |
| SwinTransformer | SwinTransformer_tiny_patch4_window7_224 | 支持 | 支持 |
## 2. 测试流程
### 2.1 准备数据
分类模型默认使用`./deploy/paddleserving/daisy.jpg`作为测试输入图片,无需下载
识别模型默认使用`drink_dataset_v1.0/test_images/001.jpeg`作为测试输入图片,在**2.2 准备环境**中会下载好。
### 2.2 准备环境
- 安装PaddlePaddle如果您已经安装了2.2或者以上版本的paddlepaddle那么无需运行下面的命令安装paddlepaddle。
```shell
# 需要安装2.2及以上版本的Paddle
# 安装GPU版本的Paddle
python3.7 -m pip install paddlepaddle-gpu==2.2.0
# 安装CPU版本的Paddle
python3.7 -m pip install paddlepaddle==2.2.0
```
- 安装依赖
```shell
python3.7 -m pip install -r requirements.txt
```
- 安装 PaddleServing 相关组件包括serving_client、serving-app自动编译并安装带自定义OP的 serving_server 包,以及自动下载并解压推理模型
```bash
bash test_tipc/prepare.sh test_tipc/configs/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt serving_infer
```
### 2.3 功能测试
测试方法如下所示,希望测试不同的模型文件,只需更换为自己的参数配置文件,即可完成对应模型的测试。
```bash
bash test_tipc/test_serving_infer_cpp.sh ${your_params_file}
```
以`PPLCNet_x1_0`的`Linux GPU/CPU C++ 服务化部署测试`为例,命令如下所示。
```bash
bash test_tipc/test_serving_infer_cpp.sh test_tipc/configs/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
```
输出结果如下,表示命令运行成功。
```
Run successfully with command - PPLCNet_x1_0 - python3.7 test_cpp_serving_client.py > ../../test_tipc/output/PPLCNet_x1_0/server_infer_cpp_gpu_pipeline_batchsize_1.log 2>&1 !
Run successfully with command - PPLCNet_x1_0 - python3.7 test_cpp_serving_client.py > ../../test_tipc/output/PPLCNet_x1_0/server_infer_cpp_cpu_pipeline_batchsize_1.log 2>&1 !
```
预测结果会自动保存在 `./test_tipc/output/PPLCNet_x1_0/server_infer_gpu_pipeline_http_batchsize_1.log` ,可以看到 PaddleServing 的运行结果:
```
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0612 09:55:16.109890 38303 naming_service_thread.cpp:202] brpc::policy::ListNamingService("127.0.0.1:9292"): added 1
I0612 09:55:16.172924 38303 general_model.cpp:490] [client]logid=0,client_cost=60.772ms,server_cost=57.6ms.
prediction: daisy, probability: 0.9099399447441101
0.06275796890258789
```
如果运行失败,也会在终端中输出运行失败的日志信息以及对应的运行命令。可以基于该命令,分析运行失败的原因。

View File

@ -1,6 +1,6 @@
# Linux GPU/CPU PYTHON 服务化部署测试
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer.sh`可以测试基于Python的模型服务化部署功能。
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer_python.sh`可以测试基于Python的模型服务化部署功能。
## 1. 测试结论汇总
@ -60,14 +60,14 @@ Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer.sh
测试方法如下所示,希望测试不同的模型文件,只需更换为自己的参数配置文件,即可完成对应模型的测试。
```bash
bash test_tipc/test_serving_infer_python.sh ${your_params_file} lite_train_lite_infer
bash test_tipc/test_serving_infer_python.sh ${your_params_file}
```
以`ResNet50`的`Linux GPU/CPU PYTHON 服务化部署测试`为例,命令如下所示。
```bash
bash test_tipc/test_serving_infer_python.sh test_tipc/configs/ResNet50/ResNet50_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt serving_infer
bash test_tipc/test_serving_infer_python.sh test_tipc/configs/ResNet50/ResNet50_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
```
输出结果如下,表示命令运行成功。

View File

@ -200,15 +200,25 @@ fi
if [[ ${MODE} = "serving_infer" ]]; then
# prepare serving env
python_name=$(func_parser_value "${lines[2]}")
${python_name} -m pip install paddle-serving-server-gpu==0.7.0.post102
${python_name} -m pip install paddle_serving_client==0.7.0
${python_name} -m pip install paddle-serving-app==0.7.0
${python_name} -m pip install paddle_serving_client==0.9.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
${python_name} -m pip install paddle-serving-app==0.9.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
python_name=$(func_parser_value "${lines[2]}")
if [[ ${FILENAME} =~ "cpp" ]]; then
pushd ./deploy/paddleserving
bash build_server.sh ${python_name}
popd
else
${python_name} -m pip install install paddle-serving-server-gpu==0.9.0.post101 -i https://pypi.tuna.tsinghua.edu.cn/simple
fi
if [[ ${model_name} =~ "ShiTu" ]]; then
${python_name} -m pip install faiss-cpu==1.7.1post2 -i https://pypi.tuna.tsinghua.edu.cn/simple
cls_inference_model_url=$(func_parser_value "${lines[3]}")
cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}")
det_inference_model_url=$(func_parser_value "${lines[4]}")
det_tar_name=$(func_get_url_file_name "${det_inference_model_url}")
cd ./deploy
wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar --no-check-certificate
tar -xf drink_dataset_v1.0.tar
mkdir models
cd models
wget -nc ${cls_inference_model_url} && tar xf ${cls_tar_name}

View File

@ -1,353 +0,0 @@
#!/bin/bash
source test_tipc/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
function func_get_url_file_name(){
strs=$1
IFS="/"
array=(${strs})
tmp=${array[${#array[@]}-1]}
echo ${tmp}
}
# parser serving
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
trans_model_py=$(func_parser_value "${lines[4]}")
infer_model_dir_key=$(func_parser_key "${lines[5]}")
infer_model_dir_value=$(func_parser_value "${lines[5]}")
model_filename_key=$(func_parser_key "${lines[6]}")
model_filename_value=$(func_parser_value "${lines[6]}")
params_filename_key=$(func_parser_key "${lines[7]}")
params_filename_value=$(func_parser_value "${lines[7]}")
serving_server_key=$(func_parser_key "${lines[8]}")
serving_server_value=$(func_parser_value "${lines[8]}")
serving_client_key=$(func_parser_key "${lines[9]}")
serving_client_value=$(func_parser_value "${lines[9]}")
serving_dir_value=$(func_parser_value "${lines[10]}")
web_service_py=$(func_parser_value "${lines[11]}")
web_use_gpu_key=$(func_parser_key "${lines[12]}")
web_use_gpu_list=$(func_parser_value "${lines[12]}")
pipeline_py=$(func_parser_value "${lines[13]}")
function func_serving_cls(){
LOG_PATH="../../test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_serving.log"
IFS='|'
# pdserving
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $trans_model_cmd
# modify the alias_name of fetch_var to "outputs"
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
eval ${server_fetch_var_line_cmd}
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
eval ${client_fetch_var_line_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
eval ${set_web_service_feet_var_cmd}
model_config=21
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
eval ${set_model_config_cmd}
for python in ${python[*]}; do
if [[ ${python} = "cpp" ]]; then
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293"
eval $web_service_cmd
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
eval $pipeline_cmd
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
else
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293 --gpu_id=0"
eval $web_service_cmd
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
eval $pipeline_cmd
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
fi
done
else
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [[ ${use_gpu} = "null" ]]; then
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
eval $set_device_type_cmd
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
eval $set_devices_cmd
web_service_cmd="${python} ${web_service_py} &"
eval $web_service_cmd
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
elif [ ${use_gpu} -eq 0 ]; then
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
continue
fi
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
eval $set_device_type_cmd
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
eval $set_devices_cmd
web_service_cmd="${python} ${web_service_py} & "
eval $web_service_cmd
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
else
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
fi
done
fi
done
}
function func_serving_rec(){
LOG_PATH="../../../test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_serving.log"
trans_model_py=$(func_parser_value "${lines[5]}")
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
model_filename_key=$(func_parser_key "${lines[8]}")
model_filename_value=$(func_parser_value "${lines[8]}")
params_filename_key=$(func_parser_key "${lines[9]}")
params_filename_value=$(func_parser_value "${lines[9]}")
cls_serving_server_key=$(func_parser_key "${lines[10]}")
cls_serving_server_value=$(func_parser_value "${lines[10]}")
cls_serving_client_key=$(func_parser_key "${lines[11]}")
cls_serving_client_value=$(func_parser_value "${lines[11]}")
det_serving_server_key=$(func_parser_key "${lines[12]}")
det_serving_server_value=$(func_parser_value "${lines[12]}")
det_serving_client_key=$(func_parser_key "${lines[13]}")
det_serving_client_value=$(func_parser_value "${lines[13]}")
serving_dir_value=$(func_parser_value "${lines[14]}")
web_service_py=$(func_parser_value "${lines[15]}")
web_use_gpu_key=$(func_parser_key "${lines[16]}")
web_use_gpu_list=$(func_parser_value "${lines[16]}")
pipeline_py=$(func_parser_value "${lines[17]}")
IFS='|'
# pdserving
cd ./deploy
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
cls_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $cls_trans_model_cmd
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
det_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $det_trans_model_cmd
# modify the alias_name of fetch_var to "outputs"
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_server_value/serving_server_conf.prototxt"
eval ${server_fetch_var_line_cmd}
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_client_value/serving_client_conf.prototxt"
eval ${client_fetch_var_line_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
eval ${set_web_service_feet_var_cmd}
for python in ${python[*]}; do
if [[ ${python} = "cpp" ]]; then
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
web_service_cpp_cmd="${python} web_service_py"
eval $web_service_cmd
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
eval $pipeline_cmd
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
else
web_service_cpp_cmd="${python} web_service_py"
eval $web_service_cmd
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
eval $pipeline_cmd
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
fi
done
else
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [[ ${use_gpu} = "null" ]]; then
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
eval $set_device_type_cmd
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
eval $set_devices_cmd
web_service_cmd="${python} ${web_service_py} &"
eval $web_service_cmd
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
elif [ ${use_gpu} -eq 0 ]; then
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
continue
fi
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
eval $set_device_type_cmd
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
eval $set_devices_cmd
web_service_cmd="${python} ${web_service_py} & "
eval $web_service_cmd
sleep 10s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 10s
done
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
else
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
fi
done
fi
done
}
# set cuda device
GPUID=$2
if [ ${#GPUID} -le 0 ];then
env=" "
else
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
fi
set CUDA_VISIBLE_DEVICES
eval $env
echo "################### run test ###################"
export Count=0
IFS="|"
if [[ ${model_name} =~ "ShiTu" ]]; then
func_serving_rec
else
func_serving_cls
fi

View File

@ -0,0 +1,262 @@
#!/bin/bash
source test_tipc/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
function func_get_url_file_name(){
strs=$1
IFS="/"
array=(${strs})
tmp=${array[${#array[@]}-1]}
echo ${tmp}
}
# parser serving
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
trans_model_py=$(func_parser_value "${lines[4]}")
infer_model_dir_key=$(func_parser_key "${lines[5]}")
infer_model_dir_value=$(func_parser_value "${lines[5]}")
model_filename_key=$(func_parser_key "${lines[6]}")
model_filename_value=$(func_parser_value "${lines[6]}")
params_filename_key=$(func_parser_key "${lines[7]}")
params_filename_value=$(func_parser_value "${lines[7]}")
serving_server_key=$(func_parser_key "${lines[8]}")
serving_server_value=$(func_parser_value "${lines[8]}")
serving_client_key=$(func_parser_key "${lines[9]}")
serving_client_value=$(func_parser_value "${lines[9]}")
serving_dir_value=$(func_parser_value "${lines[10]}")
web_service_py=$(func_parser_value "${lines[11]}")
web_use_gpu_key=$(func_parser_key "${lines[12]}")
web_use_gpu_list=$(func_parser_value "${lines[12]}")
pipeline_py=$(func_parser_value "${lines[13]}")
function func_serving_cls(){
LOG_PATH="test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
LOG_PATH="../../${LOG_PATH}"
status_log="${LOG_PATH}/results_serving.log"
IFS='|'
# pdserving
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
for python_ in ${python[*]}; do
if [[ ${python_} =~ "python" ]]; then
trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${trans_model_cmd}
break
fi
done
# modify the alias_name of fetch_var to "outputs"
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
eval ${server_fetch_var_line_cmd}
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
eval ${client_fetch_var_line_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
for item in ${python[*]}; do
if [[ ${item} =~ "python" ]]; then
python_=${item}
break
fi
done
serving_client_dir_name=$(func_get_url_file_name "$serving_client_value")
set_client_feed_type_cmd="sed -i '/feed_type/,/: .*/s/feed_type: .*/feed_type: 20/' ${serving_client_dir_name}/serving_client_conf.prototxt"
eval ${set_client_feed_type_cmd}
set_client_shape_cmd="sed -i '/shape: 3/,/shape: 3/s/shape: 3/shape: 1/' ${serving_client_dir_name}/serving_client_conf.prototxt"
eval ${set_client_shape_cmd}
set_client_shape224_cmd="sed -i '/shape: 224/,/shape: 224/s/shape: 224//' ${serving_client_dir_name}/serving_client_conf.prototxt"
eval ${set_client_shape224_cmd}
set_client_shape224_cmd="sed -i '/shape: 224/,/shape: 224/s/shape: 224//' ${serving_client_dir_name}/serving_client_conf.prototxt"
eval ${set_client_shape224_cmd}
set_pipeline_load_config_cmd="sed -i '/load_client_config/,/.prototxt/s/.\/.*\/serving_client_conf.prototxt/.\/${serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}"
eval ${set_pipeline_load_config_cmd}
set_pipeline_feed_var_cmd="sed -i '/feed=/,/: image}/s/feed={.*: image}/feed={${feed_var_name}: image}/' ${pipeline_py}"
eval ${set_pipeline_feed_var_cmd}
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
for use_gpu in ${web_use_gpu_list[*]}; do
if [[ ${use_gpu} = "null" ]]; then
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --op GeneralClasOp --port 9292 &"
eval ${web_service_cpp_cmd}
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_batchsize_1.log"
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
eval "${python_} -m paddle_serving_server.serve stop"
sleep 5s
else
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --op GeneralClasOp --port 9292 --gpu_id=${use_gpu} &"
eval ${web_service_cpp_cmd}
sleep 8s
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_pipeline_batchsize_1.log"
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
eval "${python_} -m paddle_serving_server.serve stop"
fi
done
}
function func_serving_rec(){
LOG_PATH="test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
LOG_PATH="../../../${LOG_PATH}"
status_log="${LOG_PATH}/results_serving.log"
trans_model_py=$(func_parser_value "${lines[5]}")
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
model_filename_key=$(func_parser_key "${lines[8]}")
model_filename_value=$(func_parser_value "${lines[8]}")
params_filename_key=$(func_parser_key "${lines[9]}")
params_filename_value=$(func_parser_value "${lines[9]}")
cls_serving_server_key=$(func_parser_key "${lines[10]}")
cls_serving_server_value=$(func_parser_value "${lines[10]}")
cls_serving_client_key=$(func_parser_key "${lines[11]}")
cls_serving_client_value=$(func_parser_value "${lines[11]}")
det_serving_server_key=$(func_parser_key "${lines[12]}")
det_serving_server_value=$(func_parser_value "${lines[12]}")
det_serving_client_key=$(func_parser_key "${lines[13]}")
det_serving_client_value=$(func_parser_value "${lines[13]}")
serving_dir_value=$(func_parser_value "${lines[14]}")
web_service_py=$(func_parser_value "${lines[15]}")
web_use_gpu_key=$(func_parser_key "${lines[16]}")
web_use_gpu_list=$(func_parser_value "${lines[16]}")
pipeline_py=$(func_parser_value "${lines[17]}")
IFS='|'
for python_ in ${python[*]}; do
if [[ ${python_} =~ "python" ]]; then
python_interp=${python_}
break
fi
done
# pdserving
cd ./deploy
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${cls_trans_model_cmd}
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${det_trans_model_cmd}
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/general_PPLCNet_x2_5_lite_v1.0_serving/*.prototxt ${cls_serving_server_value}"
eval ${cp_prototxt_cmd}
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/general_PPLCNet_x2_5_lite_v1.0_client/*.prototxt ${cls_serving_client_value}"
eval ${cp_prototxt_cmd}
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/*.prototxt ${det_serving_client_value}"
eval ${cp_prototxt_cmd}
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/*.prototxt ${det_serving_server_value}"
eval ${cp_prototxt_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
export SERVING_BIN=${PWD}/../Serving/server-build-gpu-opencv/core/general-server/serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../${det_serving_server_value} ../../${cls_serving_server_value} --op GeneralPicodetOp GeneralFeatureExtractOp --port 9400 &"
eval ${web_service_cpp_cmd}
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_batchsize_1.log"
pipeline_cmd="${python_interp} ${pipeline_py} > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
eval "${python_} -m paddle_serving_server.serve stop"
sleep 5s
else
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../${det_serving_server_value} ../../${cls_serving_server_value} --op GeneralPicodetOp GeneralFeatureExtractOp --port 9400 --gpu_id=${use_gpu} &"
eval ${web_service_cpp_cmd}
sleep 5s
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_batchsize_1.log"
pipeline_cmd="${python_interp} ${pipeline_py} > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
eval "${python_} -m paddle_serving_server.serve stop"
sleep 5s
fi
done
}
# set cuda device
GPUID=$2
if [ ${#GPUID} -le 0 ];then
env=" "
else
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
fi
set CUDA_VISIBLE_DEVICES
eval ${env}
echo "################### run test ###################"
export Count=0
IFS="|"
if [[ ${model_name} =~ "ShiTu" ]]; then
func_serving_rec
else
func_serving_cls
fi

View File

@ -0,0 +1,309 @@
#!/bin/bash
source test_tipc/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
function func_get_url_file_name(){
strs=$1
IFS="/"
array=(${strs})
tmp=${array[${#array[@]}-1]}
echo ${tmp}
}
# parser serving
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
trans_model_py=$(func_parser_value "${lines[4]}")
infer_model_dir_key=$(func_parser_key "${lines[5]}")
infer_model_dir_value=$(func_parser_value "${lines[5]}")
model_filename_key=$(func_parser_key "${lines[6]}")
model_filename_value=$(func_parser_value "${lines[6]}")
params_filename_key=$(func_parser_key "${lines[7]}")
params_filename_value=$(func_parser_value "${lines[7]}")
serving_server_key=$(func_parser_key "${lines[8]}")
serving_server_value=$(func_parser_value "${lines[8]}")
serving_client_key=$(func_parser_key "${lines[9]}")
serving_client_value=$(func_parser_value "${lines[9]}")
serving_dir_value=$(func_parser_value "${lines[10]}")
web_service_py=$(func_parser_value "${lines[11]}")
web_use_gpu_key=$(func_parser_key "${lines[12]}")
web_use_gpu_list=$(func_parser_value "${lines[12]}")
pipeline_py=$(func_parser_value "${lines[13]}")
function func_serving_cls(){
LOG_PATH="test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
LOG_PATH="../../${LOG_PATH}"
status_log="${LOG_PATH}/results_serving.log"
IFS='|'
# pdserving
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
for python_ in ${python[*]}; do
if [[ ${python_} =~ "python" ]]; then
trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${trans_model_cmd}
break
fi
done
# modify the alias_name of fetch_var to "outputs"
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
eval ${server_fetch_var_line_cmd}
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
eval ${client_fetch_var_line_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
# python serving
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
eval ${set_web_service_feed_var_cmd}
model_config=21
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
eval ${set_model_config_cmd}
for use_gpu in ${web_use_gpu_list[*]}; do
if [[ ${use_gpu} = "null" ]]; then
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
eval ${set_device_type_cmd}
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
eval ${set_devices_cmd}
web_service_cmd="${python_} ${web_service_py} &"
eval ${web_service_cmd}
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
eval "${python_} -m paddle_serving_server.serve stop"
elif [ ${use_gpu} -eq 0 ]; then
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
continue
fi
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
eval ${set_device_type_cmd}
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
eval ${set_devices_cmd}
web_service_cmd="${python_} ${web_service_py} & "
eval ${web_service_cmd}
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1"
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
eval "${python_} -m paddle_serving_server.serve stop"
else
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
fi
done
}
function func_serving_rec(){
LOG_PATH="test_tipc/output/${model_name}"
mkdir -p ${LOG_PATH}
LOG_PATH="../../../${LOG_PATH}"
status_log="${LOG_PATH}/results_serving.log"
trans_model_py=$(func_parser_value "${lines[5]}")
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
model_filename_key=$(func_parser_key "${lines[8]}")
model_filename_value=$(func_parser_value "${lines[8]}")
params_filename_key=$(func_parser_key "${lines[9]}")
params_filename_value=$(func_parser_value "${lines[9]}")
cls_serving_server_key=$(func_parser_key "${lines[10]}")
cls_serving_server_value=$(func_parser_value "${lines[10]}")
cls_serving_client_key=$(func_parser_key "${lines[11]}")
cls_serving_client_value=$(func_parser_value "${lines[11]}")
det_serving_server_key=$(func_parser_key "${lines[12]}")
det_serving_server_value=$(func_parser_value "${lines[12]}")
det_serving_client_key=$(func_parser_key "${lines[13]}")
det_serving_client_value=$(func_parser_value "${lines[13]}")
serving_dir_value=$(func_parser_value "${lines[14]}")
web_service_py=$(func_parser_value "${lines[15]}")
web_use_gpu_key=$(func_parser_key "${lines[16]}")
web_use_gpu_list=$(func_parser_value "${lines[16]}")
pipeline_py=$(func_parser_value "${lines[17]}")
IFS='|'
for python_ in ${python[*]}; do
if [[ ${python_} =~ "python" ]]; then
python_interp=${python_}
break
fi
done
# pdserving
cd ./deploy
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${cls_trans_model_cmd}
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval ${det_trans_model_cmd}
# modify the alias_name of fetch_var to "outputs"
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_server_value/serving_server_conf.prototxt"
eval ${server_fetch_var_line_cmd}
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_client_value/serving_client_conf.prototxt"
eval ${client_fetch_var_line_cmd}
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
IFS=$'\n'
prototxt_lines=(${prototxt_dataline})
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
IFS='|'
cd ${serving_dir_value}
unset https_proxy
unset http_proxy
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
eval ${set_web_service_feed_var_cmd}
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [[ ${use_gpu} = "null" ]]; then
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
eval ${set_device_type_cmd}
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
eval ${set_devices_cmd}
web_service_cmd="${python} ${web_service_py} &"
eval ${web_service_cmd}
sleep 5s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 5s
done
eval "${python_} -m paddle_serving_server.serve stop"
elif [ ${use_gpu} -eq 0 ]; then
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
continue
fi
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
continue
fi
device_type_line=24
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
eval ${set_device_type_cmd}
devices_line=27
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
eval ${set_devices_cmd}
web_service_cmd="${python} ${web_service_py} & "
eval ${web_service_cmd}
sleep 10s
for pipeline in ${pipeline_py[*]}; do
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
eval ${pipeline_cmd}
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
sleep 10s
done
eval "${python_} -m paddle_serving_server.serve stop"
else
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
fi
done
}
# set cuda device
GPUID=$2
if [ ${#GPUID} -le 0 ];then
env=" "
else
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
fi
set CUDA_VISIBLE_DEVICES
eval ${env}
echo "################### run test ###################"
export Count=0
IFS="|"
if [[ ${model_name} =~ "ShiTu" ]]; then
func_serving_rec
else
func_serving_cls
fi