Update DetectMultiBackend for tuple outputs 2 (#9275)
* Update DetectMultiBackend for tuple outputs 2 Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update * Update * Update Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9295/head
parent
96c3c7f71d
commit
7aa263c5f2
|
@ -457,7 +457,7 @@ class DetectMultiBackend(nn.Module):
|
|||
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
|
||||
def forward(self, im, augment=False, visualize=False, val=False):
|
||||
def forward(self, im, augment=False, visualize=False):
|
||||
# YOLOv5 MultiBackend inference
|
||||
b, ch, h, w = im.shape # batch, channel, height, width
|
||||
if self.fp16 and im.dtype != torch.float16:
|
||||
|
@ -521,10 +521,12 @@ class DetectMultiBackend(nn.Module):
|
|||
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
||||
|
||||
if isinstance(y, (list, tuple)):
|
||||
y = y[0]
|
||||
if isinstance(y, np.ndarray):
|
||||
y = torch.from_numpy(y).to(self.device)
|
||||
return (y, []) if val else y
|
||||
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
||||
else:
|
||||
return self.from_numpy(y)
|
||||
|
||||
def from_numpy(self, x):
|
||||
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
||||
|
||||
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||
# Warmup model by running inference once
|
||||
|
|
|
@ -813,6 +813,9 @@ def non_max_suppression(prediction,
|
|||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
"""
|
||||
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = prediction.shape[2] - 5 # number of classes
|
||||
xc = prediction[..., 4] > conf_thres # candidates
|
||||
|
|
4
val.py
4
val.py
|
@ -204,11 +204,11 @@ def run(
|
|||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
|
||||
out, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
|
||||
|
||||
# Loss
|
||||
if compute_loss:
|
||||
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
|
||||
loss += compute_loss(train_out, targets)[1] # box, obj, cls
|
||||
|
||||
# NMS
|
||||
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
|
||||
|
|
Loading…
Reference in New Issue