[Enhance] support TensorRT engine for onnxruntime (#1739)
* Support trt engine for onnxruntime * Apply lint * Check trt execution provider * Fix typo * Fix provider order * Check devicepull/1797/head
parent
b1be9c67f3
commit
fd47fa2071
|
@ -50,9 +50,16 @@ class ORTWrapper(BaseWrapper):
|
|||
logger.warning('The library of onnxruntime custom ops does '
|
||||
f'not exist: {ort_custom_op_path}')
|
||||
device_id = parse_device_id(device)
|
||||
providers = ['CPUExecutionProvider'] \
|
||||
if device == 'cpu' else \
|
||||
[('CUDAExecutionProvider', {'device_id': device_id})]
|
||||
providers = []
|
||||
if 'cuda' in device:
|
||||
if 'TensorrtExecutionProvider' in ort.get_available_providers():
|
||||
providers.append(('TensorrtExecutionProvider', {
|
||||
'device_id': device_id
|
||||
}))
|
||||
providers.append(('CUDAExecutionProvider', {
|
||||
'device_id': device_id
|
||||
}))
|
||||
providers.append('CPUExecutionProvider')
|
||||
sess = ort.InferenceSession(
|
||||
onnx_file, session_options, providers=providers)
|
||||
if output_names is None:
|
||||
|
|
Loading…
Reference in New Issue