mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
update API for TensorRT8.4 (#1144)
(cherry picked from commit 4a150e5e1bec7703781ba80fcb1336ccf26e6074)
This commit is contained in:
parent
f311cfd437
commit
8236e9ebc5
@ -169,6 +169,11 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
|||||||
builder.max_workspace_size = max_workspace_size
|
builder.max_workspace_size = max_workspace_size
|
||||||
|
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
|
|
||||||
|
if hasattr(config, 'set_memory_pool_limit'):
|
||||||
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
||||||
|
max_workspace_size)
|
||||||
|
else:
|
||||||
config.max_workspace_size = max_workspace_size
|
config.max_workspace_size = max_workspace_size
|
||||||
|
|
||||||
cuda_version = search_cuda_version()
|
cuda_version = search_cuda_version()
|
||||||
@ -187,14 +192,19 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
|||||||
opt_shape = param['opt_shape']
|
opt_shape = param['opt_shape']
|
||||||
max_shape = param['max_shape']
|
max_shape = param['max_shape']
|
||||||
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||||
config.add_optimization_profile(profile)
|
if config.add_optimization_profile(profile) < 0:
|
||||||
|
logger.warning(f'Invalid optimization profile {profile}.')
|
||||||
|
|
||||||
if fp16_mode:
|
if fp16_mode:
|
||||||
|
if not getattr(builder, 'platform_has_fast_fp16', True):
|
||||||
|
logger.warning('Platform does not has fast native fp16.')
|
||||||
if version.parse(trt.__version__) < version.parse('8'):
|
if version.parse(trt.__version__) < version.parse('8'):
|
||||||
builder.fp16_mode = fp16_mode
|
builder.fp16_mode = fp16_mode
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
|
||||||
if int8_mode:
|
if int8_mode:
|
||||||
|
if not getattr(builder, 'platform_has_fast_int8', True):
|
||||||
|
logger.warning('Platform does not has fast native int8.')
|
||||||
from .calib_utils import HDF5Calibrator
|
from .calib_utils import HDF5Calibrator
|
||||||
config.set_flag(trt.BuilderFlag.INT8)
|
config.set_flag(trt.BuilderFlag.INT8)
|
||||||
assert int8_param is not None
|
assert int8_param is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user