[Feature] Support PyTorch model forward for TensorRT inference (#377)

* Support pytorch model forward for TensorRT

* Add ort wrapper

* Fix import

* Add deploy image-demo

* raise NotImplementedError

* Fix onnxruntime output gpu tensor
pull/413/head
tripleMu 2022-12-29 12:11:16 +08:00 committed by GitHub
parent 66c80e91e1
commit a44495868d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 336 additions and 149 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backendwrapper import BackendWrapper
from .backendwrapper import ORTWrapper, TRTWrapper
from .model import DeployModel
__all__ = ['DeployModel', 'BackendWrapper']
__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper']

View File

@ -1,5 +1,5 @@
import warnings
from collections import OrderedDict, namedtuple
from collections import namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
@ -12,167 +12,202 @@ try:
except Exception:
trt = None
import torch
from numpy import ndarray
from torch import Tensor
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class BackendWrapper:
class TRTWrapper(torch.nn.Module):
def __init__(
self,
weight: Union[str, Path],
device: Optional[Union[str, int, torch.device]] = None) -> None:
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix in ('.onnx', '.engine',
'.plan')
assert weight.exists() and weight.suffix in ('.engine', '.plan')
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__build_model()
self.__init_runtime()
self.__warm_up(10)
self.stream = torch.cuda.Stream(device=device)
self.__init_engine()
self.__init_bindings()
def __build_model(self) -> None:
model_info = dict()
num_input = num_output = 0
names = []
is_dynamic = False
if self.weight.suffix == '.onnx':
model_info['backend'] = 'ONNXRuntime'
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
model = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
for i, tensor in enumerate(model.get_inputs()):
model_info[tensor.name] = dict(
shape=tensor.shape, dtype=tensor.type)
num_input += 1
names.append(tensor.name)
is_dynamic |= any(
map(lambda x: isinstance(x, str), tensor.shape))
for i, tensor in enumerate(model.get_outputs()):
model_info[tensor.name] = dict(
shape=tensor.shape, dtype=tensor.type)
num_output += 1
names.append(tensor.name)
else:
model_info['backend'] = 'TensorRT'
logger = trt.Logger(trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(
self.weight.read_bytes())
profile_shape = []
for i in range(model.num_bindings):
name = model.get_binding_name(i)
shape = tuple(model.get_binding_shape(i))
dtype = trt.nptype(model.get_binding_dtype(i))
is_dynamic |= (-1 in shape)
if model.binding_is_input(i):
num_input += 1
profile_shape.append(model.get_profile_shape(i, 0))
else:
num_output += 1
model_info[name] = dict(shape=shape, dtype=dtype)
names.append(name)
model_info['profile_shape'] = profile_shape
def __init_engine(self):
logger = trt.Logger(trt.Logger.ERROR)
self.log = partial(logger.log, trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
self.logger = logger
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
self.num_input = num_input
self.num_output = num_output
self.names = names
self.is_dynamic = is_dynamic
self.model = model
self.model_info = model_info
context = model.create_execution_context()
def __init_runtime(self) -> None:
bindings = OrderedDict()
Binding = namedtuple('Binding',
('name', 'dtype', 'shape', 'data', 'ptr'))
if self.model_info['backend'] == 'TensorRT':
context = self.model.create_execution_context()
for name in self.names:
shape, dtype = self.model_info[name].values()
if self.is_dynamic:
cpu_tensor, gpu_tensor, ptr = None, None, None
else:
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
gpu_tensor = torch.from_numpy(cpu_tensor).to(self.device)
ptr = int(gpu_tensor.data_ptr())
bindings[name] = Binding(name, dtype, shape, gpu_tensor, ptr)
else:
output_names = []
for i, name in enumerate(self.names):
if i >= self.num_input:
output_names.append(name)
shape, dtype = self.model_info[name].values()
bindings[name] = Binding(name, dtype, shape, None, None)
context = partial(self.model.run, output_names)
self.addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
self.bindings = bindings
self.context = context
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
def __infer(
self, inputs: List[Union[ndarray,
Tensor]]) -> List[Union[ndarray, Tensor]]:
assert len(inputs) == self.num_input
if self.model_info['backend'] == 'TensorRT':
outputs = []
for i, (name, gpu_input) in enumerate(
zip(self.names[:self.num_input], inputs)):
if self.is_dynamic:
self.context.set_binding_shape(i, gpu_input.shape)
self.addrs[name] = gpu_input.data_ptr()
num_inputs, num_outputs = 0, 0
for i, name in enumerate(self.names[self.num_input:]):
i += self.num_input
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(i))
dtype = self.bindings[name].dtype
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
out = torch.from_numpy(cpu_tensor).to(self.device)
self.addrs[name] = out.data_ptr()
else:
out = self.bindings[name].data
outputs.append(out)
assert self.context.execute_v2(list(
self.addrs.values())), 'Infer fault'
else:
input_feed = {
name: inputs[i]
for i, name in enumerate(self.names[:self.num_input])
}
outputs = self.context(input_feed)
return outputs
def __warm_up(self, n=10) -> None:
for _ in range(n):
_tmp = []
if self.model_info['backend'] == 'TensorRT':
for i, name in enumerate(self.names[:self.num_input]):
if self.is_dynamic:
shape = self.model_info['profile_shape'][i][1]
dtype = self.bindings[name].dtype
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
_tmp.append(
torch.from_numpy(cpu_tensor).to(self.device))
else:
_tmp.append(self.bindings[name].data)
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
print('Please warm up ONNXRuntime model by yourself')
print("So this model doesn't warm up")
return
_ = self.__infer(_tmp)
num_outputs += 1
def __call__(
self, inputs: Union[List, Tensor,
ndarray]) -> List[Union[Tensor, ndarray]]:
if not isinstance(inputs, list):
inputs = [inputs]
outputs = self.__infer(inputs)
return outputs
self.is_dynamic = -1 in model.get_binding_shape(0)
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_bindings = num_inputs + num_outputs
self.bindings: List[int] = [0] * self.num_bindings
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
inputs_info.append(Binding(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
shape = tuple(self.model.get_binding_shape(i))
outputs_info.append(Binding(name, dtype, shape))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
if not self.is_dynamic:
self.output_tensor = [
torch.empty(o.shape, dtype=o.dtype, device=self.device)
for o in outputs_info
]
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
# create output tensors
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(
size=shape,
dtype=self.output_dtypes[i],
device=self.device)
else:
output = self.output_tensor[i]
outputs.append(output)
self.bindings[j] = output.data_ptr()
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs)
@staticmethod
def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype:
"""Convert TensorRT data types to PyTorch data types.
Args:
dtype (TRTDataType): A TensorRT data type.
Returns:
The equivalent PyTorch data type.
"""
if dtype == trt.int8:
return torch.int8
elif trt.__version__ >= '7.0' and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError(f'{dtype} is not supported by torch')
class ORTWrapper(torch.nn.Module):
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]):
super().__init__()
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__init_session()
self.__init_bindings()
def __init_session(self):
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
session = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
self.session = session
def __init_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
inputs_info = []
outputs_info = []
self.is_dynamic = False
for i, tensor in enumerate(self.session.get_inputs()):
if any(not isinstance(i, int) for i in tensor.shape):
self.is_dynamic = True
inputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
for i, tensor in enumerate(self.session.get_outputs()):
outputs_info.append(
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
self.inputs_info = inputs_info
self.outputs_info = outputs_info
self.num_inputs = len(inputs_info)
def forward(self, *inputs):
assert len(inputs) == self.num_inputs
contiguous_inputs: List[np.ndarray] = [
i.contiguous().cpu().numpy() for i in inputs
]
if not self.is_dynamic:
# make sure input shape is right for static input shape
for i in range(self.num_inputs):
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
outputs = self.session.run([o.name for o in self.outputs_info], {
j.name: contiguous_inputs[i]
for i, j in enumerate(self.inputs_info)
})
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)

View File

@ -6,9 +6,13 @@ try:
import tensorrt as trt
except Exception:
trt = None
import warnings
import numpy as np
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class EngineBuilder:

View File

@ -90,7 +90,7 @@ def main():
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
output_names = ['num_det', 'det_boxes', 'det_scores', 'det_classes']
output_names = ['num_dets', 'boxes', 'scores', 'labels']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
deploy_model = DeployModel(

View File

@ -0,0 +1,148 @@
# Copyright (c) OpenMMLab. All rights reserved.
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
import os
import random
from argparse import ArgumentParser
import cv2
import mmcv
import numpy as np
import torch
from mmcv.transforms import Compose
from mmdet.utils import get_test_pipeline_cfg
from mmengine.config import Config
from mmengine.utils import ProgressBar
from mmyolo.utils import register_all_modules
from mmyolo.utils.misc import get_file_list
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
args = parser.parse_args()
return args
def preprocess(config):
data_preprocess = config.get('model', {}).get('data_preprocessor', {})
mean = data_preprocess.get('mean', [0., 0., 0.])
std = data_preprocess.get('std', [1., 1., 1.])
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)
class PreProcess(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x[None].float()
x -= mean.to(x.device)
x /= std.to(x.device)
return x
return PreProcess().eval()
def main():
args = parse_args()
# register all modules in mmdet into the registries
register_all_modules()
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]
# build the model from a config file and a checkpoint file
if args.checkpoint.endswith('.onnx'):
model = ORTWrapper(args.checkpoint, args.device)
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
'.plan'):
model = TRTWrapper(args.checkpoint, args.device)
else:
raise NotImplementedError
model.to(args.device)
cfg = Config.fromfile(args.config)
test_pipeline = get_test_pipeline_cfg(cfg)
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
test_pipeline = Compose(test_pipeline)
pre_pipeline = preprocess(cfg)
if not os.path.exists(args.out_dir) and not args.show:
os.mkdir(args.out_dir)
# get file list
files, source_type = get_file_list(args.img)
# start detector inference
progress_bar = ProgressBar(len(files))
for i, file in enumerate(files):
# result = inference_detector(model, file)
bgr = mmcv.imread(file)
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
pad_param = samples.get('pad_param',
np.array([0, 0, 0, 0], dtype=np.float32))
h, w = samples.get('ori_shape', rgb.shape[:2])
pad_param = torch.asarray(
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
device=args.device)
scale_factor = samples.get('scale_factor', [1., 1])
scale_factor = torch.asarray(scale_factor * 2, device=args.device)
data = pre_pipeline(data).to(args.device)
result = model(data)
if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
# Get candidate predict info by num_dets
num_dets, bboxes, scores, labels = result
scores = scores[0, :num_dets]
bboxes = bboxes[0, :num_dets]
labels = labels[0, :num_dets]
bboxes -= pad_param
bboxes /= scale_factor
bboxes[:, 0::2].clamp_(0, w)
bboxes[:, 1::2].clamp_(0, h)
bboxes = bboxes.round().int()
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.tolist()
color = colors[label]
name = f'cls:{label}_score:{score:0.4f}'
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
cv2.putText(
bgr,
name, (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
mmcv.imshow(bgr, 'result', 0)
else:
mmcv.imwrite(bgr, out_file)
progress_bar.update()
if __name__ == '__main__':
main()