Merge branch 'develop' into dev/add_en_doc
commit
0c34078bc2
|
@ -12,3 +12,4 @@ build/
|
|||
log/
|
||||
nohup.out
|
||||
.DS_Store
|
||||
.idea
|
||||
|
|
|
@ -20,10 +20,10 @@
|
|||
|
||||
|
||||
## 近期更新
|
||||
- 📢将于**<u>6月15-6月17日晚20:30</u>**进行为期三天的课程直播,详细介绍超轻量图像分类方案,对各场景模型优化原理及使用方式进行拆解,之后还有产业案例全流程实操,对各类痛难点解决方案进行手把手教学,加上现场互动答疑,抓紧扫码上车吧!
|
||||
- 📢将于**6月15-6月17日晚20:30** 进行为期三天的课程直播,详细介绍超轻量图像分类方案,对各场景模型优化原理及使用方式进行拆解,之后还有产业案例全流程实操,对各类痛难点解决方案进行手把手教学,加上现场互动答疑,抓紧扫码上车吧!
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/80816848/173404459-9426c0ed-4801-4f75-876f-2e6ec47255f5.png" width = "200" height = "200"/>
|
||||
<img src="https://user-images.githubusercontent.com/45199522/173483779-2332f990-4941-4f8d-baee-69b62035fc31.png" width = "200" height = "200"/>
|
||||
</div>
|
||||
|
||||
- 🔥️ 2022.6.15 发布[PULC超轻量图像分类实用方案](docs/zh_CN/PULC/PULC_train.md),CPU推理3ms,精度比肩SwinTransformer,覆盖人、车、OCR场景九大常见任务。
|
||||
|
@ -47,11 +47,11 @@ PaddleClas发布了[PP-HGNet](docs/zh_CN/models/PP-HGNet.md)、[PP-LCNetv2](docs
|
|||
|
||||
## 欢迎加入技术交流群
|
||||
|
||||
* 您可以扫描下面的QQ/微信二维码(添加小助手微信并回复“C”),加入PaddleClas微信交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
|
||||
* 您可以扫描下面的微信/QQ二维码(添加小助手微信并回复“C”),加入PaddleClas微信交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/80816848/164383225-e375eb86-716e-41b4-a9e0-4b8a3976c1aa.jpg" width="200"/>
|
||||
<img src="https://user-images.githubusercontent.com/48054808/160531099-9811bbe6-cfbb-47d5-8bdb-c2b40684d7dd.png" width="200"/>
|
||||
<img src="https://user-images.githubusercontent.com/80816848/164383225-e375eb86-716e-41b4-a9e0-4b8a3976c1aa.jpg" width="200"/>
|
||||
</div>
|
||||
|
||||
## 快速体验
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
PaddleClas is an image classification and image recognition toolset for industry and academia, helping users train better computer vision models and apply them in real scenarios.
|
||||
|
||||
<div align="center">
|
||||
<img src="./docs/images/class_simple.gif" width = "600" />
|
||||
<img src="./docs/images/class_simple_en.gif" width = "600" />
|
||||
|
||||
PULC demo images
|
||||
</div>
|
||||
|
@ -38,7 +38,7 @@ image classification and image recognition algorithms.
|
|||
Based on th algorithms above, PaddleClas release PP-ShiTu image recognition system and [**P**ractical **U**ltra **L**ight-weight image **C**lassification solutions](docs/en/PULC/PULC_quickstart_en.md).
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
## Welcome to Join the Technical Exchange Group
|
||||
|
||||
|
@ -51,6 +51,7 @@ Based on th algorithms above, PaddleClas release PP-ShiTu image recognition syst
|
|||
|
||||
## Quick Start
|
||||
Quick experience of PP-ShiTu image recognition system:[Link](./docs/en/tutorials/quick_start_recognition_en.md)
|
||||
|
||||
Quick experience of **P**ractical **U**ltra **L**ight-weight image **C**lassification models:[Link](docs/en/PULC/PULC_quickstart_en.md)
|
||||
|
||||
## Tutorials
|
||||
|
@ -114,7 +115,7 @@ For a new unknown category, there is no need to retrain the model, just prepare
|
|||
|
||||
## PULC demo images
|
||||
<div align="center">
|
||||
<img src="docs/images/classification.gif">
|
||||
<img src="docs/images/classification_en.gif">
|
||||
</div>
|
||||
|
||||
<a name="Rec_Demo_images"></a>
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# 使用镜像:
|
||||
# registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82
|
||||
|
||||
# 编译Serving Server:
|
||||
|
||||
# client和app可以直接使用release版本
|
||||
|
||||
# server因为加入了自定义OP,需要重新编译
|
||||
|
||||
# 默认编译时的${PWD}=PaddleClas/deploy/paddleserving/
|
||||
|
||||
python_name=${1:-'python'}
|
||||
|
||||
apt-get update
|
||||
apt install -y libcurl4-openssl-dev libbz2-dev
|
||||
wget -nc https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar
|
||||
tar xf centos_ssl.tar
|
||||
rm -rf centos_ssl.tar
|
||||
mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k
|
||||
mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k
|
||||
ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10
|
||||
ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10
|
||||
ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so
|
||||
ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so
|
||||
|
||||
# 安装go依赖
|
||||
rm -rf /usr/local/go
|
||||
wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local
|
||||
export GOROOT=/usr/local/go
|
||||
export GOPATH=/root/gopath
|
||||
export PATH=$PATH:$GOPATH/bin:$GOROOT/bin
|
||||
go env -w GO111MODULE=on
|
||||
go env -w GOPROXY=https://goproxy.cn,direct
|
||||
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2
|
||||
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
|
||||
go install github.com/golang/protobuf/protoc-gen-go@v1.4.3
|
||||
go install google.golang.org/grpc@v1.33.0
|
||||
go env -w GO111MODULE=auto
|
||||
|
||||
# 下载opencv库
|
||||
wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz
|
||||
tar -xvf opencv3.tar.gz
|
||||
rm -rf opencv3.tar.gz
|
||||
export OPENCV_DIR=$PWD/opencv3
|
||||
|
||||
# clone Serving
|
||||
git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1
|
||||
|
||||
cd Serving # PaddleClas/deploy/paddleserving/Serving
|
||||
export Serving_repo_path=$PWD
|
||||
git submodule update --init --recursive
|
||||
${python_name} -m pip install -r python/requirements.txt
|
||||
|
||||
# set env
|
||||
export PYTHON_INCLUDE_DIR=$(${python_name} -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")
|
||||
export PYTHON_LIBRARIES=$(${python_name} -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")
|
||||
export PYTHON_EXECUTABLE=`which ${python_name}`
|
||||
|
||||
export CUDA_PATH='/usr/local/cuda'
|
||||
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
|
||||
export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/'
|
||||
export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/'
|
||||
|
||||
# cp 自定义OP代码
|
||||
\cp ../preprocess/general_clas_op.* ${Serving_repo_path}/core/general-server/op
|
||||
\cp ../preprocess/preprocess_op.* ${Serving_repo_path}/core/predictor/tools/pp_shitu_tools
|
||||
|
||||
# 编译Server
|
||||
mkdir server-build-gpu-opencv
|
||||
cd server-build-gpu-opencv
|
||||
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \
|
||||
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
|
||||
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
|
||||
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
|
||||
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
|
||||
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
|
||||
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
|
||||
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||
-DWITH_OPENCV=ON \
|
||||
-DSERVER=ON \
|
||||
-DWITH_GPU=ON ..
|
||||
make -j32
|
||||
|
||||
${python_name} -m pip install python/dist/paddle*
|
||||
|
||||
# export SERVING_BIN
|
||||
export SERVING_BIN=$PWD/core/general-server/serving
|
||||
cd ../../
|
|
@ -0,0 +1,206 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "core/general-server/op/general_clas_op.h"
|
||||
#include "core/predictor/framework/infer.h"
|
||||
#include "core/predictor/framework/memory.h"
|
||||
#include "core/predictor/framework/resource.h"
|
||||
#include "core/util/include/timer.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
namespace baidu {
|
||||
namespace paddle_serving {
|
||||
namespace serving {
|
||||
|
||||
using baidu::paddle_serving::Timer;
|
||||
using baidu::paddle_serving::predictor::MempoolWrapper;
|
||||
using baidu::paddle_serving::predictor::general_model::Tensor;
|
||||
using baidu::paddle_serving::predictor::general_model::Response;
|
||||
using baidu::paddle_serving::predictor::general_model::Request;
|
||||
using baidu::paddle_serving::predictor::InferManager;
|
||||
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
|
||||
|
||||
int GeneralClasOp::inference() {
|
||||
VLOG(2) << "Going to run inference";
|
||||
const std::vector<std::string> pre_node_names = pre_names();
|
||||
if (pre_node_names.size() != 1) {
|
||||
LOG(ERROR) << "This op(" << op_name()
|
||||
<< ") can only have one predecessor op, but received "
|
||||
<< pre_node_names.size();
|
||||
return -1;
|
||||
}
|
||||
const std::string pre_name = pre_node_names[0];
|
||||
|
||||
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
|
||||
if (!input_blob) {
|
||||
LOG(ERROR) << "input_blob is nullptr,error";
|
||||
return -1;
|
||||
}
|
||||
uint64_t log_id = input_blob->GetLogId();
|
||||
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
|
||||
|
||||
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
|
||||
if (!output_blob) {
|
||||
LOG(ERROR) << "output_blob is nullptr,error";
|
||||
return -1;
|
||||
}
|
||||
output_blob->SetLogId(log_id);
|
||||
|
||||
if (!input_blob) {
|
||||
LOG(ERROR) << "(logid=" << log_id
|
||||
<< ") Failed mutable depended argument, op:" << pre_name;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const TensorVector *in = &input_blob->tensor_vector;
|
||||
TensorVector *out = &output_blob->tensor_vector;
|
||||
|
||||
int batch_size = input_blob->_batch_size;
|
||||
output_blob->_batch_size = batch_size;
|
||||
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
|
||||
|
||||
Timer timeline;
|
||||
int64_t start = timeline.TimeStampUS();
|
||||
timeline.Start();
|
||||
|
||||
// only support string type
|
||||
|
||||
char *total_input_ptr = static_cast<char *>(in->at(0).data.data());
|
||||
std::string base64str = total_input_ptr;
|
||||
|
||||
cv::Mat img = Base2Mat(base64str);
|
||||
|
||||
// RGB2BGR
|
||||
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
|
||||
|
||||
// Resize
|
||||
cv::Mat resize_img;
|
||||
resize_op_.Run(img, resize_img, resize_short_size_);
|
||||
|
||||
// CenterCrop
|
||||
crop_op_.Run(resize_img, crop_size_);
|
||||
|
||||
// Normalize
|
||||
normalize_op_.Run(&resize_img, mean_, scale_, is_scale_);
|
||||
|
||||
// Permute
|
||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
permute_op_.Run(&resize_img, input.data());
|
||||
float maxValue = *max_element(input.begin(), input.end());
|
||||
float minValue = *min_element(input.begin(), input.end());
|
||||
|
||||
TensorVector *real_in = new TensorVector();
|
||||
if (!real_in) {
|
||||
LOG(ERROR) << "real_in is nullptr,error";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<int> input_shape;
|
||||
int in_num = 0;
|
||||
void *databuf_data = NULL;
|
||||
char *databuf_char = NULL;
|
||||
size_t databuf_size = 0;
|
||||
|
||||
input_shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
databuf_size = in_num * sizeof(float);
|
||||
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
|
||||
if (!databuf_data) {
|
||||
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(databuf_data, input.data(), databuf_size);
|
||||
databuf_char = reinterpret_cast<char *>(databuf_data);
|
||||
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
|
||||
paddle::PaddleTensor tensor_in;
|
||||
tensor_in.name = in->at(0).name;
|
||||
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
|
||||
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
tensor_in.lod = in->at(0).lod;
|
||||
tensor_in.data = paddleBuf;
|
||||
real_in->push_back(tensor_in);
|
||||
|
||||
if (InferManager::instance().infer(engine_name().c_str(), real_in, out,
|
||||
batch_size)) {
|
||||
LOG(ERROR) << "(logid=" << log_id
|
||||
<< ") Failed do infer in fluid model: " << engine_name().c_str();
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64_t end = timeline.TimeStampUS();
|
||||
CopyBlobInfo(input_blob, output_blob);
|
||||
AddBlobInfo(output_blob, start);
|
||||
AddBlobInfo(output_blob, end);
|
||||
return 0;
|
||||
}
|
||||
|
||||
cv::Mat GeneralClasOp::Base2Mat(std::string &base64_data) {
|
||||
cv::Mat img;
|
||||
std::string s_mat;
|
||||
s_mat = base64Decode(base64_data.data(), base64_data.size());
|
||||
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
|
||||
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
|
||||
return img;
|
||||
}
|
||||
|
||||
std::string GeneralClasOp::base64Decode(const char *Data, int DataByte) {
|
||||
const char DecodeTable[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
62, // '+'
|
||||
0, 0, 0,
|
||||
63, // '/'
|
||||
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
||||
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
||||
};
|
||||
|
||||
std::string strDecode;
|
||||
int nValue;
|
||||
int i = 0;
|
||||
while (i < DataByte) {
|
||||
if (*Data != '\r' && *Data != '\n') {
|
||||
nValue = DecodeTable[*Data++] << 18;
|
||||
nValue += DecodeTable[*Data++] << 12;
|
||||
strDecode += (nValue & 0x00FF0000) >> 16;
|
||||
if (*Data != '=') {
|
||||
nValue += DecodeTable[*Data++] << 6;
|
||||
strDecode += (nValue & 0x0000FF00) >> 8;
|
||||
if (*Data != '=') {
|
||||
nValue += DecodeTable[*Data++];
|
||||
strDecode += nValue & 0x000000FF;
|
||||
}
|
||||
}
|
||||
i += 4;
|
||||
} else // 回车换行,跳过
|
||||
{
|
||||
Data++;
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return strDecode;
|
||||
}
|
||||
DEFINE_OP(GeneralClasOp);
|
||||
} // namespace serving
|
||||
} // namespace paddle_serving
|
||||
} // namespace baidu
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "core/general-server/general_model_service.pb.h"
|
||||
#include "core/general-server/op/general_infer_helper.h"
|
||||
#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h"
|
||||
#include "paddle_inference_api.h" // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
namespace baidu {
|
||||
namespace paddle_serving {
|
||||
namespace serving {
|
||||
|
||||
class GeneralClasOp
|
||||
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
|
||||
public:
|
||||
typedef std::vector<paddle::PaddleTensor> TensorVector;
|
||||
|
||||
DECLARE_OP(GeneralClasOp);
|
||||
|
||||
int inference();
|
||||
|
||||
private:
|
||||
// clas preprocess
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
int resize_short_size_ = 256;
|
||||
int crop_size_ = 224;
|
||||
|
||||
PaddleClas::ResizeImg resize_op_;
|
||||
PaddleClas::Normalize normalize_op_;
|
||||
PaddleClas::Permute permute_op_;
|
||||
PaddleClas::CenterCropImg crop_op_;
|
||||
|
||||
// read pics
|
||||
cv::Mat Base2Mat(std::string &base64_data);
|
||||
std::string base64Decode(const char *Data, int DataByte);
|
||||
};
|
||||
|
||||
} // namespace serving
|
||||
} // namespace paddle_serving
|
||||
} // namespace baidu
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <math.h>
|
||||
#include <numeric>
|
||||
|
||||
#include "preprocess_op.h"
|
||||
|
||||
namespace Feature {
|
||||
|
||||
void Permute::Run(const cv::Mat *im, float *data) {
|
||||
int rh = im->rows;
|
||||
int rw = im->cols;
|
||||
int rc = im->channels();
|
||||
for (int i = 0; i < rc; ++i) {
|
||||
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
|
||||
}
|
||||
}
|
||||
|
||||
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
||||
const std::vector<float> &std, float scale) {
|
||||
(*im).convertTo(*im, CV_32FC3, scale);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
|
||||
int resize_w = img.cols;
|
||||
int resize_h = img.rows;
|
||||
int w_start = int((resize_w - crop_size) / 2);
|
||||
int h_start = int((resize_h - crop_size) / 2);
|
||||
cv::Rect rect(w_start, h_start, crop_size, crop_size);
|
||||
img = img(rect);
|
||||
}
|
||||
|
||||
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
int resize_short_size, int size) {
|
||||
int resize_h = 0;
|
||||
int resize_w = 0;
|
||||
if (size > 0) {
|
||||
resize_h = size;
|
||||
resize_w = size;
|
||||
} else {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
float ratio = 1.f;
|
||||
if (h < w) {
|
||||
ratio = float(resize_short_size) / float(h);
|
||||
} else {
|
||||
ratio = float(resize_short_size) / float(w);
|
||||
}
|
||||
resize_h = round(float(h) * ratio);
|
||||
resize_w = round(float(w) * ratio);
|
||||
}
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
}
|
||||
|
||||
} // namespace Feature
|
||||
|
||||
namespace PaddleClas {
|
||||
void Permute::Run(const cv::Mat *im, float *data) {
|
||||
int rh = im->rows;
|
||||
int rw = im->cols;
|
||||
int rc = im->channels();
|
||||
for (int i = 0; i < rc; ++i) {
|
||||
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
|
||||
}
|
||||
}
|
||||
|
||||
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
||||
const std::vector<float> &scale, const bool is_scale) {
|
||||
double e = 1.0;
|
||||
if (is_scale) {
|
||||
e /= 255.0;
|
||||
}
|
||||
(*im).convertTo(*im, CV_32FC3, e);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / scale[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / scale[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / scale[2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
|
||||
int resize_w = img.cols;
|
||||
int resize_h = img.rows;
|
||||
int w_start = int((resize_w - crop_size) / 2);
|
||||
int h_start = int((resize_h - crop_size) / 2);
|
||||
cv::Rect rect(w_start, h_start, crop_size, crop_size);
|
||||
img = img(rect);
|
||||
}
|
||||
|
||||
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
int resize_short_size) {
|
||||
int w = img.cols;
|
||||
int h = img.rows;
|
||||
|
||||
float ratio = 1.f;
|
||||
if (h < w) {
|
||||
ratio = float(resize_short_size) / float(h);
|
||||
} else {
|
||||
ratio = float(resize_short_size) / float(w);
|
||||
}
|
||||
|
||||
int resize_h = round(float(h) * ratio);
|
||||
int resize_w = round(float(w) * ratio);
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
}
|
||||
|
||||
} // namespace PaddleClas
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
namespace Feature {
|
||||
|
||||
class Normalize {
|
||||
public:
|
||||
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
|
||||
const std::vector<float> &std, float scale);
|
||||
};
|
||||
|
||||
// RGB -> CHW
|
||||
class Permute {
|
||||
public:
|
||||
virtual void Run(const cv::Mat *im, float *data);
|
||||
};
|
||||
|
||||
class CenterCropImg {
|
||||
public:
|
||||
virtual void Run(cv::Mat &im, const int crop_size = 224);
|
||||
};
|
||||
|
||||
class ResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
|
||||
int size = 0);
|
||||
};
|
||||
|
||||
} // namespace Feature
|
||||
|
||||
namespace PaddleClas {
|
||||
|
||||
class Normalize {
|
||||
public:
|
||||
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
|
||||
const std::vector<float> &scale, const bool is_scale = true);
|
||||
};
|
||||
|
||||
// RGB -> CHW
|
||||
class Permute {
|
||||
public:
|
||||
virtual void Run(const cv::Mat *im, float *data);
|
||||
};
|
||||
|
||||
class CenterCropImg {
|
||||
public:
|
||||
virtual void Run(cv::Mat &im, const int crop_size = 224);
|
||||
};
|
||||
|
||||
class ResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
|
||||
};
|
||||
|
||||
} // namespace PaddleClas
|
|
@ -0,0 +1,32 @@
|
|||
feed_var {
|
||||
name: "x"
|
||||
alias_name: "x"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 3
|
||||
shape: 224
|
||||
shape: 224
|
||||
}
|
||||
feed_var {
|
||||
name: "boxes"
|
||||
alias_name: "boxes"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 6
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_0.tmp_1"
|
||||
alias_name: "features"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 1
|
||||
shape: 512
|
||||
}
|
||||
fetch_var {
|
||||
name: "boxes"
|
||||
alias_name: "boxes"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 1
|
||||
shape: 6
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
feed_var {
|
||||
name: "x"
|
||||
alias_name: "x"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 3
|
||||
shape: 224
|
||||
shape: 224
|
||||
}
|
||||
feed_var {
|
||||
name: "boxes"
|
||||
alias_name: "boxes"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 6
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_0.tmp_1"
|
||||
alias_name: "features"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 1
|
||||
shape: 512
|
||||
}
|
||||
fetch_var {
|
||||
name: "boxes"
|
||||
alias_name: "boxes"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 1
|
||||
shape: 6
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
feed_var {
|
||||
name: "im_shape"
|
||||
alias_name: "im_shape"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 2
|
||||
}
|
||||
feed_var {
|
||||
name: "image"
|
||||
alias_name: "image"
|
||||
is_lod_tensor: false
|
||||
feed_type: 7
|
||||
shape: -1
|
||||
shape: -1
|
||||
shape: 3
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_0.tmp_1"
|
||||
alias_name: "save_infer_model/scale_0.tmp_1"
|
||||
is_lod_tensor: true
|
||||
fetch_type: 1
|
||||
shape: -1
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_1.tmp_1"
|
||||
alias_name: "save_infer_model/scale_1.tmp_1"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 2
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
feed_var {
|
||||
name: "im_shape"
|
||||
alias_name: "im_shape"
|
||||
is_lod_tensor: false
|
||||
feed_type: 1
|
||||
shape: 2
|
||||
}
|
||||
feed_var {
|
||||
name: "image"
|
||||
alias_name: "image"
|
||||
is_lod_tensor: false
|
||||
feed_type: 7
|
||||
shape: -1
|
||||
shape: -1
|
||||
shape: 3
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_0.tmp_1"
|
||||
alias_name: "save_infer_model/scale_0.tmp_1"
|
||||
is_lod_tensor: true
|
||||
fetch_type: 1
|
||||
shape: -1
|
||||
}
|
||||
fetch_var {
|
||||
name: "save_infer_model/scale_1.tmp_1"
|
||||
alias_name: "save_infer_model/scale_1.tmp_1"
|
||||
is_lod_tensor: false
|
||||
fetch_type: 2
|
||||
}
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from paddle_serving_client import Client
|
||||
|
@ -22,101 +21,14 @@ import faiss
|
|||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class MainbodyDetect():
|
||||
"""
|
||||
pp-shitu mainbody detect.
|
||||
include preprocess, process, postprocess
|
||||
return detect results
|
||||
Attention: Postprocess include num limit and box filter; no nms
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.preprocess = DetectionSequential([
|
||||
DetectionFile2Image(), DetectionNormalize(
|
||||
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
|
||||
DetectionResize(
|
||||
(640, 640), False, interpolation=2), DetectionTranspose(
|
||||
(2, 0, 1))
|
||||
])
|
||||
|
||||
self.client = Client()
|
||||
self.client.load_client_config(
|
||||
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/serving_client_conf.prototxt"
|
||||
)
|
||||
self.client.connect(['127.0.0.1:9293'])
|
||||
|
||||
self.max_det_result = 5
|
||||
self.conf_threshold = 0.2
|
||||
|
||||
def predict(self, imgpath):
|
||||
im, im_info = self.preprocess(imgpath)
|
||||
im_shape = np.array(im.shape[1:]).reshape(-1)
|
||||
scale_factor = np.array(list(im_info['scale_factor'])).reshape(-1)
|
||||
|
||||
fetch_map = self.client.predict(
|
||||
feed={
|
||||
"image": im,
|
||||
"im_shape": im_shape,
|
||||
"scale_factor": scale_factor,
|
||||
},
|
||||
fetch=["save_infer_model/scale_0.tmp_1"],
|
||||
batch=False)
|
||||
return self.postprocess(fetch_map, imgpath)
|
||||
|
||||
def postprocess(self, fetch_map, imgpath):
|
||||
#1. get top max_det_result
|
||||
det_results = fetch_map["save_infer_model/scale_0.tmp_1"]
|
||||
if len(det_results) > self.max_det_result:
|
||||
boxes_reserved = fetch_map[
|
||||
"save_infer_model/scale_0.tmp_1"][:self.max_det_result]
|
||||
else:
|
||||
boxes_reserved = det_results
|
||||
|
||||
#2. do conf threshold
|
||||
boxes_list = []
|
||||
for i in range(boxes_reserved.shape[0]):
|
||||
if (boxes_reserved[i, 1]) > self.conf_threshold:
|
||||
boxes_list.append(boxes_reserved[i, :])
|
||||
|
||||
#3. add origin image box
|
||||
origin_img = cv2.imread(imgpath)
|
||||
boxes_list.append(
|
||||
np.array([0, 1.0, 0, 0, origin_img.shape[1], origin_img.shape[0]]))
|
||||
return np.array(boxes_list)
|
||||
|
||||
|
||||
class ObjectRecognition():
|
||||
"""
|
||||
pp-shitu object recognion for all objects detected by MainbodyDetect.
|
||||
include preprocess, process, postprocess
|
||||
preprocess include preprocess for each image and batching.
|
||||
Batch process
|
||||
postprocess include retrieval and nms
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = Client()
|
||||
self.client.load_client_config(
|
||||
"../../models/general_PPLCNet_x2_5_lite_v1.0_client/serving_client_conf.prototxt"
|
||||
)
|
||||
self.client.connect(["127.0.0.1:9294"])
|
||||
|
||||
self.seq = Sequential([
|
||||
BGR2RGB(), Resize((224, 224)), Div(255),
|
||||
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],
|
||||
False), Transpose((2, 0, 1))
|
||||
])
|
||||
|
||||
self.searcher, self.id_map = self.init_index()
|
||||
|
||||
self.rec_nms_thresold = 0.05
|
||||
self.rec_score_thres = 0.5
|
||||
self.feature_normalize = True
|
||||
self.return_k = 1
|
||||
|
||||
def init_index(self):
|
||||
rec_nms_thresold = 0.05
|
||||
rec_score_thres = 0.5
|
||||
feature_normalize = True
|
||||
return_k = 1
|
||||
index_dir = "../../drink_dataset_v1.0/index"
|
||||
|
||||
|
||||
def init_index(index_dir):
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "vector.index")), "vector.index not found ..."
|
||||
assert os.path.exists(os.path.join(
|
||||
|
@ -128,42 +40,11 @@ class ObjectRecognition():
|
|||
id_map = pickle.load(fd)
|
||||
return searcher, id_map
|
||||
|
||||
def predict(self, det_boxes, imgpath):
|
||||
#1. preprocess
|
||||
batch_imgs = []
|
||||
origin_img = cv2.imread(imgpath)
|
||||
for i in range(det_boxes.shape[0]):
|
||||
box = det_boxes[i]
|
||||
x1, y1, x2, y2 = [int(x) for x in box[2:]]
|
||||
cropped_img = origin_img[y1:y2, x1:x2, :].copy()
|
||||
tmp = self.seq(cropped_img)
|
||||
batch_imgs.append(tmp)
|
||||
batch_imgs = np.array(batch_imgs)
|
||||
|
||||
#2. process
|
||||
fetch_map = self.client.predict(
|
||||
feed={"x": batch_imgs}, fetch=["features"], batch=True)
|
||||
batch_features = fetch_map["features"]
|
||||
|
||||
#3. postprocess
|
||||
if self.feature_normalize:
|
||||
feas_norm = np.sqrt(
|
||||
np.sum(np.square(batch_features), axis=1, keepdims=True))
|
||||
batch_features = np.divide(batch_features, feas_norm)
|
||||
scores, docs = self.searcher.search(batch_features, self.return_k)
|
||||
|
||||
results = []
|
||||
for i in range(scores.shape[0]):
|
||||
pred = {}
|
||||
if scores[i][0] >= self.rec_score_thres:
|
||||
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
|
||||
pred["rec_docs"] = self.id_map[docs[i][0]].split()[1]
|
||||
pred["rec_scores"] = scores[i][0]
|
||||
results.append(pred)
|
||||
return self.nms_to_rec_results(results)
|
||||
|
||||
def nms_to_rec_results(self, results):
|
||||
#get box
|
||||
def nms_to_rec_results(results, thresh=0.1):
|
||||
filtered_results = []
|
||||
|
||||
x1 = np.array([r["bbox"][0] for r in results]).astype("float32")
|
||||
y1 = np.array([r["bbox"][1] for r in results]).astype("float32")
|
||||
x2 = np.array([r["bbox"][2] for r in results]).astype("float32")
|
||||
|
@ -183,20 +64,58 @@ class ObjectRecognition():
|
|||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
inds = np.where(ovr <= self.rec_nms_thresold)[0]
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
filtered_results.append(results[i])
|
||||
return filtered_results
|
||||
|
||||
|
||||
def postprocess(fetch_dict, feature_normalize, det_boxes, searcher, id_map,
|
||||
return_k, rec_score_thres, rec_nms_thresold):
|
||||
batch_features = fetch_dict["features"]
|
||||
|
||||
#do feature norm
|
||||
if feature_normalize:
|
||||
feas_norm = np.sqrt(
|
||||
np.sum(np.square(batch_features), axis=1, keepdims=True))
|
||||
batch_features = np.divide(batch_features, feas_norm)
|
||||
|
||||
scores, docs = searcher.search(batch_features, return_k)
|
||||
|
||||
results = []
|
||||
for i in range(scores.shape[0]):
|
||||
pred = {}
|
||||
if scores[i][0] >= rec_score_thres:
|
||||
pred["bbox"] = [int(x) for x in det_boxes[i, 2:]]
|
||||
pred["rec_docs"] = id_map[docs[i][0]].split()[1]
|
||||
pred["rec_scores"] = scores[i][0]
|
||||
results.append(pred)
|
||||
|
||||
#do nms
|
||||
results = nms_to_rec_results(results, rec_nms_thresold)
|
||||
return results
|
||||
|
||||
|
||||
#do client
|
||||
if __name__ == "__main__":
|
||||
det = MainbodyDetect()
|
||||
rec = ObjectRecognition()
|
||||
client = Client()
|
||||
client.load_client_config([
|
||||
"../../models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client",
|
||||
"../../models/general_PPLCNet_x2_5_lite_v1.0_client"
|
||||
])
|
||||
client.connect(['127.0.0.1:9400'])
|
||||
|
||||
#1. get det_results
|
||||
imgpath = "../../drink_dataset_v1.0/test_images/001.jpeg"
|
||||
det_results = det.predict(imgpath)
|
||||
im = cv2.imread("../../drink_dataset_v1.0/test_images/001.jpeg")
|
||||
im_shape = np.array(im.shape[:2]).reshape(-1)
|
||||
fetch_map = client.predict(
|
||||
feed={"image": im,
|
||||
"im_shape": im_shape},
|
||||
fetch=["features", "boxes"],
|
||||
batch=False)
|
||||
|
||||
#2. get rec_results
|
||||
rec_results = rec.predict(det_results, imgpath)
|
||||
print(rec_results)
|
||||
#add retrieval procedure
|
||||
det_boxes = fetch_map["boxes"]
|
||||
searcher, id_map = init_index(index_dir)
|
||||
results = postprocess(fetch_map, feature_normalize, det_boxes, searcher,
|
||||
id_map, return_k, rec_score_thres, rec_nms_thresold)
|
||||
print(results)
|
||||
|
|
|
@ -12,16 +12,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from paddle_serving_client import Client
|
||||
|
||||
#app
|
||||
from paddle_serving_app.reader import Sequential, URL2Image, Resize
|
||||
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
|
||||
import base64
|
||||
import time
|
||||
|
||||
from paddle_serving_client import Client
|
||||
|
||||
|
||||
def bytes_to_base64(image: bytes) -> str:
|
||||
"""encode bytes into base64 string
|
||||
"""
|
||||
return base64.b64encode(image).decode('utf8')
|
||||
|
||||
|
||||
client = Client()
|
||||
client.load_client_config("./ResNet50_vd_serving/serving_server_conf.prototxt")
|
||||
client.load_client_config("./ResNet50_client/serving_client_conf.prototxt")
|
||||
client.connect(["127.0.0.1:9292"])
|
||||
|
||||
label_dict = {}
|
||||
|
@ -31,22 +35,17 @@ with open("imagenet.label") as fin:
|
|||
label_dict[label_idx] = line.strip()
|
||||
label_idx += 1
|
||||
|
||||
#preprocess
|
||||
seq = Sequential([
|
||||
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
|
||||
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
|
||||
])
|
||||
|
||||
start = time.time()
|
||||
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
|
||||
image_file = "./daisy.jpg"
|
||||
for i in range(1):
|
||||
img = seq(image_file)
|
||||
fetch_map = client.predict(
|
||||
feed={"inputs": img}, fetch=["prediction"], batch=False)
|
||||
|
||||
prob = max(fetch_map["prediction"][0])
|
||||
label = label_dict[fetch_map["prediction"][0].tolist().index(prob)].strip(
|
||||
).replace(",", "")
|
||||
start = time.time()
|
||||
with open(image_file, 'rb') as img_file:
|
||||
image_data = img_file.read()
|
||||
image = bytes_to_base64(image_data)
|
||||
fetch_dict = client.predict(
|
||||
feed={"inputs": image}, fetch=["prediction"], batch=False)
|
||||
prob = max(fetch_dict["prediction"][0])
|
||||
label = label_dict[fetch_dict["prediction"][0].tolist().index(
|
||||
prob)].strip().replace(",", "")
|
||||
print("prediction: {}, probability: {}".format(label, prob))
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 3.8 MiB |
Binary file not shown.
After Width: | Height: | Size: 5.1 MiB |
|
@ -51,10 +51,10 @@ Optimizer:
|
|||
one_dim_param_no_weight_decay: True
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 1e-4
|
||||
eta_min: 2e-6
|
||||
learning_rate: 5e-5
|
||||
eta_min: 1e-6
|
||||
warmup_epoch: 5
|
||||
warmup_start_lr: 2e-7
|
||||
warmup_start_lr: 1e-7
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
|
|
|
@ -371,6 +371,11 @@ def run(dataloader,
|
|||
"Except RuntimeError when reading data from dataloader, try to read once again..."
|
||||
)
|
||||
continue
|
||||
except IndexError:
|
||||
logger.warning(
|
||||
"Except IndexError when reading data from dataloader, try to read once again..."
|
||||
)
|
||||
continue
|
||||
idx += 1
|
||||
# ignore the warmup iters
|
||||
if idx == 5:
|
||||
|
|
|
@ -112,4 +112,5 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/MobileNetV3/Mo
|
|||
- [test_lite_arm_cpu_cpp 使用](docs/test_lite_arm_cpu_cpp.md): 测试基于Paddle-Lite的ARM CPU端c++预测部署功能.
|
||||
- [test_paddle2onnx 使用](docs/test_paddle2onnx.md):测试Paddle2ONNX的模型转化功能,并验证正确性。
|
||||
- [test_serving_infer_python 使用](docs/test_serving_infer_python.md):测试python serving功能。
|
||||
- [test_serving_infer_cpp 使用](docs/test_serving_infer_cpp.md):测试cpp serving功能。
|
||||
- [test_train_fleet_inference_python 使用](./docs/test_train_fleet_inference_python.md):测试基于Python的多机多卡训练与推理等基本功能。
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:MobileNetV3_large_x1_0
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/MobileNetV3_large_x1_0_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/MobileNetV3_large_x1_0_serving/
|
||||
--serving_client:./deploy/paddleserving/MobileNetV3_large_x1_0_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,18 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPShiTu
|
||||
python:python3.7
|
||||
cls_inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
|
||||
det_inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./models/general_PPLCNet_x2_5_lite_v1.0_infer/
|
||||
--dirname:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./models/general_PPLCNet_x2_5_lite_v1.0_serving/
|
||||
--serving_client:./models/general_PPLCNet_x2_5_lite_v1.0_client/
|
||||
--serving_server:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/
|
||||
--serving_client:./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/
|
||||
serving_dir:./paddleserving/recognition
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPHGNet_small
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPHGNet_small_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPHGNet_small_serving/
|
||||
--serving_client:./deploy/paddleserving/PPHGNet_small_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:classification_web_service.py
|
||||
--use_gpu:0|null
|
||||
pipline:pipeline_http_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPHGNet_tiny
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPHGNet_tiny_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPHGNet_tiny_serving/
|
||||
--serving_client:./deploy/paddleserving/PPHGNet_tiny_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:classification_web_service.py
|
||||
--use_gpu:0|null
|
||||
pipline:pipeline_http_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x0_25
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_25_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x0_25_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x0_25_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x0_25_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x0_35
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_35_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x0_35_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x0_35_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x0_35_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x0_5
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_5_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x0_5_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x0_5_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x0_5_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x0_75
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_75_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x0_75_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x0_75_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x0_75_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x1_0
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_0_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x1_0_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x1_0_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x1_0_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x1_5
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_5_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x1_5_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x1_5_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x1_5_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x2_0
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x2_0_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x2_0_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x2_0_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNet_x2_5
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNet_x2_5_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNet_x2_5_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNet_x2_5_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:PPLCNetV2_base
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/PPLCNetV2_base_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/PPLCNetV2_base_serving/
|
||||
--serving_client:./deploy/paddleserving/PPLCNetV2_base_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:ResNet50
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/ResNet50_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/ResNet50_serving/
|
||||
--serving_client:./deploy/paddleserving/ResNet50_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:ResNet50_vd
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/ResNet50_vd_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/ResNet50_vd_serving/
|
||||
--serving_client:./deploy/paddleserving/ResNet50_vd_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,14 @@
|
|||
===========================serving_params===========================
|
||||
model_name:SwinTransformer_tiny_patch4_window7_224
|
||||
python:python3.7
|
||||
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/SwinTransformer_tiny_patch4_window7_224_infer.tar
|
||||
trans_model:-m paddle_serving_client.convert
|
||||
--dirname:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_infer/
|
||||
--model_filename:inference.pdmodel
|
||||
--params_filename:inference.pdiparams
|
||||
--serving_server:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_serving/
|
||||
--serving_client:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_client/
|
||||
serving_dir:./deploy/paddleserving
|
||||
web_service:null
|
||||
--use_gpu:0|null
|
||||
pipline:test_cpp_serving_client.py
|
|
@ -0,0 +1,91 @@
|
|||
# Linux GPU/CPU PYTHON 服务化部署测试
|
||||
|
||||
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer_cpp.sh`,可以测试基于Python的模型服务化部署功能。
|
||||
|
||||
|
||||
## 1. 测试结论汇总
|
||||
|
||||
- 推理相关:
|
||||
|
||||
| 算法名称 | 模型名称 | device_CPU | device_GPU |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| MobileNetV3 | MobileNetV3_large_x1_0 | 支持 | 支持 |
|
||||
| PP-ShiTu | PPShiTu_general_rec、PPShiTu_mainbody_det | 支持 | 支持 |
|
||||
| PPHGNet | PPHGNet_small | 支持 | 支持 |
|
||||
| PPHGNet | PPHGNet_tiny | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x0_25 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x0_35 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x0_5 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x0_75 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x1_0 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x1_5 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x2_0 | 支持 | 支持 |
|
||||
| PPLCNet | PPLCNet_x2_5 | 支持 | 支持 |
|
||||
| PPLCNetV2 | PPLCNetV2_base | 支持 | 支持 |
|
||||
| ResNet | ResNet50 | 支持 | 支持 |
|
||||
| ResNet | ResNet50_vd | 支持 | 支持 |
|
||||
| SwinTransformer | SwinTransformer_tiny_patch4_window7_224 | 支持 | 支持 |
|
||||
|
||||
|
||||
## 2. 测试流程
|
||||
|
||||
### 2.1 准备数据
|
||||
|
||||
分类模型默认使用`./deploy/paddleserving/daisy.jpg`作为测试输入图片,无需下载
|
||||
识别模型默认使用`drink_dataset_v1.0/test_images/001.jpeg`作为测试输入图片,在**2.2 准备环境**中会下载好。
|
||||
|
||||
### 2.2 准备环境
|
||||
|
||||
|
||||
- 安装PaddlePaddle:如果您已经安装了2.2或者以上版本的paddlepaddle,那么无需运行下面的命令安装paddlepaddle。
|
||||
```shell
|
||||
# 需要安装2.2及以上版本的Paddle
|
||||
# 安装GPU版本的Paddle
|
||||
python3.7 -m pip install paddlepaddle-gpu==2.2.0
|
||||
# 安装CPU版本的Paddle
|
||||
python3.7 -m pip install paddlepaddle==2.2.0
|
||||
```
|
||||
|
||||
- 安装依赖
|
||||
```shell
|
||||
python3.7 -m pip install -r requirements.txt
|
||||
```
|
||||
- 安装 PaddleServing 相关组件,包括serving_client、serving-app,自动编译并安装带自定义OP的 serving_server 包,以及自动下载并解压推理模型
|
||||
```bash
|
||||
bash test_tipc/prepare.sh test_tipc/configs/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt serving_infer
|
||||
```
|
||||
|
||||
### 2.3 功能测试
|
||||
|
||||
测试方法如下所示,希望测试不同的模型文件,只需更换为自己的参数配置文件,即可完成对应模型的测试。
|
||||
|
||||
```bash
|
||||
bash test_tipc/test_serving_infer_cpp.sh ${your_params_file}
|
||||
```
|
||||
|
||||
以`PPLCNet_x1_0`的`Linux GPU/CPU C++ 服务化部署测试`为例,命令如下所示。
|
||||
|
||||
|
||||
```bash
|
||||
bash test_tipc/test_serving_infer_cpp.sh test_tipc/configs/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
|
||||
```
|
||||
|
||||
输出结果如下,表示命令运行成功。
|
||||
|
||||
```
|
||||
Run successfully with command - PPLCNet_x1_0 - python3.7 test_cpp_serving_client.py > ../../test_tipc/output/PPLCNet_x1_0/server_infer_cpp_gpu_pipeline_batchsize_1.log 2>&1 !
|
||||
Run successfully with command - PPLCNet_x1_0 - python3.7 test_cpp_serving_client.py > ../../test_tipc/output/PPLCNet_x1_0/server_infer_cpp_cpu_pipeline_batchsize_1.log 2>&1 !
|
||||
```
|
||||
|
||||
预测结果会自动保存在 `./test_tipc/output/PPLCNet_x1_0/server_infer_gpu_pipeline_http_batchsize_1.log` ,可以看到 PaddleServing 的运行结果:
|
||||
|
||||
```
|
||||
WARNING: Logging before InitGoogleLogging() is written to STDERR
|
||||
I0612 09:55:16.109890 38303 naming_service_thread.cpp:202] brpc::policy::ListNamingService("127.0.0.1:9292"): added 1
|
||||
I0612 09:55:16.172924 38303 general_model.cpp:490] [client]logid=0,client_cost=60.772ms,server_cost=57.6ms.
|
||||
prediction: daisy, probability: 0.9099399447441101
|
||||
0.06275796890258789
|
||||
```
|
||||
|
||||
|
||||
如果运行失败,也会在终端中输出运行失败的日志信息以及对应的运行命令。可以基于该命令,分析运行失败的原因。
|
|
@ -1,6 +1,6 @@
|
|||
# Linux GPU/CPU PYTHON 服务化部署测试
|
||||
|
||||
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer.sh`,可以测试基于Python的模型服务化部署功能。
|
||||
Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer_python.sh`,可以测试基于Python的模型服务化部署功能。
|
||||
|
||||
|
||||
## 1. 测试结论汇总
|
||||
|
@ -60,14 +60,14 @@ Linux GPU/CPU PYTHON 服务化部署测试的主程序为`test_serving_infer.sh
|
|||
测试方法如下所示,希望测试不同的模型文件,只需更换为自己的参数配置文件,即可完成对应模型的测试。
|
||||
|
||||
```bash
|
||||
bash test_tipc/test_serving_infer_python.sh ${your_params_file} lite_train_lite_infer
|
||||
bash test_tipc/test_serving_infer_python.sh ${your_params_file}
|
||||
```
|
||||
|
||||
以`ResNet50`的`Linux GPU/CPU PYTHON 服务化部署测试`为例,命令如下所示。
|
||||
|
||||
|
||||
```bash
|
||||
bash test_tipc/test_serving_infer_python.sh test_tipc/configs/ResNet50/ResNet50_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt serving_infer
|
||||
bash test_tipc/test_serving_infer_python.sh test_tipc/configs/ResNet50/ResNet50_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
|
||||
```
|
||||
|
||||
输出结果如下,表示命令运行成功。
|
||||
|
|
|
@ -200,15 +200,25 @@ fi
|
|||
if [[ ${MODE} = "serving_infer" ]]; then
|
||||
# prepare serving env
|
||||
python_name=$(func_parser_value "${lines[2]}")
|
||||
${python_name} -m pip install paddle-serving-server-gpu==0.7.0.post102
|
||||
${python_name} -m pip install paddle_serving_client==0.7.0
|
||||
${python_name} -m pip install paddle-serving-app==0.7.0
|
||||
${python_name} -m pip install paddle_serving_client==0.9.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
${python_name} -m pip install paddle-serving-app==0.9.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
python_name=$(func_parser_value "${lines[2]}")
|
||||
if [[ ${FILENAME} =~ "cpp" ]]; then
|
||||
pushd ./deploy/paddleserving
|
||||
bash build_server.sh ${python_name}
|
||||
popd
|
||||
else
|
||||
${python_name} -m pip install install paddle-serving-server-gpu==0.9.0.post101 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
fi
|
||||
if [[ ${model_name} =~ "ShiTu" ]]; then
|
||||
${python_name} -m pip install faiss-cpu==1.7.1post2 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
cls_inference_model_url=$(func_parser_value "${lines[3]}")
|
||||
cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}")
|
||||
det_inference_model_url=$(func_parser_value "${lines[4]}")
|
||||
det_tar_name=$(func_get_url_file_name "${det_inference_model_url}")
|
||||
cd ./deploy
|
||||
wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar --no-check-certificate
|
||||
tar -xf drink_dataset_v1.0.tar
|
||||
mkdir models
|
||||
cd models
|
||||
wget -nc ${cls_inference_model_url} && tar xf ${cls_tar_name}
|
||||
|
|
|
@ -8,5 +8,11 @@ num_workers=8
|
|||
|
||||
# get data
|
||||
bash test_tipc/static/${model_item}/benchmark_common/prepare.sh
|
||||
|
||||
cd ./dataset/ILSVRC2012
|
||||
cat train_list.txt >> tmp
|
||||
for i in {1..10}; do cat tmp >> train_list.txt; done
|
||||
cd ../../
|
||||
|
||||
# run
|
||||
bash test_tipc/static/${model_item}/benchmark_common/run_benchmark.sh ${model_item} ${bs_item} ${fp_item} ${run_mode} ${device_num} ${max_epochs} ${num_workers} 2>&1;
|
||||
|
|
|
@ -8,5 +8,11 @@ num_workers=8
|
|||
|
||||
# get data
|
||||
bash test_tipc/static/${model_item}/benchmark_common/prepare.sh
|
||||
|
||||
cd ./dataset/ILSVRC2012
|
||||
cat train_list.txt >> tmp
|
||||
for i in {1..10}; do cat tmp >> train_list.txt; done
|
||||
cd ../../
|
||||
|
||||
# run
|
||||
bash test_tipc/static/${model_item}/benchmark_common/run_benchmark.sh ${model_item} ${bs_item} ${fp_item} ${run_mode} ${device_num} ${max_epochs} ${num_workers} 2>&1;
|
||||
|
|
|
@ -8,5 +8,11 @@ num_workers=8
|
|||
|
||||
# get data
|
||||
bash test_tipc/static/${model_item}/benchmark_common/prepare.sh
|
||||
|
||||
cd ./dataset/ILSVRC2012
|
||||
cat train_list.txt >> tmp
|
||||
for i in {1..10}; do cat tmp >> train_list.txt; done
|
||||
cd ../../
|
||||
|
||||
# run
|
||||
bash test_tipc/static/${model_item}/benchmark_common/run_benchmark.sh ${model_item} ${bs_item} ${fp_item} ${run_mode} ${device_num} ${max_epochs} ${num_workers} 2>&1;
|
||||
|
|
|
@ -1,353 +0,0 @@
|
|||
#!/bin/bash
|
||||
source test_tipc/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
function func_get_url_file_name(){
|
||||
strs=$1
|
||||
IFS="/"
|
||||
array=(${strs})
|
||||
tmp=${array[${#array[@]}-1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
# parser serving
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
python=$(func_parser_value "${lines[2]}")
|
||||
trans_model_py=$(func_parser_value "${lines[4]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[5]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[5]}")
|
||||
model_filename_key=$(func_parser_key "${lines[6]}")
|
||||
model_filename_value=$(func_parser_value "${lines[6]}")
|
||||
params_filename_key=$(func_parser_key "${lines[7]}")
|
||||
params_filename_value=$(func_parser_value "${lines[7]}")
|
||||
serving_server_key=$(func_parser_key "${lines[8]}")
|
||||
serving_server_value=$(func_parser_value "${lines[8]}")
|
||||
serving_client_key=$(func_parser_key "${lines[9]}")
|
||||
serving_client_value=$(func_parser_value "${lines[9]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[10]}")
|
||||
web_service_py=$(func_parser_value "${lines[11]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[12]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[12]}")
|
||||
pipeline_py=$(func_parser_value "${lines[13]}")
|
||||
|
||||
|
||||
function func_serving_cls(){
|
||||
LOG_PATH="../../test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
IFS='|'
|
||||
|
||||
# pdserving
|
||||
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
||||
|
||||
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval $trans_model_cmd
|
||||
|
||||
# modify the alias_name of fetch_var to "outputs"
|
||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
|
||||
eval ${server_fetch_var_line_cmd}
|
||||
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
|
||||
eval ${client_fetch_var_line_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
|
||||
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||
eval ${set_web_service_feet_var_cmd}
|
||||
|
||||
model_config=21
|
||||
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
||||
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
|
||||
eval ${set_model_config_cmd}
|
||||
|
||||
for python in ${python[*]}; do
|
||||
if [[ ${python} = "cpp" ]]; then
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "null" ]; then
|
||||
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293"
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
||||
eval $pipeline_cmd
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
else
|
||||
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293 --gpu_id=0"
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
||||
eval $pipeline_cmd
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
fi
|
||||
done
|
||||
else
|
||||
# python serving
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [[ ${use_gpu} = "null" ]]; then
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||
eval $set_device_type_cmd
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||
eval $set_devices_cmd
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} &"
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
elif [ ${use_gpu} -eq 0 ]; then
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||
eval $set_device_type_cmd
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||
eval $set_devices_cmd
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} & "
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
else
|
||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
function func_serving_rec(){
|
||||
LOG_PATH="../../../test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
trans_model_py=$(func_parser_value "${lines[5]}")
|
||||
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
|
||||
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
|
||||
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
|
||||
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
|
||||
model_filename_key=$(func_parser_key "${lines[8]}")
|
||||
model_filename_value=$(func_parser_value "${lines[8]}")
|
||||
params_filename_key=$(func_parser_key "${lines[9]}")
|
||||
params_filename_value=$(func_parser_value "${lines[9]}")
|
||||
|
||||
cls_serving_server_key=$(func_parser_key "${lines[10]}")
|
||||
cls_serving_server_value=$(func_parser_value "${lines[10]}")
|
||||
cls_serving_client_key=$(func_parser_key "${lines[11]}")
|
||||
cls_serving_client_value=$(func_parser_value "${lines[11]}")
|
||||
|
||||
det_serving_server_key=$(func_parser_key "${lines[12]}")
|
||||
det_serving_server_value=$(func_parser_value "${lines[12]}")
|
||||
det_serving_client_key=$(func_parser_key "${lines[13]}")
|
||||
det_serving_client_value=$(func_parser_value "${lines[13]}")
|
||||
|
||||
serving_dir_value=$(func_parser_value "${lines[14]}")
|
||||
web_service_py=$(func_parser_value "${lines[15]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[16]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[16]}")
|
||||
pipeline_py=$(func_parser_value "${lines[17]}")
|
||||
|
||||
IFS='|'
|
||||
|
||||
# pdserving
|
||||
cd ./deploy
|
||||
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
|
||||
cls_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval $cls_trans_model_cmd
|
||||
|
||||
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
|
||||
det_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval $det_trans_model_cmd
|
||||
|
||||
# modify the alias_name of fetch_var to "outputs"
|
||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_server_value/serving_server_conf.prototxt"
|
||||
eval ${server_fetch_var_line_cmd}
|
||||
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_client_value/serving_client_conf.prototxt"
|
||||
eval ${client_fetch_var_line_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
|
||||
set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||
eval ${set_web_service_feet_var_cmd}
|
||||
|
||||
for python in ${python[*]}; do
|
||||
if [[ ${python} = "cpp" ]]; then
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "null" ]; then
|
||||
web_service_cpp_cmd="${python} web_service_py"
|
||||
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
||||
eval $pipeline_cmd
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
else
|
||||
web_service_cpp_cmd="${python} web_service_py"
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log"
|
||||
pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/"
|
||||
eval $pipeline_cmd
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
fi
|
||||
done
|
||||
else
|
||||
# python serving
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [[ ${use_gpu} = "null" ]]; then
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||
eval $set_device_type_cmd
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||
eval $set_devices_cmd
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} &"
|
||||
eval $web_service_cmd
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
elif [ ${use_gpu} -eq 0 ]; then
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||
eval $set_device_type_cmd
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||
eval $set_devices_cmd
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} & "
|
||||
eval $web_service_cmd
|
||||
sleep 10s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 10s
|
||||
done
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
else
|
||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval $env
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
||||
|
||||
export Count=0
|
||||
IFS="|"
|
||||
if [[ ${model_name} =~ "ShiTu" ]]; then
|
||||
func_serving_rec
|
||||
else
|
||||
func_serving_cls
|
||||
fi
|
|
@ -0,0 +1,262 @@
|
|||
#!/bin/bash
|
||||
source test_tipc/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
function func_get_url_file_name(){
|
||||
strs=$1
|
||||
IFS="/"
|
||||
array=(${strs})
|
||||
tmp=${array[${#array[@]}-1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
# parser serving
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
python=$(func_parser_value "${lines[2]}")
|
||||
trans_model_py=$(func_parser_value "${lines[4]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[5]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[5]}")
|
||||
model_filename_key=$(func_parser_key "${lines[6]}")
|
||||
model_filename_value=$(func_parser_value "${lines[6]}")
|
||||
params_filename_key=$(func_parser_key "${lines[7]}")
|
||||
params_filename_value=$(func_parser_value "${lines[7]}")
|
||||
serving_server_key=$(func_parser_key "${lines[8]}")
|
||||
serving_server_value=$(func_parser_value "${lines[8]}")
|
||||
serving_client_key=$(func_parser_key "${lines[9]}")
|
||||
serving_client_value=$(func_parser_value "${lines[9]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[10]}")
|
||||
web_service_py=$(func_parser_value "${lines[11]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[12]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[12]}")
|
||||
pipeline_py=$(func_parser_value "${lines[13]}")
|
||||
|
||||
|
||||
function func_serving_cls(){
|
||||
LOG_PATH="test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
LOG_PATH="../../${LOG_PATH}"
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
IFS='|'
|
||||
|
||||
# pdserving
|
||||
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
||||
|
||||
for python_ in ${python[*]}; do
|
||||
if [[ ${python_} =~ "python" ]]; then
|
||||
trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${trans_model_cmd}
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# modify the alias_name of fetch_var to "outputs"
|
||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
|
||||
eval ${server_fetch_var_line_cmd}
|
||||
|
||||
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
|
||||
eval ${client_fetch_var_line_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
for item in ${python[*]}; do
|
||||
if [[ ${item} =~ "python" ]]; then
|
||||
python_=${item}
|
||||
break
|
||||
fi
|
||||
done
|
||||
serving_client_dir_name=$(func_get_url_file_name "$serving_client_value")
|
||||
set_client_feed_type_cmd="sed -i '/feed_type/,/: .*/s/feed_type: .*/feed_type: 20/' ${serving_client_dir_name}/serving_client_conf.prototxt"
|
||||
eval ${set_client_feed_type_cmd}
|
||||
set_client_shape_cmd="sed -i '/shape: 3/,/shape: 3/s/shape: 3/shape: 1/' ${serving_client_dir_name}/serving_client_conf.prototxt"
|
||||
eval ${set_client_shape_cmd}
|
||||
set_client_shape224_cmd="sed -i '/shape: 224/,/shape: 224/s/shape: 224//' ${serving_client_dir_name}/serving_client_conf.prototxt"
|
||||
eval ${set_client_shape224_cmd}
|
||||
set_client_shape224_cmd="sed -i '/shape: 224/,/shape: 224/s/shape: 224//' ${serving_client_dir_name}/serving_client_conf.prototxt"
|
||||
eval ${set_client_shape224_cmd}
|
||||
|
||||
set_pipeline_load_config_cmd="sed -i '/load_client_config/,/.prototxt/s/.\/.*\/serving_client_conf.prototxt/.\/${serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}"
|
||||
eval ${set_pipeline_load_config_cmd}
|
||||
|
||||
set_pipeline_feed_var_cmd="sed -i '/feed=/,/: image}/s/feed={.*: image}/feed={${feed_var_name}: image}/' ${pipeline_py}"
|
||||
eval ${set_pipeline_feed_var_cmd}
|
||||
|
||||
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
||||
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [[ ${use_gpu} = "null" ]]; then
|
||||
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --op GeneralClasOp --port 9292 &"
|
||||
eval ${web_service_cpp_cmd}
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_batchsize_1.log"
|
||||
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
sleep 5s
|
||||
else
|
||||
web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --op GeneralClasOp --port 9292 --gpu_id=${use_gpu} &"
|
||||
eval ${web_service_cpp_cmd}
|
||||
sleep 8s
|
||||
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_pipeline_batchsize_1.log"
|
||||
pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
function func_serving_rec(){
|
||||
LOG_PATH="test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
LOG_PATH="../../../${LOG_PATH}"
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
trans_model_py=$(func_parser_value "${lines[5]}")
|
||||
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
|
||||
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
|
||||
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
|
||||
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
|
||||
model_filename_key=$(func_parser_key "${lines[8]}")
|
||||
model_filename_value=$(func_parser_value "${lines[8]}")
|
||||
params_filename_key=$(func_parser_key "${lines[9]}")
|
||||
params_filename_value=$(func_parser_value "${lines[9]}")
|
||||
|
||||
cls_serving_server_key=$(func_parser_key "${lines[10]}")
|
||||
cls_serving_server_value=$(func_parser_value "${lines[10]}")
|
||||
cls_serving_client_key=$(func_parser_key "${lines[11]}")
|
||||
cls_serving_client_value=$(func_parser_value "${lines[11]}")
|
||||
|
||||
det_serving_server_key=$(func_parser_key "${lines[12]}")
|
||||
det_serving_server_value=$(func_parser_value "${lines[12]}")
|
||||
det_serving_client_key=$(func_parser_key "${lines[13]}")
|
||||
det_serving_client_value=$(func_parser_value "${lines[13]}")
|
||||
|
||||
serving_dir_value=$(func_parser_value "${lines[14]}")
|
||||
web_service_py=$(func_parser_value "${lines[15]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[16]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[16]}")
|
||||
pipeline_py=$(func_parser_value "${lines[17]}")
|
||||
|
||||
IFS='|'
|
||||
for python_ in ${python[*]}; do
|
||||
if [[ ${python_} =~ "python" ]]; then
|
||||
python_interp=${python_}
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# pdserving
|
||||
cd ./deploy
|
||||
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
|
||||
cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${cls_trans_model_cmd}
|
||||
|
||||
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
|
||||
det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${det_trans_model_cmd}
|
||||
|
||||
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/general_PPLCNet_x2_5_lite_v1.0_serving/*.prototxt ${cls_serving_server_value}"
|
||||
eval ${cp_prototxt_cmd}
|
||||
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/general_PPLCNet_x2_5_lite_v1.0_client/*.prototxt ${cls_serving_client_value}"
|
||||
eval ${cp_prototxt_cmd}
|
||||
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_client/*.prototxt ${det_serving_client_value}"
|
||||
eval ${cp_prototxt_cmd}
|
||||
cp_prototxt_cmd="cp ./paddleserving/recognition/preprocess/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_serving/*.prototxt ${det_serving_server_value}"
|
||||
eval ${cp_prototxt_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
export SERVING_BIN=${PWD}/../Serving/server-build-gpu-opencv/core/general-server/serving
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "null" ]; then
|
||||
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
|
||||
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../${det_serving_server_value} ../../${cls_serving_server_value} --op GeneralPicodetOp GeneralFeatureExtractOp --port 9400 &"
|
||||
eval ${web_service_cpp_cmd}
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_cpu_batchsize_1.log"
|
||||
pipeline_cmd="${python_interp} ${pipeline_py} > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
sleep 5s
|
||||
else
|
||||
det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value")
|
||||
web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../${det_serving_server_value} ../../${cls_serving_server_value} --op GeneralPicodetOp GeneralFeatureExtractOp --port 9400 --gpu_id=${use_gpu} &"
|
||||
eval ${web_service_cpp_cmd}
|
||||
sleep 5s
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpp_gpu_batchsize_1.log"
|
||||
pipeline_cmd="${python_interp} ${pipeline_py} > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
sleep 5s
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval ${env}
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
||||
|
||||
export Count=0
|
||||
IFS="|"
|
||||
if [[ ${model_name} =~ "ShiTu" ]]; then
|
||||
func_serving_rec
|
||||
else
|
||||
func_serving_cls
|
||||
fi
|
|
@ -0,0 +1,309 @@
|
|||
#!/bin/bash
|
||||
source test_tipc/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==1, NR==19{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
function func_get_url_file_name(){
|
||||
strs=$1
|
||||
IFS="/"
|
||||
array=(${strs})
|
||||
tmp=${array[${#array[@]}-1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
# parser serving
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
python=$(func_parser_value "${lines[2]}")
|
||||
trans_model_py=$(func_parser_value "${lines[4]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[5]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[5]}")
|
||||
model_filename_key=$(func_parser_key "${lines[6]}")
|
||||
model_filename_value=$(func_parser_value "${lines[6]}")
|
||||
params_filename_key=$(func_parser_key "${lines[7]}")
|
||||
params_filename_value=$(func_parser_value "${lines[7]}")
|
||||
serving_server_key=$(func_parser_key "${lines[8]}")
|
||||
serving_server_value=$(func_parser_value "${lines[8]}")
|
||||
serving_client_key=$(func_parser_key "${lines[9]}")
|
||||
serving_client_value=$(func_parser_value "${lines[9]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[10]}")
|
||||
web_service_py=$(func_parser_value "${lines[11]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[12]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[12]}")
|
||||
pipeline_py=$(func_parser_value "${lines[13]}")
|
||||
|
||||
|
||||
function func_serving_cls(){
|
||||
LOG_PATH="test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
LOG_PATH="../../${LOG_PATH}"
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
IFS='|'
|
||||
|
||||
# pdserving
|
||||
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
||||
|
||||
for python_ in ${python[*]}; do
|
||||
if [[ ${python_} =~ "python" ]]; then
|
||||
trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${trans_model_cmd}
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# modify the alias_name of fetch_var to "outputs"
|
||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt"
|
||||
eval ${server_fetch_var_line_cmd}
|
||||
|
||||
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_client_value}/serving_client_conf.prototxt"
|
||||
eval ${client_fetch_var_line_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
# python serving
|
||||
# modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt
|
||||
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||
eval ${set_web_service_feed_var_cmd}
|
||||
|
||||
model_config=21
|
||||
serving_server_dir_name=$(func_get_url_file_name "$serving_server_value")
|
||||
set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml"
|
||||
eval ${set_model_config_cmd}
|
||||
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [[ ${use_gpu} = "null" ]]; then
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||
eval ${set_device_type_cmd}
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||
eval ${set_devices_cmd}
|
||||
|
||||
web_service_cmd="${python_} ${web_service_py} &"
|
||||
eval ${web_service_cmd}
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
elif [ ${use_gpu} -eq 0 ]; then
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||
eval ${set_device_type_cmd}
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||
eval ${set_devices_cmd}
|
||||
|
||||
web_service_cmd="${python_} ${web_service_py} & "
|
||||
eval ${web_service_cmd}
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1"
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
else
|
||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
function func_serving_rec(){
|
||||
LOG_PATH="test_tipc/output/${model_name}"
|
||||
mkdir -p ${LOG_PATH}
|
||||
LOG_PATH="../../../${LOG_PATH}"
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
trans_model_py=$(func_parser_value "${lines[5]}")
|
||||
cls_infer_model_dir_key=$(func_parser_key "${lines[6]}")
|
||||
cls_infer_model_dir_value=$(func_parser_value "${lines[6]}")
|
||||
det_infer_model_dir_key=$(func_parser_key "${lines[7]}")
|
||||
det_infer_model_dir_value=$(func_parser_value "${lines[7]}")
|
||||
model_filename_key=$(func_parser_key "${lines[8]}")
|
||||
model_filename_value=$(func_parser_value "${lines[8]}")
|
||||
params_filename_key=$(func_parser_key "${lines[9]}")
|
||||
params_filename_value=$(func_parser_value "${lines[9]}")
|
||||
|
||||
cls_serving_server_key=$(func_parser_key "${lines[10]}")
|
||||
cls_serving_server_value=$(func_parser_value "${lines[10]}")
|
||||
cls_serving_client_key=$(func_parser_key "${lines[11]}")
|
||||
cls_serving_client_value=$(func_parser_value "${lines[11]}")
|
||||
|
||||
det_serving_server_key=$(func_parser_key "${lines[12]}")
|
||||
det_serving_server_value=$(func_parser_value "${lines[12]}")
|
||||
det_serving_client_key=$(func_parser_key "${lines[13]}")
|
||||
det_serving_client_value=$(func_parser_value "${lines[13]}")
|
||||
|
||||
serving_dir_value=$(func_parser_value "${lines[14]}")
|
||||
web_service_py=$(func_parser_value "${lines[15]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[16]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[16]}")
|
||||
pipeline_py=$(func_parser_value "${lines[17]}")
|
||||
|
||||
IFS='|'
|
||||
for python_ in ${python[*]}; do
|
||||
if [[ ${python_} =~ "python" ]]; then
|
||||
python_interp=${python_}
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# pdserving
|
||||
cd ./deploy
|
||||
set_dirname=$(func_set_params "${cls_infer_model_dir_key}" "${cls_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}")
|
||||
cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${cls_trans_model_cmd}
|
||||
|
||||
set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}")
|
||||
det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval ${det_trans_model_cmd}
|
||||
|
||||
# modify the alias_name of fetch_var to "outputs"
|
||||
server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_server_value/serving_server_conf.prototxt"
|
||||
eval ${server_fetch_var_line_cmd}
|
||||
client_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"features\"/' $cls_serving_client_value/serving_client_conf.prototxt"
|
||||
eval ${client_fetch_var_line_cmd}
|
||||
|
||||
prototxt_dataline=$(awk 'NR==1, NR==3{print}' ${cls_serving_server_value}/serving_server_conf.prototxt)
|
||||
IFS=$'\n'
|
||||
prototxt_lines=(${prototxt_dataline})
|
||||
feed_var_name=$(func_parser_value "${prototxt_lines[2]}")
|
||||
IFS='|'
|
||||
|
||||
cd ${serving_dir_value}
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
|
||||
# modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt
|
||||
set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}"
|
||||
eval ${set_web_service_feed_var_cmd}
|
||||
# python serving
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
if [[ ${use_gpu} = "null" ]]; then
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml"
|
||||
eval ${set_device_type_cmd}
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml"
|
||||
eval ${set_devices_cmd}
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} &"
|
||||
eval ${web_service_cmd}
|
||||
sleep 5s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 "
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 5s
|
||||
done
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
elif [ ${use_gpu} -eq 0 ]; then
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
device_type_line=24
|
||||
set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml"
|
||||
eval ${set_device_type_cmd}
|
||||
|
||||
devices_line=27
|
||||
set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml"
|
||||
eval ${set_devices_cmd}
|
||||
|
||||
web_service_cmd="${python} ${web_service_py} & "
|
||||
eval ${web_service_cmd}
|
||||
sleep 10s
|
||||
for pipeline in ${pipeline_py[*]}; do
|
||||
_save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log"
|
||||
pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1"
|
||||
eval ${pipeline_cmd}
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
|
||||
sleep 10s
|
||||
done
|
||||
eval "${python_} -m paddle_serving_server.serve stop"
|
||||
else
|
||||
echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval ${env}
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
||||
|
||||
export Count=0
|
||||
IFS="|"
|
||||
if [[ ${model_name} =~ "ShiTu" ]]; then
|
||||
func_serving_rec
|
||||
else
|
||||
func_serving_cls
|
||||
fi
|
Loading…
Reference in New Issue