Reduce val device transfers (#7525)

pull/7527/head
Glenn Jocher 2022-04-21 20:06:57 -07:00 committed by GitHub
parent 23718df1c6
commit d2e698c75c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 8 deletions

14
val.py
View File

@ -220,14 +220,14 @@ def run(
# Metrics
for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
path, shape = Path(paths[si]), shapes[si][0]
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
seen += 1
if len(pred) == 0:
if npr == 0:
if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
stats.append((correct, *torch.zeros((3, 0))))
continue
# Predictions
@ -244,9 +244,7 @@ def run(
correct = process_batch(predn, labelsn, iouv)
if plots:
confusion_matrix.process_batch(predn, labelsn)
else:
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
# Save/log
if save_txt:
@ -265,7 +263,7 @@ def run(
callbacks.run('on_val_batch_end')
# Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95