|
|
|
@ -169,7 +169,7 @@ def export_formats():
|
|
|
|
|
["ONNX", "onnx", ".onnx", True, True],
|
|
|
|
|
["OpenVINO", "openvino", "_openvino_model", True, False],
|
|
|
|
|
["TensorRT", "engine", ".engine", False, True],
|
|
|
|
|
["CoreML", "coreml", ".mlmodel", True, False],
|
|
|
|
|
["CoreML", "coreml", ".mlpackage", True, False],
|
|
|
|
|
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
|
|
|
|
|
["TensorFlow GraphDef", "pb", ".pb", True, True],
|
|
|
|
|
["TensorFlow Lite", "tflite", ".tflite", True, False],
|
|
|
|
@ -520,7 +520,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@try_export
|
|
|
|
|
def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
|
|
|
|
|
def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("CoreML:")):
|
|
|
|
|
"""
|
|
|
|
|
Export a YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support.
|
|
|
|
|
|
|
|
|
@ -531,6 +531,7 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
|
|
|
|
|
int8 (bool): Flag indicating whether to use INT8 quantization (default is False).
|
|
|
|
|
half (bool): Flag indicating whether to use FP16 quantization (default is False).
|
|
|
|
|
nms (bool): Flag indicating whether to include Non-Maximum Suppression (default is False).
|
|
|
|
|
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
|
|
|
|
|
prefix (str): Prefix string for logging purposes (default is 'CoreML:').
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -548,27 +549,46 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")):
|
|
|
|
|
model = Model(cfg, ch=3, nc=80)
|
|
|
|
|
im = torch.randn(1, 3, 640, 640)
|
|
|
|
|
file = Path("yolov5s_coreml")
|
|
|
|
|
export_coreml(model, im, file, int8=False, half=False, nms=True)
|
|
|
|
|
export_coreml(model, im, file, int8=False, half=False, nms=True, mlmodel=False)
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
check_requirements("coremltools")
|
|
|
|
|
import coremltools as ct
|
|
|
|
|
|
|
|
|
|
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
|
|
|
|
|
f = file.with_suffix(".mlmodel")
|
|
|
|
|
if mlmodel:
|
|
|
|
|
f = file.with_suffix(".mlmodel")
|
|
|
|
|
convert_to = "neuralnetwork"
|
|
|
|
|
precision = None
|
|
|
|
|
else:
|
|
|
|
|
f = file.with_suffix(".mlpackage")
|
|
|
|
|
convert_to = "mlprogram"
|
|
|
|
|
if half:
|
|
|
|
|
precision = ct.precision.FLOAT16
|
|
|
|
|
else:
|
|
|
|
|
precision = ct.precision.FLOAT32
|
|
|
|
|
|
|
|
|
|
if nms:
|
|
|
|
|
model = iOSModel(model, im)
|
|
|
|
|
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
|
|
|
|
|
ct_model = ct.convert(ts, inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
|
|
|
|
|
bits, mode = (8, "kmeans_lut") if int8 else (16, "linear") if half else (32, None)
|
|
|
|
|
ct_model = ct.convert(
|
|
|
|
|
ts,
|
|
|
|
|
inputs=[ct.ImageType("image", shape=im.shape, scale=1 / 255, bias=[0, 0, 0])],
|
|
|
|
|
convert_to=convert_to,
|
|
|
|
|
compute_precision=precision,
|
|
|
|
|
)
|
|
|
|
|
bits, mode = (8, "kmeans") if int8 else (16, "linear") if half else (32, None)
|
|
|
|
|
if bits < 32:
|
|
|
|
|
if MACOS: # quantization only supported on macOS
|
|
|
|
|
if mlmodel:
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
|
|
|
|
|
warnings.filterwarnings(
|
|
|
|
|
"ignore", category=DeprecationWarning
|
|
|
|
|
) # suppress numpy==1.20 float warning, fixed in coremltools==7.0
|
|
|
|
|
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
|
|
|
|
else:
|
|
|
|
|
print(f"{prefix} quantization only supported on macOS, skipping...")
|
|
|
|
|
elif bits == 8:
|
|
|
|
|
op_config = ct.optimize.coreml.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512)
|
|
|
|
|
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
|
|
|
|
|
ct_model = ct.optimize.coreml.palettize_weights(ct_model, config)
|
|
|
|
|
ct_model.save(f)
|
|
|
|
|
return f, ct_model
|
|
|
|
|
|
|
|
|
@ -1070,7 +1090,7 @@ def add_tflite_metadata(file, metadata, num_outputs):
|
|
|
|
|
tmp_file.unlink()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")):
|
|
|
|
|
def pipeline_coreml(model, im, file, names, y, mlmodel, prefix=colorstr("CoreML Pipeline:")):
|
|
|
|
|
"""
|
|
|
|
|
Convert a PyTorch YOLOv5 model to CoreML format with Non-Maximum Suppression (NMS), handling different input/output
|
|
|
|
|
shapes, and saving the model.
|
|
|
|
@ -1082,6 +1102,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
|
|
|
|
|
file (Path): Path to save the converted CoreML model.
|
|
|
|
|
names (dict[int, str]): Dictionary mapping class indices to class names.
|
|
|
|
|
y (torch.Tensor): Output tensor from the PyTorch model's forward pass.
|
|
|
|
|
mlmodel (bool): Flag indicating whether to export as older *.mlmodel format (default is False).
|
|
|
|
|
prefix (str): Custom prefix for logging messages.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -1114,6 +1135,11 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
|
|
|
|
|
import coremltools as ct
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
if mlmodel:
|
|
|
|
|
f = file.with_suffix(".mlmodel") # filename
|
|
|
|
|
else:
|
|
|
|
|
f = file.with_suffix(".mlpackage") # filename
|
|
|
|
|
|
|
|
|
|
print(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
|
|
|
|
batch_size, ch, h, w = list(im.shape) # BCHW
|
|
|
|
|
t = time.time()
|
|
|
|
@ -1156,7 +1182,12 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
|
|
|
|
|
print(spec.description)
|
|
|
|
|
|
|
|
|
|
# Model from spec
|
|
|
|
|
model = ct.models.MLModel(spec)
|
|
|
|
|
weights_dir = None
|
|
|
|
|
if mlmodel:
|
|
|
|
|
weights_dir = None
|
|
|
|
|
else:
|
|
|
|
|
weights_dir = str(f / "Data/com.apple.CoreML/weights")
|
|
|
|
|
model = ct.models.MLModel(spec, weights_dir=weights_dir)
|
|
|
|
|
|
|
|
|
|
# 3. Create NMS protobuf
|
|
|
|
|
nms_spec = ct.proto.Model_pb2.Model()
|
|
|
|
@ -1227,8 +1258,7 @@ def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Save the model
|
|
|
|
|
f = file.with_suffix(".mlmodel") # filename
|
|
|
|
|
model = ct.models.MLModel(pipeline.spec)
|
|
|
|
|
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
|
|
|
|
|
model.input_description["image"] = "Input image"
|
|
|
|
|
model.input_description["iouThreshold"] = f"(optional) IOU Threshold override (default: {nms.iouThreshold})"
|
|
|
|
|
model.input_description["confidenceThreshold"] = (
|
|
|
|
@ -1256,6 +1286,7 @@ def run(
|
|
|
|
|
per_tensor=False, # TF per tensor quantization
|
|
|
|
|
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
|
|
|
|
|
simplify=False, # ONNX: simplify model
|
|
|
|
|
mlmodel=False, # CoreML: Export in *.mlmodel format
|
|
|
|
|
opset=12, # ONNX: opset version
|
|
|
|
|
verbose=False, # TensorRT: verbose log
|
|
|
|
|
workspace=4, # TensorRT: workspace size (GB)
|
|
|
|
@ -1293,6 +1324,7 @@ def run(
|
|
|
|
|
topk_all (int): Top-K boxes for all classes to keep for TensorFlow.js NMS. Default is 100.
|
|
|
|
|
iou_thres (float): IoU threshold for NMS. Default is 0.45.
|
|
|
|
|
conf_thres (float): Confidence threshold for NMS. Default is 0.25.
|
|
|
|
|
mlmodel (bool): Flag to use *.mlmodel for CoreML export. Default is False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
@ -1320,6 +1352,7 @@ def run(
|
|
|
|
|
simplify=False,
|
|
|
|
|
opset=12,
|
|
|
|
|
verbose=False,
|
|
|
|
|
mlmodel=False,
|
|
|
|
|
workspace=4,
|
|
|
|
|
nms=False,
|
|
|
|
|
agnostic_nms=False,
|
|
|
|
@ -1383,9 +1416,9 @@ def run(
|
|
|
|
|
if xml: # OpenVINO
|
|
|
|
|
f[3], _ = export_openvino(file, metadata, half, int8, data)
|
|
|
|
|
if coreml: # CoreML
|
|
|
|
|
f[4], ct_model = export_coreml(model, im, file, int8, half, nms)
|
|
|
|
|
f[4], ct_model = export_coreml(model, im, file, int8, half, nms, mlmodel)
|
|
|
|
|
if nms:
|
|
|
|
|
pipeline_coreml(ct_model, im, file, model.names, y)
|
|
|
|
|
pipeline_coreml(ct_model, im, file, model.names, y, mlmodel)
|
|
|
|
|
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
|
|
|
|
assert not tflite or not tfjs, "TFLite and TF.js models must be exported separately, please pass only one type."
|
|
|
|
|
assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported."
|
|
|
|
@ -1473,6 +1506,7 @@ def parse_opt(known=False):
|
|
|
|
|
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
|
|
|
|
|
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
|
|
|
|
|
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
|
|
|
|
|
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
|
|
|
|
|
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
|
|
|
|
|
parser.add_argument("--verbose", action="store_true", help="TensorRT: verbose log")
|
|
|
|
|
parser.add_argument("--workspace", type=int, default=4, help="TensorRT: workspace size (GB)")
|
|
|
|
|