diff --git a/contrib/README.md b/contrib/README.md index b2efbe7d7..117546359 100644 --- a/contrib/README.md +++ b/contrib/README.md @@ -54,3 +54,7 @@ Defintion of how to access data for some standard datsets. ### factory_tools.py Functions related to factory strings. + +### evaluation.py + +A few non-trivial evaluation functions for search results diff --git a/contrib/evaluation.py b/contrib/evaluation.py new file mode 100644 index 000000000..d69dfaf88 --- /dev/null +++ b/contrib/evaluation.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import unittest + +from multiprocessing.dummy import Pool as ThreadPool + +############################################################### +# Simple functions to evaluate knn results + +def knn_intersection_measure(I1, I2): + """ computes the intersection measure of two result tables + """ + nq, rank = I1.shape + assert I2.shape == (nq, rank) + ninter = sum( + np.intersect1d(I1[i], I2[i]).size + for i in range(nq) + ) + return ninter / I1.size + +############################################################### +# Range search results can be compared with Precision-Recall + +def filter_range_results(lims, D, I, thresh): + """ select a set of results """ + nq = lims.size - 1 + mask = D < thresh + new_lims = np.zeros_like(lims) + for i in range(nq): + new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum() + return new_lims, D[mask], I[mask] + + +def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"): + """compute the precision and recall of range search results. The + function does not take the distances into account. """ + + def ref_result_for(i): + return Iref[lims_ref[i]:lims_ref[i + 1]] + + def new_result_for(i): + return Inew[lims_new[i]:lims_new[i + 1]] + + nq = lims_ref.size - 1 + assert lims_new.size - 1 == nq + + ninter = np.zeros(nq, dtype="int64") + + def compute_PR_for(q): + + # ground truth results for this query + gt_ids = ref_result_for(q) + + # results for this query + new_ids = new_result_for(q) + + # there are no set functions in numpy so let's do this + inter = np.intersect1d(gt_ids, new_ids) + + ninter[q] = len(inter) + + # run in a thread pool, which helps in spite of the GIL + pool = ThreadPool(20) + pool.map(compute_PR_for, range(nq)) + + return counts_to_PR( + lims_ref[1:] - lims_ref[:-1], + lims_new[1:] - lims_new[:-1], + ninter, + mode=mode + ) + + +def counts_to_PR(ngt, nres, ninter, mode="overall"): + """ computes a precision-recall for a ser of queries. + ngt = nb of GT results per query + nres = nb of found results per query + ninter = nb of correct results per query (smaller than nres of course) + """ + + if mode == "overall": + ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum() + + if nres > 0: + precision = ninter / nres + else: + precision = 1.0 + + if ngt > 0: + recall = ninter / ngt + elif nres == 0: + recall = 1.0 + else: + recall = 0.0 + + return precision, recall + + elif mode == "average": + # average precision and recall over queries + + mask = ngt == 0 + ngt[mask] = 1 + + recalls = ninter / ngt + recalls[mask] = (nres[mask] == 0).astype(float) + + # avoid division by 0 + mask = nres == 0 + assert np.all(ninter[mask] == 0) + ninter[mask] = 1 + nres[mask] = 1 + + precisions = ninter / nres + + return precisions.mean(), recalls.mean() + + else: + raise AssertionError() + +def sort_range_res_2(lims, D, I): + """ sort 2 arrays using the first as key """ + I2 = np.empty_like(I) + D2 = np.empty_like(D) + nq = len(lims) - 1 + for i in range(nq): + l0, l1 = lims[i], lims[i + 1] + ii = I[l0:l1] + di = D[l0:l1] + o = di.argsort() + I2[l0:l1] = ii[o] + D2[l0:l1] = di[o] + return I2, D2 + + +def sort_range_res_1(lims, I): + I2 = np.empty_like(I) + nq = len(lims) - 1 + for i in range(nq): + l0, l1 = lims[i], lims[i + 1] + I2[l0:l1] = I[l0:l1] + I2[l0:l1].sort() + return I2 + + +def range_PR_multiple_thresholds( + lims_ref, Iref, + lims_new, Dnew, Inew, + thresholds, + mode="overall", do_sort="ref,new" + ): + """ compute precision-recall values for range search results + for several thresholds on the "new" results. + This is to plot PR curves + """ + # ref should be sorted by ids + if "ref" in do_sort: + Iref = sort_range_res_1(lims_ref, Iref) + + # new should be sorted by distances + if "new" in do_sort: + Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew) + + def ref_result_for(i): + return Iref[lims_ref[i]:lims_ref[i + 1]] + + def new_result_for(i): + l0, l1 = lims_new[i], lims_new[i + 1] + return Inew[l0:l1], Dnew[l0:l1] + + nq = lims_ref.size - 1 + assert lims_new.size - 1 == nq + + nt = len(thresholds) + counts = np.zeros((nq, nt, 3), dtype="int64") + + def compute_PR_for(q): + gt_ids = ref_result_for(q) + res_ids, res_dis = new_result_for(q) + + counts[q, :, 0] = len(gt_ids) + + if res_dis.size == 0: + # the rest remains at 0 + return + + # which offsets we are interested in + nres= np.searchsorted(res_dis, thresholds) + counts[q, :, 1] = nres + + if gt_ids.size == 0: + return + + # find number of TPs at each stage in the result list + ii = np.searchsorted(gt_ids, res_ids) + ii[ii == len(gt_ids)] = -1 + n_ok = np.cumsum(gt_ids[ii] == res_ids) + + # focus on threshold points + n_ok = np.hstack(([0], n_ok)) + counts[q, :, 2] = n_ok[nres] + + pool = ThreadPool(20) + pool.map(compute_PR_for, range(nq)) + # print(counts.transpose(2, 1, 0)) + + precisions = np.zeros(nt) + recalls = np.zeros(nt) + for t in range(nt): + p, r = counts_to_PR( + counts[:, t, 0], counts[:, t, 1], counts[:, t, 2], + mode=mode + ) + precisions[t] = p + recalls[t] = r + + return precisions, recalls + + + + +############################################################### +# Functions that compare search results with a reference result. +# They are intended for use in tests + +def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew): + """ test that knn search results are identical, raise if not """ + np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) + # here we have to be careful because of draws + testcase = unittest.TestCase() # because it makes nice error messages + for i in range(len(Iref)): + if np.all(Iref[i] == Inew[i]): # easy case + continue + # we can deduce nothing about the latest line + skip_dis = Dref[i, -1] + for dis in np.unique(Dref): + if dis == skip_dis: + continue + mask = Dref[i, :] == dis + testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask])) + + +def test_ref_range_results(lims_ref, Dref, Iref, + lims_new, Dnew, Inew): + """ compare range search results wrt. a reference result, + throw if it fails """ + np.testing.assert_array_equal(lims_ref, lims_new) + nq = len(lims_ref) - 1 + for i in range(nq): + l0, l1 = lims_ref[i], lims_ref[i + 1] + Ii_ref = Iref[l0:l1] + Ii_new = Inew[l0:l1] + Di_ref = Dref[l0:l1] + Di_new = Dnew[l0:l1] + if np.all(Ii_ref == Ii_new): # easy + pass + else: + def sort_by_ids(I, D): + o = I.argsort() + return I[o], D[o] + # sort both + (Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref) + (Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new) + np.testing.assert_array_equal(Ii_ref, Ii_new) + np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5) diff --git a/contrib/exhaustive_search.py b/contrib/exhaustive_search.py index 32517b539..638c4f4d2 100644 --- a/contrib/exhaustive_search.py +++ b/contrib/exhaustive_search.py @@ -11,7 +11,7 @@ import logging LOG = logging.getLogger(__name__) -def knn_ground_truth(xq, db_iterator, k): +def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2): """Computes the exact KNN search results for a dataset that possibly does not fit in RAM but for which we have an iterator that returns it block by block. @@ -21,12 +21,12 @@ def knn_ground_truth(xq, db_iterator, k): nq, d = xq.shape rh = faiss.ResultHeap(nq, k) - index = faiss.IndexFlatL2(d) + index = faiss.IndexFlat(d, metric_type) if faiss.get_num_gpus(): LOG.info('running on %d GPUs' % faiss.get_num_gpus()) index = faiss.index_cpu_to_all_gpus(index) - # compute ground-truth by blocks of bs, and add to heaps + # compute ground-truth by blocks, and add to heaps i0 = 0 for xbi in db_iterator: ni = xbi.shape[0] @@ -44,4 +44,184 @@ def knn_ground_truth(xq, db_iterator, k): return rh.D, rh.I # knn function used to be here -knn = faiss.knn \ No newline at end of file +knn = faiss.knn + + + + +def range_search_gpu(xq, r2, index_gpu, xb): + """ GPU does not support range search, so we emulate it with + knn search + fallback to CPU index """ + nq, d = xq.shape + LOG.debug("GPU search %d queries" % nq) + k = min(index_gpu.ntotal, 1024) + D, I = index_gpu.search(xq, k) + if index_gpu.metric_type == faiss.METRIC_L2: + mask = D[:, k - 1] < r2 + else: + mask = D[:, k - 1] > r2 + if mask.sum() > 0: + LOG.debug("CPU search remain %d" % mask.sum()) + index_cpu = faiss.IndexFlat(d, index_gpu.metric_type) + index_cpu.add(xb) + lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2) + LOG.debug("combine") + D_res, I_res = [], [] + nr = 0 + for i in range(nq): + if not mask[i]: + if index_gpu.metric_type == faiss.METRIC_L2: + nv = (D[i, :] < r2).sum() + else: + nv = (D[i, :] > r2).sum() + D_res.append(D[i, :nv]) + I_res.append(I[i, :nv]) + else: + l0, l1 = lim_remain[nr], lim_remain[nr + 1] + D_res.append(D_remain[l0:l1]) + I_res.append(I_remain[l0:l1]) + nr += 1 + lims = np.cumsum([0] + [len(di) for di in D_res]) + return lims, np.hstack(D_res), np.hstack(I_res) + + +def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2, + shard=False, ngpu=-1): + """Computes the range-search search results for a dataset that possibly + does not fit in RAM but for which we have an iterator that + returns it block by block. + """ + nq, d = xq.shape + t0 = time.time() + xq = np.ascontiguousarray(xq, dtype='float32') + + index = faiss.IndexFlat(d, metric_type) + if ngpu == -1: + ngpu = faiss.get_num_gpus() + if ngpu: + LOG.info('running on %d GPUs' % faiss.get_num_gpus()) + co = faiss.GpuMultipleClonerOptions() + co.shard = shard + index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu) + + # compute ground-truth by blocks + i0 = 0 + D = [[] for _i in range(nq)] + I = [[] for _i in range(nq)] + all_lims = [] + for xbi in db_iterator: + ni = xbi.shape[0] + if ngpu > 0: + index_gpu.add(xbi) + lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi) + index_gpu.reset() + else: + index.add(xbi) + lims_i, Di, Ii = index.range_search(xq, threshold) + index.reset() + Ii += i0 + for j in range(nq): + l0, l1 = lims_i[j], lims_i[j + 1] + if l1 > l0: + D[j].append(Di[l0:l1]) + I[j].append(Ii[l0:l1]) + i0 += ni + LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0)) + + empty_I = np.zeros(0, dtype='int64') + empty_D = np.zeros(0, dtype='float32') + # import pdb; pdb.set_trace() + D = [(np.hstack(i) if i != [] else empty_D) for i in D] + I = [(np.hstack(i) if i != [] else empty_I) for i in I] + sizes = [len(i) for i in I] + assert len(sizes) == nq + lims = np.zeros(nq + 1, dtype="uint64") + lims[1:] = np.cumsum(sizes) + return lims, np.hstack(D), np.hstack(I) + + +def threshold_radius(nres, dis, ids, thresh): + """ select a set of results """ + mask = dis < thresh + new_nres = np.zeros_like(nres) + o = 0 + for i, nr in enumerate(nres): + nr = int(nr) # avoid issues with int64 + uint64 + new_nres[i] = mask[o : o + nr].sum() + o += nr + return new_nres, dis[mask], ids[mask] + + +def apply_maxres(res_batches, target_nres): + """find radius that reduces number of results to target_nres, and + applies it in-place to the result batches""" + alldis = np.hstack([dis for _, dis, _ in res_batches]) + alldis.partition(target_nres) + radius = alldis[target_nres] + + if alldis.dtype == 'float32': + radius = float(radius) + else: + radius = int(radius) + print(' setting radius to %s' % radius) + totres = 0 + for i, (nres, dis, ids) in enumerate(res_batches): + nres, dis, ids = threshold_radius(nres, dis, ids, radius) + totres += len(dis) + res_batches[i] = nres, dis, ids + print(' updated previous results, new nb results %d' % totres) + return radius, totres + + +def range_search_max_results(index, query_iterator, radius, + max_results=None, min_results=None): + """ Performs a range search with many queries (given by an iterator) + and adjusts the threshold on-the-fly so that the total results + table does not grow larger than max_results """ + + if max_results is not None: + if min_results is None: + min_results = int(0.8 * max_results) + + t_start = time.time() + t_search = t_post_process = 0 + qtot = totres = raw_totres = 0 + res_batches = [] + + for xqi in query_iterator: + t0 = time.time() + lims_i, Di, Ii = index.range_search(xqi, radius) + nres_i = lims_i[1:] - lims_i[:-1] + raw_totres += len(Di) + qtot += len(xqi) + + t1 = time.time() + if xqi.dtype != np.float32: + # for binary indexes + # weird Faiss quirk that returns floats for Hamming distances + Di = Di.astype('int16') + + totres += len(Di) + res_batches.append((nres_i, Di, Ii)) + + if max_results is not None and totres > max_results: + LOG.info('too many results %d > %d, scaling back radius' % + (totres, max_results)) + radius, totres = apply_maxres(res_batches, min_results) + t2 = time.time() + t_search += t1 - t0 + t_post_process += t2 - t1 + LOG.debug(' [%.3f s] %d queries done, %d results' % ( + time.time() - t_start, qtot, totres)) + + LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % ( + t_search, t_post_process, totres, radius)) + + nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches]) + dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches]) + ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches]) + + lims = np.zeros(len(nres) + 1, dtype='uint64') + lims[1:] = np.cumsum(nres) + + return radius, lims, dis, ids diff --git a/faiss/gpu/test/test_contrib.py b/faiss/gpu/test/test_contrib.py index 9194a6b2e..a4f2e0cd4 100644 --- a/faiss/gpu/test/test_contrib.py +++ b/faiss/gpu/test/test_contrib.py @@ -8,7 +8,9 @@ import unittest import numpy as np from faiss.contrib import datasets -from faiss.contrib.exhaustive_search import knn_ground_truth +from faiss.contrib.exhaustive_search import knn_ground_truth, range_ground_truth +from faiss.contrib import evaluation + from common import get_dataset_2 @@ -33,3 +35,29 @@ class TestComputeGT(unittest.TestCase): np.testing.assert_array_equal(Iref, Inew) np.testing.assert_almost_equal(Dref, Dnew, decimal=4) + + def do_test_range(self, metric): + ds = datasets.SyntheticDataset(32, 0, 1000, 10) + xq = ds.get_queries() + xb = ds.get_database() + D, I = faiss.knn(xq, xb, 10, distance_type=metric) + threshold = float(D[:, -1].mean()) + + index = faiss.IndexFlat(32, metric) + index.add(xb) + ref_lims, ref_D, ref_I = index.range_search(xq, threshold) + + new_lims, new_D, new_I = range_ground_truth( + xq, ds.database_iterator(bs=100), threshold, + metric_type=metric) + + evaluation.test_ref_range_results( + ref_lims, ref_D, ref_I, + new_lims, new_D, new_I + ) + + def test_range_L2(self): + self.do_test_range(faiss.METRIC_L2) + + def test_range_IP(self): + self.do_test_range(faiss.METRIC_INNER_PRODUCT) diff --git a/tests/test_binary_hashindex.py b/tests/test_binary_hashindex.py old mode 100755 new mode 100644 diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 917d80ef2..859908138 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -10,14 +10,18 @@ import platform from faiss.contrib import datasets from faiss.contrib import inspect_tools +from faiss.contrib import evaluation from common import get_dataset_2 try: - from faiss.contrib.exhaustive_search import knn_ground_truth, knn + from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth + from faiss.contrib.exhaustive_search import range_search_max_results except: pass # Submodule import broken in python 2. + + @unittest.skipIf(platform.python_version_tuple()[0] < '3', \ 'Submodule import broken in python 2.') class TestComputeGT(unittest.TestCase): @@ -80,11 +84,9 @@ class TestDatasets(unittest.TestCase): class TestExhaustiveSearch(unittest.TestCase): def test_knn_cpu(self): - xb = np.random.rand(200, 32).astype('float32') xq = np.random.rand(100, 32).astype('float32') - index = faiss.IndexFlatL2(32) index.add(xb) Dref, Iref = index.search(xq, 10) @@ -104,6 +106,72 @@ class TestExhaustiveSearch(unittest.TestCase): assert np.all(Inew == Iref) assert np.allclose(Dref, Dnew) + def do_test_range(self, metric): + ds = datasets.SyntheticDataset(32, 0, 1000, 10) + xq = ds.get_queries() + xb = ds.get_database() + D, I = faiss.knn(xq, xb, 10, distance_type=metric) + threshold = float(D[:, -1].mean()) + + index = faiss.IndexFlat(32, metric) + index.add(xb) + ref_lims, ref_D, ref_I = index.range_search(xq, threshold) + + new_lims, new_D, new_I = range_ground_truth( + xq, ds.database_iterator(bs=100), threshold, ngpu=0, + metric_type=metric) + + evaluation.test_ref_range_results( + ref_lims, ref_D, ref_I, + new_lims, new_D, new_I + ) + + def test_range_L2(self): + self.do_test_range(faiss.METRIC_L2) + + def test_range_IP(self): + self.do_test_range(faiss.METRIC_INNER_PRODUCT) + + def test_query_iterator(self, metric=faiss.METRIC_L2): + ds = datasets.SyntheticDataset(32, 0, 1000, 1000) + xq = ds.get_queries() + xb = ds.get_database() + D, I = faiss.knn(xq, xb, 10, distance_type=metric) + threshold = float(D[:, -1].mean()) + print(threshold) + + index = faiss.IndexFlat(32, metric) + index.add(xb) + ref_lims, ref_D, ref_I = index.range_search(xq, threshold) + + def matrix_iterator(xb, bs): + for i0 in range(0, xb.shape[0], bs): + yield xb[i0:i0 + bs] + + # check repro OK + _, new_lims, new_D, new_I = range_search_max_results( + index, matrix_iterator(xq, 100), threshold) + + evaluation.test_ref_range_results( + ref_lims, ref_D, ref_I, + new_lims, new_D, new_I + ) + + max_res = ref_lims[-1] // 2 + + new_threshold, new_lims, new_D, new_I = range_search_max_results( + index, matrix_iterator(xq, 100), threshold, max_results=max_res) + + self.assertLessEqual(new_lims[-1], max_res) + + ref_lims, ref_D, ref_I = index.range_search(xq, new_threshold) + + evaluation.test_ref_range_results( + ref_lims, ref_D, ref_I, + new_lims, new_D, new_I + ) + + class TestInspect(unittest.TestCase): @@ -123,3 +191,77 @@ class TestInspect(unittest.TestCase): # verify ynew = x @ A.T + b np.testing.assert_array_almost_equal(yref, ynew) + + +class TestRangeEval(unittest.TestCase): + + def test_precision_recall(self): + Iref = [ + [1, 2, 3], + [5, 6], + [], + [] + ] + Inew = [ + [1, 2], + [6, 7], + [1], + [] + ] + + lims_ref = np.cumsum([0] + [len(x) for x in Iref]) + Iref = np.hstack(Iref) + lims_new = np.cumsum([0] + [len(x) for x in Inew]) + Inew = np.hstack(Inew) + + precision, recall = evaluation.range_PR(lims_ref, Iref, lims_new, Inew) + print(precision, recall) + + self.assertEqual(precision, 0.6) + self.assertEqual(recall, 0.6) + + def test_PR_multiple(self): + metric = faiss.METRIC_L2 + ds = datasets.SyntheticDataset(32, 1000, 1000, 10) + xq = ds.get_queries() + xb = ds.get_database() + + # good for ~10k results + threshold = 15 + + index = faiss.IndexFlat(32, metric) + index.add(xb) + ref_lims, ref_D, ref_I = index.range_search(xq, threshold) + + # now make a slightly suboptimal index + index2 = faiss.index_factory(32, "PCA16,Flat") + index2.train(ds.get_train()) + index2.add(xb) + + # PCA reduces distances so will have more results + new_lims, new_D, new_I = index2.range_search(xq, threshold) + + all_thr = np.array([5.0, 10.0, 12.0, 15.0]) + for mode in "overall", "average": + ref_precisions = np.zeros_like(all_thr) + ref_recalls = np.zeros_like(all_thr) + + for i, thr in enumerate(all_thr): + + lims2, _, I2 = evaluation.filter_range_results( + new_lims, new_D, new_I, thr) + + prec, recall = evaluation.range_PR( + ref_lims, ref_I, lims2, I2, mode=mode) + + ref_precisions[i] = prec + ref_recalls[i] = recall + + precisions, recalls = evaluation.range_PR_multiple_thresholds( + ref_lims, ref_I, + new_lims, new_D, new_I, all_thr, + mode=mode + ) + + np.testing.assert_array_almost_equal(ref_precisions, precisions) + np.testing.assert_array_almost_equal(ref_recalls, recalls) diff --git a/tests/test_extra_distances.py b/tests/test_extra_distances.py old mode 100755 new mode 100644 diff --git a/tests/test_io.py b/tests/test_io.py old mode 100755 new mode 100644 diff --git a/tests/test_oom_exception.py b/tests/test_oom_exception.py old mode 100755 new mode 100644 diff --git a/tests/test_standalone_codec.py b/tests/test_standalone_codec.py old mode 100755 new mode 100644