mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Add tensorrt>=7.0.0
checks (#6193)
* Add `tensorrt>=7.0.0` checks * Update export.py * Update common.py * Update export.py
This commit is contained in:
parent
a2f4a1799b
commit
7b31a531b4
12
export.py
12
export.py
@ -61,8 +61,8 @@ from models.experimental import attempt_load
|
|||||||
from models.yolo import Detect
|
from models.yolo import Detect
|
||||||
from utils.activations import SiLU
|
from utils.activations import SiLU
|
||||||
from utils.datasets import LoadImages
|
from utils.datasets import LoadImages
|
||||||
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args,
|
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
|
||||||
url2file)
|
file_size, print_args, url2file)
|
||||||
from utils.torch_utils import select_device
|
from utils.torch_utils import select_device
|
||||||
|
|
||||||
|
|
||||||
@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|||||||
check_requirements(('tensorrt',))
|
check_requirements(('tensorrt',))
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
|
|
||||||
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
|
if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
||||||
if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
|
||||||
grid = model.model[-1].anchor_grid
|
grid = model.model[-1].anchor_grid
|
||||||
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
||||||
export_onnx(model, im, file, opset, train, False, simplify)
|
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
|
||||||
model.model[-1].anchor_grid = grid
|
model.model[-1].anchor_grid = grid
|
||||||
else: # TensorRT >= 8
|
else: # TensorRT >= 8
|
||||||
export_onnx(model, im, file, opset, train, False, simplify)
|
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
|
||||||
|
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
|
||||||
onnx = file.with_suffix('.onnx')
|
onnx = file.with_suffix('.onnx')
|
||||||
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module):
|
|||||||
elif engine: # TensorRT
|
elif engine: # TensorRT
|
||||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
|
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||||
logger = trt.Logger(trt.Logger.INFO)
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user