Update HubServing (#870)
* Update HubServing * Fix the relative path * Update doc of HubServingpull/900/head
parent
d63e0e88cd
commit
666d912678
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,15 +18,14 @@ sys.path.insert(0, ".")
|
|||
|
||||
import time
|
||||
|
||||
from paddlehub.utils.log import logger
|
||||
from paddlehub.module.module import moduleinfo, serving
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle.nn as nn
|
||||
from paddlehub.module.module import moduleinfo, serving
|
||||
|
||||
from tools.infer.predict import Predictor
|
||||
from tools.infer.utils import b64_to_np, postprocess
|
||||
from deploy.hubserving.clas.params import read_params
|
||||
from hubserving.clas.params import get_default_confg
|
||||
from python.predict_cls import ClsPredictor
|
||||
from utils import config
|
||||
from utils.encode_decode import b64_to_np
|
||||
|
||||
|
||||
@moduleinfo(
|
||||
|
@ -41,19 +40,24 @@ class ClasSystem(nn.Layer):
|
|||
"""
|
||||
initialize with the necessary elements
|
||||
"""
|
||||
cfg = read_params()
|
||||
self._config = self._load_config(
|
||||
use_gpu=use_gpu, enable_mkldnn=enable_mkldnn)
|
||||
self.cls_predictor = ClsPredictor(self._config)
|
||||
|
||||
def _load_config(self, use_gpu=None, enable_mkldnn=None):
|
||||
cfg = get_default_confg()
|
||||
cfg = config.AttrDict(cfg)
|
||||
config.create_attr_dict(cfg)
|
||||
if use_gpu is not None:
|
||||
cfg.use_gpu = use_gpu
|
||||
cfg.Global.use_gpu = use_gpu
|
||||
if enable_mkldnn is not None:
|
||||
cfg.enable_mkldnn = enable_mkldnn
|
||||
cfg.hubserving = True
|
||||
cfg.Global.enable_mkldnn = enable_mkldnn
|
||||
cfg.enable_benchmark = False
|
||||
self.args = cfg
|
||||
if cfg.use_gpu:
|
||||
if cfg.Global.use_gpu:
|
||||
try:
|
||||
_places = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
int(_places[0])
|
||||
print("Use GPU, GPU Memery:{}".format(cfg.gpu_mem))
|
||||
print("Use GPU, GPU Memery:{}".format(cfg.Global.gpu_mem))
|
||||
print("CUDA_VISIBLE_DEVICES: ", _places)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
|
@ -62,24 +66,36 @@ class ClasSystem(nn.Layer):
|
|||
else:
|
||||
print("Use CPU")
|
||||
print("Enable MKL-DNN") if enable_mkldnn else None
|
||||
self.predictor = Predictor(self.args)
|
||||
return cfg
|
||||
|
||||
def predict(self, batch_input_data, top_k=1):
|
||||
assert isinstance(
|
||||
batch_input_data,
|
||||
np.ndarray), "The input data is inconsistent with expectations."
|
||||
def predict(self, inputs):
|
||||
if not isinstance(inputs, list):
|
||||
raise Exception(
|
||||
"The input data is inconsistent with expectations.")
|
||||
|
||||
starttime = time.time()
|
||||
batch_outputs = self.predictor.predict(batch_input_data)
|
||||
outputs = self.cls_predictor.predict(inputs)
|
||||
elapse = time.time() - starttime
|
||||
batch_result_list = postprocess(batch_outputs, top_k)
|
||||
return {"prediction": batch_result_list, "elapse": elapse}
|
||||
preds = self.cls_predictor.postprocess(outputs)
|
||||
return {"prediction": preds, "elapse": elapse}
|
||||
|
||||
@serving
|
||||
def serving_method(self, images, revert_params, **kwargs):
|
||||
def serving_method(self, images, revert_params):
|
||||
"""
|
||||
Run as a service.
|
||||
"""
|
||||
input_data = b64_to_np(images, revert_params)
|
||||
results = self.predict(batch_input_data=input_data, **kwargs)
|
||||
results = self.predict(inputs=list(input_data))
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import cv2
|
||||
import paddlehub as hub
|
||||
|
||||
module = hub.Module(name="clas_system")
|
||||
img_path = "./hubserving/ILSVRC2012_val_00006666.JPEG"
|
||||
img = cv2.imread(img_path)[:, :, ::-1]
|
||||
img = cv2.resize(img, (224, 224)).transpose((2, 0, 1))
|
||||
res = module.predict([img.astype(np.float32)])
|
||||
print("The returned result of {}: {}".format(img_path, res))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,28 +17,24 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
|
||||
class Config(object):
|
||||
pass
|
||||
|
||||
|
||||
def read_params():
|
||||
cfg = Config()
|
||||
|
||||
cfg.model_file = "./inference/cls_infer.pdmodel"
|
||||
cfg.params_file = "./inference/cls_infer.pdiparams"
|
||||
cfg.batch_size = 1
|
||||
cfg.use_gpu = False
|
||||
cfg.enable_mkldnn = False
|
||||
cfg.ir_optim = True
|
||||
cfg.gpu_mem = 8000
|
||||
cfg.use_fp16 = False
|
||||
cfg.use_tensorrt = False
|
||||
cfg.cpu_num_threads = 10
|
||||
cfg.enable_profile = False
|
||||
|
||||
# params for preprocess
|
||||
cfg.resize_short = 256
|
||||
cfg.resize = 224
|
||||
cfg.normalize = True
|
||||
|
||||
return cfg
|
||||
def get_default_confg():
|
||||
return {
|
||||
'Global': {
|
||||
"inference_model_dir": "../inference/",
|
||||
"batch_size": 1,
|
||||
'use_gpu': False,
|
||||
'use_fp16': False,
|
||||
'enable_mkldnn': False,
|
||||
'cpu_num_threads': 1,
|
||||
'use_tensorrt': False,
|
||||
'ir_optim': False,
|
||||
"gpu_mem": 8000,
|
||||
'enable_profile': False,
|
||||
"enable_benchmark": False
|
||||
},
|
||||
'PostProcess': {
|
||||
'name': 'Topk',
|
||||
'topk': 5,
|
||||
'class_id_map_file': './utils/imagenet1k_label_list.txt'
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddlehub as hub
|
||||
from tools.infer.utils import preprocess
|
||||
|
||||
args = argparse.Namespace(resize_short=256, resize=224, normalize=True)
|
||||
|
||||
img_path_list = ["./deploy/hubserving/ILSVRC2012_val_00006666.JPEG", ]
|
||||
|
||||
module = hub.Module(name="clas_system")
|
||||
for i, img_path in enumerate(img_path_list):
|
||||
img = cv2.imread(img_path)[:, :, ::-1]
|
||||
img = preprocess(img, args)
|
||||
batch_input_data = np.expand_dims(img, axis=0)
|
||||
res = module.predict(batch_input_data)
|
||||
print("The returned result of {}: {}".format(img_path, res))
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
hubserving服务部署配置服务包`clas`下包含3个必选文件,目录如下:
|
||||
```
|
||||
deploy/hubserving/clas/
|
||||
hubserving/clas/
|
||||
└─ __init__.py 空文件,必选
|
||||
└─ config.json 配置文件,可选,使用配置启动服务时作为参数传入
|
||||
└─ module.py 主模块,必选,包含服务的完整逻辑
|
||||
|
@ -21,16 +21,16 @@ pip3 install paddlehub==2.0.0b1 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/s
|
|||
### 2. 下载推理模型
|
||||
安装服务模块前,需要准备推理模型并放到正确路径,默认模型路径为:
|
||||
```
|
||||
分类推理模型结构文件:./inference/cls_infer.pdmodel
|
||||
分类推理模型权重文件:./inference/cls_infer.pdiparams
|
||||
分类推理模型结构文件:PaddleClas/inference/inference.pdmodel
|
||||
分类推理模型权重文件:PaddleClas/inference/inference.pdiparams
|
||||
```
|
||||
|
||||
**注意**:
|
||||
* 模型路径可在`./PaddleClas/deploy/hubserving/clas/params.py`中查看和修改。
|
||||
* 模型文件路径可在`PaddleClas/deploy/hubserving/clas/params.py`中查看和修改:
|
||||
```python
|
||||
cfg.model_file = "./inference/cls_infer.pdmodel"
|
||||
cfg.params_file = "./inference/cls_infer.pdiparams"
|
||||
"inference_model_dir": "../inference/"
|
||||
```
|
||||
需要注意,模型文件(包括.pdmodel与.pdiparams)名称必须为`inference`。
|
||||
* 我们也提供了大量基于ImageNet-1k数据集的预训练模型,模型列表及下载地址详见[模型库概览](../../docs/zh_CN/models/models_intro.md),也可以使用自己训练转换好的模型。
|
||||
|
||||
### 3. 安装服务模块
|
||||
|
@ -38,14 +38,17 @@ pip3 install paddlehub==2.0.0b1 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/s
|
|||
|
||||
* 在Linux环境下,安装示例如下:
|
||||
```shell
|
||||
# 安装服务模块:
|
||||
hub install deploy/hubserving/clas/
|
||||
cd PaddleClas/deploy
|
||||
# 安装服务模块:
|
||||
hub install hubserving/clas/
|
||||
```
|
||||
|
||||
* 在Windows环境下(文件夹的分隔符为`\`),安装示例如下:
|
||||
|
||||
```shell
|
||||
cd PaddleClas\deploy
|
||||
# 安装服务模块:
|
||||
hub install deploy\hubserving\clas\
|
||||
hub install hubserving\clas\
|
||||
```
|
||||
|
||||
### 4. 启动服务
|
||||
|
@ -59,7 +62,6 @@ $ hub serving start --modules Module1==Version1 \
|
|||
```
|
||||
|
||||
**参数:**
|
||||
|
||||
|参数|用途|
|
||||
|-|-|
|
||||
|--modules/-m| [**必选**] PaddleHub Serving预安装模型,以多个Module==Version键值对的形式列出<br>*`当不指定Version时,默认选择最新版本`*|
|
||||
|
@ -108,30 +110,32 @@ $ hub serving start --modules Module1==Version1 \
|
|||
|
||||
如,使用GPU 3号卡启动串联服务:
|
||||
```shell
|
||||
cd PaddleClas/deploy
|
||||
export CUDA_VISIBLE_DEVICES=3
|
||||
hub serving start -c deploy/hubserving/clas/config.json
|
||||
hub serving start -c hubserving/clas/config.json
|
||||
```
|
||||
|
||||
## 发送预测请求
|
||||
配置好服务端,可使用以下命令发送预测请求,获取预测结果:
|
||||
|
||||
```python tools/test_hubserving.py server_url image_path```
|
||||
```shell
|
||||
cd PaddleClas/deploy
|
||||
python hubserving/test_hubserving.py server_url image_path
|
||||
```
|
||||
|
||||
需要给脚本传递2个必须参数:
|
||||
- **server_url**:服务地址,格式为
|
||||
`http://[ip_address]:[port]/predict/[module_name]`
|
||||
- **image_path**:测试图像路径,可以是单张图片路径,也可以是图像集合目录路径。
|
||||
- **top_k**:[**可选**] 返回前 `top_k` 个 `score` ,默认为 `1`。
|
||||
- **batch_size**:[**可选**] 以`batch_size`大小为单位进行预测,默认为`1`。
|
||||
- **resize_short**:[**可选**] 将图像等比例缩放到最短边为`resize_short`,默认为`256`。
|
||||
- **resize**:[**可选**] 将图像resize到`resize * resize`尺寸,默认为`224`。
|
||||
- **normalize**:[**可选**] 是否对图像进行normalize处理,默认为`True`。
|
||||
|
||||
**注意**:如果使用`Transformer`系列模型,如`DeiT_***_384`, `ViT_***_384`等,请注意模型的输入数据尺寸。需要指定`--resize_short=384 --resize=384`。
|
||||
|
||||
|
||||
访问示例:
|
||||
```python tools/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./deploy/hubserving/ILSVRC2012_val_00006666.JPEG --top_k 5```
|
||||
```shell
|
||||
python hubserving/test_hubserving.py --server_url http://127.0.0.1:8866/predict/clas_system --image_file ./hubserving/ILSVRC2012_val_00006666.JPEG --batch_size 8
|
||||
```
|
||||
|
||||
### 返回结果格式说明
|
||||
返回结果为列表(list),包含top-k个分类结果,以及对应的得分,还有此图片预测耗时,具体如下:
|
||||
|
@ -143,7 +147,7 @@ list: 返回结果
|
|||
└─ float: 该图分类耗时,单位秒
|
||||
```
|
||||
|
||||
**说明:** 如果需要增加、删除、修改返回字段,可在相应模块的`module.py`文件中进行修改,完整流程参考下一节自定义修改服务模块。
|
||||
**说明:** 如果需要增加、删除、修改返回字段,可对相应模块进行修改,完整流程参考下一节自定义修改服务模块。
|
||||
|
||||
## 自定义修改服务模块
|
||||
如果需要修改服务逻辑,你一般需要操作以下步骤:
|
||||
|
@ -151,16 +155,30 @@ list: 返回结果
|
|||
- 1、 停止服务
|
||||
```hub serving stop --port/-p XXXX```
|
||||
|
||||
- 2、 到相应的`module.py`和`params.py`等文件中根据实际需求修改代码。
|
||||
例如,例如需要替换部署服务所用模型,则需要到`params.py`中修改模型路径参数`cfg.model_file`和`cfg.params_file`。
|
||||
|
||||
修改并安装(`hub install deploy/hubserving/clas/`)完成后,在进行部署前,可通过`python deploy/hubserving/clas/test.py`测试已安装服务模块。
|
||||
- 2、 到相应的`module.py`和`params.py`等文件中根据实际需求修改代码。`module.py`修改后需要重新安装(`hub install hubserving/clas/`)并部署。在进行部署前,可通过`python hubserving/clas/module.py`测试已安装服务模块。
|
||||
|
||||
- 3、 卸载旧服务包
|
||||
```hub uninstall clas_system```
|
||||
|
||||
- 4、 安装修改后的新服务包
|
||||
```hub install deploy/hubserving/clas/```
|
||||
```hub install hubserving/clas/```
|
||||
|
||||
- 5、重新启动服务
|
||||
```hub serving start -m clas_system```
|
||||
|
||||
**注意**:
|
||||
常用参数可在[params.py](./clas/params.py)中修改:
|
||||
* 更换模型,需要修改模型文件路径参数:
|
||||
```python
|
||||
"inference_model_dir":
|
||||
```
|
||||
* 更改后处理时返回的`top-k`结果数量:
|
||||
```python
|
||||
'topk':
|
||||
```
|
||||
* 更改后处理时的lable与class id对应映射文件:
|
||||
```python
|
||||
'class_id_map_file':
|
||||
```
|
||||
|
||||
为了避免不必要的延时以及能够以batch_size进行预测,数据预处理逻辑(包括resize、crop等操作)在客户端完成,因此需要在[test_hubserving.py](./test_hubserving.py#L35-L52)中修改。
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,29 +15,54 @@
|
|||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
from tools.infer.utils import parse_args, get_image_list, preprocess, np_to_b64
|
||||
from ppcls.utils import logger
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from utils import logger
|
||||
from utils.get_image_list import get_image_list
|
||||
from utils import config
|
||||
from utils.encode_decode import np_to_b64
|
||||
from python.preprocess import create_operators
|
||||
|
||||
preprocess_config = [{
|
||||
'ResizeImage': {
|
||||
'resize_short': 256
|
||||
}
|
||||
}, {
|
||||
'CropImage': {
|
||||
'size': 224
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'scale': 0.00392157,
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'order': ''
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}]
|
||||
|
||||
|
||||
def main(args):
|
||||
image_path_list = get_image_list(args.image_file)
|
||||
headers = {"Content-type": "application/json"}
|
||||
preprocess_ops = create_operators(preprocess_config)
|
||||
|
||||
cnt = 0
|
||||
predict_time = 0
|
||||
all_score = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
batch_input_list = []
|
||||
img_data_list = []
|
||||
img_name_list = []
|
||||
cnt = 0
|
||||
for idx, img_path in enumerate(image_path_list):
|
||||
|
@ -48,22 +73,23 @@ def main(args):
|
|||
format(img_path))
|
||||
continue
|
||||
else:
|
||||
img = img[:, :, ::-1]
|
||||
data = preprocess(img, args)
|
||||
batch_input_list.append(data)
|
||||
for ops in preprocess_ops:
|
||||
img = ops(img)
|
||||
img = np.array(img)
|
||||
img_data_list.append(img)
|
||||
|
||||
img_name = img_path.split('/')[-1]
|
||||
img_name_list.append(img_name)
|
||||
cnt += 1
|
||||
if cnt % args.batch_size == 0 or (idx + 1) == len(image_path_list):
|
||||
batch_input = np.array(batch_input_list)
|
||||
b64str, revert_shape = np_to_b64(batch_input)
|
||||
inputs = np.array(img_data_list)
|
||||
b64str, revert_shape = np_to_b64(inputs)
|
||||
data = {
|
||||
"images": b64str,
|
||||
"revert_params": {
|
||||
"shape": revert_shape,
|
||||
"dtype": str(batch_input.dtype)
|
||||
},
|
||||
"top_k": args.top_k
|
||||
"dtype": str(inputs.dtype)
|
||||
}
|
||||
}
|
||||
try:
|
||||
r = requests.post(
|
||||
|
@ -80,24 +106,25 @@ def main(args):
|
|||
continue
|
||||
else:
|
||||
results = r.json()["results"]
|
||||
batch_result_list = results["prediction"]
|
||||
preds = results["prediction"]
|
||||
elapse = results["elapse"]
|
||||
|
||||
cnt += len(batch_result_list)
|
||||
cnt += len(preds)
|
||||
predict_time += elapse
|
||||
|
||||
for number, result_list in enumerate(batch_result_list):
|
||||
for number, result_list in enumerate(preds):
|
||||
all_score += result_list["scores"][0]
|
||||
result_str = ""
|
||||
for i in range(len(result_list["clas_ids"])):
|
||||
for i in range(len(result_list["class_ids"])):
|
||||
result_str += "{}: {:.2f}\t".format(
|
||||
result_list["clas_ids"][i],
|
||||
result_list["class_ids"][i],
|
||||
result_list["scores"][i])
|
||||
logger.info("File:{}, The top-{} result(s): {}".format(
|
||||
img_name_list[number], args.top_k, result_str))
|
||||
|
||||
logger.info("File:{}, The result(s): {}".format(
|
||||
img_name_list[number], result_str))
|
||||
|
||||
finally:
|
||||
batch_input_list = []
|
||||
img_data_list = []
|
||||
img_name_list = []
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
@ -109,5 +136,10 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--server_url", type=str)
|
||||
parser.add_argument("--image_file", type=str)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -24,16 +24,22 @@ from utils import logger
|
|||
from utils import config
|
||||
from utils.predictor import Predictor
|
||||
from utils.get_image_list import get_image_list
|
||||
from preprocess import create_operators
|
||||
from postprocess import build_postprocess
|
||||
from python.preprocess import create_operators
|
||||
from python.postprocess import build_postprocess
|
||||
|
||||
|
||||
class ClsPredictor(Predictor):
|
||||
def __init__(self, config):
|
||||
super().__init__(config["Global"])
|
||||
self.preprocess_ops = create_operators(config["PreProcess"][
|
||||
"transform_ops"])
|
||||
self.postprocess = build_postprocess(config["PostProcess"])
|
||||
|
||||
self.preprocess_ops = []
|
||||
self.postprocess = None
|
||||
if "PreProcess" in config:
|
||||
if "transform_ops" in config["PreProcess"]:
|
||||
self.preprocess_ops = create_operators(config["PreProcess"][
|
||||
"transform_ops"])
|
||||
if "PostProcess" in config:
|
||||
self.postprocess = build_postprocess(config["PostProcess"])
|
||||
|
||||
def predict(self, images):
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
|
|
|
@ -26,7 +26,7 @@ import cv2
|
|||
import numpy as np
|
||||
import importlib
|
||||
|
||||
from det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
|
||||
from python.det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
|
||||
|
||||
|
||||
def create_operators(params):
|
||||
|
|
|
@ -2,3 +2,4 @@ from . import logger
|
|||
from . import config
|
||||
from . import get_image_list
|
||||
from . import predictor
|
||||
from . import encode_decode
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,7 +16,9 @@ import os
|
|||
import copy
|
||||
import argparse
|
||||
import yaml
|
||||
|
||||
from utils import logger
|
||||
|
||||
__all__ = ['get_config']
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def np_to_b64(images):
|
||||
img_str = base64.b64encode(images).decode('utf8')
|
||||
return img_str, images.shape
|
||||
|
||||
|
||||
def b64_to_np(b64str, revert_params):
|
||||
shape = revert_params["shape"]
|
||||
dtype = revert_params["dtype"]
|
||||
dtype = getattr(np, dtype) if isinstance(str, type(dtype)) else dtype
|
||||
data = base64.b64decode(b64str.encode('utf8'))
|
||||
data = np.fromstring(data, dtype).reshape(shape)
|
||||
return data
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue