Edge TPU inference fix (#6686)
* refactor: use edgetpu flag * fix: remove bitwise and assignation to tflite * Cleanup and fix tflite * Cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/6705/head
parent
0365379016
commit
a297efc383
|
@ -279,17 +279,17 @@ class DetectMultiBackend(nn.Module):
|
|||
# YOLOv5 MultiBackend class for python inference on various backends
|
||||
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
|
||||
# Usage:
|
||||
# PyTorch: weights = *.pt
|
||||
# TorchScript: *.torchscript
|
||||
# CoreML: *.mlmodel
|
||||
# OpenVINO: *.xml
|
||||
# TensorFlow: *_saved_model
|
||||
# TensorFlow: *.pb
|
||||
# TensorFlow Lite: *.tflite
|
||||
# TensorFlow Edge TPU: *_edgetpu.tflite
|
||||
# ONNX Runtime: *.onnx
|
||||
# OpenCV DNN: *.onnx with dnn=True
|
||||
# TensorRT: *.engine
|
||||
# PyTorch: weights = *.pt
|
||||
# TorchScript: *.torchscript
|
||||
# ONNX Runtime: *.onnx
|
||||
# ONNX OpenCV DNN: *.onnx with --dnn
|
||||
# OpenVINO: *.xml
|
||||
# CoreML: *.mlmodel
|
||||
# TensorRT: *.engine
|
||||
# TensorFlow SavedModel: *_saved_model
|
||||
# TensorFlow GraphDef: *.pb
|
||||
# TensorFlow Lite: *.tflite
|
||||
# TensorFlow Edge TPU: *_edgetpu.tflite
|
||||
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
|
||||
|
||||
super().__init__()
|
||||
|
@ -367,19 +367,19 @@ class DetectMultiBackend(nn.Module):
|
|||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
|
||||
tf.nest.map_structure(x.graph.as_graph_element, outputs))
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
graph_def = tf.Graph().as_graph_def()
|
||||
graph_def.ParseFromString(open(w, 'rb').read())
|
||||
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
|
||||
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
gd = tf.Graph().as_graph_def() # graph_def
|
||||
gd.ParseFromString(open(w, 'rb').read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
import tensorflow as tf
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||
if 'edgetpu' in w.lower(): # Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||
if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||
delegate = {'Linux': 'libedgetpu.so.1',
|
||||
'Darwin': 'libedgetpu.1.dylib',
|
||||
|
@ -391,6 +391,8 @@ class DetectMultiBackend(nn.Module):
|
|||
interpreter.allocate_tensors() # allocate
|
||||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
elif tfjs:
|
||||
raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
|
||||
def forward(self, im, augment=False, visualize=False, val=False):
|
||||
|
@ -436,7 +438,7 @@ class DetectMultiBackend(nn.Module):
|
|||
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
|
||||
elif self.pb: # GraphDef
|
||||
y = self.frozen_func(x=self.tf.constant(im)).numpy()
|
||||
elif self.tflite: # Lite
|
||||
else: # Lite or Edge TPU
|
||||
input, output = self.input_details[0], self.output_details[0]
|
||||
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
|
||||
if int8:
|
||||
|
|
Loading…
Reference in New Issue