diff --git a/models/common.py b/models/common.py index f7192219c..ce8467ee3 100644 --- a/models/common.py +++ b/models/common.py @@ -749,6 +749,9 @@ class DetectMultiBackend(nn.Module): scale, zero_point = output["quantization"] x = (x.astype(np.float32) - zero_point) * scale # re-scale y.append(x) + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels