Confusion matrix (#1474)
* initial commit * add plotting * matrix to cpu * bug fix * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * seaborn pandas to requirements.txt * seaborn pandas to requirements.txt * update wandb plotting * remove pandas * if plots * if plots * if plots * if plots * if plots * initial commit * add plotting * matrix to cpu * bug fix * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * update plot * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * cleanup * seaborn pandas to requirements.txt * seaborn pandas to requirements.txt * update wandb plotting * remove pandas * if plots * if plots * if plots * if plots * if plots * Cat apriori to autolabels * cleanuppull/1488/head
parent
95fa65339f
commit
0a3ff71ae0
|
@ -16,8 +16,9 @@ tqdm>=4.41.0
|
||||||
# logging -------------------------------------
|
# logging -------------------------------------
|
||||||
# wandb
|
# wandb
|
||||||
|
|
||||||
# coco ----------------------------------------
|
# plotting ------------------------------------
|
||||||
# pycocotools>=2.0
|
seaborn
|
||||||
|
pandas
|
||||||
|
|
||||||
# export --------------------------------------
|
# export --------------------------------------
|
||||||
# coremltools==4.0
|
# coremltools==4.0
|
||||||
|
@ -26,4 +27,4 @@ tqdm>=4.41.0
|
||||||
|
|
||||||
# extras --------------------------------------
|
# extras --------------------------------------
|
||||||
# thop # FLOPS computation
|
# thop # FLOPS computation
|
||||||
# seaborn # plotting
|
# pycocotools>=2.0 # COCO mAP
|
||||||
|
|
15
test.py
15
test.py
|
@ -14,7 +14,7 @@ from utils.datasets import create_dataloader
|
||||||
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
|
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
|
||||||
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
|
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
|
||||||
from utils.loss import compute_loss
|
from utils.loss import compute_loss
|
||||||
from utils.metrics import ap_per_class
|
from utils.metrics import ap_per_class, ConfusionMatrix
|
||||||
from utils.plots import plot_images, output_to_target
|
from utils.plots import plot_images, output_to_target
|
||||||
from utils.torch_utils import select_device, time_synchronized
|
from utils.torch_utils import select_device, time_synchronized
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ def test(data,
|
||||||
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
|
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
|
||||||
|
|
||||||
seen = 0
|
seen = 0
|
||||||
|
confusion_matrix = ConfusionMatrix(nc=nc)
|
||||||
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
||||||
coco91class = coco80_to_coco91_class()
|
coco91class = coco80_to_coco91_class()
|
||||||
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
|
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
|
||||||
|
@ -176,6 +177,8 @@ def test(data,
|
||||||
# target boxes
|
# target boxes
|
||||||
tbox = xywh2xyxy(labels[:, 1:5])
|
tbox = xywh2xyxy(labels[:, 1:5])
|
||||||
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
|
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
|
||||||
|
if plots:
|
||||||
|
confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
|
||||||
|
|
||||||
# Per target class
|
# Per target class
|
||||||
for cls in torch.unique(tcls_tensor):
|
for cls in torch.unique(tcls_tensor):
|
||||||
|
@ -218,10 +221,12 @@ def test(data,
|
||||||
else:
|
else:
|
||||||
nt = torch.zeros(1)
|
nt = torch.zeros(1)
|
||||||
|
|
||||||
# W&B logging
|
# Plots
|
||||||
if plots and wandb and wandb.run:
|
if plots:
|
||||||
wandb.log({"Images": wandb_images})
|
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
||||||
wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]})
|
if wandb and wandb.run:
|
||||||
|
wandb.log({"Images": wandb_images})
|
||||||
|
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
pf = '%20s' + '%12.3g' * 6 # print format
|
pf = '%20s' + '%12.3g' * 6 # print format
|
||||||
|
|
5
train.py
5
train.py
|
@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||||
if plots:
|
if plots:
|
||||||
plot_results(save_dir=save_dir) # save as results.png
|
plot_results(save_dir=save_dir) # save as results.png
|
||||||
if wandb:
|
if wandb:
|
||||||
wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in
|
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
|
||||||
['results.png', 'precision_recall_curve.png']]})
|
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
||||||
|
if (save_dir / f).exists()]})
|
||||||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||||
else:
|
else:
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
|
@ -4,6 +4,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import general
|
||||||
|
|
||||||
|
|
||||||
def fitness(x):
|
def fitness(x):
|
||||||
|
@ -102,6 +105,84 @@ def compute_ap(recall, precision):
|
||||||
return ap, mpre, mrec
|
return ap, mpre, mrec
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix:
|
||||||
|
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
|
||||||
|
def __init__(self, nc, conf=0.25, iou_thres=0.45):
|
||||||
|
self.matrix = np.zeros((nc + 1, nc + 1))
|
||||||
|
self.nc = nc # number of classes
|
||||||
|
self.conf = conf
|
||||||
|
self.iou_thres = iou_thres
|
||||||
|
|
||||||
|
def process_batch(self, detections, labels):
|
||||||
|
"""
|
||||||
|
Return intersection-over-union (Jaccard index) of boxes.
|
||||||
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||||
|
Arguments:
|
||||||
|
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
||||||
|
labels (Array[M, 5]), class, x1, y1, x2, y2
|
||||||
|
Returns:
|
||||||
|
None, updates confusion matrix accordingly
|
||||||
|
"""
|
||||||
|
detections = detections[detections[:, 4] > self.conf]
|
||||||
|
gt_classes = labels[:, 0].int()
|
||||||
|
detection_classes = detections[:, 5].int()
|
||||||
|
iou = general.box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
|
x = torch.where(iou > self.iou_thres)
|
||||||
|
if x[0].shape[0]:
|
||||||
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||||
|
if x[0].shape[0] > 1:
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
else:
|
||||||
|
matches = np.zeros((0, 3))
|
||||||
|
|
||||||
|
n = matches.shape[0] > 0
|
||||||
|
m0, m1, _ = matches.transpose().astype(np.int16)
|
||||||
|
for i, gc in enumerate(gt_classes):
|
||||||
|
j = m0 == i
|
||||||
|
if n and sum(j) == 1:
|
||||||
|
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
|
||||||
|
else:
|
||||||
|
self.matrix[gc, self.nc] += 1 # background FP
|
||||||
|
|
||||||
|
if n:
|
||||||
|
for i, dc in enumerate(detection_classes):
|
||||||
|
if not any(m1 == i):
|
||||||
|
self.matrix[self.nc, dc] += 1 # background FN
|
||||||
|
|
||||||
|
def matrix(self):
|
||||||
|
return self.matrix
|
||||||
|
|
||||||
|
def plot(self, save_dir='', names=()):
|
||||||
|
try:
|
||||||
|
import seaborn as sn
|
||||||
|
|
||||||
|
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
|
||||||
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(12, 9))
|
||||||
|
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
||||||
|
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
|
||||||
|
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
|
||||||
|
xticklabels=names + ['background FN'] if labels else "auto",
|
||||||
|
yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||||
|
fig.axes[0].set_xlabel('True')
|
||||||
|
fig.axes[0].set_ylabel('Predicted')
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
for i in range(self.nc + 1):
|
||||||
|
print(' '.join(map(str, self.matrix[i])))
|
||||||
|
|
||||||
|
|
||||||
|
# Plots ----------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
|
||||||
py = np.stack(py, axis=1)
|
py = np.stack(py, axis=1)
|
||||||
|
|
Loading…
Reference in New Issue