NMS MPS device wrapper (#9620)
* NMS MPS device wrapper May resolve https://github.com/ultralytics/yolov5/issues/9613 Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update general.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9558/head^2
parent
7314363f26
commit
2373d5470e
|
@ -843,7 +843,9 @@ def non_max_suppression(
|
||||||
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||||
prediction = prediction[0] # select only inference output
|
prediction = prediction[0] # select only inference output
|
||||||
|
|
||||||
if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
|
device = prediction.device
|
||||||
|
mps = 'mps' in device.type # Apple MPS
|
||||||
|
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||||
prediction = prediction.cpu()
|
prediction = prediction.cpu()
|
||||||
bs = prediction.shape[0] # batch size
|
bs = prediction.shape[0] # batch size
|
||||||
nc = prediction.shape[2] - nm - 5 # number of classes
|
nc = prediction.shape[2] - nm - 5 # number of classes
|
||||||
|
@ -930,6 +932,8 @@ def non_max_suppression(
|
||||||
i = i[iou.sum(1) > 1] # require redundancy
|
i = i[iou.sum(1) > 1] # require redundancy
|
||||||
|
|
||||||
output[xi] = x[i]
|
output[xi] = x[i]
|
||||||
|
if mps:
|
||||||
|
output[xi] = output[xi].to(device)
|
||||||
if (time.time() - t) > time_limit:
|
if (time.time() - t) > time_limit:
|
||||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||||
break # time limit exceeded
|
break # time limit exceeded
|
||||||
|
|
Loading…
Reference in New Issue