diff --git a/contrib/evaluation.py b/contrib/evaluation.py index d69dfaf88..51a4a7499 100644 --- a/contrib/evaluation.py +++ b/contrib/evaluation.py @@ -220,8 +220,6 @@ def range_PR_multiple_thresholds( return precisions, recalls - - ############################################################### # Functions that compare search results with a reference result. # They are intended for use in tests diff --git a/contrib/exhaustive_search.py b/contrib/exhaustive_search.py index e12ae827a..4f3bd1a89 100644 --- a/contrib/exhaustive_search.py +++ b/contrib/exhaustive_search.py @@ -11,6 +11,7 @@ import logging LOG = logging.getLogger(__name__) + def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2): """Computes the exact KNN search results for a dataset that possibly does not fit in RAM but for which we have an iterator that @@ -146,21 +147,27 @@ def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2, return lims, np.hstack(D), np.hstack(I) -def threshold_radius_nres(nres, dis, ids, thresh): +def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False): """ select a set of results """ - mask = dis < thresh + if keep_max: + mask = dis > thresh + else: + mask = dis < thresh new_nres = np.zeros_like(nres) o = 0 for i, nr in enumerate(nres): nr = int(nr) # avoid issues with int64 + uint64 - new_nres[i] = mask[o : o + nr].sum() + new_nres[i] = mask[o:o + nr].sum() o += nr return new_nres, dis[mask], ids[mask] -def threshold_radius(lims, dis, ids, thresh): +def threshold_radius(lims, dis, ids, thresh, keep_max=False): """ restrict range-search results to those below a given radius """ - mask = dis < thresh + if keep_max: + mask = dis > thresh + else: + mask = dis < thresh new_lims = np.zeros_like(lims) n = len(lims) - 1 for i in range(n): @@ -169,12 +176,18 @@ def threshold_radius(lims, dis, ids, thresh): return new_lims, dis[mask], ids[mask] -def apply_maxres(res_batches, target_nres): +def apply_maxres(res_batches, target_nres, keep_max=False): """find radius that reduces number of results to target_nres, and - applies it in-place to the result batches used in range_search_max_results""" + applies it in-place to the result batches used in + range_search_max_results""" alldis = np.hstack([dis for _, dis, _ in res_batches]) - alldis.partition(target_nres) - radius = alldis[target_nres] + assert len(alldis) > target_nres + if keep_max: + alldis.partition(len(alldis) - target_nres - 1) + radius = alldis[-1 - target_nres] + else: + alldis.partition(target_nres) + radius = alldis[target_nres] if alldis.dtype == 'float32': radius = float(radius) @@ -183,7 +196,8 @@ def apply_maxres(res_batches, target_nres): LOG.debug(' setting radius to %s' % radius) totres = 0 for i, (nres, dis, ids) in enumerate(res_batches): - nres, dis, ids = threshold_radius_nres(nres, dis, ids, radius) + nres, dis, ids = threshold_radius_nres( + nres, dis, ids, radius, keep_max=keep_max) totres += len(dis) res_batches[i] = nres, dis, ids LOG.debug(' updated previous results, new nb results %d' % totres) @@ -192,7 +206,7 @@ def apply_maxres(res_batches, target_nres): def range_search_max_results(index, query_iterator, radius, max_results=None, min_results=None, - shard=False, ngpu=0): + shard=False, ngpu=0, clip_to_min=False): """Performs a range search with many queries (given by an iterator) and adjusts the threshold on-the-fly so that the total results table does not grow larger than max_results. @@ -200,10 +214,16 @@ def range_search_max_results(index, query_iterator, radius, If ngpu != 0, the function moves the index to this many GPUs to speed up search. """ + # TODO: all result manipulations are in python, should move to C++ if perf + # critical - if max_results is not None: - if min_results is None: - min_results = int(0.8 * max_results) + if min_results is None: + assert max_results is not None + min_results = int(0.8 * max_results) + + if max_results is None: + assert min_results is not None + max_results = int(min_results * 1.5) if ngpu == -1: ngpu = faiss.get_num_gpus() @@ -242,15 +262,26 @@ def range_search_max_results(index, query_iterator, radius, if max_results is not None and totres > max_results: LOG.info('too many results %d > %d, scaling back radius' % (totres, max_results)) - radius, totres = apply_maxres(res_batches, min_results) + radius, totres = apply_maxres( + res_batches, min_results, + keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT + ) t2 = time.time() t_search += t1 - t0 t_post_process += t2 - t1 LOG.debug(' [%.3f s] %d queries done, %d results' % ( time.time() - t_start, qtot, totres)) - LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % ( - t_search, t_post_process, totres, radius)) + LOG.info( + 'search done in %.3f s + %.3f s, total %d results, end threshold %g' % ( + t_search, t_post_process, totres, radius) + ) + + if clip_to_min and totres > min_results: + radius, totres = apply_maxres( + res_batches, min_results, + keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT + ) nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches]) dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches]) @@ -260,3 +291,18 @@ def range_search_max_results(index, query_iterator, radius, lims[1:] = np.cumsum(nres) return radius, lims, dis, ids + + +def exponential_query_iterator(xq, start_bs=32, max_bs=20000): + """ produces batches of progressively increasing sizes. This is useful to + adjust the search radius progressively without overflowing with + intermediate results """ + nq = len(xq) + bs = start_bs + i = 0 + while i < nq: + xqi = xq[i:i + bs] + yield xqi + if bs < max_bs: + bs *= 2 + i += len(xqi) diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 2b6f19f8f..7f748dd0e 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -1600,12 +1600,15 @@ class ResultHeap: """Accumulate query results from a sliced dataset. The final result will be in self.D, self.I.""" - def __init__(self, nq, k): + def __init__(self, nq, k, keep_max=False): " nq: number of query vectors, k: number of results per query " self.I = np.zeros((nq, k), dtype='int64') self.D = np.zeros((nq, k), dtype='float32') self.nq, self.k = nq, k - heaps = float_maxheap_array_t() + if keep_max: + heaps = float_minheap_array_t() + else: + heaps = float_maxheap_array_t() heaps.k = k heaps.nh = nq heaps.val = swig_ptr(self.D) @@ -1615,11 +1618,12 @@ class ResultHeap: def add_result(self, D, I): """D, I do not need to be in a particular order (heap or sorted)""" - assert D.shape == (self.nq, self.k) - assert I.shape == (self.nq, self.k) + nq, kd = D.shape + assert I.shape == (nq, kd) + assert nq == self.nq self.heaps.addn_with_ids( - self.k, swig_ptr(D), - swig_ptr(I), self.k) + kd, swig_ptr(D), + swig_ptr(I), kd) def finalize(self): self.heaps.reorder() diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index f2cb2d160..698b166ba 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -614,10 +614,12 @@ void gpu_sync_all_devices() DOWNCAST (RemapDimensionsTransform) DOWNCAST (OPQMatrix) DOWNCAST (PCAMatrix) + DOWNCAST (ITQMatrix) DOWNCAST (RandomRotationMatrix) DOWNCAST (LinearTransform) DOWNCAST (NormalizationTransform) DOWNCAST (CenteringTransform) + DOWNCAST (ITQTransform) DOWNCAST (VectorTransform) { assert(false); diff --git a/tests/test_build_blocks.py b/tests/test_build_blocks.py index 585c8b212..0a546aeb8 100644 --- a/tests/test_build_blocks.py +++ b/tests/test_build_blocks.py @@ -470,5 +470,33 @@ class TestNNDescentKNNG(unittest.TestCase): assert recall > 0.99 -if __name__ == '__main__': - unittest.main() +class TestResultHeap(unittest.TestCase): + + def test_keep_min(self): + self.run_test(False) + + def test_keep_max(self): + self.run_test(True) + + def run_test(self, keep_max): + nq = 100 + nb = 1000 + restab = faiss.rand((nq, nb), 123) + ids = faiss.randint((nq, nb), 1324, 10000) + all_rh = {} + for nstep in 1, 3: + rh = faiss.ResultHeap(nq, 10, keep_max=keep_max) + for i in range(nstep): + i0, i1 = i * nb // nstep, (i + 1) * nb // nstep + D = restab[:, i0:i1].copy() + I = ids[:, i0:i1].copy() + rh.add_result(D, I) + rh.finalize() + if keep_max: + assert np.all(rh.D[:, :-1] >= rh.D[:, 1:]) + else: + assert np.all(rh.D[:, :-1] <= rh.D[:, 1:]) + all_rh[nstep] = rh + + np.testing.assert_equal(all_rh[1].D, all_rh[3].D) + np.testing.assert_equal(all_rh[1].I, all_rh[3].I) diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 71e287903..109e82442 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -15,8 +15,9 @@ from faiss.contrib import ivf_tools from common_faiss_tests import get_dataset_2 try: - from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth - from faiss.contrib.exhaustive_search import range_search_max_results + from faiss.contrib.exhaustive_search import \ + knn_ground_truth, knn, range_ground_truth, \ + range_search_max_results, exponential_query_iterator except: pass # Submodule import broken in python 2. @@ -151,7 +152,7 @@ class TestExhaustiveSearch(unittest.TestCase): # check repro OK _, new_lims, new_D, new_I = range_search_max_results( - index, matrix_iterator(xq, 100), threshold) + index, matrix_iterator(xq, 100), threshold, max_results=1e10) evaluation.test_ref_range_results( ref_lims, ref_D, ref_I, @@ -381,3 +382,38 @@ class TestPreassigned(unittest.TestCase): for q in range(len(xq)): l0, l1 = lims[q], lims[q + 1] self.assertTrue(set(I[q]) <= set(IR[l0:l1])) + + +class TestRangeSearchMaxResults(unittest.TestCase): + + def do_test(self, metric_type): + ds = datasets.SyntheticDataset(32, 0, 1000, 200) + index = faiss.IndexFlat(ds.d, metric_type) + index.add(ds.get_database()) + + # find a reasonable radius + D, _ = index.search(ds.get_queries(), 10) + radius0 = float(np.median(D[:, -1])) + + # baseline = search with that radius + lims_ref, Dref, Iref = index.range_search(ds.get_queries(), radius0) + + # now see if using just the total number of results, we can get back the same + # result table + query_iterator = exponential_query_iterator(ds.get_queries()) + + init_radius = 1e10 if metric_type == faiss.METRIC_L2 else -1e10 + radius1, lims_new, Dnew, Inew = range_search_max_results( + index, query_iterator, init_radius, min_results=Dref.size, clip_to_min=True + ) + + evaluation.test_ref_range_results( + lims_ref, Dref, Iref, + lims_new, Dnew, Inew + ) + + def test_L2(self): + self.do_test(faiss.METRIC_L2) + + def test_IP(self): + self.do_test(faiss.METRIC_INNER_PRODUCT) diff --git a/tests/test_factory.py b/tests/test_factory.py index 2a22ee4cc..8f7b918a3 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -168,6 +168,7 @@ class TestCloneSize(unittest.TestCase): index2 = faiss.clone_index(index) assert index2.ntotal == 100 + class TestCloneIVFPQ(unittest.TestCase): def test_clone(self): @@ -177,3 +178,12 @@ class TestCloneIVFPQ(unittest.TestCase): index.add(xb) index2 = faiss.clone_index(index) assert index2.ntotal == index.ntotal + + +class TestVTDowncast(unittest.TestCase): + + def test_itq_transform(self): + codec = faiss.index_factory(16, "ITQ8,LSHt") + + itqt = faiss.downcast_VectorTransform(codec.chain.at(0)) + itqt.pca_then_itq