NMS MPS device wrapper (#9620)
* NMS MPS device wrapper May resolve https://github.com/ultralytics/yolov5/issues/9613 Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9558/head^2
parent
7314363f26
commit
2373d5470e
utils
|
@ -843,7 +843,9 @@ def non_max_suppression(
|
|||
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
device = prediction.device
|
||||
mps = 'mps' in device.type # Apple MPS
|
||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
prediction = prediction.cpu()
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = prediction.shape[2] - nm - 5 # number of classes
|
||||
|
@ -930,6 +932,8 @@ def non_max_suppression(
|
|||
i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if mps:
|
||||
output[xi] = output[xi].to(device)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||
break # time limit exceeded
|
||||
|
|
Loading…
Reference in New Issue