mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
update API for TensorRT8.4 (#1144)
(cherry picked from commit 4a150e5e1bec7703781ba80fcb1336ccf26e6074)
This commit is contained in:
parent
f311cfd437
commit
8236e9ebc5
@ -169,7 +169,12 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
builder.max_workspace_size = max_workspace_size
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size
|
||||
|
||||
if hasattr(config, 'set_memory_pool_limit'):
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
||||
max_workspace_size)
|
||||
else:
|
||||
config.max_workspace_size = max_workspace_size
|
||||
|
||||
cuda_version = search_cuda_version()
|
||||
if cuda_version is not None:
|
||||
@ -187,14 +192,19 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
opt_shape = param['opt_shape']
|
||||
max_shape = param['max_shape']
|
||||
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
if config.add_optimization_profile(profile) < 0:
|
||||
logger.warning(f'Invalid optimization profile {profile}.')
|
||||
|
||||
if fp16_mode:
|
||||
if not getattr(builder, 'platform_has_fast_fp16', True):
|
||||
logger.warning('Platform does not has fast native fp16.')
|
||||
if version.parse(trt.__version__) < version.parse('8'):
|
||||
builder.fp16_mode = fp16_mode
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
if int8_mode:
|
||||
if not getattr(builder, 'platform_has_fast_int8', True):
|
||||
logger.warning('Platform does not has fast native int8.')
|
||||
from .calib_utils import HDF5Calibrator
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
assert int8_param is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user