[Refactor] Add backend manager for 1.x (#1515)

* backend manager 1.x

* update pplnn init

* rename file

* add to backend

* add check env and misc

* fix action

* fix ut

* fix comment
pull/1576/head
q.yao 2022-12-28 11:38:01 +08:00 committed by GitHub
parent 6288141bd5
commit d6fdb3e860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 1959 additions and 981 deletions

View File

@ -79,4 +79,4 @@ jobs:
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
python -m pip install mmcv-lite python -m pip install mmcv-lite
python tools/scripts/build_ubuntu_x64_ncnn.py python tools/scripts/build_ubuntu_x64_ncnn.py
python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available() and ncnn_api.is_custom_ops_available()' python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available(with_custom_ops=True)'

View File

@ -36,7 +36,7 @@ jobs:
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
python -m pip install mmcv-lite openmim python -m pip install mmcv-lite openmim
python tools/scripts/build_ubuntu_x64_ort.py python tools/scripts/build_ubuntu_x64_ort.py
python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available() and ort_api.is_custom_ops_available()' python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available(with_custom_ops=True)'
- name: test mmcls full pipeline - name: test mmcls full pipeline
run: | run: |
python -m mim install $(cat requirements/codebases.txt | grep mmcls) python -m mim install $(cat requirements/codebases.txt | grep mmcls)

View File

@ -31,3 +31,6 @@ mmdeploy_export(${PROJECT_NAME}_obj)
mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "") mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj) target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj)
add_library(mmdeploy::torchscript_ops ALIAS ${PROJECT_NAME}) add_library(mmdeploy::torchscript_ops ALIAS ${PROJECT_NAME})
set(_TORCHJIT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_TORCHJIT_OPS_DIR})

View File

@ -123,32 +123,20 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
__all__ += ['onnx2ncnn', 'get_output_model_file'] __all__ += ['onnx2ncnn', 'get_output_model_file']
``` ```
Then add the codes about conversion to `tools/deploy.py` using these APIs if necessary. Create a backend manager class which derive from `BaseBackendManager`, implement its `to_backend` static method.
**Example:** **Example:**
```Python ```Python
# tools/deploy.py @classmethod
# ... def to_backend(cls,
elif backend == Backend.NCNN: ir_files: Sequence[str],
from mmdeploy.apis.ncnn import is_available as is_available_ncnn deploy_cfg: Any,
work_dir: str,
if not is_available_ncnn(): log_level: int = logging.INFO,
logging.error('ncnn support is not available.') device: str = 'cpu',
exit(-1) **kwargs) -> Sequence[str]:
return ir_files
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
backend_files = []
for onnx_path in onnx_files:
create_process(
f'onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
``` ```
6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators. 6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators.
@ -209,23 +197,26 @@ Although the backend engines are usually implemented in C/C++, it is convenient
self.sess.run_with_iobinding(io_binding) self.sess.run_with_iobinding(io_binding)
``` ```
4. Add a default initialization method for the new wrapper in `mmdeploy/codebase/base/backend_model.py` 4. Create a backend manager class which derive from `BaseBackendManager`, implement its `build_wrapper` static method.
**Example:** **Example:**
```Python ```Python
@staticmethod @BACKEND_MANAGERS.register('onnxruntime')
def _build_wrapper(backend: Backend, class ONNXRuntimeManager(BaseBackendManager):
backend_files: Sequence[str], @classmethod
device: str, def build_wrapper(cls,
input_names: Optional[Sequence[str]] = None, backend_files: Sequence[str],
output_names: Optional[Sequence[str]] = None): device: str = 'cpu',
if backend == Backend.ONNXRUNTIME: input_names: Optional[Sequence[str]] = None,
from mmdeploy.backend.onnxruntime import ORTWrapper output_names: Optional[Sequence[str]] = None,
return ORTWrapper( deploy_cfg: Optional[Any] = None,
onnx_file=backend_files[0], **kwargs):
device=device, from .wrapper import ORTWrapper
output_names=output_names) return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
``` ```
5. Add docstring and unit tests for new code :). 5. Add docstring and unit tests for new code :).

View File

@ -123,32 +123,20 @@ MMDeploy 中的后端必须支持 ONNX因此后端能直接加载“.onnx”
__all__ += ['onnx2ncnn', 'get_output_model_file'] __all__ += ['onnx2ncnn', 'get_output_model_file']
``` ```
然后根据需要使用这些 APIs 为 `tools/deploy.py` 添加相关转换代码 从 BaseBackendManager 派生类,实现 `to_backend` 类方法。
**例子** **例子**
```Python ```Python
# tools/deploy.py @classmethod
# ... def to_backend(cls,
elif backend == Backend.NCNN: ir_files: Sequence[str],
from mmdeploy.apis.ncnn import is_available as is_available_ncnn deploy_cfg: Any,
work_dir: str,
if not is_available_ncnn(): log_level: int = logging.INFO,
logging.error('ncnn support is not available.') device: str = 'cpu',
exit(-1) **kwargs) -> Sequence[str]:
return ir_files
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
backend_files = []
for onnx_path in onnx_files:
create_process(
f'mmdeploy_onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
``` ```
6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。 6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
@ -210,22 +198,26 @@ MMDeploy 中的后端必须支持 ONNX因此后端能直接加载“.onnx”
self.sess.run_with_iobinding(io_binding) self.sess.run_with_iobinding(io_binding)
``` ```
4. 为新封装装器添加默认初始化方法 `mmdeploy/codebase/base/backend_model.py` 4. `BaseBackendManager` 派生接口类,实现 `build_wrapper` 静态方法
**例子** **例子**
```Python ```Python
@staticmethod @BACKEND_MANAGERS.register('onnxruntime')
def _build_wrapper(backend: Backend, class ONNXRuntimeManager(BaseBackendManager):
backend_files: Sequence[str], @classmethod
device: str, def build_wrapper(cls,
output_names: Optional[Sequence[str]] = None): backend_files: Sequence[str],
if backend == Backend.ONNXRUNTIME: device: str = 'cpu',
from mmdeploy.backend.onnxruntime import ORTWrapper input_names: Optional[Sequence[str]] = None,
return ORTWrapper( output_names: Optional[Sequence[str]] = None,
onnx_file=backend_files[0], deploy_cfg: Optional[Any] = None,
device=device, **kwargs):
output_names=output_names) from .wrapper import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
``` ```
5. 为新后端引擎代码添加相关注释和单元测试 :). 5. 为新后端引擎代码添加相关注释和单元测试 :).

View File

@ -1,19 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .calibration import create_calib_input_data
from .extract_model import extract_model
from .inference import inference_model
from .pytorch2onnx import torch2onnx
from .pytorch2torchscript import torch2torchscript
from .utils import build_task_processor, get_predefined_partition_cfg
from .visualize import visualize_model
# mmcv & mmengine dependency __all__ = [
try: 'create_calib_input_data', 'extract_model', 'inference_model',
from .calibration import create_calib_input_data 'torch2onnx', 'torch2torchscript', 'build_task_processor',
from .extract_model import extract_model 'get_predefined_partition_cfg', 'visualize_model'
from .inference import inference_model ]
from .pytorch2onnx import torch2onnx
from .pytorch2torchscript import torch2torchscript
from .utils import build_task_processor, get_predefined_partition_cfg
from .visualize import visualize_model
__all__ = [
'create_calib_input_data', 'extract_model', 'inference_model',
'torch2onnx', 'torch2torchscript', 'build_task_processor',
'get_predefined_partition_cfg', 'visualize_model'
]
except Exception:
pass

View File

@ -4,11 +4,7 @@ from typing import Optional, Union
from mmengine import Config from mmengine import Config
from mmdeploy.core import patch_model
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend, get_ir_config,
load_config)
from .core import PIPELINE_MANAGER, no_mp from .core import PIPELINE_MANAGER, no_mp
from .utils import create_calib_input_data as create_calib_input_data_impl
@PIPELINE_MANAGER.register_pipeline() @PIPELINE_MANAGER.register_pipeline()
@ -34,6 +30,11 @@ def create_calib_input_data(calib_file: str,
dataset_type (str, optional): The dataset type. Defaults to 'val'. dataset_type (str, optional): The dataset type. Defaults to 'val'.
device (str, optional): Device to create dataset. Defaults to 'cpu'. device (str, optional): Device to create dataset. Defaults to 'cpu'.
""" """
from mmdeploy.core import patch_model
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend,
get_ir_config, load_config)
from .utils import create_calib_input_data as create_calib_input_data_impl
with no_mp(): with no_mp():
if dataset_cfg is None: if dataset_cfg is None:
dataset_cfg = model_cfg dataset_cfg = model_cfg

View File

