Avoid FP64 ops for MPS support in train.py (#8511)
Avoid FP64 ops for MPS support Resolves https://github.com/ultralytics/yolov5/pull/7878#issuecomment-1177952614pull/8484/head^2
parent
9d7bc06ae7
commit
dd28df98c2
|
@ -644,7 +644,7 @@ def labels_to_class_weights(labels, nc=80):
|
|||
return torch.Tensor()
|
||||
|
||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
|
||||
classes = labels[:, 0].astype(int) # labels = [class xywh]
|
||||
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
||||
|
||||
# Prepend gridpoint count (for uCE training)
|
||||
|
@ -654,13 +654,13 @@ def labels_to_class_weights(labels, nc=80):
|
|||
weights[weights == 0] = 1 # replace empty bins with 1
|
||||
weights = 1 / weights # number of targets per class
|
||||
weights /= weights.sum() # normalize
|
||||
return torch.from_numpy(weights)
|
||||
return torch.from_numpy(weights).float()
|
||||
|
||||
|
||||
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
||||
# Produces image weights based on class_weights and image contents
|
||||
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
|
||||
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
|
||||
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
|
||||
return (class_weights.reshape(1, nc) * class_counts).sum(1)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue