Revised according to comments

pull/123/head
lixinran 2020-12-31 10:02:23 +08:00
parent a3407c2cf4
commit 48a781cd6a
1 changed files with 14 additions and 5 deletions

View File

@ -5,6 +5,15 @@ import torch
def average_precision(pred, target):
""" Calculate the average precision for a single class
AP summarizes a precision-recall curve as the weighted mean of maximum
precisions obtained for any r'>r, where r is the recall:
..math::
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
Note that no approximation is involved since the curve is piecewise
constant.
Args:
pred (np.ndarray): The model prediction with shape (N, ).
target (np.ndarray): The target of each prediction with shape (N, ).
@ -19,17 +28,17 @@ def average_precision(pred, target):
sort_target = target[sort_inds]
# count true positive examples
p_inds = sort_target == 1
tp = np.cumsum(p_inds)
total_p = tp[-1]
pos_inds = sort_target == 1
tp = np.cumsum(pos_inds)
total_pos = tp[-1]
# count not difficult examples
pn_inds = sort_target != -1
pn = np.cumsum(pn_inds)
tp[np.logical_not(p_inds)] = 0
tp[np.logical_not(pos_inds)] = 0
precision = tp / np.maximum(pn, eps)
ap = np.sum(precision) / np.maximum(total_p, eps)
ap = np.sum(precision) / np.maximum(total_pos, eps)
return ap