Edge TPU inference fix (#6686)

* refactor: use edgetpu flag

* fix: remove bitwise and assignation to tflite

* Cleanup and fix tflite

* Cleanup

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/6705/head
Raffaele Galliera 2022-02-19 08:10:07 -06:00 committed by GitHub
parent 0365379016
commit a297efc383
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 19 deletions

View File

@ -279,17 +279,17 @@ class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
# CoreML: *.mlmodel
# OpenVINO: *.xml
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
# ONNX Runtime: *.onnx
# ONNX OpenCV DNN: *.onnx with --dnn
# OpenVINO: *.xml
# CoreML: *.mlmodel
# TensorRT: *.engine
# TensorFlow SavedModel: *_saved_model
# TensorFlow GraphDef: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
super().__init__()
@ -367,19 +367,19 @@ class DetectMultiBackend(nn.Module):
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
gd = tf.Graph().as_graph_def() # graph_def
gd.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
if 'edgetpu' in w.lower(): # Edge TPU https://coral.ai/software/#edgetpu-runtime
if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
@ -391,6 +391,8 @@ class DetectMultiBackend(nn.Module):
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
elif tfjs:
raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False, val=False):
@ -436,7 +438,7 @@ class DetectMultiBackend(nn.Module):
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite
else: # Lite or Edge TPU
input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8: