export.py return exported files/dirs (#6343)

* `export.py` return exported files/dirs

* Path to str
This commit is contained in:
Glenn Jocher 2022-01-18 15:18:23 -10:00 committed by GitHub
parent e2e95b2d8e
commit 0cf932bf63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -434,16 +434,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
# Exports # Exports
f = [''] * 10 # exported filenames
if 'torchscript' in include: if 'torchscript' in include:
f = export_torchscript(model, im, file, optimize) f[0] = export_torchscript(model, im, file, optimize)
if 'engine' in include: # TensorRT required before ONNX if 'engine' in include: # TensorRT required before ONNX
f = export_engine(model, im, file, train, half, simplify, workspace, verbose) f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
f = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'openvino' in include: if 'openvino' in include:
f = export_openvino(model, im, file) f[3] = export_openvino(model, im, file)
if 'coreml' in include: if 'coreml' in include:
_, f = export_coreml(model, im, file) _, f[4] = export_coreml(model, im, file)
# TensorFlow Exports # TensorFlow Exports
if any(tf_exports): if any(tf_exports):
@ -451,25 +452,27 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model, f = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs, model, f[5] = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,
topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres) # keras model topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
f = export_pb(model, im, file) f[6] = export_pb(model, im, file)
if tflite or edgetpu: if tflite or edgetpu:
f = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100) f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
if edgetpu: if edgetpu:
f = export_edgetpu(model, im, file) f[8] = export_edgetpu(model, im, file)
if tfjs: if tfjs:
f = export_tfjs(model, im, file) f[9] = export_tfjs(model, im, file)
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nVisualize with https://netron.app" f"\nVisualize with https://netron.app"
f"\nDetect with `python detect.py --weights {f}`" f"\nDetect with `python detect.py --weights {f[-1]}`"
f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f}')" f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
f"\nValidate with `python val.py --weights {f}`") f"\nValidate with `python val.py --weights {f[-1]}`")
return f # return list of exported files/dirs
def parse_opt(): def parse_opt():