diff --git a/.gitignore b/.gitignore index 6bcedfac6..d8b9c068b 100755 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ VOC/ *.onnx *.engine *.mlmodel +*.mlpackage *.torchscript *.tflite *.h5 diff --git a/benchmarks.py b/benchmarks.py index e92a645fb..996b8d438 100644 --- a/benchmarks.py +++ b/benchmarks.py @@ -9,7 +9,7 @@ TorchScript | `torchscript` | yolov5s.torchscrip ONNX | `onnx` | yolov5s.onnx OpenVINO | `openvino` | yolov5s_openvino_model/ TensorRT | `engine` | yolov5s.engine -CoreML | `coreml` | yolov5s.mlmodel +CoreML | `coreml` | yolov5s.mlpackage TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ TensorFlow GraphDef | `pb` | yolov5s.pb TensorFlow Lite | `tflite` | yolov5s.tflite diff --git a/detect.py b/detect.py index 61dd8d549..8a25ac235 100644 --- a/detect.py +++ b/detect.py @@ -20,7 +20,7 @@ Usage - formats: yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov5s_openvino_model # OpenVINO yolov5s.engine # TensorRT - yolov5s.mlmodel # CoreML (macOS-only) + yolov5s.mlpackage # CoreML (macOS-only) yolov5s_saved_model # TensorFlow SavedModel yolov5s.pb # TensorFlow GraphDef yolov5s.tflite # TensorFlow Lite diff --git a/export.py b/export.py index f5db08096..c1524ec2d 100644 --- a/export.py +++ b/export.py @@ -169,7 +169,7 @@ def export_formats(): ["ONNX", "onnx", ".onnx", True, True], ["OpenVINO", "openvino", "_openvino_model", True, False], ["TensorRT", "engine", ".engine", False, True], - ["CoreML", "coreml", ".mlmodel", True, False], + ["CoreML", "coreml", ".mlpackage", True, False], ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], ["TensorFlow GraphDef", "pb", ".pb", True, True], ["TensorFlow Lite", "tflite", ".tflite", True, False], @@ -520,7 +520,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")): @try_export -def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): +def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("CoreML:")): """ Export a YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support. @@ -531,6 +531,7 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): int8 (bool): Flag indicating whether to use INT8 quantization (default is False). half (bool): Flag indicating whether to use FP16 quantization (default is False). nms (bool): Flag indicating whether to include Non-Maximum Suppression (default is False). + mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False). prefix (str): Prefix string for logging purposes (default is 'CoreML:'). Returns: @@ -548,27 +549,46 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): model = Model(cfg, ch=3, nc=80) im = torch.randn(1, 3, 640, 640) file = Path("yolov5s_coreml") - export_coreml(model, im, file, int8=False, half=False, nms=True) + export_coreml(model, im, file, int8=False, half=False, nms=True, mlmodel=False) ``` """ check_requirements("coremltools") import coremltools as ct LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") - f = file.with_suffix(".mlmodel") + if mlmodel: + f = file.with_suffix(".mlmodel") + convert_to = "neuralnetwork" + precision = None + else: + f = file.with_suffix(".mlpackage") + convert_to = "mlprogram" + if half: + precision = ct.precision.FLOAT16 + else: + precision = ct.precision.FLOAT32 if nms: model = iOSModel(model, im) ts = torch.jit.trace(model, im, strict=False) # TorchScript model - ct_model = ct.convert(ts, inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) - bits, mode = (8, "kmeans_lut") if int8 else (16, "linear") if half else (32, None) + ct_model = ct.convert( + ts, + inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])], + convert_to=convert_to, + compute_precision=precision, + ) + bits, mode = (8, "kmeans") if int8 else (16, "linear") if half else (32, None) if bits < 32: - if MACOS: # quantization only supported on macOS + if mlmodel: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning + warnings.filterwarnings( + "ignore", category=DeprecationWarning + ) # suppress numpy==1.20 float warning, fixed in coremltools==7.0 ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) - else: - print(f"{prefix} quantization only supported on macOS, skipping...") + elif bits == 8: + op_config = ct.optimize.coreml.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512) + config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) + ct_model = ct.optimize.coreml.palettize_weights(ct_model, config) ct_model.save(f) return f, ct_model @@ -1070,7 +1090,7 @@ def add_tflite_metadata(file, metadata, num_outputs): tmp_file.unlink() -def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")): +def pipeline_coreml(model, im, file, names, y, mlmodel, prefix=colorstr("CoreML Pipeline:")): """ Convert a PyTorch YOLOv5 model to CoreML format with Non-Maximum Suppression (NMS), handling different input/output shapes, and saving the model. @@ -1082,6 +1102,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline: file (Path): Path to save the converted CoreML model. names (dict[int, str]): Dictionary mapping class indices to class names. y (torch.Tensor): Output tensor from the PyTorch model's forward pass. + mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False). prefix (str): Custom prefix for logging messages. Returns: @@ -1114,6 +1135,11 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline: import coremltools as ct from PIL import Image + if mlmodel: + f = file.with_suffix(".mlmodel") # filename + else: + f = file.with_suffix(".mlpackage") # filename + print(f"{prefix} starting pipeline with coremltools {ct.__version__}...") batch_size, ch, h, w = list(im.shape) # BCHW t = time.time() @@ -1156,7 +1182,12 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline: print(spec.description) # Model from spec - model = ct.models.MLModel(spec) + weights_dir = None + if mlmodel: + weights_dir = None + else: + weights_dir = str(f / "Data/com.apple.CoreML/weights") + model = ct.models.MLModel(spec, weights_dir=weights_dir) # 3. Create NMS protobuf nms_spec = ct.proto.Model_pb2.Model() @@ -1227,8 +1258,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline: ) # Save the model - f = file.with_suffix(".mlmodel") # filename - model = ct.models.MLModel(pipeline.spec) + model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) model.input_description["image"] = "Input image" model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})" model.input_description["confidenceThreshold"] = ( @@ -1256,6 +1286,7 @@ def run( per_tensor=False, # TF per tensor quantization dynamic=False, # ONNX/TF/TensorRT: dynamic axes simplify=False, # ONNX: simplify model + mlmodel=False, # CoreML: Export in *.mlmodel format opset=12, # ONNX: opset version verbose=False, # TensorRT: verbose log workspace=4, # TensorRT: workspace size (GB) @@ -1293,6 +1324,7 @@ def run( topk_all (int): Top-K boxes for all classes to keep for TensorFlow.js NMS. Default is 100. iou_thres (float): IoU threshold for NMS. Default is 0.45. conf_thres (float): Confidence threshold for NMS. Default is 0.25. + mlmodel (bool): Flag to use *.mlmodel for CoreML export. Default is False. Returns: None @@ -1320,6 +1352,7 @@ def run( simplify=False, opset=12, verbose=False, + mlmodel=False, workspace=4, nms=False, agnostic_nms=False, @@ -1383,9 +1416,9 @@ def run( if xml: # OpenVINO f[3], _ = export_openvino(file, metadata, half, int8, data) if coreml: # CoreML - f[4], ct_model = export_coreml(model, im, file, int8, half, nms) + f[4], ct_model = export_coreml(model, im, file, int8, half, nms, mlmodel) if nms: - pipeline_coreml(ct_model, im, file, model.names, y) + pipeline_coreml(ct_model, im, file, model.names, y, mlmodel) if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type." assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported." @@ -1473,6 +1506,7 @@ def parse_opt(known=False): parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization") parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes") parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model") + parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format") parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version") parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log") parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)") diff --git a/models/common.py b/models/common.py index 049dfc0b9..1e0ffdd3a 100644 --- a/models/common.py +++ b/models/common.py @@ -444,7 +444,7 @@ class DetectMultiBackend(nn.Module): # ONNX Runtime: *.onnx # ONNX OpenCV DNN: *.onnx --dnn # OpenVINO: *_openvino_model - # CoreML: *.mlmodel + # CoreML: *.mlpackage # TensorRT: *.engine # TensorFlow SavedModel: *_saved_model # TensorFlow GraphDef: *.pb diff --git a/val.py b/val.py index c4e1e402e..b8db6122f 100644 --- a/val.py +++ b/val.py @@ -11,7 +11,7 @@ Usage - formats: yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov5s_openvino_model # OpenVINO yolov5s.engine # TensorRT - yolov5s.mlmodel # CoreML (macOS-only) + yolov5s.mlpackage # CoreML (macOS-only) yolov5s_saved_model # TensorFlow SavedModel yolov5s.pb # TensorFlow GraphDef yolov5s.tflite # TensorFlow Lite