mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
* add attention layer and more loss function * add attention layer and various loss functions * add siou loss * add tah,various attention layers, and different loss functions * add asff sim, gsconv * blade utils fit faster * blade optimize for yolox static & fp16 * decode output for yolox control by cfg * add reparameterize_models for export * e2e trt_nms plugin export support and numeric test * split preprocess from end2end+blade, speedup from 17ms->7.2ms Co-authored-by: zouxinyi0625 <zouxinyi.zxy@alibaba-inc.com>
157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
# This is a TensorRT Plugin Python Wrapper Link implementation, original plugin documents refers to
|
|
# https://github.com/NVIDIA/TensorRT/tree/main/plugin/
|
|
# We use python wrapper to build ONNX-TRTPlugin Engine and then wrapper as a jit script module, after this,
|
|
# we could replace some original model's OP with this plugin during Blade Export to speed up those are not
|
|
# well optimized by original Blade
|
|
# Here we provide a TRTPlugin-EfficientNMS implementation
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class TRT8_NMS(torch.autograd.Function):
|
|
'''TensorRT NMS operation'''
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
boxes,
|
|
scores,
|
|
background_class=-1,
|
|
box_coding=1,
|
|
iou_threshold=0.45,
|
|
max_output_boxes=100,
|
|
plugin_version='1',
|
|
score_activation=0,
|
|
score_threshold=0.25,
|
|
):
|
|
batch_size, num_boxes, num_classes = scores.shape
|
|
num_det = torch.randint(
|
|
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
|
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
|
det_scores = torch.randn(batch_size, max_output_boxes)
|
|
det_classes = torch.randint(
|
|
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
|
return num_det, det_boxes, det_scores, det_classes
|
|
|
|
@staticmethod
|
|
def symbolic(g,
|
|
boxes,
|
|
scores,
|
|
background_class=-1,
|
|
box_coding=1,
|
|
iou_threshold=0.45,
|
|
max_output_boxes=100,
|
|
plugin_version='1',
|
|
score_activation=0,
|
|
score_threshold=0.25):
|
|
out = g.op(
|
|
'TRT::EfficientNMS_TRT',
|
|
boxes,
|
|
scores,
|
|
background_class_i=background_class,
|
|
box_coding_i=box_coding,
|
|
iou_threshold_f=iou_threshold,
|
|
max_output_boxes_i=max_output_boxes,
|
|
plugin_version_s=plugin_version,
|
|
score_activation_i=score_activation,
|
|
score_threshold_f=score_threshold,
|
|
outputs=4)
|
|
nums, boxes, scores, classes = out
|
|
return nums, boxes, scores, classes
|
|
|
|
|
|
class ONNX_TRT8(nn.Module):
|
|
'''onnx module with TensorRT NMS operation.'''
|
|
|
|
def __init__(self,
|
|
max_obj=100,
|
|
iou_thres=0.45,
|
|
score_thres=0.25,
|
|
max_wh=None,
|
|
device=None):
|
|
super().__init__()
|
|
assert max_wh is None
|
|
self.device = device if device else torch.device('cpu')
|
|
self.background_class = -1,
|
|
self.box_coding = 1,
|
|
self.iou_threshold = iou_thres
|
|
self.max_obj = max_obj
|
|
self.plugin_version = '1'
|
|
self.score_activation = 0
|
|
self.score_threshold = score_thres
|
|
|
|
def forward(self, x):
|
|
box = x[:, :, :4]
|
|
conf = x[:, :, 4:5]
|
|
score = x[:, :, 5:]
|
|
score *= conf
|
|
num_det, det_boxes, det_scores, det_classes = TRT8_NMS.apply(
|
|
box, score, self.background_class, self.box_coding,
|
|
self.iou_threshold, self.max_obj, self.plugin_version,
|
|
self.score_activation, self.score_threshold)
|
|
return num_det, det_boxes, det_scores, det_classes
|
|
|
|
|
|
def create_tensorrt_efficientnms(example_scores,
|
|
iou_thres=0.45,
|
|
score_thres=0.25):
|
|
"""
|
|
|
|
"""
|
|
from torch_blade import tensorrt
|
|
import torch_blade._torch_blade._backends as backends
|
|
import io
|
|
|
|
model = torch.jit.trace(
|
|
ONNX_TRT8(iou_thres=iou_thres, score_thres=score_thres),
|
|
example_scores)
|
|
example_outputs = model(example_scores)
|
|
|
|
input_names = ['input']
|
|
output_names = [
|
|
'num_det', 'detection_boxes', 'detection_scores', 'detection_classes'
|
|
]
|
|
with io.BytesIO() as onnx_proto_f:
|
|
torch.onnx.export(
|
|
model,
|
|
example_scores,
|
|
onnx_proto_f,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
example_outputs=example_outputs)
|
|
onnx_proto = onnx_proto_f.getvalue()
|
|
|
|
def _copy_meta(data, name, dtype, sizes):
|
|
data.name = name
|
|
if dtype.is_floating_point:
|
|
data.dtype = 'Float'
|
|
else:
|
|
data.dtype = 'Int'
|
|
data.sizes = sizes
|
|
return data
|
|
|
|
state = backends.EngineState()
|
|
state.inputs = [
|
|
_copy_meta(backends.TensorInfo(), name, tensor.dtype,
|
|
list(tensor.shape))
|
|
for name, tensor in zip(input_names, [example_scores])
|
|
]
|
|
state.outputs = [
|
|
_copy_meta(backends.TensorInfo(), name, tensor.dtype, [])
|
|
for name, tensor in zip(output_names, example_outputs)
|
|
]
|
|
state = tensorrt.cvt_onnx_to_tensorrt(onnx_proto, state, [], dict())
|
|
|
|
class Model(torch.nn.Module):
|
|
|
|
def __init__(self, state):
|
|
super().__init__()
|
|
self._trt_engine_ext = backends.create_engine(state)
|
|
|
|
def forward(self, x):
|
|
return self._trt_engine_ext.execute([x])
|
|
|
|
trt_ext = torch.jit.script(Model(state))
|
|
return trt_ext
|