[Enhancement] Enhance ppl for all codebases (#177)

* enhance ppl for all codebases

* fix dump info

* fix md and use not None

* remove redundant codes

* safe convert empty ppl tensor

* add examples and remove useless lines
This commit is contained in:
AllentDan 2021-11-16 19:16:46 +08:00 committed by GitHub
parent 0043848a52
commit a4dceb4bb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 217 additions and 58 deletions

View File

@ -0,0 +1,5 @@
_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py']
onnx_config = dict(input_shape=[224, 224])
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 224, 224]))

View File

@ -1 +0,0 @@
_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py']

View File

@ -1,3 +0,0 @@
_base_ = ['./classification_static.py', '../_base_/backends/ppl.py']
onnx_config = dict(input_shape=None)

View File

@ -0,0 +1,5 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=(1344, 800))
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))

View File

@ -1 +0,0 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']

View File

@ -1 +0,0 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py']

View File

@ -0,0 +1,5 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=(1344, 800))
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))

View File

@ -1 +0,0 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']

View File

@ -1 +0,0 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py']

View File

@ -0,0 +1,5 @@
_base_ = ['../_base_/mask_base_dynamic.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=(1344, 800))
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))

View File

@ -0,0 +1,5 @@
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=(32, 32))
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 32, 32]))

View File

@ -1 +0,0 @@
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py']

View File

@ -1,3 +0,0 @@
_base_ = ['./super-resolution_static.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=None)

View File

@ -0,0 +1,5 @@
_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=(640, 640))
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 640, 640]))

View File

@ -1 +0,0 @@
_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py']

View File

@ -1,3 +0,0 @@
_base_ = ['./text-detection_static.py', '../../_base_/backends/ppl.py']
onnx_config = dict(input_shape=None)

View File

@ -0,0 +1,5 @@
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py']
onnx_config = dict(input_shape=[1024, 512])
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 512, 1024]))

View File

@ -1 +0,0 @@
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py']

View File

@ -1,3 +0,0 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/ppl.py']
onnx_config = dict(input_shape=None)

View File

