update API for TensorRT8.4 (#1144)

(cherry picked from commit 4a150e5e1bec7703781ba80fcb1336ccf26e6074)
This commit is contained in:
q.yao 2022-10-24 10:52:18 +08:00 committed by lvhan028
parent f311cfd437
commit 8236e9ebc5

View File

@ -169,7 +169,12 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
builder.max_workspace_size = max_workspace_size
config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size
if hasattr(config, 'set_memory_pool_limit'):
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
max_workspace_size)
else:
config.max_workspace_size = max_workspace_size
cuda_version = search_cuda_version()
if cuda_version is not None:
@ -187,14 +192,19 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
opt_shape = param['opt_shape']
max_shape = param['max_shape']
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
if config.add_optimization_profile(profile) < 0:
logger.warning(f'Invalid optimization profile {profile}.')
if fp16_mode:
if not getattr(builder, 'platform_has_fast_fp16', True):
logger.warning('Platform does not has fast native fp16.')
if version.parse(trt.__version__) < version.parse('8'):
builder.fp16_mode = fp16_mode
config.set_flag(trt.BuilderFlag.FP16)
if int8_mode:
if not getattr(builder, 'platform_has_fast_int8', True):
logger.warning('Platform does not has fast native int8.')
from .calib_utils import HDF5Calibrator
config.set_flag(trt.BuilderFlag.INT8)
assert int8_param is not None