[Enhance] support TensorRT engine for onnxruntime (#1739)

* Support trt engine for onnxruntime

* Apply lint

* Check trt execution provider

* Fix typo

* Fix provider order

* Check device
pull/1797/head
YH 2023-02-20 15:18:09 +09:00 committed by GitHub
parent b1be9c67f3
commit fd47fa2071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 3 deletions

View File

@ -50,9 +50,16 @@ class ORTWrapper(BaseWrapper):
logger.warning('The library of onnxruntime custom ops does '
f'not exist: {ort_custom_op_path}')
device_id = parse_device_id(device)
providers = ['CPUExecutionProvider'] \
if device == 'cpu' else \
[('CUDAExecutionProvider', {'device_id': device_id})]
providers = []
if 'cuda' in device:
if 'TensorrtExecutionProvider' in ort.get_available_providers():
providers.append(('TensorrtExecutionProvider', {
'device_id': device_id
}))
providers.append(('CUDAExecutionProvider', {
'device_id': device_id
}))
providers.append('CPUExecutionProvider')
sess = ort.InferenceSession(
onnx_file, session_options, providers=providers)
if output_names is None: