mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Add NMS to CoreML exports (#11361)
* Add NMS to CoreML exports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
60e29e2d86
commit
a66fa8314c
150
export.py
150
export.py
@ -77,6 +77,25 @@ from utils.torch_utils import select_device, smart_inference_mode
|
|||||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||||
|
|
||||||
|
|
||||||
|
class iOSModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model, im):
|
||||||
|
super().__init__()
|
||||||
|
b, c, h, w = im.shape # batch, channel, height, width
|
||||||
|
self.model = model
|
||||||
|
self.nc = model.nc # number of classes
|
||||||
|
if w == h:
|
||||||
|
self.normalize = 1. / w
|
||||||
|
else:
|
||||||
|
self.normalize = torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]) # broadcast (slower, smaller)
|
||||||
|
# np = model(im)[0].shape[1] # number of points
|
||||||
|
# self.normalize = torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]).expand(np, 4) # explicit (faster, larger)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xywh, conf, cls = self.model(x)[0].squeeze().split((4, 1, self.nc), 1)
|
||||||
|
return cls * conf, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||||
|
|
||||||
|
|
||||||
def export_formats():
|
def export_formats():
|
||||||
# YOLOv5 export formats
|
# YOLOv5 export formats
|
||||||
x = [
|
x = [
|
||||||
@ -223,7 +242,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
|
|||||||
|
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
def export_coreml(model, im, file, int8, half, nms, prefix=colorstr('CoreML:')):
|
||||||
# YOLOv5 CoreML export
|
# YOLOv5 CoreML export
|
||||||
check_requirements('coremltools')
|
check_requirements('coremltools')
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
@ -231,6 +250,8 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
|||||||
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
||||||
f = file.with_suffix('.mlmodel')
|
f = file.with_suffix('.mlmodel')
|
||||||
|
|
||||||
|
if nms:
|
||||||
|
model = iOSModel(model, im)
|
||||||
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
|
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
|
||||||
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
|
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
|
||||||
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
|
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
|
||||||
@ -506,6 +527,129 @@ def add_tflite_metadata(file, metadata, num_outputs):
|
|||||||
tmp_file.unlink()
|
tmp_file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_coreml(model, im, file, names, y, prefix=colorstr('CoreML Pipeline:')):
|
||||||
|
# YOLOv5 CoreML pipeline
|
||||||
|
import coremltools as ct
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
print(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
|
||||||
|
batch_size, ch, h, w = list(im.shape) # BCHW
|
||||||
|
t = time.time()
|
||||||
|
|
||||||
|
# Output shapes
|
||||||
|
spec = model.get_spec()
|
||||||
|
out0, out1 = iter(spec.description.output)
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
|
||||||
|
# img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
|
||||||
|
out = model.predict({'image': img})
|
||||||
|
out0_shape, out1_shape = out[out0.name].shape, out[out1.name].shape
|
||||||
|
else: # linux and windows can not run model.predict(), get sizes from pytorch output y
|
||||||
|
s = tuple(y[0].shape)
|
||||||
|
out0_shape, out1_shape = (s[1], s[2] - 5), (s[1], 4) # (3780, 80), (3780, 4)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
||||||
|
na, nc = out0_shape
|
||||||
|
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
|
||||||
|
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
|
||||||
|
|
||||||
|
# Define output shapes (missing)
|
||||||
|
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
||||||
|
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
||||||
|
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
||||||
|
|
||||||
|
# Flexible input shapes
|
||||||
|
# from coremltools.models.neural_network import flexible_shape_utils
|
||||||
|
# s = [] # shapes
|
||||||
|
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
||||||
|
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
||||||
|
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
||||||
|
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
||||||
|
# r.add_height_range((192, 640))
|
||||||
|
# r.add_width_range((192, 640))
|
||||||
|
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
||||||
|
|
||||||
|
# Print
|
||||||
|
print(spec.description)
|
||||||
|
|
||||||
|
# Model from spec
|
||||||
|
model = ct.models.MLModel(spec)
|
||||||
|
|
||||||
|
# 3. Create NMS protobuf
|
||||||
|
nms_spec = ct.proto.Model_pb2.Model()
|
||||||
|
nms_spec.specificationVersion = 5
|
||||||
|
for i in range(2):
|
||||||
|
decoder_output = model._spec.description.output[i].SerializeToString()
|
||||||
|
nms_spec.description.input.add()
|
||||||
|
nms_spec.description.input[i].ParseFromString(decoder_output)
|
||||||
|
nms_spec.description.output.add()
|
||||||
|
nms_spec.description.output[i].ParseFromString(decoder_output)
|
||||||
|
|
||||||
|
nms_spec.description.output[0].name = 'confidence'
|
||||||
|
nms_spec.description.output[1].name = 'coordinates'
|
||||||
|
|
||||||
|
output_sizes = [nc, 4]
|
||||||
|
for i in range(2):
|
||||||
|
ma_type = nms_spec.description.output[i].type.multiArrayType
|
||||||
|
ma_type.shapeRange.sizeRanges.add()
|
||||||
|
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
||||||
|
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
||||||
|
ma_type.shapeRange.sizeRanges.add()
|
||||||
|
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
||||||
|
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
||||||
|
del ma_type.shape[:]
|
||||||
|
|
||||||
|
nms = nms_spec.nonMaximumSuppression
|
||||||
|
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
||||||
|
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
||||||
|
nms.confidenceOutputFeatureName = 'confidence'
|
||||||
|
nms.coordinatesOutputFeatureName = 'coordinates'
|
||||||
|
nms.iouThresholdInputFeatureName = 'iouThreshold'
|
||||||
|
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
|
||||||
|
nms.iouThreshold = 0.45
|
||||||
|
nms.confidenceThreshold = 0.25
|
||||||
|
nms.pickTop.perClass = True
|
||||||
|
nms.stringClassLabels.vector.extend(names.values())
|
||||||
|
nms_model = ct.models.MLModel(nms_spec)
|
||||||
|
|
||||||
|
# 4. Pipeline models together
|
||||||
|
pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
|
||||||
|
('iouThreshold', ct.models.datatypes.Double()),
|
||||||
|
('confidenceThreshold', ct.models.datatypes.Double())],
|
||||||
|
output_features=['confidence', 'coordinates'])
|
||||||
|
pipeline.add_model(model)
|
||||||
|
pipeline.add_model(nms_model)
|
||||||
|
|
||||||
|
# Correct datatypes
|
||||||
|
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
||||||
|
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
||||||
|
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
pipeline.spec.specificationVersion = 5
|
||||||
|
pipeline.spec.description.metadata.versionString = 'https://github.com/ultralytics/yolov5'
|
||||||
|
pipeline.spec.description.metadata.shortDescription = 'https://github.com/ultralytics/yolov5'
|
||||||
|
pipeline.spec.description.metadata.author = 'glenn.jocher@ultralytics.com'
|
||||||
|
pipeline.spec.description.metadata.license = 'https://github.com/ultralytics/yolov5/blob/master/LICENSE'
|
||||||
|
pipeline.spec.description.metadata.userDefined.update({
|
||||||
|
'classes': ','.join(names.values()),
|
||||||
|
'iou_threshold': str(nms.iouThreshold),
|
||||||
|
'confidence_threshold': str(nms.confidenceThreshold)})
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
f = file.with_suffix('.mlmodel') # filename
|
||||||
|
model = ct.models.MLModel(pipeline.spec)
|
||||||
|
model.input_description['image'] = 'Input image'
|
||||||
|
model.input_description['iouThreshold'] = f'(optional) IOU Threshold override (default: {nms.iouThreshold})'
|
||||||
|
model.input_description['confidenceThreshold'] = \
|
||||||
|
f'(optional) Confidence Threshold override (default: {nms.confidenceThreshold})'
|
||||||
|
model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
||||||
|
model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
|
||||||
|
model.save(f) # pipelined
|
||||||
|
print(f'{prefix} pipeline success ({time.time() - t:.2f}s), saved as {f} ({file_size(f):.1f} MB)')
|
||||||
|
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
@ -584,7 +728,9 @@ def run(
|
|||||||
if xml: # OpenVINO
|
if xml: # OpenVINO
|
||||||
f[3], _ = export_openvino(file, metadata, half)
|
f[3], _ = export_openvino(file, metadata, half)
|
||||||
if coreml: # CoreML
|
if coreml: # CoreML
|
||||||
f[4], _ = export_coreml(model, im, file, int8, half)
|
f[4], ct_model = export_coreml(model, im, file, int8, half, nms)
|
||||||
|
if nms:
|
||||||
|
pipeline_coreml(ct_model, im, file, model.names, y)
|
||||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||||
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
||||||
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
|
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user