@ -13,7 +13,7 @@ This tutorial is based on Linux systems like Ubuntu-18.04.
Example:
```bash
python tools/deploy.py \
configs/mmdet/single-stage/single-stage_ppl_dynamic.py \
configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py \
/mmdetection_dir/mmdetection/configs/retinanet/retinanet_r50_fpn_1x_coco.py \
/tmp/snapshots/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth \
tests/data/tiger.jpeg \

View File

@ -12,4 +12,5 @@ def is_available():
if is_available():
from .ppl_utils import PPLWrapper, register_engines
__all__ = ['register_engines', 'PPLWrapper']
from .onnx2ppl import onnx2ppl
__all__ = ['register_engines', 'PPLWrapper', 'onnx2ppl']

View File

@ -0,0 +1,68 @@
from typing import Optional, Sequence
import torch
from pyppl import nn as pplnn
from mmdeploy.apis.ppl import register_engines
def parse_cuda_device_id(device: str) -> int:
"""Parse cuda device index from a string.
Args:
device (str): The typical style of string specifying cuda device,
e.g.: 'cuda:0'.
Returns:
int: The parsed device id, defaults to `0`.
"""
device_id = 0
if len(device) >= 6:
device_id = torch.device(device).index
return device_id
def onnx2ppl(algo_file: str,
onnx_model: str,
device: str = 'cuda:0',
input_shapes: Optional[Sequence[Sequence[int]]] = None,
**kwargs):
"""Convert ONNX to PPL.
PPL is capable of optimizing onnx model. The optimized algorithm is saved
into `algo_file` in json format. Note that `input_shapes` actually require
multiple shapes of inputs in its original design. But in the pipeline of
our codebase, we only pass one input shape which can be modified by users'
own preferences.
Args:
algo_file (str): File path to save PPL optimization algorithm.
onnx_model (str): Input onnx model.
device (str): A string specifying cuda device, defaults to 'cuda:0'.
input_shapes (Sequence[Sequence[int]] | None): shapes for PPL
optimization, default to None.
Examples:
>>> from mmdeploy.apis.ppl import onnx2ppl
>>>
>>> onnx2ppl(algo_file = 'example.json', onnx_model = 'example.onnx')
"""
if device == 'cpu':
device_id = -1
else:
assert 'cuda' in device, f'unexpected device: {device}, must contain '
'`cpu` or `cuda`'
device_id = parse_cuda_device_id(device)
if input_shapes is None:
input_shapes = [[1, 3, 224, 224]] # PPL default shape for optimization
engines = register_engines(
device_id,
disable_avx512=False,
quick_select=False,
export_algo_file=algo_file,
input_shapes=input_shapes)
runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile(
onnx_model, engines)
assert runtime_builder is not None, 'Failed to create '\
'OnnxRuntimeBuilder.'

View File

@ -1,18 +1,21 @@
import logging
import sys
from typing import Dict
from typing import Dict, Sequence
import numpy as np
import pyppl.common as pplcommon
import pyppl.nn as pplnn
import torch
from pyppl import common as pplcommon
from pyppl import nn as pplnn
from mmdeploy.utils.timer import TimeCounter
def register_engines(device_id: int,
disable_avx512: bool = False,
quick_select: bool = False):
quick_select: bool = False,
input_shapes: Sequence[Sequence[int]] = None,
export_algo_file: str = None,
import_algo_file: str = None):
"""Register engines for ppl runtime.
Args:
@ -21,6 +24,9 @@ def register_engines(device_id: int,
Defaults to `False`.
quick_select (bool): Whether to use default algorithms.
Defaults to `False`.
input_shapes (Sequence[Sequence[int]]): shapes for PPL optimization.
export_algo_file (str): File path for exporting PPL optimization file.
import_algo_file (str): File path for loading PPL optimization file.
Returns:
list[pplnn.Engine]: A list of registered ppl engines.
@ -59,6 +65,33 @@ def register_engines(device_id: int,
pplcommon.GetRetCodeStr(status))
sys.exit(-1)
if input_shapes is not None:
status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS,
input_shapes)
if status != pplcommon.RC_SUCCESS:
logging.error(
'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: '
+ pplcommon.GetRetCodeStr(status))
sys.exit(-1)
if export_algo_file is not None:
status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS,
export_algo_file)
if status != pplcommon.RC_SUCCESS:
logging.error(
'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) '
'failed: ' + pplcommon.GetRetCodeStr(status))
sys.exit(-1)
if import_algo_file is not None:
status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS,
import_algo_file)
if status != pplcommon.RC_SUCCESS:
logging.error(
'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) '
'failed: ' + pplcommon.GetRetCodeStr(status))
sys.exit(-1)
engines.append(pplnn.Engine(cuda_engine))
return engines
@ -68,7 +101,8 @@ class PPLWrapper(torch.nn.Module):
"""PPL wrapper for inference.
Args:
model_file (str): Input onnx model file.
onnx_file (str): Path of input ONNX model file.
algo_file (str): Path of PPL algorithm file.
device_id (int): Device id to put model.
Examples:
@ -76,21 +110,23 @@ class PPLWrapper(torch.nn.Module):
>>> import torch
>>>
>>> onnx_file = 'model.onnx'
>>> model = PPLWrapper(onnx_file, 0)
>>> model = PPLWrapper(onnx_file, 'end2end.json', 0)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self, model_file: str, device_id: int):
def __init__(self, onnx_file: str, algo_file: str, device_id: int):
super(PPLWrapper, self).__init__()
# enable quick select by default to speed up pipeline
# TODO: open it to users after ppl supports saving serialized models
# TODO: disable_avx512 will be removed or open to users in config
engines = register_engines(
device_id, disable_avx512=False, quick_select=True)
device_id,
disable_avx512=False,
quick_select=False,
import_algo_file=algo_file)
runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile(
model_file, engines)
onnx_file, engines)
assert runtime_builder is not None, 'Failed to create '\
'OnnxRuntimeBuilder.'
@ -119,7 +155,12 @@ class PPLWrapper(torch.nn.Module):
outputs = []
for i in range(self.runtime.GetOutputCount()):
out_tensor = self.runtime.GetOutputTensor(i).ConvertToHost()
outputs.append(np.array(out_tensor, copy=False))
if out_tensor:
outputs.append(np.array(out_tensor, copy=False))
else:
out_shape = self.runtime.GetOutputTensor(
i).GetShape().GetDims()
outputs.append(np.random.rand(*out_shape))
return outputs
@TimeCounter.count_time()

View File

@ -132,15 +132,17 @@ class PPLClassifier(DeployBaseClassifier):
"""Wrapper for classifier's inference with PPL.
Args:
model_file (str): Path of input ONNX model file.
onnx_file (str): Path of input ONNX model file.
algo_file (str): Path of PPL algorithm file.
class_names (Sequence[str]): A list of string specifying class names.
device_id (int): An integer represents device index.
"""
def __init__(self, model_file, class_names, device_id):
def __init__(self, onnx_file, algo_file, class_names, device_id):
super(PPLClassifier, self).__init__(class_names, device_id)
from mmdeploy.apis.ppl import PPLWrapper
model = PPLWrapper(model_file=model_file, device_id=device_id)
model = PPLWrapper(
onnx_file=onnx_file, algo_file=algo_file, device_id=device_id)
self.model = model
self.CLASSES = class_names

View File

@ -300,7 +300,7 @@ class PPLDetector(DeployBaseDetector):
"""Wrapper for detection's inference with PPL.
Args:
model_file (str): Path of input ONNX model file.
model_file (str): The path of input model file.
class_names (Sequence[str]): A list of string specifying class names.
device_id (int): An integer represents device index.
"""
@ -308,7 +308,7 @@ class PPLDetector(DeployBaseDetector):
def __init__(self, model_file, class_names, device_id, **kwargs):
super(PPLDetector, self).__init__(class_names, device_id)
from mmdeploy.apis.ppl import PPLWrapper
self.model = PPLWrapper(model_file, device_id)
self.model = PPLWrapper(*model_file, device_id=device_id)
def forward_test(self, imgs: torch.Tensor, *args, **kwargs):
"""Implement forward test.

View File

@ -211,14 +211,16 @@ class PPLRestorer(DeployBaseRestorer):
"""Wrapper for restorer's inference with ppl.
Args:
model_file (str): The path of input model file.
onnx_file (str): Path of input ONNX model file.
algo_file (str): Path of PPL algorithm file.
device_id (int): An integer represents device index.
test_cfg (mmcv.Config): The test config in model config, which is
used in evaluation.
"""
def __init__(self,
model_file: str,
onnx_file: str,
algo_file: str,
device_id: int,
test_cfg: Optional[mmcv.Config] = None,
**kwargs):
@ -226,7 +228,7 @@ class PPLRestorer(DeployBaseRestorer):
device_id, test_cfg=test_cfg, **kwargs)
from mmdeploy.apis.ppl import PPLWrapper
self.model = PPLWrapper(model_file, device_id)
self.model = PPLWrapper(onnx_file, algo_file, device_id)
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
"""Run test inference for restorer with PPL.
@ -278,7 +280,7 @@ def build_restorer(model_files: Sequence[str], backend: Backend,
backend_model_class = model_map[model_type]
backend_model = backend_model_class(
model_files[0], device_id=device_id, test_cfg=model_cfg.test_cfg)
*model_files, device_id=device_id, test_cfg=model_cfg.test_cfg)
return backend_model

View File

@ -373,7 +373,7 @@ class PPLDetector(DeployBaseTextDetector):
"""Wrapper for TextDetector with PPL.
Args:
model_file (str): The path of input model file.
model_file (Sequence[str]): Paths of input model files.
cfg (str | mmcv.ConfigDict): Input model config.
device_id (int): An integer represents device index.
show_score (bool): Whether to show scores. Defaults to `False`.
@ -388,7 +388,7 @@ class PPLDetector(DeployBaseTextDetector):
**kwargs):
super(PPLDetector, self).__init__(cfg, device_id, show_score)
from mmdeploy.apis.ppl import PPLWrapper
model = PPLWrapper(model_file, device_id)
model = PPLWrapper(model_file[0], model_file[1], device_id)
self.model = model
def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict],
@ -411,7 +411,8 @@ class PPLRecognizer(DeployBaseRecognizer):
"""Wrapper for TextRecognizer with PPL.
Args:
model_file (str): The path of input model file.
onnx_file (str): Path of input ONNX model file.
algo_file (str): Path of PPL algorithm file.
cfg (str | mmcv.ConfigDict): Input model config.
device_id (int): An integer represents device index.
show_score (bool): Whether to show scores. Defaults to `False`.
@ -419,6 +420,7 @@ class PPLRecognizer(DeployBaseRecognizer):
def __init__(self,
model_file: str,
algo_file: str,
cfg: Union[mmcv.Config, mmcv.ConfigDict],
device_id: int,
show_score: bool = False,
@ -426,7 +428,7 @@ class PPLRecognizer(DeployBaseRecognizer):
**kwargs):
super(PPLRecognizer, self).__init__(cfg, device_id, show_score)
from mmdeploy.apis.ppl import PPLWrapper
model = PPLWrapper(model_file, device_id)
model = PPLWrapper(model_file, algo_file, device_id)
self.model = model
def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict],

View File

@ -144,7 +144,7 @@ class PPLSegmentor(DeployBaseSegmentor):
"""Wrapper for segmentation's inference with PPL.
Args:
model_file (str): The path of input model file.
model_file (Sequence[str]): Paths of input params and bin files.
class_names (Sequence[str]): A list of string specifying class names.
palette (np.ndarray): The palette of segmentation map.
device_id (int): An integer represents device index.
@ -154,7 +154,7 @@ class PPLSegmentor(DeployBaseSegmentor):
palette: np.ndarray, device_id: int):
super(PPLSegmentor, self).__init__(class_names, palette, device_id)
from mmdeploy.apis.ppl import PPLWrapper
self.model = PPLWrapper(model_file, device_id)
self.model = PPLWrapper(model_file[0], model_file[1], device_id)
def forward_test(self, imgs: torch.Tensor, img_metas: Sequence[dict],
**kwargs):

View File

@ -87,7 +87,10 @@ def onnx2backend(backend, onnx_file):
elif backend == Backend.ONNXRUNTIME:
return onnx_file
elif backend == Backend.PPL:
return onnx_file
from mmdeploy.apis.ppl import onnx2ppl
algo_file = tempfile.NamedTemporaryFile(suffix='.json').name
onnx2ppl(algo_file=algo_file, onnx_model=onnx_file)
return onnx_file, algo_file
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import get_onnx2ncnn_path
onnx2ncnn_path = get_onnx2ncnn_path()
@ -117,7 +120,7 @@ def create_wrapper(backend, model_files):
return ort_model
elif backend == Backend.PPL:
from mmdeploy.apis.ppl import PPLWrapper
ppl_model = PPLWrapper(model_files, 0)
ppl_model = PPLWrapper(model_files[0], None, device_id=0)
return ppl_model
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import NCNNWrapper

View File

@ -94,7 +94,7 @@ def test_PPLClassifier():
wrapper.set(outputs=outputs)
from mmdeploy.mmcls.apis.inference import PPLClassifier
ppl_classifier = PPLClassifier('', [''], 0)
ppl_classifier = PPLClassifier('', '', [''], 0)
imgs = torch.rand(1, 3, 64, 64)
results = ppl_classifier.forward(imgs, return_loss=False)

View File

@ -86,7 +86,7 @@ def test_PPLRestorer():
wrapper.set(outputs=outputs)
from mmdeploy.mmedit.apis.inference import PPLRestorer
ppl_restorer = PPLRestorer('', 0)
ppl_restorer = PPLRestorer('', '', 0)
imgs = torch.rand(1, 3, 64, 64)
results = ppl_restorer.forward(imgs)

View File

@ -89,7 +89,7 @@ def test_PPLDetector():
from mmdeploy.mmocr.apis.inference import PPLDetector
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/dbnet.py')
ppl_detector = PPLDetector('', model_config, 0, False)
ppl_detector = PPLDetector(['', ''], model_config, 0, False)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
@ -194,7 +194,7 @@ def test_PPLRecognizer():
from mmdeploy.mmocr.apis.inference import PPLRecognizer
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/crnn.py')
ppl_recognizer = PPLRecognizer('', model_config, 0, False)
ppl_recognizer = PPLRecognizer('', '', model_config, 0, False)
imgs = [torch.rand(1, 1, 32, 32)]
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]

View File

@ -89,7 +89,7 @@ def test_PPLSegmentor():
wrapper.set(outputs=outputs)
from mmdeploy.mmseg.apis.inference import PPLSegmentor
ppl_segmentor = PPLSegmentor('', ['' for i in range(19)],
ppl_segmentor = PPLSegmentor(['', ''], ['' for i in range(19)],
np.empty([19], dtype=int), 0)
imgs = torch.rand(1, 3, 64, 64)
img_metas = [[{

View File

@ -96,12 +96,12 @@ def main():
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir)
# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir)
ret_value = mp.Value('d', 0, lock=False)
# convert onnx
@ -241,6 +241,31 @@ def main():
openvino_files.append(model_xml_path)
backend_files = openvino_files
elif backend == Backend.PPL:
from mmdeploy.apis.ppl import \
is_available as is_available_ppl
assert is_available_ppl(), \
'PPL is not available, please install PPL first.'
from mmdeploy.apis.ppl import onnx2ppl
ppl_files = []
for onnx_path in onnx_files:
algo_file = onnx_path.replace('.onnx', '.json')
model_inputs = get_model_inputs(deploy_cfg)
assert 'opt_shape' in model_inputs, 'expect opt_shape '
'in deploy config for ppl'
# PPL accepts only 1 input shape for optimization,
# may get changed in the future
input_shapes = [model_inputs.opt_shape]
create_process(
f'onnx2ppl with {onnx_path}',
target=onnx2ppl,
args=(algo_file, onnx_path),
kwargs=dict(device=args.device, input_shapes=input_shapes),
ret_value=ret_value)
ppl_files += [onnx_path, algo_file]
backend_files = ppl_files
if args.test_img is None:
args.test_img = args.img
# visualize model of the backend