Fix TFLite Segment inference

pull/13488/head
Mohammed Yasin 2025-01-10 23:42:20 -06:00
parent 86fd1ab270
commit 53bc3a0ac4
1 changed files with 3 additions and 0 deletions

View File

@ -750,6 +750,9 @@ class DetectMultiBackend(nn.Module):
scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels