diff --git a/utils/general.py b/utils/general.py index a855691d3..d31b043a1 100644 --- a/utils/general.py +++ b/utils/general.py @@ -843,7 +843,9 @@ def non_max_suppression( if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output - if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS + device = prediction.device + mps = 'mps' in device.type # Apple MPS + if mps: # MPS not fully supported yet, convert tensors to CPU before NMS prediction = prediction.cpu() bs = prediction.shape[0] # batch size nc = prediction.shape[2] - nm - 5 # number of classes @@ -930,6 +932,8 @@ def non_max_suppression( i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] + if mps: + output[xi] = output[xi].to(device) if (time.time() - t) > time_limit: LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') break # time limit exceeded