DetectMultiBackend() return `device` update (#6958)
Fixes ONNX validation that returns outputs on CPU.pull/6960/head
parent
c84dd27d62
commit
52c1399fdc
|
@ -458,7 +458,8 @@ class DetectMultiBackend(nn.Module):
|
|||
y = (y.astype(np.float32) - zero_point) * scale # re-scale
|
||||
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||
|
||||
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
|
||||
if isinstance(y, np.ndarray):
|
||||
y = torch.tensor(y, device=self.device)
|
||||
return (y, []) if val else y
|
||||
|
||||
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||
|
|
Loading…
Reference in New Issue