mirror of https://github.com/open-mmlab/mmyolo.git
214 lines
7.1 KiB
Python
214 lines
7.1 KiB
Python
import warnings
|
|
from collections import namedtuple
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
import onnxruntime
|
|
|
|
try:
|
|
import tensorrt as trt
|
|
except Exception:
|
|
trt = None
|
|
import torch
|
|
|
|
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
|
|
|
|
|
class TRTWrapper(torch.nn.Module):
|
|
|
|
def __init__(self, weight: Union[str, Path],
|
|
device: Optional[torch.device]):
|
|
super().__init__()
|
|
weight = Path(weight) if isinstance(weight, str) else weight
|
|
assert weight.exists() and weight.suffix in ('.engine', '.plan')
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
elif isinstance(device, int):
|
|
device = torch.device(f'cuda:{device}')
|
|
self.weight = weight
|
|
self.device = device
|
|
self.stream = torch.cuda.Stream(device=device)
|
|
self.__init_engine()
|
|
self.__init_bindings()
|
|
|
|
def __init_engine(self):
|
|
logger = trt.Logger(trt.Logger.ERROR)
|
|
self.log = partial(logger.log, trt.Logger.ERROR)
|
|
trt.init_libnvinfer_plugins(logger, namespace='')
|
|
self.logger = logger
|
|
with trt.Runtime(logger) as runtime:
|
|
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
|
|
|
|
context = model.create_execution_context()
|
|
|
|
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
|
|
|
|
num_inputs, num_outputs = 0, 0
|
|
|
|
for i in range(model.num_bindings):
|
|
if model.binding_is_input(i):
|
|
num_inputs += 1
|
|
else:
|
|
num_outputs += 1
|
|
|
|
self.is_dynamic = -1 in model.get_binding_shape(0)
|
|
|
|
self.model = model
|
|
self.context = context
|
|
self.input_names = names[:num_inputs]
|
|
self.output_names = names[num_inputs:]
|
|
self.num_inputs = num_inputs
|
|
self.num_outputs = num_outputs
|
|
self.num_bindings = num_inputs + num_outputs
|
|
self.bindings: List[int] = [0] * self.num_bindings
|
|
|
|
def __init_bindings(self):
|
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
|
inputs_info = []
|
|
outputs_info = []
|
|
|
|
for i, name in enumerate(self.input_names):
|
|
assert self.model.get_binding_name(i) == name
|
|
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
|
|
shape = tuple(self.model.get_binding_shape(i))
|
|
inputs_info.append(Binding(name, dtype, shape))
|
|
|
|
for i, name in enumerate(self.output_names):
|
|
i += self.num_inputs
|
|
assert self.model.get_binding_name(i) == name
|
|
dtype = self.torch_dtype_from_trt(self.model.get_binding_dtype(i))
|
|
shape = tuple(self.model.get_binding_shape(i))
|
|
outputs_info.append(Binding(name, dtype, shape))
|
|
self.inputs_info = inputs_info
|
|
self.outputs_info = outputs_info
|
|
if not self.is_dynamic:
|
|
self.output_tensor = [
|
|
torch.empty(o.shape, dtype=o.dtype, device=self.device)
|
|
for o in outputs_info
|
|
]
|
|
|
|
def forward(self, *inputs):
|
|
|
|
assert len(inputs) == self.num_inputs
|
|
|
|
contiguous_inputs: List[torch.Tensor] = [
|
|
i.contiguous() for i in inputs
|
|
]
|
|
|
|
for i in range(self.num_inputs):
|
|
self.bindings[i] = contiguous_inputs[i].data_ptr()
|
|
if self.is_dynamic:
|
|
self.context.set_binding_shape(
|
|
i, tuple(contiguous_inputs[i].shape))
|
|
|
|
# create output tensors
|
|
outputs: List[torch.Tensor] = []
|
|
|
|
for i in range(self.num_outputs):
|
|
j = i + self.num_inputs
|
|
if self.is_dynamic:
|
|
shape = tuple(self.context.get_binding_shape(j))
|
|
output = torch.empty(
|
|
size=shape,
|
|
dtype=self.output_dtypes[i],
|
|
device=self.device)
|
|
|
|
else:
|
|
output = self.output_tensor[i]
|
|
outputs.append(output)
|
|
self.bindings[j] = output.data_ptr()
|
|
|
|
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
|
|
self.stream.synchronize()
|
|
|
|
return tuple(outputs)
|
|
|
|
@staticmethod
|
|
def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype:
|
|
"""Convert TensorRT data types to PyTorch data types.
|
|
|
|
Args:
|
|
dtype (TRTDataType): A TensorRT data type.
|
|
Returns:
|
|
The equivalent PyTorch data type.
|
|
"""
|
|
if dtype == trt.int8:
|
|
return torch.int8
|
|
elif trt.__version__ >= '7.0' and dtype == trt.bool:
|
|
return torch.bool
|
|
elif dtype == trt.int32:
|
|
return torch.int32
|
|
elif dtype == trt.float16:
|
|
return torch.float16
|
|
elif dtype == trt.float32:
|
|
return torch.float32
|
|
else:
|
|
raise TypeError(f'{dtype} is not supported by torch')
|
|
|
|
|
|
class ORTWrapper(torch.nn.Module):
|
|
|
|
def __init__(self, weight: Union[str, Path],
|
|
device: Optional[torch.device]):
|
|
super().__init__()
|
|
weight = Path(weight) if isinstance(weight, str) else weight
|
|
assert weight.exists() and weight.suffix == '.onnx'
|
|
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
elif isinstance(device, int):
|
|
device = torch.device(f'cuda:{device}')
|
|
self.weight = weight
|
|
self.device = device
|
|
self.__init_session()
|
|
self.__init_bindings()
|
|
|
|
def __init_session(self):
|
|
providers = ['CPUExecutionProvider']
|
|
if 'cuda' in self.device.type:
|
|
providers.insert(0, 'CUDAExecutionProvider')
|
|
|
|
session = onnxruntime.InferenceSession(
|
|
str(self.weight), providers=providers)
|
|
self.session = session
|
|
|
|
def __init_bindings(self):
|
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape'))
|
|
inputs_info = []
|
|
outputs_info = []
|
|
self.is_dynamic = False
|
|
for i, tensor in enumerate(self.session.get_inputs()):
|
|
if any(not isinstance(i, int) for i in tensor.shape):
|
|
self.is_dynamic = True
|
|
inputs_info.append(
|
|
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
|
|
|
for i, tensor in enumerate(self.session.get_outputs()):
|
|
outputs_info.append(
|
|
Binding(tensor.name, tensor.type, tuple(tensor.shape)))
|
|
self.inputs_info = inputs_info
|
|
self.outputs_info = outputs_info
|
|
self.num_inputs = len(inputs_info)
|
|
|
|
def forward(self, *inputs):
|
|
|
|
assert len(inputs) == self.num_inputs
|
|
|
|
contiguous_inputs: List[np.ndarray] = [
|
|
i.contiguous().cpu().numpy() for i in inputs
|
|
]
|
|
|
|
if not self.is_dynamic:
|
|
# make sure input shape is right for static input shape
|
|
for i in range(self.num_inputs):
|
|
assert contiguous_inputs[i].shape == self.inputs_info[i].shape
|
|
|
|
outputs = self.session.run([o.name for o in self.outputs_info], {
|
|
j.name: contiguous_inputs[i]
|
|
for i, j in enumerate(self.inputs_info)
|
|
})
|
|
|
|
return tuple(torch.from_numpy(o).to(self.device) for o in outputs)
|