mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Add TFLite Metadata to TFLite and Edge TPU models (#9903)
* added embedded meta data to tflite models * added try block for inference * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactored tfite meta data into separate function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Creat tmp file in /tmp * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * Update export.py * Update export.py * Update common.py * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fba61e5583
commit
54f49fa581
39
export.py
39
export.py
@ -45,6 +45,7 @@ TensorFlow.js:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
|||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
|
|
||||||
|
def add_tflite_metadata(file, metadata, num_outputs):
|
||||||
|
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
||||||
|
with contextlib.suppress(ImportError):
|
||||||
|
# check_requirements('tflite_support')
|
||||||
|
from tflite_support import flatbuffers
|
||||||
|
from tflite_support import metadata as _metadata
|
||||||
|
from tflite_support import metadata_schema_py_generated as _metadata_fb
|
||||||
|
|
||||||
|
tmp_file = Path('/tmp/meta.txt')
|
||||||
|
with open(tmp_file, 'w') as meta_f:
|
||||||
|
meta_f.write(str(metadata))
|
||||||
|
|
||||||
|
model_meta = _metadata_fb.ModelMetadataT()
|
||||||
|
label_file = _metadata_fb.AssociatedFileT()
|
||||||
|
label_file.name = tmp_file.name
|
||||||
|
model_meta.associatedFiles = [label_file]
|
||||||
|
|
||||||
|
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||||
|
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
||||||
|
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
|
||||||
|
model_meta.subgraphMetadata = [subgraph]
|
||||||
|
|
||||||
|
b = flatbuffers.Builder(0)
|
||||||
|
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||||
|
metadata_buf = b.Output()
|
||||||
|
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(file)
|
||||||
|
populator.load_metadata_buffer(metadata_buf)
|
||||||
|
populator.load_associated_files([str(tmp_file)])
|
||||||
|
populator.populate()
|
||||||
|
tmp_file.unlink()
|
||||||
|
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
@ -550,8 +584,9 @@ def run(
|
|||||||
f[6], _ = export_pb(s_model, file)
|
f[6], _ = export_pb(s_model, file)
|
||||||
if tflite or edgetpu:
|
if tflite or edgetpu:
|
||||||
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
||||||
if edgetpu:
|
if edgetpu:
|
||||||
f[8], _ = export_edgetpu(file)
|
f[8], _ = export_edgetpu(file)
|
||||||
|
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
|
||||||
if tfjs:
|
if tfjs:
|
||||||
f[9], _ = export_tfjs(file)
|
f[9], _ = export_tfjs(file)
|
||||||
if paddle: # PaddlePaddle
|
if paddle: # PaddlePaddle
|
||||||
|
@ -3,10 +3,13 @@
|
|||||||
Common modules
|
Common modules
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
|
import zipfile
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -462,6 +465,12 @@ class DetectMultiBackend(nn.Module):
|
|||||||
interpreter.allocate_tensors() # allocate
|
interpreter.allocate_tensors() # allocate
|
||||||
input_details = interpreter.get_input_details() # inputs
|
input_details = interpreter.get_input_details() # inputs
|
||||||
output_details = interpreter.get_output_details() # outputs
|
output_details = interpreter.get_output_details() # outputs
|
||||||
|
# load metadata
|
||||||
|
with contextlib.suppress(zipfile.BadZipFile):
|
||||||
|
with zipfile.ZipFile(w, "r") as model:
|
||||||
|
meta_file = model.namelist()[0]
|
||||||
|
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||||
|
stride, names = int(meta['stride']), meta['names']
|
||||||
elif tfjs: # TF.js
|
elif tfjs: # TF.js
|
||||||
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
||||||
elif paddle: # PaddlePaddle
|
elif paddle: # PaddlePaddle
|
||||||
|
Loading…
x
Reference in New Issue
Block a user