Precision-Recall Curve Feature Addition (#1107)
* initial commit * Update general.py Indent update * Update general.py refactor duplicate code * 200 dpipull/1118/head
parent
9eae82e3a3
commit
5fac5ad165
21
test.py
21
test.py
|
@ -30,9 +30,9 @@ def test(data,
|
|||
verbose=False,
|
||||
model=None,
|
||||
dataloader=None,
|
||||
save_dir='',
|
||||
merge=False,
|
||||
save_txt=False):
|
||||
save_dir=Path(''), # for saving images
|
||||
save_txt=False, # for auto-labelling
|
||||
plots=True):
|
||||
# Initialize/load model and set device
|
||||
training = model is not None
|
||||
if training: # called by train.py
|
||||
|
@ -41,7 +41,7 @@ def test(data,
|
|||
else: # called directly
|
||||
set_logging()
|
||||
device = select_device(opt.device, batch_size=batch_size)
|
||||
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
|
||||
save_txt = opt.save_txt # save *.txt labels
|
||||
if save_txt:
|
||||
out = Path('inference/output')
|
||||
if os.path.exists(out):
|
||||
|
@ -49,7 +49,7 @@ def test(data,
|
|||
os.makedirs(out) # make new output folder
|
||||
|
||||
# Remove previous
|
||||
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
|
||||
for f in glob.glob(str(save_dir / 'test_batch*.jpg')):
|
||||
os.remove(f)
|
||||
|
||||
# Load model
|
||||
|
@ -110,7 +110,7 @@ def test(data,
|
|||
|
||||
# Run NMS
|
||||
t = time_synchronized()
|
||||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
|
||||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||
t1 += time_synchronized() - t
|
||||
|
||||
# Statistics per image
|
||||
|
@ -186,16 +186,16 @@ def test(data,
|
|||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||
|
||||
# Plot images
|
||||
if batch_i < 1:
|
||||
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
|
||||
if plots and batch_i < 1:
|
||||
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
|
||||
plot_images(img, targets, paths, str(f), names) # ground truth
|
||||
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
|
||||
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
|
||||
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
|
||||
|
||||
# Compute statistics
|
||||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats)
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
|
||||
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
|
||||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||
|
@ -261,7 +261,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
|
||||
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
|
||||
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
|
||||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
||||
opt = parser.parse_args()
|
||||
|
|
6
train.py
6
train.py
|
@ -1,5 +1,4 @@
|
|||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
if final_epoch: # replot predictions
|
||||
[os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)]
|
||||
results, maps, times = test.test(opt.data,
|
||||
batch_size=total_batch_size,
|
||||
imgsz=imgsz_test,
|
||||
model=ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=log_dir)
|
||||
save_dir=log_dir,
|
||||
plots=epoch == 0 or final_epoch) # plot first and last
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
|
|
|
@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape):
|
|||
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
||||
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls):
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
|
||||
""" Compute the average precision, given the recall and precision curves.
|
||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
||||
# Arguments
|
||||
tp: True positives (nparray, nx1 or nx10).
|
||||
tp: True positives (nparray, nx1 or nx10).
|
||||
conf: Objectness value from 0-1 (nparray).
|
||||
pred_cls: Predicted object classes (nparray).
|
||||
target_cls: True object classes (nparray).
|
||||
pred_cls: Predicted object classes (nparray).
|
||||
target_cls: True object classes (nparray).
|
||||
plot: Plot precision-recall curve at mAP@0.5
|
||||
fname: Plot filename
|
||||
# Returns
|
||||
The average precision as computed in py-faster-rcnn.
|
||||
"""
|
||||
|
@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
|
|||
unique_classes = np.unique(target_cls)
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
||||
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
|
||||
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
|
||||
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
|
||||
|
@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
|
|||
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
|
||||
|
||||
# Plot
|
||||
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
|
||||
# ax.plot(recall, precision)
|
||||
# ax.set_xlabel('Recall')
|
||||
# ax.set_ylabel('Precision')
|
||||
# ax.set_xlim(0, 1.01)
|
||||
# ax.set_ylim(0, 1.01)
|
||||
# fig.tight_layout()
|
||||
# fig.savefig('PR_curve.png', dpi=300)
|
||||
|
||||
# Compute F1 score (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + 1e-16)
|
||||
|
||||
if plot:
|
||||
py = np.stack(py, axis=1)
|
||||
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
|
||||
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
|
||||
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
|
||||
ax.set_xlabel('Recall')
|
||||
ax.set_ylabel('Precision')
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
plt.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(fname, dpi=200)
|
||||
|
||||
return p, r, ap, f1, unique_classes.astype('int32')
|
||||
|
||||
|
||||
|
@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
|
|||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
|
||||
tl = 3 # line thickness
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
if os.path.isfile(fname): # do not overwrite
|
||||
return None
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
|
|
Loading…
Reference in New Issue