EasyCV/easycv/core/evaluation/metrics.py

203 lines
7.3 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import matplotlib.pyplot as plt
import numpy as np
import torch
def ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=False,
fname='precision-recall_curve.png'):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
Args:
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
fname: Plot filename
Returns:
The average precision as computed in py-faster-rcnn.
"""
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes
unique_classes = np.unique(target_cls)
# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
s = [unique_classes.shape[0], tp.shape[1]
] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_gt = (target_cls == c).sum() # Number of ground truth objects
n_p = i.sum() # Number of predicted objects
if n_p == 0 or n_gt == 0:
continue
else:
# Accumulate FPs and TPs
fpc = (1 - tp[i]).cumsum(0)
tpc = tp[i].cumsum(0)
# Recall
recall = tpc / (n_gt + 1e-16) # recall curve
r[ci] = np.interp(
-pr_score, -conf[i], recall[:, 0]
) # r at pr_score, negative x, xp because xp decreases
# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = np.interp(-pr_score, -conf[i],
precision[:, 0]) # p at pr_score
# AP from recall-precision curve
py.append(np.interp(px, recall[:, 0],
precision[:, 0])) # precision at mAP@0.5
for j in range(tp.shape[1]):
ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
# Compute F1 score (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)
if plot:
py = np.stack(py, axis=1)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend()
fig.tight_layout()
fig.savefig(fname, dpi=200)
return p, r, ap, f1, unique_classes.astype('int32')
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rbgirshick/py-faster-rcnn.
Args:
recall: The recall curve (list).
precision: The precision curve (list).
Returns:
The average precision as computed in py-faster-rcnn.
"""
# Append sentinel values to beginning and end
mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
mpre = np.concatenate(([0.], precision, [0.]))
# Compute the precision envelope
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
method = 'interp' # methods: 'continuous', 'interp'
if method == 'interp':
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
i = np.where(
mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
return ap
def f_score(precision, recall, beta=1):
"""calculate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined score.
Default: False.
Returns:
[torch.tensor]: The f-score value.
"""
beta2 = beta**2
# if tp == 0 F will be 1 only if all predictions are zero, all labels are
# zero, and zero_division=1. In all other case, 0
if np.isposinf(beta):
f_score = recall
else:
denom = beta2 * precision + recall
denom[denom == 0.] = 1 # avoid division by 0
f_score = (1 + beta2) * precision * recall / denom
return f_score
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
target (torch.Tensor): The target of each prediction, shape (N, , ...)
ignore_index (int | None): The label index to be ignored. Default: None
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == target.ndim + 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
# transpose to shape (maxk, N, ...)
pred_label = pred_label.transpose(0, 1)
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
if ignore_index is not None:
correct = correct[:, target != ignore_index]
res = []
eps = torch.finfo(torch.float32).eps
for k in topk:
# Avoid causing ZeroDivisionError when all pixels
# of an image are ignored
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
if ignore_index is not None:
total_num = target[target != ignore_index].numel() + eps
else:
total_num = target.numel() + eps
res.append(correct_k.mul_(100.0 / total_num))
return res[0] if return_single else res