Filter validation set from RF leakage USA. Add tag per PR curve

pull/2109/head
hanoch 2024-11-13 16:49:31 +02:00
parent 7391d5fc94
commit 4afc52b1cd
3 changed files with 23 additions and 3651 deletions

View File

@ -247,8 +247,8 @@ def test(data,
stats_all_large.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
# Plot images aa = np.repeat(img[0,:,:,:].cpu().permute(1,2,0).numpy(), 3, axis=2).astype('float32') cv2.imwrite('test/exp40/test_batch88_labels__.jpg', aa*255)
if plots and batch_i < 10 or 1:
# conf_thresh_plot = 0.1
if plots and batch_i > 10 or 1:
# conf_thresh_plot = 0.1 # the plot threshold the connfidence
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
@ -265,12 +265,12 @@ def test(data,
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) #based on correct @ IOU=0.5 of pred box with target
if not training or 1:
if bool(stats_person_medium):
p_med, r_med, ap_med, f1_med, ap_class_med = ap_per_class(*stats_person_medium, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names)
p_med, r_med, ap_med, f1_med, ap_class_med = ap_per_class(*stats_person_medium, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names, tag='small_objects')
ap50_med, ap_med = ap_med[:, 0], ap_med.mean(1) # AP@0.5, AP@0.5:0.95
mp_med, mr_med, map50_med, map_med = p_med.mean(), r_med.mean(), ap50_med.mean(), ap_med.mean()
nt_med = np.bincount(stats_person_medium[3].astype(np.int64), minlength=nc) # number of targets per class
if bool(stats_all_large):
p_large, r_large, ap_large, f1_large, ap_class_large = ap_per_class(*stats_all_large, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names)
p_large, r_large, ap_large, f1_large, ap_class_large = ap_per_class(*stats_all_large, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names, tag='large_objects')
ap50_large, ap_large = ap_large[:, 0], ap_large.mean(1) # AP@0.5, AP@0.5:0.95
mp_large, mr_large, map50_large, map_large = p_large.mean(), r_large.mean(), ap50_large.mean(), ap_large.mean()
nt_large = np.bincount(stats_all_large[3].astype(np.int64), minlength=nc) # number of targets per class

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from . import general
@ -15,7 +15,7 @@ def fitness(x):
return (x[:, :4] * w).sum(1)
def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=()):
def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=(), tag=''):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
@ -69,10 +69,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, sa
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)
if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
if bool(py):
plot_pr_curve(px, py, ap, Path(os.path.join(save_dir, 'PR_curve_' + tag + '.png')), names)
plot_mc_curve(px, f1, Path(os.path.join(save_dir, 'F1_curve_' + tag + '.png')), ylabel='F1')
plot_mc_curve(px, p, Path(os.path.join(save_dir, 'P_curve_' + tag + '.png')), ylabel='Precision')
plot_mc_curve(px, p, Path(os.path.join(save_dir, 'R_curve_' + tag + '.png')), ylabel='Recall')
i = f1.mean(0).argmax() # max F1 index
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
@ -197,8 +198,12 @@ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=(), precisions_of_i
for i, y in enumerate(py.T):
recall_of_interest_per_class = np.zeros_like(precisions_of_interest)
if np.array(precisions_of_interest).min() > np.array(precisions_of_interest).max():
recall_of_interest_per_class = [px[int(np.where(y.reshape(-1) > x)[0][-1])] for x in precisions_of_interest]
try:
if np.array(precisions_of_interest).min() < np.array(py[:, i]).max(): # make sure that precision has the values in the interest ROI
recall_of_interest_per_class = [px[int(np.where(y.reshape(-1) > x)[0][-1])] for x in precisions_of_interest]
except Exception as e:
print(e)
print(precisions_of_interest)
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
ax.plot(recall_of_interest_per_class, precisions_of_interest, '*', color='green')
@ -219,8 +224,12 @@ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=(), precisions_of_i
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
recall_of_interest = np.zeros_like(precisions_of_interest)
if np.array(precisions_of_interest).min() > np.array(precisions_of_interest).max():
recall_of_interest = [px[int(np.where(py.mean(1).reshape(-1) > x)[0][-1])] for x in precisions_of_interest]
try:
if np.array(precisions_of_interest).min() < np.array(py.mean(1)).max():
recall_of_interest = [px[int(np.where(py.mean(1).reshape(-1) > x)[0][-1])] for x in precisions_of_interest]
except Exception as e:
print(e)
print(precisions_of_interest)
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) # py [ap , num_clases]
ax.plot(recall_of_interest, precisions_of_interest, '*')