TensorFlow SegmentationModel support (#9472)
* TensorFlow SegmentationModel support * TensorFlow SegmentationModel support * TensorFlow SegmentationModel support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * TFLite fixes * GraphDef fixes * Update ci-testing.yml Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9491/head
parent
120e27e38e
commit
fda8aa551d
|
@ -43,7 +43,7 @@ jobs:
|
|||
python benchmarks.py --data coco128.yaml --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29
|
||||
- name: Benchmark SegmentationModel
|
||||
run: |
|
||||
python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320
|
||||
python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 --hard-fail 0.22
|
||||
|
||||
Tests:
|
||||
timeout-minutes: 60
|
||||
|
|
|
@ -341,7 +341,7 @@ def export_saved_model(model,
|
|||
m = m.get_concrete_function(spec)
|
||||
frozen_func = convert_variables_to_constants_v2(m)
|
||||
tfm = tf.Module()
|
||||
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
|
||||
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
|
||||
tfm.__call__(im)
|
||||
tf.saved_model.save(tfm,
|
||||
f,
|
||||
|
|
|
@ -427,10 +427,17 @@ class DetectMultiBackend(nn.Module):
|
|||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
def gd_outputs(gd):
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
input_list.extend(node.input)
|
||||
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
||||
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
|
@ -528,22 +535,26 @@ class DetectMultiBackend(nn.Module):
|
|||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||
if self.saved_model: # SavedModel
|
||||
y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
|
||||
y = self.model(im, training=False) if self.keras else self.model(im)
|
||||
elif self.pb: # GraphDef
|
||||
y = self.frozen_func(x=self.tf.constant(im)).numpy()
|
||||
y = self.frozen_func(x=self.tf.constant(im))
|
||||
else: # Lite or Edge TPU
|
||||
input, output = self.input_details[0], self.output_details[0]
|
||||
input = self.input_details[0]
|
||||
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
|
||||
if int8:
|
||||
scale, zero_point = input['quantization']
|
||||
im = (im / scale + zero_point).astype(np.uint8) # de-scale
|
||||
self.interpreter.set_tensor(input['index'], im)
|
||||
self.interpreter.invoke()
|
||||
y = self.interpreter.get_tensor(output['index'])
|
||||
if int8:
|
||||
scale, zero_point = output['quantization']
|
||||
y = (y.astype(np.float32) - zero_point) * scale # re-scale
|
||||
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||
y = []
|
||||
for output in self.output_details:
|
||||
x = self.interpreter.get_tensor(output['index'])
|
||||
if int8:
|
||||
scale, zero_point = output['quantization']
|
||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||
y.append(x)
|
||||
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
||||
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||
|
||||
if isinstance(y, (list, tuple)):
|
||||
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
||||
|
|
15
models/tf.py
15
models/tf.py
|
@ -299,15 +299,15 @@ class TFDetect(keras.layers.Layer):
|
|||
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
|
||||
|
||||
if not self.training: # inference
|
||||
y = tf.sigmoid(x[i])
|
||||
y = x[i]
|
||||
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
|
||||
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
|
||||
xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
|
||||
wh = y[..., 2:4] ** 2 * anchor_grid
|
||||
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
|
||||
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
|
||||
# Normalize xywh to 0-1 to reduce calibration error
|
||||
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
||||
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
||||
y = tf.concat([xy, wh, y[..., 4:]], -1)
|
||||
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
|
||||
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
|
||||
|
||||
return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
|
||||
|
@ -333,8 +333,9 @@ class TFSegment(TFDetect):
|
|||
|
||||
def call(self, x):
|
||||
p = self.proto(x[0])
|
||||
p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
|
||||
x = self.detect(self, x)
|
||||
return (x, p) if self.training else ((x[0], p),)
|
||||
return (x, p) if self.training else (x[0], p)
|
||||
|
||||
|
||||
class TFProto(keras.layers.Layer):
|
||||
|
@ -485,8 +486,8 @@ class TFModel:
|
|||
conf_thres,
|
||||
clip_boxes=False)
|
||||
return nms, x[1]
|
||||
return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
|
||||
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
|
||||
return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
|
||||
# x = x[0] # [x(1,6300,85), ...] to x(6300,85)
|
||||
# xywh = x[..., :4] # x(6300,4) boxes
|
||||
# conf = x[..., 4:5] # x(6300,1) confidences
|
||||
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
|
||||
|
|
Loading…
Reference in New Issue