Update CoreML exports to support newer *.mlpackage outputs (#13222)

* Implement and default mlpackage generation for CoreML model exports

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>

* Provide command line argument to export as *.mlmodel instead of *.mlpackage for CoreML

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>

* Remove macOS check for CoreML quantization

Requirements for macOS during quantization was removed from coremltools 6.0

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>

* Undo removal of warning catching

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>

* Change file extension references from mlmodel to mlpackage

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>

* Auto-format by https://ultralytics.com/actions

---------

Signed-off-by: Ryan Hirasaki <ryanhirasaki@gmail.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
pull/13236/head
Ryan Hirasaki 2024-07-28 19:09:57 -05:00 committed by GitHub
parent dcf1242558
commit 6096750fcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 55 additions and 20 deletions

1
.gitignore vendored
View File

@ -54,6 +54,7 @@ VOC/
*.onnx *.onnx
*.engine *.engine
*.mlmodel *.mlmodel
*.mlpackage
*.torchscript *.torchscript
*.tflite *.tflite
*.h5 *.h5

View File

@ -9,7 +9,7 @@ TorchScript | `torchscript` | yolov5s.torchscrip
ONNX | `onnx` | yolov5s.onnx ONNX | `onnx` | yolov5s.onnx
OpenVINO | `openvino` | yolov5s_openvino_model/ OpenVINO | `openvino` | yolov5s_openvino_model/
TensorRT | `engine` | yolov5s.engine TensorRT | `engine` | yolov5s.engine
CoreML | `coreml` | yolov5s.mlmodel CoreML | `coreml` | yolov5s.mlpackage
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
TensorFlow GraphDef | `pb` | yolov5s.pb TensorFlow GraphDef | `pb` | yolov5s.pb
TensorFlow Lite | `tflite` | yolov5s.tflite TensorFlow Lite | `tflite` | yolov5s.tflite

View File

@ -20,7 +20,7 @@ Usage - formats:
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only) yolov5s.mlpackage # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef yolov5s.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite yolov5s.tflite # TensorFlow Lite

View File

