DetectMultiBackend() return `device` update (#6958)

Fixes ONNX validation that returns outputs on CPU.
pull/6960/head
Glenn Jocher 2022-03-12 13:16:29 +01:00 committed by GitHub
parent c84dd27d62
commit 52c1399fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -458,7 +458,8 @@ class DetectMultiBackend(nn.Module):
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
if isinstance(y, np.ndarray):
y = torch.tensor(y, device=self.device)
return (y, []) if val else y
def warmup(self, imgsz=(1, 3, 640, 640)):