From 6c47ee3d2a7bb3e961041a695fcce336db5dada9 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 23 Jun 2021 13:14:28 +0800 Subject: [PATCH] add tensorrt support (#2) --- .isort.cfg | 2 +- configs/_base_/backends/tensorrt.py | 6 + configs/_base_/torch2onnx.py | 5 +- configs/mmcls/mmcls_tensorrt.py | 4 + configs/mmdet/tensorrt.py | 4 + mmdeploy/apis/pytorch2onnx.py | 9 +- mmdeploy/apis/tensorrt/__init__.py | 14 ++ mmdeploy/apis/tensorrt/init_plugins.py | 27 +++ mmdeploy/apis/tensorrt/onnx2tensorrt.py | 49 +++++ mmdeploy/apis/tensorrt/tensorrt_utils.py | 222 +++++++++++++++++++++++ mmdeploy/mmcv/funcs.py | 10 + mmdeploy/mmcv/ops.py | 6 +- tools/deploy.py | 49 ++++- 13 files changed, 394 insertions(+), 13 deletions(-) create mode 100644 mmdeploy/apis/tensorrt/__init__.py create mode 100644 mmdeploy/apis/tensorrt/init_plugins.py create mode 100644 mmdeploy/apis/tensorrt/onnx2tensorrt.py create mode 100644 mmdeploy/apis/tensorrt/tensorrt_utils.py diff --git a/.isort.cfg b/.isort.cfg index 90b502f8c..dc8b6bbc3 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,2 @@ [settings] -known_third_party = mmcv,mmdet,numpy,setuptools,torch +known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch diff --git a/configs/_base_/backends/tensorrt.py b/configs/_base_/backends/tensorrt.py index ef413236a..cb248b4e3 100644 --- a/configs/_base_/backends/tensorrt.py +++ b/configs/_base_/backends/tensorrt.py @@ -1 +1,7 @@ +import tensorrt as trt + backend = 'tensorrt' +tensorrt_param = dict( + log_level=trt.Logger.WARNING, + fp16_mode=False, + save_file='onnx2tensorrt.engine') diff --git a/configs/_base_/torch2onnx.py b/configs/_base_/torch2onnx.py index 2573de7d7..ff2a9c4dc 100644 --- a/configs/_base_/torch2onnx.py +++ b/configs/_base_/torch2onnx.py @@ -1,2 +1,5 @@ pytorch2onnx = dict( - export_params=True, keep_initializers_as_inputs=False, opset_version=11) + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='torch2onnx.onnx') diff --git a/configs/mmcls/mmcls_tensorrt.py b/configs/mmcls/mmcls_tensorrt.py index 0eea93a6d..76d054453 100644 --- a/configs/mmcls/mmcls_tensorrt.py +++ b/configs/mmcls/mmcls_tensorrt.py @@ -1 +1,5 @@ _base_ = ['./mmcls_base.py', '../_base_/backends/tensorrt.py'] +tensorrt_param = dict( + opt_shape_dict=dict( + input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]), + max_workspace_size=1 << 30) diff --git a/configs/mmdet/tensorrt.py b/configs/mmdet/tensorrt.py index 390a369e7..2fcad32a9 100644 --- a/configs/mmdet/tensorrt.py +++ b/configs/mmdet/tensorrt.py @@ -1 +1,5 @@ _base_ = ['./base.py', '../_base_/backends/tensorrt.py'] +tensorrt_param = dict( + opt_shape_dict=dict( + input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]), + max_workspace_size=1 << 30) diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 66f360286..5ab3c8fc9 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -11,7 +11,8 @@ from .utils import create_input, init_model def torch2onnx(img: Any, - work_dir: Optional[str], + work_dir: str, + save_file: str, deploy_cfg: Union[str, mmcv.Config], model_cfg: Union[str, mmcv.Config], model_checkpoint: Optional[str] = None, @@ -22,18 +23,18 @@ def torch2onnx(img: Any, # load deploy_cfg if needed if isinstance(deploy_cfg, str): deploy_cfg = mmcv.Config.fromfile(deploy_cfg) - elif not isinstance(deploy_cfg, mmcv.Config): + if not isinstance(deploy_cfg, mmcv.Config): raise TypeError('deploy_cfg must be a filename or Config object, ' f'but got {type(deploy_cfg)}') # load model_cfg if needed if isinstance(model_cfg, str): model_cfg = mmcv.Config.fromfile(model_cfg) - elif not isinstance(model_cfg, mmcv.Config): + if not isinstance(model_cfg, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(model_cfg)}') mmcv.mkdir_or_exist(osp.abspath(work_dir)) - output_file = osp.join(work_dir, 'torch2onnx.onnx') + output_file = osp.join(work_dir, save_file) pytorch2onnx_cfg = deploy_cfg['pytorch2onnx'] codebase = deploy_cfg['codebase'] diff --git a/mmdeploy/apis/tensorrt/__init__.py b/mmdeploy/apis/tensorrt/__init__.py new file mode 100644 index 000000000..2a4c8c4fe --- /dev/null +++ b/mmdeploy/apis/tensorrt/__init__.py @@ -0,0 +1,14 @@ +# flake8: noqa +from .init_plugins import load_tensorrt_plugin +from .onnx2tensorrt import onnx2tensorrt +from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt, + save_trt_engine) + +# load tensorrt plugin lib +load_tensorrt_plugin() + +__all__ = [ + 'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper', + 'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx', + 'onnx2tensorrt' +] diff --git a/mmdeploy/apis/tensorrt/init_plugins.py b/mmdeploy/apis/tensorrt/init_plugins.py new file mode 100644 index 000000000..f3bdc5779 --- /dev/null +++ b/mmdeploy/apis/tensorrt/init_plugins.py @@ -0,0 +1,27 @@ +import ctypes +import glob +import os +import logging + + +def get_tensorrt_op_path(): + """Get TensorRT plugins library path.""" + wildcard = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + '../../../build/lib/libmmlab_tensorrt_ops.so')) + + paths = glob.glob(wildcard) + lib_path = paths[0] if len(paths) > 0 else '' + return lib_path + + +def load_tensorrt_plugin(): + """load TensorRT plugins library.""" + lib_path = get_tensorrt_op_path() + if os.path.exists(lib_path): + ctypes.CDLL(lib_path) + return 0 + else: + logging.warning('Can not load tensorrt custom ops.') + return -1 diff --git a/mmdeploy/apis/tensorrt/onnx2tensorrt.py b/mmdeploy/apis/tensorrt/onnx2tensorrt.py new file mode 100644 index 000000000..2b4cbef07 --- /dev/null +++ b/mmdeploy/apis/tensorrt/onnx2tensorrt.py @@ -0,0 +1,49 @@ +import os.path as osp +from typing import Optional, Union + +import tensorrt as trt + +import mmcv +import onnx +import torch.multiprocessing as mp + +from .tensorrt_utils import onnx2trt, save_trt_engine + + +def onnx2tensorrt(work_dir: str, + save_file: str, + deploy_cfg: Union[str, mmcv.Config], + onnx_model: Union[str, onnx.ModelProto], + device: str = 'cuda:0', + ret_value: Optional[mp.Value] = None): + ret_value.value = -1 + save_file = 'onnx2tensorrt.engine' + + # load deploy_cfg if needed + if isinstance(deploy_cfg, str): + deploy_cfg = mmcv.Config.fromfile(deploy_cfg) + elif not isinstance(deploy_cfg, mmcv.Config): + raise TypeError('deploy_cfg must be a filename or Config object, ' + f'but got {type(deploy_cfg)}') + + mmcv.mkdir_or_exist(osp.abspath(work_dir)) + + assert 'tensorrt_param' in deploy_cfg + + tensorrt_param = deploy_cfg['tensorrt_param'] + + assert device.startswith('cuda') + device_id = 0 + if len(device) >= 6: + device_id = int(device[5:]) + engine = onnx2trt( + onnx_model, + opt_shape_dict=tensorrt_param['opt_shape_dict'], + log_level=tensorrt_param.get('log_level', trt.Logger.WARNING), + fp16_mode=tensorrt_param.get('fp16_mode', False), + max_workspace_size=tensorrt_param.get('max_workspace_size', 0), + device_id=device_id) + + save_trt_engine(engine, osp.join(work_dir, save_file)) + + ret_value.value = 0 diff --git a/mmdeploy/apis/tensorrt/tensorrt_utils.py b/mmdeploy/apis/tensorrt/tensorrt_utils.py new file mode 100644 index 000000000..5254a9266 --- /dev/null +++ b/mmdeploy/apis/tensorrt/tensorrt_utils.py @@ -0,0 +1,222 @@ +import onnx +import tensorrt as trt +import torch + + +def onnx2trt(onnx_model, + opt_shape_dict, + log_level=trt.Logger.ERROR, + fp16_mode=False, + max_workspace_size=0, + device_id=0): + """Convert onnx model to tensorrt engine. + + Arguments: + onnx_model (str or onnx.ModelProto): the onnx model to convert from + opt_shape_dict (dict): the min/opt/max shape of each input + log_level (TensorRT log level): the log level of TensorRT + fp16_mode (bool): enable fp16 mode + max_workspace_size (int): set max workspace size of TensorRT engine. + some tactic and layers need large workspace. + device_id (int): choice the device to create engine. + + Returns: + tensorrt.ICudaEngine: the TensorRT engine created from onnx_model + + Example: + >>> engine = onnx2trt( + >>> "onnx_model.onnx", + >>> {'input': [[1, 3, 160, 160], + >>> [1, 3, 320, 320], + >>> [1, 3, 640, 640]]}, + >>> log_level=trt.Logger.WARNING, + >>> fp16_mode=True, + >>> max_workspace_size=1 << 30, + >>> device_id=0) + >>> }) + """ + device = torch.device('cuda:{}'.format(device_id)) + # create builder and network + logger = trt.Logger(log_level) + builder = trt.Builder(logger) + EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + + # parse onnx + parser = trt.OnnxParser(network, logger) + + if isinstance(onnx_model, str): + onnx_model = onnx.load(onnx_model) + + if not parser.parse(onnx_model.SerializeToString()): + error_msgs = '' + for error in range(parser.num_errors): + error_msgs += f'{parser.get_error(error)}\n' + raise RuntimeError(f'parse onnx failed:\n{error_msgs}') + + # config builder + builder.max_workspace_size = max_workspace_size + + config = builder.create_builder_config() + config.max_workspace_size = max_workspace_size + profile = builder.create_optimization_profile() + + for input_name, param in opt_shape_dict.items(): + min_shape = tuple(param[0][:]) + opt_shape = tuple(param[1][:]) + max_shape = tuple(param[2][:]) + profile.set_shape(input_name, min_shape, opt_shape, max_shape) + config.add_optimization_profile(profile) + + if fp16_mode: + builder.fp16_mode = fp16_mode + config.set_flag(trt.BuilderFlag.FP16) + + # create engine + with torch.cuda.device(device): + engine = builder.build_engine(network, config) + + return engine + + +def save_trt_engine(engine, path): + """Serialize TensorRT engine to disk. + + Arguments: + engine (tensorrt.ICudaEngine): TensorRT engine to serialize + path (str): disk path to write the engine + """ + with open(path, mode='wb') as f: + f.write(bytearray(engine.serialize())) + + +def load_trt_engine(path): + """Deserialize TensorRT engine from disk. + + Arguments: + path (str): disk path to read the engine + + Returns: + tensorrt.ICudaEngine: the TensorRT engine loaded from disk + """ + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + with open(path, mode='rb') as f: + engine_bytes = f.read() + engine = runtime.deserialize_cuda_engine(engine_bytes) + return engine + + +def torch_dtype_from_trt(dtype): + """Convert pytorch dtype to TensorRT dtype.""" + if dtype == trt.bool: + return torch.bool + elif dtype == trt.int8: + return torch.int8 + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + else: + raise TypeError('%s is not supported by torch' % dtype) + + +def torch_device_from_trt(device): + """Convert pytorch device to TensorRT device.""" + if device == trt.TensorLocation.DEVICE: + return torch.device('cuda') + elif device == trt.TensorLocation.HOST: + return torch.device('cpu') + else: + return TypeError('%s is not supported by torch' % device) + + +class TRTWrapper(torch.nn.Module): + """TensorRT engine Wrapper. + + Arguments: + engine (tensorrt.ICudaEngine): TensorRT engine to wrap + input_names (list[str]): names of each inputs + output_names (list[str]): names of each outputs + + Note: + If the engine is converted from onnx model. The input_names and + output_names should be the same as onnx model. + """ + + def __init__(self, engine): + super(TRTWrapper, self).__init__() + self.engine = engine + if isinstance(self.engine, str): + self.engine = load_trt_engine(engine) + + if not isinstance(self.engine, trt.ICudaEngine): + raise TypeError('engine should be str or trt.ICudaEngine') + + self._register_state_dict_hook(TRTWrapper._on_state_dict) + self.context = self.engine.create_execution_context() + + self._load_io_names() + + def _load_io_names(self): + # get input and output names from engine + names = [_ for _ in self.engine] + input_names = list(filter(self.engine.binding_is_input, names)) + output_names = list(set(names) - set(input_names)) + self.input_names = input_names + self.output_names = output_names + + def _on_state_dict(self, state_dict, prefix, local_metadata): + state_dict[prefix + 'engine'] = bytearray(self.engine.serialize()) + state_dict[prefix + 'input_names'] = self.input_names + state_dict[prefix + 'output_names'] = self.output_names + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + engine_bytes = state_dict[prefix + 'engine'] + + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + self.engine = runtime.deserialize_cuda_engine(engine_bytes) + self.context = self.engine.create_execution_context() + + self.input_names = state_dict[prefix + 'input_names'] + self.output_names = state_dict[prefix + 'output_names'] + + def forward(self, inputs): + """ + Arguments: + inputs (dict): dict of input name-tensors pair + + Return: + dict: dict of output name-tensors pair + """ + assert self.input_names is not None + assert self.output_names is not None + bindings = [None] * (len(self.input_names) + len(self.output_names)) + + for input_name, input_tensor in inputs.items(): + idx = self.engine.get_binding_index(input_name) + + if input_tensor.dtype == torch.long: + input_tensor = input_tensor.int() + self.context.set_binding_shape(idx, tuple(input_tensor.shape)) + bindings[idx] = input_tensor.contiguous().data_ptr() + + # create output tensors + outputs = {} + for i, output_name in enumerate(self.output_names): + idx = self.engine.get_binding_index(output_name) + dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + shape = tuple(self.context.get_binding_shape(idx)) + + device = torch_device_from_trt(self.engine.get_location(idx)) + output = torch.empty(size=shape, dtype=dtype, device=device) + outputs[output_name] = output + bindings[idx] = output.data_ptr() + + self.context.execute_async_v2(bindings, + torch.cuda.current_stream().cuda_stream) + + return outputs diff --git a/mmdeploy/mmcv/funcs.py b/mmdeploy/mmcv/funcs.py index 9b5462fd7..ae647dd87 100644 --- a/mmdeploy/mmcv/funcs.py +++ b/mmdeploy/mmcv/funcs.py @@ -45,3 +45,13 @@ def rewrite_topk_tensorrt(rewriter, k = int(k) return rewriter.origin_func( input, k, dim=dim, largest=largest, sorted=sorted) + + +@FUNCTION_REWRITERS.register_rewriter( + func_name='torch.Tensor.repeat', backend='tensorrt') +def rewrite_repeat_tensorrt(rewriter, input, *size): + origin_func = rewriter.origin_func + if input.dim() == 1 and len(size) == 1: + return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) + else: + return origin_func(input, *size) diff --git a/mmdeploy/mmcv/ops.py b/mmdeploy/mmcv/ops.py index 8947d1e6c..cb0250b86 100644 --- a/mmdeploy/mmcv/ops.py +++ b/mmdeploy/mmcv/ops.py @@ -66,9 +66,11 @@ def nms_tensorrt(symbolic_wrapper, g, boxes, scores, score_threshold = sym_help._maybe_get_const(score_threshold, 'f') return g.op( - 'NonMaxSuppression', + 'mmcv::NonMaxSuppression', boxes, scores, max_output_boxes_per_class_i=max_output_boxes_per_class, iou_threshold_f=iou_threshold, - score_threshold_f=score_threshold) + score_threshold_f=score_threshold, + center_point_box_i=0, + offset_i=0) diff --git a/tools/deploy.py b/tools/deploy.py index a3b1f75dc..380d0094b 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -19,6 +19,11 @@ def parse_args(): parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--device', help='device used for conversion', default='cpu') + parser.add_argument( + '--log-level', + help='set log level', + default='INFO', + choices=list(logging._nameToLevel.keys())) args = parser.parse_args() return args @@ -28,20 +33,31 @@ def main(): args = parse_args() set_start_method('spawn') - deploy_cfg = args.deploy_cfg - model_cfg = args.model_cfg - checkpoint = args.checkpoint + logger = logging.getLogger() + logger.setLevel(args.log_level) + + deploy_cfg_path = args.deploy_cfg + model_cfg_path = args.model_cfg + checkpoint_path = args.checkpoint + + # load deploy_cfg + deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path) + if not isinstance(deploy_cfg, mmcv.Config): + raise TypeError('deploy_cfg must be a filename or Config object, ' + f'but got {type(deploy_cfg)}') # create work_dir if not mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) ret_value = mp.Value('d', 0, lock=False) - # convert model + # convert onnx logging.info('start torch2onnx conversion.') + onnx_save_file = deploy_cfg['pytorch2onnx']['save_file'] process = Process( target=torch2onnx, - args=(args.img, args.work_dir, deploy_cfg, model_cfg, checkpoint), + args=(args.img, args.work_dir, onnx_save_file, deploy_cfg_path, + model_cfg_path, checkpoint_path), kwargs=dict(device=args.device, ret_value=ret_value)) process.start() process.join() @@ -52,6 +68,29 @@ def main(): else: logging.info('torch2onnx success.') + # convert backend + backend = deploy_cfg.get('backend', 'default') + onnx_paths = [osp.join(args.work_dir, onnx_save_file)] + if backend == 'tensorrt': + logging.info('start onnx2tensorrt conversion.') + from mmdeploy.apis.tensorrt import onnx2tensorrt + for onnx_path in onnx_paths: + process = Process( + target=onnx2tensorrt, + args=(args.work_dir, onnx_save_file, deploy_cfg_path, + onnx_path), + kwargs=dict(device=args.device, ret_value=ret_value)) + process.start() + process.join() + + if ret_value.value != 0: + logging.error('onnx2tensorrt failed.') + exit() + else: + logging.info('onnx2tensorrt success.') + + logging.info('All process success.') + if __name__ == '__main__': main()