mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add tensorrt support (#2)
This commit is contained in:
parent
6eb2e89016
commit
6c47ee3d2a
@ -1,2 +1,2 @@
|
||||
[settings]
|
||||
known_third_party = mmcv,mmdet,numpy,setuptools,torch
|
||||
known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch
|
||||
|
@ -1 +1,7 @@
|
||||
import tensorrt as trt
|
||||
|
||||
backend = 'tensorrt'
|
||||
tensorrt_param = dict(
|
||||
log_level=trt.Logger.WARNING,
|
||||
fp16_mode=False,
|
||||
save_file='onnx2tensorrt.engine')
|
||||
|
@ -1,2 +1,5 @@
|
||||
pytorch2onnx = dict(
|
||||
export_params=True, keep_initializers_as_inputs=False, opset_version=11)
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
save_file='torch2onnx.onnx')
|
||||
|
@ -1 +1,5 @@
|
||||
_base_ = ['./mmcls_base.py', '../_base_/backends/tensorrt.py']
|
||||
tensorrt_param = dict(
|
||||
opt_shape_dict=dict(
|
||||
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]),
|
||||
max_workspace_size=1 << 30)
|
||||
|
@ -1 +1,5 @@
|
||||
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
|
||||
tensorrt_param = dict(
|
||||
opt_shape_dict=dict(
|
||||
input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]),
|
||||
max_workspace_size=1 << 30)
|
||||
|
@ -11,7 +11,8 @@ from .utils import create_input, init_model
|
||||
|
||||
|
||||
def torch2onnx(img: Any,
|
||||
work_dir: Optional[str],
|
||||
work_dir: str,
|
||||
save_file: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
model_checkpoint: Optional[str] = None,
|
||||
@ -22,18 +23,18 @@ def torch2onnx(img: Any,
|
||||
# load deploy_cfg if needed
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
elif not isinstance(deploy_cfg, mmcv.Config):
|
||||
if not isinstance(deploy_cfg, mmcv.Config):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(deploy_cfg)}')
|
||||
# load model_cfg if needed
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, mmcv.Config):
|
||||
if not isinstance(model_cfg, mmcv.Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
output_file = osp.join(work_dir, 'torch2onnx.onnx')
|
||||
output_file = osp.join(work_dir, save_file)
|
||||
|
||||
pytorch2onnx_cfg = deploy_cfg['pytorch2onnx']
|
||||
codebase = deploy_cfg['codebase']
|
||||
|
14
mmdeploy/apis/tensorrt/__init__.py
Normal file
14
mmdeploy/apis/tensorrt/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# flake8: noqa
|
||||
from .init_plugins import load_tensorrt_plugin
|
||||
from .onnx2tensorrt import onnx2tensorrt
|
||||
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt,
|
||||
save_trt_engine)
|
||||
|
||||
# load tensorrt plugin lib
|
||||
load_tensorrt_plugin()
|
||||
|
||||
__all__ = [
|
||||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
|
||||
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx',
|
||||
'onnx2tensorrt'
|
||||
]
|
27
mmdeploy/apis/tensorrt/init_plugins.py
Normal file
27
mmdeploy/apis/tensorrt/init_plugins.py
Normal file
@ -0,0 +1,27 @@
|
||||
import ctypes
|
||||
import glob
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
def get_tensorrt_op_path():
|
||||
"""Get TensorRT plugins library path."""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../../build/lib/libmmlab_tensorrt_ops.so'))
|
||||
|
||||
paths = glob.glob(wildcard)
|
||||
lib_path = paths[0] if len(paths) > 0 else ''
|
||||
return lib_path
|
||||
|
||||
|
||||
def load_tensorrt_plugin():
|
||||
"""load TensorRT plugins library."""
|
||||
lib_path = get_tensorrt_op_path()
|
||||
if os.path.exists(lib_path):
|
||||
ctypes.CDLL(lib_path)
|
||||
return 0
|
||||
else:
|
||||
logging.warning('Can not load tensorrt custom ops.')
|
||||
return -1
|
49
mmdeploy/apis/tensorrt/onnx2tensorrt.py
Normal file
49
mmdeploy/apis/tensorrt/onnx2tensorrt.py
Normal file
@ -0,0 +1,49 @@
|
||||
import os.path as osp
|
||||
from typing import Optional, Union
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
import mmcv
|
||||
import onnx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from .tensorrt_utils import onnx2trt, save_trt_engine
|
||||
|
||||
|
||||
def onnx2tensorrt(work_dir: str,
|
||||
save_file: str,
|
||||
deploy_cfg: Union[str, mmcv.Config],
|
||||
onnx_model: Union[str, onnx.ModelProto],
|
||||
device: str = 'cuda:0',
|
||||
ret_value: Optional[mp.Value] = None):
|
||||
ret_value.value = -1
|
||||
save_file = 'onnx2tensorrt.engine'
|
||||
|
||||
# load deploy_cfg if needed
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
elif not isinstance(deploy_cfg, mmcv.Config):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(deploy_cfg)}')
|
||||
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
|
||||
assert 'tensorrt_param' in deploy_cfg
|
||||
|
||||
tensorrt_param = deploy_cfg['tensorrt_param']
|
||||
|
||||
assert device.startswith('cuda')
|
||||
device_id = 0
|
||||
if len(device) >= 6:
|
||||
device_id = int(device[5:])
|
||||
engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict=tensorrt_param['opt_shape_dict'],
|
||||
log_level=tensorrt_param.get('log_level', trt.Logger.WARNING),
|
||||
fp16_mode=tensorrt_param.get('fp16_mode', False),
|
||||
max_workspace_size=tensorrt_param.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
|
||||
save_trt_engine(engine, osp.join(work_dir, save_file))
|
||||
|
||||
ret_value.value = 0
|
222
mmdeploy/apis/tensorrt/tensorrt_utils.py
Normal file
222
mmdeploy/apis/tensorrt/tensorrt_utils.py
Normal file
@ -0,0 +1,222 @@
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
|
||||
def onnx2trt(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=0,
|
||||
device_id=0):
|
||||
"""Convert onnx model to tensorrt engine.
|
||||
|
||||
Arguments:
|
||||
onnx_model (str or onnx.ModelProto): the onnx model to convert from
|
||||
opt_shape_dict (dict): the min/opt/max shape of each input
|
||||
log_level (TensorRT log level): the log level of TensorRT
|
||||
fp16_mode (bool): enable fp16 mode
|
||||
max_workspace_size (int): set max workspace size of TensorRT engine.
|
||||
some tactic and layers need large workspace.
|
||||
device_id (int): choice the device to create engine.
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
|
||||
|
||||
Example:
|
||||
>>> engine = onnx2trt(
|
||||
>>> "onnx_model.onnx",
|
||||
>>> {'input': [[1, 3, 160, 160],
|
||||
>>> [1, 3, 320, 320],
|
||||
>>> [1, 3, 640, 640]]},
|
||||
>>> log_level=trt.Logger.WARNING,
|
||||
>>> fp16_mode=True,
|
||||
>>> max_workspace_size=1 << 30,
|
||||
>>> device_id=0)
|
||||
>>> })
|
||||
"""
|
||||
device = torch.device('cuda:{}'.format(device_id))
|
||||
# create builder and network
|
||||
logger = trt.Logger(log_level)
|
||||
builder = trt.Builder(logger)
|
||||
EXPLICIT_BATCH = 1 << (int)(
|
||||
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(EXPLICIT_BATCH)
|
||||
|
||||
# parse onnx
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
if isinstance(onnx_model, str):
|
||||
onnx_model = onnx.load(onnx_model)
|
||||
|
||||
if not parser.parse(onnx_model.SerializeToString()):
|
||||
error_msgs = ''
|
||||
for error in range(parser.num_errors):
|
||||
error_msgs += f'{parser.get_error(error)}\n'
|
||||
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')
|
||||
|
||||
# config builder
|
||||
builder.max_workspace_size = max_workspace_size
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
for input_name, param in opt_shape_dict.items():
|
||||
min_shape = tuple(param[0][:])
|
||||
opt_shape = tuple(param[1][:])
|
||||
max_shape = tuple(param[2][:])
|
||||
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
if fp16_mode:
|
||||
builder.fp16_mode = fp16_mode
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# create engine
|
||||
with torch.cuda.device(device):
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def save_trt_engine(engine, path):
|
||||
"""Serialize TensorRT engine to disk.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
|
||||
path (str): disk path to write the engine
|
||||
"""
|
||||
with open(path, mode='wb') as f:
|
||||
f.write(bytearray(engine.serialize()))
|
||||
|
||||
|
||||
def load_trt_engine(path):
|
||||
"""Deserialize TensorRT engine from disk.
|
||||
|
||||
Arguments:
|
||||
path (str): disk path to read the engine
|
||||
|
||||
Returns:
|
||||
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
|
||||
"""
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
with open(path, mode='rb') as f:
|
||||
engine_bytes = f.read()
|
||||
engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
return engine
|
||||
|
||||
|
||||
def torch_dtype_from_trt(dtype):
|
||||
"""Convert pytorch dtype to TensorRT dtype."""
|
||||
if dtype == trt.bool:
|
||||
return torch.bool
|
||||
elif dtype == trt.int8:
|
||||
return torch.int8
|
||||
elif dtype == trt.int32:
|
||||
return torch.int32
|
||||
elif dtype == trt.float16:
|
||||
return torch.float16
|
||||
elif dtype == trt.float32:
|
||||
return torch.float32
|
||||
else:
|
||||
raise TypeError('%s is not supported by torch' % dtype)
|
||||
|
||||
|
||||
def torch_device_from_trt(device):
|
||||
"""Convert pytorch device to TensorRT device."""
|
||||
if device == trt.TensorLocation.DEVICE:
|
||||
return torch.device('cuda')
|
||||
elif device == trt.TensorLocation.HOST:
|
||||
return torch.device('cpu')
|
||||
else:
|
||||
return TypeError('%s is not supported by torch' % device)
|
||||
|
||||
|
||||
class TRTWrapper(torch.nn.Module):
|
||||
"""TensorRT engine Wrapper.
|
||||
|
||||
Arguments:
|
||||
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
|
||||
input_names (list[str]): names of each inputs
|
||||
output_names (list[str]): names of each outputs
|
||||
|
||||
Note:
|
||||
If the engine is converted from onnx model. The input_names and
|
||||
output_names should be the same as onnx model.
|
||||
"""
|
||||
|
||||
def __init__(self, engine):
|
||||
super(TRTWrapper, self).__init__()
|
||||
self.engine = engine
|
||||
if isinstance(self.engine, str):
|
||||
self.engine = load_trt_engine(engine)
|
||||
|
||||
if not isinstance(self.engine, trt.ICudaEngine):
|
||||
raise TypeError('engine should be str or trt.ICudaEngine')
|
||||
|
||||
self._register_state_dict_hook(TRTWrapper._on_state_dict)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self._load_io_names()
|
||||
|
||||
def _load_io_names(self):
|
||||
# get input and output names from engine
|
||||
names = [_ for _ in self.engine]
|
||||
input_names = list(filter(self.engine.binding_is_input, names))
|
||||
output_names = list(set(names) - set(input_names))
|
||||
self.input_names = input_names
|
||||
self.output_names = output_names
|
||||
|
||||
def _on_state_dict(self, state_dict, prefix, local_metadata):
|
||||
state_dict[prefix + 'engine'] = bytearray(self.engine.serialize())
|
||||
state_dict[prefix + 'input_names'] = self.input_names
|
||||
state_dict[prefix + 'output_names'] = self.output_names
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
engine_bytes = state_dict[prefix + 'engine']
|
||||
|
||||
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
|
||||
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self.input_names = state_dict[prefix + 'input_names']
|
||||
self.output_names = state_dict[prefix + 'output_names']
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Arguments:
|
||||
inputs (dict): dict of input name-tensors pair
|
||||
|
||||
Return:
|
||||
dict: dict of output name-tensors pair
|
||||
"""
|
||||
assert self.input_names is not None
|
||||
assert self.output_names is not None
|
||||
bindings = [None] * (len(self.input_names) + len(self.output_names))
|
||||
|
||||
for input_name, input_tensor in inputs.items():
|
||||
idx = self.engine.get_binding_index(input_name)
|
||||
|
||||
if input_tensor.dtype == torch.long:
|
||||
input_tensor = input_tensor.int()
|
||||
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
|
||||
bindings[idx] = input_tensor.contiguous().data_ptr()
|
||||
|
||||
# create output tensors
|
||||
outputs = {}
|
||||
for i, output_name in enumerate(self.output_names):
|
||||
idx = self.engine.get_binding_index(output_name)
|
||||
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
|
||||
shape = tuple(self.context.get_binding_shape(idx))
|
||||
|
||||
device = torch_device_from_trt(self.engine.get_location(idx))
|
||||
output = torch.empty(size=shape, dtype=dtype, device=device)
|
||||
outputs[output_name] = output
|
||||
bindings[idx] = output.data_ptr()
|
||||
|
||||
self.context.execute_async_v2(bindings,
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
return outputs
|
@ -45,3 +45,13 @@ def rewrite_topk_tensorrt(rewriter,
|
||||
k = int(k)
|
||||
return rewriter.origin_func(
|
||||
input, k, dim=dim, largest=largest, sorted=sorted)
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||
def rewrite_repeat_tensorrt(rewriter, input, *size):
|
||||
origin_func = rewriter.origin_func
|
||||
if input.dim() == 1 and len(size) == 1:
|
||||
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
|
||||
else:
|
||||
return origin_func(input, *size)
|
||||
|
@ -66,9 +66,11 @@ def nms_tensorrt(symbolic_wrapper, g, boxes, scores,
|
||||
score_threshold = sym_help._maybe_get_const(score_threshold, 'f')
|
||||
|
||||
return g.op(
|
||||
'NonMaxSuppression',
|
||||
'mmcv::NonMaxSuppression',
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class_i=max_output_boxes_per_class,
|
||||
iou_threshold_f=iou_threshold,
|
||||
score_threshold_f=score_threshold)
|
||||
score_threshold_f=score_threshold,
|
||||
center_point_box_i=0,
|
||||
offset_i=0)
|
||||
|
@ -19,6 +19,11 @@ def parse_args():
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--device', help='device used for conversion', default='cpu')
|
||||
parser.add_argument(
|
||||
'--log-level',
|
||||
help='set log level',
|
||||
default='INFO',
|
||||
choices=list(logging._nameToLevel.keys()))
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@ -28,20 +33,31 @@ def main():
|
||||
args = parse_args()
|
||||
set_start_method('spawn')
|
||||
|
||||
deploy_cfg = args.deploy_cfg
|
||||
model_cfg = args.model_cfg
|
||||
checkpoint = args.checkpoint
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(args.log_level)
|
||||
|
||||
deploy_cfg_path = args.deploy_cfg
|
||||
model_cfg_path = args.model_cfg
|
||||
checkpoint_path = args.checkpoint
|
||||
|
||||
# load deploy_cfg
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path)
|
||||
if not isinstance(deploy_cfg, mmcv.Config):
|
||||
raise TypeError('deploy_cfg must be a filename or Config object, '
|
||||
f'but got {type(deploy_cfg)}')
|
||||
|
||||
# create work_dir if not
|
||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
||||
ret_value = mp.Value('d', 0, lock=False)
|
||||
|
||||
# convert model
|
||||
# convert onnx
|
||||
logging.info('start torch2onnx conversion.')
|
||||
onnx_save_file = deploy_cfg['pytorch2onnx']['save_file']
|
||||
process = Process(
|
||||
target=torch2onnx,
|
||||
args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint),
|
||||
args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path,
|
||||
model_cfg_path, checkpoint_path),
|
||||
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||
process.start()
|
||||
process.join()
|
||||
@ -52,6 +68,29 @@ def main():
|
||||
else:
|
||||
logging.info('torch2onnx success.')
|
||||
|
||||
# convert backend
|
||||
backend = deploy_cfg.get('backend', 'default')
|
||||
onnx_paths = [osp.join(args.work_dir, onnx_save_file)]
|
||||
if backend == 'tensorrt':
|
||||
logging.info('start onnx2tensorrt conversion.')
|
||||
from mmdeploy.apis.tensorrt import onnx2tensorrt
|
||||
for onnx_path in onnx_paths:
|
||||
process = Process(
|
||||
target=onnx2tensorrt,
|
||||
args=(args.work_dir, onnx_save_file, deploy_cfg_path,
|
||||
onnx_path),
|
||||
kwargs=dict(device=args.device, ret_value=ret_value))
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if ret_value.value != 0:
|
||||
logging.error('onnx2tensorrt failed.')
|
||||
exit()
|
||||
else:
|
||||
logging.info('onnx2tensorrt success.')
|
||||
|
||||
logging.info('All process success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user