[Docstring]: Coding style and docstring revision for mmdeploy.apis (#87)
* check style of mmdeploy.apis.ncnn * finish check style with mmdeploy.apis.onnxruntime * check mmdeploy.apis.ppl * check mmdeploy.apis.tensorrt * update docstring for mmdeploy.apis * update some docstring * make style consistent * update * resolve comments * resolve commentspull/12/head
parent
5453f9befa
commit
de096d5f00
|
@ -20,13 +20,23 @@ def create_calib_table(calib_file: str,
|
|||
dataset_type: str = 'val',
|
||||
device: str = 'cuda:0',
|
||||
**kwargs) -> None:
|
||||
"""Create calibration table.
|
||||
|
||||
Args:
|
||||
calib_file (str): Input calibration file.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
model_cfg (str | mmcv.Config): The model config.
|
||||
model_checkpoint (str): PyTorch model checkpoint, defaults to `None`.
|
||||
dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None`
|
||||
dataset_type (str): A string specifying dataset type, e.g.: 'test',
|
||||
'val', defaults to 'val'.
|
||||
device (str): Specifying the device to run on, defaults to 'cuda:0'.
|
||||
"""
|
||||
if dataset_cfg is None:
|
||||
dataset_cfg = model_cfg
|
||||
|
||||
# load cfg if necessary
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
device_id = torch.device(device).index
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
|
|
@ -17,7 +17,27 @@ def extract_model(model: Union[str, onnx.ModelProto],
|
|||
end_name_map: Optional[Dict[str, str]] = None,
|
||||
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
||||
save_file: Optional[str] = None):
|
||||
"""Extract sub-model from an ONNX model.
|
||||
|
||||
The sub-model is defined by the names of the input and output tensors
|
||||
exactly.
|
||||
|
||||
Args:
|
||||
model (str | onnx.ModelProto): Input ONNX model to be extracted.
|
||||
start (str | Sequence[str]): Start marker(s) to extract.
|
||||
end (str | Sequence[str]): End marker(s) to extract.
|
||||
start_name_map (Dict[str, str]): A mapping of start names, defaults to
|
||||
`None`.
|
||||
end_name_map (Dict[str, str]): A mapping of end names, defaults to
|
||||
`None`.
|
||||
dynamic_axes (Dict[str, Dict[int, str]]): A dictionary to specify
|
||||
dynamic axes of input/output, defaults to `None`.
|
||||
save_file (str): A file to save the extracted model, defaults to
|
||||
`None`.
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: The extracted sub-model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = onnx.load(model)
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase,
|
||||
|
@ -8,15 +10,29 @@ from .utils import (create_input, init_backend_model, init_pytorch_model,
|
|||
run_inference, visualize)
|
||||
|
||||
|
||||
def inference_model(model_cfg,
|
||||
deploy_cfg,
|
||||
model,
|
||||
img,
|
||||
def inference_model(model_cfg: Union[str, mmcv.Config],
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model: Union[str, Sequence[str], torch.nn.Module],
|
||||
img: Union[str, np.ndarray],
|
||||
device: str,
|
||||
backend: Optional[Backend] = None,
|
||||
output_file: Optional[str] = None,
|
||||
show_result=False):
|
||||
show_result: bool = False):
|
||||
"""Run inference with PyTorch or backend model and show results.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or Config
|
||||
object.
|
||||
model (str | list[str], torch.nn.Module): Input model or file(s).
|
||||
img (str | np.ndarray): Input image file or numpy array for inference.
|
||||
device (str): A string specifying device type.
|
||||
backend (Backend): Specifying backend type, defaults to `None`.
|
||||
output_file (str): Output file to save visualized image, defaults to
|
||||
`None`. Only valid if `show_result` is set to `False`.
|
||||
show_result (bool): Whether to show plotted image in windows, defaults
|
||||
to `False`.
|
||||
"""
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
codebase = get_codebase(deploy_cfg)
|
||||
|
@ -27,7 +43,12 @@ def inference_model(model_cfg,
|
|||
|
||||
if isinstance(model, str):
|
||||
model = [model]
|
||||
|
||||
if isinstance(model, (list, tuple)):
|
||||
assert len(model) > 0, 'Model should have at least one element.'
|
||||
assert all([isinstance(m, str) for m in model]), 'All elements in the \
|
||||
list should be str'
|
||||
|
||||
if backend == Backend.PYTORCH:
|
||||
model = init_pytorch_model(codebase, model_cfg, model[0], device)
|
||||
else:
|
||||
|
|
|
@ -7,6 +7,11 @@ __all__ = ['get_ops_path', 'get_onnx2ncnn_path']
|
|||
|
||||
|
||||
def is_available():
|
||||
"""Check whether ncnn with extension is installed.
|
||||
|
||||
Returns:
|
||||
bool: True if ncnn and its extension are installed.
|
||||
"""
|
||||
ncnn_ops_path = get_ops_path()
|
||||
if not osp.exists(ncnn_ops_path):
|
||||
return False
|
||||
|
|
|
@ -3,7 +3,11 @@ import os
|
|||
|
||||
|
||||
def get_ops_path():
|
||||
"""Get NCNN custom ops library path."""
|
||||
"""Get NCNN custom ops library path.
|
||||
|
||||
Returns:
|
||||
str: The library path of NCNN custom ops.
|
||||
"""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
|
@ -15,7 +19,11 @@ def get_ops_path():
|
|||
|
||||
|
||||
def get_onnx2ncnn_path():
|
||||
"""Get onnx2ncnn path."""
|
||||
"""Get onnx2ncnn path.
|
||||
|
||||
Returns:
|
||||
str: A path of onnx2ncnn tool.
|
||||
"""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), '../../../build/bin/onnx2ncnn'))
|
||||
|
|
|
@ -9,11 +9,24 @@ from mmdeploy.utils.timer import TimeCounter
|
|||
|
||||
|
||||
class NCNNWrapper(torch.nn.Module):
|
||||
"""NCNN Wrapper.
|
||||
"""NCNN wrapper class for inference.
|
||||
|
||||
Arguments:
|
||||
param_file (str): param file path
|
||||
bin_file (str): bin file path
|
||||
Args:
|
||||
param_file (str): Path of a parameter file.
|
||||
bin_file (str): Path of a binary file.
|
||||
output_names (list[str] | tuple[str]): Names to model outputs. Defaults
|
||||
to `None`.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.ncnn import NCNNWrapper
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> param_file = 'model.params'
|
||||
>>> bin_file = 'model.bin'
|
||||
>>> model = NCNNWrapper(param_file, bin_file)
|
||||
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
|
||||
>>> outputs = model(inputs)
|
||||
>>> print(outputs)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -31,10 +44,20 @@ class NCNNWrapper(torch.nn.Module):
|
|||
self._net = net
|
||||
self._output_names = output_names
|
||||
|
||||
def set_output_names(self, output_names):
|
||||
def set_output_names(self, output_names: Iterable[str]):
|
||||
"""Set names of the model outputs.
|
||||
|
||||
Args:
|
||||
output_names (list[str] | tuple[str]): Names to model outputs.
|
||||
"""
|
||||
self._output_names = output_names
|
||||
|
||||
def get_output_names(self):
|
||||
"""Get names of the model outputs.
|
||||
|
||||
Returns:
|
||||
list[str]: Names to model outputs.
|
||||
"""
|
||||
if self._output_names is not None:
|
||||
return self._output_names
|
||||
else:
|
||||
|
@ -42,12 +65,21 @@ class NCNNWrapper(torch.nn.Module):
|
|||
return self._net.output_names()
|
||||
|
||||
def forward(self, inputs: Dict[str, torch.Tensor]):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, torch.Tensor]): Key-value pairs of model inputs.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Key-value pairs of model outputs.
|
||||
"""
|
||||
input_list = list(inputs.values())
|
||||
batch_size = input_list[0].size(0)
|
||||
for tensor in input_list[1:]:
|
||||
assert tensor.size(
|
||||
0) == batch_size, 'All tensor should have same batch size'
|
||||
assert tensor.device.type == 'cpu', 'NCNN only support cpu device'
|
||||
for input_tensor in input_list[1:]:
|
||||
assert input_tensor.size(
|
||||
0) == batch_size, 'All tensors should have same batch size'
|
||||
assert input_tensor.device.type == 'cpu', \
|
||||
'NCNN only supports cpu device'
|
||||
|
||||
# set output names
|
||||
output_names = self.get_output_names()
|
||||
|
@ -55,30 +87,42 @@ class NCNNWrapper(torch.nn.Module):
|
|||
# create output dict
|
||||
outputs = dict([name, [None] * batch_size] for name in output_names)
|
||||
|
||||
# inference
|
||||
# run inference
|
||||
for batch_id in range(batch_size):
|
||||
# create extractor
|
||||
ex = self._net.create_extractor()
|
||||
|
||||
# set input
|
||||
for k, v in inputs.items():
|
||||
in_data = ncnn.Mat(v[batch_id].detach().cpu().numpy())
|
||||
ex.input(k, in_data)
|
||||
# set inputs
|
||||
for name, input_tensor in inputs.items():
|
||||
input_mat = ncnn.Mat(
|
||||
input_tensor[batch_id].detach().cpu().numpy())
|
||||
ex.input(name, input_mat)
|
||||
|
||||
# get output
|
||||
# get outputs
|
||||
result = self.ncnn_execute(extractor=ex, output_names=output_names)
|
||||
for name in output_names:
|
||||
outputs[name][batch_id] = torch.from_numpy(
|
||||
np.array(result[name]))
|
||||
|
||||
# stack outputs together
|
||||
for k, v in outputs.items():
|
||||
outputs[k] = torch.stack(v)
|
||||
for name, input_tensor in outputs.items():
|
||||
outputs[name] = torch.stack(input_tensor)
|
||||
|
||||
return outputs
|
||||
|
||||
@TimeCounter.count_time()
|
||||
def ncnn_execute(self, extractor, output_names):
|
||||
def ncnn_execute(self, extractor: ncnn.Extractor,
|
||||
output_names: Iterable[str]):
|
||||
"""Run inference with NCNN.
|
||||
|
||||
Args:
|
||||
extractor (ncnn.Extractor): NCNN extractor to extract output.
|
||||
output_names (Iterable[str]): A list of string specifying
|
||||
output names.
|
||||
|
||||
Returns:
|
||||
dict[str, ncnn.Mat]: Inference results of NCNN model.
|
||||
"""
|
||||
result = {}
|
||||
for name in output_names:
|
||||
out_ret, out = extractor.extract(name)
|
||||
|
|
|
@ -5,6 +5,12 @@ from .init_plugins import get_ops_path
|
|||
|
||||
|
||||
def is_available():
|
||||
"""Check whether onnxruntime and its custom ops are installed.
|
||||
|
||||
Returns:
|
||||
bool: True if onnxruntime package is installed and its
|
||||
custom ops are compiled.
|
||||
"""
|
||||
onnxruntime_op_path = get_ops_path()
|
||||
if not osp.exists(onnxruntime_op_path):
|
||||
return False
|
||||
|
|
|
@ -3,7 +3,11 @@ import os
|
|||
|
||||
|
||||
def get_ops_path():
|
||||
"""Get ONNX Runtime plugins library path."""
|
||||
"""Get the library path of onnxruntime custom ops.
|
||||
|
||||
Returns:
|
||||
str: The library path to onnxruntime custom ops.
|
||||
"""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import os.path as osp
|
||||
from typing import Sequence
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
@ -10,11 +11,22 @@ from .init_plugins import get_ops_path
|
|||
|
||||
|
||||
class ORTWrapper(torch.nn.Module):
|
||||
"""ONNXRuntime Wrapper.
|
||||
"""ONNXRuntime wrapper for inference.
|
||||
|
||||
Arguments:
|
||||
onnx_file (str): Input onnx model file
|
||||
device_id (int): The device id to put model
|
||||
Args:
|
||||
onnx_file (str): Input onnx model file.
|
||||
device_id (int): The device id to input model.
|
||||
output_names (list[str] | tuple[str]): Names to model outputs.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.onnxruntime import ORTWrapper
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> onnx_file = 'model.onnx'
|
||||
>>> model = ORTWrapper(onnx_file, -1)
|
||||
>>> inputs = dict(input=torch.randn(1, 3, 224, 224, device='cpu'))
|
||||
>>> outputs = model(inputs)
|
||||
>>> print(outputs)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -28,6 +40,12 @@ class ORTWrapper(torch.nn.Module):
|
|||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
logging.info(f'Successfully loaded onnxruntime custom ops from \
|
||||
{ort_custom_op_path}')
|
||||
else:
|
||||
logging.warning(f'The library of onnxruntime custom ops does \
|
||||
not exist: {ort_custom_op_path}')
|
||||
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
|
||||
providers = ['CPUExecutionProvider']
|
||||
|
@ -46,13 +64,14 @@ class ORTWrapper(torch.nn.Module):
|
|||
self.is_cuda_available = is_cuda_available
|
||||
self.device_type = 'cuda' if is_cuda_available else 'cpu'
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Arguments:
|
||||
inputs (dict): the input name and tensor pairs
|
||||
input_names: list of input name
|
||||
Return:
|
||||
list[np.ndarray]: list of output numpy array
|
||||
def forward(self, inputs: Dict[str, torch.Tensor]):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray]: A list of output numpy array.
|
||||
"""
|
||||
for name, input_tensor in inputs.items():
|
||||
# set io binding for inputs/outputs
|
||||
|
@ -75,5 +94,11 @@ class ORTWrapper(torch.nn.Module):
|
|||
return outputs
|
||||
|
||||
@TimeCounter.count_time()
|
||||
def ort_execute(self, io_binding):
|
||||
def ort_execute(self, io_binding: ort.IOBinding):
|
||||
"""Run inference with ONNXRuntime session.
|
||||
|
||||
Args:
|
||||
io_binding (ort.IOBinding): To bind input/output to a specified
|
||||
device, e.g. GPU.
|
||||
"""
|
||||
self.sess.run_with_iobinding(io_binding)
|
||||
|
|
|
@ -2,7 +2,11 @@ import importlib
|
|||
|
||||
|
||||
def is_available():
|
||||
"""Check whether ppl is installed."""
|
||||
"""Check whether ppl is installed.
|
||||
|
||||
Returns:
|
||||
bool: True if ppl package is installed.
|
||||
"""
|
||||
return importlib.util.find_spec('pyppl') is not None
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pyppl.common as pplcommon
|
||||
|
@ -15,9 +16,14 @@ def register_engines(device_id: int,
|
|||
"""Register engines for ppl runtime.
|
||||
|
||||
Args:
|
||||
device_id (int): -1 for cpu.
|
||||
device_id (int): Specifying device index. `-1` for cpu.
|
||||
disable_avx512 (bool): Whether to disable avx512 for x86.
|
||||
Defaults to `False`.
|
||||
quick_select (bool): Whether to use default algorithms.
|
||||
Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
list[pplnn.Engine]: A list of registered ppl engines.
|
||||
"""
|
||||
engines = []
|
||||
if device_id == -1:
|
||||
|
@ -59,11 +65,21 @@ def register_engines(device_id: int,
|
|||
|
||||
|
||||
class PPLWrapper(torch.nn.Module):
|
||||
"""PPLWrapper Wrapper.
|
||||
"""PPL wrapper for inference.
|
||||
|
||||
Arguments:
|
||||
model_file (str): Input onnx model file
|
||||
device_id (int): The device id to put model
|
||||
Args:
|
||||
model_file (str): Input onnx model file.
|
||||
device_id (int): Device id to put model.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.ppl import PPLWrapper
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> onnx_file = 'model.onnx'
|
||||
>>> model = PPLWrapper(onnx_file, 0)
|
||||
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
|
||||
>>> outputs = model(inputs)
|
||||
>>> print(outputs)
|
||||
"""
|
||||
|
||||
def __init__(self, model_file: str, device_id: int):
|
||||
|
@ -87,14 +103,16 @@ class PPLWrapper(torch.nn.Module):
|
|||
for i in range(runtime.GetInputCount())
|
||||
}
|
||||
|
||||
def forward(self, input_data):
|
||||
"""
|
||||
Arguments:
|
||||
input_data (dict): the input name and tensor pairs
|
||||
def forward(self, inputs: Dict[str, torch.Tensor]):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, torch.Tensor]): Input name and tensor pairs.
|
||||
|
||||
Return:
|
||||
list[np.ndarray]: list of output numpy array
|
||||
list[np.ndarray]: A list of output numpy array.
|
||||
"""
|
||||
for name, input_tensor in input_data.items():
|
||||
for name, input_tensor in inputs.items():
|
||||
input_tensor = input_tensor.contiguous()
|
||||
self.inputs[name].ConvertFromHost(input_tensor.cpu().numpy())
|
||||
self.ppl_execute()
|
||||
|
@ -106,6 +124,7 @@ class PPLWrapper(torch.nn.Module):
|
|||
|
||||
@TimeCounter.count_time()
|
||||
def ppl_execute(self):
|
||||
"""Run inference with PPL."""
|
||||
status = self.runtime.Run()
|
||||
assert status == pplcommon.RC_SUCCESS, 'Run() '\
|
||||
'failed: ' + pplcommon.GetRetCodeStr(status)
|
||||
|
|
|
@ -13,12 +13,17 @@ from .utils import create_input, init_pytorch_model
|
|||
|
||||
def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
|
||||
deploy_cfg: Union[str, mmcv.Config], output_file: str):
|
||||
"""Converting torch model to ONNX.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Input pytorch model.
|
||||
input (torch.Tensor): Input tensor used to convert model.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
Config object.
|
||||
output_file (str): Output file to save ONNX model.
|
||||
"""
|
||||
# load deploy_cfg if needed
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
if not isinstance(deploy_cfg, mmcv.Config):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(deploy_cfg)}')
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
|
||||
backend = get_backend(deploy_cfg).value
|
||||
|
@ -51,7 +56,20 @@ def torch2onnx(img: Any,
|
|||
model_cfg: Union[str, mmcv.Config],
|
||||
model_checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0'):
|
||||
"""Convert PyToch model to ONNX model.
|
||||
|
||||
Args:
|
||||
img (str | np.ndarray | torch.Tensor): Input image used to assist
|
||||
converting model.
|
||||
work_dir (str): A working directory to save files.
|
||||
save_file (str): Filename to save onnx model.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
Config object.
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
model_checkpoint (str): A checkpoint path of PyTorch model,
|
||||
defaults to `None`.
|
||||
device (str): A string specifying device type, defaults to 'cuda:0'.
|
||||
"""
|
||||
# load deploy_cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
|
|
|
@ -6,6 +6,11 @@ from .init_plugins import get_ops_path, load_tensorrt_plugin
|
|||
|
||||
|
||||
def is_available():
|
||||
"""Check whether TensorRT and plugins are installed.
|
||||
|
||||
Returns:
|
||||
bool: True if TensorRT and plugins are installed.
|
||||
"""
|
||||
tensorrt_op_path = get_ops_path()
|
||||
if not osp.exists(tensorrt_op_path):
|
||||
return False
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Dict, Sequence, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
|
@ -7,14 +9,26 @@ DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2
|
|||
|
||||
|
||||
class HDF5Calibrator(trt.IInt8Calibrator):
|
||||
"""HDF5 calibrator.
|
||||
|
||||
def __init__(self,
|
||||
calib_file,
|
||||
opt_shape_dict,
|
||||
model_type='end2end',
|
||||
device_id=0,
|
||||
algorithm=DEFAULT_CALIBRATION_ALGORITHM,
|
||||
**kwargs):
|
||||
Args:
|
||||
calib_file (str | h5py.File): Input calibration file.
|
||||
input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of
|
||||
each input.
|
||||
model_type (str): Input model type, defaults to 'end2end'.
|
||||
device_id (int): Cuda device id, defaults to 0.
|
||||
algorithm (trt.CalibrationAlgoType): Calibration algo type, defaults
|
||||
to `trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calib_file: Union[str, h5py.File],
|
||||
input_shapes: Dict[str, Sequence[int]],
|
||||
model_type: str = 'end2end',
|
||||
device_id: int = 0,
|
||||
algorithm: trt.CalibrationAlgoType = DEFAULT_CALIBRATION_ALGORITHM,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(calib_file, str):
|
||||
|
@ -29,7 +43,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
|||
self.calib_data = calib_data
|
||||
self.device_id = device_id
|
||||
self.algorithm = algorithm
|
||||
self.opt_shape_dict = opt_shape_dict
|
||||
self.input_shapes = input_shapes
|
||||
self.kwargs = kwargs
|
||||
|
||||
# create buffers that will hold data batches
|
||||
|
@ -45,7 +59,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
|||
if hasattr(self, 'calib_file'):
|
||||
self.calib_file.close()
|
||||
|
||||
def get_batch(self, names, **kwargs):
|
||||
def get_batch(self, names: Sequence[str], **kwargs):
|
||||
if self.count < self.dataset_length:
|
||||
|
||||
ret = []
|
||||
|
@ -55,7 +69,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
|||
data_torch = torch.from_numpy(data_np)
|
||||
|
||||
# tile the tensor so we can keep the same distribute
|
||||
opt_shape = self.opt_shape_dict[name]['opt_shape']
|
||||
opt_shape = self.input_shapes[name]['opt_shape']
|
||||
data_shape = data_torch.shape
|
||||
|
||||
reps = [
|
||||
|
@ -87,7 +101,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
|||
return self.batch_size
|
||||
|
||||
def read_calibration_cache(self, *args, **kwargs):
|
||||
return None
|
||||
pass
|
||||
|
||||
def write_calibration_cache(self, cache, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
@ -5,7 +5,11 @@ import os
|
|||
|
||||
|
||||
def get_ops_path():
|
||||
"""Get TensorRT plugins library path."""
|
||||
"""Get path of the TensorRT plugin library.
|
||||
|
||||
Returns:
|
||||
str: A path of the TensorRT plugin library.
|
||||
"""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
|
@ -17,11 +21,18 @@ def get_ops_path():
|
|||
|
||||
|
||||
def load_tensorrt_plugin():
|
||||
"""load TensorRT plugins library."""
|
||||
"""Load TensorRT plugins library.
|
||||
|
||||
Returns:
|
||||
bool: True if TensorRT plugin library is successfully loaded.
|
||||
"""
|
||||
lib_path = get_ops_path()
|
||||
success = False
|
||||
if os.path.exists(lib_path):
|
||||
ctypes.CDLL(lib_path)
|
||||
return 0
|
||||
logging.info(f'Successfully loaded tensorrt plugins from {lib_path}')
|
||||
success = True
|
||||
else:
|
||||
logging.warning('Can not load tensorrt custom ops.')
|
||||
return -1
|
||||
logging.warning(f'Could not load the library of tensorrt plugins. \
|
||||
Because the file does not exist: {lib_path}')
|
||||
return success
|
||||
|
|
|
@ -10,7 +10,16 @@ from mmdeploy.utils import (get_calib_filename, get_common_config,
|
|||
from .tensorrt_utils import create_trt_engine, save_trt_engine
|
||||
|
||||
|
||||
def parse_device_id(device: str):
|
||||
def parse_device_id(device: str) -> int:
|
||||
"""Parse cuda device index from a string.
|
||||
|
||||
Args:
|
||||
device (str): The typical style of string specifying cuda device,
|
||||
e.g.: 'cuda:0'.
|
||||
|
||||
Returns:
|
||||
int: The parsed device id, defaults to `0`.
|
||||
"""
|
||||
device_id = 0
|
||||
if len(device) >= 6:
|
||||
device_id = int(device[5:])
|
||||
|
@ -25,6 +34,18 @@ def onnx2tensorrt(work_dir: str,
|
|||
device: str = 'cuda:0',
|
||||
partition_type: str = 'end2end',
|
||||
**kwargs):
|
||||
"""Convert ONNX to TensorRT.
|
||||
|
||||
Args:
|
||||
work_dir (str): A working directory.
|
||||
save_file (str): File path to save TensorRT engine.
|
||||
model_id (int): Index of input model.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config.
|
||||
onnx_model (str | onnx.ModelProto): input onnx model.
|
||||
device (str): A string specifying cuda device, defaults to 'cuda:0'.
|
||||
partition_type (str): Specifying partition type of a model, defaults to
|
||||
'end2end'.
|
||||
"""
|
||||
|
||||
# load deploy_cfg if necessary
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
@ -43,11 +64,13 @@ def onnx2tensorrt(work_dir: str,
|
|||
int8_param['calib_file'] = osp.join(work_dir, calib_file)
|
||||
int8_param['model_type'] = partition_type
|
||||
|
||||
assert device.startswith('cuda'), 'TensorRT require cuda device.'
|
||||
assert device.startswith('cuda'), f'TensorRT requires cuda device, \
|
||||
but given: {device}'
|
||||
|
||||
device_id = parse_device_id(device)
|
||||
engine = create_trt_engine(
|
||||
onnx_model,
|
||||
opt_shape_dict=final_params['input_shapes'],
|
||||
input_shapes=final_params['input_shapes'],
|
||||
log_level=final_params.get('log_level', trt.Logger.WARNING),
|
||||
fp16_mode=final_params.get('fp16_mode', False),
|
||||
int8_mode=final_params.get('int8_mode', False),
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Dict, Sequence, Union
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
@ -7,37 +9,42 @@ from mmdeploy.utils.timer import TimeCounter
|
|||
from .calib_utils import HDF5Calibrator
|
||||
|
||||
|
||||
def create_trt_engine(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
int8_mode=False,
|
||||
int8_param=None,
|
||||
max_workspace_size=0,
|
||||
device_id=0,
|
||||
def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
|
||||
input_shapes: Dict[str, Sequence[int]],
|
||||
log_level: trt.Logger.Severity = trt.Logger.ERROR,
|
||||
fp16_mode: bool = False,
|
||||
int8_mode: bool = False,
|
||||
int8_param: dict = None,
|
||||
max_workspace_size: int = 0,
|
||||
device_id: int = 0,
|
||||
**kwargs):
|
||||
"""Create a tensorrt engine from ONNX.
|
||||
|
||||
Arguments:
|
||||
onnx_model (str or onnx.ModelProto): the onnx model to convert from
|
||||
opt_shape_dict (dict): the min/opt/max shape of each input
|
||||
log_level (TensorRT log level): the log level of TensorRT
|
||||
fp16_mode (bool): enable fp16 mode
|
||||
int8_mode (bool): enable int8 mode
|
||||
int8_param (None|dict): parameter of int8 mode.
|
||||
max_workspace_size (int): set max workspace size of TensorRT engine.
|
||||
some tactic and layers need large workspace.
|
||||
device_id (int): choice the device to create engine.
|
||||
Args:
|
||||
onnx_model (str or onnx.ModelProto): Input onnx model to convert from.
|
||||
input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of
|
||||
each input.
|
||||
log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to
|
||||
`trt.Logger.ERROR`.
|
||||
fp16_mode (bool): Specifying whether to enable fp16 mode.
|
||||
Defaults to `False`.
|
||||
int8_mode (bool): Specifying whether to enable int8 mode.
|
||||
Defaults to `False`.
|
||||
int8_param (dict): A dict of parameter int8 mode. Defaults to `None`.
|
||||
max_workspace_size (int): To set max workspace size of TensorRT engine.
|
||||
some tactics and layers need large workspace. Defaults to `0`.
|
||||
device_id (int): Choice the device to create engine. Defaults to `0`.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
|
||||
tensorrt.ICudaEngine: The TensorRT engine created from onnx_model.
|
||||
|
||||
Example:
|
||||
>>> from mmdeploy.apis.tensorrt import create_trt_engine
|
||||
>>> engine = create_trt_engine(
|
||||
>>> "onnx_model.onnx",
|
||||
>>> {'input': {"min_shape" : [1, 3, 160, 160],
|
||||
>>> "opt_shape" :[1, 3, 320, 320],
|
||||
>>> "max_shape" :[1, 3, 640, 640]}},
|
||||
>>> "opt_shape" : [1, 3, 320, 320],
|
||||
>>> "max_shape" : [1, 3, 640, 640]}},
|
||||
>>> log_level=trt.Logger.WARNING,
|
||||
>>> fp16_mode=True,
|
||||
>>> max_workspace_size=1 << 30,
|
||||
|
@ -62,7 +69,7 @@ def create_trt_engine(onnx_model,
|
|||
error_msgs = ''
|
||||
for error in range(parser.num_errors):
|
||||
error_msgs += f'{parser.get_error(error)}\n'
|
||||
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
|
||||
raise RuntimeError(f'Failed to parse onnx: {onnx_model}\n{error_msgs}')
|
||||
|
||||
# config builder
|
||||
if version.parse(trt.__version__) < version.parse('8'):
|
||||
|
@ -72,7 +79,7 @@ def create_trt_engine(onnx_model,
|
|||
config.max_workspace_size = max_workspace_size
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
for input_name, param in opt_shape_dict.items():
|
||||
for input_name, param in input_shapes.items():
|
||||
min_shape = param['min_shape']
|
||||
opt_shape = param['opt_shape']
|
||||
max_shape = param['max_shape']
|
||||
|
@ -88,7 +95,7 @@ def create_trt_engine(onnx_model,
|
|||
assert int8_param is not None
|
||||
config.int8_calibrator = HDF5Calibrator(
|
||||
int8_param['calib_file'],
|
||||
opt_shape_dict,
|
||||
input_shapes,
|
||||
model_type=int8_param['model_type'],
|
||||
device_id=device_id,
|
||||
algorithm=int8_param.get(
|
||||
|
@ -100,29 +107,29 @@ def create_trt_engine(onnx_model,
|
|||
# create engine
|
||||
with torch.cuda.device(device):
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
assert engine is not None, f'Failed to create engine from {onnx_model}'
|
||||
return engine
|
||||
|
||||
|
||||
def save_trt_engine(engine, path):
|
||||
def save_trt_engine(engine: trt.ICudaEngine, path: str):
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
|
||||
path (str): disk path to write the engine
|
||||
Args:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
|
||||
path (str): The disk path to write the engine.
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load_trt_engine(path):
|
||||
def load_trt_engine(path: str):
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Arguments:
|
||||
path (str): disk path to read the engine
|
||||
Args:
|
||||
path (str): The disk path to read the engine.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
|
||||
tensorrt.ICudaEngine: The TensorRT engine loaded from disk.
|
||||
"""
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
with open(path, mode='rb') as f:
|
||||
|
@ -131,8 +138,16 @@ def load_trt_engine(path):
|
|||
return engine
|
||||
|
||||
|
||||
def torch_dtype_from_trt(dtype):
|
||||
"""Convert pytorch dtype to TensorRT dtype."""
|
||||
def torch_dtype_from_trt(dtype: trt.DataType):
|
||||
"""Convert pytorch dtype to TensorRT dtype.
|
||||
|
||||
Args:
|
||||
dtype (str.DataType): The data type in tensorrt.
|
||||
|
||||
Returns:
|
||||
torch.dtype: The corresponding data type in torch.
|
||||
"""
|
||||
|
||||
if dtype == trt.bool:
|
||||
return torch.bool
|
||||
elif dtype == trt.int8:
|
||||
|
@ -144,38 +159,54 @@ def torch_dtype_from_trt(dtype):
|
|||
elif dtype == trt.float32:
|
||||
return torch.float32
|
||||
else:
|
||||
raise TypeError('%s is not supported by torch' % dtype)
|
||||
raise TypeError(f'{dtype} is not supported by torch')
|
||||
|
||||
|
||||
def torch_device_from_trt(device):
|
||||
"""Convert pytorch device to TensorRT device."""
|
||||
def torch_device_from_trt(device: trt.TensorLocation):
|
||||
"""Convert pytorch device to TensorRT device.
|
||||
|
||||
Args:
|
||||
device (trt.TensorLocation): The device in tensorrt.
|
||||
|
||||
Returns:
|
||||
torch.device: The corresponding device in torch.
|
||||
"""
|
||||
if device == trt.TensorLocation.DEVICE:
|
||||
return torch.device('cuda')
|
||||
elif device == trt.TensorLocation.HOST:
|
||||
return torch.device('cpu')
|
||||
else:
|
||||
return TypeError('%s is not supported by torch' % device)
|
||||
return TypeError(f'{device} is not supported by torch')
|
||||
|
||||
|
||||
class TRTWrapper(torch.nn.Module):
|
||||
"""TensorRT engine Wrapper.
|
||||
"""TensorRT engine wrapper for inference.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
|
||||
Args:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to wrap.
|
||||
|
||||
Note:
|
||||
If the engine is converted from onnx model. The input_names and
|
||||
output_names should be the same as onnx model.
|
||||
|
||||
Examples:
|
||||
>>> from mmdeploy.apis.tensorrt import TRTWrapper
|
||||
>>> engine_file = 'resnet.engine'
|
||||
>>> model = TRTWrapper(engine_file)
|
||||
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
|
||||
>>> outputs = model(inputs)
|
||||
>>> print(outputs)
|
||||
"""
|
||||
|
||||
def __init__(self, engine):
|
||||
def __init__(self, engine: Union[str, trt.ICudaEngine]):
|
||||
super(TRTWrapper, self).__init__()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
self.engine = load_trt_engine(engine)
|
||||
|
||||
if not isinstance(self.engine, trt.ICudaEngine):
|
||||
raise TypeError('engine should be str or trt.ICudaEngine')
|
||||
raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
|
||||
but given: {type(self.engine)}')
|
||||
|
||||
self._register_state_dict_hook(TRTWrapper._on_state_dict)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
@ -206,13 +237,14 @@ class TRTWrapper(torch.nn.Module):
|
|||
self.input_names = state_dict[prefix + 'input_names']
|
||||
self.output_names = state_dict[prefix + 'output_names']
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Arguments:
|
||||
inputs (dict): dict of input name-tensors pair
|
||||
def forward(self, inputs: Dict[str, torch.Tensor]):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.
|
||||
|
||||
Return:
|
||||
dict: dict of output name-tensors pair
|
||||
Dict[str, torch.Tensor]: The output name and tensor pairs.
|
||||
"""
|
||||
assert self.input_names is not None
|
||||
assert self.output_names is not None
|
||||
|
@ -243,6 +275,11 @@ class TRTWrapper(torch.nn.Module):
|
|||
return outputs
|
||||
|
||||
@TimeCounter.count_time()
|
||||
def trt_execute(self, bindings):
|
||||
def trt_execute(self, bindings: Sequence[int]):
|
||||
"""Run inference with TensorRT.
|
||||
|
||||
Args:
|
||||
bindings (list[int]): A list of integer binding the input/output.
|
||||
"""
|
||||
self.context.execute_async_v2(bindings,
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
|
|
|
@ -1,20 +1,35 @@
|
|||
import warnings
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmdeploy.utils import Codebase
|
||||
|
||||
|
||||
def single_gpu_test(codebase: str,
|
||||
def single_gpu_test(codebase: Codebase,
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
show: bool = False,
|
||||
out_dir: Any = None,
|
||||
out_dir: Optional[str] = None,
|
||||
show_score_thr: float = 0.3):
|
||||
"""Run test with single gpu.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
model (torch.nn.Module): Input model from nn.Module.
|
||||
data_loader (DataLoader): PyTorch data loader.
|
||||
show (bool): Specifying whether to show plotted results. Defaults
|
||||
to `False`.
|
||||
out_dir (str): A directory to save results, defaults to `None`.
|
||||
show_score_thr (float): A threshold to show detection results,
|
||||
defaults to `0.3`.
|
||||
|
||||
Returns:
|
||||
list: The prediction results.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
from mmcls.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||
|
@ -36,14 +51,35 @@ def single_gpu_test(codebase: str,
|
|||
return outputs
|
||||
|
||||
|
||||
def post_process_outputs(outputs,
|
||||
dataset,
|
||||
def post_process_outputs(outputs: list,
|
||||
dataset: Dataset,
|
||||
model_cfg: mmcv.Config,
|
||||
codebase: str,
|
||||
metrics: str = None,
|
||||
out: str = None,
|
||||
metric_options: dict = None,
|
||||
codebase: Codebase,
|
||||
metrics: Optional[str] = None,
|
||||
out: Optional[str] = None,
|
||||
metric_options: Optional[dict] = None,
|
||||
format_only: bool = False):
|
||||
"""Perform post-processing to predictions of model.
|
||||
|
||||
Args:
|
||||
outputs (list): A list of predictions of model inference.
|
||||
dataset (Dataset): Input dataset to run test.
|
||||
model_cfg (mmcv.Config): The model config.
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
metrics (str): Evaluation metrics, which depends on
|
||||
the codebase and the dataset, e.g., "bbox", "segm", "proposal"
|
||||
for COCO, and "mAP", "recall" for PASCAL VOC in mmdet; "accuracy",
|
||||
"precision", "recall", "f1_score", "support" for single label
|
||||
dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for
|
||||
multi-label dataset in mmcls. Defaults is `None`.
|
||||
out (str): Output result file in pickle format, defaults to `None`.
|
||||
metric_options (dict): Custom options for evaluation, will be kwargs
|
||||
for dataset.evaluate() function. Defaults to `None`.
|
||||
format_only (bool): Format the output results without perform
|
||||
evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server. Defaults
|
||||
to `False`.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
if metrics:
|
||||
results = dataset.evaluate(outputs, metrics, metric_options)
|
||||
|
@ -121,7 +157,7 @@ def post_process_outputs(outputs,
|
|||
print(f'\nwriting results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
# The Dataset doesn't need metrics
|
||||
print('')
|
||||
print('\n')
|
||||
# print metrics
|
||||
stats = dataset.evaluate(outputs)
|
||||
for stat in stats:
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import Any, Dict, Optional, Sequence, Union
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.utils import Backend, Codebase, Task, get_codebase, load_config
|
||||
|
||||
|
@ -11,6 +13,19 @@ def init_pytorch_model(codebase: Codebase,
|
|||
model_checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[Dict] = None):
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
model_cfg (str | mmcv.Config): Model config file or Config object.
|
||||
model_checkpoint (str): The checkpoint file of torch model, defaults
|
||||
to `None`.
|
||||
device (str): A string specifying device type, defaults to 'cuda:0'.
|
||||
cfg_options (dict): Optional config key-pair parameters.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized torch model.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
from mmcls.apis import init_model
|
||||
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
|
||||
|
@ -46,6 +61,22 @@ def create_input(codebase: Codebase,
|
|||
input_shape: Sequence[int] = None,
|
||||
device: str = 'cuda:0',
|
||||
**kwargs):
|
||||
"""Create input for model.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
task (Task): Specifying task type.
|
||||
model_cfg (str | mmcv.Config): model config file or loaded Config
|
||||
object.
|
||||
imgs (str | np.ndarray): Input image(s).
|
||||
input_shape (list[int]): Input shape of image in (width, height)
|
||||
format, defaults to `None`.
|
||||
device (str): A string specifying device type, defaults to 'cuda:0'.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input
|
||||
image tensor.
|
||||
"""
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
|
||||
cfg = model_cfg.copy()
|
||||
|
@ -78,6 +109,19 @@ def init_backend_model(model_files: Sequence[str],
|
|||
deploy_cfg: Union[str, mmcv.Config],
|
||||
device_id: int = 0,
|
||||
**kwargs):
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
model_files (list[str]): Input model files.
|
||||
model_cfg (str | mmcv.Config): Model config file or
|
||||
loaded Config object.
|
||||
deploy_cfg (str | mmcv.Config): Deployment config file or
|
||||
loaded Config object.
|
||||
device_id (int): An integer specifying device index.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized model.
|
||||
"""
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
codebase = get_codebase(deploy_cfg)
|
||||
|
@ -111,7 +155,19 @@ def init_backend_model(model_files: Sequence[str],
|
|||
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
|
||||
|
||||
|
||||
def run_inference(codebase: Codebase, model_inputs, model):
|
||||
def run_inference(codebase: Codebase, model_inputs: dict,
|
||||
model: torch.nn.Module):
|
||||
"""Run once inference for a model of nn.Module.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
model_inputs (dict): A dict containing model inputs tensor and
|
||||
meta info.
|
||||
model (nn.Module): Input model.
|
||||
|
||||
Returns:
|
||||
list: The predictions of model inference.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
return model(**model_inputs, return_loss=False)[0]
|
||||
elif codebase == Codebase.MMDET:
|
||||
|
@ -133,12 +189,23 @@ def run_inference(codebase: Codebase, model_inputs, model):
|
|||
|
||||
def visualize(codebase: Codebase,
|
||||
image: Union[str, np.ndarray],
|
||||
result,
|
||||
model,
|
||||
result: list,
|
||||
model: torch.nn.Module,
|
||||
output_file: str,
|
||||
backend: Backend,
|
||||
show_result=False):
|
||||
show_result: bool = False):
|
||||
"""Visualize predictions of a model.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
image (str | np.ndarray): Input image to draw predictions on.
|
||||
result (list): A list of predictions.
|
||||
model (nn.Module): Input model.
|
||||
output_file (str): Output file to save drawn image.
|
||||
backend (Backend): Specifying backend type.
|
||||
show_result (bool): Whether to show result in windows, defaults
|
||||
to `False`.
|
||||
"""
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
output_file = None if show_result else output_file
|
||||
|
||||
|
@ -164,6 +231,18 @@ def visualize(codebase: Codebase,
|
|||
|
||||
|
||||
def get_partition_cfg(codebase: Codebase, partition_type: str):
|
||||
"""Get a certain partition config.
|
||||
|
||||
Notes:
|
||||
Currently only support mmdet codebase.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
partition_type (str): A string specifying partition type.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of partition config.
|
||||
"""
|
||||
if codebase == Codebase.MMDET:
|
||||
from mmdeploy.mmdet.export import get_partition_cfg \
|
||||
as get_partition_cfg_mmdet
|
||||
|
@ -176,6 +255,17 @@ def build_dataset(codebase: Codebase,
|
|||
dataset_cfg: Union[str, mmcv.Config],
|
||||
dataset_type: str = 'val',
|
||||
**kwargs):
|
||||
"""Build dataset for different codebase.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
dataset_cfg (str | mmcv.Config): Dataset config file or Config object.
|
||||
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
|
||||
'val', defaults to 'val'.
|
||||
|
||||
Returns:
|
||||
Dataset: The built dataset.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
from mmdeploy.mmcls.export import build_dataset \
|
||||
as build_dataset_mmcls
|
||||
|
@ -198,8 +288,21 @@ def build_dataset(codebase: Codebase,
|
|||
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
|
||||
|
||||
|
||||
def build_dataloader(codebase: Codebase, dataset, samples_per_gpu: int,
|
||||
workers_per_gpu: int, **kwargs):
|
||||
def build_dataloader(codebase: Codebase, dataset: Dataset,
|
||||
samples_per_gpu: int, workers_per_gpu: int, **kwargs):
|
||||
"""Build PyTorch dataloader.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
dataset (Dataset): A PyTorch dataset.
|
||||
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
||||
batch size of each GPU.
|
||||
workers_per_gpu (int): How many subprocesses to use for data loading
|
||||
for each GPU.
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
from mmdeploy.mmcls.export import build_dataloader \
|
||||
as build_dataloader_mmcls
|
||||
|
@ -229,7 +332,16 @@ def build_dataloader(codebase: Codebase, dataset, samples_per_gpu: int,
|
|||
raise NotImplementedError(f'Unknown codebase type: {codebase.value}')
|
||||
|
||||
|
||||
def get_tensor_from_input(codebase: Codebase, input_data):
|
||||
def get_tensor_from_input(codebase: Codebase, input_data: tuple):
|
||||
"""Get input tensor from input data.
|
||||
|
||||
Args:
|
||||
codebase (Codebase): Specifying codebase type.
|
||||
input_data (tuple): Input data containing meta info and image tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Input tensor of image.
|
||||
"""
|
||||
if codebase == Codebase.MMCLS:
|
||||
from mmdeploy.mmcls.export import get_tensor_from_input \
|
||||
as get_tensor_from_input_mmcls
|
||||
|
|
|
@ -42,7 +42,15 @@ def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes,
|
|||
impl(node_output_name, graph_input_nodes, reachable_nodes)
|
||||
|
||||
|
||||
def create_extractor(model):
|
||||
def create_extractor(model: onnx.ModelProto):
|
||||
"""Create Extractor for ONNX.
|
||||
|
||||
Args:
|
||||
model (onnx.ModelProto): An input onnx model.
|
||||
|
||||
Returns:
|
||||
Extractor: Extractor for the onnx.
|
||||
"""
|
||||
assert version.parse(onnx.__version__) >= version.parse('1.8.0')
|
||||
# patch extractor
|
||||
onnx.utils.Extractor._dfs_search_reachable_nodes = \
|
||||
|
|
|
@ -91,6 +91,9 @@ def main():
|
|||
|
||||
# load deploy_cfg
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
||||
# merge options for model cfg
|
||||
if args.cfg_options is not None:
|
||||
model_cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# prepare the dataset loader
|
||||
codebase = get_codebase(deploy_cfg)
|
||||
|
|
Loading…
Reference in New Issue