@ -5,7 +5,6 @@ from typing import Dict, Iterable, Optional, Union
import onnx import onnx
from .core import PIPELINE_MANAGER from .core import PIPELINE_MANAGER
from .onnx import extract_partition
@PIPELINE_MANAGER.register_pipeline() @PIPELINE_MANAGER.register_pipeline()
@ -63,5 +62,6 @@ def extract_model(model: Union[str, onnx.ModelProto],
onnx.ModelProto: The extracted model. onnx.ModelProto: The extracted model.
""" """
from .onnx import extract_partition
return extract_partition(model, start_marker, end_marker, start_name_map, return extract_partition(model, start_marker, end_marker, start_name_map,
end_name_map, dynamic_axes, save_file) end_name_map, dynamic_axes, save_file)

View File

@ -3,9 +3,6 @@ from typing import Any, Sequence, Union
import mmengine import mmengine
import numpy as np import numpy as np
import torch
from mmdeploy.utils import get_input_shape, load_config
def inference_model(model_cfg: Union[str, mmengine.Config], def inference_model(model_cfg: Union[str, mmengine.Config],
@ -37,6 +34,9 @@ def inference_model(model_cfg: Union[str, mmengine.Config],
Returns: Returns:
Any: The inference results Any: The inference results
""" """
import torch
from mmdeploy.utils import get_input_shape, load_config
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
from mmdeploy.apis.utils import build_task_processor from mmdeploy.apis.utils import build_task_processor

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.ncnn import from_onnx as _from_onnx from mmdeploy.backend.ncnn import from_onnx as _from_onnx
from mmdeploy.backend.ncnn import is_available, is_custom_ops_available from mmdeploy.backend.ncnn import is_available
from ..core import PIPELINE_MANAGER from ..core import PIPELINE_MANAGER
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
__all__ = ['is_available', 'is_custom_ops_available', 'from_onnx'] __all__ = ['is_available', 'from_onnx']
if is_available(): if is_available():
try: try:

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.onnxruntime import is_available, is_custom_ops_available from mmdeploy.backend.onnxruntime import is_available
__all__ = ['is_available', 'is_custom_ops_available'] __all__ = ['is_available']

View File

@ -4,11 +4,7 @@ from typing import Any, Optional, Union
import mmengine import mmengine
from mmdeploy.apis.core.pipeline_manager import no_mp
from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes,
get_input_shape, get_onnx_config, load_config)
from .core import PIPELINE_MANAGER from .core import PIPELINE_MANAGER
from .onnx import export
@PIPELINE_MANAGER.register_pipeline() @PIPELINE_MANAGER.register_pipeline()
@ -48,6 +44,12 @@ def torch2onnx(img: Any,
defaults to `None`. defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'. device (str): A string specifying device type, defaults to 'cuda:0'.
""" """
from mmdeploy.apis.core.pipeline_manager import no_mp
from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes,
get_input_shape, get_onnx_config, load_config)
from .onnx import export
# load deploy_cfg if necessary # load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmengine.mkdir_or_exist(osp.abspath(work_dir)) mmengine.mkdir_or_exist(osp.abspath(work_dir))

View File

@ -3,11 +3,8 @@ import os.path as osp
from typing import Any, Optional, Union from typing import Any, Optional, Union
import mmengine import mmengine
import torch
from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp
from mmdeploy.utils import get_backend, get_input_shape, load_config
from .torch_jit import trace
@PIPELINE_MANAGER.register_pipeline() @PIPELINE_MANAGER.register_pipeline()
@ -32,6 +29,11 @@ def torch2torchscript(img: Any,
defaults to `None`. defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'. device (str): A string specifying device type, defaults to 'cuda:0'.
""" """
import torch
from mmdeploy.utils import get_backend, get_input_shape, load_config
from .torch_jit import trace
# load deploy_cfg if necessary # load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmengine.mkdir_or_exist(osp.abspath(work_dir)) mmengine.mkdir_or_exist(osp.abspath(work_dir))

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.tensorrt import is_available, is_custom_ops_available from mmdeploy.backend.tensorrt import is_available
from ..core import PIPELINE_MANAGER from ..core import PIPELINE_MANAGER
__all__ = ['is_available', 'is_custom_ops_available'] __all__ = ['is_available']
if is_available(): if is_available():
from mmdeploy.backend.tensorrt import from_onnx as _from_onnx from mmdeploy.backend.tensorrt import from_onnx as _from_onnx

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .calibration import create_calib_input_data from .calibration import create_calib_input_data
from .utils import build_task_processor, get_predefined_partition_cfg from .utils import (build_task_processor, get_predefined_partition_cfg,
to_backend)
__all__ = [ __all__ = [
'create_calib_input_data', 'build_task_processor', 'create_calib_input_data', 'build_task_processor',
'get_predefined_partition_cfg' 'get_predefined_partition_cfg', 'to_backend'
] ]

View File

@ -2,12 +2,9 @@
from copy import deepcopy from copy import deepcopy
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
import h5py
import torch import torch
import tqdm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmdeploy.core import RewriterContext, reset_mark_function_count
from ..core import PIPELINE_MANAGER from ..core import PIPELINE_MANAGER
@ -46,7 +43,10 @@ def create_calib_input_data(calib_file: str,
'val', defaults to 'val'. 'val', defaults to 'val'.
device (str): Specifying the device to run on, defaults to 'cpu'. device (str): Specifying the device to run on, defaults to 'cpu'.
""" """
import h5py
import tqdm
from mmdeploy.core import RewriterContext, reset_mark_function_count
backend = 'default' backend = 'default'
with h5py.File(calib_file, mode='w') as file: with h5py.File(calib_file, mode='w') as file:

View File

@ -1,10 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Optional, Sequence
import mmengine import mmengine
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
from mmdeploy.utils import (get_backend, get_codebase, get_task_type, from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
parse_device_id) parse_device_id)
from mmdeploy.utils.config_utils import get_codebase_external_module from mmdeploy.utils.config_utils import get_codebase_external_module
from ..core import PIPELINE_MANAGER
def check_backend_device(deploy_cfg: mmengine.Config, device: str): def check_backend_device(deploy_cfg: mmengine.Config, device: str):
@ -66,3 +70,35 @@ def get_predefined_partition_cfg(deploy_cfg: mmengine.Config,
codebase = get_codebase_class(codebase_type) codebase = get_codebase_class(codebase_type)
task_processor_class = codebase.get_task_class(task) task_processor_class = codebase.get_task_class(task)
return task_processor_class.get_partition_cfg(partition_type) return task_processor_class.get_partition_cfg(partition_type)
@PIPELINE_MANAGER.register_pipeline()
def to_backend(backend_name: str,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Optional[Any] = None,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
backend_name (str): The name of the backend.
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from mmdeploy.backend.base import get_backend_manager
backend_mgr = get_backend_manager(backend_name)
return backend_mgr.to_backend(
ir_files=ir_files,
work_dir=work_dir,
deploy_cfg=deploy_cfg,
log_level=log_level,
device=device,
**kwargs)

View File

@ -5,13 +5,12 @@ import mmengine
import numpy as np import numpy as np
import torch import torch
from mmdeploy.codebase import BaseTask
from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config
def visualize_model(model_cfg: Union[str, mmengine.Config], def visualize_model(model_cfg: Union[str, mmengine.Config],
deploy_cfg: Union[str, mmengine.Config], deploy_cfg: Union[str, mmengine.Config],
model: Union[str, Sequence[str], BaseTask], model: Union[str, Sequence[str]],
img: Union[str, np.ndarray], img: Union[str, np.ndarray],
device: str, device: str,
backend: Optional[Backend] = None, backend: Optional[Backend] = None,

View File

@ -1,19 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import AscendManager
from .utils import update_sdk_pipeline from .utils import update_sdk_pipeline
_BackendManager = AscendManager
def is_available(): is_available = _BackendManager.is_available
"""Check whether acl is installed. build_wrapper = _BackendManager.build_wrapper
Returns: __all__ = ['update_sdk_pipeline', 'AscendManager']
bool: True if acl package is installed.
"""
return importlib.util.find_spec('acl') is not None
__all__ = ['update_sdk_pipeline']
if is_available(): if is_available():
from .wrapper import AscendWrapper, Error from .wrapper import AscendWrapper, Error

View File

@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('ascend')
class AscendManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import AscendWrapper
return AscendWrapper(model=backend_files[0], device=device)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
return importlib.util.find_spec('acl') is not None
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('acl').version
except Exception:
return 'None'
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from mmdeploy.utils import get_model_inputs
from .onnx2ascend import from_onnx
model_inputs = get_model_inputs(deploy_cfg)
om_files = []
for model_id, onnx_path in enumerate(ir_files):
om_path = osp.splitext(onnx_path)[0] + '.om'
from_onnx(onnx_path, work_dir, model_inputs[model_id])
om_files.append(om_path)
backend_files = om_files
return backend_files

View File

@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .backend_manager import (BACKEND_MANAGERS, BaseBackendManager,
get_backend_manager)
from .backend_wrapper_registry import (BACKEND_WRAPPER, get_backend_file_count, from .backend_wrapper_registry import (BACKEND_WRAPPER, get_backend_file_count,
get_backend_wrapper_class) get_backend_wrapper_class)
from .base_wrapper import BaseWrapper from .base_wrapper import BaseWrapper
__all__ = [ __all__ = [
'BACKEND_MANAGERS', 'BaseBackendManager', 'get_backend_manager',
'BaseWrapper', 'BACKEND_WRAPPER', 'get_backend_wrapper_class', 'BaseWrapper', 'BACKEND_WRAPPER', 'get_backend_wrapper_class',
'get_backend_file_count' 'get_backend_file_count'
] ]

View File

@ -0,0 +1,173 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import logging
from abc import ABCMeta
from typing import Any, Callable, Optional, Sequence
class BaseBackendManager(metaclass=ABCMeta):
"""Abstract interface of backend manager."""
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
raise NotImplementedError(
f'build_wrapper has not been implemented for `{cls.__name__}`')
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
raise NotImplementedError(
f'is_available has not been implemented for "{cls.__name__}"')
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
raise NotImplementedError(
f'get_version has not been implemented for "{cls.__name__}"')
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
try:
available = cls.is_available()
if available:
try:
backend_version = cls.get_version()
except NotImplementedError:
backend_version = 'Unknown'
else:
backend_version = 'None'
info = f'{cls.backend_name}:\t{backend_version}'
except Exception:
info = f'{cls.backend_name}:\tCheckFailed'
log_callback(info)
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
raise NotImplementedError(
f'to_backend has not been implemented for `{cls.__name__}`')
class BackendManagerRegistry:
"""backend manager registry."""
def __init__(self):
self._module_dict = {}
def register(self, name: str, enum_name: Optional[str] = None):
"""register backend manager.
Args:
name (str): name of the backend
enum_name (Optional[str], optional): enum name of the backend.
if not given, the upper case of name would be used.
"""
from mmdeploy.utils import get_root_logger
logger = get_root_logger()
if enum_name is None:
enum_name = name.upper()
def wrap_manager(cls):
from mmdeploy.utils import Backend
if not hasattr(Backend, enum_name):
from aenum import extend_enum
extend_enum(Backend, enum_name, name)
logger.info(f'Registry new backend: {enum_name} = {name}.')
if name in self._module_dict:
logger.info(
f'Backend manager of `{name}` has already been registered.'
)
self._module_dict[name] = cls
cls.backend_name = name
return cls
return wrap_manager
def find(self, name: str) -> BaseBackendManager:
"""Find the backend manager with name.
Args:
name (str): backend name.
Returns:
BaseBackendManager: backend manager of the given backend.
"""
# try import backend if backend is in `mmdeploy.backend`
try:
importlib.import_module('mmdeploy.backend.' + name)
except Exception:
pass
return self._module_dict.get(name, None)
BACKEND_MANAGERS = BackendManagerRegistry()
def get_backend_manager(name: str) -> BaseBackendManager:
"""Get backend manager.
Args:
name (str): name of the backend.
Returns:
BaseBackendManager: The backend manager of given name
"""
from enum import Enum
if isinstance(name, Enum):
name = name.value
return BACKEND_MANAGERS.find(name)

View File

@ -1,18 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .backend_manager import CoreMLManager
import importlib _BackendManager = CoreMLManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['CoreMLManager']
"""Check whether coremltools is installed.
Returns:
bool: True if coremltools package is installed.
"""
return importlib.util.find_spec('coremltools') is not None
__all__ = []
if is_available(): if is_available():
from . import ops from . import ops

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('coreml')
class CoreMLManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import CoreMLWrapper
return CoreMLWrapper(model_file=backend_files[0])
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
return importlib.util.find_spec('coreml') is not None
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('coreml').version
except Exception:
return 'None'
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from .torchscript2coreml import from_torchscript
coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
from_torchscript(model_id, torchscript_path, output_file_prefix,
deploy_cfg, coreml_files)
return coreml_files

View File

@ -1,38 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import NCNNManager
import os.path as osp
from .init_plugins import get_onnx2ncnn_path, get_ops_path
from .onnx2ncnn import from_onnx from .onnx2ncnn import from_onnx
_BackendManager = NCNNManager
def is_available(): is_available = _BackendManager.is_available
"""Check whether ncnn and onnx2ncnn tool are installed. build_wrapper = _BackendManager.build_wrapper
Returns: __all__ = ['NCNNManager', 'from_onnx']
bool: True if ncnn and onnx2ncnn tool are installed.
"""
has_pyncnn = importlib.util.find_spec('ncnn') is not None
onnx2ncnn = get_onnx2ncnn_path()
return has_pyncnn and osp.exists(onnx2ncnn)
def is_custom_ops_available():
"""Check whether ncnn extension and custom ops are installed.
Returns:
bool: True if ncnn extension and custom ops are compiled.
"""
has_pyncnn_ext = importlib.util.find_spec(
'mmdeploy.backend.ncnn.ncnn_ext') is not None
ncnn_ops_path = get_ops_path()
return has_pyncnn_ext and osp.exists(ncnn_ops_path)
__all__ = ['from_onnx']
if is_available(): if is_available():
try: try:

View File

@ -0,0 +1,145 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
import sys
from typing import Any, Callable, Optional, Sequence
from mmdeploy.utils import get_backend_config, get_root_logger
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('ncnn')
class NCNNManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import NCNNWrapper
# For unittest deploy_config will not pass into _build_wrapper
# function.
if deploy_cfg:
backend_config = get_backend_config(deploy_cfg)
use_vulkan = backend_config.get('use_vulkan', False)
else:
use_vulkan = False
return NCNNWrapper(
param_file=backend_files[0],
bin_file=backend_files[1],
output_names=output_names,
use_vulkan=use_vulkan)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
from .init_plugins import get_onnx2ncnn_path, get_ops_path
has_pyncnn = importlib.util.find_spec('ncnn') is not None
onnx2ncnn = get_onnx2ncnn_path()
ret = has_pyncnn and (onnx2ncnn is not None)
if ret and with_custom_ops:
has_pyncnn_ext = importlib.util.find_spec(
'mmdeploy.backend.ncnn.ncnn_ext') is not None
op_path = get_ops_path()
custom_ops_exist = osp.exists(op_path)
ret = ret and has_pyncnn_ext and custom_ops_exist
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('ncnn').version
except Exception:
return 'None'
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
info = super().check_env(log_callback=log_callback)
available = cls.is_available()
ops_available = cls.is_available(with_custom_ops=True)
ops_available = 'Available' if ops_available else 'NotAvailable'
if available:
ops_info = f'ncnn custom ops:\t{ops_available}'
log_callback(ops_info)
info = f'{info}\n{ops_info}'
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
logger = get_root_logger()
from . import is_available
if not is_available():
logger.error('ncnn support is not available, please make sure:\n'
'1) `mmdeploy_onnx2ncnn` existed in `PATH`\n'
'2) python import ncnn success')
sys.exit(1)
from mmdeploy.apis.ncnn import get_output_model_file
from .onnx2ncnn import from_onnx
backend_files = []
for onnx_path in ir_files:
model_param_path, model_bin_path = get_output_model_file(
onnx_path, work_dir)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
from_onnx(onnx_path, osp.join(work_dir, onnx_name))
backend_files += [model_param_path, model_bin_path]
return backend_files

View File

@ -1,34 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import ONNXRuntimeManager
import os.path as osp
from .init_plugins import get_ops_path _BackendManager = ONNXRuntimeManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['ONNXRuntimeManager']
"""Check whether ONNX Runtime package is installed.
Returns:
bool: True if ONNX Runtime package is installed.
"""
return importlib.util.find_spec('onnxruntime') is not None
def is_custom_ops_available():
"""Check whether ONNX Runtime custom ops are installed.
Returns:
bool: True if ONNX Runtime custom ops are compiled.
"""
onnxruntime_op_path = get_ops_path()
return osp.exists(onnxruntime_op_path)
if is_available(): if is_available():
try: try:
# import wrapper if pytorch is available # import wrapper if pytorch is available
from .wrapper import ORTWrapper from .wrapper import ORTWrapper
__all__ = ['ORTWrapper'] __all__ += ['ORTWrapper']
except Exception: except Exception:
pass pass

View File

@ -0,0 +1,142 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Callable, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('onnxruntime')
class ONNXRuntimeManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('onnxruntime') is not None
if ret and with_custom_ops:
from .init_plugins import get_ops_path
ops_path = get_ops_path()
custom_ops_exist = osp.exists(ops_path)
ret = ret and custom_ops_exist
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
ort_version = pkg_resources.get_distribution(
'onnxruntime').version
except Exception:
ort_version = 'None'
try:
ort_gpu_version = pkg_resources.get_distribution(
'onnxruntime-gpu').version
except Exception:
ort_gpu_version = 'None'
if ort_gpu_version != 'None':
return ort_gpu_version
else:
return ort_version
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
import pkg_resources
try:
if cls.is_available():
ops_available = cls.is_available(with_custom_ops=True)
ops_available = 'Available' \
if ops_available else 'NotAvailable'
try:
ort_version = pkg_resources.get_distribution(
'onnxruntime').version
except Exception:
ort_version = 'None'
try:
ort_gpu_version = pkg_resources.get_distribution(
'onnxruntime-gpu').version
except Exception:
ort_gpu_version = 'None'
ort_info = f'ONNXRuntime:\t{ort_version}'
log_callback(ort_info)
ort_gpu_info = f'ONNXRuntime-gpu:\t{ort_gpu_version}'
log_callback(ort_gpu_info)
ort_ops_info = f'ONNXRuntime custom ops:\t{ops_available}'
log_callback(ort_ops_info)
info = f'{ort_info}\n{ort_gpu_info}\n{ort_ops_info}'
else:
info = 'ONNXRuntime:\tNone'
log_callback(info)
except Exception:
info = f'{cls.backend_name}:\tCheckFailed'
log_callback(info)
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
return ir_files

View File

@ -44,11 +44,11 @@ class ORTWrapper(BaseWrapper):
logger = get_root_logger() logger = get_root_logger()
if osp.exists(ort_custom_op_path): if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path) session_options.register_custom_ops_library(ort_custom_op_path)
logger.info(f'Successfully loaded onnxruntime custom ops from \ logger.info('Successfully loaded onnxruntime custom ops from '
{ort_custom_op_path}') f'{ort_custom_op_path}')
else: else:
logger.warning(f'The library of onnxruntime custom ops does \ logger.warning('The library of onnxruntime custom ops does'
not exist: {ort_custom_op_path}') f'not exist: {ort_custom_op_path}')
device_id = parse_device_id(device) device_id = parse_device_id(device)
providers = ['CPUExecutionProvider'] \ providers = ['CPUExecutionProvider'] \
if device == 'cpu' else \ if device == 'cpu' else \

View File

@ -1,20 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import OpenVINOManager
_BackendManager = OpenVINOManager
def is_available() -> bool: is_available = _BackendManager.is_available
"""Checking if OpenVINO is installed. build_wrapper = _BackendManager.build_wrapper
Returns:
bool: True if OpenVINO is installed.
"""
return importlib.util.find_spec('openvino') is not None
__all__ = ['OpenVINOManager']
if is_available(): if is_available():
from .onnx2openvino import get_output_model_file from .onnx2openvino import get_output_model_file
from .utils import ModelOptimizerOptions from .utils import ModelOptimizerOptions
from .wrapper import OpenVINOWrapper from .wrapper import OpenVINOWrapper
__all__ = [ __all__ += [
'OpenVINOWrapper', 'get_output_model_file', 'ModelOptimizerOptions' 'OpenVINOWrapper', 'get_output_model_file', 'ModelOptimizerOptions'
] ]

View File

@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('openvino')
class OpenVINOManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import OpenVINOWrapper
return OpenVINOWrapper(
ir_model_file=backend_files[0], output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('openvino') is not None
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('openvino').version
except Exception:
return 'None'
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from . import is_available
assert is_available(), \
'OpenVINO is not available, please install OpenVINO first.'
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
get_mo_options_from_cfg,
get_output_model_file)
from mmdeploy.utils import get_ir_config
from .onnx2openvino import from_onnx
openvino_files = []
for onnx_path in ir_files:
model_xml_path = get_output_model_file(onnx_path, work_dir)
input_info = get_input_info_from_cfg(deploy_cfg)
output_names = get_ir_config(deploy_cfg).output_names
mo_options = get_mo_options_from_cfg(deploy_cfg)
from_onnx(onnx_path, work_dir, input_info, output_names,
mo_options)
openvino_files.append(model_xml_path)
return openvino_files

View File

@ -1,17 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import PPLNNManager
_BackendManager = PPLNNManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['PPLNNManager']
"""Check whether pplnn is installed.
Returns:
bool: True if pplnn package is installed.
"""
return importlib.util.find_spec('pyppl') is not None
__all__ = []
if is_available(): if is_available():
from .utils import register_engines from .utils import register_engines

View File

@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('pplnn')
class PPLNNManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import PPLNNWrapper
return PPLNNWrapper(
onnx_file=backend_files[0],
algo_file=backend_files[1] if len(backend_files) > 1 else None,
device=device,
output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('pyppl') is not None
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('pyppl').version
except Exception:
return 'None'
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from mmdeploy.utils import get_model_inputs
from . import is_available
from .onnx2pplnn import from_onnx
assert is_available(), \
'PPLNN is not available, please install PPLNN first.'
pplnn_files = []
for onnx_path in ir_files:
algo_file = onnx_path.replace('.onnx', '.json')
model_inputs = get_model_inputs(deploy_cfg)
assert 'opt_shape' in model_inputs, 'Expect opt_shape ' \
'in deploy config for PPLNN'
# PPLNN accepts only 1 input shape for optimization,
# may get changed in the future
input_shapes = [model_inputs.opt_shape]
algo_prefix = osp.splitext(algo_file)[0]
from_onnx(
onnx_path,
algo_prefix,
device=device,
input_shapes=input_shapes)
pplnn_files += [onnx_path, algo_file]
return pplnn_files

View File

@ -1,17 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import RKNNManager
_BackendManager = RKNNManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['RKNNManager']
"""Check whether rknn is installed.
Returns:
bool: True if rknn package is installed.
"""
return importlib.util.find_spec('rknn') is not None
__all__ = []
if is_available(): if is_available():
from .wrapper import RKNNWrapper from .wrapper import RKNNWrapper

View File

@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Callable, Optional, Sequence
from mmdeploy.utils import get_common_config
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('rknn')
class RKNNManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import RKNNWrapper
common_config = get_common_config(deploy_cfg)
return RKNNWrapper(
model=backend_files[0],
common_config=common_config,
input_names=input_names,
output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
try:
ret = importlib.util.find_spec('rknn-toolkit2') is not None
except Exception:
pass
if ret is None:
try:
ret = importlib.util.find_spec('rknn-toolkit') is not None
except Exception:
pass
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
rknn_version = None
rknn2_version = None
try:
rknn_version = pkg_resources.get_distribution(
'rknn-toolkit').version
except Exception:
pass
try:
rknn2_version = pkg_resources.get_distribution(
'rknn-toolkit2').version
except Exception:
pass
if rknn2_version is not None:
return rknn2_version
elif rknn_version is not None:
return rknn_version
return 'None'
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
import pkg_resources
try:
rknn_version = 'None'
rknn2_version = 'None'
try:
rknn_version = pkg_resources.get_distribution(
'rknn-toolkit').version
except Exception:
pass
try:
rknn2_version = pkg_resources.get_distribution(
'rknn-toolkit2').version
except Exception:
pass
rknn_info = f'rknn-toolkit:\t{rknn_version}'
rknn2_info = f'rknn2-toolkit:\t{rknn2_version}'
log_callback(rknn_info)
log_callback(rknn2_info)
info = '\n'.join([rknn_info, rknn2_info])
except Exception:
info = f'{cls.backend_name}:\tCheckFailed'
log_callback(info)
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from . import is_available
assert is_available(
), 'RKNN is not available, please install RKNN first.'
from .onnx2rknn import onnx2rknn
backend_files = []
for model_id, onnx_path in zip(range(len(ir_files)), ir_files):
pre_fix_name = osp.splitext(osp.split(onnx_path)[1])[0]
output_path = osp.join(work_dir, pre_fix_name + '.rknn')
onnx2rknn(onnx_path, output_path, deploy_cfg)
backend_files.append(output_path)
return backend_files

View File

@ -25,7 +25,7 @@ def rknn_package_info():
def onnx2rknn(onnx_model: str, def onnx2rknn(onnx_model: str,
output_path: str, output_path: str,
deploy_cfg: Union[str, mmengine.Config], deploy_cfg: Union[str, mmengine.Config],
model_cfg: Union[str, mmengine.Config], model_cfg: Optional[Union[str, mmengine.Config]] = None,
dataset_file: Optional[str] = None, dataset_file: Optional[str] = None,
**kwargs): **kwargs):
"""Convert ONNX to RKNN. """Convert ONNX to RKNN.
@ -55,7 +55,7 @@ def onnx2rknn(onnx_model: str,
input_size_list = get_backend_config(deploy_cfg).get( input_size_list = get_backend_config(deploy_cfg).get(
'input_size_list', None) 'input_size_list', None)
# update norm value # update norm value
if get_rknn_quantization(deploy_cfg) is True: if get_rknn_quantization(deploy_cfg) is True and model_cfg is not None:
transform = get_normalization(model_cfg) transform = get_normalization(model_cfg)
common_params.update( common_params.update(
dict( dict(

View File

@ -1,37 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib from .backend_manager import SDKManager
import os
import sys
from mmdeploy.utils import get_file_path _BackendManager = SDKManager
is_available = _BackendManager.is_available
_is_available = False build_wrapper = _BackendManager.build_wrapper
module_name = 'mmdeploy_python'
candidates = [
f'../../../build/lib/{module_name}.*.so',
f'../../../build/bin/*/{module_name}.*.pyd'
]
lib_path = get_file_path(os.path.dirname(__file__), candidates)
if lib_path:
lib_dir = os.path.dirname(lib_path)
sys.path.append(lib_dir)
if importlib.util.find_spec(module_name) is not None:
_is_available = True
def is_available() -> bool:
return _is_available
__all__ = ['SDKManager']
if is_available(): if is_available():
try: try:
from .wrapper import SDKWrapper from .wrapper import SDKWrapper
__all__ = ['SDKWrapper'] __all__ += ['SDKWrapper']
except Exception: except Exception:
pass pass

View File

@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os.path as osp
import sys
from typing import Any, Optional, Sequence
from mmdeploy.utils import get_file_path
from ..base import BACKEND_MANAGERS, BaseBackendManager
_is_available = False
module_name = 'mmdeploy_python'
candidates = [
f'../../../build/lib/{module_name}.*.so',
f'../../../build/bin/*/{module_name}.*.pyd'
]
lib_path = get_file_path(osp.dirname(__file__), candidates)
if lib_path:
lib_dir = osp.dirname(lib_path)
sys.path.append(lib_dir)
if importlib.util.find_spec(module_name) is not None:
_is_available = True
@BACKEND_MANAGERS.register('sdk')
class SDKManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
assert deploy_cfg is not None, \
'Building SDKWrapper requires deploy_cfg'
from mmdeploy.backend.sdk import SDKWrapper
from mmdeploy.utils import SDK_TASK_MAP, get_task_type
task_name = SDK_TASK_MAP[get_task_type(deploy_cfg)]['cls_name']
return SDKWrapper(
model_file=backend_files[0], task_name=task_name, device=device)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
global _is_available
return _is_available
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('mmdeploy').version
except Exception:
return 'None'

View File

@ -14,7 +14,7 @@ from mmdeploy.utils.config_utils import get_backend_config
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
def get_mmdpeloy_version() -> str: def get_mmdeploy_version() -> str:
"""Return the version of MMDeploy.""" """Return the version of MMDeploy."""
import mmdeploy import mmdeploy
version = mmdeploy.__version__ version = mmdeploy.__version__
@ -261,7 +261,7 @@ def get_deploy(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
cls_name = task_map[task]['cls_name'] cls_name = task_map[task]['cls_name']
_, customs = get_model_name_customs( _, customs = get_model_name_customs(
deploy_cfg, model_cfg, work_dir=work_dir, device=device) deploy_cfg, model_cfg, work_dir=work_dir, device=device)
version = get_mmdpeloy_version() version = get_mmdeploy_version()
models = get_models(deploy_cfg, model_cfg, work_dir, device) models = get_models(deploy_cfg, model_cfg, work_dir, device)
return dict(version=version, task=cls_name, models=models, customs=customs) return dict(version=version, task=cls_name, models=models, customs=customs)
@ -312,7 +312,7 @@ def get_detail(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
dict: Composed of version, codebase, codebase_config, onnx_config, dict: Composed of version, codebase, codebase_config, onnx_config,
backend_config and calib_config. backend_config and calib_config.
""" """
version = get_mmdpeloy_version() version = get_mmdeploy_version()
codebase = get_task(deploy_cfg) codebase = get_task(deploy_cfg)
codebase['pth'] = pth codebase['pth'] = pth
codebase['config'] = model_cfg.filename codebase['config'] = model_cfg.filename

View File

@ -1,25 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp from .backend_manager import SNPEManager
from .init_plugins import get_onnx2dlc_path
from .onnx2dlc import from_onnx from .onnx2dlc import from_onnx
_BackendManager = SNPEManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['from_onnx', 'SNPEManager']
"""Check whether ncnn and snpe-onnx-to-dlc tool are installed.
Returns:
bool: True if snpe-onnx-to-dlc tool are installed.
"""
onnx2dlc = get_onnx2dlc_path()
if onnx2dlc is None:
return False
return osp.exists(onnx2dlc)
__all__ = ['from_onnx']
if is_available(): if is_available():
try: try:
from .wrapper import SNPEWrapper from .wrapper import SNPEWrapper

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import sys
from typing import Any, Optional, Sequence
from mmdeploy.utils import get_root_logger
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('snpe')
class SNPEManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import SNPEWrapper
uri = None
if 'uri' in kwargs:
uri = kwargs['uri']
return SNPEWrapper(
dlc_file=backend_files[0], uri=uri, output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
from .onnx2dlc import get_onnx2dlc_path
onnx2dlc = get_onnx2dlc_path()
if onnx2dlc is None:
return False
return osp.exists(onnx2dlc)
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
uri: str = '',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
from . import is_available
logger = get_root_logger()
if not is_available():
logger.error('snpe support is not available, please check\n'
'1) `snpe-onnx-to-dlc` existed in `PATH`\n'
'2) snpe only support\n'
'ubuntu18.04')
sys.exit(1)
from mmdeploy.apis.snpe import get_env_key, get_output_model_file
from .onnx2dlc import from_onnx
if get_env_key() not in os.environ:
os.environ[get_env_key()] = uri
backend_files = []
for onnx_path in ir_files:
dlc_path = get_output_model_file(onnx_path, work_dir)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
from_onnx(onnx_path, osp.join(work_dir, onnx_name))
backend_files += [dlc_path]
return backend_files

View File

@ -1,35 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa # flake8: noqa
import importlib from .backend_manager import TensorRTManager
import os.path as osp from .init_plugins import load_tensorrt_plugin
from .init_plugins import get_ops_path, load_tensorrt_plugin _BackendManager = TensorRTManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available():
"""Check whether TensorRT package is installed and cuda is available.
Returns:
bool: True if TensorRT package is installed and cuda is available.
"""
return importlib.util.find_spec('tensorrt') is not None
def is_custom_ops_available():
"""Check whether TensorRT custom ops are installed.
Returns:
bool: True if TensorRT custom ops are compiled.
"""
tensorrt_op_path = get_ops_path()
return osp.exists(tensorrt_op_path)
__all__ = ['TensorRTManager']
if is_available(): if is_available():
from .utils import from_onnx, load, save from .utils import from_onnx, load, save
__all__ = ['from_onnx', 'save', 'load', 'load_tensorrt_plugin'] __all__ += ['from_onnx', 'save', 'load', 'load_tensorrt_plugin']
try: try:
# import wrapper if pytorch is available # import wrapper if pytorch is available

View File

@ -0,0 +1,138 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Callable, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('tensorrt')
class TensorRTManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import TRTWrapper
return TRTWrapper(engine=backend_files[0], output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('tensorrt') is not None
if ret and with_custom_ops:
from .init_plugins import get_ops_path
ops_path = get_ops_path()
custom_ops_exist = osp.exists(ops_path)
ret = ret and custom_ops_exist
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('tensorrt').version
except Exception:
return 'None'
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
info = super().check_env(log_callback=log_callback)
available = cls.is_available()
ops_available = cls.is_available(with_custom_ops=True)
ops_available = 'Available' if ops_available else 'NotAvailable'
if available:
ops_info = f'tensorrt custom ops:\t{ops_available}'
log_callback(ops_info)
info = f'{info}\n{ops_info}'
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
import os.path as osp
from mmdeploy.utils import get_model_inputs, get_partition_config
model_params = get_model_inputs(deploy_cfg)
partition_cfgs = get_partition_config(deploy_cfg)
assert len(model_params) == len(ir_files)
from . import is_available
assert is_available(), (
'TensorRT is not available,'
' please install TensorRT and build TensorRT custom ops first.')
from .onnx2tensorrt import onnx2tensorrt
backend_files = []
for model_id, model_param, onnx_path in zip(
range(len(ir_files)), model_params, ir_files):
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_file = model_param.get('save_file', onnx_name + '.engine')
partition_type = 'end2end' if partition_cfgs is None \
else onnx_name
onnx2tensorrt(
work_dir,
save_file,
model_id,
deploy_cfg,
onnx_path,
device=device,
partition_type=partition_type)
backend_files.append(osp.join(work_dir, save_file))
return backend_files

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union from typing import Any, Dict, Sequence, Union
import h5py
import numpy as np import numpy as np
import pycuda.autoinit # noqa:F401 import pycuda.autoinit # noqa:F401
import pycuda.driver as cuda import pycuda.driver as cuda
@ -25,13 +24,14 @@ class HDF5Calibrator(trt.IInt8Calibrator):
def __init__( def __init__(
self, self,
calib_file: Union[str, h5py.File], calib_file: Union[str, Any],
input_shapes: Dict[str, Sequence[int]], input_shapes: Dict[str, Sequence[int]],
model_type: str = 'end2end', model_type: str = 'end2end',
device_id: int = 0, device_id: int = 0,
algorithm: trt.CalibrationAlgoType = DEFAULT_CALIBRATION_ALGORITHM, algorithm: trt.CalibrationAlgoType = DEFAULT_CALIBRATION_ALGORITHM,
**kwargs): **kwargs):
super().__init__() super().__init__()
import h5py
if isinstance(calib_file, str): if isinstance(calib_file, str):
calib_file = h5py.File(calib_file, mode='r') calib_file = h5py.File(calib_file, mode='r')

View File

@ -140,14 +140,15 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
>>> }) >>> })
""" """
import os if device_id != 0:
old_cuda_device = os.environ.get('CUDA_DEVICE', None) import os
os.environ['CUDA_DEVICE'] = str(device_id) old_cuda_device = os.environ.get('CUDA_DEVICE', None)
import pycuda.autoinit # noqa:F401 os.environ['CUDA_DEVICE'] = str(device_id)
if old_cuda_device is not None: import pycuda.autoinit # noqa:F401
os.environ['CUDA_DEVICE'] = old_cuda_device if old_cuda_device is not None:
else: os.environ['CUDA_DEVICE'] = old_cuda_device
os.environ.pop('CUDA_DEVICE') else:
os.environ.pop('CUDA_DEVICE')
load_tensorrt_plugin() load_tensorrt_plugin()
# create builder and network # create builder and network

View File

@ -1,18 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa # flake8: noqa
from .backend_manager import TorchScriptManager
from .init_plugins import get_ops_path, ops_available from .init_plugins import get_ops_path, ops_available
_BackendManager = TorchScriptManager
is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
def is_available(): __all__ = ['get_ops_path', 'ops_available', 'TorchScriptManager']
"""Torchscript available.
Returns:
bool: Always True.
"""
return True
__all__ = ['get_ops_path', 'ops_available']
if is_available(): if is_available():
from .wrapper import TorchscriptWrapper from .wrapper import TorchscriptWrapper

View File

@ -0,0 +1,104 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Callable, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('torchscript')
class TorchScriptManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import TorchscriptWrapper
return TorchscriptWrapper(
model=backend_files[0],
input_names=input_names,
output_names=output_names)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('torch') is not None
if ret and with_custom_ops:
from .init_plugins import ops_available
ret = ret and ops_available()
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('torch').version
except Exception:
return 'None'
@classmethod
def check_env(cls, log_callback: Callable = lambda _: _) -> str:
"""Check current environment.
Returns:
str: Info about the environment.
"""
info = super().check_env(log_callback=log_callback)
available = cls.is_available()
ops_available = cls.is_available(with_custom_ops=True)
ops_available = 'Available' if ops_available else 'NotAvailable'
if available:
ops_info = f'torchscript custom ops:\t{ops_available}'
log_callback(ops_info)
info = f'{info}\n{ops_info}'
return info
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be save in this directory.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
return ir_files

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp import os.path as osp
@ -9,14 +8,14 @@ def get_ops_path() -> str:
Returns: Returns:
str: A path of the torchscript extension library. str: A path of the torchscript extension library.
""" """
wildcard = osp.abspath( from mmdeploy.utils import get_file_path
osp.join( candidates = [
osp.dirname(__file__), '../../lib/libmmdeploy_torchscript_ops.so',
'../../../build/lib/libmmdeploy_torchscript_ops.so')) '../../lib/mmdeploy_torchscript_ops.dll',
'../../../build/lib/libmmdeploy_torchscript_ops.so',
paths = glob.glob(wildcard) '../../../build/bin/*/mmdeploy_torchscript_ops.dll'
lib_path = paths[0] if len(paths) > 0 else '' ]
return lib_path return get_file_path(osp.dirname(__file__), candidates)
def ops_available() -> bool: def ops_available() -> bool:

View File

@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import importlib
import sys import sys
from .backend_manager import TVMManager
def is_available() -> bool: _BackendManager = TVMManager
"""Check whether tvm package is installed. is_available = _BackendManager.is_available
build_wrapper = _BackendManager.build_wrapper
Returns:
bool: True if tvm package is installed.
"""
return importlib.util.find_spec('tvm') is not None
def get_library_ext() -> str: def get_library_ext() -> str:
@ -26,12 +21,14 @@ def get_library_ext() -> str:
return '.so' return '.so'
__all__ = ['TVMManager']
if is_available(): if is_available():
from .onnx2tvm import from_onnx from .onnx2tvm import from_onnx
from .quantize import HDF5Dataset from .quantize import HDF5Dataset
from .tuner import build_tvm_tuner from .tuner import build_tvm_tuner
__all__ = ['from_onnx', 'build_tvm_tuner', 'HDF5Dataset', 'TVMManager'] __all__ += ['from_onnx', 'build_tvm_tuner', 'HDF5Dataset']
try: try:
# import wrapper if pytorch is available # import wrapper if pytorch is available

View File

@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
from typing import Any, Optional, Sequence
from ..base import BACKEND_MANAGERS, BaseBackendManager
@BACKEND_MANAGERS.register('tvm')
class TVMManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
"""Build the wrapper for the backend model.
Args:
backend_files (Sequence[str]): Backend files.
device (str, optional): The device info. Defaults to 'cpu'.
input_names (Optional[Sequence[str]], optional): input names.
Defaults to None.
output_names (Optional[Sequence[str]], optional): output names.
Defaults to None.
deploy_cfg (Optional[Any], optional): The deploy config. Defaults
to None.
"""
from .wrapper import TVMWrapper
bytecode = None if len(backend_files) <= 1 else backend_files[1]
return TVMWrapper(
backend_files[0],
bytecode=bytecode,
output_names=output_names,
device=device)
@classmethod
def is_available(cls, with_custom_ops: bool = False) -> bool:
"""Check whether backend is installed.
Args:
with_custom_ops (bool): check custom ops exists.
Returns:
bool: True if backend package is installed.
"""
import importlib
ret = importlib.util.find_spec('tvm') is not None
return ret
@classmethod
def get_version(cls) -> str:
"""Get the version of the backend."""
if not cls.is_available():
return 'None'
else:
import pkg_resources
try:
return pkg_resources.get_distribution('tvm').version
except Exception:
return 'None'
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
"""Convert intermediate representation to given backend.
Args:
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
import copy
from mmdeploy.apis.tvm import get_library_ext
from mmdeploy.utils import (get_calib_filename, get_model_inputs,
get_partition_config)
from .onnx2tvm import from_onnx
model_inputs = get_model_inputs(deploy_cfg)
if device.startswith('cuda'):
target = 'cuda'
else:
target = 'llvm'
lib_ext = get_library_ext()
tvm_files = []
for model_id, onnx_path in enumerate(ir_files):
model_input = copy.deepcopy(model_inputs[model_id])
use_vm = model_input.get('use_vm', False)
if 'target' not in model_input['tuner']:
model_input['tuner']['target'] = target
lib_path = osp.splitext(onnx_path)[0] + lib_ext
code_path = osp.splitext(
onnx_path)[0] + '.code' if use_vm else None
model_input['output_file'] = lib_path
model_input['onnx_model'] = onnx_path
model_input['bytecode_file'] = code_path
# create calibration dataset
if 'qconfig' in model_input:
from .quantize import HDF5Dataset
calib_filename = get_calib_filename(deploy_cfg)
calib_path = osp.join(work_dir, calib_filename)
partition_cfgs = get_partition_config(deploy_cfg)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
partition_type = 'end2end' if partition_cfgs is None \
else onnx_name
dataset = HDF5Dataset(
calib_path,
model_input['shape'],
model_type=partition_type,
device=target)
model_input['dataset'] = dataset()
from_onnx(**model_input)
tvm_files += [lib_path, code_path]
return tvm_files

View File

@ -6,8 +6,7 @@ import mmengine
from mmengine.model import BaseModel from mmengine.model import BaseModel
from torch import nn from torch import nn
from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config, from mmdeploy.utils import Backend, get_ir_config
get_common_config, get_ir_config, get_task_type)
class BaseBackendModel(BaseModel, metaclass=ABCMeta): class BaseBackendModel(BaseModel, metaclass=ABCMeta):
@ -58,88 +57,13 @@ class BaseBackendModel(BaseModel, metaclass=ABCMeta):
names from the model. names from the model.
deploy_cfg: Deployment config file. deploy_cfg: Deployment config file.
""" """
if backend == Backend.ONNXRUNTIME: from mmdeploy.backend.base import get_backend_manager
from mmdeploy.backend.onnxruntime import ORTWrapper backend_mgr = get_backend_manager(backend.value)
return ORTWrapper( if backend_mgr is None:
onnx_file=backend_files[0], raise NotImplementedError(
device=device, f'Unsupported backend type: {backend.value}')
output_names=output_names) return backend_mgr.build_wrapper(backend_files, device, input_names,
elif backend == Backend.TENSORRT: output_names, deploy_cfg, **kwargs)
from mmdeploy.backend.tensorrt import TRTWrapper
return TRTWrapper(
engine=backend_files[0], output_names=output_names)
elif backend == Backend.PPLNN:
from mmdeploy.backend.pplnn import PPLNNWrapper
return PPLNNWrapper(
onnx_file=backend_files[0],
algo_file=backend_files[1] if len(backend_files) > 1 else None,
device=device,
output_names=output_names)
elif backend == Backend.NCNN:
from mmdeploy.backend.ncnn import NCNNWrapper
# For unittest deploy_config will not pass into _build_wrapper
# function.
if deploy_cfg:
backend_config = get_backend_config(deploy_cfg)
use_vulkan = backend_config.get('use_vulkan', False)
else:
use_vulkan = False
return NCNNWrapper(
param_file=backend_files[0],
bin_file=backend_files[1],
output_names=output_names,
use_vulkan=use_vulkan)
elif backend == Backend.OPENVINO:
from mmdeploy.backend.openvino import OpenVINOWrapper
return OpenVINOWrapper(
ir_model_file=backend_files[0], output_names=output_names)
elif backend == Backend.SDK:
assert deploy_cfg is not None, \
'Building SDKWrapper requires deploy_cfg'
from mmdeploy.backend.sdk.wrapper import SDKWrapper
task_name = SDK_TASK_MAP[get_task_type(deploy_cfg)]['cls_name']
return SDKWrapper(
model_file=backend_files[0],
task_name=task_name,
device=device)
elif backend == Backend.TORCHSCRIPT:
from mmdeploy.backend.torchscript import TorchscriptWrapper
return TorchscriptWrapper(
model=backend_files[0],
input_names=input_names,
output_names=output_names)
elif backend == Backend.RKNN:
from mmdeploy.backend.rknn import RKNNWrapper
common_config = get_common_config(deploy_cfg)
return RKNNWrapper(
model=backend_files[0],
common_config=common_config,
input_names=input_names,
output_names=output_names)
elif backend == Backend.ASCEND:
from mmdeploy.backend.ascend import AscendWrapper
return AscendWrapper(model=backend_files[0], device=device)
elif backend == Backend.SNPE:
from mmdeploy.backend.snpe import SNPEWrapper
uri = None
if 'uri' in kwargs:
uri = kwargs['uri']
return SNPEWrapper(
dlc_file=backend_files[0], uri=uri, output_names=output_names)
elif backend == Backend.COREML:
from mmdeploy.backend.coreml import CoreMLWrapper
return CoreMLWrapper(model_file=backend_files[0])
elif backend == Backend.TVM:
from mmdeploy.backend.tvm import TVMWrapper
bytecode = None if len(backend_files) == 1 else backend_files[1]
return TVMWrapper(
lib=backend_files[0],
output_names=output_names,
bytecode=bytecode,
device=device)
else:
raise NotImplementedError(f'Unknown backend type: {backend.value}')
def destroy(self): def destroy(self):
if hasattr(self, 'wrapper') and hasattr(self.wrapper, 'destroy'): if hasattr(self, 'wrapper') and hasattr(self.wrapper, 'destroy'):

View File

@ -239,14 +239,14 @@ class End2EndModel(BaseBackendModel):
masks = batch_masks[i] masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2] img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2]
export_postprocess_mask = True export_postprocess_mask = False
if self.deploy_cfg is not None: if self.deploy_cfg is not None:
mmdet_deploy_cfg = get_post_processing_params( mmdet_deploy_cfg = get_post_processing_params(
self.deploy_cfg) self.deploy_cfg)
# this flag enable postprocess when export. # this flag enable postprocess when export.
export_postprocess_mask = mmdet_deploy_cfg.get( export_postprocess_mask = mmdet_deploy_cfg.get(
'export_postprocess_mask', True) 'export_postprocess_mask', False)
if not export_postprocess_mask: if not export_postprocess_mask:
masks = End2EndModel.postprocessing_masks( masks = End2EndModel.postprocessing_masks(
dets[:, :4], masks, ori_w, ori_h, self.device) dets[:, :4], masks, ori_w, ori_h, self.device)

View File

@ -64,7 +64,8 @@ def fcn_mask_head__predict_by_feat(self,
# grid sample is not supported by most engine # grid sample is not supported by most engine
# so we add a flag to disable it. # so we add a flag to disable it.
mmdet_params = get_post_processing_params(ctx.cfg) mmdet_params = get_post_processing_params(ctx.cfg)
export_postprocess_mask = mmdet_params.get('export_postprocess_mask', True) export_postprocess_mask = mmdet_params.get('export_postprocess_mask',
False)
if not export_postprocess_mask: if not export_postprocess_mask:
return mask_pred return mask_pred

View File

@ -15,7 +15,10 @@ def get_library_version(lib):
""" """
try: try:
lib = importlib.import_module(lib) lib = importlib.import_module(lib)
version = lib.__version__ if hasattr(lib, '__version__'):
version = lib.__version__
else:
version = None
except Exception: except Exception:
version = None version = None

View File

@ -4,7 +4,6 @@ import os.path as osp
import random import random
import string import string
import tempfile import tempfile
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -35,47 +34,13 @@ def backend_checker(backend: Backend, require_plugin: bool = False):
will also check if the backend plugin has been compiled. Default will also check if the backend plugin has been compiled. Default
to `False`. to `False`.
""" """
is_custom_ops_available = None from mmdeploy.backend.base import get_backend_manager
if backend == Backend.ONNXRUNTIME:
from mmdeploy.apis.onnxruntime import is_available backend_mgr = get_backend_manager(backend.value)
if require_plugin: result = backend_mgr.is_available(with_custom_ops=require_plugin)
from mmdeploy.apis.onnxruntime import is_custom_ops_available
elif backend == Backend.TENSORRT:
from mmdeploy.apis.tensorrt import is_available
if require_plugin:
from mmdeploy.apis.tensorrt import is_custom_ops_available
elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import is_available
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available
if require_plugin:
from mmdeploy.apis.ncnn import is_custom_ops_available
elif backend == Backend.OPENVINO:
from mmdeploy.apis.openvino import is_available
elif backend == Backend.RKNN:
# device not require as backend is not really running
from mmdeploy.apis.rknn import is_available
elif backend == Backend.ASCEND:
from mmdeploy.apis.ascend import is_available
elif backend == Backend.TVM:
from mmdeploy.apis.ascend import is_available
else:
warnings.warn('The backend checker is not available')
return
checker = pytest.mark.skipif( checker = pytest.mark.skipif(
not is_available(), reason=f'{backend.value} package is not available') not result, reason=f'{backend.value} package is not available')
if require_plugin and is_custom_ops_available is not None:
plugin_checker = pytest.mark.skipif(
not is_custom_ops_available(),
reason=f'{backend.value} plugin is not available')
def double_checker(func):
func = checker(func)
func = plugin_checker(func)
return func
return double_checker
return checker return checker
@ -90,40 +55,13 @@ def check_backend(backend: Backend, require_plugin: bool = False):
will also check if the backend plugin has been compiled. Default will also check if the backend plugin has been compiled. Default
to `False`. to `False`.
""" """
is_custom_ops_available = None from mmdeploy.backend.base import get_backend_manager
if backend == Backend.ONNXRUNTIME:
from mmdeploy.apis.onnxruntime import is_available
if require_plugin:
from mmdeploy.apis.onnxruntime import is_custom_ops_available
elif backend == Backend.TENSORRT:
from mmdeploy.apis.tensorrt import is_available
if require_plugin:
from mmdeploy.apis.tensorrt import is_custom_ops_available
elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import is_available
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available
if require_plugin:
from mmdeploy.apis.ncnn import is_custom_ops_available
elif backend == Backend.OPENVINO:
from mmdeploy.apis.openvino import is_available
elif backend == Backend.TORCHSCRIPT:
from mmdeploy.backend.torchscript import ops_available as is_available
elif backend == Backend.RKNN:
from mmdeploy.backend.rknn import is_available
elif backend == Backend.ASCEND:
from mmdeploy.backend.ascend import is_available
elif backend == Backend.TVM:
from mmdeploy.backend.ascend import is_available
else:
warnings.warn('The backend checker is not available')
return
if not is_available(): backend_mgr = get_backend_manager(backend.value)
result = backend_mgr.is_available(with_custom_ops=require_plugin)
if not result:
pytest.skip(f'{backend.value} package is not available') pytest.skip(f'{backend.value} package is not available')
if require_plugin and is_custom_ops_available is not None:
if not is_custom_ops_available():
pytest.skip(f'{backend.value} plugin is not available')
class WrapFunction(nn.Module): class WrapFunction(nn.Module):
@ -455,6 +393,7 @@ def get_backend_outputs(ir_file_path: str,
If the backend specified in 'deploy_cfg' is not available, If the backend specified in 'deploy_cfg' is not available,
then None will be returned. then None will be returned.
""" """
from mmdeploy.apis.utils import to_backend
backend = get_backend(deploy_cfg) backend = get_backend(deploy_cfg)
flatten_model_inputs = get_flatten_inputs(model_inputs) flatten_model_inputs = get_flatten_inputs(model_inputs)
ir_config = get_ir_config(deploy_cfg) ir_config = get_ir_config(deploy_cfg)
@ -465,109 +404,32 @@ def get_backend_outputs(ir_file_path: str,
k for k, v in flatten_model_inputs.items() if k != 'ctx' k for k, v in flatten_model_inputs.items() if k != 'ctx'
] ]
# prepare backend model and input features work_dir = tempfile.TemporaryDirectory().name
device = 'cpu'
# TODO: Try to wrap these code into backend manager
if backend != Backend.TORCHSCRIPT:
model_inputs = flatten_model_inputs
if backend == Backend.TENSORRT: if backend == Backend.TENSORRT:
# convert to engine device = 'cuda'
import mmdeploy.apis.tensorrt as trt_apis model_inputs = dict((k, v.cuda()) for k, v in model_inputs.items())
if not (trt_apis.is_available()
and trt_apis.is_custom_ops_available()):
return None
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
trt_apis.onnx2tensorrt(
'',
trt_file_path,
0,
deploy_cfg=deploy_cfg,
onnx_model=ir_file_path)
backend_files = [trt_file_path]
for k, v in model_inputs.items():
model_inputs[k] = model_inputs[k].cuda()
backend_feats = model_inputs
device = 'cuda:0'
elif backend == Backend.ONNXRUNTIME:
import mmdeploy.apis.onnxruntime as ort_apis
if not (ort_apis.is_available()
and ort_apis.is_custom_ops_available()):
return None
feature_list = []
backend_feats = {}
for k, item in model_inputs.items():
if type(item) is torch.Tensor:
feature_list.append(item)
elif type(item) is tuple or list:
for i in item:
assert type(i) is torch.Tensor, 'model_inputs contains '
'nested sequence of torch.Tensor'
feature_list.append(i)
else:
raise TypeError('values of model_inputs are expected to be '
'torch.Tensor or its sequence, '
f'but got {type(model_inputs)}')
# for onnx file generated with list[torch.Tensor] input,
# the input dict keys are just numbers if not specified
for i in range(len(feature_list)):
if i < len(input_names):
backend_feats[input_names[i]] = feature_list[i]
else:
backend_feats[str(i)] = feature_list[i]
backend_files = [ir_file_path]
device = 'cpu'
elif backend == Backend.NCNN:
import mmdeploy.apis.ncnn as ncnn_apis
if not (ncnn_apis.is_available()
and ncnn_apis.is_custom_ops_available()):
return None
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
ncnn_files_prefix = osp.splitext(ir_file_path)[0]
ncnn_apis.from_onnx(ir_file_path, ncnn_files_prefix)
backend_files = [param_path, bin_path]
backend_feats = flatten_model_inputs
device = 'cpu'
elif backend == Backend.OPENVINO: elif backend == Backend.OPENVINO:
import mmdeploy.apis.openvino as openvino_apis
if not openvino_apis.is_available():
return None
from mmdeploy.apis.openvino import get_mo_options_from_cfg
openvino_work_dir = tempfile.TemporaryDirectory().name
openvino_file_path = openvino_apis.get_output_model_file(
ir_file_path, openvino_work_dir)
input_info = { input_info = {
name: value.shape name: value.shape
for name, value in flatten_model_inputs.items() for name, value in flatten_model_inputs.items()
} }
mo_options = get_mo_options_from_cfg(deploy_cfg) deploy_cfg['backend_config']['model_inputs'] = [
openvino_apis.from_onnx(ir_file_path, openvino_work_dir, input_info, dict(opt_shapes=input_info)
output_names, mo_options) ]
backend_files = [openvino_file_path] backend_files = to_backend(
backend_feats = flatten_model_inputs backend.value, [ir_file_path],
device = 'cpu' work_dir=work_dir,
deploy_cfg=deploy_cfg,
device=device)
backend_feats = model_inputs
elif backend == Backend.DEFAULT: if backend == Backend.TORCHSCRIPT:
return None
elif backend == Backend.TORCHSCRIPT:
backend_files = [ir_file_path]
device = 'cpu'
backend_feats = [v for _, v in model_inputs.items()] backend_feats = [v for _, v in model_inputs.items()]
elif backend == Backend.ASCEND:
# Ascend model conversion
import mmdeploy.apis.ascend as ascend_apis
from mmdeploy.utils import get_model_inputs
if not ascend_apis.is_available():
return None
work_dir = osp.split(ir_file_path)[0]
# convert model
convert_args = get_model_inputs(deploy_cfg)
ascend_apis.from_onnx(ir_file_path, work_dir, convert_args[0])
om_file_name = osp.splitext(osp.split(ir_file_path)[1])[0]
backend_files = [osp.join(work_dir, om_file_name + '.om')]
backend_feats = flatten_model_inputs
device = 'cpu'
else:
raise NotImplementedError(
f'Unimplemented backend type: {backend.value}')
from mmdeploy.codebase.base import BaseBackendModel from mmdeploy.codebase.base import BaseBackendModel
backend_model = BaseBackendModel._build_wrapper( backend_model = BaseBackendModel._build_wrapper(

View File

@ -1,2 +1,4 @@
-r requirements/codebases.txt -r requirements/codebases.txt
-r requirements/backends.txt -r requirements/backends.txt
h5py
tqdm

View File

@ -1,6 +1,5 @@
aenum aenum
grpcio grpcio
h5py
matplotlib matplotlib
mmengine mmengine
multiprocess multiprocess
@ -10,4 +9,3 @@ prettytable
protobuf<=3.20.1 protobuf<=3.20.1
six six
terminaltables terminaltables
tqdm

View File

@ -3,7 +3,6 @@ import os.path as osp
import tempfile import tempfile
from multiprocessing import Process from multiprocessing import Process
import h5py
from mmengine import Config from mmengine import Config
from mmdeploy.apis import create_calib_input_data from mmdeploy.apis import create_calib_input_data
@ -171,6 +170,7 @@ def get_model_cfg():
def run_test_create_calib_end2end(): def run_test_create_calib_end2end():
import h5py
model_cfg = get_model_cfg() model_cfg = get_model_cfg()
deploy_cfg = get_end2end_deploy_cfg() deploy_cfg = get_end2end_deploy_cfg()
create_calib_input_data( create_calib_input_data(
@ -203,6 +203,7 @@ def test_create_calib_end2end():
def run_test_create_calib_parittion(): def run_test_create_calib_parittion():
import h5py
model_cfg = get_model_cfg() model_cfg = get_model_cfg()
deploy_cfg = get_partition_deploy_cfg() deploy_cfg = get_partition_deploy_cfg()
create_calib_input_data( create_calib_input_data(

View File

@ -145,96 +145,40 @@ def onnx2backend(backend, onnx_file):
def create_wrapper(backend, model_files): def create_wrapper(backend, model_files):
if backend == Backend.TENSORRT: from mmdeploy.backend.base import get_backend_manager
from mmdeploy.backend.tensorrt import TRTWrapper backend_mgr = get_backend_manager(backend.value)
trt_model = TRTWrapper(model_files, output_names) deploy_cfg = None
return trt_model if isinstance(model_files, str):
elif backend == Backend.ONNXRUNTIME: model_files = [model_files]
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_model = ORTWrapper(model_files, 'cpu', output_names)
return ort_model
elif backend == Backend.PPLNN:
from mmdeploy.backend.pplnn import PPLNNWrapper
onnx_file, algo_file = model_files
pplnn_model = PPLNNWrapper(onnx_file, algo_file, 'cpu', output_names)
return pplnn_model
elif backend == Backend.NCNN:
from mmdeploy.backend.ncnn import NCNNWrapper
param_file, bin_file = model_files
ncnn_model = NCNNWrapper(param_file, bin_file, output_names)
return ncnn_model
elif backend == Backend.OPENVINO:
from mmdeploy.backend.openvino import OpenVINOWrapper
openvino_model = OpenVINOWrapper(model_files, output_names)
return openvino_model
elif backend == Backend.TORCHSCRIPT:
from mmdeploy.backend.torchscript import TorchscriptWrapper
torchscript_model = TorchscriptWrapper(
model_files, input_names=input_names, output_names=output_names)
return torchscript_model
elif backend == Backend.RKNN: elif backend == Backend.RKNN:
from mmdeploy.backend.rknn import RKNNWrapper deploy_cfg = dict(
rknn_model = RKNNWrapper( backend_config=dict(
model_files, common_config=dict(target_platform=target_platform)))
common_config=dict(target_platform=target_platform), return backend_mgr.build_wrapper(
input_names=input_names, model_files,
output_names=output_names) input_names=input_names,
return rknn_model output_names=output_names,
elif backend == Backend.ASCEND: deploy_cfg=deploy_cfg)
from mmdeploy.backend.ascend import AscendWrapper
ascend_model = AscendWrapper(model_files)
return ascend_model
elif backend == Backend.TVM:
from mmdeploy.backend.tvm import TVMWrapper
tvm_model = TVMWrapper(model_files, output_names=output_names)
return tvm_model
else:
raise NotImplementedError(f'Unknown backend type: {backend.value}')
def run_wrapper(backend, wrapper, input): def run_wrapper(backend, wrapper, input):
if backend == Backend.TENSORRT: if backend == Backend.TENSORRT:
input = input.cuda() input = input.cuda()
results = wrapper({'input': input})['output']
results = results.detach().cpu() results = wrapper({'input': input})
return results
elif backend == Backend.ONNXRUNTIME: if backend != Backend.RKNN:
results = wrapper({'input': input})['output'] results = results['output']
results = results.detach().cpu()
return results results = results.detach().cpu()
elif backend == Backend.PPLNN:
results = wrapper({'input': input})['output'] return results
results = results.detach().cpu()
return results
elif backend == Backend.NCNN:
input = input.float()
results = wrapper({'input': input})['output']
results = results.detach().cpu()
return results
elif backend == Backend.OPENVINO:
results = wrapper({'input': input})['output']
results = results.detach().cpu()
return results
elif backend == Backend.TORCHSCRIPT:
results = wrapper({'input': input})['output']
return results
elif backend == Backend.RKNN:
results = wrapper({'input': input})
elif backend == Backend.ASCEND:
results = wrapper({'input': input})['output']
return results
elif backend == Backend.TVM:
results = wrapper({'input': input})['output']
return results
else:
raise NotImplementedError(f'Unknown backend type: {backend.value}')
ALL_BACKEND = [ ALL_BACKEND = list(Backend)
Backend.TENSORRT, Backend.ONNXRUNTIME, Backend.PPLNN, Backend.NCNN, ALL_BACKEND.remove(Backend.DEFAULT)
Backend.OPENVINO, Backend.TORCHSCRIPT, Backend.ASCEND, Backend.RKNN, ALL_BACKEND.remove(Backend.PYTORCH)
Backend.TVM ALL_BACKEND.remove(Backend.SDK)
]
@pytest.mark.parametrize('backend', ALL_BACKEND) @pytest.mark.parametrize('backend', ALL_BACKEND)

View File

@ -710,6 +710,7 @@ def test_forward_of_base_detector(model_cfg_path, backend):
pre_top_k=-1, pre_top_k=-1,
keep_top_k=100, keep_top_k=100,
background_label_id=-1, background_label_id=-1,
export_postprocess_mask=False,
)))) ))))
model_cfg = Config(dict(model=mmengine.load(model_cfg_path))) model_cfg = Config(dict(model=mmengine.load(model_cfg_path)))
model_cfg.model = _replace_r50_with_r18(model_cfg.model) model_cfg.model = _replace_r50_with_r18(model_cfg.model)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import pytest import pytest
import torch import torch
from mmengine import Config from mmengine import Config
@ -20,7 +22,7 @@ from mmdeploy.codebase.mmdet.deploy.object_detection_model import End2EndModel
def assert_det_results(results, module_name: str = 'model'): def assert_det_results(results, module_name: str = 'model'):
assert results is not None, f'failed to get output using {module_name}' assert results is not None, f'failed to get output using {module_name}'
assert isinstance(results, tuple) assert isinstance(results, Sequence)
assert len(results) == 2 assert len(results) == 2
assert results[0].shape[0] == results[1].shape[0] assert results[0].shape[0] == results[1].shape[0]
assert results[0].shape[1] == results[1].shape[1] assert results[0].shape[1] == results[1].shape[1]
@ -28,7 +30,7 @@ def assert_det_results(results, module_name: str = 'model'):
def assert_forward_results(results, module_name: str = 'model'): def assert_forward_results(results, module_name: str = 'model'):
assert results is not None, f'failed to get output using {module_name}' assert results is not None, f'failed to get output using {module_name}'
assert isinstance(results, list) assert isinstance(results, Sequence)
assert len(results) == 1 assert len(results) == 1
assert isinstance(results[0].pred_instances, InstanceData) assert isinstance(results[0].pred_instances, InstanceData)
assert results[0].pred_instances.bboxes.shape[-1] == 4 assert results[0].pred_instances.bboxes.shape[-1] == 4

View File

@ -96,7 +96,7 @@ def get_single_roi_extractor():
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_rotated_single_roi_extractor(backend_type: Backend): def test_rotated_single_roi_extractor(backend_type: Backend):
check_backend(backend_type) check_backend(backend_type, True)
single_roi_extractor = get_single_roi_extractor() single_roi_extractor = get_single_roi_extractor()
output_names = ['roi_feat'] output_names = ['roi_feat']
@ -226,7 +226,7 @@ def test_oriented_rpn_head__predict_by_feat(backend_type: Backend):
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_gv_ratio_roi_head__predict_bbox(backend_type: Backend): def test_gv_ratio_roi_head__predict_bbox(backend_type: Backend):
check_backend(backend_type) check_backend(backend_type, True)
from mmrotate.models.roi_heads import GVRatioRoIHead from mmrotate.models.roi_heads import GVRatioRoIHead
output_names = ['dets', 'labels'] output_names = ['dets', 'labels']
deploy_cfg = Config( deploy_cfg = Config(
@ -369,7 +369,7 @@ def get_rotated_rtmdet_head_model():
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_rotated_rtmdet_head_predict_by_feat(backend_type: Backend): def test_rotated_rtmdet_head_predict_by_feat(backend_type: Backend):
"""Test predict_by_feat rewrite of RTMDet-R.""" """Test predict_by_feat rewrite of RTMDet-R."""
check_backend(backend_type) check_backend(backend_type, require_plugin=True)
rtm_r_head = get_rotated_rtmdet_head_model() rtm_r_head = get_rotated_rtmdet_head_model()
rtm_r_head.cpu().eval() rtm_r_head.cpu().eval()
s = 128 s = 128

View File

@ -119,7 +119,7 @@ def test_upconvblock_forward(backend, is_dynamic_shape):
dict( dict(
backend_config=dict(type=backend.value), backend_config=dict(type=backend.value),
onnx_config=dict( onnx_config=dict(
input_names=['skip', 'x'], input_names=['x', 'skip'],
output_names=['output'], output_names=['output'],
dynamic_axes=dynamic_axes), dynamic_axes=dynamic_axes),
codebase_config=dict( codebase_config=dict(

View File

@ -31,7 +31,7 @@ def generate_mmseg_deploy_config(backend='onnxruntime'):
keep_initializers_as_inputs=False, keep_initializers_as_inputs=False,
opset_version=11, opset_version=11,
input_shape=None, input_shape=None,
input_names=['input'], input_names=['inputs'],
output_names=['output']))) output_names=['output'])))
return deploy_cfg return deploy_cfg

View File

@ -4,8 +4,7 @@ from mmcv.utils import collect_env as collect_base_env
from mmengine.utils import get_git_hash from mmengine.utils import get_git_hash
import mmdeploy import mmdeploy
from mmdeploy.utils import (get_backend_version, get_codebase_version, from mmdeploy.utils import get_codebase_version, get_root_logger
get_root_logger)
def collect_env(): def collect_env():
@ -17,44 +16,16 @@ def collect_env():
def check_backend(): def check_backend():
backend_versions = get_backend_version() from mmdeploy.backend.base import get_backend_manager
ort_version = backend_versions['onnxruntime'] from mmdeploy.utils import Backend
trt_version = backend_versions['tensorrt'] exclude_backend_lists = [Backend.DEFAULT, Backend.PYTORCH, Backend.SDK]
ncnn_version = backend_versions['ncnn'] backend_lists = [
tvm_version = backend_versions['tvm'] backend for backend in Backend if backend not in exclude_backend_lists
]
import mmdeploy.apis.onnxruntime as ort_apis for backend in backend_lists:
logger = get_root_logger() backend_mgr = get_backend_manager(backend.value)
logger.info(f'onnxruntime: {ort_version}\tops_is_avaliable : ' backend_mgr.check_env(logger.info)
f'{ort_apis.is_custom_ops_available()}')
import mmdeploy.apis.tensorrt as trt_apis
logger.info(f'tensorrt: {trt_version}\tops_is_avaliable : '
f'{trt_apis.is_custom_ops_available()}')
import mmdeploy.apis.ncnn as ncnn_apis
logger.info(f'ncnn: {ncnn_version}\tops_is_avaliable : '
f'{ncnn_apis.is_custom_ops_available()}')
logger.info(f'tvm: {tvm_version}')
import mmdeploy.apis.pplnn as pplnn_apis
logger.info(f'pplnn_is_avaliable: {pplnn_apis.is_available()}')
import mmdeploy.apis.openvino as openvino_apis
logger.info(f'openvino_is_avaliable: {openvino_apis.is_available()}')
import mmdeploy.apis.snpe as snpe_apis
logger.info(f'snpe_is_available: {snpe_apis.is_available()}')
import mmdeploy.apis.ascend as ascend_apis
logger.info(f'ascend_is_available: {ascend_apis.is_available()}')
import mmdeploy.apis.coreml as coreml_apis
logger.info(f'coreml_is_available: {coreml_apis.is_available()}')
import mmdeploy.apis.rknn as rknn_apis
logger.info(f'rknn_is_avaliable: {rknn_apis.is_available()}')
def check_codebase(): def check_codebase():

View File

@ -13,11 +13,11 @@ from mmdeploy.apis import (create_calib_input_data, extract_model,
get_predefined_partition_cfg, torch2onnx, get_predefined_partition_cfg, torch2onnx,
torch2torchscript, visualize_model) torch2torchscript, visualize_model)
from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.apis.utils import to_backend
from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.backend.sdk.export_info import export2SDK
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
get_ir_config, get_model_inputs, get_ir_config, get_partition_config,
get_partition_config, get_root_logger, load_config, get_root_logger, load_config, target_wrapper)
target_wrapper)
def parse_args(): def parse_args():
@ -193,269 +193,79 @@ def main():
backend_files = ir_files backend_files = ir_files
# convert backend # convert backend
backend = get_backend(deploy_cfg) backend = get_backend(deploy_cfg)
if backend == Backend.TENSORRT:
model_params = get_model_inputs(deploy_cfg)
assert len(model_params) == len(ir_files)
from mmdeploy.apis.tensorrt import is_available as trt_is_available # preprocess deploy_cfg
assert trt_is_available( if backend == Backend.RKNN:
), 'TensorRT is not available,' \ # TODO: Add this to task_processor in the future
+ ' please install TensorRT and build TensorRT custom ops first.' import tempfile
from mmdeploy.apis.tensorrt import onnx2tensorrt from mmdeploy.utils import (get_common_config, get_normalization,
PIPELINE_MANAGER.enable_multiprocess(True, [onnx2tensorrt]) get_quantization_config,
PIPELINE_MANAGER.set_log_level(log_level, [onnx2tensorrt]) get_rknn_quantization)
quantization_cfg = get_quantization_config(deploy_cfg)
backend_files = [] common_params = get_common_config(deploy_cfg)
for model_id, model_param, onnx_path in zip( if get_rknn_quantization(deploy_cfg) is True:
range(len(ir_files)), model_params, ir_files): transform = get_normalization(model_cfg)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0] common_params.update(
save_file = model_param.get('save_file', onnx_name + '.engine') dict(
mean_values=[transform['mean']],
partition_type = 'end2end' if partition_cfgs is None \ std_values=[transform['std']]))
else onnx_name
onnx2tensorrt(
args.work_dir,
save_file,
model_id,
deploy_cfg_path,
onnx_path,
device=args.device,
partition_type=partition_type)
backend_files.append(osp.join(args.work_dir, save_file))
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
if not is_available_ncnn():
logger.error('ncnn support is not available, please make sure \
1) `mmdeploy_onnx2ncnn` existed in `PATH` \
2) python import ncnn success')
exit(1)
import mmdeploy.apis.ncnn as ncnn_api
from mmdeploy.apis.ncnn import get_output_model_file
PIPELINE_MANAGER.set_log_level(log_level, [ncnn_api.from_onnx])
backend_files = []
for onnx_path in ir_files:
model_param_path, model_bin_path = get_output_model_file(
onnx_path, args.work_dir)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
ncnn_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name))
if quant:
from onnx2ncnn_quant_table import get_table
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
model_cfg_path)
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
onnx_path, args.work_dir)
create_process(
'ncnn quant table',
target=get_table,
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
quant_table, quant_image_dir, args.device),
kwargs=dict(),
ret_value=ret_value)
create_process(
'ncnn_int8',
target=ncnn2int8,
args=(model_param_path, model_bin_path, quant_table,
quant_param, quant_bin),
kwargs=dict(),
ret_value=ret_value)
backend_files += [quant_param, quant_bin]
else:
backend_files += [model_param_path, model_bin_path]
elif backend == Backend.SNPE:
from mmdeploy.apis.snpe import is_available as is_available
if not is_available():
logger.error('snpe support is not available, please check \
1) `snpe-onnx-to-dlc` existed in `PATH` 2) snpe only support \
ubuntu18.04')
exit(1)
import mmdeploy.apis.snpe as snpe_api
from mmdeploy.apis.snpe import get_env_key, get_output_model_file
if get_env_key() not in os.environ:
os.environ[get_env_key()] = args.uri
PIPELINE_MANAGER.set_log_level(log_level, [snpe_api.from_onnx])
backend_files = []
for onnx_path in ir_files:
dlc_path = get_output_model_file(onnx_path, args.work_dir)
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
snpe_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name))
backend_files = [dlc_path]
elif backend == Backend.OPENVINO:
from mmdeploy.apis.openvino import \
is_available as is_available_openvino
assert is_available_openvino(), \
'OpenVINO is not available, please install OpenVINO first.'
import mmdeploy.apis.openvino as openvino_api
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
get_mo_options_from_cfg,
get_output_model_file)
PIPELINE_MANAGER.set_log_level(log_level, [openvino_api.from_onnx])
openvino_files = []
for onnx_path in ir_files:
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
input_info = get_input_info_from_cfg(deploy_cfg)
output_names = get_ir_config(deploy_cfg).output_names
mo_options = get_mo_options_from_cfg(deploy_cfg)
openvino_api.from_onnx(onnx_path, args.work_dir, input_info,
output_names, mo_options)
openvino_files.append(model_xml_path)
backend_files = openvino_files
elif backend == Backend.PPLNN:
from mmdeploy.apis.pplnn import is_available as is_available_pplnn
assert is_available_pplnn(), \
'PPLNN is not available, please install PPLNN first.'
from mmdeploy.apis.pplnn import from_onnx
pplnn_pipeline_funcs = [from_onnx]
PIPELINE_MANAGER.set_log_level(log_level, pplnn_pipeline_funcs)
pplnn_files = []
for onnx_path in ir_files:
algo_file = onnx_path.replace('.onnx', '.json')
model_inputs = get_model_inputs(deploy_cfg)
assert 'opt_shape' in model_inputs, 'Expect opt_shape ' \
'in deploy config for PPLNN'
# PPLNN accepts only 1 input shape for optimization,
# may get changed in the future
input_shapes = [model_inputs.opt_shape]
algo_prefix = osp.splitext(algo_file)[0]
from_onnx(
onnx_path,
algo_prefix,
device=args.device,
input_shapes=input_shapes)
pplnn_files += [onnx_path, algo_file]
backend_files = pplnn_files
elif backend == Backend.RKNN:
from mmdeploy.apis.rknn import is_available as rknn_is_available
assert rknn_is_available(
), 'RKNN is not available, please install RKNN first.'
from mmdeploy.apis.rknn import onnx2rknn
PIPELINE_MANAGER.enable_multiprocess(True, [onnx2rknn])
PIPELINE_MANAGER.set_log_level(logging.INFO, [onnx2rknn])
backend_files = []
for model_id, onnx_path in zip(range(len(ir_files)), ir_files):
pre_fix_name = osp.splitext(osp.split(onnx_path)[1])[0]
output_path = osp.join(args.work_dir, pre_fix_name + '.rknn')
import tempfile
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
with open(dataset_file, 'w') as f:
f.writelines([osp.abspath(args.img)])
onnx2rknn(
onnx_path,
output_path,
deploy_cfg_path,
model_cfg_path,
dataset_file=dataset_file)
backend_files.append(output_path)
elif backend == Backend.ASCEND:
from mmdeploy.apis.ascend import from_onnx
ascend_pipeline_funcs = [from_onnx]
PIPELINE_MANAGER.set_log_level(log_level, ascend_pipeline_funcs)
model_inputs = get_model_inputs(deploy_cfg)
om_files = []
for model_id, onnx_path in enumerate(ir_files):
om_path = osp.splitext(onnx_path)[0] + '.om'
from_onnx(onnx_path, args.work_dir, model_inputs[model_id])
om_files.append(om_path)
backend_files = om_files
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
with open(dataset_file, 'w') as f:
f.writelines([osp.abspath(args.img)])
quantization_cfg.setdefault('dataset', dataset_file)
if backend == Backend.ASCEND:
# TODO: Add this to backend manager in the future
if args.dump_info: if args.dump_info:
from mmdeploy.backend.ascend import update_sdk_pipeline from mmdeploy.backend.ascend import update_sdk_pipeline
update_sdk_pipeline(args.work_dir) update_sdk_pipeline(args.work_dir)
elif backend == Backend.COREML: # convert to backend
from mmdeploy.apis.coreml import from_torchscript, get_model_suffix PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
coreml_pipeline_funcs = [from_torchscript] if backend == Backend.TENSORRT:
PIPELINE_MANAGER.set_log_level(log_level, coreml_pipeline_funcs) PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
model_inputs = get_model_inputs(deploy_cfg) backend_files = to_backend(
coreml_files = [] backend,
for model_id, torchscript_path in enumerate(ir_files): ir_files,
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0] work_dir=args.work_dir,
output_file_prefix = osp.join(args.work_dir, torchscript_name) deploy_cfg=deploy_cfg,
convert_to = deploy_cfg.backend_config.convert_to log_level=log_level,
from_torchscript(torchscript_path, output_file_prefix, device=args.device,
ir_config.input_names, ir_config.output_names, uri=args.uri)
model_inputs[model_id].input_shapes, convert_to)
suffix = get_model_suffix(convert_to)
coreml_files.append(output_file_prefix + suffix)
backend_files = coreml_files
elif backend == Backend.TVM:
import copy
from mmdeploy.apis.tvm import from_onnx, get_library_ext # ncnn quantization
PIPELINE_MANAGER.set_log_level(log_level, [from_onnx]) if backend == Backend.NCNN and quant:
model_inputs = get_model_inputs(deploy_cfg) from onnx2ncnn_quant_table import get_table
if args.device.startswith('cuda'): from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
target = 'cuda' model_param_paths = backend_files[::2]
else: model_bin_paths = backend_files[1::2]
target = 'llvm' backend_files = []
for onnx_path, model_param_path, model_bin_path in zip(
ir_files, model_param_paths, model_bin_paths):
lib_ext = get_library_ext() deploy_cfg, model_cfg = load_config(deploy_cfg_path,
model_cfg_path)
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
onnx_path, args.work_dir)
tvm_files = [] create_process(
for model_id, onnx_path in enumerate(ir_files): 'ncnn quant table',
model_input = copy.deepcopy(model_inputs[model_id]) target=get_table,
use_vm = model_input.get('use_vm', False) args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
if 'target' not in model_input['tuner']: quant_table, quant_image_dir, args.device),
model_input['tuner']['target'] = target kwargs=dict(),
lib_path = osp.splitext(onnx_path)[0] + lib_ext ret_value=ret_value)
code_path = osp.splitext(
onnx_path)[0] + '.code' if use_vm else None
model_input['output_file'] = lib_path
model_input['onnx_model'] = onnx_path
model_input['bytecode_file'] = code_path
# create calibration dataset create_process(
if 'qconfig' in model_input: 'ncnn_int8',
calib_path = osp.join(args.work_dir, calib_filename) target=ncnn2int8,
from mmdeploy.backend.tvm import HDF5Dataset args=(model_param_path, model_bin_path, quant_table,
partition_type = 'end2end' if partition_cfgs is None \ quant_param, quant_bin),
else onnx_name kwargs=dict(),
dataset = HDF5Dataset( ret_value=ret_value)
calib_path, backend_files += [quant_param, quant_bin]
model_input['shape'],
model_type=partition_type,
device=target)
model_input['dataset'] = dataset()
from_onnx(**model_input)
tvm_files += [lib_path, code_path]
backend_files = tvm_files
if args.test_img is None: if args.test_img is None:
args.test_img = args.img args.test_img = args.img