# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import unittest from multiprocessing.pool import ThreadPool ############################################################### # Simple functions to evaluate knn results def knn_intersection_measure(I1, I2): """ computes the intersection measure of two result tables """ nq, rank = I1.shape assert I2.shape == (nq, rank) ninter = sum( np.intersect1d(I1[i], I2[i]).size for i in range(nq) ) return ninter / I1.size ############################################################### # Range search results can be compared with Precision-Recall def filter_range_results(lims, D, I, thresh): """ select a set of results """ nq = lims.size - 1 mask = D < thresh new_lims = np.zeros_like(lims) for i in range(nq): new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum() return new_lims, D[mask], I[mask] def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"): """compute the precision and recall of range search results. The function does not take the distances into account. """ def ref_result_for(i): return Iref[lims_ref[i]:lims_ref[i + 1]] def new_result_for(i): return Inew[lims_new[i]:lims_new[i + 1]] nq = lims_ref.size - 1 assert lims_new.size - 1 == nq ninter = np.zeros(nq, dtype="int64") def compute_PR_for(q): # ground truth results for this query gt_ids = ref_result_for(q) # results for this query new_ids = new_result_for(q) # there are no set functions in numpy so let's do this inter = np.intersect1d(gt_ids, new_ids) ninter[q] = len(inter) # run in a thread pool, which helps in spite of the GIL pool = ThreadPool(20) pool.map(compute_PR_for, range(nq)) return counts_to_PR( lims_ref[1:] - lims_ref[:-1], lims_new[1:] - lims_new[:-1], ninter, mode=mode ) def counts_to_PR(ngt, nres, ninter, mode="overall"): """ computes a precision-recall for a ser of queries. ngt = nb of GT results per query nres = nb of found results per query ninter = nb of correct results per query (smaller than nres of course) """ if mode == "overall": ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum() if nres > 0: precision = ninter / nres else: precision = 1.0 if ngt > 0: recall = ninter / ngt elif nres == 0: recall = 1.0 else: recall = 0.0 return precision, recall elif mode == "average": # average precision and recall over queries mask = ngt == 0 ngt[mask] = 1 recalls = ninter / ngt recalls[mask] = (nres[mask] == 0).astype(float) # avoid division by 0 mask = nres == 0 assert np.all(ninter[mask] == 0) ninter[mask] = 1 nres[mask] = 1 precisions = ninter / nres return precisions.mean(), recalls.mean() else: raise AssertionError() def sort_range_res_2(lims, D, I): """ sort 2 arrays using the first as key """ I2 = np.empty_like(I) D2 = np.empty_like(D) nq = len(lims) - 1 for i in range(nq): l0, l1 = lims[i], lims[i + 1] ii = I[l0:l1] di = D[l0:l1] o = di.argsort() I2[l0:l1] = ii[o] D2[l0:l1] = di[o] return I2, D2 def sort_range_res_1(lims, I): I2 = np.empty_like(I) nq = len(lims) - 1 for i in range(nq): l0, l1 = lims[i], lims[i + 1] I2[l0:l1] = I[l0:l1] I2[l0:l1].sort() return I2 def range_PR_multiple_thresholds( lims_ref, Iref, lims_new, Dnew, Inew, thresholds, mode="overall", do_sort="ref,new" ): """ compute precision-recall values for range search results for several thresholds on the "new" results. This is to plot PR curves """ # ref should be sorted by ids if "ref" in do_sort: Iref = sort_range_res_1(lims_ref, Iref) # new should be sorted by distances if "new" in do_sort: Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew) def ref_result_for(i): return Iref[lims_ref[i]:lims_ref[i + 1]] def new_result_for(i): l0, l1 = lims_new[i], lims_new[i + 1] return Inew[l0:l1], Dnew[l0:l1] nq = lims_ref.size - 1 assert lims_new.size - 1 == nq nt = len(thresholds) counts = np.zeros((nq, nt, 3), dtype="int64") def compute_PR_for(q): gt_ids = ref_result_for(q) res_ids, res_dis = new_result_for(q) counts[q, :, 0] = len(gt_ids) if res_dis.size == 0: # the rest remains at 0 return # which offsets we are interested in nres= np.searchsorted(res_dis, thresholds) counts[q, :, 1] = nres if gt_ids.size == 0: return # find number of TPs at each stage in the result list ii = np.searchsorted(gt_ids, res_ids) ii[ii == len(gt_ids)] = -1 n_ok = np.cumsum(gt_ids[ii] == res_ids) # focus on threshold points n_ok = np.hstack(([0], n_ok)) counts[q, :, 2] = n_ok[nres] pool = ThreadPool(20) pool.map(compute_PR_for, range(nq)) # print(counts.transpose(2, 1, 0)) precisions = np.zeros(nt) recalls = np.zeros(nt) for t in range(nt): p, r = counts_to_PR( counts[:, t, 0], counts[:, t, 1], counts[:, t, 2], mode=mode ) precisions[t] = p recalls[t] = r return precisions, recalls ############################################################### # Functions that compare search results with a reference result. # They are intended for use in tests def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew): """ test that knn search results are identical, raise if not """ np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) # here we have to be careful because of draws testcase = unittest.TestCase() # because it makes nice error messages for i in range(len(Iref)): if np.all(Iref[i] == Inew[i]): # easy case continue # we can deduce nothing about the latest line skip_dis = Dref[i, -1] for dis in np.unique(Dref): if dis == skip_dis: continue mask = Dref[i, :] == dis testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) def test_ref_range_results(lims_ref, Dref, Iref, lims_new, Dnew, Inew): """ compare range search results wrt. a reference result, throw if it fails """ np.testing.assert_array_equal(lims_ref, lims_new) nq = len(lims_ref) - 1 for i in range(nq): l0, l1 = lims_ref[i], lims_ref[i + 1] Ii_ref = Iref[l0:l1] Ii_new = Inew[l0:l1] Di_ref = Dref[l0:l1] Di_new = Dnew[l0:l1] if np.all(Ii_ref == Ii_new): # easy pass else: def sort_by_ids(I, D): o = I.argsort() return I[o], D[o] # sort both (Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref) (Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new) np.testing.assert_array_equal(Ii_ref, Ii_new) np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5) ############################################################### # OperatingPoints functions # this is the Python version of the AutoTune object in C++ class OperatingPoints: """ Manages a set of search parameters with associated performance and time. Keeps the Pareto optimal points. """ def __init__(self): # list of (key, perf, t) self.operating_points = [ # (self.do_nothing_key(), 0.0, 0.0) ] self.suboptimal_points = [] def compare_keys(self, k1, k2): """ return -1 if k1 > k2, 1 if k2 > k1, 0 otherwise """ raise NotImplemented def do_nothing_key(self): """ parameters to say we do noting, takes 0 time and has 0 performance""" raise NotImplemented def is_pareto_optimal(self, perf_new, t_new): for _, perf, t in self.operating_points: if perf >= perf_new and t <= t_new: return False return True def predict_bounds(self, key): """ predicts the bound on time and performance """ min_time = 0.0 max_perf = 1.0 for key2, perf, t in self.operating_points + self.suboptimal_points: cmp = self.compare_keys(key, key2) if cmp > 0: # key2 > key if t > min_time: min_time = t if cmp < 0: # key2 < key if perf < max_perf: max_perf = perf return max_perf, min_time def should_run_experiment(self, key): (max_perf, min_time) = self.predict_bounds(key) return self.is_pareto_optimal(max_perf, min_time) def add_operating_point(self, key, perf, t): if self.is_pareto_optimal(perf, t): i = 0 # maybe it shadows some other operating point completely? while i < len(self.operating_points): op_Ls, perf2, t2 = self.operating_points[i] if perf >= perf2 and t < t2: self.suboptimal_points.append( self.operating_points.pop(i)) else: i += 1 self.operating_points.append((key, perf, t)) return True else: self.suboptimal_points.append((key, perf, t)) return False class OperatingPointsWithRanges(OperatingPoints): """ Set of parameters that are each picked from a discrete range of values. An increase of each parameter is assumed to make the operation slower and more accurate. A key = int array of indices in the ordered set of parameters. """ def __init__(self): OperatingPoints.__init__(self) # list of (name, values) self.ranges = [] def add_range(self, name, values): self.ranges.append((name, values)) def compare_keys(self, k1, k2): if np.all(k1 >= k2): return 1 if np.all(k2 >= k1): return -1 return 0 def do_nothing_key(self): return np.zeros(len(self.ranges), dtype=int) def num_experiments(self): return np.prod([len(values) for name, values in self.ranges]) def cno_to_key(self, cno): """Convert a sequential experiment number to a key""" k = np.zeros(len(self.ranges), dtype=int) for i, (name, values) in enumerate(self.ranges): k[i] = cno % len(values) cno //= len(values) assert cno == 0 return k def get_parameters(self, k): """Convert a key to a dictionary with parameter values""" return { name: values[k[i]] for i, (name, values) in enumerate(self.ranges) } def restrict_range(self, name, max_val): """ remove too large values from a range""" for name2, values in self.ranges: if name == name2: val2 = [v for v in values if v < max_val] values[:] = val2 return raise RuntimeError(f"parameter {name} not found")