diff --git a/export.py b/export.py index 6cf1db2c4..a0cb5fdc5 100644 --- a/export.py +++ b/export.py @@ -61,8 +61,8 @@ from models.experimental import attempt_load from models.yolo import Detect from utils.activations import SiLU from utils.datasets import LoadImages -from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args, - url2file) +from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr, + file_size, print_args, url2file) from utils.torch_utils import select_device @@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F check_requirements(('tensorrt',)) import tensorrt as trt - opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x - if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 + if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 grid = model.model[-1].anchor_grid model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] - export_onnx(model, im, file, opset, train, False, simplify) + export_onnx(model, im, file, 12, train, False, simplify) # opset 12 model.model[-1].anchor_grid = grid else: # TensorRT >= 8 - export_onnx(model, im, file, opset, train, False, simplify) + check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0 + export_onnx(model, im, file, 13, train, False, simplify) # opset 13 onnx = file.with_suffix('.onnx') assert onnx.exists(), f'failed to export ONNX file: {onnx}' diff --git a/models/common.py b/models/common.py index 284dd2bb3..836314568 100644 --- a/models/common.py +++ b/models/common.py @@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module): elif engine: # TensorRT LOGGER.info(f'Loading {w} for TensorRT inference...') import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download - check_version(trt.__version__, '8.0.0', verbose=True) # version requirement + check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: