diff --git a/projects/easydeploy/model/__init__.py b/projects/easydeploy/model/__init__.py index bccb936a..52d6043e 100644 --- a/projects/easydeploy/model/__init__.py +++ b/projects/easydeploy/model/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backendwrapper import BackendWrapper +from .backendwrapper import ORTWrapper, TRTWrapper from .model import DeployModel -__all__ = ['DeployModel', 'BackendWrapper'] +__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper'] diff --git a/projects/easydeploy/model/backendwrapper.py b/projects/easydeploy/model/backendwrapper.py index 32478f0f..35949e28 100644 --- a/projects/easydeploy/model/backendwrapper.py +++ b/projects/easydeploy/model/backendwrapper.py @@ -1,5 +1,5 @@ import warnings -from collections import OrderedDict, namedtuple +from collections import namedtuple from functools import partial from pathlib import Path from typing import List, Optional, Union @@ -12,167 +12,202 @@ try: except Exception: trt = None import torch -from numpy import ndarray -from torch import Tensor warnings.filterwarnings(action='ignore', category=DeprecationWarning) -class BackendWrapper: +class TRTWrapper(torch.nn.Module): - def __init__( - self, - weight: Union[str, Path], - device: Optional[Union[str, int, torch.device]] = None) -> None: + def __init__(self, weight: Union[str, Path], + device: Optional[torch.device]): + super().__init__() weight = Path(weight) if isinstance(weight, str) else weight - assert weight.exists() and weight.suffix in ('.onnx', '.engine', - '.plan') + assert weight.exists() and weight.suffix in ('.engine', '.plan') if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device(f'cuda:{device}') self.weight = weight self.device = device - self.__build_model() - self.__init_runtime() - self.__warm_up(10) + self.stream = torch.cuda.Stream(device=device) + self.__init_engine() + self.__init_bindings() - def __build_model(self) -> None: - model_info = dict() - num_input = num_output = 0 - names = [] - is_dynamic = False - if self.weight.suffix == '.onnx': - model_info['backend'] = 'ONNXRuntime' - providers = ['CPUExecutionProvider'] - if 'cuda' in self.device.type: - providers.insert(0, 'CUDAExecutionProvider') - model = onnxruntime.InferenceSession( - str(self.weight), providers=providers) - for i, tensor in enumerate(model.get_inputs()): - model_info[tensor.name] = dict( - shape=tensor.shape, dtype=tensor.type) - num_input += 1 - names.append(tensor.name) - is_dynamic |= any( - map(lambda x: isinstance(x, str), tensor.shape)) - for i, tensor in enumerate(model.get_outputs()): - model_info[tensor.name] = dict( - shape=tensor.shape, dtype=tensor.type) - num_output += 1 - names.append(tensor.name) - else: - model_info['backend'] = 'TensorRT' - logger = trt.Logger(trt.Logger.ERROR) - trt.init_libnvinfer_plugins(logger, namespace='') - with trt.Runtime(logger) as runtime: - model = runtime.deserialize_cuda_engine( - self.weight.read_bytes()) - profile_shape = [] - for i in range(model.num_bindings): - name = model.get_binding_name(i) - shape = tuple(model.get_binding_shape(i)) - dtype = trt.nptype(model.get_binding_dtype(i)) - is_dynamic |= (-1 in shape) - if model.binding_is_input(i): - num_input += 1 - profile_shape.append(model.get_profile_shape(i, 0)) - else: - num_output += 1 - model_info[name] = dict(shape=shape, dtype=dtype) - names.append(name) - model_info['profile_shape'] = profile_shape + def __init_engine(self): + logger = trt.Logger(trt.Logger.ERROR) + self.log = partial(logger.log, trt.Logger.ERROR) + trt.init_libnvinfer_plugins(logger, namespace='') + self.logger = logger + with trt.Runtime(logger) as runtime: + model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) - self.num_input = num_input - self.num_output = num_output - self.names = names - self.is_dynamic = is_dynamic - self.model = model - self.model_info = model_info + context = model.create_execution_context() - def __init_runtime(self) -> None: - bindings = OrderedDict() - Binding = namedtuple('Binding', - ('name', 'dtype', 'shape', 'data', 'ptr')) - if self.model_info['backend'] == 'TensorRT': - context = self.model.create_execution_context() - for name in self.names: - shape, dtype = self.model_info[name].values() - if self.is_dynamic: - cpu_tensor, gpu_tensor, ptr = None, None, None - else: - cpu_tensor = np.empty(shape, dtype=np.dtype(dtype)) - gpu_tensor = torch.from_numpy(cpu_tensor).to(self.device) - ptr = int(gpu_tensor.data_ptr()) - bindings[name] = Binding(name, dtype, shape, gpu_tensor, ptr) - else: - output_names = [] - for i, name in enumerate(self.names): - if i >= self.num_input: - output_names.append(name) - shape, dtype = self.model_info[name].values() - bindings[name] = Binding(name, dtype, shape, None, None) - context = partial(self.model.run, output_names) - self.addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) - self.bindings = bindings - self.context = context + names = [model.get_binding_name(i) for i in range(model.num_bindings)] - def __infer( - self, inputs: List[Union[ndarray, - Tensor]]) -> List[Union[ndarray, Tensor]]: - assert len(inputs) == self.num_input - if self.model_info['backend'] == 'TensorRT': - outputs = [] - for i, (name, gpu_input) in enumerate( - zip(self.names[:self.num_input], inputs)): - if self.is_dynamic: - self.context.set_binding_shape(i, gpu_input.shape) - self.addrs[name] = gpu_input.data_ptr() + num_inputs, num_outputs = 0, 0 - for i, name in enumerate(self.names[self.num_input:]): - i += self.num_input - if self.is_dynamic: - shape = tuple(self.context.get_binding_shape(i)) - dtype = self.bindings[name].dtype - cpu_tensor = np.empty(shape, dtype=np.dtype(dtype)) - out = torch.from_numpy(cpu_tensor).to(self.device) - self.addrs[name] = out.data_ptr() - else: - out = self.bindings[name].data - outputs.append(out) - assert self.context.execute_v2(list( - self.addrs.values())), 'Infer fault' - else: - input_feed = { - name: inputs[i] - for i, name in enumerate(self.names[:self.num_input]) - } - outputs = self.context(input_feed) - return outputs - - def __warm_up(self, n=10) -> None: - for _ in range(n): - _tmp = [] - if self.model_info['backend'] == 'TensorRT': - for i, name in enumerate(self.names[:self.num_input]): - if self.is_dynamic: - shape = self.model_info['profile_shape'][i][1] - dtype = self.bindings[name].dtype - cpu_tensor = np.empty(shape, dtype=np.dtype(dtype)) - _tmp.append( - torch.from_numpy(cpu_tensor).to(self.device)) - else: - _tmp.append(self.bindings[name].data) + for i in range(model.num_bindings): + if model.binding_is_input(i): + num_inputs += 1 else: - print('Please warm up ONNXRuntime model by yourself') - print("So this model doesn't warm up") - return - _ = self.__infer(_tmp) + num_outputs += 1 - def __call__( - self, inputs: Union[List, Tensor, - ndarray]) -> List[Union[Tensor, ndarray]]: - if not isinstance(inputs, list): - inputs = [inputs] - outputs = self.__infer(inputs) - return outputs + self.is_dynamic = -1 in model.get_binding_shape(0) + + self.model = model + self.context = context + self.input_names = names[:num_inputs] + self.output_names = names[num_inputs:] + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.num_bindings = num_inputs + num_outputs + self.bindings: List[int] = [0] * self.num_bindings + + def __init_bindings(self): + Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) + inputs_info = [] + outputs_info = [] + + for i, name in enumerate(self.input_names): + assert self.model.get_binding_name(i) == name + dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + inputs_info.append(Binding(name, dtype, shape)) + + for i, name in enumerate(self.output_names): + i += self.num_inputs + assert self.model.get_binding_name(i) == name + dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i)) + shape = tuple(self.model.get_binding_shape(i)) + outputs_info.append(Binding(name, dtype, shape)) + self.inputs_info = inputs_info + self.outputs_info = outputs_info + if not self.is_dynamic: + self.output_tensor = [ + torch.empty(o.shape, dtype=o.dtype, device=self.device) + for o in outputs_info + ] + + def forward(self, *inputs): + + assert len(inputs) == self.num_inputs + + contiguous_inputs: List[torch.Tensor] = [ + i.contiguous() for i in inputs + ] + + for i in range(self.num_inputs): + self.bindings[i] = contiguous_inputs[i].data_ptr() + if self.is_dynamic: + self.context.set_binding_shape( + i, tuple(contiguous_inputs[i].shape)) + + # create output tensors + outputs: List[torch.Tensor] = [] + + for i in range(self.num_outputs): + j = i + self.num_inputs + if self.is_dynamic: + shape = tuple(self.context.get_binding_shape(j)) + output = torch.empty( + size=shape, + dtype=self.output_dtypes[i], + device=self.device) + + else: + output = self.output_tensor[i] + outputs.append(output) + self.bindings[j] = output.data_ptr() + + self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) + self.stream.synchronize() + + return tuple(outputs) + + @staticmethod + def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype: + """Convert TensorRT data types to PyTorch data types. + + Args: + dtype (TRTDataType): A TensorRT data type. + Returns: + The equivalent PyTorch data type. + """ + if dtype == trt.int8: + return torch.int8 + elif trt.__version__ >= '7.0' and dtype == trt.bool: + return torch.bool + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + else: + raise TypeError(f'{dtype} is not supported by torch') + + +class ORTWrapper(torch.nn.Module): + + def __init__(self, weight: Union[str, Path], + device: Optional[torch.device]): + super().__init__() + weight = Path(weight) if isinstance(weight, str) else weight + assert weight.exists() and weight.suffix == '.onnx' + + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device(f'cuda:{device}') + self.weight = weight + self.device = device + self.__init_session() + self.__init_bindings() + + def __init_session(self): + providers = ['CPUExecutionProvider'] + if 'cuda' in self.device.type: + providers.insert(0, 'CUDAExecutionProvider') + + session = onnxruntime.InferenceSession( + str(self.weight), providers=providers) + self.session = session + + def __init_bindings(self): + Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) + inputs_info = [] + outputs_info = [] + self.is_dynamic = False + for i, tensor in enumerate(self.session.get_inputs()): + if any(not isinstance(i, int) for i in tensor.shape): + self.is_dynamic = True + inputs_info.append( + Binding(tensor.name, tensor.type, tuple(tensor.shape))) + + for i, tensor in enumerate(self.session.get_outputs()): + outputs_info.append( + Binding(tensor.name, tensor.type, tuple(tensor.shape))) + self.inputs_info = inputs_info + self.outputs_info = outputs_info + self.num_inputs = len(inputs_info) + + def forward(self, *inputs): + + assert len(inputs) == self.num_inputs + + contiguous_inputs: List[np.ndarray] = [ + i.contiguous().cpu().numpy() for i in inputs + ] + + if not self.is_dynamic: + # make sure input shape is right for static input shape + for i in range(self.num_inputs): + assert contiguous_inputs[i].shape == self.inputs_info[i].shape + + outputs = self.session.run([o.name for o in self.outputs_info], { + j.name: contiguous_inputs[i] + for i, j in enumerate(self.inputs_info) + }) + + return tuple(torch.from_numpy(o).to(self.device) for o in outputs) diff --git a/projects/easydeploy/tools/build_engine.py b/projects/easydeploy/tools/build_engine.py index afb95887..b400c9db 100644 --- a/projects/easydeploy/tools/build_engine.py +++ b/projects/easydeploy/tools/build_engine.py @@ -6,9 +6,13 @@ try: import tensorrt as trt except Exception: trt = None +import warnings + import numpy as np import torch +warnings.filterwarnings(action='ignore', category=DeprecationWarning) + class EngineBuilder: diff --git a/projects/easydeploy/tools/export.py b/projects/easydeploy/tools/export.py index 2c0f0877..39d9fcfc 100644 --- a/projects/easydeploy/tools/export.py +++ b/projects/easydeploy/tools/export.py @@ -90,7 +90,7 @@ def main(): iou_threshold=args.iou_threshold, score_threshold=args.score_threshold, backend=args.backend) - output_names = ['num_det', 'det_boxes', 'det_scores', 'det_classes'] + output_names = ['num_dets', 'boxes', 'scores', 'labels'] baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) deploy_model = DeployModel( diff --git a/projects/easydeploy/tools/image-demo.py b/projects/easydeploy/tools/image-demo.py new file mode 100644 index 00000000..2b1da95f --- /dev/null +++ b/projects/easydeploy/tools/image-demo.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip +import os +import random +from argparse import ArgumentParser + +import cv2 +import mmcv +import numpy as np +import torch +from mmcv.transforms import Compose +from mmdet.utils import get_test_pipeline_cfg +from mmengine.config import Config +from mmengine.utils import ProgressBar + +from mmyolo.utils import register_all_modules +from mmyolo.utils.misc import get_file_list + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + 'img', help='Image path, include image file, dir and URL.') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--out-dir', default='./output', help='Path to output file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--show', action='store_true', help='Show the detection results') + args = parser.parse_args() + return args + + +def preprocess(config): + data_preprocess = config.get('model', {}).get('data_preprocessor', {}) + mean = data_preprocess.get('mean', [0., 0., 0.]) + std = data_preprocess.get('std', [1., 1., 1.]) + mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1) + std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1) + + class PreProcess(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + x = x[None].float() + x -= mean.to(x.device) + x /= std.to(x.device) + return x + + return PreProcess().eval() + + +def main(): + args = parse_args() + + # register all modules in mmdet into the registries + register_all_modules() + + colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)] + + # build the model from a config file and a checkpoint file + if args.checkpoint.endswith('.onnx'): + model = ORTWrapper(args.checkpoint, args.device) + elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith( + '.plan'): + model = TRTWrapper(args.checkpoint, args.device) + else: + raise NotImplementedError + + model.to(args.device) + + cfg = Config.fromfile(args.config) + + test_pipeline = get_test_pipeline_cfg(cfg) + test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' + test_pipeline = Compose(test_pipeline) + + pre_pipeline = preprocess(cfg) + + if not os.path.exists(args.out_dir) and not args.show: + os.mkdir(args.out_dir) + + # get file list + files, source_type = get_file_list(args.img) + + # start detector inference + progress_bar = ProgressBar(len(files)) + for i, file in enumerate(files): + # result = inference_detector(model, file) + + bgr = mmcv.imread(file) + rgb = mmcv.imconvert(bgr, 'bgr', 'rgb') + data, samples = test_pipeline(dict(img=rgb, img_id=i)).values() + pad_param = samples.get('pad_param', + np.array([0, 0, 0, 0], dtype=np.float32)) + h, w = samples.get('ori_shape', rgb.shape[:2]) + pad_param = torch.asarray( + [pad_param[2], pad_param[0], pad_param[2], pad_param[0]], + device=args.device) + scale_factor = samples.get('scale_factor', [1., 1]) + scale_factor = torch.asarray(scale_factor * 2, device=args.device) + data = pre_pipeline(data).to(args.device) + + result = model(data) + if source_type['is_dir']: + filename = os.path.relpath(file, args.img).replace('/', '_') + else: + filename = os.path.basename(file) + out_file = None if args.show else os.path.join(args.out_dir, filename) + + # Get candidate predict info by num_dets + num_dets, bboxes, scores, labels = result + scores = scores[0, :num_dets] + bboxes = bboxes[0, :num_dets] + labels = labels[0, :num_dets] + bboxes -= pad_param + bboxes /= scale_factor + + bboxes[:, 0::2].clamp_(0, w) + bboxes[:, 1::2].clamp_(0, h) + bboxes = bboxes.round().int() + + for (bbox, score, label) in zip(bboxes, scores, labels): + bbox = bbox.tolist() + color = colors[label] + name = f'cls:{label}_score:{score:0.4f}' + + cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2) + cv2.putText( + bgr, + name, (bbox[0], bbox[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], + thickness=2) + + if args.show: + mmcv.imshow(bgr, 'result', 0) + else: + mmcv.imwrite(bgr, out_file) + progress_bar.update() + + +if __name__ == '__main__': + main()