Add end2end yolov7 onnx export for TensorRT8.0+ and onnxruntime(testing now) (#273)

* Add end2end yolov7 onnx export for TensorRT8.0+

* Add usage in README

* Update yolo.py

* Update yolo.py

* Add tensorrt onnxruntime examples

* Add usage in README

Co-authored-by: Alexey <AlexeyAB@users.noreply.github.com>
pull/314/head
tripleMu 2022-07-23 10:36:51 +08:00 committed by GitHub
parent aae70703f7
commit 1c59e43d9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 4064 additions and 9 deletions

View File

@ -164,6 +164,7 @@ python export.py --weights yolov7-tiny.pt --grid --include-nms
```
**ONNX to TensorRT**
```shell
git clone https://github.com/Linaom1214/tensorrt-python.git
cd tensorrt-python
@ -197,6 +198,25 @@ Yolov7-mask & YOLOv7-pose
</a>
</div>
## End2End Detect for TensorRT8+ and onnxruntime
Usage:
```shell
# export end2end onnx for TensorRT8+ backend
python export.py --weights yolov7-d6.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35
# convert onnx to TensorRT engine
/usr/src/tensorrt/bin/trtexec --onnx=yolov7-d6.onnx --saveEngine=yolov7-d6.engine --fp16
# export end2end onnx for onnxruntime backend
python export.py --weights yolov7-d6.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --max-wh 7680
```
See more information for tensorrt end2end detect in [end2end_tensorrt.ipynb](end2end_tensorrt.ipynb) .
See more information for onnxruntime end2end detect in [end2end_onnxruntime.ipynb](end2end_onnxruntime.ipynb) .
## Acknowledgements
<details><summary> <b>Expand</b> </summary>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
import models
from models.experimental import attempt_load
from models.experimental import attempt_load, End2End
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.torch_utils import select_device
@ -21,6 +21,11 @@ if __name__ == '__main__':
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
parser.add_argument('--max-wh', type=int, default=None, help='None for tensorrt nms, int value for onnx-runtime nms')
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
@ -74,14 +79,31 @@ if __name__ == '__main__':
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
model.eval()
output_names = ['classes', 'boxes'] if y is None else ['output']
if opt.grid and opt.end2end:
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device)
if opt.end2end and opt.max_wh is None:
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,
opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all]
else:
output_names = ['output']
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'],
output_names=output_names,
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic and not opt.end2end else None)
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
if opt.end2end and opt.max_wh is None:
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
# # Metadata
@ -98,9 +120,11 @@ if __name__ == '__main__':
print('\nStarting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
onnx.save(onnx_model, f)
except Exception as e:
print(f'Simplifier failure: {e}')
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
onnx.save(onnx_model,f)
print('ONNX export success, saved as %s' % f)
if opt.include_nms:

View File

@ -1,4 +1,5 @@
import numpy as np
import random
import torch
import torch.nn as nn
@ -80,6 +81,159 @@ class Ensemble(nn.ModuleList):
return y, None # inference, train output
class ORT_NMS(torch.autograd.Function):
'''ONNX-Runtime NMS operation'''
@staticmethod
def forward(ctx,
boxes,
scores,
max_output_boxes_per_class=torch.tensor([100]),
iou_threshold=torch.tensor([0.45]),
score_threshold=torch.tensor([0.25])):
device = boxes.device
batch = scores.shape[0]
num_det = random.randint(0, 100)
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices
@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
class TRT_NMS(torch.autograd.Function):
'''TensorRT NMS operation'''
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes
class ONNX_ORT(nn.Module):
'''onnx module with ONNX-Runtime NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
self.iou_threshold = torch.tensor([iou_thres]).to(device)
self.score_threshold = torch.tensor([score_thres]).to(device)
self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)
def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
boxes @= self.convert_matrix
max_score, category_id = scores.max(2, keepdim=True)
dis = category_id.float() * self.max_wh
nmsbox = boxes + dis
max_score_tp = max_score.transpose(1, 2).contiguous()
selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
X, Y = selected_indices[:, 0], selected_indices[:, 2]
selected_boxes = boxes[X, Y, :]
selected_categories = category_id[X, Y, :].float()
selected_scores = max_score[X, Y, :]
X = X.unsqueeze(1).float()
return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
class ONNX_TRT(nn.Module):
'''onnx module with TensorRT NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
self.score_threshold)
return num_det, det_boxes, det_scores, det_classes
class End2End(nn.Module):
'''export onnx or tensorrt model with NMS operation.'''
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
super().__init__()
device = device if device else torch.device('cpu')
assert isinstance(max_wh,(int)) or max_wh is None
self.model = model.to(device)
self.model.model[-1].end2end = True
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
self.end2end.eval()
def forward(self, x):
x = self.model(x)
x = self.end2end(x)
return x
def attempt_load(weights, map_location=None):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
@ -104,3 +258,5 @@ def attempt_load(weights, map_location=None):
for k in ['names', 'stride']:
setattr(model, k, getattr(model[-1], k))
return model # return ensemble

View File

@ -23,7 +23,9 @@ except ImportError:
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
end2end = False
include_nms = False
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
@ -58,10 +60,17 @@ class Detect(nn.Module):
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
if self.include_nms:
if self.training:
out = x
elif self.end2end:
out = torch.cat(z, 1)
elif self.include_nms:
z = self.convert(z)
out = (z, )
else:
out = (torch.cat(z, 1), x)
return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)
return out
@staticmethod
def _make_grid(nx=20, ny=20):
@ -84,6 +93,7 @@ class Detect(nn.Module):
class IDetect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
end2end = False
include_nms = False
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
@ -140,10 +150,17 @@ class IDetect(nn.Module):
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
if self.include_nms:
if self.training:
out = x
elif self.end2end:
out = torch.cat(z, 1)
elif self.include_nms:
z = self.convert(z)
out = (z, )
else:
out = (torch.cat(z, 1), x)
return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)
return out
def fuse(self):
print("IDetect.fuse")