diff --git a/test.py b/test.py index 15dd24362..f62747593 100644 --- a/test.py +++ b/test.py @@ -213,7 +213,7 @@ def test(data, # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): - p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png') + p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95] mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class diff --git a/utils/metrics.py b/utils/metrics.py index d4a10db18..62add1da1 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,5 +1,7 @@ # Model validation metrics +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np @@ -10,7 +12,7 @@ def fitness(x): return (x[:, :4] * w).sum(1) -def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'): +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]): """ Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. # Arguments @@ -19,7 +21,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re pred_cls: Predicted object classes (nparray). target_cls: True object classes (nparray). plot: Plot precision-recall curve at mAP@0.5 - fname: Plot filename + save_dir: Plot save directory # Returns The average precision as computed in py-faster-rcnn. """ @@ -66,17 +68,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re f1 = 2 * p * r / (p + r + 1e-16) if plot: - py = np.stack(py, axis=1) - fig, ax = plt.subplots(1, 1, figsize=(5, 5)) - ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision) - ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) - ax.set_xlabel('Recall') - ax.set_ylabel('Precision') - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - plt.legend() - fig.tight_layout() - fig.savefig(fname, dpi=200) + plot_pr_curve(px, py, ap, save_dir, names) return p, r, ap, f1, unique_classes.astype('int32') @@ -108,3 +100,23 @@ def compute_ap(recall, precision): ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve return ap, mpre, mrec + + +def plot_pr_curve(px, py, ap, save_dir='.', names=()): + fig, ax = plt.subplots(1, 1, figsize=(9, 6)) + py = np.stack(py, axis=1) + + if 0 < len(names) < 21: # show mAP in legend if < 10 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + + ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + fig.tight_layout() + fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) diff --git a/utils/plots.py b/utils/plots.py index 1429155ef..3653a2561 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -65,7 +65,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) -def plot_wh_methods(): # from utils.general import *; plot_wh_methods() +def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() # Compares the two methods for width-height anchor multiplication # https://github.com/ultralytics/yolov3/issues/168 x = np.arange(-4.0, 4.0, .1) @@ -200,7 +200,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): plt.savefig(Path(save_dir) / 'LR.png', dpi=200) -def plot_test_txt(): # from utils.general import *; plot_test() +def plot_test_txt(): # from utils.plots import *; plot_test() # Plot test.txt histograms x = np.loadtxt('test.txt', dtype=np.float32) box = xyxy2xywh(x[:, :4]) @@ -217,7 +217,7 @@ def plot_test_txt(): # from utils.general import *; plot_test() plt.savefig('hist1d.png', dpi=200) -def plot_targets_txt(): # from utils.general import *; plot_targets_txt() +def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() # Plot targets.txt histograms x = np.loadtxt('targets.txt', dtype=np.float32).T s = ['x targets', 'y targets', 'width targets', 'height targets'] @@ -230,7 +230,7 @@ def plot_targets_txt(): # from utils.general import *; plot_targets_txt() plt.savefig('targets.jpg', dpi=200) -def plot_study_txt(f='study.txt', x=None): # from utils.general import *; plot_study_txt() +def plot_study_txt(f='study.txt', x=None): # from utils.plots import *; plot_study_txt() # Plot study.txt generated by test.py fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) ax = ax.ravel() @@ -294,7 +294,7 @@ def plot_labels(labels, save_dir=''): pass -def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general import *; plot_evolution() +def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() # Plot hyperparameter evolution results in evolve.txt with open(yaml_file) as f: hyp = yaml.load(f, Loader=yaml.FullLoader) @@ -318,7 +318,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im print('\nPlot saved as evolve.png') -def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay() +def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() # Plot training 'results*.txt', overlaying train and val losses s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles @@ -342,20 +342,18 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): - # from utils.general import *; plot_results(save_dir='runs/train/exp0') - # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training + # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') fig, ax = plt.subplots(2, 5, figsize=(12, 6)) ax = ax.ravel() s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] if bucket: - # os.system('rm -rf storage.googleapis.com') # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] files = ['results%g.txt' % x for x in id] c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) os.system(c) else: - files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt') + files = list(Path(save_dir).glob('results*.txt')) assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) for fi, f in enumerate(files): try: @@ -367,7 +365,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): if i in [0, 1, 2, 5, 6, 7]: y[y == 0] = np.nan # don't show zero loss values # y /= y[0] # normalize - label = labels[fi] if len(labels) else Path(f).stem + label = labels[fi] if len(labels) else f.stem ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6) ax[i].set_title(s[i]) # if i in [5, 6, 7]: # share train and val loss y axes