Use `export_formats()` in export.py (#6705)

* Use `export_formats()` in export.py

* list fix
pull/6618/head^2
Glenn Jocher 2022-02-19 16:08:33 +01:00 committed by GitHub
parent a297efc383
commit de9c25b35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 10 deletions

View File

@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
conf_thres=0.25 # TF.js NMS: confidence threshold conf_thres=0.25 # TF.js NMS: confidence threshold
): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] include = [x.lower() for x in include] # to lowercase
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports formats = tuple(export_formats()['Argument'][1:]) # --include arguments
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) flags = [x in include for x in formats]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
# Load PyTorch model # Load PyTorch model
device = select_device(device) device = select_device(device)
@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
# Exports # Exports
f = [''] * 10 # exported filenames f = [''] * 10 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if 'torchscript' in include: if jit:
f[0] = export_torchscript(model, im, file, optimize) f[0] = export_torchscript(model, im, file, optimize)
if 'engine' in include: # TensorRT required before ONNX if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose) f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'openvino' in include: if xml: # OpenVINO
f[3] = export_openvino(model, im, file) f[3] = export_openvino(model, im, file)
if 'coreml' in include: if coreml:
_, f[4] = export_coreml(model, im, file) _, f[4] = export_coreml(model, im, file)
# TensorFlow Exports # TensorFlow Exports
if any(tf_exports): if any((saved_model, pb, tflite, edgetpu, tfjs)):
pb, tflite, edgetpu, tfjs = tf_exports[1:]
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'