mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add cpp serving infer(except PPShiTu)
This commit is contained in:
parent
3fd426aa01
commit
ac821bbe4d
74
deploy/paddleserving/build_server.sh
Normal file
74
deploy/paddleserving/build_server.sh
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
#使用镜像:
|
||||||
|
#registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82
|
||||||
|
|
||||||
|
#编译Serving Server:
|
||||||
|
|
||||||
|
#client和app可以直接使用release版本
|
||||||
|
|
||||||
|
#server因为加入了自定义OP,需要重新编译
|
||||||
|
|
||||||
|
#默认编译时的${PWD}=PaddleClas/deploy/paddleserving/
|
||||||
|
|
||||||
|
python_name=${1:-'python'}
|
||||||
|
|
||||||
|
apt-get update
|
||||||
|
apt install -y libcurl4-openssl-dev libbz2-dev
|
||||||
|
wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && tar xf centos_ssl.tar && rm -rf centos_ssl.tar && mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k && mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k && ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10 && ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10 && ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so && ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so
|
||||||
|
|
||||||
|
# 安装go依赖
|
||||||
|
rm -rf /usr/local/go
|
||||||
|
wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local
|
||||||
|
export GOROOT=/usr/local/go
|
||||||
|
export GOPATH=/root/gopath
|
||||||
|
export PATH=$PATH:$GOPATH/bin:$GOROOT/bin
|
||||||
|
go env -w GO111MODULE=on
|
||||||
|
go env -w GOPROXY=https://goproxy.cn,direct
|
||||||
|
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2
|
||||||
|
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
|
||||||
|
go install github.com/golang/protobuf/protoc-gen-go@v1.4.3
|
||||||
|
go install google.golang.org/grpc@v1.33.0
|
||||||
|
go env -w GO111MODULE=auto
|
||||||
|
|
||||||
|
# 下载opencv库
|
||||||
|
wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz && tar -xvf opencv3.tar.gz && rm -rf opencv3.tar.gz
|
||||||
|
export OPENCV_DIR=$PWD/opencv3
|
||||||
|
|
||||||
|
# clone Serving
|
||||||
|
git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1
|
||||||
|
cd Serving # PaddleClas/deploy/paddleserving/Serving
|
||||||
|
export Serving_repo_path=$PWD
|
||||||
|
git submodule update --init --recursive
|
||||||
|
${python_name} -m pip install -r python/requirements.txt
|
||||||
|
|
||||||
|
# set env
|
||||||
|
export PYTHON_INCLUDE_DIR=$(${python_name} -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")
|
||||||
|
export PYTHON_LIBRARIES=$(${python_name} -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")
|
||||||
|
export PYTHON_EXECUTABLE=`which ${python_name}`
|
||||||
|
|
||||||
|
export CUDA_PATH='/usr/local/cuda'
|
||||||
|
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
|
||||||
|
export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/'
|
||||||
|
export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/'
|
||||||
|
|
||||||
|
# cp 自定义OP代码
|
||||||
|
\cp ../preprocess/general_clas_op.* ${Serving_repo_path}/core/general-server/op
|
||||||
|
\cp ../preprocess/preprocess_op.* ${Serving_repo_path}/core/predictor/tools/pp_shitu_tools
|
||||||
|
|
||||||
|
# 编译Server, export SERVING_BIN
|
||||||
|
mkdir server-build-gpu-opencv && cd server-build-gpu-opencv
|
||||||
|
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \
|
||||||
|
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
|
||||||
|
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
|
||||||
|
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
|
||||||
|
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
|
||||||
|
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
|
||||||
|
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
|
||||||
|
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||||
|
-DWITH_OPENCV=ON \
|
||||||
|
-DSERVER=ON \
|
||||||
|
-DWITH_GPU=ON ..
|
||||||
|
make -j32
|
||||||
|
|
||||||
|
${python_name} -m pip install python/dist/paddle*
|
||||||
|
export SERVING_BIN=$PWD/core/general-server/serving
|
||||||
|
cd ../../
|
206
deploy/paddleserving/preprocess/general_clas_op.cpp
Normal file
206
deploy/paddleserving/preprocess/general_clas_op.cpp
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "core/general-server/op/general_clas_op.h"
|
||||||
|
#include "core/predictor/framework/infer.h"
|
||||||
|
#include "core/predictor/framework/memory.h"
|
||||||
|
#include "core/predictor/framework/resource.h"
|
||||||
|
#include "core/util/include/timer.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace baidu {
|
||||||
|
namespace paddle_serving {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
using baidu::paddle_serving::Timer;
|
||||||
|
using baidu::paddle_serving::predictor::MempoolWrapper;
|
||||||
|
using baidu::paddle_serving::predictor::general_model::Tensor;
|
||||||
|
using baidu::paddle_serving::predictor::general_model::Response;
|
||||||
|
using baidu::paddle_serving::predictor::general_model::Request;
|
||||||
|
using baidu::paddle_serving::predictor::InferManager;
|
||||||
|
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
|
||||||
|
|
||||||
|
int GeneralClasOp::inference() {
|
||||||
|
VLOG(2) << "Going to run inference";
|
||||||
|
const std::vector<std::string> pre_node_names = pre_names();
|
||||||
|
if (pre_node_names.size() != 1) {
|
||||||
|
LOG(ERROR) << "This op(" << op_name()
|
||||||
|
<< ") can only have one predecessor op, but received "
|
||||||
|
<< pre_node_names.size();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
const std::string pre_name = pre_node_names[0];
|
||||||
|
|
||||||
|
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
|
||||||
|
if (!input_blob) {
|
||||||
|
LOG(ERROR) << "input_blob is nullptr,error";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
uint64_t log_id = input_blob->GetLogId();
|
||||||
|
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
|
||||||
|
|
||||||
|
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
|
||||||
|
if (!output_blob) {
|
||||||
|
LOG(ERROR) << "output_blob is nullptr,error";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
output_blob->SetLogId(log_id);
|
||||||
|
|
||||||
|
if (!input_blob) {
|
||||||
|
LOG(ERROR) << "(logid=" << log_id
|
||||||
|
<< ") Failed mutable depended argument, op:" << pre_name;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TensorVector *in = &input_blob->tensor_vector;
|
||||||
|
TensorVector *out = &output_blob->tensor_vector;
|
||||||
|
|
||||||
|
int batch_size = input_blob->_batch_size;
|
||||||
|
output_blob->_batch_size = batch_size;
|
||||||
|
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
|
||||||
|
|
||||||
|
Timer timeline;
|
||||||
|
int64_t start = timeline.TimeStampUS();
|
||||||
|
timeline.Start();
|
||||||
|
|
||||||
|
// only support string type
|
||||||
|
|
||||||
|
char *total_input_ptr = static_cast<char *>(in->at(0).data.data());
|
||||||
|
std::string base64str = total_input_ptr;
|
||||||
|
|
||||||
|
cv::Mat img = Base2Mat(base64str);
|
||||||
|
|
||||||
|
// RGB2BGR
|
||||||
|
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
|
||||||
|
|
||||||
|
// Resize
|
||||||
|
cv::Mat resize_img;
|
||||||
|
resize_op_.Run(img, resize_img, resize_short_size_);
|
||||||
|
|
||||||
|
// CenterCrop
|
||||||
|
crop_op_.Run(resize_img, crop_size_);
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
normalize_op_.Run(&resize_img, mean_, scale_, is_scale_);
|
||||||
|
|
||||||
|
// Permute
|
||||||
|
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||||
|
permute_op_.Run(&resize_img, input.data());
|
||||||
|
float maxValue = *max_element(input.begin(), input.end());
|
||||||
|
float minValue = *min_element(input.begin(), input.end());
|
||||||
|
|
||||||
|
TensorVector *real_in = new TensorVector();
|
||||||
|
if (!real_in) {
|
||||||
|
LOG(ERROR) << "real_in is nullptr,error";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> input_shape;
|
||||||
|
int in_num = 0;
|
||||||
|
void *databuf_data = NULL;
|
||||||
|
char *databuf_char = NULL;
|
||||||
|
size_t databuf_size = 0;
|
||||||
|
|
||||||
|
input_shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||||
|
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
|
||||||
|
databuf_size = in_num * sizeof(float);
|
||||||
|
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
|
||||||
|
if (!databuf_data) {
|
||||||
|
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(databuf_data, input.data(), databuf_size);
|
||||||
|
databuf_char = reinterpret_cast<char *>(databuf_data);
|
||||||
|
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
|
||||||
|
paddle::PaddleTensor tensor_in;
|
||||||
|
tensor_in.name = in->at(0).name;
|
||||||
|
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
|
||||||
|
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||||
|
tensor_in.lod = in->at(0).lod;
|
||||||
|
tensor_in.data = paddleBuf;
|
||||||
|
real_in->push_back(tensor_in);
|
||||||
|
|
||||||
|
if (InferManager::instance().infer(engine_name().c_str(), real_in, out,
|
||||||
|
batch_size)) {
|
||||||
|
LOG(ERROR) << "(logid=" << log_id
|
||||||
|
<< ") Failed do infer in fluid model: " << engine_name().c_str();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t end = timeline.TimeStampUS();
|
||||||
|
CopyBlobInfo(input_blob, output_blob);
|
||||||
|
AddBlobInfo(output_blob, start);
|
||||||
|
AddBlobInfo(output_blob, end);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
cv::Mat GeneralClasOp::Base2Mat(std::string &base64_data) {
|
||||||
|
cv::Mat img;
|
||||||
|
std::string s_mat;
|
||||||
|
s_mat = base64Decode(base64_data.data(), base64_data.size());
|
||||||
|
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
|
||||||
|
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
|
||||||
|
return img;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GeneralClasOp::base64Decode(const char *Data, int DataByte) {
|
||||||
|
const char DecodeTable[] = {
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
62, // '+'
|
||||||
|
0, 0, 0,
|
||||||
|
63, // '/'
|
||||||
|
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||||
|
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
||||||
|
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||||
|
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string strDecode;
|
||||||
|
int nValue;
|
||||||
|
int i = 0;
|
||||||
|
while (i < DataByte) {
|
||||||
|
if (*Data != '\r' && *Data != '\n') {
|
||||||
|
nValue = DecodeTable[*Data++] << 18;
|
||||||
|
nValue += DecodeTable[*Data++] << 12;
|
||||||
|
strDecode += (nValue & 0x00FF0000) >> 16;
|
||||||
|
if (*Data != '=') {
|
||||||
|
nValue += DecodeTable[*Data++] << 6;
|
||||||
|
strDecode += (nValue & 0x0000FF00) >> 8;
|
||||||
|
if (*Data != '=') {
|
||||||
|
nValue += DecodeTable[*Data++];
|
||||||
|
strDecode += nValue & 0x000000FF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 4;
|
||||||
|
} else // 回车换行,跳过
|
||||||
|
{
|
||||||
|
Data++;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strDecode;
|
||||||
|
}
|
||||||
|
DEFINE_OP(GeneralClasOp);
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace paddle_serving
|
||||||
|
} // namespace baidu
|
70
deploy/paddleserving/preprocess/general_clas_op.h
Normal file
70
deploy/paddleserving/preprocess/general_clas_op.h
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "core/general-server/general_model_service.pb.h"
|
||||||
|
#include "core/general-server/op/general_infer_helper.h"
|
||||||
|
#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h"
|
||||||
|
#include "paddle_inference_api.h" // NOLINT
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "opencv2/core.hpp"
|
||||||
|
#include "opencv2/imgcodecs.hpp"
|
||||||
|
#include "opencv2/imgproc.hpp"
|
||||||
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace baidu {
|
||||||
|
namespace paddle_serving {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
class GeneralClasOp
|
||||||
|
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
|
||||||
|
public:
|
||||||
|
typedef std::vector<paddle::PaddleTensor> TensorVector;
|
||||||
|
|
||||||
|
DECLARE_OP(GeneralClasOp);
|
||||||
|
|
||||||
|
int inference();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// clas preprocess
|
||||||
|
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||||
|
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
|
||||||
|
bool is_scale_ = true;
|
||||||
|
|
||||||
|
int resize_short_size_ = 256;
|
||||||
|
int crop_size_ = 224;
|
||||||
|
|
||||||
|
PaddleClas::ResizeImg resize_op_;
|
||||||
|
PaddleClas::Normalize normalize_op_;
|
||||||
|
PaddleClas::Permute permute_op_;
|
||||||
|
PaddleClas::CenterCropImg crop_op_;
|
||||||
|
|
||||||
|
// read pics
|
||||||
|
cv::Mat Base2Mat(std::string &base64_data);
|
||||||
|
std::string base64Decode(const char *Data, int DataByte);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace paddle_serving
|
||||||
|
} // namespace baidu
|
149
deploy/paddleserving/preprocess/preprocess_op.cpp
Normal file
149
deploy/paddleserving/preprocess/preprocess_op.cpp
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "opencv2/core.hpp"
|
||||||
|
#include "opencv2/imgcodecs.hpp"
|
||||||
|
#include "opencv2/imgproc.hpp"
|
||||||
|
#include "paddle_api.h"
|
||||||
|
#include "paddle_inference_api.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <math.h>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "preprocess_op.h"
|
||||||
|
|
||||||
|
namespace Feature {
|
||||||
|
|
||||||
|
void Permute::Run(const cv::Mat *im, float *data) {
|
||||||
|
int rh = im->rows;
|
||||||
|
int rw = im->cols;
|
||||||
|
int rc = im->channels();
|
||||||
|
for (int i = 0; i < rc; ++i) {
|
||||||
|
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
||||||
|
const std::vector<float> &std, float scale) {
|
||||||
|
(*im).convertTo(*im, CV_32FC3, scale);
|
||||||
|
for (int h = 0; h < im->rows; h++) {
|
||||||
|
for (int w = 0; w < im->cols; w++) {
|
||||||
|
im->at<cv::Vec3f>(h, w)[0] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
|
||||||
|
im->at<cv::Vec3f>(h, w)[1] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
|
||||||
|
im->at<cv::Vec3f>(h, w)[2] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
|
||||||
|
int resize_w = img.cols;
|
||||||
|
int resize_h = img.rows;
|
||||||
|
int w_start = int((resize_w - crop_size) / 2);
|
||||||
|
int h_start = int((resize_h - crop_size) / 2);
|
||||||
|
cv::Rect rect(w_start, h_start, crop_size, crop_size);
|
||||||
|
img = img(rect);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||||
|
int resize_short_size, int size) {
|
||||||
|
int resize_h = 0;
|
||||||
|
int resize_w = 0;
|
||||||
|
if (size > 0) {
|
||||||
|
resize_h = size;
|
||||||
|
resize_w = size;
|
||||||
|
} else {
|
||||||
|
int w = img.cols;
|
||||||
|
int h = img.rows;
|
||||||
|
|
||||||
|
float ratio = 1.f;
|
||||||
|
if (h < w) {
|
||||||
|
ratio = float(resize_short_size) / float(h);
|
||||||
|
} else {
|
||||||
|
ratio = float(resize_short_size) / float(w);
|
||||||
|
}
|
||||||
|
resize_h = round(float(h) * ratio);
|
||||||
|
resize_w = round(float(w) * ratio);
|
||||||
|
}
|
||||||
|
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace Feature
|
||||||
|
|
||||||
|
namespace PaddleClas {
|
||||||
|
void Permute::Run(const cv::Mat *im, float *data) {
|
||||||
|
int rh = im->rows;
|
||||||
|
int rw = im->cols;
|
||||||
|
int rc = im->channels();
|
||||||
|
for (int i = 0; i < rc; ++i) {
|
||||||
|
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
||||||
|
const std::vector<float> &scale, const bool is_scale) {
|
||||||
|
double e = 1.0;
|
||||||
|
if (is_scale) {
|
||||||
|
e /= 255.0;
|
||||||
|
}
|
||||||
|
(*im).convertTo(*im, CV_32FC3, e);
|
||||||
|
for (int h = 0; h < im->rows; h++) {
|
||||||
|
for (int w = 0; w < im->cols; w++) {
|
||||||
|
im->at<cv::Vec3f>(h, w)[0] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / scale[0];
|
||||||
|
im->at<cv::Vec3f>(h, w)[1] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / scale[1];
|
||||||
|
im->at<cv::Vec3f>(h, w)[2] =
|
||||||
|
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / scale[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
|
||||||
|
int resize_w = img.cols;
|
||||||
|
int resize_h = img.rows;
|
||||||
|
int w_start = int((resize_w - crop_size) / 2);
|
||||||
|
int h_start = int((resize_h - crop_size) / 2);
|
||||||
|
cv::Rect rect(w_start, h_start, crop_size, crop_size);
|
||||||
|
img = img(rect);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||||
|
int resize_short_size) {
|
||||||
|
int w = img.cols;
|
||||||
|
int h = img.rows;
|
||||||
|
|
||||||
|
float ratio = 1.f;
|
||||||
|
if (h < w) {
|
||||||
|
ratio = float(resize_short_size) / float(h);
|
||||||
|
} else {
|
||||||
|
ratio = float(resize_short_size) / float(w);
|
||||||
|
}
|
||||||
|
|
||||||
|
int resize_h = round(float(h) * ratio);
|
||||||
|
int resize_w = round(float(w) * ratio);
|
||||||
|
|
||||||
|
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace PaddleClas
|
81
deploy/paddleserving/preprocess/preprocess_op.h
Normal file
81
deploy/paddleserving/preprocess/preprocess_op.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "opencv2/core.hpp"
|
||||||
|
#include "opencv2/imgcodecs.hpp"
|
||||||
|
#include "opencv2/imgproc.hpp"
|
||||||
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace Feature {
|
||||||
|
|
||||||
|
class Normalize {
|
||||||
|
public:
|
||||||
|
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
|
||||||
|
const std::vector<float> &std, float scale);
|
||||||
|
};
|
||||||
|
|
||||||
|
// RGB -> CHW
|
||||||
|
class Permute {
|
||||||
|
public:
|
||||||
|
virtual void Run(const cv::Mat *im, float *data);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CenterCropImg {
|
||||||
|
public:
|
||||||
|
virtual void Run(cv::Mat &im, const int crop_size = 224);
|
||||||
|
};
|
||||||
|
|
||||||
|
class ResizeImg {
|
||||||
|
public:
|
||||||
|
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
|
||||||
|
int size = 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace Feature
|
||||||
|
|
||||||
|
namespace PaddleClas {
|
||||||
|
|
||||||
|
class Normalize {
|
||||||
|
public:
|
||||||
|
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
|
||||||
|
const std::vector<float> &scale, const bool is_scale = true);
|
||||||
|
};
|
||||||
|
|
||||||
|
// RGB -> CHW
|
||||||
|
class Permute {
|
||||||
|
public:
|
||||||
|
virtual void Run(const cv::Mat *im, float *data);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CenterCropImg {
|
||||||
|
public:
|
||||||
|
virtual void Run(cv::Mat &im, const int crop_size = 224);
|
||||||
|
};
|
||||||
|
|
||||||
|
class ResizeImg {
|
||||||
|
public:
|
||||||
|
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace PaddleClas
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:MobileNetV3_large_x1_0
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/MobileNetV3_large_x1_0_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/MobileNetV3_large_x1_0_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/MobileNetV3_large_x1_0_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPHGNet_small
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPHGNet_small_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPHGNet_small_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPHGNet_small_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:classification_web_service.py
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:pipeline_http_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPHGNet_tiny
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPHGNet_tiny_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPHGNet_tiny_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPHGNet_tiny_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:classification_web_service.py
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:pipeline_http_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x0_25
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_25_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x0_25_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x0_25_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x0_25_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x0_35
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_35_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x0_35_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x0_35_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x0_35_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x0_5
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_5_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x0_5_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x0_5_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x0_5_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x0_75
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_75_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x0_75_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x0_75_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x0_75_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x1_0
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_0_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x1_0_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x1_0_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x1_0_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x1_5
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_5_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x1_5_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x1_5_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x1_5_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x2_0
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x2_0_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x2_0_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x2_0_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNet_x2_5
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNet_x2_5_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNet_x2_5_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNet_x2_5_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:PPLCNetV2_base
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/PPLCNetV2_base_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/PPLCNetV2_base_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/PPLCNetV2_base_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:ResNet50
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/ResNet50_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/ResNet50_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/ResNet50_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:ResNet50_vd
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/ResNet50_vd_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/ResNet50_vd_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/ResNet50_vd_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -0,0 +1,14 @@
|
|||||||
|
===========================serving_params===========================
|
||||||
|
model_name:SwinTransformer_tiny_patch4_window7_224
|
||||||
|
python:python3.7
|
||||||
|
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/SwinTransformer_tiny_patch4_window7_224_infer.tar
|
||||||
|
trans_model:-m paddle_serving_client.convert
|
||||||
|
--dirname:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_infer/
|
||||||
|
--model_filename:inference.pdmodel
|
||||||
|
--params_filename:inference.pdiparams
|
||||||
|
--serving_server:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_serving/
|
||||||
|
--serving_client:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_client/
|
||||||
|
serving_dir:./deploy/paddleserving
|
||||||
|
web_service:null
|
||||||
|
--use_gpu:0|null
|
||||||
|
pipline:test_cpp_serving_client.py
|
@ -199,10 +199,16 @@ fi
|
|||||||
|
|
||||||
if [[ ${MODE} = "serving_infer" ]]; then
|
if [[ ${MODE} = "serving_infer" ]]; then
|
||||||
# prepare serving env
|
# prepare serving env
|
||||||
|
${python_name} -m pip install paddle_serving_client==0.9.0
|
||||||
|
${python_name} -m pip install paddle-serving-app==0.9.0
|
||||||
python_name=$(func_parser_value "${lines[2]}")
|
python_name=$(func_parser_value "${lines[2]}")
|
||||||
${python_name} -m pip install install paddle-serving-server-gpu==0.7.0.post102
|
if [[ ${FILENAME} =~ "cpp" ]]; then
|
||||||
${python_name} -m pip install paddle_serving_client==0.7.0
|
pushd ./deploy/paddleserving
|
||||||
${python_name} -m pip install paddle-serving-app==0.7.0
|
bash build_server.sh
|
||||||
|
popd
|
||||||
|
else
|
||||||
|
${python_name} -m pip install install paddle-serving-server-gpu==0.9.0.post102
|
||||||
|
fi
|
||||||
if [[ ${model_name} =~ "ShiTu" ]]; then
|
if [[ ${model_name} =~ "ShiTu" ]]; then
|
||||||
cls_inference_model_url=$(func_parser_value "${lines[3]}")
|
cls_inference_model_url=$(func_parser_value "${lines[3]}")
|
||||||
cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}")
|
cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}")
|
||||||
|
@ -50,8 +50,13 @@ function func_serving_cls(){
|
|||||||
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
||||||
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
||||||
|
|
||||||
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
for python_ in ${python[*]}; do
|
||||||
eval $trans_model_cmd
|
if [[ ${python_} =~ "python" ]]; then
|
||||||
|
trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||||
|
eval $trans_model_cmd
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
# modify the alias_name of fetch_var to "outputs"
|
# modify the alias_name of fetch_var to "outputs"
|
||||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
|
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
|
||||||
@ -69,103 +74,124 @@ function func_serving_cls(){
|
|||||||
unset https_proxy
|
unset https_proxy
|
||||||
unset http_proxy
|
unset http_proxy
|
||||||
|
|
||||||
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
|
if [[ ${FILENAME} =~ "cpp" ]]; then
|
||||||
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
for item in ${python[*]}; do
|
||||||
eval ${set_web_service_feet_var_cmd}
|
if [[ ${item} =~ "python" ]]; then
|
||||||
|
python_=${item}
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
set_pipeline_py_feed_var_cmd="sed -i '/feed/,/img}/s/{.*: img}/{${feed_var_name}: img}/' ${pipeline_py}"
|
||||||
|
echo $PWD - $set_pipeline_py_feed_var_cmd
|
||||||
|
eval ${set_pipeline_py_feed_var_cmd}
|
||||||
|
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
||||||
|
set_client_config_line_cmd="sed -i '/client/,/serving_server_conf.prototxt/s/.\/.*\/serving_server_conf.prototxt/.\/${serving_server_dir_name}\/serving_server_conf.prototxt/' ${pipeline_py}"
|
||||||
|
echo $PWD - $set_client_config_line_cmd
|
||||||
|
eval ${set_client_config_line_cmd}
|
||||||
|
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||||
|
if [ ${use_gpu} = "null" ]; then
|
||||||
|
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --port 9292 &"
|
||||||
|
echo $PWD - $web_service_cpp_cmd
|
||||||
|
eval ${web_service_cpp_cmd}
|
||||||
|
sleep 5s
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||||
|
echo ${pipeline_cmd}
|
||||||
|
eval ${pipeline_cmd}
|
||||||
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
|
ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
|
sleep 5s
|
||||||
|
else
|
||||||
|
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --port 9292 --gpu_id=${use_gpu} &"
|
||||||
|
echo $PWD - $web_service_cpp_cmd
|
||||||
|
eval ${web_service_cpp_cmd}
|
||||||
|
sleep 5s
|
||||||
|
|
||||||
model_config=21
|
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_pipeline_batchsize_1.log"
|
||||||
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||||
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
|
|
||||||
eval ${set_model_config_cmd}
|
|
||||||
|
|
||||||
for python in ${python[*]}; do
|
echo $PWD - $pipeline_cmd
|
||||||
if [[ ${python} = "cpp" ]]; then
|
eval $pipeline_cmd
|
||||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
last_status=${PIPESTATUS[0]}
|
||||||
if [ ${use_gpu} = "null" ]; then
|
eval "cat ${_save_log_path}"
|
||||||
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293"
|
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
eval $web_service_cmd
|
ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
sleep 5s
|
sleep 5s
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
fi
|
||||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
done
|
||||||
|
else
|
||||||
|
# python serving
|
||||||
|
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
|
||||||
|
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||||
|
eval ${set_web_service_feed_var_cmd}
|
||||||
|
|
||||||
|
model_config=21
|
||||||
|
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
||||||
|
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
|
||||||
|
eval ${set_model_config_cmd}
|
||||||
|
|
||||||
|
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||||
|
if [[ ${use_gpu} = "null" ]]; then
|
||||||
|
device_type_line=24
|
||||||
|
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||||
|
eval $set_device_type_cmd
|
||||||
|
|
||||||
|
devices_line=27
|
||||||
|
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||||
|
eval $set_devices_cmd
|
||||||
|
|
||||||
|
web_service_cmd="${python_} ${web_service_py} &"
|
||||||
|
eval $web_service_cmd
|
||||||
|
sleep 5s
|
||||||
|
for pipeline in ${pipeline_py[*]}; do
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||||
eval $pipeline_cmd
|
eval $pipeline_cmd
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
sleep 5s
|
sleep 5s
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
done
|
||||||
else
|
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293 --gpu_id=0"
|
elif [ ${use_gpu} -eq 0 ]; then
|
||||||
eval $web_service_cmd
|
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||||
sleep 5s
|
continue
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
fi
|
||||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
device_type_line=24
|
||||||
|
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||||
|
eval $set_device_type_cmd
|
||||||
|
|
||||||
|
devices_line=27
|
||||||
|
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||||
|
eval $set_devices_cmd
|
||||||
|
|
||||||
|
web_service_cmd="${python_} ${web_service_py} & "
|
||||||
|
eval $web_service_cmd
|
||||||
|
sleep 5s
|
||||||
|
for pipeline in ${pipeline_py[*]}; do
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1"
|
||||||
eval $pipeline_cmd
|
eval $pipeline_cmd
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
sleep 5s
|
sleep 5s
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
done
|
||||||
fi
|
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
done
|
else
|
||||||
else
|
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||||
# python serving
|
fi
|
||||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
done
|
||||||
if [[ ${use_gpu} = "null" ]]; then
|
fi
|
||||||
device_type_line=24
|
|
||||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
|
||||||
eval $set_device_type_cmd
|
|
||||||
|
|
||||||
devices_line=27
|
|
||||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
|
||||||
eval $set_devices_cmd
|
|
||||||
|
|
||||||
web_service_cmd="${python} ${web_service_py} &"
|
|
||||||
eval $web_service_cmd
|
|
||||||
sleep 5s
|
|
||||||
for pipeline in ${pipeline_py[*]}; do
|
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
|
||||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
|
||||||
eval $pipeline_cmd
|
|
||||||
last_status=${PIPESTATUS[0]}
|
|
||||||
eval "cat ${_save_log_path}"
|
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
|
||||||
sleep 5s
|
|
||||||
done
|
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
|
||||||
elif [ ${use_gpu} -eq 0 ]; then
|
|
||||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
|
|
||||||
device_type_line=24
|
|
||||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
|
||||||
eval $set_device_type_cmd
|
|
||||||
|
|
||||||
devices_line=27
|
|
||||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
|
||||||
eval $set_devices_cmd
|
|
||||||
|
|
||||||
web_service_cmd="${python} ${web_service_py} & "
|
|
||||||
eval $web_service_cmd
|
|
||||||
sleep 5s
|
|
||||||
for pipeline in ${pipeline_py[*]}; do
|
|
||||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
|
||||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
|
||||||
eval $pipeline_cmd
|
|
||||||
last_status=${PIPESTATUS[0]}
|
|
||||||
eval "cat ${_save_log_path}"
|
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
|
||||||
sleep 5s
|
|
||||||
done
|
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
|
||||||
else
|
|
||||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -200,6 +226,12 @@ function func_serving_rec(){
|
|||||||
pipeline_py=$(func_parser_value "${lines[17]}")
|
pipeline_py=$(func_parser_value "${lines[17]}")
|
||||||
|
|
||||||
IFS='|'
|
IFS='|'
|
||||||
|
for python_ in ${python[*]}; do
|
||||||
|
if [[ ${python_} =~ "python" ]]; then
|
||||||
|
python_interp=${python_}
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
# pdserving
|
# pdserving
|
||||||
cd ./deploy
|
cd ./deploy
|
||||||
@ -208,7 +240,7 @@ function func_serving_rec(){
|
|||||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||||
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
|
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
|
||||||
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
|
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
|
||||||
cls_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||||
eval $cls_trans_model_cmd
|
eval $cls_trans_model_cmd
|
||||||
|
|
||||||
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
|
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
|
||||||
@ -216,7 +248,7 @@ function func_serving_rec(){
|
|||||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||||
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
|
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
|
||||||
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
|
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
|
||||||
det_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||||
eval $det_trans_model_cmd
|
eval $det_trans_model_cmd
|
||||||
|
|
||||||
# modify the alias_name of fetch_var to "outputs"
|
# modify the alias_name of fetch_var to "outputs"
|
||||||
@ -235,99 +267,130 @@ function func_serving_rec(){
|
|||||||
unset https_proxy
|
unset https_proxy
|
||||||
unset http_proxy
|
unset http_proxy
|
||||||
|
|
||||||
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
|
if [[ ${FILENAME} =~ "cpp" ]]; then
|
||||||
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
det_serving_client_dir_name=$(func_get_url_file_name "$det_serving_client_value")
|
||||||
eval ${set_web_service_feet_var_cmd}
|
set_det_client_config_line_cmd="sed -i '/MainbodyDetect/,/serving_client_conf.prototxt/s/models\/.*\/serving_client_conf.prototxt/models\/${det_serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}"
|
||||||
|
echo $PWD - $set_det_client_config_line_cmd
|
||||||
|
eval ${set_det_client_config_line_cmd}
|
||||||
|
|
||||||
for python in ${python[*]}; do
|
cls_serving_client_dir_name=$(func_get_url_file_name "$cls_serving_client_value")
|
||||||
if [[ ${python} = "cpp" ]]; then
|
set_cls_client_config_line_cmd="sed -i '/ObjectRecognition/,/serving_client_conf.prototxt/s/models\/.*\/serving_client_conf.prototxt/models\/${cls_serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}"
|
||||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
echo $PWD - $set_cls_client_config_line_cmd
|
||||||
if [ ${use_gpu} = "null" ]; then
|
eval ${set_cls_client_config_line_cmd}
|
||||||
web_service_cpp_cmd="${python} web_service_py"
|
|
||||||
|
|
||||||
eval $web_service_cmd
|
set_pipeline_py_feed_var_cmd="sed -i '/ObjectRecognition/,/feed={\"x\": batch_imgs}/s/{.*: batch_imgs}/{${feed_var_name}: batch_imgs}/' ${pipeline_py}"
|
||||||
sleep 5s
|
echo $PWD - $set_pipeline_py_feed_var_cmd
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
eval ${set_pipeline_py_feed_var_cmd}
|
||||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
|
||||||
|
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||||
|
if [ ${use_gpu} = "null" ]; then
|
||||||
|
|
||||||
|
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
|
||||||
|
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../models/${det_serving_server_dir_name} --port 9293 >>log_mainbody_detection.txt &"
|
||||||
|
|
||||||
|
cls_serving_server_dir_name=$(func_get_url_file_name "$cls_serving_server_value")
|
||||||
|
web_service_cpp_cmd2="${python_interp} -m paddle_serving_server.serve --model ../../models/${cls_serving_server_dir_name} --port 9294 >>log_feature_extraction.txt &"
|
||||||
|
echo $PWD - $web_service_cpp_cmd
|
||||||
|
eval $web_service_cpp_cmd
|
||||||
|
echo $PWD - $web_service_cpp_cmd2
|
||||||
|
eval $web_service_cpp_cmd2
|
||||||
|
sleep 5s
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python_interp} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||||
|
echo ${pipeline_cmd}
|
||||||
|
eval ${pipeline_cmd}
|
||||||
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
|
ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
|
sleep 5s
|
||||||
|
else
|
||||||
|
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
|
||||||
|
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../models/${det_serving_server_dir_name} --port 9293 --gpu_id=${use_gpu} >>log_mainbody_detection.txt &"
|
||||||
|
|
||||||
|
cls_serving_server_dir_name=$(func_get_url_file_name "$cls_serving_server_value")
|
||||||
|
web_service_cpp_cmd2="${python_interp} -m paddle_serving_server.serve --model ../../models/${cls_serving_server_dir_name} --port 9294 --gpu_id=${use_gpu} >>log_feature_extraction.txt &"
|
||||||
|
echo $PWD - $web_service_cpp_cmd
|
||||||
|
eval $web_service_cpp_cmd
|
||||||
|
echo $PWD - $web_service_cpp_cmd2
|
||||||
|
eval $web_service_cpp_cmd2
|
||||||
|
sleep 5s
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python_interp} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||||
|
echo ${pipeline_cmd}
|
||||||
|
eval ${pipeline_cmd}
|
||||||
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
|
ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
|
sleep 5s
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
else
|
||||||
|
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
|
||||||
|
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||||
|
eval ${set_web_service_feed_var_cmd}
|
||||||
|
# python serving
|
||||||
|
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||||
|
if [[ ${use_gpu} = "null" ]]; then
|
||||||
|
device_type_line=24
|
||||||
|
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||||
|
eval $set_device_type_cmd
|
||||||
|
|
||||||
|
devices_line=27
|
||||||
|
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||||
|
eval $set_devices_cmd
|
||||||
|
|
||||||
|
web_service_cmd="${python} ${web_service_py} &"
|
||||||
|
eval $web_service_cmd
|
||||||
|
sleep 5s
|
||||||
|
for pipeline in ${pipeline_py[*]}; do
|
||||||
|
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||||
|
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||||
eval $pipeline_cmd
|
eval $pipeline_cmd
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
last_status=${PIPESTATUS[0]}
|
||||||
|
eval "cat ${_save_log_path}"
|
||||||
|
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
sleep 5s
|
sleep 5s
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
done
|
||||||
else
|
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
web_service_cpp_cmd="${python} web_service_py"
|
elif [ ${use_gpu} -eq 0 ]; then
|
||||||
eval $web_service_cmd
|
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||||
sleep 5s
|
continue
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
fi
|
||||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||||
eval $pipeline_cmd
|
continue
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
fi
|
||||||
sleep 5s
|
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
continue
|
||||||
fi
|
fi
|
||||||
done
|
|
||||||
else
|
|
||||||
# python serving
|
|
||||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
|
||||||
if [[ ${use_gpu} = "null" ]]; then
|
|
||||||
device_type_line=24
|
|
||||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
|
||||||
eval $set_device_type_cmd
|
|
||||||
|
|
||||||
devices_line=27
|
device_type_line=24
|
||||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||||
eval $set_devices_cmd
|
eval $set_device_type_cmd
|
||||||
|
|
||||||
web_service_cmd="${python} ${web_service_py} &"
|
devices_line=27
|
||||||
eval $web_service_cmd
|
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||||
sleep 5s
|
eval $set_devices_cmd
|
||||||
for pipeline in ${pipeline_py[*]}; do
|
|
||||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
|
||||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
|
||||||
eval $pipeline_cmd
|
|
||||||
last_status=${PIPESTATUS[0]}
|
|
||||||
eval "cat ${_save_log_path}"
|
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
|
||||||
sleep 5s
|
|
||||||
done
|
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
|
||||||
elif [ ${use_gpu} -eq 0 ]; then
|
|
||||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
|
|
||||||
device_type_line=24
|
web_service_cmd="${python} ${web_service_py} & "
|
||||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
eval $web_service_cmd
|
||||||
eval $set_device_type_cmd
|
sleep 10s
|
||||||
|
for pipeline in ${pipeline_py[*]}; do
|
||||||
devices_line=27
|
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
||||||
eval $set_devices_cmd
|
eval $pipeline_cmd
|
||||||
|
last_status=${PIPESTATUS[0]}
|
||||||
web_service_cmd="${python} ${web_service_py} & "
|
eval "cat ${_save_log_path}"
|
||||||
eval $web_service_cmd
|
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||||
sleep 10s
|
sleep 10s
|
||||||
for pipeline in ${pipeline_py[*]}; do
|
done
|
||||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
else
|
||||||
eval $pipeline_cmd
|
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||||
last_status=${PIPESTATUS[0]}
|
fi
|
||||||
eval "cat ${_save_log_path}"
|
done
|
||||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
fi
|
||||||
sleep 10s
|
|
||||||
done
|
|
||||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
|
||||||
else
|
|
||||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user