Reduce val device transfers (#7525)
parent
23718df1c6
commit
d2e698c75c
14
val.py
14
val.py
|
@ -220,14 +220,14 @@ def run(
|
|||
# Metrics
|
||||
for si, pred in enumerate(out):
|
||||
labels = targets[targets[:, 0] == si, 1:]
|
||||
nl = len(labels)
|
||||
tcls = labels[:, 0].tolist() if nl else [] # target class
|
||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||
path, shape = Path(paths[si]), shapes[si][0]
|
||||
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
|
||||
seen += 1
|
||||
|
||||
if len(pred) == 0:
|
||||
if npr == 0:
|
||||
if nl:
|
||||
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
||||
stats.append((correct, *torch.zeros((3, 0))))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
|
@ -244,9 +244,7 @@ def run(
|
|||
correct = process_batch(predn, labelsn, iouv)
|
||||
if plots:
|
||||
confusion_matrix.process_batch(predn, labelsn)
|
||||
else:
|
||||
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
|
||||
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
|
||||
|
||||
# Save/log
|
||||
if save_txt:
|
||||
|
@ -265,7 +263,7 @@ def run(
|
|||
callbacks.run('on_val_batch_end')
|
||||
|
||||
# Compute metrics
|
||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
|
|
Loading…
Reference in New Issue