[Refactor][API2.0] Api refactor2.0 (#529)
* [refactor][API2.0] Add onnx export and jit trace (#419) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * remove comment * better pipeline manager * remove print * [Refactor][API2.0] Api partition calibration (#433) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * better pipeline manager * rename api, remove reduant variable, and misc * [Refactor][API2.0] Api ncnn openvino (#435) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add ncnn api * finish ncnn api * add openvino support * add kwargs * remove comment * better pipeline manager * merge fix * merge util and onnx2ncnn * fix docstring * [Refactor][API2.0] API for TensorRT (#519) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add partition * move calibration * Better create_calib_table * better deploy * add kwargs * remove comment * Add tensorrt API * better pipeline manager * add tensorrt new api * remove print * rename api, remove reduant variable, and misc * add docstring * [Refactor][API2.0] Api ppl other (#528) * first commit * add async call * add new api onnx export and jit trace * add decorator * fix ci * fix torchscript ci * fix loader * better pipemanager * remove comment, better import * add kwargs * Add new APIS for pplnn sdk and misc * remove comment * better pipeline manager * merge fix * update tools/onnx2pplnn.py * rename functionpull/534/head
parent
aabab46d8a
commit
b32fc41bed
|
@ -143,3 +143,6 @@ bin/
|
|||
#
|
||||
!docs/zh_cn/build
|
||||
!docs/en/build
|
||||
|
||||
# ncnn
|
||||
mmdeploy/backend/ncnn/onnx2ncnn
|
||||
|
|
|
@ -38,11 +38,11 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
|
|||
|
||||
|
||||
if is_available():
|
||||
from .utils import create_trt_engine, load_trt_engine, save_trt_engine
|
||||
from .utils import from_onnx, load, save
|
||||
from .wrapper import TRTWrapper
|
||||
|
||||
__all__ = [
|
||||
'create_trt_engine', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper'
|
||||
'from_onnx', 'save', 'load', 'TRTWrapper'
|
||||
]
|
||||
```
|
||||
|
||||
|
|
|
@ -38,11 +38,11 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
|
|||
|
||||
|
||||
if is_available():
|
||||
from .utils import create_trt_engine, load_trt_engine, save_trt_engine
|
||||
from .utils import from_onnx, load, save
|
||||
from .wrapper import TRTWrapper
|
||||
|
||||
__all__ = [
|
||||
'create_trt_engine', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper'
|
||||
'from_onnx', 'save', 'load', 'TRTWrapper'
|
||||
]
|
||||
```
|
||||
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .calibration import create_calib_table
|
||||
from .extract_model import extract_model
|
||||
from .inference import inference_model
|
||||
from .pytorch2onnx import torch2onnx, torch2onnx_impl
|
||||
from .pytorch2torchscript import torch2torchscript, torch2torchscript_impl
|
||||
from .utils import build_task_processor, get_predefined_partition_cfg
|
||||
from .visualize import visualize_model
|
||||
|
||||
__all__ = [
|
||||
'create_calib_table', 'extract_model', 'inference_model', 'torch2onnx',
|
||||
'torch2onnx_impl', 'torch2torchscript', 'torch2torchscript_impl',
|
||||
'build_task_processor', 'get_predefined_partition_cfg', 'visualize_model'
|
||||
]
|
||||
# mmcv dependency
|
||||
try:
|
||||
from .calibration import create_calib_input_data
|
||||
from .extract_model import extract_model
|
||||
from .inference import inference_model
|
||||
from .pytorch2onnx import torch2onnx
|
||||
from .pytorch2torchscript import torch2torchscript
|
||||
from .utils import build_task_processor, get_predefined_partition_cfg
|
||||
from .visualize import visualize_model
|
||||
|
||||
__all__ = [
|
||||
'create_calib_input_data', 'extract_model', 'inference_model',
|
||||
'torch2onnx', 'torch2torchscript', 'build_task_processor',
|
||||
'get_predefined_partition_cfg', 'visualize_model'
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -1,107 +1,78 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Union
|
||||
|
||||
import h5py
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel
|
||||
|
||||
from mmdeploy.core import (RewriterContext, patch_model,
|
||||
reset_mark_function_count)
|
||||
from mmdeploy.core import patch_model
|
||||
from mmdeploy.utils import cfg_apply_marks, load_config
|
||||
from .core import PIPELINE_MANAGER, no_mp
|
||||
from .utils import create_calib_input_data as create_calib_input_data_impl
|
||||
|
||||
|
||||
def create_calib_table(calib_file: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
model_checkpoint: Optional[str] = None,
|
||||
dataset_cfg: Optional[Union[str, mmcv.Config]] = None,
|
||||
dataset_type: str = 'val',
|
||||
device: str = 'cuda:0',
|
||||
**kwargs) -> None:
|
||||
"""Create calibration table.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import create_calib_table
|
||||
>>> from mmdeploy.utils import get_calib_filename, load_config
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/' \
|
||||
'detection_tensorrt-int8_dynamic-320x320-1344x1344.py'
|
||||
>>> deploy_cfg = load_config(deploy_cfg)[0]
|
||||
>>> calib_file = get_calib_filename(deploy_cfg)
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> model_checkpoint = 'checkpoints/' \
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
|
||||
>>> create_calib_table(calib_file, deploy_cfg, \
|
||||
model_cfg, model_checkpoint, device='cuda:0')
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def create_calib_input_data(calib_file: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
model_checkpoint: Optional[str] = None,
|
||||
dataset_cfg: Optional[Union[str,
|
||||
mmcv.Config]] = None,
|
||||
dataset_type: str = 'val',
|
||||
device: str = 'cpu') -> None:
|
||||
"""Create dataset for post-training quantization.
|
||||
|
||||
Args:
|
||||
calib_file (str): Input calibration file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
model_cfg (str | mmcv.Config): The model config.
|
||||
model_checkpoint (str): PyTorch model checkpoint, defaults to `None`.
|
||||
dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None`
|
||||
dataset_type (str): A string specifying dataset type, e.g.: 'test',
|
||||
'val', defaults to 'val'.
|
||||
device (str): Specifying the device to run on, defaults to 'cuda:0'.
|
||||
calib_file (str): The output calibration data file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
Config object.
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
model_checkpoint (str): A checkpoint path of PyTorch model,
|
||||
defaults to `None`.
|
||||
dataset_cfg (Optional[Union[str, mmcv.Config]], optional): Model
|
||||
config to provide calibration dataset. If none, use `model_cfg`
|
||||
as the dataset config. Defaults to None.
|
||||
dataset_type (str, optional): The dataset type. Defaults to 'val'.
|
||||
device (str, optional): Device to create dataset. Defaults to 'cpu'.
|
||||
"""
|
||||
if dataset_cfg is None:
|
||||
dataset_cfg = model_cfg
|
||||
with no_mp():
|
||||
if dataset_cfg is None:
|
||||
dataset_cfg = model_cfg
|
||||
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
device_id = torch.device(device).index
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
device_id = torch.device(device).index
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
if dataset_cfg is None:
|
||||
dataset_cfg = model_cfg
|
||||
# load dataset_cfg if necessary
|
||||
dataset_cfg = load_config(dataset_cfg)[0]
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
if dataset_cfg is None:
|
||||
dataset_cfg = model_cfg
|
||||
|
||||
apply_marks = cfg_apply_marks(deploy_cfg)
|
||||
backend = 'default'
|
||||
model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
dataset = task_processor.build_dataset(dataset_cfg, dataset_type)
|
||||
# load dataset_cfg if necessary
|
||||
dataset_cfg = load_config(dataset_cfg)[0]
|
||||
|
||||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
with h5py.File(calib_file, mode='w') as file:
|
||||
calib_data_group = file.create_group('calib_data')
|
||||
apply_marks = cfg_apply_marks(deploy_cfg)
|
||||
|
||||
model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
dataset = task_processor.build_dataset(dataset_cfg, dataset_type)
|
||||
|
||||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg)
|
||||
|
||||
if not apply_marks:
|
||||
# create end2end group
|
||||
input_data_group = calib_data_group.create_group('end2end')
|
||||
input_group = input_data_group.create_group('input')
|
||||
dataloader = task_processor.build_dataloader(
|
||||
dataset, 1, 1, dist=False, shuffle=False)
|
||||
patched_model = MMDataParallel(patched_model, device_ids=[device_id])
|
||||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
for data_id, input_data in enumerate(dataloader):
|
||||
|
||||
if not apply_marks:
|
||||
# save end2end data
|
||||
input_tensor = task_processor.get_tensor_from_input(input_data)
|
||||
input_ndarray = input_tensor.detach().cpu().numpy()
|
||||
input_group.create_dataset(
|
||||
str(data_id),
|
||||
shape=input_ndarray.shape,
|
||||
compression='gzip',
|
||||
compression_opts=4,
|
||||
data=input_ndarray)
|
||||
|
||||
with torch.no_grad(), RewriterContext(
|
||||
cfg=deploy_cfg,
|
||||
backend=backend,
|
||||
create_calib=True,
|
||||
calib_file=file,
|
||||
data_id=data_id):
|
||||
reset_mark_function_count()
|
||||
_ = task_processor.run_inference(patched_model, input_data)
|
||||
file.flush()
|
||||
|
||||
prog_bar.update()
|
||||
create_calib_input_data_impl(
|
||||
calib_file,
|
||||
patched_model,
|
||||
dataloader,
|
||||
get_tensor_func=task_processor.get_tensor_from_input,
|
||||
inference_func=task_processor.run_inference,
|
||||
model_partition=apply_marks,
|
||||
context_info=dict(cfg=deploy_cfg),
|
||||
device=device)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .pipeline_manager import PIPELINE_MANAGER, no_mp
|
||||
|
||||
__all__ = ['PIPELINE_MANAGER', 'no_mp']
|
|
@ -0,0 +1,376 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
try:
|
||||
import torch.multiprocessing as mp
|
||||
except Exception:
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
def _get_func_name(func: Callable) -> str:
|
||||
"""get function name."""
|
||||
assert isinstance(func, Callable), f'{func} is not a Callable object.'
|
||||
_func_name = None
|
||||
if hasattr(func, '__qualname__'):
|
||||
_func_name = f'{func.__module__}.{func.__qualname__}'
|
||||
elif hasattr(func, '__class__'):
|
||||
_func_name = func.__class__
|
||||
else:
|
||||
_func_name = str(func)
|
||||
return _func_name
|
||||
|
||||
|
||||
class PipelineCaller:
|
||||
"""Classes to record the attribute of each pipeline function."""
|
||||
|
||||
def __init__(self,
|
||||
module_name: str,
|
||||
impl_name: str,
|
||||
func_name: Optional[str] = None,
|
||||
log_level: int = logging.DEBUG,
|
||||
is_multiprocess_available: bool = True) -> None:
|
||||
if func_name is not None:
|
||||
self._func_name = func_name
|
||||
else:
|
||||
self._func_name = impl_name
|
||||
# Can not save the function directly since multiprocess with spawn mode
|
||||
# require all field can be pickled.
|
||||
self._module_name = module_name
|
||||
self._impl_name = impl_name
|
||||
self._is_multiprocess_available = is_multiprocess_available
|
||||
self._enable_multiprocess = False
|
||||
self._mp_dict = None
|
||||
self._mp_async = False
|
||||
self._call_id = 0
|
||||
self._log_level = log_level
|
||||
self._input_hooks: List[Callable] = []
|
||||
self._output_hooks: List[Callable] = []
|
||||
|
||||
@property
|
||||
def is_multiprocess_available(self) -> bool:
|
||||
"""check if multiprocess is available for this pipeline."""
|
||||
return self._is_multiprocess_available
|
||||
|
||||
@property
|
||||
def is_multiprocess(self) -> bool:
|
||||
"""check if this pipeline is multiprocess."""
|
||||
return self._enable_multiprocess
|
||||
|
||||
@property
|
||||
def input_hooks(self) -> List[Callable]:
|
||||
"""get input hooks."""
|
||||
return self._input_hooks
|
||||
|
||||
@property
|
||||
def output_hooks(self) -> List[Callable]:
|
||||
"""get output hooks."""
|
||||
return self._output_hooks
|
||||
|
||||
def pop_mp_output(self, call_id: int = None) -> Any:
|
||||
"""pop multiprocess output."""
|
||||
assert self._mp_dict is not None, 'mp_dict is None.'
|
||||
call_id = self._call_id if call_id is None else call_id
|
||||
assert call_id in self._mp_dict, \
|
||||
f'`{self._func_name}` with Call id: {call_id} failed.'
|
||||
ret = self._mp_dict[call_id]
|
||||
self._mp_dict.pop(call_id)
|
||||
return ret
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
do_multiprocess = self.is_multiprocess_available \
|
||||
and self.is_multiprocess\
|
||||
and self._mp_dict is not None
|
||||
|
||||
logger = get_root_logger(log_level=self._log_level)
|
||||
mp_log_str = 'subprocess' if do_multiprocess else 'main process'
|
||||
logger.log(self._log_level,
|
||||
f'Start pipeline {self._func_name} in {mp_log_str}')
|
||||
|
||||
for input_hook in self.input_hooks:
|
||||
args, kwargs = input_hook(*args, **kwargs)
|
||||
|
||||
module_name = self._module_name
|
||||
impl_name = self._impl_name
|
||||
# TODO: find another way to load function
|
||||
mod = importlib.import_module(module_name)
|
||||
func = getattr(mod, impl_name, None)
|
||||
assert func is not None, \
|
||||
f'Can not find implementation of {self._func_name}'
|
||||
ret = func(*args, **kwargs)
|
||||
for output_hook in self.output_hooks:
|
||||
ret = output_hook(ret)
|
||||
|
||||
if do_multiprocess:
|
||||
self._mp_dict[self._call_id] = ret
|
||||
|
||||
logger.log(self._log_level, f'Finish pipeline {self._func_name}')
|
||||
return ret
|
||||
|
||||
|
||||
class PipelineResult:
|
||||
"""The result of async pipeline."""
|
||||
|
||||
def __init__(self, manager: Any, call_id: int) -> None:
|
||||
self._manager = manager
|
||||
self._call_id = call_id
|
||||
|
||||
@property
|
||||
def call_id(self) -> int:
|
||||
return self._call_id
|
||||
|
||||
def get(self) -> Any:
|
||||
"""get result."""
|
||||
return self._manager.get_result_sync(self._call_id)
|
||||
|
||||
|
||||
FUNC_NAME_TYPE = Union[str, Callable]
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
"""This is a tool to manager all pipeline functions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._enable_multiprocess = True
|
||||
self._mp_manager = None
|
||||
self._callers: Dict[str, PipelineCaller] = dict()
|
||||
self._call_id = 0
|
||||
self._proc_async: Dict[int, (str, mp.Process)] = dict()
|
||||
|
||||
@property
|
||||
def mp_manager(self) -> Optional[mp.Manager]:
|
||||
"""get multiprocess manager."""
|
||||
return self._mp_manager
|
||||
|
||||
def get_caller(self, func_name: FUNC_NAME_TYPE) -> PipelineCaller:
|
||||
"""get caller of given function."""
|
||||
if isinstance(func_name, Callable):
|
||||
func_name = _get_func_name(func_name)
|
||||
assert func_name in self._callers, \
|
||||
f'{func_name} has not been registered.'
|
||||
return self._callers[func_name]
|
||||
|
||||
def __set_caller_val(self,
|
||||
val_name: str,
|
||||
val: Any,
|
||||
func_name: Optional[FUNC_NAME_TYPE] = None) -> None:
|
||||
"""helper to set any caller value."""
|
||||
if func_name is None:
|
||||
for func_name_ in self._callers:
|
||||
setattr(self.get_caller(func_name_), val_name, val)
|
||||
else:
|
||||
setattr(self.get_caller(func_name), val_name, val)
|
||||
|
||||
def _create_mp_manager(self) -> None:
|
||||
"""create multiprocess manager if not exists."""
|
||||
if self._mp_manager is None:
|
||||
self._mp_manager = mp.Manager()
|
||||
|
||||
def _enable_multiprocess_single(self,
|
||||
val: bool,
|
||||
func_name: FUNC_NAME_TYPE = None) -> None:
|
||||
"""implement of enable_multiprocess."""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
# check if multiprocess is available for this function
|
||||
if not pipe_caller.is_multiprocess_available:
|
||||
return
|
||||
pipe_caller._enable_multiprocess = val
|
||||
if val is True and self.mp_manager is not None:
|
||||
pipe_caller._mp_dict = self.mp_manager.dict()
|
||||
else:
|
||||
pipe_caller._mp_dict = None
|
||||
|
||||
def enable_multiprocess(
|
||||
self,
|
||||
val: bool,
|
||||
func_names: Optional[Union[FUNC_NAME_TYPE,
|
||||
Sequence[FUNC_NAME_TYPE]]] = None
|
||||
) -> None:
|
||||
"""enable multiprocess for pipeline function.
|
||||
|
||||
Args:
|
||||
val (bool): enable or disable multiprocess.
|
||||
func_names (str | List[str]): function names to enable. If
|
||||
func_name is None, all registered function will be enabled.
|
||||
"""
|
||||
if val is True:
|
||||
self._create_mp_manager()
|
||||
if func_names is None:
|
||||
for func_name in self._callers:
|
||||
self._enable_multiprocess_single(val, func_name=func_name)
|
||||
else:
|
||||
if isinstance(func_names, str):
|
||||
func_names = [func_names]
|
||||
for func_name in func_names:
|
||||
self._enable_multiprocess_single(val, func_name=func_name)
|
||||
|
||||
def set_mp_async(self,
|
||||
val: bool,
|
||||
func_name: Optional[FUNC_NAME_TYPE] = None) -> None:
|
||||
"""set multiprocess async of the pipeline function.
|
||||
|
||||
Args:
|
||||
val (bool): enable async call.
|
||||
func_name (str | None): function name to set. If func_name is
|
||||
None, all registered function will be set.
|
||||
"""
|
||||
self.__set_caller_val('_mp_async', val, func_name)
|
||||
|
||||
def set_log_level(
|
||||
self,
|
||||
level: int,
|
||||
func_names: Optional[Union[FUNC_NAME_TYPE,
|
||||
Sequence[FUNC_NAME_TYPE]]] = None
|
||||
) -> None:
|
||||
"""set log level of the pipeline function.
|
||||
|
||||
Args:
|
||||
level (int): the log level.
|
||||
func_names (str | List[str]): function names to set. If func_names
|
||||
is None, all registered function will be set.
|
||||
"""
|
||||
if isinstance(func_names, str):
|
||||
func_names = [func_names]
|
||||
for func_name in func_names:
|
||||
self.__set_caller_val('_log_level', level, func_name)
|
||||
|
||||
def get_input_hooks(self, func_name: FUNC_NAME_TYPE):
|
||||
"""get input hooks of given function name.
|
||||
|
||||
Args:
|
||||
func_name (str): function name.
|
||||
"""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
return pipe_caller.input_hooks
|
||||
|
||||
def get_output_hooks(self, func_name: FUNC_NAME_TYPE):
|
||||
"""get output hooks of given function name.
|
||||
|
||||
Args:
|
||||
func_name (str): function name.
|
||||
"""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
return pipe_caller.output_hooks
|
||||
|
||||
def call_function_local(self, func_name: FUNC_NAME_TYPE, *args,
|
||||
**kwargs) -> Any:
|
||||
"""call pipeline function.
|
||||
|
||||
Args:
|
||||
func_name (str): function name to be called.
|
||||
|
||||
Returns:
|
||||
Any: The result of call function
|
||||
"""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
pipe_caller._call_id = self._call_id
|
||||
self._call_id += 1
|
||||
return pipe_caller(*args, **kwargs)
|
||||
|
||||
def call_function_async(self, func_name: FUNC_NAME_TYPE, *args,
|
||||
**kwargs) -> int:
|
||||
"""call pipeline function.
|
||||
|
||||
Args:
|
||||
func_name (str): function name to be called.
|
||||
|
||||
Returns:
|
||||
int: Call id of this function
|
||||
"""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
assert pipe_caller.is_multiprocess, \
|
||||
f'multiprocess of {func_name} has not been enabled.'
|
||||
|
||||
call_id = self._call_id
|
||||
pipe_caller._call_id = call_id
|
||||
self._call_id += 1
|
||||
proc = mp.Process(target=pipe_caller, args=args, kwargs=kwargs)
|
||||
proc.start()
|
||||
self._proc_async[call_id] = (func_name, proc)
|
||||
|
||||
return call_id
|
||||
|
||||
def get_result_sync(self, call_id: int):
|
||||
"""get result of async call."""
|
||||
assert call_id in self._proc_async, f'Unknown call id: {call_id}'
|
||||
func_name, proc = self._proc_async.pop(call_id)
|
||||
proc.join()
|
||||
ret = self.get_caller(func_name).pop_mp_output(call_id)
|
||||
|
||||
return ret
|
||||
|
||||
def call_function(self, func_name: FUNC_NAME_TYPE, *args, **kwargs) -> Any:
|
||||
"""call pipeline function.
|
||||
|
||||
Args:
|
||||
func_name (str): function name to be called.
|
||||
|
||||
Returns:
|
||||
Any: The result of call function
|
||||
"""
|
||||
pipe_caller = self.get_caller(func_name)
|
||||
|
||||
if self._enable_multiprocess and pipe_caller.is_multiprocess:
|
||||
call_id = self.call_function_async(func_name, *args, **kwargs)
|
||||
if pipe_caller._mp_async:
|
||||
return PipelineResult(self, call_id)
|
||||
return self.get_result_sync(call_id)
|
||||
else:
|
||||
return self.call_function_local(func_name, *args, **kwargs)
|
||||
|
||||
def register_pipeline(self,
|
||||
is_multiprocess_available: bool = True,
|
||||
log_level: int = logging.DEBUG):
|
||||
"""register the pipeline function."""
|
||||
|
||||
def _register(func):
|
||||
assert isinstance(func, Callable), f'{func} is not Callable.'
|
||||
func_name_ = _get_func_name(func)
|
||||
|
||||
# save the implementation into the registry module
|
||||
impl_name = f'_pipe_{func.__name__}__impl_'
|
||||
frame = inspect.stack()[1]
|
||||
outer_mod = inspect.getmodule(frame[0])
|
||||
mod_name = outer_mod.__name__
|
||||
setattr(outer_mod, impl_name, func)
|
||||
|
||||
# create caller
|
||||
pipe_caller = PipelineCaller(
|
||||
mod_name,
|
||||
impl_name,
|
||||
func_name=func_name_,
|
||||
log_level=log_level,
|
||||
is_multiprocess_available=is_multiprocess_available)
|
||||
PIPELINE_MANAGER._callers[func_name_] = pipe_caller
|
||||
|
||||
# wrap call
|
||||
@wraps(func)
|
||||
def _wrap(*args, **kwargs):
|
||||
return self.call_function(func_name_, *args, **kwargs)
|
||||
|
||||
return _wrap
|
||||
|
||||
return _register
|
||||
|
||||
|
||||
PIPELINE_MANAGER = PipelineManager()
|
||||
|
||||
|
||||
class no_mp:
|
||||
"""The context manager used to disable multiprocess."""
|
||||
|
||||
def __init__(self, manager: PipelineManager = PIPELINE_MANAGER) -> None:
|
||||
self._manager = manager
|
||||
self._old_enable_multiprocess = True
|
||||
|
||||
def __enter__(self):
|
||||
self._old_enable_multiprocess = self._manager._enable_multiprocess
|
||||
self._manager._enable_multiprocess = False
|
||||
|
||||
def __exit__(self, type, val, tb):
|
||||
self._manager._enable_multiprocess = self._old_enable_multiprocess
|
|
@ -1,33 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import onnx.utils
|
||||
|
||||
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
|
||||
get_new_name, parse_extractor_io_string,
|
||||
remove_identity, rename_value)
|
||||
from mmdeploy.utils import get_root_logger
|
||||
from .core import PIPELINE_MANAGER
|
||||
from .onnx import extract_partition
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def extract_model(model: Union[str, onnx.ModelProto],
|
||||
start: Union[str, Iterable[str]],
|
||||
end: Union[str, Iterable[str]],
|
||||
start_marker: Union[str, Iterable[str]],
|
||||
end_marker: Union[str, Iterable[str]],
|
||||
start_name_map: Optional[Dict[str, str]] = None,
|
||||
end_name_map: Optional[Dict[str, str]] = None,
|
||||
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
||||
save_file: Optional[str] = None) -> onnx.ModelProto:
|
||||
"""Extract sub-model from an ONNX model.
|
||||
"""Extract partition-model from an ONNX model.
|
||||
|
||||
The sub-model is defined by the names of the input and output tensors
|
||||
The partition-model is defined by the names of the input and output tensors
|
||||
exactly.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import extract_model
|
||||
>>> model = 'work_dir/fastrcnn.onnx'
|
||||
>>> start = 'detector:input'
|
||||
>>> end = ['extract_feat:output', 'multiclass_nms[0]:input']
|
||||
>>> start_marker = 'detector:input'
|
||||
>>> end_marker = ['extract_feat:output', 'multiclass_nms[0]:input']
|
||||
>>> dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
|
@ -44,13 +42,14 @@ def extract_model(model: Union[str, onnx.ModelProto],
|
|||
}
|
||||
}
|
||||
>>> save_file = 'partition_model.onnx'
|
||||
>>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \
|
||||
>>> extract_model(model, start_marker, end_marker, \
|
||||
dynamic_axes=dynamic_axes, \
|
||||
save_file=save_file)
|
||||
|
||||
Args:
|
||||
model (str | onnx.ModelProto): Input ONNX model to be extracted.
|
||||
start (str | Sequence[str]): Start marker(s) to extract.
|
||||
end (str | Sequence[str]): End marker(s) to extract.
|
||||
start_marker (str | Sequence[str]): Start marker(s) to extract.
|
||||
end_marker (str | Sequence[str]): End marker(s) to extract.
|
||||
start_name_map (Dict[str, str]): A mapping of start names, defaults to
|
||||
`None`.
|
||||
end_name_map (Dict[str, str]): A mapping of end names, defaults to
|
||||
|
@ -61,142 +60,8 @@ def extract_model(model: Union[str, onnx.ModelProto],
|
|||
`None`.
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: The extracted sub-model.
|
||||
onnx.ModelProto: The extracted model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = onnx.load(model)
|
||||
|
||||
num_value_info = len(model.graph.value_info)
|
||||
inputs = []
|
||||
outputs = []
|
||||
logger = get_root_logger()
|
||||
if not isinstance(start, (list, tuple)):
|
||||
start = [start]
|
||||
for s in start:
|
||||
start_name, func_id, start_type = parse_extractor_io_string(s)
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == start_name and attr[
|
||||
'type'] == start_type and attr['func_id'] == func_id:
|
||||
name = node.input[0]
|
||||
if name not in inputs:
|
||||
new_name = get_new_name(
|
||||
attr, mark_name=s, name_map=start_name_map)
|
||||
rename_value(model, name, new_name)
|
||||
if not any([
|
||||
v_info.name == new_name
|
||||
for v_info in model.graph.value_info
|
||||
]):
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_name, attr['dtype'], attr['shape'])
|
||||
model.graph.value_info.append(new_val_info)
|
||||
inputs.append(new_name)
|
||||
|
||||
logger.info(f'inputs: {", ".join(inputs)}')
|
||||
|
||||
# collect outputs
|
||||
if not isinstance(end, (list, tuple)):
|
||||
end = [end]
|
||||
for e in end:
|
||||
end_name, func_id, end_type = parse_extractor_io_string(e)
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == end_name and attr[
|
||||
'type'] == end_type and attr['func_id'] == func_id:
|
||||
name = node.output[0]
|
||||
if name not in outputs:
|
||||
new_name = get_new_name(
|
||||
attr, mark_name=e, name_map=end_name_map)
|
||||
rename_value(model, name, new_name)
|
||||
if not any([
|
||||
v_info.name == new_name
|
||||
for v_info in model.graph.value_info
|
||||
]):
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_name, attr['dtype'], attr['shape'])
|
||||
model.graph.value_info.append(new_val_info)
|
||||
outputs.append(new_name)
|
||||
|
||||
logger.info(f'outputs: {", ".join(outputs)}')
|
||||
|
||||
# replace Mark with Identity
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
del node.attribute[:]
|
||||
node.domain = ''
|
||||
node.op_type = 'Identity'
|
||||
|
||||
extractor = create_extractor(model)
|
||||
extracted_model = extractor.extract_model(inputs, outputs)
|
||||
|
||||
# remove all Identity, this may be done by onnx simplifier
|
||||
remove_identity(extracted_model)
|
||||
|
||||
# collect all used inputs
|
||||
used = set()
|
||||
for node in extracted_model.graph.node:
|
||||
for input in node.input:
|
||||
used.add(input)
|
||||
|
||||
for output in extracted_model.graph.output:
|
||||
used.add(output.name)
|
||||
|
||||
# delete unused inputs
|
||||
success = True
|
||||
while success:
|
||||
success = False
|
||||
for i, input in enumerate(extracted_model.graph.input):
|
||||
if input.name not in used:
|
||||
del extracted_model.graph.input[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
# eliminate output without shape
|
||||
for xs in [extracted_model.graph.output]:
|
||||
for x in xs:
|
||||
if not x.type.tensor_type.shape.dim:
|
||||
logger.info(f'fixing output shape: {x.name}')
|
||||
x.CopyFrom(
|
||||
onnx.helper.make_tensor_value_info(
|
||||
x.name, x.type.tensor_type.elem_type, []))
|
||||
|
||||
# eliminate 0-batch dimension, dirty workaround for two-stage detectors
|
||||
for input in extracted_model.graph.input:
|
||||
if input.name in inputs:
|
||||
if input.type.tensor_type.shape.dim[0].dim_value == 0:
|
||||
input.type.tensor_type.shape.dim[0].dim_value = 1
|
||||
|
||||
# eliminate duplicated value_info for inputs
|
||||
success = True
|
||||
# num_value_info == 0 if dynamic shape
|
||||
if num_value_info == 0:
|
||||
while len(extracted_model.graph.value_info) > 0:
|
||||
extracted_model.graph.value_info.pop()
|
||||
while success:
|
||||
success = False
|
||||
for i, x in enumerate(extracted_model.graph.value_info):
|
||||
if x.name in inputs:
|
||||
del extracted_model.graph.value_info[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
# dynamic shape support
|
||||
if dynamic_axes is not None:
|
||||
for input_node in extracted_model.graph.input:
|
||||
if input_node.name in dynamic_axes:
|
||||
axes = dynamic_axes[input_node.name]
|
||||
for k, v in axes.items():
|
||||
input_node.type.tensor_type.shape.dim[k].dim_value = 0
|
||||
input_node.type.tensor_type.shape.dim[k].dim_param = v
|
||||
for output_node in extracted_model.graph.output:
|
||||
for idx, dim in enumerate(output_node.type.tensor_type.shape.dim):
|
||||
dim.dim_value = 0
|
||||
dim.dim_param = f'dim_{idx}'
|
||||
|
||||
# save extract_model if save_file is given
|
||||
if save_file is not None:
|
||||
onnx.save(extracted_model, save_file)
|
||||
|
||||
return extracted_model
|
||||
return extract_partition(model, start_marker, end_marker, start_name_map,
|
||||
end_name_map, dynamic_axes, save_file)
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.backend.ncnn import is_available, is_plugin_available
|
||||
from mmdeploy.backend.ncnn import from_onnx as _from_onnx
|
||||
from mmdeploy.backend.ncnn import is_available, is_custom_ops_available
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
__all__ = ['is_available', 'is_plugin_available']
|
||||
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
|
||||
|
||||
__all__ = ['is_available', 'is_custom_ops_available', 'from_onnx']
|
||||
|
||||
if is_available():
|
||||
from mmdeploy.backend.ncnn.onnx2ncnn import (get_output_model_file,
|
||||
onnx2ncnn)
|
||||
from mmdeploy.backend.ncnn.quant import get_quant_model_file, ncnn2int8
|
||||
__all__ += [
|
||||
'onnx2ncnn', 'get_output_model_file', 'ncnn2int8',
|
||||
'get_quant_model_file'
|
||||
]
|
||||
try:
|
||||
from mmdeploy.backend.ncnn.onnx2ncnn import get_output_model_file
|
||||
from mmdeploy.backend.ncnn.quant import get_quant_model_file, ncnn2int8
|
||||
__all__ += [
|
||||
'get_output_model_file', 'ncnn2int8', 'get_quant_model_file'
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .export import export
|
||||
from .partition import extract_partition
|
||||
|
||||
__all__ = ['export', 'extract_partition']
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import Backend, get_root_logger
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def export(model: torch.nn.Module,
|
||||
args: Union[torch.Tensor, Tuple, Dict],
|
||||
output_path_prefix: str,
|
||||
backend: Union[Backend, str] = 'default',
|
||||
input_metas: Optional[Dict] = None,
|
||||
context_info: Dict = dict(),
|
||||
input_names: Optional[Sequence[str]] = None,
|
||||
output_names: Optional[Sequence[str]] = None,
|
||||
opset_version: int = 11,
|
||||
dynamic_axes: Optional[Dict] = None,
|
||||
verbose: bool = False,
|
||||
keep_initializers_as_inputs: Optional[bool] = None,
|
||||
**kwargs):
|
||||
"""Export a PyTorch model into ONNX format. This is a wrap of
|
||||
`torch.onnx.export` with some enhancement.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.onnx import export
|
||||
>>>
|
||||
>>> model = create_model()
|
||||
>>> args = get_input_tensor()
|
||||
>>>
|
||||
>>> export(
|
||||
>>> model,
|
||||
>>> args,
|
||||
>>> 'place/to/save/model',
|
||||
>>> backend='tensorrt',
|
||||
>>> input_names=['input'],
|
||||
>>> output_names=['output'],
|
||||
>>> dynamic_axes={'input': {
|
||||
>>> 0: 'batch',
|
||||
>>> 2: 'height',
|
||||
>>> 3: 'width'
|
||||
>>> }})
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): the model to be exported.
|
||||
args (torch.Tensor|Tuple|Dict): Dummy input of the model.
|
||||
output_path_prefix (str): The output file prefix. The model will
|
||||
be saved to `<output_path_prefix>.onnx`.
|
||||
backend (Backend|str): Which backend will the graph be used. Different
|
||||
backend would generate different graph.
|
||||
input_metas (Dict): The constant inputs of the model.
|
||||
context_info (Dict): The information that would be used in the context
|
||||
of exporting.
|
||||
input_names (Sequence[str]): The input names of the model.
|
||||
output_names (Sequence[str]): The output names of the model.
|
||||
opset_version (int): The version of ONNX opset version. 11 as default.
|
||||
dynamic_axes (Dict): The information used to determine which axes are
|
||||
dynamic.
|
||||
verbose (bool): Enable verbose model on `torch.onnx.export`.
|
||||
keep_initializers_as_inputs (bool): Whether we should add inputs for
|
||||
each initializer.
|
||||
"""
|
||||
output_path = output_path_prefix + '.onnx'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Export PyTorch model to ONNX: {output_path}.')
|
||||
|
||||
def _add_or_update(cfg: dict, key: str, val: Any):
|
||||
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
|
||||
cfg[key].update(val)
|
||||
else:
|
||||
cfg[key] = val
|
||||
|
||||
context_info = deepcopy(context_info)
|
||||
deploy_cfg = context_info.pop('deploy_cfg', dict())
|
||||
ir_config = dict(
|
||||
type='onnx',
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
verbose=verbose,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs)
|
||||
_add_or_update(deploy_cfg, 'ir_config', ir_config)
|
||||
|
||||
if isinstance(backend, Backend):
|
||||
backend = backend.value
|
||||
backend_config = dict(type=backend)
|
||||
_add_or_update(deploy_cfg, 'backend_config', backend_config)
|
||||
|
||||
context_info['cfg'] = deploy_cfg
|
||||
if 'backend' not in context_info:
|
||||
context_info['backend'] = backend
|
||||
if 'opset' not in context_info:
|
||||
context_info['opset'] = opset_version
|
||||
|
||||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
with RewriterContext(**context_info), torch.no_grad():
|
||||
# patch input_metas
|
||||
if input_metas is not None:
|
||||
assert isinstance(
|
||||
input_metas, dict
|
||||
), f'Expect input_metas type is dict, get {type(input_metas)}.'
|
||||
model_forward = model.forward
|
||||
model.forward = partial(model.forward, **input_metas)
|
||||
|
||||
torch.onnx.export(
|
||||
patched_model,
|
||||
args,
|
||||
output_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
verbose=verbose)
|
||||
|
||||
if input_metas is not None:
|
||||
model.forward = model_forward
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import onnx.utils
|
||||
|
||||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
|
||||
get_new_name, parse_extractor_io_string,
|
||||
remove_identity, rename_value)
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def extract_partition(model: Union[str, onnx.ModelProto],
|
||||
start_marker: Union[str, Iterable[str]],
|
||||
end_marker: Union[str, Iterable[str]],
|
||||
start_name_map: Optional[Dict[str, str]] = None,
|
||||
end_name_map: Optional[Dict[str, str]] = None,
|
||||
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
||||
save_file: Optional[str] = None) -> onnx.ModelProto:
|
||||
"""Extract partition-model from an ONNX model.
|
||||
|
||||
The partition-model is defined by the names of the input and output tensors
|
||||
exactly.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis import extract_model
|
||||
>>> model = 'work_dir/fastrcnn.onnx'
|
||||
>>> start_marker = 'detector:input'
|
||||
>>> end_marker = ['extract_feat:output', 'multiclass_nms[0]:input']
|
||||
>>> dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'scores': {
|
||||
0: 'batch',
|
||||
1: 'num_boxes',
|
||||
},
|
||||
'boxes': {
|
||||
0: 'batch',
|
||||
1: 'num_boxes',
|
||||
}
|
||||
}
|
||||
>>> save_file = 'partition_model.onnx'
|
||||
>>> extract_partition(model, start_marker, end_marker, \
|
||||
dynamic_axes=dynamic_axes, \
|
||||
save_file=save_file)
|
||||
|
||||
Args:
|
||||
model (str | onnx.ModelProto): Input ONNX model to be extracted.
|
||||
start_marker (str | Sequence[str]): Start marker(s) to extract.
|
||||
end_marker (str | Sequence[str]): End marker(s) to extract.
|
||||
start_name_map (Dict[str, str]): A mapping of start names, defaults to
|
||||
`None`.
|
||||
end_name_map (Dict[str, str]): A mapping of end names, defaults to
|
||||
`None`.
|
||||
dynamic_axes (Dict[str, Dict[int, str]]): A dictionary to specify
|
||||
dynamic axes of input/output, defaults to `None`.
|
||||
save_file (str): A file to save the extracted model, defaults to
|
||||
`None`.
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: The extracted model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = onnx.load(model)
|
||||
|
||||
num_value_info = len(model.graph.value_info)
|
||||
inputs = []
|
||||
outputs = []
|
||||
logger = get_root_logger()
|
||||
if not isinstance(start_marker, (list, tuple)):
|
||||
start_marker = [start_marker]
|
||||
for s in start_marker:
|
||||
start_name, func_id, start_type = parse_extractor_io_string(s)
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == start_name and attr[
|
||||
'type'] == start_type and attr['func_id'] == func_id:
|
||||
name = node.input[0]
|
||||
if name not in inputs:
|
||||
new_name = get_new_name(
|
||||
attr, mark_name=s, name_map=start_name_map)
|
||||
rename_value(model, name, new_name)
|
||||
if not any([
|
||||
v_info.name == new_name
|
||||
for v_info in model.graph.value_info
|
||||
]):
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_name, attr['dtype'], attr['shape'])
|
||||
model.graph.value_info.append(new_val_info)
|
||||
inputs.append(new_name)
|
||||
|
||||
logger.info(f'inputs: {", ".join(inputs)}')
|
||||
|
||||
# collect outputs
|
||||
if not isinstance(end_marker, (list, tuple)):
|
||||
end_marker = [end_marker]
|
||||
for e in end_marker:
|
||||
end_name, func_id, end_type = parse_extractor_io_string(e)
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == end_name and attr[
|
||||
'type'] == end_type and attr['func_id'] == func_id:
|
||||
name = node.output[0]
|
||||
if name not in outputs:
|
||||
new_name = get_new_name(
|
||||
attr, mark_name=e, name_map=end_name_map)
|
||||
rename_value(model, name, new_name)
|
||||
if not any([
|
||||
v_info.name == new_name
|
||||
for v_info in model.graph.value_info
|
||||
]):
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_name, attr['dtype'], attr['shape'])
|
||||
model.graph.value_info.append(new_val_info)
|
||||
outputs.append(new_name)
|
||||
|
||||
logger.info(f'outputs: {", ".join(outputs)}')
|
||||
|
||||
# replace Mark with Identity
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
del node.attribute[:]
|
||||
node.domain = ''
|
||||
node.op_type = 'Identity'
|
||||
|
||||
extractor = create_extractor(model)
|
||||
extracted_model = extractor.extract_model(inputs, outputs)
|
||||
|
||||
# remove all Identity, this may be done by onnx simplifier
|
||||
remove_identity(extracted_model)
|
||||
|
||||
# collect all used inputs
|
||||
used = set()
|
||||
for node in extracted_model.graph.node:
|
||||
for input in node.input:
|
||||
used.add(input)
|
||||
|
||||
for output in extracted_model.graph.output:
|
||||
used.add(output.name)
|
||||
|
||||
# delete unused inputs
|
||||
success = True
|
||||
while success:
|
||||
success = False
|
||||
for i, input in enumerate(extracted_model.graph.input):
|
||||
if input.name not in used:
|
||||
del extracted_model.graph.input[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
# eliminate output without shape
|
||||
for xs in [extracted_model.graph.output]:
|
||||
for x in xs:
|
||||
if not x.type.tensor_type.shape.dim:
|
||||
logger.info(f'fixing output shape: {x.name}')
|
||||
x.CopyFrom(
|
||||
onnx.helper.make_tensor_value_info(
|
||||
x.name, x.type.tensor_type.elem_type, []))
|
||||
|
||||
# eliminate 0-batch dimension, dirty workaround for two-stage detectors
|
||||
for input in extracted_model.graph.input:
|
||||
if input.name in inputs:
|
||||
if input.type.tensor_type.shape.dim[0].dim_value == 0:
|
||||
input.type.tensor_type.shape.dim[0].dim_value = 1
|
||||
|
||||
# eliminate duplicated value_info for inputs
|
||||
success = True
|
||||
# num_value_info == 0 if dynamic shape
|
||||
if num_value_info == 0:
|
||||
while len(extracted_model.graph.value_info) > 0:
|
||||
extracted_model.graph.value_info.pop()
|
||||
while success:
|
||||
success = False
|
||||
for i, x in enumerate(extracted_model.graph.value_info):
|
||||
if x.name in inputs:
|
||||
del extracted_model.graph.value_info[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
# dynamic shape support
|
||||
if dynamic_axes is not None:
|
||||
for input_node in extracted_model.graph.input:
|
||||
if input_node.name in dynamic_axes:
|
||||
axes = dynamic_axes[input_node.name]
|
||||
for k, v in axes.items():
|
||||
input_node.type.tensor_type.shape.dim[k].dim_value = 0
|
||||
input_node.type.tensor_type.shape.dim[k].dim_param = v
|
||||
for output_node in extracted_model.graph.output:
|
||||
for idx, dim in enumerate(output_node.type.tensor_type.shape.dim):
|
||||
dim.dim_value = 0
|
||||
dim.dim_param = f'dim_{idx}'
|
||||
|
||||
# save extract_model if save_file is given
|
||||
if save_file is not None:
|
||||
onnx.save(extracted_model, save_file)
|
||||
|
||||
return extracted_model
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.backend.onnxruntime import is_available, is_plugin_available
|
||||
from mmdeploy.backend.onnxruntime import is_available, is_custom_ops_available
|
||||
|
||||
__all__ = ['is_available', 'is_plugin_available']
|
||||
__all__ = ['is_available', 'is_custom_ops_available']
|
||||
|
|
|
@ -4,10 +4,14 @@ from mmdeploy.backend.openvino import is_available
|
|||
__all__ = ['is_available']
|
||||
|
||||
if is_available():
|
||||
from mmdeploy.backend.openvino.onnx2openvino import (get_output_model_file,
|
||||
onnx2openvino)
|
||||
from mmdeploy.backend.openvino.onnx2openvino import from_onnx as _from_onnx
|
||||
from mmdeploy.backend.openvino.onnx2openvino import get_output_model_file
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
|
||||
|
||||
from .utils import get_input_info_from_cfg, get_mo_options_from_cfg
|
||||
__all__ += [
|
||||
'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg',
|
||||
'from_onnx', 'get_output_model_file', 'get_input_info_from_cfg',
|
||||
'get_mo_options_from_cfg'
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@ from mmdeploy.backend.pplnn import is_available
|
|||
__all__ = ['is_available']
|
||||
|
||||
if is_available():
|
||||
from mmdeploy.backend.pplnn.onnx2pplnn import onnx2pplnn
|
||||
from mmdeploy.backend.pplnn.onnx2pplnn import from_onnx as _from_onnx
|
||||
from ..core import PIPELINE_MANAGER
|
||||
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
|
||||
|
||||
__all__ += ['onnx2pplnn']
|
||||
__all__ += ['from_onnx']
|
||||
|
|
|
@ -1,60 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.apis.core.pipeline_manager import no_mp
|
||||
from mmdeploy.utils import (get_backend, get_dynamic_axes, get_input_shape,
|
||||
get_onnx_config, load_config)
|
||||
from .core import PIPELINE_MANAGER
|
||||
from .onnx import export
|
||||
|
||||
|
||||
def torch2onnx_impl(model: torch.nn.Module, input: Union[torch.Tensor, Tuple],
|
||||
deploy_cfg: Union[str, mmcv.Config], output_file: str):
|
||||
"""Converting torch model to ONNX.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Input pytorch model.
|
||||
input (torch.Tensor | Tuple): Input tensor used to convert model.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
Config object.
|
||||
output_file (str): Output file to save ONNX model.
|
||||
"""
|
||||
# load deploy_cfg if needed
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
onnx_cfg = get_onnx_config(deploy_cfg)
|
||||
backend = get_backend(deploy_cfg).value
|
||||
opset_version = onnx_cfg.get('opset_version', 11)
|
||||
|
||||
input_names = onnx_cfg['input_names']
|
||||
output_names = onnx_cfg['output_names']
|
||||
axis_names = input_names + output_names
|
||||
dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names)
|
||||
verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get(
|
||||
'verbose', False)
|
||||
|
||||
# patch model
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg, backend=backend,
|
||||
opset=opset_version), torch.no_grad():
|
||||
torch.onnx.export(
|
||||
patched_model,
|
||||
input,
|
||||
output_file,
|
||||
export_params=onnx_cfg['export_params'],
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=onnx_cfg[
|
||||
'keep_initializers_as_inputs'],
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def torch2onnx(img: Any,
|
||||
work_dir: str,
|
||||
save_file: str,
|
||||
|
@ -94,10 +52,10 @@ def torch2onnx(img: Any,
|
|||
# load deploy_cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
output_file = osp.join(work_dir, save_file)
|
||||
|
||||
input_shape = get_input_shape(deploy_cfg)
|
||||
|
||||
# create model an inputs
|
||||
from mmdeploy.apis import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
|
@ -106,8 +64,34 @@ def torch2onnx(img: Any,
|
|||
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
|
||||
model_inputs = model_inputs[0]
|
||||
|
||||
torch2onnx_impl(
|
||||
torch_model,
|
||||
model_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
output_file=output_file)
|
||||
# export to onnx
|
||||
context_info = dict()
|
||||
context_info['deploy_cfg'] = deploy_cfg
|
||||
output_prefix = osp.join(work_dir,
|
||||
osp.splitext(osp.basename(save_file))[0])
|
||||
backend = get_backend(deploy_cfg).value
|
||||
|
||||
onnx_cfg = get_onnx_config(deploy_cfg)
|
||||
opset_version = onnx_cfg.get('opset_version', 11)
|
||||
|
||||
input_names = onnx_cfg['input_names']
|
||||
output_names = onnx_cfg['output_names']
|
||||
axis_names = input_names + output_names
|
||||
dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names)
|
||||
verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get(
|
||||
'verbose', False)
|
||||
keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs',
|
||||
True)
|
||||
with no_mp():
|
||||
export(
|
||||
torch_model,
|
||||
model_inputs,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
context_info=context_info,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
verbose=verbose,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs)
|
||||
|
|
|
@ -1,73 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from packaging.version import parse as version_parse
|
||||
|
||||
from mmdeploy.backend.torchscript import get_ops_path
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import (IR, get_backend, get_input_shape, get_root_logger,
|
||||
load_config)
|
||||
|
||||
|
||||
def torch2torchscript_impl(model: torch.nn.Module,
|
||||
inputs: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
deploy_cfg: Union[str,
|
||||
mmcv.Config], output_file: str):
|
||||
"""Converting torch model to torchscript.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Input pytorch model.
|
||||
inputs (torch.Tensor | Sequence[torch.Tensor]): Input tensors used to
|
||||
convert model.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
Config object.
|
||||
output_file (str): Output file to save torchscript model.
|
||||
"""
|
||||
# load custom ops if exist
|
||||
custom_ops_path = get_ops_path()
|
||||
if osp.exists(custom_ops_path):
|
||||
torch.ops.load_library(custom_ops_path)
|
||||
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
backend = get_backend(deploy_cfg).value
|
||||
|
||||
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg, backend=backend,
|
||||
ir=IR.TORCHSCRIPT), torch.no_grad(), torch.jit.optimized_execution(
|
||||
True):
|
||||
# for exporting models with weight that depends on inputs
|
||||
patched_model(*inputs) if isinstance(inputs, Sequence) \
|
||||
else patched_model(inputs)
|
||||
ts_model = torch.jit.trace(patched_model, inputs)
|
||||
|
||||
# perform optimize, note that optimizing models may trigger errors when
|
||||
# loading the saved .pt file, as described in
|
||||
# https://github.com/pytorch/pytorch/issues/62706
|
||||
logger = get_root_logger()
|
||||
logger.info('perform torchscript optimizer.')
|
||||
try:
|
||||
# custom optimizer
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
logger = get_root_logger()
|
||||
ts_optimizer.optimize_for_backend(
|
||||
ts_model._c, ir=IR.TORCHSCRIPT.value, backend=backend)
|
||||
except Exception:
|
||||
# use pytorch builtin optimizer
|
||||
ts_model = torch.jit.freeze(ts_model)
|
||||
torch_version = version_parse(torch.__version__)
|
||||
if torch_version.minor >= 9:
|
||||
ts_model = torch.jit.optimize_for_inference(ts_model)
|
||||
|
||||
# save model
|
||||
torch.jit.save(ts_model, output_file)
|
||||
from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp
|
||||
from mmdeploy.utils import get_backend, get_input_shape, load_config
|
||||
from .torch_jit import trace
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def torch2torchscript(img: Any,
|
||||
work_dir: str,
|
||||
save_file: str,
|
||||
|
@ -92,7 +35,6 @@ def torch2torchscript(img: Any,
|
|||
# load deploy_cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
output_file = osp.join(work_dir, save_file)
|
||||
|
||||
input_shape = get_input_shape(deploy_cfg)
|
||||
|
||||
|
@ -104,8 +46,15 @@ def torch2torchscript(img: Any,
|
|||
if not isinstance(model_inputs, torch.Tensor):
|
||||
model_inputs = model_inputs[0]
|
||||
|
||||
torch2torchscript_impl(
|
||||
torch_model,
|
||||
model_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
output_file=output_file)
|
||||
context_info = dict(deploy_cfg=deploy_cfg)
|
||||
backend = get_backend(deploy_cfg).value
|
||||
output_prefix = osp.join(work_dir, osp.splitext(save_file)[0])
|
||||
|
||||
with no_mp():
|
||||
trace(
|
||||
torch_model,
|
||||
model_inputs,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
context_info=context_info,
|
||||
check_trace=False)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.backend.sdk import is_available
|
||||
|
||||
__all__ = ['is_available']
|
||||
|
||||
if is_available():
|
||||
try:
|
||||
from mmdeploy.backend.sdk.export_info import export2SDK as _export2SDK
|
||||
from ..core import PIPELINE_MANAGER
|
||||
export2SDK = PIPELINE_MANAGER.register_pipeline()(_export2SDK)
|
||||
|
||||
__all__ += ['export2SDK']
|
||||
except Exception:
|
||||
pass
|
|
@ -1,9 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.backend.tensorrt import is_available, is_plugin_available
|
||||
from mmdeploy.backend.tensorrt import from_onnx as _from_onnx
|
||||
from mmdeploy.backend.tensorrt import (is_available, is_custom_ops_available,
|
||||
load, save)
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
__all__ = ['is_available', 'is_plugin_available']
|
||||
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
|
||||
|
||||
__all__ = [
|
||||
'is_available', 'is_custom_ops_available', 'from_onnx', 'save', 'load'
|
||||
]
|
||||
|
||||
if is_available():
|
||||
from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt
|
||||
try:
|
||||
from mmdeploy.backend.tensorrt.onnx2tensorrt import \
|
||||
onnx2tensorrt as _onnx2tensorrt
|
||||
|
||||
__all__ += ['onnx2tensorrt']
|
||||
onnx2tensorrt = PIPELINE_MANAGER.register_pipeline()(_onnx2tensorrt)
|
||||
__all__ += ['onnx2tensorrt']
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.backend.torchscript import get_ops_path
|
||||
from .trace import trace
|
||||
|
||||
__all__ = ['get_ops_path', 'trace']
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as version_parse
|
||||
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def trace(func: torch.nn.Module,
|
||||
inputs: Union[torch.Tensor, Tuple],
|
||||
output_path_prefix: Optional[str] = None,
|
||||
backend: Union[Backend, str] = 'default',
|
||||
context_info: Dict = dict(),
|
||||
check_trace: bool = True,
|
||||
check_tolerance: float = 1e-05) -> torch.jit.TracedModule:
|
||||
"""A wrapper of `torch.jit.trace` with some enhancement.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.torch_jit import trace
|
||||
>>>
|
||||
>>> func = create_model()
|
||||
>>> inputs = get_input_tensor()
|
||||
>>>
|
||||
>>> jit_model = trace(
|
||||
>>> func,
|
||||
>>> inputs,
|
||||
>>> backend='torchscript',
|
||||
>>> check_trace=False)
|
||||
>>>
|
||||
|
||||
Args:
|
||||
func (torch.nn.Module): A Python function or `torch.nn.Module` that
|
||||
will be run with `example_inputs`.
|
||||
inputs (torch.Tensor, Tuple): A tuple of example inputs that will be
|
||||
passed to the function while tracing.
|
||||
output_path_prefix (str): The model would be serialized in
|
||||
`<output_path_prefix>.pth`, None if you don't want to
|
||||
save the model.
|
||||
backend (Backend|str): Which backend will the graph be used. Different
|
||||
backend would generate different graph.
|
||||
context_info (Dict): The information that would be used in the context
|
||||
of exporting.
|
||||
check_trace (bool): Check if the same inputs run through traced code
|
||||
produce the same outputs.
|
||||
check_tolerance (float): Floating-point comparison tolerance to use in
|
||||
the checker procedure.
|
||||
|
||||
Returns:
|
||||
torch.jit.TracedModule: The traced torch jit model.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
logger.info('Export PyTorch model to torchscript.')
|
||||
|
||||
def _add_or_update(cfg: dict, key: str, val: Any):
|
||||
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
|
||||
cfg[key].update(val)
|
||||
else:
|
||||
cfg[key] = val
|
||||
|
||||
context_info = deepcopy(context_info)
|
||||
deploy_cfg = context_info.pop('deploy_cfg', dict())
|
||||
ir_config = dict(type='torchscript')
|
||||
_add_or_update(deploy_cfg, 'ir_config', ir_config)
|
||||
|
||||
if isinstance(backend, Backend):
|
||||
backend = backend.value
|
||||
backend_config = dict(type=backend)
|
||||
_add_or_update(deploy_cfg, 'backend_config', backend_config)
|
||||
|
||||
context_info['cfg'] = deploy_cfg
|
||||
if 'backend' not in context_info:
|
||||
context_info['backend'] = backend
|
||||
elif context_info['backend'] != backend:
|
||||
logger.warning(
|
||||
f'Find backend {context_info["backend"]} in context_info.'
|
||||
f' Expect {backend}.')
|
||||
if 'ir' not in context_info:
|
||||
context_info['ir'] = IR.TORCHSCRIPT
|
||||
elif context_info['ir'] != backend:
|
||||
logger.warning(f'Find ir {context_info["ir"]} in context_info.'
|
||||
f' Expect {IR.TORCHSCRIPT}.')
|
||||
|
||||
# patch model
|
||||
if isinstance(func, torch.nn.Module):
|
||||
func = patch_model(func, cfg=deploy_cfg, backend=backend)
|
||||
|
||||
with RewriterContext(**context_info), torch.no_grad():
|
||||
# for exporting models with weight that depends on inputs
|
||||
func(*inputs) if isinstance(inputs, Sequence) \
|
||||
else func(inputs)
|
||||
ts_model = torch.jit.trace(
|
||||
func,
|
||||
inputs,
|
||||
check_trace=check_trace,
|
||||
check_tolerance=check_tolerance)
|
||||
|
||||
logger.info('perform torchscript optimizer.')
|
||||
try:
|
||||
# custom optimizer
|
||||
from mmdeploy.backend.torchscript import ts_optimizer
|
||||
logger = get_root_logger()
|
||||
ts_optimizer.optimize_for_backend(
|
||||
ts_model._c, ir=IR.TORCHSCRIPT.value, backend=backend)
|
||||
except Exception:
|
||||
# use pytorch builtin optimizer
|
||||
ts_model = torch.jit.freeze(ts_model)
|
||||
torch_version = version_parse(torch.__version__)
|
||||
if torch_version.minor >= 9:
|
||||
ts_model = torch.jit.optimize_for_inference(ts_model)
|
||||
|
||||
# save model
|
||||
if output_path_prefix is not None:
|
||||
output_path = output_path_prefix + '.pt'
|
||||
logger.info(f'Save PyTorch model: {output_path}.')
|
||||
torch.jit.save(ts_model, output_path)
|
||||
|
||||
return ts_model
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .calibration import create_calib_input_data
|
||||
from .utils import build_task_processor, get_predefined_partition_cfg
|
||||
|
||||
__all__ = [
|
||||
'create_calib_input_data', 'build_task_processor',
|
||||
'get_predefined_partition_cfg'
|
||||
]
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmdeploy.core import RewriterContext, reset_mark_function_count
|
||||
from ..core import PIPELINE_MANAGER
|
||||
|
||||
|
||||
@PIPELINE_MANAGER.register_pipeline()
|
||||
def create_calib_input_data(calib_file: str,
|
||||
model: torch.nn.Module,
|
||||
dataloader: DataLoader,
|
||||
get_tensor_func: Optional[Callable] = None,
|
||||
inference_func: Optional[Callable] = None,
|
||||
model_partition: bool = False,
|
||||
context_info: Dict = dict(),
|
||||
device: str = 'cpu') -> None:
|
||||
"""Create calibration table.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.utils import create_calib_input_data
|
||||
>>> from mmdeploy.utils import get_calib_filename, load_config
|
||||
>>> deploy_cfg = 'configs/mmdet/detection/'
|
||||
'detection_tensorrt-int8_dynamic-320x320-1344x1344.py'
|
||||
>>> deploy_cfg = load_config(deploy_cfg)[0]
|
||||
>>> calib_file = get_calib_filename(deploy_cfg)
|
||||
>>> model_cfg = 'mmdetection/configs/fcos/'
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
|
||||
>>> model_checkpoint = 'checkpoints/'
|
||||
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
|
||||
>>> create_calib_input_data(calib_file, deploy_cfg,
|
||||
model_cfg, model_checkpoint, device='cuda:0')
|
||||
|
||||
Args:
|
||||
calib_file (str): Input calibration file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
model_cfg (str | mmcv.Config): The model config.
|
||||
model_checkpoint (str): PyTorch model checkpoint, defaults to `None`.
|
||||
dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None`
|
||||
dataset_type (str): A string specifying dataset type, e.g.: 'test',
|
||||
'val', defaults to 'val'.
|
||||
device (str): Specifying the device to run on, defaults to 'cpu'.
|
||||
"""
|
||||
|
||||
backend = 'default'
|
||||
|
||||
with h5py.File(calib_file, mode='w') as file:
|
||||
calib_data_group = file.create_group('calib_data')
|
||||
|
||||
if not model_partition:
|
||||
# create end2end group
|
||||
input_data_group = calib_data_group.create_group('end2end')
|
||||
input_group = input_data_group.create_group('input')
|
||||
for data_id, input_data in enumerate(tqdm.tqdm(dataloader)):
|
||||
|
||||
if not model_partition:
|
||||
# save end2end data
|
||||
if get_tensor_func is not None:
|
||||
input_tensor = get_tensor_func(input_data)
|
||||
else:
|
||||
input_tensor = input_data
|
||||
input_ndarray = input_tensor.detach().cpu().numpy()
|
||||
input_group.create_dataset(
|
||||
str(data_id),
|
||||
shape=input_ndarray.shape,
|
||||
compression='gzip',
|
||||
compression_opts=4,
|
||||
data=input_ndarray)
|
||||
else:
|
||||
context_info_ = deepcopy(context_info)
|
||||
if 'cfg' not in context_info:
|
||||
context_info_['cfg'] = dict()
|
||||
context_info_['backend'] = backend
|
||||
context_info_['create_calib'] = True
|
||||
context_info_['calib_file'] = file
|
||||
context_info_['data_id'] = data_id
|
||||
|
||||
with torch.no_grad(), RewriterContext(**context_info_):
|
||||
reset_mark_function_count()
|
||||
if inference_func is not None:
|
||||
inference_func(model, input_data)
|
||||
else:
|
||||
model(input_data)
|
||||
|
||||
file.flush()
|
|
@ -3,6 +3,7 @@ import importlib
|
|||
import os.path as osp
|
||||
|
||||
from .init_plugins import get_onnx2ncnn_path, get_ops_path
|
||||
from .onnx2ncnn import from_onnx
|
||||
|
||||
|
||||
def is_available():
|
||||
|
@ -19,7 +20,7 @@ def is_available():
|
|||
return has_pyncnn and osp.exists(onnx2ncnn)
|
||||
|
||||
|
||||
def is_plugin_available():
|
||||
def is_custom_ops_available():
|
||||
"""Check whether ncnn extension and custom ops are installed.
|
||||
|
||||
Returns:
|
||||
|
@ -31,10 +32,12 @@ def is_plugin_available():
|
|||
return has_pyncnn_ext and osp.exists(ncnn_ops_path)
|
||||
|
||||
|
||||
__all__ = ['from_onnx']
|
||||
|
||||
if is_available():
|
||||
try:
|
||||
from .wrapper import NCNNWrapper
|
||||
|
||||
__all__ = ['NCNNWrapper']
|
||||
__all__ += ['NCNNWrapper']
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from subprocess import call
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import onnx
|
||||
|
||||
from .init_plugins import get_onnx2ncnn_path
|
||||
|
||||
|
@ -32,7 +35,8 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> List[str]:
|
|||
return [save_param, save_bin]
|
||||
|
||||
|
||||
def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str):
|
||||
def from_onnx(onnx_model: Union[onnx.ModelProto, str],
|
||||
output_file_prefix: str):
|
||||
"""Convert ONNX to ncnn.
|
||||
|
||||
The inputs of ncnn include a model file and a weight file. We need to use
|
||||
|
@ -40,18 +44,24 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str):
|
|||
a `.bin` file. The output files will save to work_dir.
|
||||
|
||||
Example:
|
||||
>>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn
|
||||
>>> from mmdeploy.apis.ncnn import from_onnx
|
||||
>>> onnx_path = 'work_dir/end2end.onnx'
|
||||
>>> save_param = 'work_dir/end2end.param'
|
||||
>>> save_bin = 'work_dir/end2end.bin'
|
||||
>>> onnx2ncnn(onnx_path, save_param, save_bin)
|
||||
>>> output_file_prefix = 'work_dir/end2end'
|
||||
>>> from_onnx(onnx_path, output_file_prefix)
|
||||
|
||||
Args:
|
||||
onnx_path (str): The path of the onnx model.
|
||||
save_param (str): The path to save the output `.param` file.
|
||||
save_bin (str): The path to save the output `.bin` file.
|
||||
onnx_path (ModelProto|str): The path of the onnx model.
|
||||
output_file_prefix (str): The path to save the output ncnn file.
|
||||
"""
|
||||
|
||||
onnx2ncnn_path = get_onnx2ncnn_path()
|
||||
if not isinstance(onnx_model, str):
|
||||
onnx_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
onnx.save(onnx_model, onnx_path)
|
||||
else:
|
||||
onnx_path = onnx_model
|
||||
|
||||
save_param = output_file_prefix + '.param'
|
||||
save_bin = output_file_prefix + '.bin'
|
||||
|
||||
onnx2ncnn_path = get_onnx2ncnn_path()
|
||||
call([onnx2ncnn_path, onnx_path, save_param, save_bin])
|
||||
|
|
|
@ -15,7 +15,7 @@ def is_available():
|
|||
return importlib.util.find_spec('onnxruntime') is not None
|
||||
|
||||
|
||||
def is_plugin_available():
|
||||
def is_custom_ops_available():
|
||||
"""Check whether ONNX Runtime custom ops are installed.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import tempfile
|
||||
from subprocess import PIPE, CalledProcessError, run
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import onnx
|
||||
|
||||
from mmdeploy.utils import get_root_logger
|
||||
from .utils import ModelOptimizerOptions
|
||||
|
@ -55,30 +56,33 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> str:
|
|||
return model_xml
|
||||
|
||||
|
||||
def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
|
||||
output_names: List[str],
|
||||
onnx_path: str,
|
||||
work_dir: str,
|
||||
mo_options: Optional[ModelOptimizerOptions] = None):
|
||||
def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
output_file_prefix: str,
|
||||
input_info: Dict[str, Sequence[int]],
|
||||
output_names: Sequence[str],
|
||||
mo_options: Optional[ModelOptimizerOptions] = None):
|
||||
"""Convert ONNX to OpenVINO.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino
|
||||
>>> from mmdeploy.apis.openvino import from_onnx
|
||||
>>> input_info = {'input': [1,3,800,1344]}
|
||||
>>> output_names = ['dets', 'labels']
|
||||
>>> onnx_path = 'work_dir/end2end.onnx'
|
||||
>>> work_dir = 'work_dir'
|
||||
>>> onnx2openvino(input_info, output_names, onnx_path, work_dir)
|
||||
>>> output_dir = 'work_dir'
|
||||
>>> from_onnx( onnx_path, output_dir, input_info, output_names)
|
||||
|
||||
Args:
|
||||
input_info (Dict[str, Union[List[int], torch.Size]]):
|
||||
onnx_model (str|ModelProto): The onnx model or its path.
|
||||
output_file_prefix (str): The path to the directory for saving
|
||||
the results.
|
||||
input_info (Dict[str, Sequence[int]]):
|
||||
The shape of each input.
|
||||
output_names (List[str]): Output names. Example: ['dets', 'labels'].
|
||||
onnx_path (str): The path to the onnx model.
|
||||
work_dir (str): The path to the directory for saving the results.
|
||||
output_names (Sequence[str]): Output names. Example:
|
||||
['dets', 'labels'].
|
||||
mo_options (None | ModelOptimizerOptions): The class with
|
||||
additional arguments for the Model Optimizer.
|
||||
"""
|
||||
work_dir = output_file_prefix
|
||||
input_names = ','.join(input_info.keys())
|
||||
input_shapes = ','.join(str(list(elem)) for elem in input_info.values())
|
||||
output = ','.join(output_names)
|
||||
|
@ -89,6 +93,12 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
|
|||
raise RuntimeError(
|
||||
'OpenVINO Model Optimizer is not found or configured improperly')
|
||||
|
||||
if isinstance(onnx_model, str):
|
||||
onnx_path = onnx_model
|
||||
else:
|
||||
onnx_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
onnx.save(onnx_model, onnx_path)
|
||||
|
||||
mo_args = f'--input_model="{onnx_path}" '\
|
||||
f'--output_dir="{work_dir}" ' \
|
||||
f'--output="{output}" ' \
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import importlib
|
||||
|
||||
from .utils import register_engines
|
||||
|
||||
|
||||
def is_available():
|
||||
"""Check whether pplnn is installed.
|
||||
|
@ -11,6 +13,8 @@ def is_available():
|
|||
return importlib.util.find_spec('pyppl') is not None
|
||||
|
||||
|
||||
__all__ = ['register_engines']
|
||||
|
||||
if is_available():
|
||||
from .wrapper import PPLNNWrapper, register_engines
|
||||
__all__ = ['register_engines', 'PPLNNWrapper']
|
||||
from .wrapper import PPLNNWrapper
|
||||
__all__ += ['PPLNNWrapper']
|
||||
|
|
|
@ -4,14 +4,14 @@ from typing import Optional, Sequence
|
|||
from pyppl import nn as pplnn
|
||||
|
||||
from mmdeploy.utils.device import parse_cuda_device_id
|
||||
from .wrapper import register_engines
|
||||
from .utils import register_engines
|
||||
|
||||
|
||||
def onnx2pplnn(algo_file: str,
|
||||
onnx_model: str,
|
||||
device: str = 'cuda:0',
|
||||
input_shapes: Optional[Sequence[Sequence[int]]] = None,
|
||||
**kwargs):
|
||||
def from_onnx(onnx_model: str,
|
||||
output_file_prefix: str,
|
||||
device: str = 'cuda:0',
|
||||
input_shapes: Optional[Sequence[Sequence[int]]] = None,
|
||||
**kwargs):
|
||||
"""Convert ONNX to PPLNN.
|
||||
|
||||
PPLNN is capable of optimizing onnx model. The optimized algorithm is saved
|
||||
|
@ -21,16 +21,18 @@ def onnx2pplnn(algo_file: str,
|
|||
own preferences.
|
||||
|
||||
Args:
|
||||
algo_file (str): File path to save PPLNN optimization algorithm.
|
||||
output_file_prefix (str): File path to save PPLNN optimization
|
||||
algorithm and ONNX file
|
||||
onnx_model (str): Input onnx model.
|
||||
device (str): A string specifying device, defaults to 'cuda:0'.
|
||||
input_shapes (Sequence[Sequence[int]] | None): Shapes for PPLNN
|
||||
optimization, default to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.pplnn import onnx2pplnn
|
||||
>>> from mmdeploy.apis.pplnn import from_onnx
|
||||
>>>
|
||||
>>> onnx2pplnn(algo_file = 'example.json', onnx_model = 'example.onnx')
|
||||
>>> from_onnx(onnx_model = 'example.onnx',
|
||||
output_file_prefix = 'example')
|
||||
"""
|
||||
if device == 'cpu':
|
||||
device_id = -1
|
||||
|
@ -42,6 +44,8 @@ def onnx2pplnn(algo_file: str,
|
|||
input_shapes = [[1, 3, 224,
|
||||
224]] # PPLNN default shape for optimization
|
||||
|
||||
algo_file = output_file_prefix + '.json'
|
||||
onnx_output_path = output_file_prefix + '.onnx'
|
||||
engines = register_engines(
|
||||
device_id,
|
||||
disable_avx512=False,
|
||||
|
@ -52,3 +56,6 @@ def onnx2pplnn(algo_file: str,
|
|||
onnx_model, engines)
|
||||
assert runtime_builder is not None, 'Failed to create '\
|
||||
'OnnxRuntimeBuilder.'
|
||||
import shutil
|
||||
if onnx_output_path != onnx_model:
|
||||
shutil.copy2(onnx_model, onnx_output_path)
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from typing import List, Sequence
|
||||
|
||||
import pyppl.common as pplcommon
|
||||
import pyppl.nn as pplnn
|
||||
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
def register_engines(device_id: int,
|
||||
disable_avx512: bool = False,
|
||||
quick_select: bool = False,
|
||||
input_shapes: Sequence[Sequence[int]] = None,
|
||||
export_algo_file: str = None,
|
||||
import_algo_file: str = None) -> List[pplnn.Engine]:
|
||||
"""Register engines for pplnn runtime.
|
||||
|
||||
Args:
|
||||
device_id (int): Specifying device index. `-1` for cpu.
|
||||
disable_avx512 (bool): Whether to disable avx512 for x86.
|
||||
Defaults to `False`.
|
||||
quick_select (bool): Whether to use default algorithms.
|
||||
Defaults to `False`.
|
||||
input_shapes (Sequence[Sequence[int]]): shapes for PPLNN optimization.
|
||||
export_algo_file (str): File path for exporting PPLNN optimization
|
||||
file.
|
||||
import_algo_file (str): File path for loading PPLNN optimization file.
|
||||
|
||||
Returns:
|
||||
list[pplnn.Engine]: A list of registered pplnn engines.
|
||||
"""
|
||||
engines = []
|
||||
logger = get_root_logger()
|
||||
if device_id == -1:
|
||||
x86_options = pplnn.X86EngineOptions()
|
||||
x86_engine = pplnn.X86EngineFactory.Create(x86_options)
|
||||
if not x86_engine:
|
||||
logger.error('Failed to create x86 engine')
|
||||
sys.exit(-1)
|
||||
|
||||
if disable_avx512:
|
||||
status = x86_engine.Configure(pplnn.X86_CONF_DISABLE_AVX512)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error('x86 engine Configure() failed: ' +
|
||||
pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
engines.append(pplnn.Engine(x86_engine))
|
||||
|
||||
else:
|
||||
cuda_options = pplnn.CudaEngineOptions()
|
||||
cuda_options.device_id = device_id
|
||||
|
||||
cuda_engine = pplnn.CudaEngineFactory.Create(cuda_options)
|
||||
if not cuda_engine:
|
||||
logger.error('Failed to create cuda engine.')
|
||||
sys.exit(-1)
|
||||
|
||||
if quick_select:
|
||||
status = cuda_engine.Configure(
|
||||
pplnn.CUDA_CONF_USE_DEFAULT_ALGORITHMS)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error('cuda engine Configure() failed: ' +
|
||||
pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if input_shapes is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS,
|
||||
input_shapes)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: '
|
||||
+ pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if export_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS,
|
||||
export_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if import_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS,
|
||||
import_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
engines.append(pplnn.Engine(cuda_engine))
|
||||
|
||||
return engines
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
@ -8,98 +7,10 @@ import pyppl.common as pplcommon
|
|||
import pyppl.nn as pplnn
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils import Backend, get_root_logger, parse_device_id
|
||||
from mmdeploy.utils import Backend, parse_device_id
|
||||
from mmdeploy.utils.timer import TimeCounter
|
||||
from ..base import BACKEND_WRAPPER, BaseWrapper
|
||||
|
||||
|
||||
def register_engines(device_id: int,
|
||||
disable_avx512: bool = False,
|
||||
quick_select: bool = False,
|
||||
input_shapes: Sequence[Sequence[int]] = None,
|
||||
export_algo_file: str = None,
|
||||
import_algo_file: str = None) -> List[pplnn.Engine]:
|
||||
"""Register engines for pplnn runtime.
|
||||
|
||||
Args:
|
||||
device_id (int): Specifying device index. `-1` for cpu.
|
||||
disable_avx512 (bool): Whether to disable avx512 for x86.
|
||||
Defaults to `False`.
|
||||
quick_select (bool): Whether to use default algorithms.
|
||||
Defaults to `False`.
|
||||
input_shapes (Sequence[Sequence[int]]): shapes for PPLNN optimization.
|
||||
export_algo_file (str): File path for exporting PPLNN optimization
|
||||
file.
|
||||
import_algo_file (str): File path for loading PPLNN optimization file.
|
||||
|
||||
Returns:
|
||||
list[pplnn.Engine]: A list of registered pplnn engines.
|
||||
"""
|
||||
engines = []
|
||||
logger = get_root_logger()
|
||||
if device_id == -1:
|
||||
x86_options = pplnn.X86EngineOptions()
|
||||
x86_engine = pplnn.X86EngineFactory.Create(x86_options)
|
||||
if not x86_engine:
|
||||
logger.error('Failed to create x86 engine')
|
||||
sys.exit(-1)
|
||||
|
||||
if disable_avx512:
|
||||
status = x86_engine.Configure(pplnn.X86_CONF_DISABLE_AVX512)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error('x86 engine Configure() failed: ' +
|
||||
pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
engines.append(pplnn.Engine(x86_engine))
|
||||
|
||||
else:
|
||||
cuda_options = pplnn.CudaEngineOptions()
|
||||
cuda_options.device_id = device_id
|
||||
|
||||
cuda_engine = pplnn.CudaEngineFactory.Create(cuda_options)
|
||||
if not cuda_engine:
|
||||
logger.error('Failed to create cuda engine.')
|
||||
sys.exit(-1)
|
||||
|
||||
if quick_select:
|
||||
status = cuda_engine.Configure(
|
||||
pplnn.CUDA_CONF_USE_DEFAULT_ALGORITHMS)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error('cuda engine Configure() failed: ' +
|
||||
pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if input_shapes is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS,
|
||||
input_shapes)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: '
|
||||
+ pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if export_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS,
|
||||
export_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
if import_algo_file is not None:
|
||||
status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS,
|
||||
import_algo_file)
|
||||
if status != pplcommon.RC_SUCCESS:
|
||||
logger.error(
|
||||
'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) '
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status))
|
||||
sys.exit(-1)
|
||||
|
||||
engines.append(pplnn.Engine(cuda_engine))
|
||||
|
||||
return engines
|
||||
from .utils import register_engines
|
||||
|
||||
|
||||
@BACKEND_WRAPPER.register_module(Backend.PPLNN.value)
|
||||
|
|
|
@ -350,8 +350,8 @@ def get_detail(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config,
|
|||
calib_config=calib_config)
|
||||
|
||||
|
||||
def dump_info(deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str):
|
||||
def export2SDK(deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str):
|
||||
"""Export information to SDK. This function dump `deploy.json`,
|
||||
`pipeline.json` and `detail.json` to work dir.
|
||||
|
|
@ -16,7 +16,7 @@ def is_available():
|
|||
return importlib.util.find_spec('tensorrt') is not None
|
||||
|
||||
|
||||
def is_plugin_available():
|
||||
def is_custom_ops_available():
|
||||
"""Check whether TensorRT custom ops are installed.
|
||||
|
||||
Returns:
|
||||
|
@ -27,12 +27,9 @@ def is_plugin_available():
|
|||
|
||||
|
||||
if is_available():
|
||||
from .utils import create_trt_engine, load_trt_engine, save_trt_engine
|
||||
from .utils import from_onnx, load, save
|
||||
|
||||
__all__ = [
|
||||
'create_trt_engine', 'save_trt_engine', 'load_trt_engine',
|
||||
'load_tensorrt_plugin'
|
||||
]
|
||||
__all__ = ['from_onnx', 'save', 'load', 'load_tensorrt_plugin']
|
||||
|
||||
try:
|
||||
# import wrapper if pytorch is available
|
||||
|
|
|
@ -8,7 +8,7 @@ import onnx
|
|||
from mmdeploy.utils import (get_calib_filename, get_common_config,
|
||||
get_model_inputs, load_config, parse_device_id)
|
||||
from mmdeploy.utils.config_utils import get_ir_config
|
||||
from .utils import create_trt_engine, get_trt_log_level, save_trt_engine
|
||||
from .utils import from_onnx, get_trt_log_level
|
||||
|
||||
|
||||
def onnx2tensorrt(work_dir: str,
|
||||
|
@ -72,8 +72,13 @@ def onnx2tensorrt(work_dir: str,
|
|||
but given: {device}'
|
||||
|
||||
device_id = parse_device_id(device)
|
||||
engine = create_trt_engine(
|
||||
assert save_file.endswith(
|
||||
'.engine'
|
||||
), 'Expect save file ends with `.engine`.' f' but get {save_file}'
|
||||
save_path = osp.join(work_dir, save_file)
|
||||
from_onnx(
|
||||
onnx_model,
|
||||
osp.splitext(save_path)[0],
|
||||
input_shapes=input_shapes,
|
||||
log_level=get_trt_log_level(),
|
||||
fp16_mode=final_params.get('fp16_mode', False),
|
||||
|
@ -81,5 +86,3 @@ def onnx2tensorrt(work_dir: str,
|
|||
int8_param=int8_param,
|
||||
max_workspace_size=final_params.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
|
||||
save_trt_engine(engine, osp.join(work_dir, save_file))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Dict, Sequence, Union
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
|
@ -10,38 +10,68 @@ from mmdeploy.utils import get_root_logger
|
|||
from .init_plugins import load_tensorrt_plugin
|
||||
|
||||
|
||||
def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
|
||||
input_shapes: Dict[str, Sequence[int]],
|
||||
log_level: trt.Logger.Severity = trt.Logger.ERROR,
|
||||
fp16_mode: bool = False,
|
||||
int8_mode: bool = False,
|
||||
int8_param: dict = None,
|
||||
max_workspace_size: int = 0,
|
||||
device_id: int = 0,
|
||||
**kwargs) -> trt.ICudaEngine:
|
||||
def save(engine: trt.ICudaEngine, path: str) -> None:
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Args:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
|
||||
path (str): The absolute disk path to write the engine.
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load(path: str) -> trt.ICudaEngine:
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Args:
|
||||
path (str): The disk path to read the engine.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: The TensorRT engine loaded from disk.
|
||||
"""
|
||||
load_tensorrt_plugin()
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
with open(path, mode='rb') as f:
|
||||
engine_bytes = f.read()
|
||||
engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
return engine
|
||||
|
||||
|
||||
def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
output_file_prefix: str,
|
||||
input_shapes: Dict[str, Sequence[int]],
|
||||
max_workspace_size: int = 0,
|
||||
fp16_mode: bool = False,
|
||||
int8_mode: bool = False,
|
||||
int8_param: Optional[dict] = None,
|
||||
device_id: int = 0,
|
||||
log_level: trt.Logger.Severity = trt.Logger.ERROR,
|
||||
**kwargs) -> trt.ICudaEngine:
|
||||
"""Create a tensorrt engine from ONNX.
|
||||
|
||||
Args:
|
||||
onnx_model (str or onnx.ModelProto): Input onnx model to convert from.
|
||||
output_file_prefix (str): The path to save the output ncnn file.
|
||||
input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of
|
||||
each input.
|
||||
log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to
|
||||
`trt.Logger.ERROR`.
|
||||
max_workspace_size (int): To set max workspace size of TensorRT engine.
|
||||
some tactics and layers need large workspace. Defaults to `0`.
|
||||
fp16_mode (bool): Specifying whether to enable fp16 mode.
|
||||
Defaults to `False`.
|
||||
int8_mode (bool): Specifying whether to enable int8 mode.
|
||||
Defaults to `False`.
|
||||
int8_param (dict): A dict of parameter int8 mode. Defaults to `None`.
|
||||
max_workspace_size (int): To set max workspace size of TensorRT engine.
|
||||
some tactics and layers need large workspace. Defaults to `0`.
|
||||
device_id (int): Choice the device to create engine. Defaults to `0`.
|
||||
log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to
|
||||
`trt.Logger.ERROR`.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: The TensorRT engine created from onnx_model.
|
||||
|
||||
Example:
|
||||
>>> from mmdeploy.apis.tensorrt import create_trt_engine
|
||||
>>> engine = create_trt_engine(
|
||||
>>> from mmdeploy.apis.tensorrt import from_onnx
|
||||
>>> engine = from_onnx(
|
||||
>>> "onnx_model.onnx",
|
||||
>>> {'input': {"min_shape" : [1, 3, 160, 160],
|
||||
>>> "opt_shape" : [1, 3, 320, 320],
|
||||
|
@ -121,37 +151,11 @@ def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
|
|||
engine = builder.build_engine(network, config)
|
||||
|
||||
assert engine is not None, 'Failed to create TensorRT engine'
|
||||
|
||||
save(engine, output_file_prefix + '.engine')
|
||||
return engine
|
||||
|
||||
|
||||
def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Args:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
|
||||
path (str): The absolute disk path to write the engine.
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load_trt_engine(path: str) -> trt.ICudaEngine:
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Args:
|
||||
path (str): The disk path to read the engine.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: The TensorRT engine loaded from disk.
|
||||
"""
|
||||
load_tensorrt_plugin()
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
with open(path, mode='rb') as f:
|
||||
engine_bytes = f.read()
|
||||
engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
return engine
|
||||
|
||||
|
||||
def get_trt_log_level() -> trt.Logger.Severity:
|
||||
"""Get tensorrt log level from root logger.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdeploy.utils import Backend
|
|||
from mmdeploy.utils.timer import TimeCounter
|
||||
from ..base import BACKEND_WRAPPER, BaseWrapper
|
||||
from .init_plugins import load_tensorrt_plugin
|
||||
from .utils import load_trt_engine
|
||||
from .utils import load
|
||||
|
||||
|
||||
def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype:
|
||||
|
@ -81,7 +81,7 @@ class TRTWrapper(BaseWrapper):
|
|||
load_tensorrt_plugin()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
self.engine = load_trt_engine(engine)
|
||||
self.engine = load(engine)
|
||||
|
||||
if not isinstance(self.engine, trt.ICudaEngine):
|
||||
raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
|
||||
|
|
|
@ -15,7 +15,6 @@ def get_logger(name: str,
|
|||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified, a FileHandler will also be added.
|
||||
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
|
@ -23,7 +22,6 @@ def get_logger(name: str,
|
|||
log_level (int): The logger level.
|
||||
file_mode (str): The file mode used in opening log file.
|
||||
Defaults to 'w'.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import os.path as osp
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
|
@ -28,21 +29,21 @@ def backend_checker(backend: Backend, require_plugin: bool = False):
|
|||
will also check if the backend plugin has been compiled. Default
|
||||
to `False`.
|
||||
"""
|
||||
is_plugin_available = None
|
||||
is_custom_ops_available = None
|
||||
if backend == Backend.ONNXRUNTIME:
|
||||
from mmdeploy.apis.onnxruntime import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.onnxruntime import is_plugin_available
|
||||
from mmdeploy.apis.onnxruntime import is_custom_ops_available
|
||||
elif backend == Backend.TENSORRT:
|
||||
from mmdeploy.apis.tensorrt import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.tensorrt import is_plugin_available
|
||||
from mmdeploy.apis.tensorrt import is_custom_ops_available
|
||||
elif backend == Backend.PPLNN:
|
||||
from mmdeploy.apis.pplnn import is_available
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.ncnn import is_plugin_available
|
||||
from mmdeploy.apis.ncnn import is_custom_ops_available
|
||||
elif backend == Backend.OPENVINO:
|
||||
from mmdeploy.apis.openvino import is_available
|
||||
else:
|
||||
|
@ -51,9 +52,9 @@ def backend_checker(backend: Backend, require_plugin: bool = False):
|
|||
|
||||
checker = pytest.mark.skipif(
|
||||
not is_available(), reason=f'{backend.value} package is not available')
|
||||
if require_plugin and is_plugin_available is not None:
|
||||
if require_plugin and is_custom_ops_available is not None:
|
||||
plugin_checker = pytest.mark.skipif(
|
||||
not is_plugin_available(),
|
||||
not is_custom_ops_available(),
|
||||
reason=f'{backend.value} plugin is not available')
|
||||
|
||||
def double_checker(func):
|
||||
|
@ -76,21 +77,21 @@ def check_backend(backend: Backend, require_plugin: bool = False):
|
|||
will also check if the backend plugin has been compiled. Default
|
||||
to `False`.
|
||||
"""
|
||||
is_plugin_available = None
|
||||
is_custom_ops_available = None
|
||||
if backend == Backend.ONNXRUNTIME:
|
||||
from mmdeploy.apis.onnxruntime import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.onnxruntime import is_plugin_available
|
||||
from mmdeploy.apis.onnxruntime import is_custom_ops_available
|
||||
elif backend == Backend.TENSORRT:
|
||||
from mmdeploy.apis.tensorrt import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.tensorrt import is_plugin_available
|
||||
from mmdeploy.apis.tensorrt import is_custom_ops_available
|
||||
elif backend == Backend.PPLNN:
|
||||
from mmdeploy.apis.pplnn import is_available
|
||||
elif backend == Backend.NCNN:
|
||||
from mmdeploy.apis.ncnn import is_available
|
||||
if require_plugin:
|
||||
from mmdeploy.apis.ncnn import is_plugin_available
|
||||
from mmdeploy.apis.ncnn import is_custom_ops_available
|
||||
elif backend == Backend.OPENVINO:
|
||||
from mmdeploy.apis.openvino import is_available
|
||||
elif backend == Backend.TORCHSCRIPT:
|
||||
|
@ -101,8 +102,8 @@ def check_backend(backend: Backend, require_plugin: bool = False):
|
|||
|
||||
if not is_available():
|
||||
pytest.skip(f'{backend.value} package is not available')
|
||||
if require_plugin and is_plugin_available is not None:
|
||||
if not is_plugin_available():
|
||||
if require_plugin and is_custom_ops_available is not None:
|
||||
if not is_custom_ops_available():
|
||||
pytest.skip(f'{backend.value} plugin is not available')
|
||||
|
||||
|
||||
|
@ -409,14 +410,18 @@ def get_ts_model(wrapped_model: nn.Module,
|
|||
"""
|
||||
ir_file_path = tempfile.NamedTemporaryFile(suffix='.pt').name
|
||||
backend = get_backend(deploy_cfg)
|
||||
patched_model = patch_model(
|
||||
wrapped_model, cfg=deploy_cfg, backend=backend.value)
|
||||
|
||||
from mmdeploy.apis.pytorch2torchscript import torch2torchscript_impl
|
||||
torch2torchscript_impl(
|
||||
patched_model, [v for _, v in model_inputs.items()],
|
||||
deploy_cfg=deploy_cfg,
|
||||
output_file=ir_file_path)
|
||||
from mmdeploy.apis.torch_jit import trace
|
||||
context_info = dict(deploy_cfg=deploy_cfg)
|
||||
output_prefix = osp.splitext(ir_file_path)[0]
|
||||
|
||||
example_inputs = [v for _, v in model_inputs.items()]
|
||||
trace(
|
||||
wrapped_model,
|
||||
example_inputs,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
context_info=context_info)
|
||||
return ir_file_path
|
||||
|
||||
|
||||
|
@ -450,7 +455,8 @@ def get_backend_outputs(ir_file_path: str,
|
|||
if backend == Backend.TENSORRT:
|
||||
# convert to engine
|
||||
import mmdeploy.apis.tensorrt as trt_apis
|
||||
if not (trt_apis.is_available() and trt_apis.is_plugin_available()):
|
||||
if not (trt_apis.is_available()
|
||||
and trt_apis.is_custom_ops_available()):
|
||||
return None
|
||||
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
|
||||
trt_apis.onnx2tensorrt(
|
||||
|
@ -467,7 +473,8 @@ def get_backend_outputs(ir_file_path: str,
|
|||
device = 'cuda:0'
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
if not (ort_apis.is_available() and ort_apis.is_plugin_available()):
|
||||
if not (ort_apis.is_available()
|
||||
and ort_apis.is_custom_ops_available()):
|
||||
return None
|
||||
feature_list = []
|
||||
backend_feats = {}
|
||||
|
@ -495,12 +502,14 @@ def get_backend_outputs(ir_file_path: str,
|
|||
device = 'cpu'
|
||||
elif backend == Backend.NCNN:
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
if not (ncnn_apis.is_available() and ncnn_apis.is_plugin_available()):
|
||||
if not (ncnn_apis.is_available()
|
||||
and ncnn_apis.is_custom_ops_available()):
|
||||
return None
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
param_path, bin_path = ncnn_apis.get_output_model_file(
|
||||
ir_file_path, work_dir)
|
||||
ncnn_apis.onnx2ncnn(ir_file_path, param_path, bin_path)
|
||||
ir_file_name = osp.splitext(ir_file_path)[0]
|
||||
ncnn_apis.from_onnx(ir_file_path, osp.join(work_dir, ir_file_name))
|
||||
backend_files = [param_path, bin_path]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
|
@ -518,8 +527,8 @@ def get_backend_outputs(ir_file_path: str,
|
|||
for name, value in flatten_model_inputs.items()
|
||||
}
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
openvino_apis.onnx2openvino(input_info, output_names, ir_file_path,
|
||||
openvino_work_dir, mo_options)
|
||||
openvino_apis.from_onnx(ir_file_path, openvino_work_dir, input_info,
|
||||
output_names, mo_options)
|
||||
backend_files = [openvino_file_path]
|
||||
backend_feats = flatten_model_inputs
|
||||
device = 'cpu'
|
||||
|
|
|
@ -6,7 +6,10 @@ import sys
|
|||
import traceback
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import multiprocess as mp
|
||||
try:
|
||||
from torch import multiprocessing as mp
|
||||
except ImportError:
|
||||
import multiprocess as mp
|
||||
|
||||
from mmdeploy.utils.logging import get_logger
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from multiprocessing import Process
|
|||
import h5py
|
||||
import mmcv
|
||||
|
||||
from mmdeploy.apis import create_calib_table
|
||||
from mmdeploy.apis import create_calib_input_data
|
||||
|
||||
calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name
|
||||
ann_file = 'tests/data/annotation.json'
|
||||
|
@ -173,7 +173,7 @@ def get_model_cfg():
|
|||
def run_test_create_calib_end2end():
|
||||
model_cfg = get_model_cfg()
|
||||
deploy_cfg = get_end2end_deploy_cfg()
|
||||
create_calib_table(
|
||||
create_calib_input_data(
|
||||
calib_file,
|
||||
deploy_cfg,
|
||||
model_cfg,
|
||||
|
@ -205,7 +205,7 @@ def test_create_calib_end2end():
|
|||
def run_test_create_calib_parittion():
|
||||
model_cfg = get_model_cfg()
|
||||
deploy_cfg = get_partition_deploy_cfg()
|
||||
create_calib_table(
|
||||
create_calib_input_data(
|
||||
calib_file,
|
||||
deploy_cfg,
|
||||
model_cfg,
|
||||
|
|
|
@ -4,7 +4,7 @@ import tempfile
|
|||
import onnx
|
||||
import torch
|
||||
|
||||
from mmdeploy.apis import extract_model
|
||||
from mmdeploy.apis.onnx import extract_partition
|
||||
from mmdeploy.core import mark
|
||||
|
||||
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
|
@ -33,7 +33,7 @@ def test_extract():
|
|||
torch.onnx.export(model, (x, y), output_file)
|
||||
onnx_model = onnx.load(output_file)
|
||||
|
||||
extracted = extract_model(onnx_model, 'add:input', 'add:output')
|
||||
extracted = extract_partition(onnx_model, 'add:input', 'add:output')
|
||||
|
||||
assert extracted.graph.input[0].name == 'x'
|
||||
assert extracted.graph.input[1].name == 'y'
|
||||
|
|
|
@ -55,13 +55,14 @@ def generate_onnx_file(model):
|
|||
|
||||
@backend_checker(Backend.NCNN)
|
||||
def test_onnx2ncnn():
|
||||
from mmdeploy.apis.ncnn import onnx2ncnn
|
||||
from mmdeploy.apis.ncnn import from_onnx
|
||||
model = test_model
|
||||
generate_onnx_file(model)
|
||||
|
||||
work_dir, _ = osp.split(onnx_file)
|
||||
save_param, save_bin = get_output_model_file(onnx_file, work_dir=work_dir)
|
||||
onnx2ncnn(onnx_file, save_param, save_bin)
|
||||
file_name = osp.splitext(onnx_file)[0]
|
||||
from_onnx(onnx_file, osp.join(work_dir, file_name))
|
||||
assert osp.exists(work_dir)
|
||||
assert osp.exists(save_param)
|
||||
assert osp.exists(save_bin)
|
||||
|
|
|
@ -80,8 +80,8 @@ def get_deploy_cfg_with_mo_args():
|
|||
[get_base_deploy_cfg, get_deploy_cfg_with_mo_args])
|
||||
@backend_checker(Backend.OPENVINO)
|
||||
def test_onnx2openvino(get_deploy_cfg):
|
||||
from mmdeploy.apis.openvino import (get_mo_options_from_cfg,
|
||||
get_output_model_file, onnx2openvino)
|
||||
from mmdeploy.apis.openvino import (from_onnx, get_mo_options_from_cfg,
|
||||
get_output_model_file)
|
||||
pytorch_model = TestModel().eval()
|
||||
export_img = torch.rand([1, 3, 8, 8])
|
||||
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
|
@ -95,8 +95,7 @@ def test_onnx2openvino(get_deploy_cfg):
|
|||
openvino_dir = tempfile.TemporaryDirectory().name
|
||||
deploy_cfg = get_deploy_cfg()
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
onnx2openvino(input_info, output_names, onnx_file, openvino_dir,
|
||||
mo_options)
|
||||
from_onnx(onnx_file, openvino_dir, input_info, output_names, mo_options)
|
||||
openvino_model_path = get_output_model_file(onnx_file, openvino_dir)
|
||||
assert osp.exists(openvino_model_path), \
|
||||
'The file (.xml) for OpenVINO IR has not been created.'
|
||||
|
@ -117,8 +116,8 @@ def test_can_not_run_onnx2openvino_without_mo():
|
|||
|
||||
is_error = False
|
||||
try:
|
||||
from mmdeploy.apis.openvino import onnx2openvino
|
||||
onnx2openvino({}, ['output'], 'tmp.onnx', '/tmp')
|
||||
from mmdeploy.apis.openvino import from_onnx
|
||||
from_onnx('tmp.onnx', '/tmp', {}, ['output'])
|
||||
except RuntimeError:
|
||||
is_error = True
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ def generate_onnx_file(model):
|
|||
@backend_checker(Backend.TENSORRT)
|
||||
def test_onnx2tensorrt():
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
from mmdeploy.backend.tensorrt import load_trt_engine
|
||||
from mmdeploy.backend.tensorrt import load
|
||||
model = test_model
|
||||
generate_onnx_file(model)
|
||||
deploy_cfg = get_deploy_cfg()
|
||||
|
@ -85,5 +85,5 @@ def test_onnx2tensorrt():
|
|||
onnx2tensorrt(work_dir, save_file, 0, deploy_cfg, onnx_file)
|
||||
assert osp.exists(work_dir)
|
||||
assert osp.exists(engine_file)
|
||||
engine = load_trt_engine(engine_file)
|
||||
engine = load(engine_file)
|
||||
assert engine is not None
|
||||
|
|
|
@ -8,7 +8,9 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdeploy.apis import torch2onnx_impl
|
||||
from mmdeploy.apis.onnx import export
|
||||
from mmdeploy.utils.config_utils import (get_backend, get_dynamic_axes,
|
||||
get_onnx_config)
|
||||
from mmdeploy.utils.test import get_random_name
|
||||
|
||||
onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
|
@ -63,7 +65,33 @@ def get_deploy_cfg(input_name, output_name, dynamic_axes):
|
|||
[dynamic_axes_dict, dynamic_axes_list])
|
||||
def test_torch2onnx(input_name, output_name, dynamic_axes):
|
||||
deploy_cfg = get_deploy_cfg(input_name, output_name, dynamic_axes)
|
||||
torch2onnx_impl(test_model, test_img, deploy_cfg, onnx_file)
|
||||
|
||||
output_prefix = osp.splitext(onnx_file)[0]
|
||||
context_info = dict(cfg=deploy_cfg)
|
||||
backend = get_backend(deploy_cfg).value
|
||||
onnx_cfg = get_onnx_config(deploy_cfg)
|
||||
opset_version = onnx_cfg.get('opset_version', 11)
|
||||
|
||||
input_names = onnx_cfg['input_names']
|
||||
output_names = onnx_cfg['output_names']
|
||||
axis_names = input_names + output_names
|
||||
dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names)
|
||||
verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get(
|
||||
'verbose', False)
|
||||
keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs',
|
||||
True)
|
||||
export(
|
||||
test_model,
|
||||
test_img,
|
||||
context_info=context_info,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
verbose=verbose,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs)
|
||||
|
||||
assert osp.exists(onnx_file)
|
||||
|
||||
|
|
|
@ -84,4 +84,5 @@ def test_torch2torchscript(input_name, output_name):
|
|||
model_cfg=get_model_cfg(),
|
||||
device='cpu')
|
||||
|
||||
print(ts_file)
|
||||
assert osp.exists(ts_file)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
|
@ -49,28 +50,35 @@ def generate_onnx_file():
|
|||
def generate_torchscript_file():
|
||||
import mmcv
|
||||
|
||||
from mmdeploy.apis import torch2torchscript_impl
|
||||
deploy_cfg = mmcv.Config(
|
||||
{'backend_config': dict(type=Backend.TORCHSCRIPT.value)})
|
||||
with torch.no_grad():
|
||||
torch2torchscript_impl(model, torch.rand(1, 3, 8, 8), deploy_cfg,
|
||||
ts_file)
|
||||
backend = Backend.TORCHSCRIPT.value
|
||||
deploy_cfg = mmcv.Config({'backend_config': dict(type=backend)})
|
||||
|
||||
from mmdeploy.apis.torch_jit import trace
|
||||
context_info = dict(deploy_cfg=deploy_cfg)
|
||||
output_prefix = osp.splitext(ts_file)[0]
|
||||
|
||||
example_inputs = torch.rand(1, 3, 8, 8)
|
||||
trace(
|
||||
model,
|
||||
example_inputs,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
context_info=context_info)
|
||||
|
||||
|
||||
def onnx2backend(backend, onnx_file):
|
||||
if backend == Backend.TENSORRT:
|
||||
from mmdeploy.backend.tensorrt import (create_trt_engine,
|
||||
save_trt_engine)
|
||||
from mmdeploy.backend.tensorrt import from_onnx
|
||||
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
|
||||
engine = create_trt_engine(
|
||||
onnx_file, {
|
||||
from_onnx(
|
||||
onnx_file,
|
||||
osp.splitext(backend_file)[0], {
|
||||
'input': {
|
||||
'min_shape': [1, 3, 8, 8],
|
||||
'opt_shape': [1, 3, 8, 8],
|
||||
'max_shape': [1, 3, 8, 8]
|
||||
}
|
||||
})
|
||||
save_trt_engine(engine, backend_file)
|
||||
return backend_file
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
return onnx_file
|
||||
|
@ -87,13 +95,13 @@ def onnx2backend(backend, onnx_file):
|
|||
subprocess.call([onnx2ncnn_path, onnx_file, param_file, bin_file])
|
||||
return param_file, bin_file
|
||||
elif backend == Backend.OPENVINO:
|
||||
from mmdeploy.apis.openvino import get_output_model_file, onnx2openvino
|
||||
from mmdeploy.apis.openvino import from_onnx, get_output_model_file
|
||||
backend_dir = tempfile.TemporaryDirectory().name
|
||||
backend_file = get_output_model_file(onnx_file, backend_dir)
|
||||
input_info = {'input': test_img.shape}
|
||||
output_names = ['output']
|
||||
work_dir = backend_dir
|
||||
onnx2openvino(input_info, output_names, onnx_file, work_dir)
|
||||
from_onnx(onnx_file, work_dir, input_info, output_names)
|
||||
return backend_file
|
||||
|
||||
|
||||
|
|
|
@ -230,21 +230,27 @@ class TestNCNNExporter:
|
|||
tolerate_small_mismatch)
|
||||
|
||||
def onnx2ncnn(self, model, model_name, output_names, save_dir=None):
|
||||
if save_dir is None:
|
||||
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
ncnn_param_path = tempfile.NamedTemporaryFile(suffix='.param').name
|
||||
ncnn_bin_path = tempfile.NamedTemporaryFile(suffix='.bin').name
|
||||
else:
|
||||
|
||||
def _from_onnx(self, model, model_name, output_names, save_dir=None):
|
||||
onnx_file_path = os.path.join(save_dir, model_name + '.onnx')
|
||||
ncnn_param_path = os.path.join(save_dir, model_name + '.param')
|
||||
ncnn_bin_path = os.path.join(save_dir, model_name + '.bin')
|
||||
|
||||
onnx.save_model(model, onnx_file_path)
|
||||
onnx.save_model(model, onnx_file_path)
|
||||
|
||||
from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn
|
||||
onnx2ncnn(onnx_file_path, ncnn_param_path, ncnn_bin_path)
|
||||
from mmdeploy.backend.ncnn import from_onnx
|
||||
from_onnx(onnx_file_path, os.path.join(save_dir, model_name))
|
||||
|
||||
from mmdeploy.backend.ncnn import NCNNWrapper
|
||||
ncnn_model = NCNNWrapper(ncnn_param_path, ncnn_bin_path, output_names)
|
||||
from mmdeploy.backend.ncnn import NCNNWrapper
|
||||
ncnn_model = NCNNWrapper(ncnn_param_path, ncnn_bin_path,
|
||||
output_names)
|
||||
|
||||
return ncnn_model
|
||||
return ncnn_model
|
||||
|
||||
if save_dir is None:
|
||||
with tempfile.TemporaryDirectory() as save_dir:
|
||||
return _from_onnx(
|
||||
self, model, model_name, output_names, save_dir=save_dir)
|
||||
else:
|
||||
return _from_onnx(
|
||||
self, model, model_name, output_names, save_dir=save_dir)
|
||||
|
|
|
@ -10,9 +10,9 @@ import pytest
|
|||
import torch.multiprocessing as mp
|
||||
|
||||
import mmdeploy.utils as util
|
||||
from mmdeploy.backend.sdk.export_info import export2SDK
|
||||
from mmdeploy.utils import target_wrapper
|
||||
from mmdeploy.utils.constants import Backend, Codebase, Task
|
||||
from mmdeploy.utils.export_info import dump_info
|
||||
from mmdeploy.utils.test import get_random_name
|
||||
|
||||
correct_model_path = 'tests/data/srgan.py'
|
||||
|
@ -413,9 +413,11 @@ def test_AdvancedEnum():
|
|||
assert k.value == v
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('mmedit'), reason='requires mmedit')
|
||||
def test_export_info():
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
dump_info(correct_deploy_cfg, correct_model_cfg, dir, '')
|
||||
export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '')
|
||||
deploy_json = os.path.join(dir, 'deploy.json')
|
||||
pipeline_json = os.path.join(dir, 'pipeline.json')
|
||||
detail_json = os.path.join(dir, 'detail.json')
|
||||
|
|
130
tools/deploy.py
130
tools/deploy.py
|
@ -8,14 +8,15 @@ import mmcv
|
|||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import Process, set_start_method
|
||||
|
||||
from mmdeploy.apis import (create_calib_table, extract_model,
|
||||
from mmdeploy.apis import (create_calib_input_data, extract_model,
|
||||
get_predefined_partition_cfg, torch2onnx,
|
||||
torch2torchscript, visualize_model)
|
||||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||
from mmdeploy.backend.sdk.export_info import export2SDK
|
||||
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
|
||||
get_ir_config, get_model_inputs,
|
||||
get_partition_config, get_root_logger, load_config,
|
||||
target_wrapper)
|
||||
from mmdeploy.utils.export_info import dump_info
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -91,7 +92,14 @@ def main():
|
|||
args = parse_args()
|
||||
set_start_method('spawn')
|
||||
logger = get_root_logger()
|
||||
logger.setLevel(args.log_level)
|
||||
log_level = logging.getLevelName(args.log_level)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
pipeline_funcs = [
|
||||
torch2onnx, torch2torchscript, extract_model, create_calib_input_data
|
||||
]
|
||||
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
|
||||
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
|
||||
|
||||
deploy_cfg_path = args.deploy_cfg
|
||||
model_cfg_path = args.model_cfg
|
||||
|
@ -106,7 +114,7 @@ def main():
|
|||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
||||
if args.dump_info:
|
||||
dump_info(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path)
|
||||
export2SDK(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path)
|
||||
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
|
||||
|
@ -114,13 +122,14 @@ def main():
|
|||
ir_config = get_ir_config(deploy_cfg)
|
||||
ir_save_file = ir_config['save_file']
|
||||
ir_type = IR.get(ir_config['type'])
|
||||
create_process(
|
||||
f'torch2{ir_type.value}',
|
||||
target=torch2ir(ir_type),
|
||||
args=(args.img, args.work_dir, ir_save_file, deploy_cfg_path,
|
||||
model_cfg_path, checkpoint_path),
|
||||
kwargs=dict(device=args.device),
|
||||
ret_value=ret_value)
|
||||
torch2ir(ir_type)(
|
||||
args.img,
|
||||
args.work_dir,
|
||||
ir_save_file,
|
||||
deploy_cfg_path,
|
||||
model_cfg_path,
|
||||
checkpoint_path,
|
||||
device=args.device)
|
||||
|
||||
# convert backend
|
||||
ir_files = [osp.join(args.work_dir, ir_save_file)]
|
||||
|
@ -146,12 +155,12 @@ def main():
|
|||
end = partition_cfg['end']
|
||||
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
||||
|
||||
create_process(
|
||||
f'partition model {save_file} with start: {start}, end: {end}',
|
||||
extract_model,
|
||||
args=(origin_ir_file, start, end),
|
||||
kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path),
|
||||
ret_value=ret_value)
|
||||
extract_model(
|
||||
origin_ir_file,
|
||||
start,
|
||||
end,
|
||||
dynamic_axes=dynamic_axes,
|
||||
save_file=save_path)
|
||||
|
||||
ir_files.append(save_path)
|
||||
|
||||
|
@ -159,17 +168,14 @@ def main():
|
|||
calib_filename = get_calib_filename(deploy_cfg)
|
||||
if calib_filename is not None:
|
||||
calib_path = osp.join(args.work_dir, calib_filename)
|
||||
|
||||
create_process(
|
||||
'calibration',
|
||||
create_calib_table,
|
||||
args=(calib_path, deploy_cfg_path, model_cfg_path,
|
||||
checkpoint_path),
|
||||
kwargs=dict(
|
||||
dataset_cfg=args.calib_dataset_cfg,
|
||||
dataset_type='val',
|
||||
device=args.device),
|
||||
ret_value=ret_value)
|
||||
create_calib_input_data(
|
||||
calib_path,
|
||||
deploy_cfg_path,
|
||||
model_cfg_path,
|
||||
checkpoint_path,
|
||||
dataset_cfg=args.calib_dataset_cfg,
|
||||
dataset_type='val',
|
||||
device=args.device)
|
||||
|
||||
backend_files = ir_files
|
||||
# convert backend
|
||||
|
@ -179,10 +185,14 @@ def main():
|
|||
assert len(model_params) == len(ir_files)
|
||||
|
||||
from mmdeploy.apis.tensorrt import is_available as trt_is_available
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
assert trt_is_available(
|
||||
), 'TensorRT is not available,' \
|
||||
+ ' please install TensorRT and build TensorRT custom ops first.'
|
||||
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
PIPELINE_MANAGER.enable_multiprocess(True, [onnx2tensorrt])
|
||||
PIPELINE_MANAGER.set_log_level(logging.INFO, [onnx2tensorrt])
|
||||
|
||||
backend_files = []
|
||||
for model_id, model_param, onnx_path in zip(
|
||||
range(len(ir_files)), model_params, ir_files):
|
||||
|
@ -191,13 +201,14 @@ def main():
|
|||
|
||||
partition_type = 'end2end' if partition_cfgs is None \
|
||||
else onnx_name
|
||||
create_process(
|
||||
f'onnx2tensorrt of {onnx_path}',
|
||||
target=onnx2tensorrt,
|
||||
args=(args.work_dir, save_file, model_id, deploy_cfg_path,
|
||||
onnx_path),
|
||||
kwargs=dict(device=args.device, partition_type=partition_type),
|
||||
ret_value=ret_value)
|
||||
onnx2tensorrt(
|
||||
args.work_dir,
|
||||
save_file,
|
||||
model_id,
|
||||
deploy_cfg_path,
|
||||
onnx_path,
|
||||
device=args.device,
|
||||
partition_type=partition_type)
|
||||
|
||||
backend_files.append(osp.join(args.work_dir, save_file))
|
||||
|
||||
|
@ -208,18 +219,17 @@ def main():
|
|||
logger.error('ncnn support is not available.')
|
||||
exit(1)
|
||||
|
||||
from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn
|
||||
import mmdeploy.apis.ncnn as ncnn_api
|
||||
from mmdeploy.apis.ncnn import get_output_model_file
|
||||
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [ncnn_api.from_onnx])
|
||||
|
||||
backend_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_param_path, model_bin_path = get_output_model_file(
|
||||
onnx_path, args.work_dir)
|
||||
create_process(
|
||||
f'onnx2ncnn with {onnx_path}',
|
||||
target=onnx2ncnn,
|
||||
args=(onnx_path, model_param_path, model_bin_path),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
|
||||
ncnn_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name))
|
||||
|
||||
if quant:
|
||||
from onnx2ncnn_quant_table import get_table
|
||||
|
@ -256,23 +266,21 @@ def main():
|
|||
assert is_available_openvino(), \
|
||||
'OpenVINO is not available, please install OpenVINO first.'
|
||||
|
||||
import mmdeploy.apis.openvino as openvino_api
|
||||
from mmdeploy.apis.openvino import (get_input_info_from_cfg,
|
||||
get_mo_options_from_cfg,
|
||||
get_output_model_file,
|
||||
onnx2openvino)
|
||||
get_output_model_file)
|
||||
|
||||
PIPELINE_MANAGER.set_log_level(log_level, [openvino_api.from_onnx])
|
||||
|
||||
openvino_files = []
|
||||
for onnx_path in ir_files:
|
||||
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
|
||||
input_info = get_input_info_from_cfg(deploy_cfg)
|
||||
output_names = get_ir_config(deploy_cfg).output_names
|
||||
mo_options = get_mo_options_from_cfg(deploy_cfg)
|
||||
create_process(
|
||||
f'onnx2openvino with {onnx_path}',
|
||||
target=onnx2openvino,
|
||||
args=(input_info, output_names, onnx_path, args.work_dir,
|
||||
mo_options),
|
||||
kwargs=dict(),
|
||||
ret_value=ret_value)
|
||||
openvino_api.from_onnx(onnx_path, args.work_dir, input_info,
|
||||
output_names, mo_options)
|
||||
openvino_files.append(model_xml_path)
|
||||
backend_files = openvino_files
|
||||
|
||||
|
@ -281,7 +289,11 @@ def main():
|
|||
assert is_available_pplnn(), \
|
||||
'PPLNN is not available, please install PPLNN first.'
|
||||
|
||||
from mmdeploy.apis.pplnn import onnx2pplnn
|
||||
from mmdeploy.apis.pplnn import from_onnx
|
||||
|
||||
pplnn_pipeline_funcs = [from_onnx]
|
||||
PIPELINE_MANAGER.set_log_level(logging.INFO, pplnn_pipeline_funcs)
|
||||
|
||||
pplnn_files = []
|
||||
for onnx_path in ir_files:
|
||||
algo_file = onnx_path.replace('.onnx', '.json')
|
||||
|
@ -291,12 +303,12 @@ def main():
|
|||
# PPLNN accepts only 1 input shape for optimization,
|
||||
# may get changed in the future
|
||||
input_shapes = [model_inputs.opt_shape]
|
||||
create_process(
|
||||
f'onnx2pplnn with {onnx_path}',
|
||||
target=onnx2pplnn,
|
||||
args=(algo_file, onnx_path),
|
||||
kwargs=dict(device=args.device, input_shapes=input_shapes),
|
||||
ret_value=ret_value)
|
||||
algo_prefix = osp.splitext(algo_file)[0]
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
algo_prefix,
|
||||
device=args.device,
|
||||
input_shapes=input_shapes)
|
||||
pplnn_files += [onnx_path, algo_file]
|
||||
backend_files = pplnn_files
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import os.path as osp
|
|||
import onnx
|
||||
import onnx.helper
|
||||
|
||||
from mmdeploy.apis import extract_model
|
||||
from mmdeploy.apis.onnx import extract_partition
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ def main():
|
|||
marks = collect_avaiable_marks(model)
|
||||
logger.info('Available marks:\n {}'.format('\n '.join(marks)))
|
||||
|
||||
extracted_model = extract_model(model, args.start, args.end)
|
||||
extracted_model = extract_partition(model, args.start, args.end)
|
||||
|
||||
if osp.splitext(args.output_model)[-1] != '.onnx':
|
||||
args.output_model += '.onnx'
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
import argparse
|
||||
import logging
|
||||
|
||||
from mmdeploy.apis.ncnn import onnx2ncnn
|
||||
from mmdeploy.apis.ncnn import from_onnx
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert ONNX to ncnn.')
|
||||
parser.add_argument('onnx_path', help='ONNX model path')
|
||||
parser.add_argument('output_param', help='output ncnn param path')
|
||||
parser.add_argument('output_bin', help='output bin path')
|
||||
parser.add_argument('output_prefix', help='output ncnn model path')
|
||||
parser.add_argument(
|
||||
'--log-level',
|
||||
help='set log level',
|
||||
|
@ -26,12 +25,11 @@ def main():
|
|||
logger = get_root_logger(log_level=args.log_level)
|
||||
|
||||
onnx_path = args.onnx_path
|
||||
output_param = args.output_param
|
||||
output_bin = args.output_bin
|
||||
output_prefix = args.output_prefix
|
||||
|
||||
logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
|
||||
try:
|
||||
onnx2ncnn(onnx_path, output_param, output_bin)
|
||||
from_onnx(onnx_path, output_prefix)
|
||||
logger.info('onnx2ncnn success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
import collections
|
||||
import logging
|
||||
|
||||
from mmdeploy.apis.pplnn import onnx2pplnn
|
||||
from mmdeploy.apis.pplnn import from_onnx
|
||||
from mmdeploy.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Convert ONNX to PPLNN.')
|
||||
parser.add_argument('onnx_path', help='ONNX model path')
|
||||
parser.add_argument(
|
||||
'output_path', help='output PPLNN algorithm path in json format')
|
||||
'output_prefix', help='output PPLNN algorithm prefix in json format')
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
help='`the device of model during conversion',
|
||||
|
@ -36,7 +36,7 @@ def main():
|
|||
logger = get_root_logger(log_level=args.log_level)
|
||||
|
||||
onnx_path = args.onnx_path
|
||||
output_path = args.output_path
|
||||
output_prefix = args.output_prefix
|
||||
device = args.device
|
||||
|
||||
input_shapes = eval(args.opt_shapes)
|
||||
|
@ -50,10 +50,10 @@ def main():
|
|||
input_shapes = [input_shapes]
|
||||
|
||||
logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} '
|
||||
f'\n\toutput_path: {output_path}'
|
||||
f'\n\toutput_prefix: {output_prefix}'
|
||||
f'\n\topt_shapes: {input_shapes}')
|
||||
try:
|
||||
onnx2pplnn(output_path, onnx_path, device, input_shapes)
|
||||
from_onnx(onnx_path, output_prefix, device, input_shapes)
|
||||
logger.info('onnx2tpplnn success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import argparse
|
||||
import logging
|
||||
|
||||
from mmdeploy.backend.tensorrt import create_trt_engine, save_trt_engine
|
||||
from mmdeploy.backend.tensorrt import from_onnx
|
||||
from mmdeploy.backend.tensorrt.utils import get_trt_log_level
|
||||
from mmdeploy.utils import (get_common_config, get_model_inputs,
|
||||
get_root_logger, load_config)
|
||||
|
@ -12,7 +12,7 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Convert ONNX to TensorRT.')
|
||||
parser.add_argument('deploy_cfg', help='deploy config path')
|
||||
parser.add_argument('onnx_path', help='ONNX model path')
|
||||
parser.add_argument('output', help='output TensorRT engine path')
|
||||
parser.add_argument('output_prefix', help='output TensorRT engine prefix')
|
||||
parser.add_argument('--device-id', help='`the CUDA device id', default=0)
|
||||
parser.add_argument(
|
||||
'--calib-file',
|
||||
|
@ -35,7 +35,7 @@ def main():
|
|||
deploy_cfg_path = args.deploy_cfg
|
||||
deploy_cfg = load_config(deploy_cfg_path)[0]
|
||||
onnx_path = args.onnx_path
|
||||
output_path = args.output
|
||||
output_prefix = args.output_prefix
|
||||
device_id = args.device_id
|
||||
calib_file = args.calib_file
|
||||
|
||||
|
@ -56,8 +56,9 @@ def main():
|
|||
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
|
||||
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
||||
try:
|
||||
engine = create_trt_engine(
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
output_prefix,
|
||||
input_shapes=final_params['input_shapes'],
|
||||
log_level=get_trt_log_level(),
|
||||
fp16_mode=final_params.get('fp16_mode', False),
|
||||
|
@ -66,7 +67,6 @@ def main():
|
|||
max_workspace_size=final_params.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
|
||||
save_trt_engine(engine, output_path)
|
||||
logger.info('onnx2tensorrt success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
Loading…
Reference in New Issue