TensorRT `assert im.device.type != 'cpu'` on export (#6340)
* TensorRT `assert im.device.type != 'cpu'` on export * Update export.pypull/6343/head
parent
fd55271c04
commit
e2e95b2d8e
|
@ -184,9 +184,10 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
||||||
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
||||||
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
|
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
|
||||||
onnx = file.with_suffix('.onnx')
|
onnx = file.with_suffix('.onnx')
|
||||||
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
|
||||||
|
|
||||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||||
|
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
||||||
|
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
||||||
f = file.with_suffix('.engine') # TensorRT engine file
|
f = file.with_suffix('.engine') # TensorRT engine file
|
||||||
logger = trt.Logger(trt.Logger.INFO)
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
Loading…
Reference in New Issue