EasyCV/easycv/toolkit/blade/trt_plugin_utils.py
wuziheng 9aaa600f79
Yolox improve with REPConv/ASFF/TOOD (#154)
* add attention layer and more loss function

* add attention layer and various loss functions

* add siou loss

* add tah,various attention layers, and different loss functions

* add asff sim, gsconv

* blade utils fit faster

* blade optimize for yolox static & fp16

* decode output for yolox control by cfg

* add reparameterize_models for export

* e2e trt_nms plugin export support and numeric test

* split preprocess from end2end+blade, speedup from 17ms->7.2ms

Co-authored-by: zouxinyi0625 <zouxinyi.zxy@alibaba-inc.com>
2022-08-24 18:11:15 +08:00

157 lines
5.0 KiB
Python

# This is a TensorRT Plugin Python Wrapper Link implementation, original plugin documents refers to
# https://github.com/NVIDIA/TensorRT/tree/main/plugin/
# We use python wrapper to build ONNX-TRTPlugin Engine and then wrapper as a jit script module, after this,
# we could replace some original model's OP with this plugin during Blade Export to speed up those are not
# well optimized by original Blade
# Here we provide a TRTPlugin-EfficientNMS implementation
import torch
from torch import nn
class TRT8_NMS(torch.autograd.Function):
'''TensorRT NMS operation'''
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25):
out = g.op(
'TRT::EfficientNMS_TRT',
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes
class ONNX_TRT8(nn.Module):
'''onnx module with TensorRT NMS operation.'''
def __init__(self,
max_obj=100,
iou_thres=0.45,
score_thres=0.25,
max_wh=None,
device=None):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
def forward(self, x):
box = x[:, :, :4]
conf = x[:, :, 4:5]
score = x[:, :, 5:]
score *= conf
num_det, det_boxes, det_scores, det_classes = TRT8_NMS.apply(
box, score, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj, self.plugin_version,
self.score_activation, self.score_threshold)
return num_det, det_boxes, det_scores, det_classes
def create_tensorrt_efficientnms(example_scores,
iou_thres=0.45,
score_thres=0.25):
"""
"""
from torch_blade import tensorrt
import torch_blade._torch_blade._backends as backends
import io
model = torch.jit.trace(
ONNX_TRT8(iou_thres=iou_thres, score_thres=score_thres),
example_scores)
example_outputs = model(example_scores)
input_names = ['input']
output_names = [
'num_det', 'detection_boxes', 'detection_scores', 'detection_classes'
]
with io.BytesIO() as onnx_proto_f:
torch.onnx.export(
model,
example_scores,
onnx_proto_f,
input_names=input_names,
output_names=output_names,
example_outputs=example_outputs)
onnx_proto = onnx_proto_f.getvalue()
def _copy_meta(data, name, dtype, sizes):
data.name = name
if dtype.is_floating_point:
data.dtype = 'Float'
else:
data.dtype = 'Int'
data.sizes = sizes
return data
state = backends.EngineState()
state.inputs = [
_copy_meta(backends.TensorInfo(), name, tensor.dtype,
list(tensor.shape))
for name, tensor in zip(input_names, [example_scores])
]
state.outputs = [
_copy_meta(backends.TensorInfo(), name, tensor.dtype, [])
for name, tensor in zip(output_names, example_outputs)
]
state = tensorrt.cvt_onnx_to_tensorrt(onnx_proto, state, [], dict())
class Model(torch.nn.Module):
def __init__(self, state):
super().__init__()
self._trt_engine_ext = backends.create_engine(state)
def forward(self, x):
return self._trt_engine_ext.execute([x])
trt_ext = torch.jit.script(Model(state))
return trt_ext