Apple MPS -> CPU NMS fallback strategy (#9600)
Until more ops are fully supported this update will allow for seamless MPS inference (but slower MPS to CPU transfer before NMS, so slower NMS times). Partially resolves https://github.com/ultralytics/yolov5/issues/9596 Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9610/head
parent
bd9c0c42ae
commit
c4c0ee8fc3
|
@ -843,6 +843,8 @@ def non_max_suppression(
|
||||||
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||||
prediction = prediction[0] # select only inference output
|
prediction = prediction[0] # select only inference output
|
||||||
|
|
||||||
|
if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||||
|
prediction = prediction.cpu()
|
||||||
bs = prediction.shape[0] # batch size
|
bs = prediction.shape[0] # batch size
|
||||||
nc = prediction.shape[2] - nm - 5 # number of classes
|
nc = prediction.shape[2] - nm - 5 # number of classes
|
||||||
xc = prediction[..., 4] > conf_thres # candidates
|
xc = prediction[..., 4] > conf_thres # candidates
|
||||||
|
|
Loading…
Reference in New Issue