@ -169,7 +169,7 @@ def export_formats():
["ONNX", "onnx", ".onnx", True, True], ["ONNX", "onnx", ".onnx", True, True],
["OpenVINO", "openvino", "_openvino_model", True, False], ["OpenVINO", "openvino", "_openvino_model", True, False],
["TensorRT", "engine", ".engine", False, True], ["TensorRT", "engine", ".engine", False, True],
["CoreML", "coreml", ".mlmodel", True, False], ["CoreML", "coreml", ".mlpackage", True, False],
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
["TensorFlow GraphDef", "pb", ".pb", True, True], ["TensorFlow GraphDef", "pb", ".pb", True, True],
["TensorFlow Lite", "tflite", ".tflite", True, False], ["TensorFlow Lite", "tflite", ".tflite", True, False],
@ -520,7 +520,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):
@try_export @try_export
def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("CoreML:")):
""" """
Export a YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support. Export a YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support.
@ -531,6 +531,7 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
int8 (bool): Flag indicating whether to use INT8 quantization (default is False). int8 (bool): Flag indicating whether to use INT8 quantization (default is False).
half (bool): Flag indicating whether to use FP16 quantization (default is False). half (bool): Flag indicating whether to use FP16 quantization (default is False).
nms (bool): Flag indicating whether to include Non-Maximum Suppression (default is False). nms (bool): Flag indicating whether to include Non-Maximum Suppression (default is False).
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
prefix (str): Prefix string for logging purposes (default is 'CoreML:'). prefix (str): Prefix string for logging purposes (default is 'CoreML:').
Returns: Returns:
@ -548,27 +549,46 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
model = Model(cfg, ch=3, nc=80) model = Model(cfg, ch=3, nc=80)
im = torch.randn(1, 3, 640, 640) im = torch.randn(1, 3, 640, 640)
file = Path("yolov5s_coreml") file = Path("yolov5s_coreml")
export_coreml(model, im, file, int8=False, half=False, nms=True) export_coreml(model, im, file, int8=False, half=False, nms=True, mlmodel=False)
``` ```
""" """
check_requirements("coremltools") check_requirements("coremltools")
import coremltools as ct import coremltools as ct
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
f = file.with_suffix(".mlmodel") if mlmodel:
f = file.with_suffix(".mlmodel")
convert_to = "neuralnetwork"
precision = None
else:
f = file.with_suffix(".mlpackage")
convert_to = "mlprogram"
if half:
precision = ct.precision.FLOAT16
else:
precision = ct.precision.FLOAT32
if nms: if nms:
model = iOSModel(model, im) model = iOSModel(model, im)
ts = torch.jit.trace(model, im, strict=False) # TorchScript model ts = torch.jit.trace(model, im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) ct_model = ct.convert(
bits, mode = (8, "kmeans_lut") if int8 else (16, "linear") if half else (32, None) ts,
inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])],
convert_to=convert_to,
compute_precision=precision,
)
bits, mode = (8, "kmeans") if int8 else (16, "linear") if half else (32, None)
if bits < 32: if bits < 32:
if MACOS: # quantization only supported on macOS if mlmodel:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning warnings.filterwarnings(
"ignore", category=DeprecationWarning
) # suppress numpy==1.20 float warning, fixed in coremltools==7.0
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else: elif bits == 8:
print(f"{prefix} quantization only supported on macOS, skipping...") op_config = ct.optimize.coreml.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512)
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
ct_model = ct.optimize.coreml.palettize_weights(ct_model, config)
ct_model.save(f) ct_model.save(f)
return f, ct_model return f, ct_model
@ -1070,7 +1090,7 @@ def add_tflite_metadata(file, metadata, num_outputs):
tmp_file.unlink() tmp_file.unlink()
def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")): def pipeline_coreml(model, im, file, names, y, mlmodel, prefix=colorstr("CoreML Pipeline:")):
""" """
Convert a PyTorch YOLOv5 model to CoreML format with Non-Maximum Suppression (NMS), handling different input/output Convert a PyTorch YOLOv5 model to CoreML format with Non-Maximum Suppression (NMS), handling different input/output
shapes, and saving the model. shapes, and saving the model.
@ -1082,6 +1102,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
file (Path): Path to save the converted CoreML model. file (Path): Path to save the converted CoreML model.
names (dict[int, str]): Dictionary mapping class indices to class names. names (dict[int, str]): Dictionary mapping class indices to class names.
y (torch.Tensor): Output tensor from the PyTorch model's forward pass. y (torch.Tensor): Output tensor from the PyTorch model's forward pass.
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
prefix (str): Custom prefix for logging messages. prefix (str): Custom prefix for logging messages.
Returns: Returns:
@ -1114,6 +1135,11 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
import coremltools as ct import coremltools as ct
from PIL import Image from PIL import Image
if mlmodel:
f = file.with_suffix(".mlmodel") # filename
else:
f = file.with_suffix(".mlpackage") # filename
print(f"{prefix} starting pipeline with coremltools {ct.__version__}...") print(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
batch_size, ch, h, w = list(im.shape) # BCHW batch_size, ch, h, w = list(im.shape) # BCHW
t = time.time() t = time.time()
@ -1156,7 +1182,12 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
print(spec.description) print(spec.description)
# Model from spec # Model from spec
model = ct.models.MLModel(spec) weights_dir = None
if mlmodel:
weights_dir = None
else:
weights_dir = str(f / "Data/com.apple.CoreML/weights")
model = ct.models.MLModel(spec, weights_dir=weights_dir)
# 3. Create NMS protobuf # 3. Create NMS protobuf
nms_spec = ct.proto.Model_pb2.Model() nms_spec = ct.proto.Model_pb2.Model()
@ -1227,8 +1258,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
) )
# Save the model # Save the model
f = file.with_suffix(".mlmodel") # filename model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
model = ct.models.MLModel(pipeline.spec)
model.input_description["image"] = "Input image" model.input_description["image"] = "Input image"
model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})" model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})"
model.input_description["confidenceThreshold"] = ( model.input_description["confidenceThreshold"] = (
@ -1256,6 +1286,7 @@ def run(
per_tensor=False, # TF per tensor quantization per_tensor=False, # TF per tensor quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model simplify=False, # ONNX: simplify model
mlmodel=False, # CoreML: Export in *.mlmodel format
opset=12, # ONNX: opset version opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB) workspace=4, # TensorRT: workspace size (GB)
@ -1293,6 +1324,7 @@ def run(
topk_all (int): Top-K boxes for all classes to keep for TensorFlow.js NMS. Default is 100. topk_all (int): Top-K boxes for all classes to keep for TensorFlow.js NMS. Default is 100.
iou_thres (float): IoU threshold for NMS. Default is 0.45. iou_thres (float): IoU threshold for NMS. Default is 0.45.
conf_thres (float): Confidence threshold for NMS. Default is 0.25. conf_thres (float): Confidence threshold for NMS. Default is 0.25.
mlmodel (bool): Flag to use *.mlmodel for CoreML export. Default is False.
Returns: Returns:
None None
@ -1320,6 +1352,7 @@ def run(
simplify=False, simplify=False,
opset=12, opset=12,
verbose=False, verbose=False,
mlmodel=False,
workspace=4, workspace=4,
nms=False, nms=False,
agnostic_nms=False, agnostic_nms=False,
@ -1383,9 +1416,9 @@ def run(
if xml: # OpenVINO if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half, int8, data) f[3], _ = export_openvino(file, metadata, half, int8, data)
if coreml: # CoreML if coreml: # CoreML
f[4], ct_model = export_coreml(model, im, file, int8, half, nms) f[4], ct_model = export_coreml(model, im, file, int8, half, nms, mlmodel)
if nms: if nms:
pipeline_coreml(ct_model, im, file, model.names, y) pipeline_coreml(ct_model, im, file, model.names, y, mlmodel)
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type." assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type."
assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported." assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported."
@ -1473,6 +1506,7 @@ def parse_opt(known=False):
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization") parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes") parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model") parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version") parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log") parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log")
parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)") parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)")

View File

@ -444,7 +444,7 @@ class DetectMultiBackend(nn.Module):
# ONNX Runtime: *.onnx # ONNX Runtime: *.onnx
# ONNX OpenCV DNN: *.onnx --dnn # ONNX OpenCV DNN: *.onnx --dnn
# OpenVINO: *_openvino_model # OpenVINO: *_openvino_model
# CoreML: *.mlmodel # CoreML: *.mlpackage
# TensorRT: *.engine # TensorRT: *.engine
# TensorFlow SavedModel: *_saved_model # TensorFlow SavedModel: *_saved_model
# TensorFlow GraphDef: *.pb # TensorFlow GraphDef: *.pb

2
val.py
View File

@ -11,7 +11,7 @@ Usage - formats:
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only) yolov5s.mlpackage # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef yolov5s.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite yolov5s.tflite # TensorFlow Lite