Fix numpy to torch cls streaming bug (#9112)
* Fix numpy to torch cls streaming bug Resolves https://github.com/ultralytics/yolov5/issues/9111 Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9117/head
parent
51c9f92297
commit
e6f54c5b32
|
@ -30,6 +30,7 @@ import platform
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ def run(
|
||||||
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
|
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
|
||||||
for path, im, im0s, vid_cap, s in dataset:
|
for path, im, im0s, vid_cap, s in dataset:
|
||||||
with dt[0]:
|
with dt[0]:
|
||||||
im = im.to(device)
|
im = torch.Tensor(im).to(device)
|
||||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
|
|
Loading…
Reference in New Issue