mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
feat: enable timing cache for engine export
This commit is contained in:
parent
8ee1670d57
commit
10bf52d087
21
export.py
21
export.py
@ -593,7 +593,8 @@ def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("Co
|
|||||||
|
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr("TensorRT:")):
|
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, cache="",
|
||||||
|
prefix=colorstr("TensorRT:")):
|
||||||
"""
|
"""
|
||||||
Export a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.
|
Export a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.
|
||||||
|
|
||||||
@ -606,6 +607,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||||||
simplify (bool): Set to True to simplify the model during export.
|
simplify (bool): Set to True to simplify the model during export.
|
||||||
workspace (int): Workspace size in GB (default is 4).
|
workspace (int): Workspace size in GB (default is 4).
|
||||||
verbose (bool): Set to True for verbose logging output.
|
verbose (bool): Set to True for verbose logging output.
|
||||||
|
cache (str): Path to save the TensorRT timing cache.
|
||||||
prefix (str): Log message prefix.
|
prefix (str): Log message prefix.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -660,6 +662,11 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)
|
||||||
else: # TensorRT versions 7, 8
|
else: # TensorRT versions 7, 8
|
||||||
config.max_workspace_size = workspace * 1 << 30
|
config.max_workspace_size = workspace * 1 << 30
|
||||||
|
if cache: # enable timing cache
|
||||||
|
Path(cache).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
buf = Path(cache).read_bytes() if Path(cache).exists() else b""
|
||||||
|
timing_cache = config.create_timing_cache(buf)
|
||||||
|
config.set_timing_cache(timing_cache, ignore_mismatch=True)
|
||||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
network = builder.create_network(flag)
|
network = builder.create_network(flag)
|
||||||
parser = trt.OnnxParser(network, logger)
|
parser = trt.OnnxParser(network, logger)
|
||||||
@ -688,6 +695,9 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||||||
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
||||||
with build(network, config) as engine, open(f, "wb") as t:
|
with build(network, config) as engine, open(f, "wb") as t:
|
||||||
t.write(engine if is_trt10 else engine.serialize())
|
t.write(engine if is_trt10 else engine.serialize())
|
||||||
|
if cache: # save timing cache
|
||||||
|
with open(cache, "wb") as c:
|
||||||
|
c.write(config.get_timing_cache().serialize())
|
||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
|
|
||||||
@ -1277,6 +1287,7 @@ def run(
|
|||||||
int8=False, # CoreML/TF INT8 quantization
|
int8=False, # CoreML/TF INT8 quantization
|
||||||
per_tensor=False, # TF per tensor quantization
|
per_tensor=False, # TF per tensor quantization
|
||||||
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
|
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
|
||||||
|
cache="", # TensorRT: timing cache path
|
||||||
simplify=False, # ONNX: simplify model
|
simplify=False, # ONNX: simplify model
|
||||||
mlmodel=False, # CoreML: Export in *.mlmodel format
|
mlmodel=False, # CoreML: Export in *.mlmodel format
|
||||||
opset=12, # ONNX: opset version
|
opset=12, # ONNX: opset version
|
||||||
@ -1306,6 +1317,7 @@ def run(
|
|||||||
int8 (bool): Apply INT8 quantization for CoreML or TensorFlow models. Default is False.
|
int8 (bool): Apply INT8 quantization for CoreML or TensorFlow models. Default is False.
|
||||||
per_tensor (bool): Apply per tensor quantization for TensorFlow models. Default is False.
|
per_tensor (bool): Apply per tensor quantization for TensorFlow models. Default is False.
|
||||||
dynamic (bool): Enable dynamic axes for ONNX, TensorFlow, or TensorRT exports. Default is False.
|
dynamic (bool): Enable dynamic axes for ONNX, TensorFlow, or TensorRT exports. Default is False.
|
||||||
|
cache (str): TensorRT timing cache path. Default is an empty string.
|
||||||
simplify (bool): Simplify the ONNX model during export. Default is False.
|
simplify (bool): Simplify the ONNX model during export. Default is False.
|
||||||
opset (int): ONNX opset version. Default is 12.
|
opset (int): ONNX opset version. Default is 12.
|
||||||
verbose (bool): Enable verbose logging for TensorRT export. Default is False.
|
verbose (bool): Enable verbose logging for TensorRT export. Default is False.
|
||||||
@ -1341,6 +1353,7 @@ def run(
|
|||||||
int8=False,
|
int8=False,
|
||||||
per_tensor=False,
|
per_tensor=False,
|
||||||
dynamic=False,
|
dynamic=False,
|
||||||
|
cache="",
|
||||||
simplify=False,
|
simplify=False,
|
||||||
opset=12,
|
opset=12,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
@ -1378,7 +1391,8 @@ def run(
|
|||||||
# Input
|
# Input
|
||||||
gs = int(max(model.stride)) # grid size (max stride)
|
gs = int(max(model.stride)) # grid size (max stride)
|
||||||
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
||||||
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
ch = next(model.parameters()).size(1) # require input image channels
|
||||||
|
im = torch.zeros(batch_size, ch, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||||
|
|
||||||
# Update model
|
# Update model
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -1402,7 +1416,7 @@ def run(
|
|||||||
if jit: # TorchScript
|
if jit: # TorchScript
|
||||||
f[0], _ = export_torchscript(model, im, file, optimize)
|
f[0], _ = export_torchscript(model, im, file, optimize)
|
||||||
if engine: # TensorRT required before ONNX
|
if engine: # TensorRT required before ONNX
|
||||||
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
|
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose, cache)
|
||||||
if onnx or xml: # OpenVINO requires ONNX
|
if onnx or xml: # OpenVINO requires ONNX
|
||||||
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
|
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
|
||||||
if xml: # OpenVINO
|
if xml: # OpenVINO
|
||||||
@ -1497,6 +1511,7 @@ def parse_opt(known=False):
|
|||||||
parser.add_argument("--int8", action="store_true", help="CoreML/TF/OpenVINO INT8 quantization")
|
parser.add_argument("--int8", action="store_true", help="CoreML/TF/OpenVINO INT8 quantization")
|
||||||
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
|
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
|
||||||
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
|
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
|
||||||
|
parser.add_argument("--cache", type=str, default="", help="TensorRT: timing cache file path")
|
||||||
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
|
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
|
||||||
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
|
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
|
||||||
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
|
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user