[Docstring]: Coding style and docstring revision for mmdeploy.apis (#87)

* check style of mmdeploy.apis.ncnn

* finish check style with mmdeploy.apis.onnxruntime

* check mmdeploy.apis.ppl

* check mmdeploy.apis.tensorrt

* update docstring for mmdeploy.apis

* update some docstring

* make style consistent

* update

* resolve comments

* resolve comments
pull/12/head
RunningLeon 2021-09-24 10:40:39 +08:00 committed by GitHub
parent 5453f9befa
commit de096d5f00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 580 additions and 147 deletions

View File

@ -20,13 +20,23 @@ def create_calib_table(calib_file: str,
dataset_type: str = 'val',
device: str = 'cuda:0',
**kwargs) -> None:
"""Create calibration table.
Args:
calib_file (str): Input calibration file.
deploy_cfg (str | mmcv.Config): Deployment config.
model_cfg (str | mmcv.Config): The model config.
model_checkpoint (str): PyTorch model checkpoint, defaults to `None`.
dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None`
dataset_type (str): A string specifying dataset type, e.g.: 'test',
'val', defaults to 'val'.
device (str): Specifying the device to run on, defaults to 'cuda:0'.
"""
if dataset_cfg is None:
dataset_cfg = model_cfg
# load cfg if necessary
deploy_cfg = load_config(deploy_cfg)[0]
model_cfg = load_config(model_cfg)[0]
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
device_id = torch.device(device).index
if device_id is None:
device_id = 0

View File

@ -17,7 +17,27 @@ def extract_model(model: Union[str, onnx.ModelProto],
end_name_map: Optional[Dict[str, str]] = None,
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
save_file: Optional[str] = None):
"""Extract sub-model from an ONNX model.
The sub-model is defined by the names of the input and output tensors
exactly.
Args:
model (str | onnx.ModelProto): Input ONNX model to be extracted.
start (str | Sequence[str]): Start marker(s) to extract.
end (str | Sequence[str]): End marker(s) to extract.
start_name_map (Dict[str, str]): A mapping of start names, defaults to
`None`.
end_name_map (Dict[str, str]): A mapping of end names, defaults to
`None`.
dynamic_axes (Dict[str, Dict[int, str]]): A dictionary to specify
dynamic axes of input/output, defaults to `None`.
save_file (str): A file to save the extracted model, defaults to
`None`.
Returns:
onnx.ModelProto: The extracted sub-model.
"""
if isinstance(model, str):
model = onnx.load(model)

View File

@ -1,5 +1,7 @@
from typing import Optional
from typing import Optional, Sequence, Union
import mmcv
import numpy as np
import torch
from mmdeploy.utils import (Backend, get_backend, get_codebase,
@ -8,15 +10,29 @@ from .utils import (create_input, init_backend_model, init_pytorch_model,
run_inference, visualize)
def inference_model(model_cfg,
deploy_cfg,
model,
img,
def inference_model(model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config],
model: Union[str, Sequence[str], torch.nn.Module],
img: Union[str, np.ndarray],
device: str,
backend: Optional[Backend] = None,
output_file: Optional[str] = None,
show_result=False):
show_result: bool = False):
"""Run inference with PyTorch or backend model and show results.
Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config
object.
model (str | list[str], torch.nn.Module): Input model or file(s).
img (str | np.ndarray): Input image file or numpy array for inference.
device (str): A string specifying device type.
backend (Backend): Specifying backend type, defaults to `None`.
output_file (str): Output file to save visualized image, defaults to
`None`. Only valid if `show_result` is set to `False`.
show_result (bool): Whether to show plotted image in windows, defaults
to `False`.
"""
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
codebase = get_codebase(deploy_cfg)
@ -27,7 +43,12 @@ def inference_model(model_cfg,
if isinstance(model, str):
model = [model]
if isinstance(model, (list, tuple)):
assert len(model) > 0, 'Model should have at least one element.'
assert all([isinstance(m, str) for m in model]), 'All elements in the \
list should be str'
if backend == Backend.PYTORCH:
model = init_pytorch_model(codebase, model_cfg, model[0], device)
else:

View File

@ -7,6 +7,11 @@ __all__ = ['get_ops_path', 'get_onnx2ncnn_path']
def is_available():
"""Check whether ncnn with extension is installed.
Returns:
bool: True if ncnn and its extension are installed.
"""
ncnn_ops_path = get_ops_path()
if not osp.exists(ncnn_ops_path):
return False

View File

@ -3,7 +3,11 @@ import os
def get_ops_path():
"""Get NCNN custom ops library path."""
"""Get NCNN custom ops library path.
Returns:
str: The library path of NCNN custom ops.
"""
wildcard = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
@ -15,7 +19,11 @@ def get_ops_path():
def get_onnx2ncnn_path():
"""Get onnx2ncnn path."""
"""Get onnx2ncnn path.
Returns:
str: A path of onnx2ncnn tool.
"""
wildcard = os.path.abspath(
os.path.join(
os.path.dirname(__file__), '../../../build/bin/onnx2ncnn'))

View File

@ -9,11 +9,24 @@ from mmdeploy.utils.timer import TimeCounter
class NCNNWrapper(torch.nn.Module):
"""NCNN Wrapper.
"""NCNN wrapper class for inference.
Arguments:
param_file (str): param file path
bin_file (str): bin file path
Args:
param_file (str): Path of a parameter file.
bin_file (str): Path of a binary file.
output_names (list[str] | tuple[str]): Names to model outputs. Defaults
to `None`.
Examples:
>>> from mmdeploy.apis.ncnn import NCNNWrapper
>>> import torch
>>>
>>> param_file = 'model.params'
>>> bin_file = 'model.bin'
>>> model = NCNNWrapper(param_file, bin_file)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self,
@ -31,10 +44,20 @@ class NCNNWrapper(torch.nn.Module):
self._net = net
self._output_names = output_names
def set_output_names(self, output_names):
def set_output_names(self, output_names: Iterable[str]):
"""Set names of the model outputs.
Args:
output_names (list[str] | tuple[str]): Names to model outputs.
"""
self._output_names = output_names
def get_output_names(self):
"""Get names of the model outputs.
Returns:
list[str]: Names to model outputs.
"""
if self._output_names is not None:
return self._output_names
else:
@ -42,12 +65,21 @@ class NCNNWrapper(torch.nn.Module):
return self._net.output_names()
def forward(self, inputs: Dict[str, torch.Tensor]):
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): Key-value pairs of model inputs.
Returns:
Dict[str, torch.Tensor]: Key-value pairs of model outputs.
"""
input_list = list(inputs.values())
batch_size = input_list[0].size(0)
for tensor in input_list[1:]:
assert tensor.size(
0) == batch_size, 'All tensor should have same batch size'
assert tensor.device.type == 'cpu', 'NCNN only support cpu device'
for input_tensor in input_list[1:]:
assert input_tensor.size(
0) == batch_size, 'All tensors should have same batch size'
assert input_tensor.device.type == 'cpu', \
'NCNN only supports cpu device'
# set output names
output_names = self.get_output_names()
@ -55,30 +87,42 @@ class NCNNWrapper(torch.nn.Module):
# create output dict
outputs = dict([name, [None] * batch_size] for name in output_names)
# inference
# run inference
for batch_id in range(batch_size):
# create extractor
ex = self._net.create_extractor()
# set input
for k, v in inputs.items():
in_data = ncnn.Mat(v[batch_id].detach().cpu().numpy())
ex.input(k, in_data)
# set inputs
for name, input_tensor in inputs.items():
input_mat = ncnn.Mat(
input_tensor[batch_id].detach().cpu().numpy())
ex.input(name, input_mat)
# get output
# get outputs
result = self.ncnn_execute(extractor=ex, output_names=output_names)
for name in output_names:
outputs[name][batch_id] = torch.from_numpy(
np.array(result[name]))
# stack outputs together
for k, v in outputs.items():
outputs[k] = torch.stack(v)
for name, input_tensor in outputs.items():
outputs[name] = torch.stack(input_tensor)
return outputs
@TimeCounter.count_time()
def ncnn_execute(self, extractor, output_names):
def ncnn_execute(self, extractor: ncnn.Extractor,
output_names: Iterable[str]):
"""Run inference with NCNN.
Args:
extractor (ncnn.Extractor): NCNN extractor to extract output.
output_names (Iterable[str]): A list of string specifying
output names.
Returns:
dict[str, ncnn.Mat]: Inference results of NCNN model.
"""
result = {}
for name in output_names:
out_ret, out = extractor.extract(name)

View File

@ -5,6 +5,12 @@ from .init_plugins import get_ops_path
def is_available():
"""Check whether onnxruntime and its custom ops are installed.
Returns:
bool: True if onnxruntime package is installed and its
custom ops are compiled.
"""
onnxruntime_op_path = get_ops_path()
if not osp.exists(onnxruntime_op_path):
return False

View File

@ -3,7 +3,11 @@ import os
def get_ops_path():
"""Get ONNX Runtime plugins library path."""
"""Get the library path of onnxruntime custom ops.
Returns:
str: The library path to onnxruntime custom ops.
"""
wildcard = os.path.abspath(
os.path.join(
os.path.dirname(__file__),

View File

@ -1,5 +1,6 @@
import logging
import os.path as osp
from typing import Sequence
from typing import Dict, Sequence
import numpy as np
import onnxruntime as ort
@ -10,11 +11,22 @@ from .init_plugins import get_ops_path
class ORTWrapper(torch.nn.Module):
"""ONNXRuntime Wrapper.
"""ONNXRuntime wrapper for inference.
Arguments:
onnx_file (str): Input onnx model file
device_id (int): The device id to put model
Args:
onnx_file (str): Input onnx model file.
device_id (int): The device id to input model.
output_names (list[str] | tuple[str]): Names to model outputs.
Examples:
>>> from mmdeploy.apis.onnxruntime import ORTWrapper
>>> import torch
>>>
>>> onnx_file = 'model.onnx'
>>> model = ORTWrapper(onnx_file, -1)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224, device='cpu'))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self,
@ -28,6 +40,12 @@ class ORTWrapper(torch.nn.Module):
# register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
logging.info(f'Successfully loaded onnxruntime custom ops from \
{ort_custom_op_path}')
else:
logging.warning(f'The library of onnxruntime custom ops does \
not exist: {ort_custom_op_path}')
sess = ort.InferenceSession(onnx_file, session_options)
providers = ['CPUExecutionProvider']
@ -46,13 +64,14 @@ class ORTWrapper(torch.nn.Module):
self.is_cuda_available = is_cuda_available
self.device_type = 'cuda' if is_cuda_available else 'cpu'
def forward(self, inputs):
"""
Arguments:
inputs (dict): the input name and tensor pairs
input_names: list of input name
Return:
list[np.ndarray]: list of output numpy array
def forward(self, inputs: Dict[str, torch.Tensor]):
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.
Returns:
list[np.ndarray]: A list of output numpy array.
"""
for name, input_tensor in inputs.items():
# set io binding for inputs/outputs
@ -75,5 +94,11 @@ class ORTWrapper(torch.nn.Module):
return outputs
@TimeCounter.count_time()
def ort_execute(self, io_binding):
def ort_execute(self, io_binding: ort.IOBinding):
"""Run inference with ONNXRuntime session.
Args:
io_binding (ort.IOBinding): To bind input/output to a specified
device, e.g. GPU.
"""
self.sess.run_with_iobinding(io_binding)

View File

@ -2,7 +2,11 @@ import importlib
def is_available():
"""Check whether ppl is installed."""
"""Check whether ppl is installed.
Returns:
bool: True if ppl package is installed.
"""
return importlib.util.find_spec('pyppl') is not None

View File

@ -1,5 +1,6 @@
import logging
import sys
from typing import Dict
import numpy as np
import pyppl.common as pplcommon
@ -15,9 +16,14 @@ def register_engines(device_id: int,
"""Register engines for ppl runtime.
Args:
device_id (int): -1 for cpu.
device_id (int): Specifying device index. `-1` for cpu.
disable_avx512 (bool): Whether to disable avx512 for x86.
Defaults to `False`.
quick_select (bool): Whether to use default algorithms.
Defaults to `False`.
Returns:
list[pplnn.Engine]: A list of registered ppl engines.
"""
engines = []
if device_id == -1:
@ -59,11 +65,21 @@ def register_engines(device_id: int,
class PPLWrapper(torch.nn.Module):
"""PPLWrapper Wrapper.
"""PPL wrapper for inference.
Arguments:
model_file (str): Input onnx model file
device_id (int): The device id to put model
Args:
model_file (str): Input onnx model file.
device_id (int): Device id to put model.
Examples:
>>> from mmdeploy.apis.ppl import PPLWrapper
>>> import torch
>>>
>>> onnx_file = 'model.onnx'
>>> model = PPLWrapper(onnx_file, 0)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self, model_file: str, device_id: int):
@ -87,14 +103,16 @@ class PPLWrapper(torch.nn.Module):
for i in range(runtime.GetInputCount())
}
def forward(self, input_data):
"""
Arguments:
input_data (dict): the input name and tensor pairs
def forward(self, inputs: Dict[str, torch.Tensor]):
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): Input name and tensor pairs.
Return:
list[np.ndarray]: list of output numpy array
list[np.ndarray]: A list of output numpy array.
"""
for name, input_tensor in input_data.items():
for name, input_tensor in inputs.items():
input_tensor = input_tensor.contiguous()
self.inputs[name].ConvertFromHost(input_tensor.cpu().numpy())
self.ppl_execute()
@ -106,6 +124,7 @@ class PPLWrapper(torch.nn.Module):
@TimeCounter.count_time()
def ppl_execute(self):
"""Run inference with PPL."""
status = self.runtime.Run()
assert status == pplcommon.RC_SUCCESS, 'Run() '\
'failed: ' + pplcommon.GetRetCodeStr(status)

View File

@ -13,12 +13,17 @@ from .utils import create_input, init_pytorch_model
def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
deploy_cfg: Union[str, mmcv.Config], output_file: str):
"""Converting torch model to ONNX.
Args:
model (torch.nn.Module): Input pytorch model.
input (torch.Tensor): Input tensor used to convert model.
deploy_cfg (str | mmcv.Config): Deployment config file or
Config object.
output_file (str): Output file to save ONNX model.
"""
# load deploy_cfg if needed
if isinstance(deploy_cfg, str):
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
if not isinstance(deploy_cfg, mmcv.Config):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(deploy_cfg)}')
deploy_cfg = load_config(deploy_cfg)[0]
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
backend = get_backend(deploy_cfg).value
@ -51,7 +56,20 @@ def torch2onnx(img: Any,
model_cfg: Union[str, mmcv.Config],
model_checkpoint: Optional[str] = None,
device: str = 'cuda:0'):
"""Convert PyToch model to ONNX model.
Args:
img (str | np.ndarray | torch.Tensor): Input image used to assist
converting model.
work_dir (str): A working directory to save files.
save_file (str): Filename to save onnx model.
deploy_cfg (str | mmcv.Config): Deployment config file or
Config object.
model_cfg (str | mmcv.Config): Model config file or Config object.
model_checkpoint (str): A checkpoint path of PyTorch model,
defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'.
"""
# load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmcv.mkdir_or_exist(osp.abspath(work_dir))

View File

@ -6,6 +6,11 @@ from .init_plugins import get_ops_path, load_tensorrt_plugin
def is_available():
"""Check whether TensorRT and plugins are installed.
Returns:
bool: True if TensorRT and plugins are installed.
"""
tensorrt_op_path = get_ops_path()
if not osp.exists(tensorrt_op_path):
return False

View File

@ -1,3 +1,5 @@
from typing import Dict, Sequence, Union
import h5py
import numpy as np
import tensorrt as trt
@ -7,14 +9,26 @@ DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2
class HDF5Calibrator(trt.IInt8Calibrator):
"""HDF5 calibrator.
def __init__(self,
calib_file,
opt_shape_dict,
model_type='end2end',
device_id=0,
algorithm=DEFAULT_CALIBRATION_ALGORITHM,
**kwargs):
Args:
calib_file (str | h5py.File): Input calibration file.
input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of
each input.
model_type (str): Input model type, defaults to 'end2end'.
device_id (int): Cuda device id, defaults to 0.
algorithm (trt.CalibrationAlgoType): Calibration algo type, defaults
to `trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2`.
"""
def __init__(
self,
calib_file: Union[str, h5py.File],
input_shapes: Dict[str, Sequence[int]],
model_type: str = 'end2end',
device_id: int = 0,
algorithm: trt.CalibrationAlgoType = DEFAULT_CALIBRATION_ALGORITHM,
**kwargs):
super().__init__()
if isinstance(calib_file, str):
@ -29,7 +43,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
self.calib_data = calib_data
self.device_id = device_id
self.algorithm = algorithm
self.opt_shape_dict = opt_shape_dict
self.input_shapes = input_shapes
self.kwargs = kwargs
# create buffers that will hold data batches
@ -45,7 +59,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
if hasattr(self, 'calib_file'):
self.calib_file.close()
def get_batch(self, names, **kwargs):
def get_batch(self, names: Sequence[str], **kwargs):
if self.count < self.dataset_length:
ret = []
@ -55,7 +69,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
data_torch = torch.from_numpy(data_np)
# tile the tensor so we can keep the same distribute
opt_shape = self.opt_shape_dict[name]['opt_shape']
opt_shape = self.input_shapes[name]['opt_shape']
data_shape = data_torch.shape
reps = [
@ -87,7 +101,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
return self.batch_size
def read_calibration_cache(self, *args, **kwargs):
return None
pass
def write_calibration_cache(self, cache, *args, **kwargs):
pass

View File

@ -5,7 +5,11 @@ import os
def get_ops_path():
"""Get TensorRT plugins library path."""
"""Get path of the TensorRT plugin library.
Returns:
str: A path of the TensorRT plugin library.
"""
wildcard = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
@ -17,11 +21,18 @@ def get_ops_path():
def load_tensorrt_plugin():
"""load TensorRT plugins library."""
"""Load TensorRT plugins library.
Returns:
bool: True if TensorRT plugin library is successfully loaded.
"""
lib_path = get_ops_path()
success = False
if os.path.exists(lib_path):
ctypes.CDLL(lib_path)
return 0
logging.info(f'Successfully loaded tensorrt plugins from {lib_path}')
success = True
else:
logging.warning('Can not load tensorrt custom ops.')
return -1
logging.warning(f'Could not load the library of tensorrt plugins. \
Because the file does not exist: {lib_path}')
return success

View File

@ -10,7 +10,16 @@ from mmdeploy.utils import (get_calib_filename, get_common_config,
from .tensorrt_utils import create_trt_engine, save_trt_engine
def parse_device_id(device: str):
def parse_device_id(device: str) -> int:
"""Parse cuda device index from a string.
Args:
device (str): The typical style of string specifying cuda device,
e.g.: 'cuda:0'.
Returns:
int: The parsed device id, defaults to `0`.
"""
device_id = 0
if len(device) >= 6:
device_id = int(device[5:])
@ -25,6 +34,18 @@ def onnx2tensorrt(work_dir: str,
device: str = 'cuda:0',
partition_type: str = 'end2end',
**kwargs):
"""Convert ONNX to TensorRT.
Args:
work_dir (str): A working directory.
save_file (str): File path to save TensorRT engine.
model_id (int): Index of input model.
deploy_cfg (str | mmcv.Config): Deployment config.
onnx_model (str | onnx.ModelProto): input onnx model.
device (str): A string specifying cuda device, defaults to 'cuda:0'.
partition_type (str): Specifying partition type of a model, defaults to
'end2end'.
"""
# load deploy_cfg if necessary
deploy_cfg = load_config(deploy_cfg)[0]
@ -43,11 +64,13 @@ def onnx2tensorrt(work_dir: str,
int8_param['calib_file'] = osp.join(work_dir, calib_file)
int8_param['model_type'] = partition_type
assert device.startswith('cuda'), 'TensorRT require cuda device.'
assert device.startswith('cuda'), f'TensorRT requires cuda device, \
but given: {device}'
device_id = parse_device_id(device)
engine = create_trt_engine(
onnx_model,
opt_shape_dict=final_params['input_shapes'],
input_shapes=final_params['input_shapes'],
log_level=final_params.get('log_level', trt.Logger.WARNING),
fp16_mode=final_params.get('fp16_mode', False),
int8_mode=final_params.get('int8_mode', False),

View File

@ -1,3 +1,5 @@
from typing import Dict, Sequence, Union
import onnx
import tensorrt as trt
import torch
@ -7,37 +9,42 @@ from mmdeploy.utils.timer import TimeCounter
from .calib_utils import HDF5Calibrator
def create_trt_engine(onnx_model,
opt_shape_dict,
log_level=trt.Logger.ERROR,
fp16_mode=False,
int8_mode=False,
int8_param=None,
max_workspace_size=0,
device_id=0,
def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
input_shapes: Dict[str, Sequence[int]],
log_level: trt.Logger.Severity = trt.Logger.ERROR,
fp16_mode: bool = False,
int8_mode: bool = False,
int8_param: dict = None,
max_workspace_size: int = 0,
device_id: int = 0,
**kwargs):
"""Create a tensorrt engine from ONNX.
Arguments:
onnx_model (str or onnx.ModelProto): the onnx model to convert from
opt_shape_dict (dict): the min/opt/max shape of each input
log_level (TensorRT log level): the log level of TensorRT
fp16_mode (bool): enable fp16 mode
int8_mode (bool): enable int8 mode
int8_param (None|dict): parameter of int8 mode.
max_workspace_size (int): set max workspace size of TensorRT engine.
some tactic and layers need large workspace.
device_id (int): choice the device to create engine.
Args:
onnx_model (str or onnx.ModelProto): Input onnx model to convert from.
input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of
each input.
log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to
`trt.Logger.ERROR`.
fp16_mode (bool): Specifying whether to enable fp16 mode.
Defaults to `False`.
int8_mode (bool): Specifying whether to enable int8 mode.
Defaults to `False`.
int8_param (dict): A dict of parameter int8 mode. Defaults to `None`.
max_workspace_size (int): To set max workspace size of TensorRT engine.
some tactics and layers need large workspace. Defaults to `0`.
device_id (int): Choice the device to create engine. Defaults to `0`.
Returns:
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
tensorrt.ICudaEngine: The TensorRT engine created from onnx_model.
Example:
>>> from mmdeploy.apis.tensorrt import create_trt_engine
>>> engine = create_trt_engine(
>>> "onnx_model.onnx",
>>> {'input': {"min_shape" : [1, 3, 160, 160],
>>> "opt_shape" :[1, 3, 320, 320],
>>> "max_shape" :[1, 3, 640, 640]}},
>>> "opt_shape" : [1, 3, 320, 320],
>>> "max_shape" : [1, 3, 640, 640]}},
>>> log_level=trt.Logger.WARNING,
>>> fp16_mode=True,
>>> max_workspace_size=1 << 30,
@ -62,7 +69,7 @@ def create_trt_engine(onnx_model,
error_msgs = ''
for error in range(parser.num_errors):
error_msgs += f'{parser.get_error(error)}\n'
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
raise RuntimeError(f'Failed to parse onnx: {onnx_model}\n{error_msgs}')
# config builder
if version.parse(trt.__version__) < version.parse('8'):
@ -72,7 +79,7 @@ def create_trt_engine(onnx_model,
config.max_workspace_size = max_workspace_size
profile = builder.create_optimization_profile()
for input_name, param in opt_shape_dict.items():
for input_name, param in input_shapes.items():
min_shape = param['min_shape']
opt_shape = param['opt_shape']
max_shape = param['max_shape']
@ -88,7 +95,7 @@ def create_trt_engine(onnx_model,
assert int8_param is not None
config.int8_calibrator = HDF5Calibrator(
int8_param['calib_file'],
opt_shape_dict,
input_shapes,
model_type=int8_param['model_type'],
device_id=device_id,
algorithm=int8_param.get(
@ -100,29 +107,29 @@ def create_trt_engine(onnx_model,
# create engine
with torch.cuda.device(device):
engine = builder.build_engine(network, config)
assert engine is not None, f'Failed to create engine from {onnx_model}'
return engine
def save_trt_engine(engine, path):
def save_trt_engine(engine: trt.ICudaEngine, path: str):
"""Serialize TensorRT engine to disk.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
path (str): disk path to write the engine
Args:
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
path (str): The disk path to write the engine.
"""
with open(path, mode='wb') as f:
f.write(bytearray(engine.serialize()))
def load_trt_engine(path):
def load_trt_engine(path: str):
"""Deserialize TensorRT engine from disk.
Arguments:
path (str): disk path to read the engine
Args:
path (str): The disk path to read the engine.
Returns:
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
tensorrt.ICudaEngine: The TensorRT engine loaded from disk.
"""
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(path, mode='rb') as f:
@ -131,8 +138,16 @@ def load_trt_engine(path):
return engine
def torch_dtype_from_trt(dtype):
"""Convert pytorch dtype to TensorRT dtype."""
def torch_dtype_from_trt(dtype: trt.DataType):
"""Convert pytorch dtype to TensorRT dtype.
Args:
dtype (str.DataType): The data type in tensorrt.
Returns:
torch.dtype: The corresponding data type in torch.
"""
if dtype == trt.bool:
return torch.bool
elif dtype == trt.int8:
@ -144,38 +159,54 @@ def torch_dtype_from_trt(dtype):
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError('%s is not supported by torch' % dtype)
raise TypeError(f'{dtype} is not supported by torch')
def torch_device_from_trt(device):
"""Convert pytorch device to TensorRT device."""
def torch_device_from_trt(device: trt.TensorLocation):
"""Convert pytorch device to TensorRT device.
Args:
device (trt.TensorLocation): The device in tensorrt.
Returns:
torch.device: The corresponding device in torch.
"""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
elif device == trt.TensorLocation.HOST:
return torch.device('cpu')
else:
return TypeError('%s is not supported by torch' % device)
return TypeError(f'{device} is not supported by torch')
class TRTWrapper(torch.nn.Module):
"""TensorRT engine Wrapper.
"""TensorRT engine wrapper for inference.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
Args:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap.
Note:
If the engine is converted from onnx model. The input_names and
output_names should be the same as onnx model.
Examples:
>>> from mmdeploy.apis.tensorrt import TRTWrapper
>>> engine_file = 'resnet.engine'
>>> model = TRTWrapper(engine_file)
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""
def __init__(self, engine):
def __init__(self, engine: Union[str, trt.ICudaEngine]):
super(TRTWrapper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine')
raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
but given: {type(self.engine)}')
self._register_state_dict_hook(TRTWrapper._on_state_dict)
self.context = self.engine.create_execution_context()
@ -206,13 +237,14 @@ class TRTWrapper(torch.nn.Module):
self.input_names = state_dict[prefix + 'input_names']
self.output_names = state_dict[prefix + 'output_names']
def forward(self, inputs):
"""
Arguments:
inputs (dict): dict of input name-tensors pair
def forward(self, inputs: Dict[str, torch.Tensor]):
"""Run forward inference.
Args:
inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.
Return:
dict: dict of output name-tensors pair
Dict[str, torch.Tensor]: The output name and tensor pairs.
"""
assert self.input_names is not None
assert self.output_names is not None
@ -243,6 +275,11 @@ class TRTWrapper(torch.nn.Module):
return outputs
@TimeCounter.count_time()
def trt_execute(self, bindings):
def trt_execute(self, bindings: Sequence[int]):
"""Run inference with TensorRT.
Args:
bindings (list[int]): A list of integer binding the input/output.
"""
self.context.execute_async_v2(bindings,
torch.cuda.current_stream().cuda_stream)

View File

@ -1,20 +1,35 @@
import warnings
from typing import Any
from typing import Optional
import mmcv
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import Codebase
def single_gpu_test(codebase: str,
def single_gpu_test(codebase: Codebase,
model: nn.Module,
data_loader: DataLoader,
show: bool = False,
out_dir: Any = None,
out_dir: Optional[str] = None,
show_score_thr: float = 0.3):
"""Run test with single gpu.
Args:
codebase (Codebase): Specifying codebase type.
model (torch.nn.Module): Input model from nn.Module.
data_loader (DataLoader): PyTorch data loader.
show (bool): Specifying whether to show plotted results. Defaults
to `False`.
out_dir (str): A directory to save results, defaults to `None`.
show_score_thr (float): A threshold to show detection results,
defaults to `0.3`.
Returns:
list: The prediction results.
"""
if codebase == Codebase.MMCLS:
from mmcls.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir)
@ -36,14 +51,35 @@ def single_gpu_test(codebase: str,
return outputs
def post_process_outputs(outputs,
dataset,
def post_process_outputs(outputs: list,
dataset: Dataset,
model_cfg: mmcv.Config,
codebase: str,
metrics: str = None,
out: str = None,
metric_options: dict = None,
codebase: Codebase,
metrics: Optional[str] = None,
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False):
"""Perform post-processing to predictions of model.
Args:
outputs (list): A list of predictions of model inference.
dataset (Dataset): Input dataset to run test.
model_cfg (mmcv.Config): The model config.
codebase (Codebase): Specifying codebase type.
metrics (str): Evaluation metrics, which depends on
the codebase and the dataset, e.g., "bbox", "segm", "proposal"
for COCO, and "mAP", "recall" for PASCAL VOC in mmdet; "accuracy",
"precision", "recall", "f1_score", "support" for single label
dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for
multi-label dataset in mmcls. Defaults is `None`.
out (str): Output result file in pickle format, defaults to `None`.
metric_options (dict): Custom options for evaluation, will be kwargs
for dataset.evaluate() function. Defaults to `None`.
format_only (bool): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server. Defaults
to `False`.
"""
if codebase == Codebase.MMCLS:
if metrics:
results = dataset.evaluate(outputs, metrics, metric_options)
@ -121,7 +157,7 @@ def post_process_outputs(outputs,
print(f'\nwriting results to {out}')
mmcv.dump(outputs, out)
# The Dataset doesn't need metrics
print('')
print('\n')
# print metrics
stats = dataset.evaluate(outputs)
for stat in stats:

View File

@ -2,6 +2,8 @@ from typing import Any, Dict, Optional, Sequence, Union
import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset
from mmdeploy.utils import Backend, Codebase, Task, get_codebase, load_config
@ -11,6 +13,19 @@ def init_pytorch_model(codebase: Codebase,
model_checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[Dict] = None):
"""Initialize torch model.
Args:
codebase (Codebase): Specifying codebase type.
model_cfg (str | mmcv.Config): Model config file or Config object.
model_checkpoint (str): The checkpoint file of torch model, defaults
to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'.
cfg_options (dict): Optional config key-pair parameters.
Returns:
nn.Module: An initialized torch model.
"""
if codebase == Codebase.MMCLS:
from mmcls.apis import init_model
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
@ -46,6 +61,22 @@ def create_input(codebase: Codebase,
input_shape: Sequence[int] = None,
device: str = 'cuda:0',
**kwargs):
"""Create input for model.
Args:
codebase (Codebase): Specifying codebase type.
task (Task): Specifying task type.
model_cfg (str | mmcv.Config): model config file or loaded Config
object.
imgs (str | np.ndarray): Input image(s).
input_shape (list[int]): Input shape of image in (width, height)
format, defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'.
Returns:
tuple: (data, img), meta information for the input image and input
image tensor.
"""
model_cfg = load_config(model_cfg)[0]
cfg = model_cfg.copy()
@ -78,6 +109,19 @@ def init_backend_model(model_files: Sequence[str],
deploy_cfg: Union[str, mmcv.Config],
device_id: int = 0,
**kwargs):
"""Initialize backend model.
Args:
model_files (list[str]): Input model files.
model_cfg (str | mmcv.Config): Model config file or
loaded Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or
loaded Config object.
device_id (int): An integer specifying device index.
Returns:
nn.Module: An initialized model.
"""
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
codebase = get_codebase(deploy_cfg)
@ -111,7 +155,19 @@ def init_backend_model(model_files: Sequence[str],
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
def run_inference(codebase: Codebase, model_inputs, model):
def run_inference(codebase: Codebase, model_inputs: dict,
model: torch.nn.Module):
"""Run once inference for a model of nn.Module.
Args:
codebase (Codebase): Specifying codebase type.
model_inputs (dict): A dict containing model inputs tensor and
meta info.
model (nn.Module): Input model.
Returns:
list: The predictions of model inference.
"""
if codebase == Codebase.MMCLS:
return model(**model_inputs, return_loss=False)[0]
elif codebase == Codebase.MMDET:
@ -133,12 +189,23 @@ def run_inference(codebase: Codebase, model_inputs, model):
def visualize(codebase: Codebase,
image: Union[str, np.ndarray],
result,
model,
result: list,
model: torch.nn.Module,
output_file: str,
backend: Backend,
show_result=False):
show_result: bool = False):
"""Visualize predictions of a model.
Args:
codebase (Codebase): Specifying codebase type.
image (str | np.ndarray): Input image to draw predictions on.
result (list): A list of predictions.
model (nn.Module): Input model.
output_file (str): Output file to save drawn image.
backend (Backend): Specifying backend type.
show_result (bool): Whether to show result in windows, defaults
to `False`.
"""
show_img = mmcv.imread(image) if isinstance(image, str) else image
output_file = None if show_result else output_file
@ -164,6 +231,18 @@ def visualize(codebase: Codebase,
def get_partition_cfg(codebase: Codebase, partition_type: str):
"""Get a certain partition config.
Notes:
Currently only support mmdet codebase.
Args:
codebase (Codebase): Specifying codebase type.
partition_type (str): A string specifying partition type.
Returns:
dict: A dictionary of partition config.
"""
if codebase == Codebase.MMDET:
from mmdeploy.mmdet.export import get_partition_cfg \
as get_partition_cfg_mmdet
@ -176,6 +255,17 @@ def build_dataset(codebase: Codebase,
dataset_cfg: Union[str, mmcv.Config],
dataset_type: str = 'val',
**kwargs):
"""Build dataset for different codebase.
Args:
codebase (Codebase): Specifying codebase type.
dataset_cfg (str | mmcv.Config): Dataset config file or Config object.
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
'val', defaults to 'val'.
Returns:
Dataset: The built dataset.
"""
if codebase == Codebase.MMCLS:
from mmdeploy.mmcls.export import build_dataset \
as build_dataset_mmcls
@ -198,8 +288,21 @@ def build_dataset(codebase: Codebase,
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
def build_dataloader(codebase: Codebase, dataset, samples_per_gpu: int,
workers_per_gpu: int, **kwargs):
def build_dataloader(codebase: Codebase, dataset: Dataset,
samples_per_gpu: int, workers_per_gpu: int, **kwargs):
"""Build PyTorch dataloader.
Args:
codebase (Codebase): Specifying codebase type.
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
Returns:
DataLoader: A PyTorch dataloader.
"""
if codebase == Codebase.MMCLS:
from mmdeploy.mmcls.export import build_dataloader \
as build_dataloader_mmcls
@ -229,7 +332,16 @@ def build_dataloader(codebase: Codebase, dataset, samples_per_gpu: int,
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
def get_tensor_from_input(codebase: Codebase, input_data):
def get_tensor_from_input(codebase: Codebase, input_data: tuple):
"""Get input tensor from input data.
Args:
codebase (Codebase): Specifying codebase type.
input_data (tuple): Input data containing meta info and image tensor.
Returns:
torch.Tensor: Input tensor of image.
"""
if codebase == Codebase.MMCLS:
from mmdeploy.mmcls.export import get_tensor_from_input \
as get_tensor_from_input_mmcls

View File

@ -42,7 +42,15 @@ def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes,
impl(node_output_name, graph_input_nodes, reachable_nodes)
def create_extractor(model):
def create_extractor(model: onnx.ModelProto):
"""Create Extractor for ONNX.
Args:
model (onnx.ModelProto): An input onnx model.
Returns:
Extractor: Extractor for the onnx.
"""
assert version.parse(onnx.__version__) >= version.parse('1.8.0')
# patch extractor
onnx.utils.Extractor._dfs_search_reachable_nodes = \

View File

@ -91,6 +91,9 @@ def main():
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
# merge options for model cfg
if args.cfg_options is not None:
model_cfg.merge_from_dict(args.cfg_options)
# prepare the dataset loader
codebase = get_codebase(deploy_cfg)