Merge branch 'master' into main
commit
32b97fb179
14
export.py
14
export.py
|
@ -346,6 +346,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||
onnx = file.with_suffix(".onnx")
|
||||
|
||||
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
||||
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
|
||||
assert onnx.exists(), f"failed to export ONNX file: {onnx}"
|
||||
f = file.with_suffix(".engine") # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
|
@ -354,9 +355,10 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = workspace * 1 << 30
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
||||
|
||||
if is_trt10:
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)
|
||||
else: # TensorRT versions 7, 8
|
||||
config.max_workspace_size = workspace * 1 << 30
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
@ -381,8 +383,10 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||
LOGGER.info(f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}")
|
||||
if builder.platform_has_fast_fp16 and half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
|
||||
t.write(engine.serialize())
|
||||
|
||||
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
||||
with build(network, config) as engine, open(f, "wb") as t:
|
||||
t.write(engine if is_trt10 else engine.serialize())
|
||||
return f, None
|
||||
|
||||
|
||||
|
|
|
@ -527,18 +527,34 @@ class DetectMultiBackend(nn.Module):
|
|||
output_names = []
|
||||
fp16 = False # default updated below
|
||||
dynamic = False
|
||||
for i in range(model.num_bindings):
|
||||
name = model.get_binding_name(i)
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
if model.binding_is_input(i):
|
||||
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
||||
dynamic = True
|
||||
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else: # output
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_binding_shape(i))
|
||||
is_trt10 = not hasattr(model, "num_bindings")
|
||||
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
|
||||
for i in num:
|
||||
if is_trt10:
|
||||
name = model.get_tensor_name(i)
|
||||
dtype = trt.nptype(model.get_tensor_dtype(name))
|
||||
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
||||
if is_input:
|
||||
if -1 in tuple(model.get_tensor_shape(name)): # dynamic
|
||||
dynamic = True
|
||||
context.set_input_shape(name, tuple(model.get_profile_shape(name, 0)[2]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else: # output
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_tensor_shape(name))
|
||||
else:
|
||||
name = model.get_binding_name(i)
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
if model.binding_is_input(i):
|
||||
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
||||
dynamic = True
|
||||
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else: # output
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_binding_shape(i))
|
||||
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
|
|
Loading…
Reference in New Issue