Add TFLite Metadata to TFLite and Edge TPU models (#9903)

* added embedded meta data to tflite models

* added try block for inference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactored tfite meta data into separate function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Creat tmp file in /tmp

* Update export.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update export.py

* Update export.py

* Update export.py

* Update export.py

* Update common.py

* Update export.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update common.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
paradigm 2022-10-25 17:53:22 +02:00 committed by GitHub
parent fba61e5583
commit 54f49fa581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 2 deletions

View File

@ -45,6 +45,7 @@ TensorFlow.js:
""" """
import argparse import argparse
import contextlib
import json import json
import os import os
import platform import platform
@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
return f, None return f, None
def add_tflite_metadata(file, metadata, num_outputs):
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
with contextlib.suppress(ImportError):
# check_requirements('tflite_support')
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(metadata))
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
@smart_inference_mode() @smart_inference_mode()
def run( def run(
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
@ -550,8 +584,9 @@ def run(
f[6], _ = export_pb(s_model, file) f[6], _ = export_pb(s_model, file)
if tflite or edgetpu: if tflite or edgetpu:
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu: if edgetpu:
f[8], _ = export_edgetpu(file) f[8], _ = export_edgetpu(file)
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
if tfjs: if tfjs:
f[9], _ = export_tfjs(file) f[9], _ = export_tfjs(file)
if paddle: # PaddlePaddle if paddle: # PaddlePaddle

View File

@ -3,10 +3,13 @@
Common modules Common modules
""" """
import ast
import contextlib
import json import json
import math import math
import platform import platform
import warnings import warnings
import zipfile
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
@ -462,6 +465,12 @@ class DetectMultiBackend(nn.Module):
interpreter.allocate_tensors() # allocate interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
elif paddle: # PaddlePaddle elif paddle: # PaddlePaddle