From a4dceb4bb4b99b40d71ed4f4df6470a60076eeae Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Tue, 16 Nov 2021 19:16:46 +0800 Subject: [PATCH] [Enhancement] Enhance ppl for all codebases (#177) * enhance ppl for all codebases * fix dump info * fix md and use not None * remove redundant codes * safe convert empty ppl tensor * add examples and remove useless lines --- .../classification_ppl_dynamic-224x224.py | 5 ++ configs/mmcls/classification_ppl_dynamic.py | 1 - configs/mmcls/classification_ppl_static.py | 3 - .../single-stage_ppl_dynamic-800x1344.py | 5 ++ .../single-stage/single-stage_ppl_dynamic.py | 1 - .../single-stage/single-stage_ppl_static.py | 1 - .../two-stage_ppl_dynamic-800x1344.py | 5 ++ .../mmdet/two-stage/two-stage_ppl_dynamic.py | 1 - .../mmdet/two-stage/two-stage_ppl_static.py | 1 - .../with-mask/mask_ppl_dynamic-800x1344.py | 5 ++ .../super-resolution_ppl_dynamic-32x32.py | 5 ++ .../super-resolution_ppl_dynamic.py | 1 - .../super-resolution_ppl_static.py | 3 - .../text-detection_ppl_dynamic-640x640.py | 5 ++ .../text-detection_ppl_dynamic.py | 1 - .../text-detection_ppl_static.py | 3 - .../segmentation_ppl_dynamic-512x1024.py | 5 ++ configs/mmseg/segmentation_ppl_dynamic.py | 1 - configs/mmseg/segmentation_ppl_static.py | 3 - docs/backends/ppl.md | 2 +- mmdeploy/apis/ppl/__init__.py | 3 +- mmdeploy/apis/ppl/onnx2ppl.py | 68 +++++++++++++++++++ mmdeploy/apis/ppl/ppl_utils.py | 63 ++++++++++++++--- mmdeploy/mmcls/apis/inference.py | 8 ++- mmdeploy/mmdet/apis/inference.py | 4 +- mmdeploy/mmedit/apis/inference.py | 10 +-- mmdeploy/mmocr/apis/inference.py | 10 +-- mmdeploy/mmseg/apis/inference.py | 4 +- tests/test_apis/test_wrapper.py | 7 +- tests/test_mmcls/test_mmcls_apis.py | 2 +- tests/test_mmedit/test_mmedit_apis.py | 2 +- tests/test_mmocr/test_mmocr_apis.py | 4 +- tests/test_mmseg/test_mmseg_apis.py | 2 +- tools/deploy.py | 31 ++++++++- 34 files changed, 217 insertions(+), 58 deletions(-) create mode 100644 configs/mmcls/classification_ppl_dynamic-224x224.py delete mode 100644 configs/mmcls/classification_ppl_dynamic.py delete mode 100644 configs/mmcls/classification_ppl_static.py create mode 100644 configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py delete mode 100644 configs/mmdet/single-stage/single-stage_ppl_dynamic.py delete mode 100644 configs/mmdet/single-stage/single-stage_ppl_static.py create mode 100644 configs/mmdet/two-stage/two-stage_ppl_dynamic-800x1344.py delete mode 100644 configs/mmdet/two-stage/two-stage_ppl_dynamic.py delete mode 100644 configs/mmdet/two-stage/two-stage_ppl_static.py create mode 100644 configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py create mode 100644 configs/mmedit/super-resolution/super-resolution_ppl_dynamic-32x32.py delete mode 100644 configs/mmedit/super-resolution/super-resolution_ppl_dynamic.py delete mode 100644 configs/mmedit/super-resolution/super-resolution_ppl_static.py create mode 100644 configs/mmocr/text-detection/text-detection_ppl_dynamic-640x640.py delete mode 100644 configs/mmocr/text-detection/text-detection_ppl_dynamic.py delete mode 100644 configs/mmocr/text-detection/text-detection_ppl_static.py create mode 100644 configs/mmseg/segmentation_ppl_dynamic-512x1024.py delete mode 100644 configs/mmseg/segmentation_ppl_dynamic.py delete mode 100644 configs/mmseg/segmentation_ppl_static.py create mode 100644 mmdeploy/apis/ppl/onnx2ppl.py diff --git a/configs/mmcls/classification_ppl_dynamic-224x224.py b/configs/mmcls/classification_ppl_dynamic-224x224.py new file mode 100644 index 000000000..f640d281a --- /dev/null +++ b/configs/mmcls/classification_ppl_dynamic-224x224.py @@ -0,0 +1,5 @@ +_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=[224, 224]) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 224, 224])) diff --git a/configs/mmcls/classification_ppl_dynamic.py b/configs/mmcls/classification_ppl_dynamic.py deleted file mode 100644 index 15eb62c4e..000000000 --- a/configs/mmcls/classification_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./classification_dynamic.py', '../_base_/backends/ppl.py'] diff --git a/configs/mmcls/classification_ppl_static.py b/configs/mmcls/classification_ppl_static.py deleted file mode 100644 index a5c577a69..000000000 --- a/configs/mmcls/classification_ppl_static.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./classification_static.py', '../_base_/backends/ppl.py'] - -onnx_config = dict(input_shape=None) diff --git a/configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py b/configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py new file mode 100644 index 000000000..eabdf42f2 --- /dev/null +++ b/configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py @@ -0,0 +1,5 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=(1344, 800)) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344])) diff --git a/configs/mmdet/single-stage/single-stage_ppl_dynamic.py b/configs/mmdet/single-stage/single-stage_ppl_dynamic.py deleted file mode 100644 index 5d8068fa3..000000000 --- a/configs/mmdet/single-stage/single-stage_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmdet/single-stage/single-stage_ppl_static.py b/configs/mmdet/single-stage/single-stage_ppl_static.py deleted file mode 100644 index b2eaf0f97..000000000 --- a/configs/mmdet/single-stage/single-stage_ppl_static.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmdet/two-stage/two-stage_ppl_dynamic-800x1344.py b/configs/mmdet/two-stage/two-stage_ppl_dynamic-800x1344.py new file mode 100644 index 000000000..eabdf42f2 --- /dev/null +++ b/configs/mmdet/two-stage/two-stage_ppl_dynamic-800x1344.py @@ -0,0 +1,5 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=(1344, 800)) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344])) diff --git a/configs/mmdet/two-stage/two-stage_ppl_dynamic.py b/configs/mmdet/two-stage/two-stage_ppl_dynamic.py deleted file mode 100644 index 5d8068fa3..000000000 --- a/configs/mmdet/two-stage/two-stage_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmdet/two-stage/two-stage_ppl_static.py b/configs/mmdet/two-stage/two-stage_ppl_static.py deleted file mode 100644 index b2eaf0f97..000000000 --- a/configs/mmdet/two-stage/two-stage_ppl_static.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['../_base_/base_static.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py b/configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py new file mode 100644 index 000000000..e9dbdebb9 --- /dev/null +++ b/configs/mmdet/with-mask/mask_ppl_dynamic-800x1344.py @@ -0,0 +1,5 @@ +_base_ = ['../_base_/mask_base_dynamic.py', '../../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=(1344, 800)) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 800, 1344])) diff --git a/configs/mmedit/super-resolution/super-resolution_ppl_dynamic-32x32.py b/configs/mmedit/super-resolution/super-resolution_ppl_dynamic-32x32.py new file mode 100644 index 000000000..af9f0a504 --- /dev/null +++ b/configs/mmedit/super-resolution/super-resolution_ppl_dynamic-32x32.py @@ -0,0 +1,5 @@ +_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=(32, 32)) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 32, 32])) diff --git a/configs/mmedit/super-resolution/super-resolution_ppl_dynamic.py b/configs/mmedit/super-resolution/super-resolution_ppl_dynamic.py deleted file mode 100644 index 96fc23f4e..000000000 --- a/configs/mmedit/super-resolution/super-resolution_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmedit/super-resolution/super-resolution_ppl_static.py b/configs/mmedit/super-resolution/super-resolution_ppl_static.py deleted file mode 100644 index 0ba1ef277..000000000 --- a/configs/mmedit/super-resolution/super-resolution_ppl_static.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./super-resolution_static.py', '../../_base_/backends/ppl.py'] - -onnx_config = dict(input_shape=None) diff --git a/configs/mmocr/text-detection/text-detection_ppl_dynamic-640x640.py b/configs/mmocr/text-detection/text-detection_ppl_dynamic-640x640.py new file mode 100644 index 000000000..1c0a6de45 --- /dev/null +++ b/configs/mmocr/text-detection/text-detection_ppl_dynamic-640x640.py @@ -0,0 +1,5 @@ +_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=(640, 640)) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 640, 640])) diff --git a/configs/mmocr/text-detection/text-detection_ppl_dynamic.py b/configs/mmocr/text-detection/text-detection_ppl_dynamic.py deleted file mode 100644 index 8e9bf8a02..000000000 --- a/configs/mmocr/text-detection/text-detection_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/ppl.py'] diff --git a/configs/mmocr/text-detection/text-detection_ppl_static.py b/configs/mmocr/text-detection/text-detection_ppl_static.py deleted file mode 100644 index 5bc4bdd9a..000000000 --- a/configs/mmocr/text-detection/text-detection_ppl_static.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./text-detection_static.py', '../../_base_/backends/ppl.py'] - -onnx_config = dict(input_shape=None) diff --git a/configs/mmseg/segmentation_ppl_dynamic-512x1024.py b/configs/mmseg/segmentation_ppl_dynamic-512x1024.py new file mode 100644 index 000000000..b7f1c673d --- /dev/null +++ b/configs/mmseg/segmentation_ppl_dynamic-512x1024.py @@ -0,0 +1,5 @@ +_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py'] + +onnx_config = dict(input_shape=[1024, 512]) + +backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 512, 1024])) diff --git a/configs/mmseg/segmentation_ppl_dynamic.py b/configs/mmseg/segmentation_ppl_dynamic.py deleted file mode 100644 index c45dd6233..000000000 --- a/configs/mmseg/segmentation_ppl_dynamic.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./segmentation_dynamic.py', '../_base_/backends/ppl.py'] diff --git a/configs/mmseg/segmentation_ppl_static.py b/configs/mmseg/segmentation_ppl_static.py deleted file mode 100644 index c809ad3df..000000000 --- a/configs/mmseg/segmentation_ppl_static.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['./segmentation_static.py', '../_base_/backends/ppl.py'] - -onnx_config = dict(input_shape=None) diff --git a/docs/backends/ppl.md b/docs/backends/ppl.md index 6cc4f15ff..0af329b6a 100644 --- a/docs/backends/ppl.md +++ b/docs/backends/ppl.md @@ -13,7 +13,7 @@ This tutorial is based on Linux systems like Ubuntu-18.04. Example: ```bash python tools/deploy.py \ - configs/mmdet/single-stage/single-stage_ppl_dynamic.py \ + configs/mmdet/single-stage/single-stage_ppl_dynamic-800x1344.py \ /mmdetection_dir/mmdetection/configs/retinanet/retinanet_r50_fpn_1x_coco.py \ /tmp/snapshots/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth \ tests/data/tiger.jpeg \ diff --git a/mmdeploy/apis/ppl/__init__.py b/mmdeploy/apis/ppl/__init__.py index 85d24fb3c..4b0d4e954 100644 --- a/mmdeploy/apis/ppl/__init__.py +++ b/mmdeploy/apis/ppl/__init__.py @@ -12,4 +12,5 @@ def is_available(): if is_available(): from .ppl_utils import PPLWrapper, register_engines - __all__ = ['register_engines', 'PPLWrapper'] + from .onnx2ppl import onnx2ppl + __all__ = ['register_engines', 'PPLWrapper', 'onnx2ppl'] diff --git a/mmdeploy/apis/ppl/onnx2ppl.py b/mmdeploy/apis/ppl/onnx2ppl.py new file mode 100644 index 000000000..331f62cb2 --- /dev/null +++ b/mmdeploy/apis/ppl/onnx2ppl.py @@ -0,0 +1,68 @@ +from typing import Optional, Sequence + +import torch +from pyppl import nn as pplnn + +from mmdeploy.apis.ppl import register_engines + + +def parse_cuda_device_id(device: str) -> int: + """Parse cuda device index from a string. + + Args: + device (str): The typical style of string specifying cuda device, + e.g.: 'cuda:0'. + + Returns: + int: The parsed device id, defaults to `0`. + """ + device_id = 0 + if len(device) >= 6: + device_id = torch.device(device).index + return device_id + + +def onnx2ppl(algo_file: str, + onnx_model: str, + device: str = 'cuda:0', + input_shapes: Optional[Sequence[Sequence[int]]] = None, + **kwargs): + """Convert ONNX to PPL. + + PPL is capable of optimizing onnx model. The optimized algorithm is saved + into `algo_file` in json format. Note that `input_shapes` actually require + multiple shapes of inputs in its original design. But in the pipeline of + our codebase, we only pass one input shape which can be modified by users' + own preferences. + + Args: + algo_file (str): File path to save PPL optimization algorithm. + onnx_model (str): Input onnx model. + device (str): A string specifying cuda device, defaults to 'cuda:0'. + input_shapes (Sequence[Sequence[int]] | None): shapes for PPL + optimization, default to None. + + Examples: + >>> from mmdeploy.apis.ppl import onnx2ppl + >>> + >>> onnx2ppl(algo_file = 'example.json', onnx_model = 'example.onnx') + """ + if device == 'cpu': + device_id = -1 + else: + assert 'cuda' in device, f'unexpected device: {device}, must contain ' + '`cpu` or `cuda`' + device_id = parse_cuda_device_id(device) + if input_shapes is None: + input_shapes = [[1, 3, 224, 224]] # PPL default shape for optimization + + engines = register_engines( + device_id, + disable_avx512=False, + quick_select=False, + export_algo_file=algo_file, + input_shapes=input_shapes) + runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile( + onnx_model, engines) + assert runtime_builder is not None, 'Failed to create '\ + 'OnnxRuntimeBuilder.' diff --git a/mmdeploy/apis/ppl/ppl_utils.py b/mmdeploy/apis/ppl/ppl_utils.py index 370181035..9ad8fcfa8 100644 --- a/mmdeploy/apis/ppl/ppl_utils.py +++ b/mmdeploy/apis/ppl/ppl_utils.py @@ -1,18 +1,21 @@ import logging import sys -from typing import Dict +from typing import Dict, Sequence import numpy as np -import pyppl.common as pplcommon -import pyppl.nn as pplnn import torch +from pyppl import common as pplcommon +from pyppl import nn as pplnn from mmdeploy.utils.timer import TimeCounter def register_engines(device_id: int, disable_avx512: bool = False, - quick_select: bool = False): + quick_select: bool = False, + input_shapes: Sequence[Sequence[int]] = None, + export_algo_file: str = None, + import_algo_file: str = None): """Register engines for ppl runtime. Args: @@ -21,6 +24,9 @@ def register_engines(device_id: int, Defaults to `False`. quick_select (bool): Whether to use default algorithms. Defaults to `False`. + input_shapes (Sequence[Sequence[int]]): shapes for PPL optimization. + export_algo_file (str): File path for exporting PPL optimization file. + import_algo_file (str): File path for loading PPL optimization file. Returns: list[pplnn.Engine]: A list of registered ppl engines. @@ -59,6 +65,33 @@ def register_engines(device_id: int, pplcommon.GetRetCodeStr(status)) sys.exit(-1) + if input_shapes is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS, + input_shapes) + if status != pplcommon.RC_SUCCESS: + logging.error( + 'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: ' + + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + if export_algo_file is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS, + export_algo_file) + if status != pplcommon.RC_SUCCESS: + logging.error( + 'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) ' + 'failed: ' + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + if import_algo_file is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS, + import_algo_file) + if status != pplcommon.RC_SUCCESS: + logging.error( + 'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) ' + 'failed: ' + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + engines.append(pplnn.Engine(cuda_engine)) return engines @@ -68,7 +101,8 @@ class PPLWrapper(torch.nn.Module): """PPL wrapper for inference. Args: - model_file (str): Input onnx model file. + onnx_file (str): Path of input ONNX model file. + algo_file (str): Path of PPL algorithm file. device_id (int): Device id to put model. Examples: @@ -76,21 +110,23 @@ class PPLWrapper(torch.nn.Module): >>> import torch >>> >>> onnx_file = 'model.onnx' - >>> model = PPLWrapper(onnx_file, 0) + >>> model = PPLWrapper(onnx_file, 'end2end.json', 0) >>> inputs = dict(input=torch.randn(1, 3, 224, 224)) >>> outputs = model(inputs) >>> print(outputs) """ - def __init__(self, model_file: str, device_id: int): + def __init__(self, onnx_file: str, algo_file: str, device_id: int): super(PPLWrapper, self).__init__() # enable quick select by default to speed up pipeline - # TODO: open it to users after ppl supports saving serialized models # TODO: disable_avx512 will be removed or open to users in config engines = register_engines( - device_id, disable_avx512=False, quick_select=True) + device_id, + disable_avx512=False, + quick_select=False, + import_algo_file=algo_file) runtime_builder = pplnn.OnnxRuntimeBuilderFactory.CreateFromFile( - model_file, engines) + onnx_file, engines) assert runtime_builder is not None, 'Failed to create '\ 'OnnxRuntimeBuilder.' @@ -119,7 +155,12 @@ class PPLWrapper(torch.nn.Module): outputs = [] for i in range(self.runtime.GetOutputCount()): out_tensor = self.runtime.GetOutputTensor(i).ConvertToHost() - outputs.append(np.array(out_tensor, copy=False)) + if out_tensor: + outputs.append(np.array(out_tensor, copy=False)) + else: + out_shape = self.runtime.GetOutputTensor( + i).GetShape().GetDims() + outputs.append(np.random.rand(*out_shape)) return outputs @TimeCounter.count_time() diff --git a/mmdeploy/mmcls/apis/inference.py b/mmdeploy/mmcls/apis/inference.py index e46c75578..b5fe05996 100644 --- a/mmdeploy/mmcls/apis/inference.py +++ b/mmdeploy/mmcls/apis/inference.py @@ -132,15 +132,17 @@ class PPLClassifier(DeployBaseClassifier): """Wrapper for classifier's inference with PPL. Args: - model_file (str): Path of input ONNX model file. + onnx_file (str): Path of input ONNX model file. + algo_file (str): Path of PPL algorithm file. class_names (Sequence[str]): A list of string specifying class names. device_id (int): An integer represents device index. """ - def __init__(self, model_file, class_names, device_id): + def __init__(self, onnx_file, algo_file, class_names, device_id): super(PPLClassifier, self).__init__(class_names, device_id) from mmdeploy.apis.ppl import PPLWrapper - model = PPLWrapper(model_file=model_file, device_id=device_id) + model = PPLWrapper( + onnx_file=onnx_file, algo_file=algo_file, device_id=device_id) self.model = model self.CLASSES = class_names diff --git a/mmdeploy/mmdet/apis/inference.py b/mmdeploy/mmdet/apis/inference.py index 14ccdc244..6d83c0ef1 100644 --- a/mmdeploy/mmdet/apis/inference.py +++ b/mmdeploy/mmdet/apis/inference.py @@ -300,7 +300,7 @@ class PPLDetector(DeployBaseDetector): """Wrapper for detection's inference with PPL. Args: - model_file (str): Path of input ONNX model file. + model_file (str): The path of input model file. class_names (Sequence[str]): A list of string specifying class names. device_id (int): An integer represents device index. """ @@ -308,7 +308,7 @@ class PPLDetector(DeployBaseDetector): def __init__(self, model_file, class_names, device_id, **kwargs): super(PPLDetector, self).__init__(class_names, device_id) from mmdeploy.apis.ppl import PPLWrapper - self.model = PPLWrapper(model_file, device_id) + self.model = PPLWrapper(*model_file, device_id=device_id) def forward_test(self, imgs: torch.Tensor, *args, **kwargs): """Implement forward test. diff --git a/mmdeploy/mmedit/apis/inference.py b/mmdeploy/mmedit/apis/inference.py index 77e26e469..c675574f1 100644 --- a/mmdeploy/mmedit/apis/inference.py +++ b/mmdeploy/mmedit/apis/inference.py @@ -211,14 +211,16 @@ class PPLRestorer(DeployBaseRestorer): """Wrapper for restorer's inference with ppl. Args: - model_file (str): The path of input model file. + onnx_file (str): Path of input ONNX model file. + algo_file (str): Path of PPL algorithm file. device_id (int): An integer represents device index. test_cfg (mmcv.Config): The test config in model config, which is used in evaluation. """ def __init__(self, - model_file: str, + onnx_file: str, + algo_file: str, device_id: int, test_cfg: Optional[mmcv.Config] = None, **kwargs): @@ -226,7 +228,7 @@ class PPLRestorer(DeployBaseRestorer): device_id, test_cfg=test_cfg, **kwargs) from mmdeploy.apis.ppl import PPLWrapper - self.model = PPLWrapper(model_file, device_id) + self.model = PPLWrapper(onnx_file, algo_file, device_id) def forward_dummy(self, lq: torch.Tensor, *args, **kwargs): """Run test inference for restorer with PPL. @@ -278,7 +280,7 @@ def build_restorer(model_files: Sequence[str], backend: Backend, backend_model_class = model_map[model_type] backend_model = backend_model_class( - model_files[0], device_id=device_id, test_cfg=model_cfg.test_cfg) + *model_files, device_id=device_id, test_cfg=model_cfg.test_cfg) return backend_model diff --git a/mmdeploy/mmocr/apis/inference.py b/mmdeploy/mmocr/apis/inference.py index a7cc52a50..ab8e8997f 100644 --- a/mmdeploy/mmocr/apis/inference.py +++ b/mmdeploy/mmocr/apis/inference.py @@ -373,7 +373,7 @@ class PPLDetector(DeployBaseTextDetector): """Wrapper for TextDetector with PPL. Args: - model_file (str): The path of input model file. + model_file (Sequence[str]): Paths of input model files. cfg (str | mmcv.ConfigDict): Input model config. device_id (int): An integer represents device index. show_score (bool): Whether to show scores. Defaults to `False`. @@ -388,7 +388,7 @@ class PPLDetector(DeployBaseTextDetector): **kwargs): super(PPLDetector, self).__init__(cfg, device_id, show_score) from mmdeploy.apis.ppl import PPLWrapper - model = PPLWrapper(model_file, device_id) + model = PPLWrapper(model_file[0], model_file[1], device_id) self.model = model def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict], @@ -411,7 +411,8 @@ class PPLRecognizer(DeployBaseRecognizer): """Wrapper for TextRecognizer with PPL. Args: - model_file (str): The path of input model file. + onnx_file (str): Path of input ONNX model file. + algo_file (str): Path of PPL algorithm file. cfg (str | mmcv.ConfigDict): Input model config. device_id (int): An integer represents device index. show_score (bool): Whether to show scores. Defaults to `False`. @@ -419,6 +420,7 @@ class PPLRecognizer(DeployBaseRecognizer): def __init__(self, model_file: str, + algo_file: str, cfg: Union[mmcv.Config, mmcv.ConfigDict], device_id: int, show_score: bool = False, @@ -426,7 +428,7 @@ class PPLRecognizer(DeployBaseRecognizer): **kwargs): super(PPLRecognizer, self).__init__(cfg, device_id, show_score) from mmdeploy.apis.ppl import PPLWrapper - model = PPLWrapper(model_file, device_id) + model = PPLWrapper(model_file, algo_file, device_id) self.model = model def forward_of_backend(self, img: torch.Tensor, img_metas: Sequence[dict], diff --git a/mmdeploy/mmseg/apis/inference.py b/mmdeploy/mmseg/apis/inference.py index 5f95678f3..1b9b53c01 100644 --- a/mmdeploy/mmseg/apis/inference.py +++ b/mmdeploy/mmseg/apis/inference.py @@ -144,7 +144,7 @@ class PPLSegmentor(DeployBaseSegmentor): """Wrapper for segmentation's inference with PPL. Args: - model_file (str): The path of input model file. + model_file (Sequence[str]): Paths of input params and bin files. class_names (Sequence[str]): A list of string specifying class names. palette (np.ndarray): The palette of segmentation map. device_id (int): An integer represents device index. @@ -154,7 +154,7 @@ class PPLSegmentor(DeployBaseSegmentor): palette: np.ndarray, device_id: int): super(PPLSegmentor, self).__init__(class_names, palette, device_id) from mmdeploy.apis.ppl import PPLWrapper - self.model = PPLWrapper(model_file, device_id) + self.model = PPLWrapper(model_file[0], model_file[1], device_id) def forward_test(self, imgs: torch.Tensor, img_metas: Sequence[dict], **kwargs): diff --git a/tests/test_apis/test_wrapper.py b/tests/test_apis/test_wrapper.py index ad600ca89..c61c95559 100644 --- a/tests/test_apis/test_wrapper.py +++ b/tests/test_apis/test_wrapper.py @@ -87,7 +87,10 @@ def onnx2backend(backend, onnx_file): elif backend == Backend.ONNXRUNTIME: return onnx_file elif backend == Backend.PPL: - return onnx_file + from mmdeploy.apis.ppl import onnx2ppl + algo_file = tempfile.NamedTemporaryFile(suffix='.json').name + onnx2ppl(algo_file=algo_file, onnx_model=onnx_file) + return onnx_file, algo_file elif backend == Backend.NCNN: from mmdeploy.apis.ncnn import get_onnx2ncnn_path onnx2ncnn_path = get_onnx2ncnn_path() @@ -117,7 +120,7 @@ def create_wrapper(backend, model_files): return ort_model elif backend == Backend.PPL: from mmdeploy.apis.ppl import PPLWrapper - ppl_model = PPLWrapper(model_files, 0) + ppl_model = PPLWrapper(model_files[0], None, device_id=0) return ppl_model elif backend == Backend.NCNN: from mmdeploy.apis.ncnn import NCNNWrapper diff --git a/tests/test_mmcls/test_mmcls_apis.py b/tests/test_mmcls/test_mmcls_apis.py index c92165627..3b20b022b 100644 --- a/tests/test_mmcls/test_mmcls_apis.py +++ b/tests/test_mmcls/test_mmcls_apis.py @@ -94,7 +94,7 @@ def test_PPLClassifier(): wrapper.set(outputs=outputs) from mmdeploy.mmcls.apis.inference import PPLClassifier - ppl_classifier = PPLClassifier('', [''], 0) + ppl_classifier = PPLClassifier('', '', [''], 0) imgs = torch.rand(1, 3, 64, 64) results = ppl_classifier.forward(imgs, return_loss=False) diff --git a/tests/test_mmedit/test_mmedit_apis.py b/tests/test_mmedit/test_mmedit_apis.py index 61298b6ea..14fcf9b2f 100644 --- a/tests/test_mmedit/test_mmedit_apis.py +++ b/tests/test_mmedit/test_mmedit_apis.py @@ -86,7 +86,7 @@ def test_PPLRestorer(): wrapper.set(outputs=outputs) from mmdeploy.mmedit.apis.inference import PPLRestorer - ppl_restorer = PPLRestorer('', 0) + ppl_restorer = PPLRestorer('', '', 0) imgs = torch.rand(1, 3, 64, 64) results = ppl_restorer.forward(imgs) diff --git a/tests/test_mmocr/test_mmocr_apis.py b/tests/test_mmocr/test_mmocr_apis.py index 371b36687..ff9da9a3c 100755 --- a/tests/test_mmocr/test_mmocr_apis.py +++ b/tests/test_mmocr/test_mmocr_apis.py @@ -89,7 +89,7 @@ def test_PPLDetector(): from mmdeploy.mmocr.apis.inference import PPLDetector model_config = mmcv.Config.fromfile( 'tests/test_mmocr/data/config/dbnet.py') - ppl_detector = PPLDetector('', model_config, 0, False) + ppl_detector = PPLDetector(['', ''], model_config, 0, False) imgs = [torch.rand(1, 3, 64, 64)] img_metas = [[{ 'ori_shape': [64, 64, 3], @@ -194,7 +194,7 @@ def test_PPLRecognizer(): from mmdeploy.mmocr.apis.inference import PPLRecognizer model_config = mmcv.Config.fromfile( 'tests/test_mmocr/data/config/crnn.py') - ppl_recognizer = PPLRecognizer('', model_config, 0, False) + ppl_recognizer = PPLRecognizer('', '', model_config, 0, False) imgs = [torch.rand(1, 1, 32, 32)] img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]] diff --git a/tests/test_mmseg/test_mmseg_apis.py b/tests/test_mmseg/test_mmseg_apis.py index a172d864d..f038ef285 100644 --- a/tests/test_mmseg/test_mmseg_apis.py +++ b/tests/test_mmseg/test_mmseg_apis.py @@ -89,7 +89,7 @@ def test_PPLSegmentor(): wrapper.set(outputs=outputs) from mmdeploy.mmseg.apis.inference import PPLSegmentor - ppl_segmentor = PPLSegmentor('', ['' for i in range(19)], + ppl_segmentor = PPLSegmentor(['', ''], ['' for i in range(19)], np.empty([19], dtype=int), 0) imgs = torch.rand(1, 3, 64, 64) img_metas = [[{ diff --git a/tools/deploy.py b/tools/deploy.py index 61ef4900b..3c04d3c2d 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -96,12 +96,12 @@ def main(): # load deploy_cfg deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) - if args.dump_info: - dump_info(deploy_cfg, model_cfg, args.work_dir) - # create work_dir if not mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + if args.dump_info: + dump_info(deploy_cfg, model_cfg, args.work_dir) + ret_value = mp.Value('d', 0, lock=False) # convert onnx @@ -241,6 +241,31 @@ def main(): openvino_files.append(model_xml_path) backend_files = openvino_files + elif backend == Backend.PPL: + from mmdeploy.apis.ppl import \ + is_available as is_available_ppl + assert is_available_ppl(), \ + 'PPL is not available, please install PPL first.' + + from mmdeploy.apis.ppl import onnx2ppl + ppl_files = [] + for onnx_path in onnx_files: + algo_file = onnx_path.replace('.onnx', '.json') + model_inputs = get_model_inputs(deploy_cfg) + assert 'opt_shape' in model_inputs, 'expect opt_shape ' + 'in deploy config for ppl' + # PPL accepts only 1 input shape for optimization, + # may get changed in the future + input_shapes = [model_inputs.opt_shape] + create_process( + f'onnx2ppl with {onnx_path}', + target=onnx2ppl, + args=(algo_file, onnx_path), + kwargs=dict(device=args.device, input_shapes=input_shapes), + ret_value=ret_value) + ppl_files += [onnx_path, algo_file] + backend_files = ppl_files + if args.test_img is None: args.test_img = args.img # visualize model of the backend