faiss/contrib/evaluation.py

389 lines
11 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import unittest
from multiprocessing.pool import ThreadPool
###############################################################
# Simple functions to evaluate knn results
def knn_intersection_measure(I1, I2):
""" computes the intersection measure of two result tables
"""
nq, rank = I1.shape
assert I2.shape == (nq, rank)
ninter = sum(
np.intersect1d(I1[i], I2[i]).size
for i in range(nq)
)
return ninter / I1.size
###############################################################
# Range search results can be compared with Precision-Recall
def filter_range_results(lims, D, I, thresh):
""" select a set of results """
nq = lims.size - 1
mask = D < thresh
new_lims = np.zeros_like(lims)
for i in range(nq):
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
return new_lims, D[mask], I[mask]
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
"""compute the precision and recall of range search results. The
function does not take the distances into account. """
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
return Inew[lims_new[i]:lims_new[i + 1]]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
ninter = np.zeros(nq, dtype="int64")
def compute_PR_for(q):
# ground truth results for this query
gt_ids = ref_result_for(q)
# results for this query
new_ids = new_result_for(q)
# there are no set functions in numpy so let's do this
inter = np.intersect1d(gt_ids, new_ids)
ninter[q] = len(inter)
# run in a thread pool, which helps in spite of the GIL
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
return counts_to_PR(
lims_ref[1:] - lims_ref[:-1],
lims_new[1:] - lims_new[:-1],
ninter,
mode=mode
)
def counts_to_PR(ngt, nres, ninter, mode="overall"):
""" computes a precision-recall for a ser of queries.
ngt = nb of GT results per query
nres = nb of found results per query
ninter = nb of correct results per query (smaller than nres of course)
"""
if mode == "overall":
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
if nres > 0:
precision = ninter / nres
else:
precision = 1.0
if ngt > 0:
recall = ninter / ngt
elif nres == 0:
recall = 1.0
else:
recall = 0.0
return precision, recall
elif mode == "average":
# average precision and recall over queries
mask = ngt == 0
ngt[mask] = 1
recalls = ninter / ngt
recalls[mask] = (nres[mask] == 0).astype(float)
# avoid division by 0
mask = nres == 0
assert np.all(ninter[mask] == 0)
ninter[mask] = 1
nres[mask] = 1
precisions = ninter / nres
return precisions.mean(), recalls.mean()
else:
raise AssertionError()
def sort_range_res_2(lims, D, I):
""" sort 2 arrays using the first as key """
I2 = np.empty_like(I)
D2 = np.empty_like(D)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
ii = I[l0:l1]
di = D[l0:l1]
o = di.argsort()
I2[l0:l1] = ii[o]
D2[l0:l1] = di[o]
return I2, D2
def sort_range_res_1(lims, I):
I2 = np.empty_like(I)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
I2[l0:l1] = I[l0:l1]
I2[l0:l1].sort()
return I2
def range_PR_multiple_thresholds(
lims_ref, Iref,
lims_new, Dnew, Inew,
thresholds,
mode="overall", do_sort="ref,new"
):
""" compute precision-recall values for range search results
for several thresholds on the "new" results.
This is to plot PR curves
"""
# ref should be sorted by ids
if "ref" in do_sort:
Iref = sort_range_res_1(lims_ref, Iref)
# new should be sorted by distances
if "new" in do_sort:
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
l0, l1 = lims_new[i], lims_new[i + 1]
return Inew[l0:l1], Dnew[l0:l1]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
nt = len(thresholds)
counts = np.zeros((nq, nt, 3), dtype="int64")
def compute_PR_for(q):
gt_ids = ref_result_for(q)
res_ids, res_dis = new_result_for(q)
counts[q, :, 0] = len(gt_ids)
if res_dis.size == 0:
# the rest remains at 0
return
# which offsets we are interested in
nres= np.searchsorted(res_dis, thresholds)
counts[q, :, 1] = nres
if gt_ids.size == 0:
return
# find number of TPs at each stage in the result list
ii = np.searchsorted(gt_ids, res_ids)
ii[ii == len(gt_ids)] = -1
n_ok = np.cumsum(gt_ids[ii] == res_ids)
# focus on threshold points
n_ok = np.hstack(([0], n_ok))
counts[q, :, 2] = n_ok[nres]
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
# print(counts.transpose(2, 1, 0))
precisions = np.zeros(nt)
recalls = np.zeros(nt)
for t in range(nt):
p, r = counts_to_PR(
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
mode=mode
)
precisions[t] = p
recalls[t] = r
return precisions, recalls
###############################################################
# Functions that compare search results with a reference result.
# They are intended for use in tests
def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:
continue
mask = Dref[i, :] == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
def test_ref_range_results(lims_ref, Dref, Iref,
lims_new, Dnew, Inew):
""" compare range search results wrt. a reference result,
throw if it fails """
np.testing.assert_array_equal(lims_ref, lims_new)
nq = len(lims_ref) - 1
for i in range(nq):
l0, l1 = lims_ref[i], lims_ref[i + 1]
Ii_ref = Iref[l0:l1]
Ii_new = Inew[l0:l1]
Di_ref = Dref[l0:l1]
Di_new = Dnew[l0:l1]
if np.all(Ii_ref == Ii_new): # easy
pass
else:
def sort_by_ids(I, D):
o = I.argsort()
return I[o], D[o]
# sort both
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
np.testing.assert_array_equal(Ii_ref, Ii_new)
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
###############################################################
# OperatingPoints functions
# this is the Python version of the AutoTune object in C++
class OperatingPoints:
"""
Manages a set of search parameters with associated performance and time.
Keeps the Pareto optimal points.
"""
def __init__(self):
# list of (key, perf, t)
self.operating_points = [
# (self.do_nothing_key(), 0.0, 0.0)
]
self.suboptimal_points = []
def compare_keys(self, k1, k2):
""" return -1 if k1 > k2, 1 if k2 > k1, 0 otherwise """
raise NotImplemented
def do_nothing_key(self):
""" parameters to say we do noting, takes 0 time and has 0 performance"""
raise NotImplemented
def is_pareto_optimal(self, perf_new, t_new):
for _, perf, t in self.operating_points:
if perf >= perf_new and t <= t_new:
return False
return True
def predict_bounds(self, key):
""" predicts the bound on time and performance """
min_time = 0.0
max_perf = 1.0
for key2, perf, t in self.operating_points + self.suboptimal_points:
cmp = self.compare_keys(key, key2)
if cmp > 0: # key2 > key
if t > min_time:
min_time = t
if cmp < 0: # key2 < key
if perf < max_perf:
max_perf = perf
return max_perf, min_time
def should_run_experiment(self, key):
(max_perf, min_time) = self.predict_bounds(key)
return self.is_pareto_optimal(max_perf, min_time)
def add_operating_point(self, key, perf, t):
if self.is_pareto_optimal(perf, t):
i = 0
# maybe it shadows some other operating point completely?
while i < len(self.operating_points):
op_Ls, perf2, t2 = self.operating_points[i]
if perf >= perf2 and t < t2:
self.suboptimal_points.append(
self.operating_points.pop(i))
else:
i += 1
self.operating_points.append((key, perf, t))
return True
else:
self.suboptimal_points.append((key, perf, t))
return False
class OperatingPointsWithRanges(OperatingPoints):
"""
Set of parameters that are each picked from a discrete range of values.
An increase of each parameter is assumed to make the operation slower
and more accurate.
A key = int array of indices in the ordered set of parameters.
"""
def __init__(self):
OperatingPoints.__init__(self)
# list of (name, values)
self.ranges = []
def add_range(self, name, values):
self.ranges.append((name, values))
def compare_keys(self, k1, k2):
if np.all(k1 >= k2):
return 1
if np.all(k2 >= k1):
return -1
return 0
def do_nothing_key(self):
return np.zeros(len(self.ranges), dtype=int)
def num_experiments(self):
return np.prod([len(values) for name, values in self.ranges])
def cno_to_key(self, cno):
"""Convert a sequential experiment number to a key"""
k = np.zeros(len(self.ranges), dtype=int)
for i, (name, values) in enumerate(self.ranges):
k[i] = cno % len(values)
cno //= len(values)
assert cno == 0
return k
def get_parameters(self, k):
"""Convert a key to a dictionary with parameter values"""
return {
name: values[k[i]]
for i, (name, values) in enumerate(self.ranges)
}
def restrict_range(self, name, max_val):
""" remove too large values from a range"""
for name2, values in self.ranges:
if name == name2:
val2 = [v for v in values if v < max_val]
values[:] = val2
return
raise RuntimeError(f"parameter {name} not found")