Avoid FP64 ops for MPS support in train.py (#8511)

Avoid FP64 ops for MPS support

Resolves https://github.com/ultralytics/yolov5/pull/7878#issuecomment-1177952614
pull/8484/head^2
Glenn Jocher 2022-07-07 20:36:23 +02:00 committed by GitHub
parent 9d7bc06ae7
commit dd28df98c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -644,7 +644,7 @@ def labels_to_class_weights(labels, nc=80):
return torch.Tensor()
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
classes = labels[:, 0].astype(int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
@ -654,13 +654,13 @@ def labels_to_class_weights(labels, nc=80):
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.from_numpy(weights)
return torch.from_numpy(weights).float()
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
return (class_weights.reshape(1, nc) * class_counts).sum(1)