mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support PyTorch model forward for TensorRT inference (#377)
* Support pytorch model forward for TensorRT * Add ort wrapper * Fix import * Add deploy image-demo * raise NotImplementedError * Fix onnxruntime output gpu tensorpull/413/head
parent
66c80e91e1
commit
a44495868d
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backendwrapper import BackendWrapper
|
||||
from .backendwrapper import ORTWrapper, TRTWrapper
|
||||
from .model import DeployModel
|
||||
|
||||
__all__ = ['DeployModel', 'BackendWrapper']
|
||||
__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper']
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
@ -12,167 +12,202 @@ try:
|
|||
except Exception:
|
||||
trt = None
|
||||
import torch
|
||||
from numpy import ndarray
|
||||
from torch import Tensor
|
||||
|
||||
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
||||
|
||||
|
||||
class BackendWrapper:
|
||||
class TRTWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: Union[str, Path],
|
||||
device: Optional[Union[str, int, torch.device]] = None) -> None:
|
||||
def __init__(self, weight: Union[str, Path],
|
||||
device: Optional[torch.device]):
|
||||
super().__init__()
|
||||
weight = Path(weight) if isinstance(weight, str) else weight
|
||||
assert weight.exists() and weight.suffix in ('.onnx', '.engine',
|
||||
'.plan')
|
||||
assert weight.exists() and weight.suffix in ('.engine', '.plan')
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f'cuda:{device}')
|
||||
self.weight = weight
|
||||
self.device = device
|
||||
self.__build_model()
|
||||
self.__init_runtime()
|
||||
self.__warm_up(10)
|
||||
self.stream = torch.cuda.Stream(device=device)
|
||||
self.__init_engine()
|
||||
self.__init_bindings()
|
||||
|
||||
def __build_model(self) -> None:
|
||||
model_info = dict()
|
||||
num_input = num_output = 0
|
||||
names = []
|
||||
is_dynamic = False
|
||||
if self.weight.suffix == '.onnx':
|
||||
model_info['backend'] = 'ONNXRuntime'
|
||||
providers = ['CPUExecutionProvider']
|
||||
if 'cuda' in self.device.type:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
model = onnxruntime.InferenceSession(
|
||||
str(self.weight), providers=providers)
|
||||
for i, tensor in enumerate(model.get_inputs()):
|
||||
model_info[tensor.name] = dict(
|
||||
shape=tensor.shape, dtype=tensor.type)
|
||||
num_input += 1
|
||||
names.append(tensor.name)
|
||||
is_dynamic |= any(
|
||||
map(lambda x: isinstance(x, str), tensor.shape))
|
||||
for i, tensor in enumerate(model.get_outputs()):
|
||||
model_info[tensor.name] = dict(
|
||||
shape=tensor.shape, dtype=tensor.type)
|
||||
num_output += 1
|
||||
names.append(tensor.name)
|
||||
else:
|
||||
model_info['backend'] = 'TensorRT'
|
||||
logger = trt.Logger(trt.Logger.ERROR)
|
||||
trt.init_libnvinfer_plugins(logger, namespace='')
|
||||
with trt.Runtime(logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(
|
||||
self.weight.read_bytes())
|
||||
profile_shape = []
|
||||
for i in range(model.num_bindings):
|
||||
name = model.get_binding_name(i)
|
||||
shape = tuple(model.get_binding_shape(i))
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
is_dynamic |= (-1 in shape)
|
||||
if model.binding_is_input(i):
|
||||
num_input += 1
|
||||
profile_shape.append(model.get_profile_shape(i, 0))
|
||||
else:
|
||||
num_output += 1
|
||||
model_info[name] = dict(shape=shape, dtype=dtype)
|
||||
names.append(name)
|
||||
model_info['profile_shape'] = profile_shape
|
||||
def __init_engine(self):
|
||||
logger = trt.Logger(trt.Logger.ERROR)
|
||||
self.log = partial(logger.log, trt.Logger.ERROR)
|
||||
trt.init_libnvinfer_plugins(logger, namespace='')
|
||||
self.logger = logger
|
||||
with trt.Runtime(logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
|
||||
|
||||
self.num_input = num_input
|
||||
self.num_output = num_output
|
||||
self.names = names
|
||||
self.is_dynamic = is_dynamic
|
||||
self.model = model
|
||||
self.model_info = model_info
|
||||
context = model.create_execution_context()
|
||||
|
||||
def __init_runtime(self) -> None:
|
||||
bindings = OrderedDict()
|
||||
Binding = namedtuple('Binding',
|
||||
('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
context = self.model.create_execution_context()
|
||||
for name in self.names:
|
||||
shape, dtype = self.model_info[name].values()
|
||||
if self.is_dynamic:
|
||||
cpu_tensor, gpu_tensor, ptr = None, None, None
|
||||
else:
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
gpu_tensor = torch.from_numpy(cpu_tensor).to(self.device)
|
||||
ptr = int(gpu_tensor.data_ptr())
|
||||
bindings[name] = Binding(name, dtype, shape, gpu_tensor, ptr)
|
||||
else:
|
||||
output_names = []
|
||||
for i, name in enumerate(self.names):
|
||||
if i >= self.num_input:
|
||||
output_names.append(name)
|
||||
shape, dtype = self.model_info[name].values()
|
||||
bindings[name] = Binding(name, dtype, shape, None, None)
|
||||
context = partial(self.model.run, output_names)
|
||||
self.addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
self.bindings = bindings
|
||||
self.context = context
|
||||
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
|
||||
|
||||
def __infer(
|
||||
self, inputs: List[Union[ndarray,
|
||||
Tensor]]) -> List[Union[ndarray, Tensor]]:
|
||||
assert len(inputs) == self.num_input
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
outputs = []
|
||||
for i, (name, gpu_input) in enumerate(
|
||||
zip(self.names[:self.num_input], inputs)):
|
||||
if self.is_dynamic:
|
||||
self.context.set_binding_shape(i, gpu_input.shape)
|
||||
self.addrs[name] = gpu_input.data_ptr()
|
||||
num_inputs, num_outputs = 0, 0
|
||||
|
||||
for i, name in enumerate(self.names[self.num_input:]):
|
||||
i += self.num_input
|
||||
if self.is_dynamic:
|
||||
shape = tuple(self.context.get_binding_shape(i))
|
||||
dtype = self.bindings[name].dtype
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
out = torch.from_numpy(cpu_tensor).to(self.device)
|
||||
self.addrs[name] = out.data_ptr()
|
||||
else:
|
||||
out = self.bindings[name].data
|
||||
outputs.append(out)
|
||||
assert self.context.execute_v2(list(
|
||||
self.addrs.values())), 'Infer fault'
|
||||
else:
|
||||
input_feed = {
|
||||
name: inputs[i]
|
||||
for i, name in enumerate(self.names[:self.num_input])
|
||||
}
|
||||
outputs = self.context(input_feed)
|
||||
return outputs
|
||||
|
||||
def __warm_up(self, n=10) -> None:
|
||||
for _ in range(n):
|
||||
_tmp = []
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
for i, name in enumerate(self.names[:self.num_input]):
|
||||
if self.is_dynamic:
|
||||
shape = self.model_info['profile_shape'][i][1]
|
||||
dtype = self.bindings[name].dtype
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
_tmp.append(
|
||||
torch.from_numpy(cpu_tensor).to(self.device))
|
||||
else:
|
||||
_tmp.append(self.bindings[name].data)
|
||||
for i in range(model.num_bindings):
|
||||
if model.binding_is_input(i):
|
||||
num_inputs += 1
|
||||
else:
|
||||
print('Please warm up ONNXRuntime model by yourself')
|
||||
print("So this model doesn't warm up")
|
||||
return
|
||||
_ = self.__infer(_tmp)
|
||||
num_outputs += 1
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[List, Tensor,
|
||||
ndarray]) -> List[Union[Tensor, ndarray]]:
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
outputs = self.__infer(inputs)
|
||||
return outputs
|
||||
self.is_dynamic = -1 in model.get_binding_shape(0)
|
||||
|
||||
self.model = model
|
||||
self.context = context
|
||||
self.input_names = names[:num_inputs]
|
||||
self.output_names = names[num_inputs:]
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.num_bindings = num_inputs + num_outputs
|
||||
self.bindings: List[int] = [0] * self.num_bindings
|
||||
|
||||
def __init_bindings(self):
|
||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
||||
inputs_info = []
|
||||
outputs_info = []
|
||||
|
||||
for i, name in enumerate(self.input_names):
|
||||
assert self.model.get_binding_name(i) == name
|
||||
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
|
||||
shape = tuple(self.model.get_binding_shape(i))
|
||||
inputs_info.append(Binding(name, dtype, shape))
|
||||
|
||||
for i, name in enumerate(self.output_names):
|
||||
i += self.num_inputs
|
||||
assert self.model.get_binding_name(i) == name
|
||||
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
|
||||
shape = tuple(self.model.get_binding_shape(i))
|
||||
outputs_info.append(Binding(name, dtype, shape))
|
||||
self.inputs_info = inputs_info
|
||||
self.outputs_info = outputs_info
|
||||
if not self.is_dynamic:
|
||||
self.output_tensor = [
|
||||
torch.empty(o.shape, dtype=o.dtype, device=self.device)
|
||||
for o in outputs_info
|
||||
]
|
||||
|
||||
def forward(self, *inputs):
|
||||
|
||||
assert len(inputs) == self.num_inputs
|
||||
|
||||
contiguous_inputs: List[torch.Tensor] = [
|
||||
i.contiguous() for i in inputs
|
||||
]
|
||||
|
||||
for i in range(self.num_inputs):
|
||||
self.bindings[i] = contiguous_inputs[i].data_ptr()
|
||||
if self.is_dynamic:
|
||||
self.context.set_binding_shape(
|
||||
i, tuple(contiguous_inputs[i].shape))
|
||||
|
||||
# create output tensors
|
||||
outputs: List[torch.Tensor] = []
|
||||
|
||||
for i in range(self.num_outputs):
|
||||
j = i + self.num_inputs
|
||||
if self.is_dynamic:
|
||||
shape = tuple(self.context.get_binding_shape(j))
|
||||
output = torch.empty(
|
||||
size=shape,
|
||||
dtype=self.output_dtypes[i],
|
||||
device=self.device)
|
||||
|
||||
else:
|
||||
output = self.output_tensor[i]
|
||||
outputs.append(output)
|
||||
self.bindings[j] = output.data_ptr()
|
||||
|
||||
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
|
||||
self.stream.synchronize()
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
@staticmethod
|
||||
def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype:
|
||||
"""Convert TensorRT data types to PyTorch data types.
|
||||
|
||||
Args:
|
||||
dtype (TRTDataType): A TensorRT data type.
|
||||
Returns:
|
||||
The equivalent PyTorch data type.
|
||||
"""
|
||||
if dtype == trt.int8:
|
||||
return torch.int8
|
||||
elif trt.__version__ >= '7.0' and dtype == trt.bool:
|
||||
return torch.bool
|
||||
elif dtype == trt.int32:
|
||||
return torch.int32
|
||||
elif dtype == trt.float16:
|
||||
return torch.float16
|
||||
elif dtype == trt.float32:
|
||||
return torch.float32
|
||||
else:
|
||||
raise TypeError(f'{dtype} is not supported by torch')
|
||||
|
||||
|
||||
class ORTWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(self, weight: Union[str, Path],
|
||||
device: Optional[torch.device]):
|
||||
super().__init__()
|
||||
weight = Path(weight) if isinstance(weight, str) else weight
|
||||
assert weight.exists() and weight.suffix == '.onnx'
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f'cuda:{device}')
|
||||
self.weight = weight
|
||||
self.device = device
|
||||
self.__init_session()
|
||||
self.__init_bindings()
|
||||
|
||||
def __init_session(self):
|
||||
providers = ['CPUExecutionProvider']
|
||||
if 'cuda' in self.device.type:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
|
||||
session = onnxruntime.InferenceSession(
|
||||
str(self.weight), providers=providers)
|
||||
self.session = session
|
||||
|
||||
def __init_bindings(self):
|
||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
||||
inputs_info = []
|
||||
outputs_info = []
|
||||
self.is_dynamic = False
|
||||
for i, tensor in enumerate(self.session.get_inputs()):
|
||||
if any(not isinstance(i, int) for i in tensor.shape):
|
||||
self.is_dynamic = True
|
||||
inputs_info.append(
|
||||
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
||||
|
||||
for i, tensor in enumerate(self.session.get_outputs()):
|
||||
outputs_info.append(
|
||||
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
||||
self.inputs_info = inputs_info
|
||||
self.outputs_info = outputs_info
|
||||
self.num_inputs = len(inputs_info)
|
||||
|
||||
def forward(self, *inputs):
|
||||
|
||||
assert len(inputs) == self.num_inputs
|
||||
|
||||
contiguous_inputs: List[np.ndarray] = [
|
||||
i.contiguous().cpu().numpy() for i in inputs
|
||||
]
|
||||
|
||||
if not self.is_dynamic:
|
||||
# make sure input shape is right for static input shape
|
||||
for i in range(self.num_inputs):
|
||||
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
|
||||
|
||||
outputs = self.session.run([o.name for o in self.outputs_info], {
|
||||
j.name: contiguous_inputs[i]
|
||||
for i, j in enumerate(self.inputs_info)
|
||||
})
|
||||
|
||||
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)
|
||||
|
|
|
@ -6,9 +6,13 @@ try:
|
|||
import tensorrt as trt
|
||||
except Exception:
|
||||
trt = None
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
||||
|
||||
|
||||
class EngineBuilder:
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ def main():
|
|||
iou_threshold=args.iou_threshold,
|
||||
score_threshold=args.score_threshold,
|
||||
backend=args.backend)
|
||||
output_names = ['num_det', 'det_boxes', 'det_scores', 'det_classes']
|
||||
output_names = ['num_dets', 'boxes', 'scores', 'labels']
|
||||
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
|
||||
|
||||
deploy_model = DeployModel(
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
|
||||
import os
|
||||
import random
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.transforms import Compose
|
||||
from mmdet.utils import get_test_pipeline_cfg
|
||||
from mmengine.config import Config
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
from mmyolo.utils import register_all_modules
|
||||
from mmyolo.utils.misc import get_file_list
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
'img', help='Image path, include image file, dir and URL.')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out-dir', default='./output', help='Path to output file')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show the detection results')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def preprocess(config):
|
||||
data_preprocess = config.get('model', {}).get('data_preprocessor', {})
|
||||
mean = data_preprocess.get('mean', [0., 0., 0.])
|
||||
std = data_preprocess.get('std', [1., 1., 1.])
|
||||
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
|
||||
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)
|
||||
|
||||
class PreProcess(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
x = x[None].float()
|
||||
x -= mean.to(x.device)
|
||||
x /= std.to(x.device)
|
||||
return x
|
||||
|
||||
return PreProcess().eval()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
register_all_modules()
|
||||
|
||||
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
if args.checkpoint.endswith('.onnx'):
|
||||
model = ORTWrapper(args.checkpoint, args.device)
|
||||
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
|
||||
'.plan'):
|
||||
model = TRTWrapper(args.checkpoint, args.device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
test_pipeline = get_test_pipeline_cfg(cfg)
|
||||
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
|
||||
pre_pipeline = preprocess(cfg)
|
||||
|
||||
if not os.path.exists(args.out_dir) and not args.show:
|
||||
os.mkdir(args.out_dir)
|
||||
|
||||
# get file list
|
||||
files, source_type = get_file_list(args.img)
|
||||
|
||||
# start detector inference
|
||||
progress_bar = ProgressBar(len(files))
|
||||
for i, file in enumerate(files):
|
||||
# result = inference_detector(model, file)
|
||||
|
||||
bgr = mmcv.imread(file)
|
||||
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
|
||||
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
|
||||
pad_param = samples.get('pad_param',
|
||||
np.array([0, 0, 0, 0], dtype=np.float32))
|
||||
h, w = samples.get('ori_shape', rgb.shape[:2])
|
||||
pad_param = torch.asarray(
|
||||
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
|
||||
device=args.device)
|
||||
scale_factor = samples.get('scale_factor', [1., 1])
|
||||
scale_factor = torch.asarray(scale_factor * 2, device=args.device)
|
||||
data = pre_pipeline(data).to(args.device)
|
||||
|
||||
result = model(data)
|
||||
if source_type['is_dir']:
|
||||
filename = os.path.relpath(file, args.img).replace('/', '_')
|
||||
else:
|
||||
filename = os.path.basename(file)
|
||||
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
||||
|
||||
# Get candidate predict info by num_dets
|
||||
num_dets, bboxes, scores, labels = result
|
||||
scores = scores[0, :num_dets]
|
||||
bboxes = bboxes[0, :num_dets]
|
||||
labels = labels[0, :num_dets]
|
||||
bboxes -= pad_param
|
||||
bboxes /= scale_factor
|
||||
|
||||
bboxes[:, 0::2].clamp_(0, w)
|
||||
bboxes[:, 1::2].clamp_(0, h)
|
||||
bboxes = bboxes.round().int()
|
||||
|
||||
for (bbox, score, label) in zip(bboxes, scores, labels):
|
||||
bbox = bbox.tolist()
|
||||
color = colors[label]
|
||||
name = f'cls:{label}_score:{score:0.4f}'
|
||||
|
||||
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
|
||||
cv2.putText(
|
||||
bgr,
|
||||
name, (bbox[0], bbox[1] - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.75, [225, 255, 255],
|
||||
thickness=2)
|
||||
|
||||
if args.show:
|
||||
mmcv.imshow(bgr, 'result', 0)
|
||||
else:
|
||||
mmcv.imwrite(bgr, out_file)
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue