[Refactor] Add 'to_backend' in BackendManager (#1522)
* Refactor to backend * export_postprocess_mask = False as defailt * update zh_cn docs * solve comment * fix commentpull/1574/head
parent
26d71ce0a8
commit
5285caf30a
|
@ -123,32 +123,20 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
|
|||
__all__ += ['onnx2ncnn', 'get_output_model_file']
|
||||
```
|
||||
|
||||
Then add the codes about conversion to `tools/deploy.py` using these APIs if necessary.
|
||||
Create a backend manager class which derive from `BackendManager`, implement its `to_backend` static method.
|
||||
|
||||
**Example:**
|
||||
|
||||
```Python
|
||||
# tools/deploy.py
|
||||
# ...
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
|
||||
|
||||
if not is_available_ncnn():
|
||||
logging.error('ncnn support is not available.')
|
||||
exit(-1)
|
||||
|
||||
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in onnx_files:
|
||||
create_process(
|
||||
f'onnx2ncnn with {onnx_path}',
|
||||
target=onnx2ncnn,
|
||||
args=(onnx_path, args.work_dir),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
backend_files += get_output_model_file(onnx_path, args.work_dir)
|
||||
# ...
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
deploy_cfg: Any,
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
return ir_files
|
||||
```
|
||||
|
||||
6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators.
|
||||
|
|
|
@ -106,7 +106,7 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
|
|||
call([onnx2ncnn_path, onnx_path, save_param, save_bin])\
|
||||
```
|
||||
|
||||
5. 在 `mmdeploy/apis` 中创建新后端库并声明对应 APIs
|
||||
从 BackendManager 派生类,实现 `to_backend` 类方法。
|
||||
|
||||
**例子**
|
||||
|
||||
|
@ -128,32 +128,20 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
|
|||
**例子**
|
||||
|
||||
```Python
|
||||
# tools/deploy.py
|
||||
# ...
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
|
||||
|
||||
if not is_available_ncnn():
|
||||
logging.error('ncnn support is not available.')
|
||||
exit(-1)
|
||||
|
||||
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in onnx_files:
|
||||
create_process(
|
||||
f'mmdeploy_onnx2ncnn with {onnx_path}',
|
||||
target=onnx2ncnn,
|
||||
args=(onnx_path, args.work_dir),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
backend_files += get_output_model_file(onnx_path, args.work_dir)
|
||||
# ...
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
deploy_cfg: Any,
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
return ir_files
|
||||
```
|
||||
|
||||
6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
|
||||
5. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
|
||||
|
||||
7. 为新后端引擎代码添加相关注释和单元测试:).
|
||||
6. 为新后端引擎代码添加相关注释和单元测试:).
|
||||
|
||||
## 支持后端推理
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .calibration import create_calib_input_data
|
||||
from .utils import build_task_processor, get_predefined_partition_cfg
|
||||
from .utils import (build_task_processor, get_predefined_partition_cfg,
|
||||
to_backend)
|
||||
|
||||
__all__ = [
|
||||
'create_calib_input_data', 'build_task_processor',
|
||||
'get_predefined_partition_cfg'
|
||||
'get_predefined_partition_cfg', 'to_backend'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
|
||||
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
|
||||
parse_device_id)
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
|
||||
def check_backend_device(deploy_cfg: mmcv.Config, device: str):
|
||||
|
@ -62,3 +66,36 @@ def get_predefined_partition_cfg(deploy_cfg: mmcv.Config, partition_type: str):
|
|||
codebase = get_codebase_class(codebase_type)
|
||||
task_processor_class = codebase.get_task_class(task)
|
||||
return task_processor_class.get_partition_cfg(partition_type)
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def to_backend(backend_name: str,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Optional[Any] = None,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
backend_name (str): The name of the backend.
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from mmdeploy.backend.base import get_backend_manager
|
||||
backend_mgr = get_backend_manager(backend_name)
|
||||
return backend_mgr.to_backend(
|
||||
ir_files=ir_files,
|
||||
work_dir=work_dir,
|
||||
deploy_cfg=deploy_cfg,
|
||||
log_level=log_level,
|
||||
device=device,
|
||||
**kwargs)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -29,3 +31,38 @@ class AscendManager(BaseBackendManager):
|
|||
"""
|
||||
from .wrapper import AscendWrapper
|
||||
return AscendWrapper(model=backend_files[0], device=device)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from mmdeploy.utils import get_model_inputs
|
||||
from .onnx2ascend import from_onnx
|
||||
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
|
||||
om_files = []
|
||||
for model_id, onnx_path in enumerate(ir_files):
|
||||
om_path = osp.splitext(onnx_path)[0] + '.om'
|
||||
from_onnx(onnx_path, work_dir, model_inputs[model_id])
|
||||
om_files.append(om_path)
|
||||
backend_files = om_files
|
||||
|
||||
return backend_files
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backend_manager import BACKEND_MANAGERS, BaseBackendManager
|
||||
from .backend_manager import (BACKEND_MANAGERS, BaseBackendManager,
|
||||
get_backend_manager)
|
||||
from .backend_wrapper_registry import (BACKEND_WRAPPER, get_backend_file_count,
|
||||
get_backend_wrapper_class)
|
||||
from .base_wrapper import BaseWrapper
|
||||
|
||||
__all__ = [
|
||||
'BACKEND_MANAGERS', 'BaseBackendManager', 'BaseWrapper', 'BACKEND_WRAPPER',
|
||||
'get_backend_wrapper_class', 'get_backend_file_count'
|
||||
'BACKEND_MANAGERS', 'BaseBackendManager', 'get_backend_manager',
|
||||
'BaseWrapper', 'BACKEND_WRAPPER', 'get_backend_wrapper_class',
|
||||
'get_backend_file_count'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
|
@ -28,7 +29,31 @@ class BaseBackendManager(metaclass=ABCMeta):
|
|||
to None.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'build_wrapper has not been implemented for {cls}')
|
||||
f'build_wrapper has not been implemented for `{cls.__name__}`')
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'to_backend has not been implemented for `{cls.__name__}`')
|
||||
|
||||
|
||||
class BackendManagerRegistry:
|
||||
|
@ -89,3 +114,18 @@ class BackendManagerRegistry:
|
|||
|
||||
|
||||
BACKEND_MANAGERS = BackendManagerRegistry()
|
||||
|
||||
|
||||
def get_backend_manager(name: str) -> BaseBackendManager:
|
||||
"""Get backend manager.
|
||||
|
||||
Args:
|
||||
name (str): name of the backend.
|
||||
|
||||
Returns:
|
||||
BaseBackendManager: The backend manager of given name
|
||||
"""
|
||||
from enum import Enum
|
||||
if isinstance(name, Enum):
|
||||
name = name.value
|
||||
return BACKEND_MANAGERS.find(name)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -29,3 +31,36 @@ class CoreMLManager(BaseBackendManager):
|
|||
"""
|
||||
from .wrapper import CoreMLWrapper
|
||||
return CoreMLWrapper(model_file=backend_files[0])
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from .torchscript2coreml import from_torchscript
|
||||
|
||||
coreml_files = []
|
||||
for model_id, torchscript_path in enumerate(ir_files):
|
||||
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
|
||||
output_file_prefix = osp.join(work_dir, torchscript_name)
|
||||
|
||||
from_torchscript(model_id, torchscript_path, output_file_prefix,
|
||||
deploy_cfg, coreml_files)
|
||||
|
||||
return coreml_files
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os.path as osp
|
||||
import sys
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from mmdeploy.utils import get_backend_config
|
||||
from mmdeploy.utils import get_backend_config, get_root_logger
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
||||
|
||||
|
@ -43,3 +45,46 @@ class NCNNManager(BaseBackendManager):
|
|||
bin_file=backend_files[1],
|
||||
output_names=output_names,
|
||||
use_vulkan=use_vulkan)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
|
||||
from . import is_available
|
||||
|
||||
if not is_available():
|
||||
logger.error('ncnn support is not available, please make sure:\n'
|
||||
'1) `mmdeploy_onnx2ncnn` existed in `PATH`\n'
|
||||
'2) python import ncnn success')
|
||||
sys.exit(1)
|
||||
|
||||
from mmdeploy.apis.ncnn import get_output_model_file
|
||||
from .onnx2ncnn import from_onnx
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_param_path, model_bin_path = get_output_model_file(
|
||||
onnx_path, work_dir)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
from_onnx(onnx_path, osp.join(work_dir, onnx_name))
|
||||
|
||||
backend_files += [model_param_path, model_bin_path]
|
||||
|
||||
return backend_files
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -33,3 +34,24 @@ class ONNXRuntimeManager(BaseBackendManager):
|
|||
onnx_file=backend_files[0],
|
||||
device=device,
|
||||
output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
return ir_files
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -31,3 +32,46 @@ class OpenVINOManager(BaseBackendManager):
|
|||
from .wrapper import OpenVINOWrapper
|
||||
return OpenVINOWrapper(
|
||||
ir_model_file=backend_files[0], output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from . import is_available
|
||||
assert is_available(), \
|
||||
'OpenVINO is not available, please install OpenVINO first.'
|
||||
|
||||
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
|
||||
get_mo_options_from_cfg,
|
||||
get_output_model_file)
|
||||
from mmdeploy.utils import get_ir_config
|
||||
from .onnx2openvino import from_onnx
|
||||
|
||||
openvino_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_xml_path = get_output_model_file(onnx_path, work_dir)
|
||||
input_info = get_input_info_from_cfg(deploy_cfg)
|
||||
output_names = get_ir_config(deploy_cfg).output_names
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
from_onnx(onnx_path, work_dir, input_info, output_names,
|
||||
mo_options)
|
||||
openvino_files.append(model_xml_path)
|
||||
|
||||
return openvino_files
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -34,3 +35,49 @@ class PPLNNManager(BaseBackendManager):
|
|||
algo_file=backend_files[1] if len(backend_files) > 1 else None,
|
||||
device=device,
|
||||
output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from mmdeploy.utils import get_model_inputs
|
||||
from . import is_available
|
||||
from .onnx2pplnn import from_onnx
|
||||
assert is_available(), \
|
||||
'PPLNN is not available, please install PPLNN first.'
|
||||
|
||||
pplnn_files = []
|
||||
for onnx_path in ir_files:
|
||||
algo_file = onnx_path.replace('.onnx', '.json')
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
assert 'opt_shape' in model_inputs, 'Expect opt_shape ' \
|
||||
'in deploy config for PPLNN'
|
||||
# PPLNN accepts only 1 input shape for optimization,
|
||||
# may get changed in the future
|
||||
input_shapes = [model_inputs.opt_shape]
|
||||
algo_prefix = osp.splitext(algo_file)[0]
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
algo_prefix,
|
||||
device=device,
|
||||
input_shapes=input_shapes)
|
||||
pplnn_files += [onnx_path, algo_file]
|
||||
|
||||
return pplnn_files
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from mmdeploy.utils import get_common_config
|
||||
|
@ -35,3 +37,39 @@ class RKNNManager(BaseBackendManager):
|
|||
model=backend_files[0],
|
||||
common_config=common_config,
|
||||
output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from . import is_available
|
||||
assert is_available(
|
||||
), 'RKNN is not available, please install RKNN first.'
|
||||
|
||||
from .onnx2rknn import onnx2rknn
|
||||
|
||||
backend_files = []
|
||||
for model_id, onnx_path in zip(range(len(ir_files)), ir_files):
|
||||
pre_fix_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
output_path = osp.join(work_dir, pre_fix_name + '.rknn')
|
||||
onnx2rknn(onnx_path, output_path, deploy_cfg)
|
||||
backend_files.append(output_path)
|
||||
|
||||
return backend_files
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmdeploy.utils import (get_backend_config, get_common_config,
|
|||
def onnx2rknn(onnx_model: str,
|
||||
output_path: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Optional[Union[str, mmcv.Config]] = None,
|
||||
dataset_file: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""Convert ONNX to RKNN.
|
||||
|
@ -43,7 +43,7 @@ def onnx2rknn(onnx_model: str,
|
|||
input_size_list = get_backend_config(deploy_cfg).get(
|
||||
'input_size_list', None)
|
||||
# update norm value
|
||||
if get_rknn_quantization(deploy_cfg) is True:
|
||||
if get_rknn_quantization(deploy_cfg) is True and model_cfg is not None:
|
||||
transform = get_normalization(model_cfg)
|
||||
common_params.update(
|
||||
dict(
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from mmdeploy.utils import get_root_logger
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
||||
|
||||
|
@ -33,3 +37,48 @@ class SNPEManager(BaseBackendManager):
|
|||
uri = kwargs['uri']
|
||||
return SNPEWrapper(
|
||||
dlc_file=backend_files[0], uri=uri, output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
uri: str = '',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
from . import is_available
|
||||
logger = get_root_logger()
|
||||
|
||||
if not is_available():
|
||||
logger.error('snpe support is not available, please check\n'
|
||||
'1) `snpe-onnx-to-dlc` existed in `PATH`\n'
|
||||
'2) snpe only support\n'
|
||||
'ubuntu18.04')
|
||||
exit(1)
|
||||
|
||||
from mmdeploy.apis.snpe import get_env_key, get_output_model_file
|
||||
from .onnx2dlc import from_onnx
|
||||
|
||||
if get_env_key() not in os.environ:
|
||||
os.environ[get_env_key()] = uri
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in ir_files:
|
||||
dlc_path = get_output_model_file(onnx_path, work_dir)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
from_onnx(onnx_path, osp.join(work_dir, onnx_name))
|
||||
backend_files += [dlc_path]
|
||||
|
||||
return backend_files
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -31,3 +32,58 @@ class TensorRTManager(BaseBackendManager):
|
|||
|
||||
from .wrapper import TRTWrapper
|
||||
return TRTWrapper(engine=backend_files[0], output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
import os.path as osp
|
||||
|
||||
from mmdeploy.utils import get_model_inputs, get_partition_config
|
||||
model_params = get_model_inputs(deploy_cfg)
|
||||
partition_cfgs = get_partition_config(deploy_cfg)
|
||||
assert len(model_params) == len(ir_files)
|
||||
|
||||
from . import is_available
|
||||
assert is_available(), (
|
||||
'TensorRT is not available,'
|
||||
' please install TensorRT and build TensorRT custom ops first.')
|
||||
|
||||
from .onnx2tensorrt import onnx2tensorrt
|
||||
backend_files = []
|
||||
for model_id, model_param, onnx_path in zip(
|
||||
range(len(ir_files)), model_params, ir_files):
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
save_file = model_param.get('save_file', onnx_name + '.engine')
|
||||
|
||||
partition_type = 'end2end' if partition_cfgs is None \
|
||||
else onnx_name
|
||||
onnx2tensorrt(
|
||||
work_dir,
|
||||
save_file,
|
||||
model_id,
|
||||
deploy_cfg,
|
||||
onnx_path,
|
||||
device=device,
|
||||
partition_type=partition_type)
|
||||
|
||||
backend_files.append(osp.join(work_dir, save_file))
|
||||
|
||||
return backend_files
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -33,3 +34,24 @@ class TorchScriptManager(BaseBackendManager):
|
|||
model=backend_files[0],
|
||||
input_names=input_names,
|
||||
output_names=output_names)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
return ir_files
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from ..base import BACKEND_MANAGERS, BaseBackendManager
|
||||
|
@ -35,3 +37,75 @@ class TVMManager(BaseBackendManager):
|
|||
bytecode=bytecode,
|
||||
output_names=output_names,
|
||||
device=device)
|
||||
|
||||
@classmethod
|
||||
def to_backend(cls,
|
||||
ir_files: Sequence[str],
|
||||
work_dir: str,
|
||||
deploy_cfg: Any,
|
||||
log_level: int = logging.INFO,
|
||||
device: str = 'cpu',
|
||||
**kwargs) -> Sequence[str]:
|
||||
"""Convert intermediate representation to given backend.
|
||||
|
||||
Args:
|
||||
ir_files (Sequence[str]): The intermediate representation files.
|
||||
work_dir (str): The work directory, backend files and logs should
|
||||
be save in this directory.
|
||||
deploy_cfg (Any): The deploy config.
|
||||
log_level (int, optional): The log level. Defaults to logging.INFO.
|
||||
device (str, optional): The device type. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
Seqeuence[str]: Backend files.
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
from mmdeploy.apis.tvm import get_library_ext
|
||||
from mmdeploy.utils import (get_calib_filename, get_model_inputs,
|
||||
get_partition_config)
|
||||
from .onnx2tvm import from_onnx
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
|
||||
if device.startswith('cuda'):
|
||||
target = 'cuda'
|
||||
else:
|
||||
target = 'llvm'
|
||||
|
||||
lib_ext = get_library_ext()
|
||||
|
||||
tvm_files = []
|
||||
for model_id, onnx_path in enumerate(ir_files):
|
||||
model_input = copy.deepcopy(model_inputs[model_id])
|
||||
use_vm = model_input.get('use_vm', False)
|
||||
if 'target' not in model_input['tuner']:
|
||||
model_input['tuner']['target'] = target
|
||||
lib_path = osp.splitext(onnx_path)[0] + lib_ext
|
||||
code_path = osp.splitext(
|
||||
onnx_path)[0] + '.code' if use_vm else None
|
||||
model_input['output_file'] = lib_path
|
||||
model_input['onnx_model'] = onnx_path
|
||||
model_input['bytecode_file'] = code_path
|
||||
|
||||
# create calibration dataset
|
||||
if 'qconfig' in model_input:
|
||||
from .quantize import HDF5Dataset
|
||||
calib_filename = get_calib_filename(deploy_cfg)
|
||||
calib_path = osp.join(work_dir, calib_filename)
|
||||
partition_cfgs = get_partition_config(deploy_cfg)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
partition_type = 'end2end' if partition_cfgs is None \
|
||||
else onnx_name
|
||||
dataset = HDF5Dataset(
|
||||
calib_path,
|
||||
model_input['shape'],
|
||||
model_type=partition_type,
|
||||
device=target)
|
||||
model_input['dataset'] = dataset()
|
||||
|
||||
from_onnx(**model_input)
|
||||
|
||||
tvm_files += [lib_path, code_path]
|
||||
|
||||
return tvm_files
|
||||
|
|
|
@ -54,9 +54,9 @@ class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta):
|
|||
names from the model.
|
||||
deploy_cfg: Deployment config file.
|
||||
"""
|
||||
from mmdeploy.backend.base import BACKEND_MANAGERS
|
||||
from mmdeploy.backend.base import get_backend_manager
|
||||
|
||||
backend_mgr = BACKEND_MANAGERS.find(backend.value)
|
||||
backend_mgr = get_backend_manager(backend.value)
|
||||
if backend_mgr is None:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported backend type: {backend.value}')
|
||||
|
|
|
@ -231,14 +231,14 @@ class End2EndModel(BaseBackendModel):
|
|||
masks = batch_masks[i]
|
||||
img_h, img_w = img_metas[i]['img_shape'][:2]
|
||||
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
|
||||
export_postprocess_mask = True
|
||||
export_postprocess_mask = False
|
||||
if self.deploy_cfg is not None:
|
||||
|
||||
mmdet_deploy_cfg = get_post_processing_params(
|
||||
self.deploy_cfg)
|
||||
# this flag enable postprocess when export.
|
||||
export_postprocess_mask = mmdet_deploy_cfg.get(
|
||||
'export_postprocess_mask', True)
|
||||
'export_postprocess_mask', False)
|
||||
if not export_postprocess_mask:
|
||||
masks = End2EndModel.postprocessing_masks(
|
||||
dets[:, :4], masks, ori_w, ori_h, self.device)
|
||||
|
|
|
@ -38,7 +38,8 @@ def fcn_mask_head__get_seg_masks(ctx, self, mask_pred, det_bboxes, det_labels,
|
|||
# grid sample is not supported by most engine
|
||||
# so we add a flag to disable it.
|
||||
mmdet_params = get_post_processing_params(ctx.cfg)
|
||||
export_postprocess_mask = mmdet_params.get('export_postprocess_mask', True)
|
||||
export_postprocess_mask = mmdet_params.get('export_postprocess_mask',
|
||||
False)
|
||||
if not export_postprocess_mask:
|
||||
return mask_pred
|
||||
|
||||
|
|
|
@ -114,6 +114,8 @@ def check_backend(backend: Backend, require_plugin: bool = False):
|
|||
from mmdeploy.backend.ascend import is_available
|
||||
elif backend == Backend.TVM:
|
||||
from mmdeploy.backend.tvm import is_available
|
||||
elif backend == Backend.COREML:
|
||||
from mmdeploy.backend.coreml import is_available
|
||||
else:
|
||||
warnings.warn('The backend checker is not available')
|
||||
return
|
||||
|
@ -459,6 +461,7 @@ def get_backend_outputs(ir_file_path: str,
|
|||
If the backend specified in 'deploy_cfg' is not available,
|
||||
then None will be returned.
|
||||
"""
|
||||
from mmdeploy.apis.utils import to_backend
|
||||
backend = get_backend(deploy_cfg)
|
||||
flatten_model_inputs = get_flatten_inputs(model_inputs)
|
||||
ir_config = get_ir_config(deploy_cfg)
|
||||
|
@ -469,109 +472,32 @@ def get_backend_outputs(ir_file_path: str,
|
|||
k for k, v in flatten_model_inputs.items() if k != 'ctx'
|
||||
]
|
||||
|
||||
# prepare backend model and input features
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
device = 'cpu'
|
||||
|
||||
# TODO: Try to wrap these code into backend manager
|
||||
if backend != Backend.TORCHSCRIPT:
|
||||
model_inputs = flatten_model_inputs
|
||||
if backend == Backend.TENSORRT:
|
||||
# convert to engine
|
||||
import mmdeploy.apis.tensorrt as trt_apis
|
||||
if not (trt_apis.is_available()
|
||||
and trt_apis.is_custom_ops_available()):
|
||||
return None
|
||||
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
|
||||
trt_apis.onnx2tensorrt(
|
||||
'',
|
||||
trt_file_path,
|
||||
0,
|
||||
deploy_cfg=deploy_cfg,
|
||||
onnx_model=ir_file_path)
|
||||
backend_files = [trt_file_path]
|
||||
for k, v in model_inputs.items():
|
||||
model_inputs[k] = model_inputs[k].cuda()
|
||||
|
||||
backend_feats = model_inputs
|
||||
device = 'cuda:0'
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
if not (ort_apis.is_available()
|
||||
and ort_apis.is_custom_ops_available()):
|
||||
return None
|
||||
feature_list = []
|
||||
backend_feats = {}
|
||||
for k, item in model_inputs.items():
|
||||
if type(item) is torch.Tensor:
|
||||
feature_list.append(item)
|
||||
elif type(item) is tuple or list:
|
||||
for i in item:
|
||||
assert type(i) is torch.Tensor, 'model_inputs contains '
|
||||
'nested sequence of torch.Tensor'
|
||||
feature_list.append(i)
|
||||
else:
|
||||
raise TypeError('values of model_inputs are expected to be '
|
||||
'torch.Tensor or its sequence, '
|
||||
f'but got {type(model_inputs)}')
|
||||
|
||||
# for onnx file generated with list[torch.Tensor] input,
|
||||
# the input dict keys are just numbers if not specified
|
||||
for i in range(len(feature_list)):
|
||||
if i < len(input_names):
|
||||
backend_feats[input_names[i]] = feature_list[i]
|
||||
else:
|
||||
backend_feats[str(i)] = feature_list[i]
|
||||
backend_files = [ir_file_path]
|
||||
device = 'cpu'
|
||||
elif backend == Backend.NCNN:
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
if not (ncnn_apis.is_available()
|
||||
and ncnn_apis.is_custom_ops_available()):
|
||||
return None
|
||||
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
|
||||
ncnn_files_prefix = osp.splitext(ir_file_path)[0]
|
||||
ncnn_apis.from_onnx(ir_file_path, ncnn_files_prefix)
|
||||
backend_files = [param_path, bin_path]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
|
||||
device = 'cuda'
|
||||
model_inputs = dict((k, v.cuda()) for k, v in model_inputs.items())
|
||||
elif backend == Backend.OPENVINO:
|
||||
import mmdeploy.apis.openvino as openvino_apis
|
||||
if not openvino_apis.is_available():
|
||||
return None
|
||||
from mmdeploy.apis.openvino import get_mo_options_from_cfg
|
||||
openvino_work_dir = tempfile.TemporaryDirectory().name
|
||||
openvino_file_path = openvino_apis.get_output_model_file(
|
||||
ir_file_path, openvino_work_dir)
|
||||
input_info = {
|
||||
name: value.shape
|
||||
for name, value in flatten_model_inputs.items()
|
||||
}
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
openvino_apis.from_onnx(ir_file_path, openvino_work_dir, input_info,
|
||||
output_names, mo_options)
|
||||
backend_files = [openvino_file_path]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
deploy_cfg['backend_config']['model_inputs'] = [
|
||||
dict(opt_shapes=input_info)
|
||||
]
|
||||
backend_files = to_backend(
|
||||
backend.value, [ir_file_path],
|
||||
work_dir=work_dir,
|
||||
deploy_cfg=deploy_cfg,
|
||||
device=device)
|
||||
backend_feats = model_inputs
|
||||
|
||||
elif backend == Backend.DEFAULT:
|
||||
return None
|
||||
elif backend == Backend.TORCHSCRIPT:
|
||||
backend_files = [ir_file_path]
|
||||
device = 'cpu'
|
||||
if backend == Backend.TORCHSCRIPT:
|
||||
backend_feats = [v for _, v in model_inputs.items()]
|
||||
elif backend == Backend.ASCEND:
|
||||
# Ascend model conversion
|
||||
import mmdeploy.apis.ascend as ascend_apis
|
||||
from mmdeploy.utils import get_model_inputs
|
||||
if not ascend_apis.is_available():
|
||||
return None
|
||||
work_dir = osp.split(ir_file_path)[0]
|
||||
# convert model
|
||||
convert_args = get_model_inputs(deploy_cfg)
|
||||
ascend_apis.from_onnx(ir_file_path, work_dir, convert_args[0])
|
||||
om_file_name = osp.splitext(osp.split(ir_file_path)[1])[0]
|
||||
backend_files = [osp.join(work_dir, om_file_name + '.om')]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unimplemented backend type: {backend.value}')
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
backend_model = BaseBackendModel._build_wrapper(
|
||||
|
|
|
@ -146,94 +146,41 @@ def onnx2backend(backend, onnx_file):
|
|||
|
||||
|
||||
def create_wrapper(backend, model_files):
|
||||
if backend == Backend.TENSORRT:
|
||||
from mmdeploy.backend.tensorrt import TRTWrapper
|
||||
trt_model = TRTWrapper(model_files, output_names)
|
||||
return trt_model
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_model = ORTWrapper(model_files, 'cpu', output_names)
|
||||
return ort_model
|
||||
elif backend == Backend.PPLNN:
|
||||
from mmdeploy.backend.pplnn import PPLNNWrapper
|
||||
onnx_file, algo_file = model_files
|
||||
pplnn_model = PPLNNWrapper(onnx_file, algo_file, 'cpu', output_names)
|
||||
return pplnn_model
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.backend.ncnn import NCNNWrapper
|
||||
param_file, bin_file = model_files
|
||||
ncnn_model = NCNNWrapper(param_file, bin_file, output_names)
|
||||
return ncnn_model
|
||||
elif backend == Backend.OPENVINO:
|
||||
from mmdeploy.backend.openvino import OpenVINOWrapper
|
||||
openvino_model = OpenVINOWrapper(model_files, output_names)
|
||||
return openvino_model
|
||||
elif backend == Backend.TORCHSCRIPT:
|
||||
from mmdeploy.backend.torchscript import TorchscriptWrapper
|
||||
torchscript_model = TorchscriptWrapper(
|
||||
model_files, input_names=input_names, output_names=output_names)
|
||||
return torchscript_model
|
||||
from mmdeploy.backend.base import get_backend_manager
|
||||
backend_mgr = get_backend_manager(backend.value)
|
||||
deploy_cfg = None
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
elif backend == Backend.RKNN:
|
||||
from mmdeploy.backend.rknn import RKNNWrapper
|
||||
rknn_model = RKNNWrapper(
|
||||
model_files,
|
||||
common_config=dict(target_platform=target_platform),
|
||||
output_names=output_names)
|
||||
return rknn_model
|
||||
elif backend == Backend.ASCEND:
|
||||
from mmdeploy.backend.ascend import AscendWrapper
|
||||
ascend_model = AscendWrapper(model_files)
|
||||
return ascend_model
|
||||
elif backend == Backend.TVM:
|
||||
from mmdeploy.backend.tvm import TVMWrapper
|
||||
tvm_model = TVMWrapper(model_files, output_names=output_names)
|
||||
return tvm_model
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown backend type: {backend.value}')
|
||||
deploy_cfg = dict(
|
||||
backend_config=dict(
|
||||
common_config=dict(target_platform=target_platform)))
|
||||
return backend_mgr.build_wrapper(
|
||||
model_files,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
|
||||
def run_wrapper(backend, wrapper, input):
|
||||
if backend == Backend.TENSORRT:
|
||||
input = input.cuda()
|
||||
results = wrapper({'input': input})['output']
|
||||
results = results.detach().cpu()
|
||||
return results
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
results = wrapper({'input': input})['output']
|
||||
results = results.detach().cpu()
|
||||
return results
|
||||
elif backend == Backend.PPLNN:
|
||||
results = wrapper({'input': input})['output']
|
||||
results = results.detach().cpu()
|
||||
return results
|
||||
elif backend == Backend.NCNN:
|
||||
input = input.float()
|
||||
results = wrapper({'input': input})['output']
|
||||
results = results.detach().cpu()
|
||||
return results
|
||||
elif backend == Backend.OPENVINO:
|
||||
results = wrapper({'input': input})['output']
|
||||
results = results.detach().cpu()
|
||||
return results
|
||||
elif backend == Backend.TORCHSCRIPT:
|
||||
results = wrapper({'input': input})['output']
|
||||
return results
|
||||
elif backend == Backend.RKNN:
|
||||
results = wrapper({'input': input})
|
||||
elif backend == Backend.ASCEND:
|
||||
results = wrapper({'input': input})['output']
|
||||
return results
|
||||
elif backend == Backend.TVM:
|
||||
results = wrapper({'input': input})['output']
|
||||
return results
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown backend type: {backend.value}')
|
||||
|
||||
results = wrapper({'input': input})
|
||||
|
||||
if backend != Backend.RKNN:
|
||||
results = results['output']
|
||||
|
||||
results = results.detach().cpu()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
ALL_BACKEND = [
|
||||
Backend.TENSORRT, Backend.ONNXRUNTIME, Backend.PPLNN, Backend.NCNN,
|
||||
Backend.OPENVINO, Backend.TORCHSCRIPT, Backend.ASCEND, Backend.RKNN,
|
||||
Backend.TVM
|
||||
Backend.COREML, Backend.TVM
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -574,6 +574,7 @@ def test_forward_of_base_detector(model_cfg_path, backend):
|
|||
pre_top_k=-1,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
export_postprocess_mask=False,
|
||||
))))
|
||||
|
||||
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -23,7 +24,7 @@ from mmdeploy.codebase.mmdet.deploy.object_detection_model import End2EndModel
|
|||
|
||||
def assert_det_results(results, module_name: str = 'model'):
|
||||
assert results is not None, f'failed to get output using {module_name}'
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(results, Sequence)
|
||||
assert len(results) == 2
|
||||
assert results[0].shape[0] == results[1].shape[0]
|
||||
assert results[0].shape[1] == results[1].shape[1]
|
||||
|
|
|
@ -386,8 +386,9 @@ def test_simple_test_of_encode_decode_recognizer(backend):
|
|||
def test_forward_of_fpnc(backend: Backend):
|
||||
"""Test forward rewrite of fpnc."""
|
||||
check_backend(backend)
|
||||
fpnc = get_fpnc_neck_model()
|
||||
fpnc = get_fpnc_neck_model().cuda()
|
||||
fpnc.eval()
|
||||
input = torch.rand(1, 1, 64, 64).cuda()
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(
|
||||
|
@ -397,9 +398,9 @@ def test_forward_of_fpnc(backend: Backend):
|
|||
dict(
|
||||
input_shapes=dict(
|
||||
inputs=dict(
|
||||
min_shape=[1, 3, 64, 64],
|
||||
opt_shape=[1, 3, 64, 64],
|
||||
max_shape=[1, 3, 64, 64])))
|
||||
min_shape=input.shape,
|
||||
opt_shape=input.shape,
|
||||
max_shape=input.shape)))
|
||||
]),
|
||||
onnx_config=dict(
|
||||
input_shape=None,
|
||||
|
@ -407,7 +408,6 @@ def test_forward_of_fpnc(backend: Backend):
|
|||
output_names=['output']),
|
||||
codebase_config=dict(type='mmocr', task='TextDetection')))
|
||||
|
||||
input = torch.rand(1, 3, 64, 64).cuda()
|
||||
model_inputs = {
|
||||
'inputs': input,
|
||||
}
|
||||
|
|
|
@ -297,7 +297,7 @@ def test_upconvblock_forward(backend, is_dynamic_shape):
|
|||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
input_names=['skip', 'x'],
|
||||
input_names=['x', 'skip'],
|
||||
output_names=['output'],
|
||||
dynamic_axes=dynamic_axes),
|
||||
codebase_config=dict(
|
||||
|
|
317
tools/deploy.py
317
tools/deploy.py
|
@ -13,11 +13,11 @@ from mmdeploy.apis import (create_calib_input_data, extract_model,
|
|||
get_predefined_partition_cfg, torch2onnx,
|
||||
torch2torchscript, visualize_model)
|
||||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||
from mmdeploy.apis.utils import to_backend
|
||||
from mmdeploy.backend.sdk.export_info import export2SDK
|
||||
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
|
||||
get_ir_config, get_model_inputs,
|
||||
get_partition_config, get_root_logger, load_config,
|
||||
target_wrapper)
|
||||
get_ir_config, get_partition_config,
|
||||
get_root_logger, load_config, target_wrapper)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -193,268 +193,79 @@ def main():
|
|||
backend_files = ir_files
|
||||
# convert backend
|
||||
backend = get_backend(deploy_cfg)
|
||||
if backend == Backend.TENSORRT:
|
||||
model_params = get_model_inputs(deploy_cfg)
|
||||
assert len(model_params) == len(ir_files)
|
||||
|
||||
from mmdeploy.apis.tensorrt import is_available as trt_is_available
|
||||
assert trt_is_available(), (
|
||||
'TensorRT is not available,'
|
||||
' please install TensorRT and build TensorRT custom ops first.')
|
||||
# preprocess deploy_cfg
|
||||
if backend == Backend.RKNN:
|
||||
# TODO: Add this to task_processor in the future
|
||||
import tempfile
|
||||
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
PIPELINE_MANAGER.enable_multiprocess(True, [onnx2tensorrt])
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [onnx2tensorrt])
|
||||
|
||||
backend_files = []
|
||||
for model_id, model_param, onnx_path in zip(
|
||||
range(len(ir_files)), model_params, ir_files):
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
save_file = model_param.get('save_file', onnx_name + '.engine')
|
||||
|
||||
partition_type = 'end2end' if partition_cfgs is None \
|
||||
else onnx_name
|
||||
onnx2tensorrt(
|
||||
args.work_dir,
|
||||
save_file,
|
||||
model_id,
|
||||
deploy_cfg_path,
|
||||
onnx_path,
|
||||
device=args.device,
|
||||
partition_type=partition_type)
|
||||
|
||||
backend_files.append(osp.join(args.work_dir, save_file))
|
||||
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
|
||||
|
||||
if not is_available_ncnn():
|
||||
logger.error('ncnn support is not available, please make sure:\n'
|
||||
'1) `mmdeploy_onnx2ncnn` existed in `PATH`\n'
|
||||
'2) python import ncnn success')
|
||||
exit(1)
|
||||
|
||||
import mmdeploy.apis.ncnn as ncnn_api
|
||||
from mmdeploy.apis.ncnn import get_output_model_file
|
||||
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [ncnn_api.from_onnx])
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_param_path, model_bin_path = get_output_model_file(
|
||||
onnx_path, args.work_dir)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
ncnn_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name))
|
||||
|
||||
if quant:
|
||||
from onnx2ncnn_quant_table import get_table
|
||||
|
||||
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
|
||||
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
|
||||
model_cfg_path)
|
||||
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
|
||||
onnx_path, args.work_dir)
|
||||
|
||||
create_process(
|
||||
'ncnn quant table',
|
||||
target=get_table,
|
||||
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
|
||||
quant_table, quant_image_dir, args.device),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
|
||||
create_process(
|
||||
'ncnn_int8',
|
||||
target=ncnn2int8,
|
||||
args=(model_param_path, model_bin_path, quant_table,
|
||||
quant_param, quant_bin),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
backend_files += [quant_param, quant_bin]
|
||||
else:
|
||||
backend_files += [model_param_path, model_bin_path]
|
||||
|
||||
elif backend == Backend.SNPE:
|
||||
from mmdeploy.apis.snpe import is_available as is_available
|
||||
|
||||
if not is_available():
|
||||
logger.error('snpe support is not available, please check\n'
|
||||
'1) `snpe-onnx-to-dlc` existed in `PATH`\n'
|
||||
'2) snpe only support\n'
|
||||
'ubuntu18.04')
|
||||
exit(1)
|
||||
|
||||
import mmdeploy.apis.snpe as snpe_api
|
||||
from mmdeploy.apis.snpe import get_env_key, get_output_model_file
|
||||
|
||||
if get_env_key() not in os.environ:
|
||||
os.environ[get_env_key()] = args.uri
|
||||
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [snpe_api.from_onnx])
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in ir_files:
|
||||
dlc_path = get_output_model_file(onnx_path, args.work_dir)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
snpe_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name))
|
||||
backend_files = [dlc_path]
|
||||
|
||||
elif backend == Backend.OPENVINO:
|
||||
from mmdeploy.apis.openvino import \
|
||||
is_available as is_available_openvino
|
||||
assert is_available_openvino(), \
|
||||
'OpenVINO is not available, please install OpenVINO first.'
|
||||
|
||||
import mmdeploy.apis.openvino as openvino_api
|
||||
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
|
||||
get_mo_options_from_cfg,
|
||||
get_output_model_file)
|
||||
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [openvino_api.from_onnx])
|
||||
|
||||
openvino_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
||||
input_info = get_input_info_from_cfg(deploy_cfg)
|
||||
output_names = get_ir_config(deploy_cfg).output_names
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
openvino_api.from_onnx(onnx_path, args.work_dir, input_info,
|
||||
output_names, mo_options)
|
||||
openvino_files.append(model_xml_path)
|
||||
backend_files = openvino_files
|
||||
|
||||
elif backend == Backend.PPLNN:
|
||||
from mmdeploy.apis.pplnn import is_available as is_available_pplnn
|
||||
assert is_available_pplnn(), \
|
||||
'PPLNN is not available, please install PPLNN first.'
|
||||
|
||||
from mmdeploy.apis.pplnn import from_onnx
|
||||
|
||||
pplnn_pipeline_funcs = [from_onnx]
|
||||
PIPELINE_MANAGER.set_log_level(log_level, pplnn_pipeline_funcs)
|
||||
|
||||
pplnn_files = []
|
||||
for onnx_path in ir_files:
|
||||
algo_file = onnx_path.replace('.onnx', '.json')
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
assert 'opt_shape' in model_inputs, 'Expect opt_shape ' \
|
||||
'in deploy config for PPLNN'
|
||||
# PPLNN accepts only 1 input shape for optimization,
|
||||
# may get changed in the future
|
||||
input_shapes = [model_inputs.opt_shape]
|
||||
algo_prefix = osp.splitext(algo_file)[0]
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
algo_prefix,
|
||||
device=args.device,
|
||||
input_shapes=input_shapes)
|
||||
pplnn_files += [onnx_path, algo_file]
|
||||
backend_files = pplnn_files
|
||||
|
||||
elif backend == Backend.RKNN:
|
||||
from mmdeploy.apis.rknn import is_available as rknn_is_available
|
||||
assert rknn_is_available(
|
||||
), 'RKNN is not available, please install RKNN first.'
|
||||
|
||||
from mmdeploy.apis.rknn import onnx2rknn
|
||||
PIPELINE_MANAGER.enable_multiprocess(True, [onnx2rknn])
|
||||
PIPELINE_MANAGER.set_log_level(logging.INFO, [onnx2rknn])
|
||||
|
||||
backend_files = []
|
||||
for model_id, onnx_path in zip(range(len(ir_files)), ir_files):
|
||||
pre_fix_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
output_path = osp.join(args.work_dir, pre_fix_name + '.rknn')
|
||||
import tempfile
|
||||
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
|
||||
with open(dataset_file, 'w') as f:
|
||||
f.writelines([osp.abspath(args.img)])
|
||||
onnx2rknn(
|
||||
onnx_path,
|
||||
output_path,
|
||||
deploy_cfg_path,
|
||||
model_cfg_path,
|
||||
dataset_file=dataset_file)
|
||||
|
||||
backend_files.append(output_path)
|
||||
elif backend == Backend.ASCEND:
|
||||
from mmdeploy.apis.ascend import from_onnx
|
||||
|
||||
ascend_pipeline_funcs = [from_onnx]
|
||||
PIPELINE_MANAGER.set_log_level(log_level, ascend_pipeline_funcs)
|
||||
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
|
||||
om_files = []
|
||||
for model_id, onnx_path in enumerate(ir_files):
|
||||
om_path = osp.splitext(onnx_path)[0] + '.om'
|
||||
from_onnx(onnx_path, args.work_dir, model_inputs[model_id])
|
||||
om_files.append(om_path)
|
||||
backend_files = om_files
|
||||
from mmdeploy.utils import (get_common_config, get_normalization,
|
||||
get_quantization_config,
|
||||
get_rknn_quantization)
|
||||
quantization_cfg = get_quantization_config(deploy_cfg)
|
||||
common_params = get_common_config(deploy_cfg)
|
||||
if get_rknn_quantization(deploy_cfg) is True:
|
||||
transform = get_normalization(model_cfg)
|
||||
common_params.update(
|
||||
dict(
|
||||
mean_values=[transform['mean']],
|
||||
std_values=[transform['std']]))
|
||||
|
||||
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
|
||||
with open(dataset_file, 'w') as f:
|
||||
f.writelines([osp.abspath(args.img)])
|
||||
quantization_cfg.setdefault('dataset', dataset_file)
|
||||
if backend == Backend.ASCEND:
|
||||
# TODO: Add this to backend manager in the future
|
||||
if args.dump_info:
|
||||
from mmdeploy.backend.ascend import update_sdk_pipeline
|
||||
update_sdk_pipeline(args.work_dir)
|
||||
|
||||
elif backend == Backend.COREML:
|
||||
from mmdeploy.apis.coreml import from_torchscript
|
||||
coreml_pipeline_funcs = [from_torchscript]
|
||||
PIPELINE_MANAGER.set_log_level(log_level, coreml_pipeline_funcs)
|
||||
# convert to backend
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
|
||||
if backend == Backend.TENSORRT:
|
||||
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
|
||||
backend_files = to_backend(
|
||||
backend,
|
||||
ir_files,
|
||||
work_dir=args.work_dir,
|
||||
deploy_cfg=deploy_cfg,
|
||||
log_level=log_level,
|
||||
device=args.device,
|
||||
uri=args.uri)
|
||||
|
||||
coreml_files = []
|
||||
for model_id, torchscript_path in enumerate(ir_files):
|
||||
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
|
||||
output_file_prefix = osp.join(args.work_dir, torchscript_name)
|
||||
# ncnn quantization
|
||||
if backend == Backend.NCNN and quant:
|
||||
from onnx2ncnn_quant_table import get_table
|
||||
|
||||
from_torchscript(model_id, torchscript_path, output_file_prefix,
|
||||
deploy_cfg, coreml_files)
|
||||
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
|
||||
model_param_paths = backend_files[::2]
|
||||
model_bin_paths = backend_files[1::2]
|
||||
backend_files = []
|
||||
for onnx_path, model_param_path, model_bin_path in zip(
|
||||
ir_files, model_param_paths, model_bin_paths):
|
||||
|
||||
backend_files = coreml_files
|
||||
elif backend == Backend.TVM:
|
||||
import copy
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
|
||||
model_cfg_path)
|
||||
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
|
||||
onnx_path, args.work_dir)
|
||||
|
||||
from mmdeploy.apis.tvm import from_onnx, get_library_ext
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [from_onnx])
|
||||
model_inputs = get_model_inputs(deploy_cfg)
|
||||
create_process(
|
||||
'ncnn quant table',
|
||||
target=get_table,
|
||||
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
|
||||
quant_table, quant_image_dir, args.device),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
|
||||
if args.device.startswith('cuda'):
|
||||
target = 'cuda'
|
||||
else:
|
||||
target = 'llvm'
|
||||
|
||||
lib_ext = get_library_ext()
|
||||
|
||||
tvm_files = []
|
||||
for model_id, onnx_path in enumerate(ir_files):
|
||||
model_input = copy.deepcopy(model_inputs[model_id])
|
||||
use_vm = model_input.get('use_vm', False)
|
||||
if 'target' not in model_input['tuner']:
|
||||
model_input['tuner']['target'] = target
|
||||
lib_path = osp.splitext(onnx_path)[0] + lib_ext
|
||||
code_path = osp.splitext(
|
||||
onnx_path)[0] + '.code' if use_vm else None
|
||||
model_input['output_file'] = lib_path
|
||||
model_input['onnx_model'] = onnx_path
|
||||
model_input['bytecode_file'] = code_path
|
||||
|
||||
# create calibration dataset
|
||||
if 'qconfig' in model_input:
|
||||
calib_path = osp.join(args.work_dir, calib_filename)
|
||||
from mmdeploy.backend.tvm import HDF5Dataset
|
||||
partition_type = 'end2end' if partition_cfgs is None \
|
||||
else onnx_name
|
||||
dataset = HDF5Dataset(
|
||||
calib_path,
|
||||
model_input['shape'],
|
||||
model_type=partition_type,
|
||||
device=target)
|
||||
model_input['dataset'] = dataset()
|
||||
|
||||
from_onnx(**model_input)
|
||||
|
||||
tvm_files += [lib_path, code_path]
|
||||
|
||||
backend_files = tvm_files
|
||||
create_process(
|
||||
'ncnn_int8',
|
||||
target=ncnn2int8,
|
||||
args=(model_param_path, model_bin_path, quant_table,
|
||||
quant_param, quant_bin),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
backend_files += [quant_param, quant_bin]
|
||||
|
||||
if args.test_img is None:
|
||||
args.test_img = args.img
|
||||
|
|
Loading…
Reference in New Issue