mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] Enhance ppl for all codebases (#177)
* enhance ppl for all codebases * fix dump info * fix md and use not None * remove redundant codes * safe convert empty ppl tensor * add examples and remove useless lines
This commit is contained in:
parent
0043848a52
commit
a4dceb4bb4
5
configs/mmcls/classification_ppl_dynamic-224x224.py
Normal file
5
configs/mmcls/classification_ppl_dynamic-224x224.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=[224, 224])
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 224, 224]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py']
|
@ -1,3 +0,0 @@
|
||||
_base_ = ['./classification_static.py', '../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
@ -0,0 +1,5 @@
|
||||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=(1344, 800))
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
|
@ -1 +0,0 @@
|
||||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py']
|
@ -0,0 +1,5 @@
|
||||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=(1344, 800))
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py']
|
@ -1 +0,0 @@
|
||||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py']
|
5
configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py
Normal file
5
configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = ['../_base_/mask_base_dynamic.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=(1344, 800))
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344]))
|
@ -0,0 +1,5 @@
|
||||
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=(32, 32))
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 32, 32]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py']
|
@ -1,3 +0,0 @@
|
||||
_base_ = ['./super-resolution_static.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
@ -0,0 +1,5 @@
|
||||
_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=(640, 640))
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 640, 640]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py']
|
@ -1,3 +0,0 @@
|
||||
_base_ = ['./text-detection_static.py', '../../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
5
configs/mmseg/segmentation_ppl_dynamic-512x1024.py
Normal file
5
configs/mmseg/segmentation_ppl_dynamic-512x1024.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=[1024, 512])
|
||||
|
||||
backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 512, 1024]))
|
@ -1 +0,0 @@
|
||||
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py']
|
@ -1,3 +0,0 @@
|
||||
_base_ = ['./segmentation_static.py', '../_base_/backends/ppl.py']
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
@ -13,7 +13,7 @@ This tutorial is based on Linux systems like Ubuntu-18.04.
|
||||
Example:
|
||||
```bash
|
||||
python tools/deploy.py \
|
||||
configs/mmdet/single-stage/single-stage_ppl_dynamic.py \
|
||||
configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py \
|
||||
/mmdetection_dir/mmdetection/configs/retinanet/retinanet_r50_fpn_1x_coco.py \
|
||||
/tmp/snapshots/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth \
|
||||
tests/data/tiger.jpeg \
|
||||
|
@ -12,4 +12,5 @@ def is_available():
|
||||
|
||||
if is_available():
|
||||
from .ppl_utils import PPLWrapper, register_engines
|
||||
__all__ = ['register_engines', 'PPLWrapper']
|
||||
from .onnx2ppl import onnx2ppl
|
||||
__all__ = ['register_engines', 'PPLWrapper', 'onnx2ppl']
|
||||
|
68
mmdeploy/apis/ppl/onnx2ppl.py
Normal file
68
mmdeploy/apis/ppl/onnx2ppl.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
from pyppl import nn as pplnn
|
||||
|
||||
from mmdeploy.apis.ppl import register_engines
|
||||
|
||||
|
||||
def parse_cuda_device_id(device: str) -> int:
|
||||
"""Parse cuda device index from a string.
|
||||
|
||||
Args:
|
||||
device (str): The typical style of string specifying cuda device,
|
||||
e.g.: 'cuda:0'.
|
||||
|
||||
Returns:
|
||||
int: The parsed device id, defaults to `0`.
|
||||
"""
|
||||
device_id = 0
|
||||
if len(device) >= 6:
|
||||
device_id = torch.device(device).index
|
||||
return device_id
|
||||
|
||||
|
||||
def onnx2ppl(algo_file: str,
|
||||
onnx_model: str,
|
||||
device: str = 'cuda:0',
|
||||
input_shapes: Optional[Sequence[Sequence[int]]] = None,
|
||||
**kwargs):
|
||||
"""Convert ONNX to PPL.
|
||||
|
||||
PPL is capable of optimizing onnx model. The optimized algorithm is saved
|
||||
into `algo_file` in json format. Note that `input_shapes` actually require
|
||||
multiple shapes of inputs in its original design. But in the pipeline of
|
||||
our codebase, we only pass one input shape which can be modified by users'
|
||||
own preferences.
|
||||
|
||||
Args:
|
||||
algo_file (str): File path to save PPL optimization algorithm.
|
||||
onnx_model (str): Input onnx model.
|
||||
device (str): A string specifying cuda device, defaults to 'cuda:0'.
|
||||
input_shapes (Sequence[Sequence[int]] | None): shapes for PPL
|
||||
optimization, default to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.ppl import onnx2ppl
|
||||
>>>
|
||||
>>> onnx2ppl(algo_file = 'example.json', onnx_model = 'example.onnx')
|
||||
"""
|
||||
if device == 'cpu':
|
||||
device_id = -1
|
||||
else:
|
||||
assert 'cuda' in device, f'unexpected device: {device}, must contain '
|
||||
'`cpu` or `cuda`'
|
||||
device_id = parse_cuda_device_id(device)
|
||||
if input_shapes is None:
|
||||
input_shapes = [[1, 3, 224, 224]] # PPL default shape for optimization
|
||||
|
||||
engines = register_engines(
|
||||
device_id,
|
||||
disable_avx512=False,
|
||||
quick_select=False,
|
||||
export_algo_file=algo_file,
|
||||
input_shapes=input_shapes)
|
||||
runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile(
|
||||
onnx_model, engines)
|
||||
assert runtime_builder is not None, 'Failed to create '\
|
||||
'OnnxRuntimeBuilder.'
|
@ -1,18 +1,21 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pyppl.common as pplcommon
|
||||
import pyppl.nn as pplnn
|
||||
import torch
|
||||
from pyppl import common as pplcommon
|
||||
from pyppl import nn as pplnn
|
||||
|
||||
from mmdeploy.utils.timer import TimeCounter
|
||||
|
||||
|
||||
def register_engines(device_id: int,
|
||||
disable_avx512: bool = False,
|
||||
quick_select: bool = False):
|
||||
quick_select: bool = False,
|
||||
input_shapes: Sequence[Sequence[int]] = None,
|
||||
export_algo_file: str = None,
|
||||
import_algo_file: str = None):
|
||||
"""Register engines for ppl runtime.
|
||||
|
||||
Args:
|
||||
@ -21,6 +24,9 @@ def register_engines(device_id: int,
|
||||
Defaults to `False`.
|
||||
quick_select (bool): Whether to use default algorithms.
|
||||
Defaults to `False`.
|
||||
input_shapes (Sequence[Sequence[int]]): shapes for PPL optimization.
|
||||
export_algo_file (str): File path for exporting PPL optimization file.
|
||||
import_algo_file (str): File path for loading PPL optimization file.
|
||||
|
||||
Returns:
|
||||
list[pplnn.Engine]: A list of registered ppl engines.
|
||||
@ -59,6 +65,33 @@ def register_engines(device_id: int,
|
||||
pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if input_shapes is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS,
|
||||
input_shapes)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logging.error(
|
||||
'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: '
|
||||
+ pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if export_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS,
|
||||
export_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logging.error(
|
||||
'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if import_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS,
|
||||
import_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logging.error(
|
||||
'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
engines.append(pplnn.Engine(cuda_engine))
|
||||
|
||||
return engines
|
||||
@ -68,7 +101,8 @@ class PPLWrapper(torch.nn.Module):
|
||||
"""PPL wrapper for inference.
|
||||
|
||||
Args:
|
||||
model_file (str): Input onnx model file.
|
||||
onnx_file (str): Path of input ONNX model file.
|
||||
algo_file (str): Path of PPL algorithm file.
|
||||
device_id (int): Device id to put model.
|
||||
|
||||
Examples:
|
||||
@ -76,21 +110,23 @@ class PPLWrapper(torch.nn.Module):
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> onnx_file = 'model.onnx'
|
||||
>>> model = PPLWrapper(onnx_file, 0)
|
||||
>>> model = PPLWrapper(onnx_file, 'end2end.json', 0)
|
||||
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
|
||||
>>> outputs = model(inputs)
|
||||
>>> print(outputs)
|
||||
"""
|
||||
|
||||
def __init__(self, model_file: str, device_id: int):
|
||||
def __init__(self, onnx_file: str, algo_file: str, device_id: int):
|
||||
super(PPLWrapper, self).__init__()
|
||||
# enable quick select by default to speed up pipeline
|
||||
# TODO: open it to users after ppl supports saving serialized models
|
||||
# TODO: disable_avx512 will be removed or open to users in config
|
||||
engines = register_engines(
|
||||
device_id, disable_avx512=False, quick_select=True)
|
||||
device_id,
|
||||
disable_avx512=False,
|
||||
quick_select=False,
|
||||
import_algo_file=algo_file)
|
||||
runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile(
|
||||
model_file, engines)
|
||||
onnx_file, engines)
|
||||
assert runtime_builder is not None, 'Failed to create '\
|
||||
'OnnxRuntimeBuilder.'
|
||||
|
||||
@ -119,7 +155,12 @@ class PPLWrapper(torch.nn.Module):
|
||||
outputs = []
|
||||
for i in range(self.runtime.GetOutputCount()):
|
||||
out_tensor = self.runtime.GetOutputTensor(i).ConvertToHost()
|
||||
outputs.append(np.array(out_tensor, copy=False))
|
||||
if out_tensor:
|
||||
outputs.append(np.array(out_tensor, copy=False))
|
||||
else:
|
||||
out_shape = self.runtime.GetOutputTensor(
|
||||
i).GetShape().GetDims()
|
||||
outputs.append(np.random.rand(*out_shape))
|
||||
return outputs
|
||||
|
||||
@TimeCounter.count_time()
|
||||
|
@ -132,15 +132,17 @@ class PPLClassifier(DeployBaseClassifier):
|
||||
"""Wrapper for classifier's inference with PPL.
|
||||
|
||||
Args:
|
||||
model_file (str): Path of input ONNX model file.
|
||||
onnx_file (str): Path of input ONNX model file.
|
||||
algo_file (str): Path of PPL algorithm file.
|
||||
class_names (Sequence[str]): A list of string specifying class names.
|
||||
device_id (int): An integer represents device index.
|
||||
"""
|
||||
|
||||
def __init__(self, model_file, class_names, device_id):
|
||||
def __init__(self, onnx_file, algo_file, class_names, device_id):
|
||||
super(PPLClassifier, self).__init__(class_names, device_id)
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
model = PPLWrapper(model_file=model_file, device_id=device_id)
|
||||
model = PPLWrapper(
|
||||
onnx_file=onnx_file, algo_file=algo_file, device_id=device_id)
|
||||
self.model = model
|
||||
self.CLASSES = class_names
|
||||
|
||||
|
@ -300,7 +300,7 @@ class PPLDetector(DeployBaseDetector):
|
||||
"""Wrapper for detection's inference with PPL.
|
||||
|
||||
Args:
|
||||
model_file (str): Path of input ONNX model file.
|
||||
model_file (str): The path of input model file.
|
||||
class_names (Sequence[str]): A list of string specifying class names.
|
||||
device_id (int): An integer represents device index.
|
||||
"""
|
||||
@ -308,7 +308,7 @@ class PPLDetector(DeployBaseDetector):
|
||||
def __init__(self, model_file, class_names, device_id, **kwargs):
|
||||
super(PPLDetector, self).__init__(class_names, device_id)
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
self.model = PPLWrapper(model_file, device_id)
|
||||
self.model = PPLWrapper(*model_file, device_id=device_id)
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs):
|
||||
"""Implement forward test.
|
||||
|
@ -211,14 +211,16 @@ class PPLRestorer(DeployBaseRestorer):
|
||||
"""Wrapper for restorer's inference with ppl.
|
||||
|
||||
Args:
|
||||
model_file (str): The path of input model file.
|
||||
onnx_file (str): Path of input ONNX model file.
|
||||
algo_file (str): Path of PPL algorithm file.
|
||||
device_id (int): An integer represents device index.
|
||||
test_cfg (mmcv.Config): The test config in model config, which is
|
||||
used in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_file: str,
|
||||
onnx_file: str,
|
||||
algo_file: str,
|
||||
device_id: int,
|
||||
test_cfg: Optional[mmcv.Config] = None,
|
||||
**kwargs):
|
||||
@ -226,7 +228,7 @@ class PPLRestorer(DeployBaseRestorer):
|
||||
device_id, test_cfg=test_cfg, **kwargs)
|
||||
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
self.model = PPLWrapper(model_file, device_id)
|
||||
self.model = PPLWrapper(onnx_file, algo_file, device_id)
|
||||
|
||||
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
|
||||
"""Run test inference for restorer with PPL.
|
||||
@ -278,7 +280,7 @@ def build_restorer(model_files: Sequence[str], backend: Backend,
|
||||
backend_model_class = model_map[model_type]
|
||||
|
||||
backend_model = backend_model_class(
|
||||
model_files[0], device_id=device_id, test_cfg=model_cfg.test_cfg)
|
||||
*model_files, device_id=device_id, test_cfg=model_cfg.test_cfg)
|
||||
|
||||
return backend_model
|
||||
|
||||
|
@ -373,7 +373,7 @@ class PPLDetector(DeployBaseTextDetector):
|
||||
"""Wrapper for TextDetector with PPL.
|
||||
|
||||
Args:
|
||||
model_file (str): The path of input model file.
|
||||
model_file (Sequence[str]): Paths of input model files.
|
||||
cfg (str | mmcv.ConfigDict): Input model config.
|
||||
device_id (int): An integer represents device index.
|
||||
show_score (bool): Whether to show scores. Defaults to `False`.
|
||||
@ -388,7 +388,7 @@ class PPLDetector(DeployBaseTextDetector):
|
||||
**kwargs):
|
||||
super(PPLDetector, self).__init__(cfg, device_id, show_score)
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
model = PPLWrapper(model_file, device_id)
|
||||
model = PPLWrapper(model_file[0], model_file[1], device_id)
|
||||
self.model = model
|
||||
|
||||
def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict],
|
||||
@ -411,7 +411,8 @@ class PPLRecognizer(DeployBaseRecognizer):
|
||||
"""Wrapper for TextRecognizer with PPL.
|
||||
|
||||
Args:
|
||||
model_file (str): The path of input model file.
|
||||
onnx_file (str): Path of input ONNX model file.
|
||||
algo_file (str): Path of PPL algorithm file.
|
||||
cfg (str | mmcv.ConfigDict): Input model config.
|
||||
device_id (int): An integer represents device index.
|
||||
show_score (bool): Whether to show scores. Defaults to `False`.
|
||||
@ -419,6 +420,7 @@ class PPLRecognizer(DeployBaseRecognizer):
|
||||
|
||||
def __init__(self,
|
||||
model_file: str,
|
||||
algo_file: str,
|
||||
cfg: Union[mmcv.Config, mmcv.ConfigDict],
|
||||
device_id: int,
|
||||
show_score: bool = False,
|
||||
@ -426,7 +428,7 @@ class PPLRecognizer(DeployBaseRecognizer):
|
||||
**kwargs):
|
||||
super(PPLRecognizer, self).__init__(cfg, device_id, show_score)
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
model = PPLWrapper(model_file, device_id)
|
||||
model = PPLWrapper(model_file, algo_file, device_id)
|
||||
self.model = model
|
||||
|
||||
def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict],
|
||||
|
@ -144,7 +144,7 @@ class PPLSegmentor(DeployBaseSegmentor):
|
||||
"""Wrapper for segmentation's inference with PPL.
|
||||
|
||||
Args:
|
||||
model_file (str): The path of input model file.
|
||||
model_file (Sequence[str]): Paths of input params and bin files.
|
||||
class_names (Sequence[str]): A list of string specifying class names.
|
||||
palette (np.ndarray): The palette of segmentation map.
|
||||
device_id (int): An integer represents device index.
|
||||
@ -154,7 +154,7 @@ class PPLSegmentor(DeployBaseSegmentor):
|
||||
palette: np.ndarray, device_id: int):
|
||||
super(PPLSegmentor, self).__init__(class_names, palette, device_id)
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
self.model = PPLWrapper(model_file, device_id)
|
||||
self.model = PPLWrapper(model_file[0], model_file[1], device_id)
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, img_metas: Sequence[dict],
|
||||
**kwargs):
|
||||
|
@ -87,7 +87,10 @@ def onnx2backend(backend, onnx_file):
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
return onnx_file
|
||||
elif backend == Backend.PPL:
|
||||
return onnx_file
|
||||
from mmdeploy.apis.ppl import onnx2ppl
|
||||
algo_file = tempfile.NamedTemporaryFile(suffix='.json').name
|
||||
onnx2ppl(algo_file=algo_file, onnx_model=onnx_file)
|
||||
return onnx_file, algo_file
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import get_onnx2ncnn_path
|
||||
onnx2ncnn_path = get_onnx2ncnn_path()
|
||||
@ -117,7 +120,7 @@ def create_wrapper(backend, model_files):
|
||||
return ort_model
|
||||
elif backend == Backend.PPL:
|
||||
from mmdeploy.apis.ppl import PPLWrapper
|
||||
ppl_model = PPLWrapper(model_files, 0)
|
||||
ppl_model = PPLWrapper(model_files[0], None, device_id=0)
|
||||
return ppl_model
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import NCNNWrapper
|
||||
|
@ -94,7 +94,7 @@ def test_PPLClassifier():
|
||||
wrapper.set(outputs=outputs)
|
||||
|
||||
from mmdeploy.mmcls.apis.inference import PPLClassifier
|
||||
ppl_classifier = PPLClassifier('', [''], 0)
|
||||
ppl_classifier = PPLClassifier('', '', [''], 0)
|
||||
imgs = torch.rand(1, 3, 64, 64)
|
||||
|
||||
results = ppl_classifier.forward(imgs, return_loss=False)
|
||||
|
@ -86,7 +86,7 @@ def test_PPLRestorer():
|
||||
wrapper.set(outputs=outputs)
|
||||
|
||||
from mmdeploy.mmedit.apis.inference import PPLRestorer
|
||||
ppl_restorer = PPLRestorer('', 0)
|
||||
ppl_restorer = PPLRestorer('', '', 0)
|
||||
imgs = torch.rand(1, 3, 64, 64)
|
||||
|
||||
results = ppl_restorer.forward(imgs)
|
||||
|
@ -89,7 +89,7 @@ def test_PPLDetector():
|
||||
from mmdeploy.mmocr.apis.inference import PPLDetector
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/dbnet.py')
|
||||
ppl_detector = PPLDetector('', model_config, 0, False)
|
||||
ppl_detector = PPLDetector(['', ''], model_config, 0, False)
|
||||
imgs = [torch.rand(1, 3, 64, 64)]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
@ -194,7 +194,7 @@ def test_PPLRecognizer():
|
||||
from mmdeploy.mmocr.apis.inference import PPLRecognizer
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/crnn.py')
|
||||
ppl_recognizer = PPLRecognizer('', model_config, 0, False)
|
||||
ppl_recognizer = PPLRecognizer('', '', model_config, 0, False)
|
||||
imgs = [torch.rand(1, 1, 32, 32)]
|
||||
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||
|
||||
|
@ -89,7 +89,7 @@ def test_PPLSegmentor():
|
||||
wrapper.set(outputs=outputs)
|
||||
|
||||
from mmdeploy.mmseg.apis.inference import PPLSegmentor
|
||||
ppl_segmentor = PPLSegmentor('', ['' for i in range(19)],
|
||||
ppl_segmentor = PPLSegmentor(['', ''], ['' for i in range(19)],
|
||||
np.empty([19], dtype=int), 0)
|
||||
imgs = torch.rand(1, 3, 64, 64)
|
||||
img_metas = [[{
|
||||
|
@ -96,12 +96,12 @@ def main():
|
||||
# load deploy_cfg
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
||||
|
||||
if args.dump_info:
|
||||
dump_info(deploy_cfg, model_cfg, args.work_dir)
|
||||
|
||||
# create work_dir if not
|
||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
||||
if args.dump_info:
|
||||
dump_info(deploy_cfg, model_cfg, args.work_dir)
|
||||
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
|
||||
# convert onnx
|
||||
@ -241,6 +241,31 @@ def main():
|
||||
openvino_files.append(model_xml_path)
|
||||
backend_files = openvino_files
|
||||
|
||||
elif backend == Backend.PPL:
|
||||
from mmdeploy.apis.ppl import \
|
||||
is_available as is_available_ppl
|
||||
assert is_available_ppl(), \
|
||||
'PPL is not available, please install PPL first.'
|
||||
|
||||
from mmdeploy.apis.ppl import onnx2ppl
|
||||
ppl_files = []
|
||||
for onnx_path in onnx_files:
|
||||
algo_file = onnx_path.replace('.onnx', '.json')
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
assert 'opt_shape' in model_inputs, 'expect opt_shape '
|
||||
'in deploy config for ppl'
|
||||
# PPL accepts only 1 input shape for optimization,
|
||||
# may get changed in the future
|
||||
input_shapes = [model_inputs.opt_shape]
|
||||
create_process(
|
||||
f'onnx2ppl with {onnx_path}',
|
||||
target=onnx2ppl,
|
||||
args=(algo_file, onnx_path),
|
||||
kwargs=dict(device=args.device, input_shapes=input_shapes),
|
||||
ret_value=ret_value)
|
||||
ppl_files += [onnx_path, algo_file]
|
||||
backend_files = ppl_files
|
||||
|
||||
if args.test_img is None:
|
||||
args.test_img = args.img
|
||||
# visualize model of the backend
|
||||
|
Loading…
x
Reference in New Issue
Block a user