Revised according to comments
parent
a3407c2cf4
commit
48a781cd6a
|
@ -5,6 +5,15 @@ import torch
|
||||||
def average_precision(pred, target):
|
def average_precision(pred, target):
|
||||||
""" Calculate the average precision for a single class
|
""" Calculate the average precision for a single class
|
||||||
|
|
||||||
|
AP summarizes a precision-recall curve as the weighted mean of maximum
|
||||||
|
precisions obtained for any r'>r, where r is the recall:
|
||||||
|
|
||||||
|
..math::
|
||||||
|
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
|
||||||
|
|
||||||
|
Note that no approximation is involved since the curve is piecewise
|
||||||
|
constant.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred (np.ndarray): The model prediction with shape (N, ).
|
pred (np.ndarray): The model prediction with shape (N, ).
|
||||||
target (np.ndarray): The target of each prediction with shape (N, ).
|
target (np.ndarray): The target of each prediction with shape (N, ).
|
||||||
|
@ -19,17 +28,17 @@ def average_precision(pred, target):
|
||||||
sort_target = target[sort_inds]
|
sort_target = target[sort_inds]
|
||||||
|
|
||||||
# count true positive examples
|
# count true positive examples
|
||||||
p_inds = sort_target == 1
|
pos_inds = sort_target == 1
|
||||||
tp = np.cumsum(p_inds)
|
tp = np.cumsum(pos_inds)
|
||||||
total_p = tp[-1]
|
total_pos = tp[-1]
|
||||||
|
|
||||||
# count not difficult examples
|
# count not difficult examples
|
||||||
pn_inds = sort_target != -1
|
pn_inds = sort_target != -1
|
||||||
pn = np.cumsum(pn_inds)
|
pn = np.cumsum(pn_inds)
|
||||||
|
|
||||||
tp[np.logical_not(p_inds)] = 0
|
tp[np.logical_not(pos_inds)] = 0
|
||||||
precision = tp / np.maximum(pn, eps)
|
precision = tp / np.maximum(pn, eps)
|
||||||
ap = np.sum(precision) / np.maximum(total_p, eps)
|
ap = np.sum(precision) / np.maximum(total_pos, eps)
|
||||||
return ap
|
return ap
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue