three small fixes (#1972)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1972 This fixes a few issues that I ran into + adds tests: - range_search_max_results with IP search - a few missing downcasts for VectorTRansforms - ResultHeap supports max IP search Reviewed By: wickedfoo Differential Revision: D29525093 fbshipit-source-id: d4ff0aff1d83af9717ff1aaa2fe3cda7b53019a3pull/1983/head
parent
7cce100c92
commit
1829aa92a1
|
@ -220,8 +220,6 @@ def range_PR_multiple_thresholds(
|
|||
return precisions, recalls
|
||||
|
||||
|
||||
|
||||
|
||||
###############################################################
|
||||
# Functions that compare search results with a reference result.
|
||||
# They are intended for use in tests
|
||||
|
|
|
@ -11,6 +11,7 @@ import logging
|
|||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
|
||||
"""Computes the exact KNN search results for a dataset that possibly
|
||||
does not fit in RAM but for which we have an iterator that
|
||||
|
@ -146,21 +147,27 @@ def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
|
|||
return lims, np.hstack(D), np.hstack(I)
|
||||
|
||||
|
||||
def threshold_radius_nres(nres, dis, ids, thresh):
|
||||
def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False):
|
||||
""" select a set of results """
|
||||
mask = dis < thresh
|
||||
if keep_max:
|
||||
mask = dis > thresh
|
||||
else:
|
||||
mask = dis < thresh
|
||||
new_nres = np.zeros_like(nres)
|
||||
o = 0
|
||||
for i, nr in enumerate(nres):
|
||||
nr = int(nr) # avoid issues with int64 + uint64
|
||||
new_nres[i] = mask[o : o + nr].sum()
|
||||
new_nres[i] = mask[o:o + nr].sum()
|
||||
o += nr
|
||||
return new_nres, dis[mask], ids[mask]
|
||||
|
||||
|
||||
def threshold_radius(lims, dis, ids, thresh):
|
||||
def threshold_radius(lims, dis, ids, thresh, keep_max=False):
|
||||
""" restrict range-search results to those below a given radius """
|
||||
mask = dis < thresh
|
||||
if keep_max:
|
||||
mask = dis > thresh
|
||||
else:
|
||||
mask = dis < thresh
|
||||
new_lims = np.zeros_like(lims)
|
||||
n = len(lims) - 1
|
||||
for i in range(n):
|
||||
|
@ -169,12 +176,18 @@ def threshold_radius(lims, dis, ids, thresh):
|
|||
return new_lims, dis[mask], ids[mask]
|
||||
|
||||
|
||||
def apply_maxres(res_batches, target_nres):
|
||||
def apply_maxres(res_batches, target_nres, keep_max=False):
|
||||
"""find radius that reduces number of results to target_nres, and
|
||||
applies it in-place to the result batches used in range_search_max_results"""
|
||||
applies it in-place to the result batches used in
|
||||
range_search_max_results"""
|
||||
alldis = np.hstack([dis for _, dis, _ in res_batches])
|
||||
alldis.partition(target_nres)
|
||||
radius = alldis[target_nres]
|
||||
assert len(alldis) > target_nres
|
||||
if keep_max:
|
||||
alldis.partition(len(alldis) - target_nres - 1)
|
||||
radius = alldis[-1 - target_nres]
|
||||
else:
|
||||
alldis.partition(target_nres)
|
||||
radius = alldis[target_nres]
|
||||
|
||||
if alldis.dtype == 'float32':
|
||||
radius = float(radius)
|
||||
|
@ -183,7 +196,8 @@ def apply_maxres(res_batches, target_nres):
|
|||
LOG.debug(' setting radius to %s' % radius)
|
||||
totres = 0
|
||||
for i, (nres, dis, ids) in enumerate(res_batches):
|
||||
nres, dis, ids = threshold_radius_nres(nres, dis, ids, radius)
|
||||
nres, dis, ids = threshold_radius_nres(
|
||||
nres, dis, ids, radius, keep_max=keep_max)
|
||||
totres += len(dis)
|
||||
res_batches[i] = nres, dis, ids
|
||||
LOG.debug(' updated previous results, new nb results %d' % totres)
|
||||
|
@ -192,7 +206,7 @@ def apply_maxres(res_batches, target_nres):
|
|||
|
||||
def range_search_max_results(index, query_iterator, radius,
|
||||
max_results=None, min_results=None,
|
||||
shard=False, ngpu=0):
|
||||
shard=False, ngpu=0, clip_to_min=False):
|
||||
"""Performs a range search with many queries (given by an iterator)
|
||||
and adjusts the threshold on-the-fly so that the total results
|
||||
table does not grow larger than max_results.
|
||||
|
@ -200,10 +214,16 @@ def range_search_max_results(index, query_iterator, radius,
|
|||
If ngpu != 0, the function moves the index to this many GPUs to
|
||||
speed up search.
|
||||
"""
|
||||
# TODO: all result manipulations are in python, should move to C++ if perf
|
||||
# critical
|
||||
|
||||
if max_results is not None:
|
||||
if min_results is None:
|
||||
min_results = int(0.8 * max_results)
|
||||
if min_results is None:
|
||||
assert max_results is not None
|
||||
min_results = int(0.8 * max_results)
|
||||
|
||||
if max_results is None:
|
||||
assert min_results is not None
|
||||
max_results = int(min_results * 1.5)
|
||||
|
||||
if ngpu == -1:
|
||||
ngpu = faiss.get_num_gpus()
|
||||
|
@ -242,15 +262,26 @@ def range_search_max_results(index, query_iterator, radius,
|
|||
if max_results is not None and totres > max_results:
|
||||
LOG.info('too many results %d > %d, scaling back radius' %
|
||||
(totres, max_results))
|
||||
radius, totres = apply_maxres(res_batches, min_results)
|
||||
radius, totres = apply_maxres(
|
||||
res_batches, min_results,
|
||||
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
t2 = time.time()
|
||||
t_search += t1 - t0
|
||||
t_post_process += t2 - t1
|
||||
LOG.debug(' [%.3f s] %d queries done, %d results' % (
|
||||
time.time() - t_start, qtot, totres))
|
||||
|
||||
LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
|
||||
t_search, t_post_process, totres, radius))
|
||||
LOG.info(
|
||||
'search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
|
||||
t_search, t_post_process, totres, radius)
|
||||
)
|
||||
|
||||
if clip_to_min and totres > min_results:
|
||||
radius, totres = apply_maxres(
|
||||
res_batches, min_results,
|
||||
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
|
||||
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
|
||||
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
|
||||
|
@ -260,3 +291,18 @@ def range_search_max_results(index, query_iterator, radius,
|
|||
lims[1:] = np.cumsum(nres)
|
||||
|
||||
return radius, lims, dis, ids
|
||||
|
||||
|
||||
def exponential_query_iterator(xq, start_bs=32, max_bs=20000):
|
||||
""" produces batches of progressively increasing sizes. This is useful to
|
||||
adjust the search radius progressively without overflowing with
|
||||
intermediate results """
|
||||
nq = len(xq)
|
||||
bs = start_bs
|
||||
i = 0
|
||||
while i < nq:
|
||||
xqi = xq[i:i + bs]
|
||||
yield xqi
|
||||
if bs < max_bs:
|
||||
bs *= 2
|
||||
i += len(xqi)
|
||||
|
|
|
@ -1600,12 +1600,15 @@ class ResultHeap:
|
|||
"""Accumulate query results from a sliced dataset. The final result will
|
||||
be in self.D, self.I."""
|
||||
|
||||
def __init__(self, nq, k):
|
||||
def __init__(self, nq, k, keep_max=False):
|
||||
" nq: number of query vectors, k: number of results per query "
|
||||
self.I = np.zeros((nq, k), dtype='int64')
|
||||
self.D = np.zeros((nq, k), dtype='float32')
|
||||
self.nq, self.k = nq, k
|
||||
heaps = float_maxheap_array_t()
|
||||
if keep_max:
|
||||
heaps = float_minheap_array_t()
|
||||
else:
|
||||
heaps = float_maxheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = swig_ptr(self.D)
|
||||
|
@ -1615,11 +1618,12 @@ class ResultHeap:
|
|||
|
||||
def add_result(self, D, I):
|
||||
"""D, I do not need to be in a particular order (heap or sorted)"""
|
||||
assert D.shape == (self.nq, self.k)
|
||||
assert I.shape == (self.nq, self.k)
|
||||
nq, kd = D.shape
|
||||
assert I.shape == (nq, kd)
|
||||
assert nq == self.nq
|
||||
self.heaps.addn_with_ids(
|
||||
self.k, swig_ptr(D),
|
||||
swig_ptr(I), self.k)
|
||||
kd, swig_ptr(D),
|
||||
swig_ptr(I), kd)
|
||||
|
||||
def finalize(self):
|
||||
self.heaps.reorder()
|
||||
|
|
|
@ -614,10 +614,12 @@ void gpu_sync_all_devices()
|
|||
DOWNCAST (RemapDimensionsTransform)
|
||||
DOWNCAST (OPQMatrix)
|
||||
DOWNCAST (PCAMatrix)
|
||||
DOWNCAST (ITQMatrix)
|
||||
DOWNCAST (RandomRotationMatrix)
|
||||
DOWNCAST (LinearTransform)
|
||||
DOWNCAST (NormalizationTransform)
|
||||
DOWNCAST (CenteringTransform)
|
||||
DOWNCAST (ITQTransform)
|
||||
DOWNCAST (VectorTransform)
|
||||
{
|
||||
assert(false);
|
||||
|
|
|
@ -470,5 +470,33 @@ class TestNNDescentKNNG(unittest.TestCase):
|
|||
assert recall > 0.99
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
class TestResultHeap(unittest.TestCase):
|
||||
|
||||
def test_keep_min(self):
|
||||
self.run_test(False)
|
||||
|
||||
def test_keep_max(self):
|
||||
self.run_test(True)
|
||||
|
||||
def run_test(self, keep_max):
|
||||
nq = 100
|
||||
nb = 1000
|
||||
restab = faiss.rand((nq, nb), 123)
|
||||
ids = faiss.randint((nq, nb), 1324, 10000)
|
||||
all_rh = {}
|
||||
for nstep in 1, 3:
|
||||
rh = faiss.ResultHeap(nq, 10, keep_max=keep_max)
|
||||
for i in range(nstep):
|
||||
i0, i1 = i * nb // nstep, (i + 1) * nb // nstep
|
||||
D = restab[:, i0:i1].copy()
|
||||
I = ids[:, i0:i1].copy()
|
||||
rh.add_result(D, I)
|
||||
rh.finalize()
|
||||
if keep_max:
|
||||
assert np.all(rh.D[:, :-1] >= rh.D[:, 1:])
|
||||
else:
|
||||
assert np.all(rh.D[:, :-1] <= rh.D[:, 1:])
|
||||
all_rh[nstep] = rh
|
||||
|
||||
np.testing.assert_equal(all_rh[1].D, all_rh[3].D)
|
||||
np.testing.assert_equal(all_rh[1].I, all_rh[3].I)
|
||||
|
|
|
@ -15,8 +15,9 @@ from faiss.contrib import ivf_tools
|
|||
|
||||
from common_faiss_tests import get_dataset_2
|
||||
try:
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth
|
||||
from faiss.contrib.exhaustive_search import range_search_max_results
|
||||
from faiss.contrib.exhaustive_search import \
|
||||
knn_ground_truth, knn, range_ground_truth, \
|
||||
range_search_max_results, exponential_query_iterator
|
||||
|
||||
except:
|
||||
pass # Submodule import broken in python 2.
|
||||
|
@ -151,7 +152,7 @@ class TestExhaustiveSearch(unittest.TestCase):
|
|||
|
||||
# check repro OK
|
||||
_, new_lims, new_D, new_I = range_search_max_results(
|
||||
index, matrix_iterator(xq, 100), threshold)
|
||||
index, matrix_iterator(xq, 100), threshold, max_results=1e10)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
ref_lims, ref_D, ref_I,
|
||||
|
@ -381,3 +382,38 @@ class TestPreassigned(unittest.TestCase):
|
|||
for q in range(len(xq)):
|
||||
l0, l1 = lims[q], lims[q + 1]
|
||||
self.assertTrue(set(I[q]) <= set(IR[l0:l1]))
|
||||
|
||||
|
||||
class TestRangeSearchMaxResults(unittest.TestCase):
|
||||
|
||||
def do_test(self, metric_type):
|
||||
ds = datasets.SyntheticDataset(32, 0, 1000, 200)
|
||||
index = faiss.IndexFlat(ds.d, metric_type)
|
||||
index.add(ds.get_database())
|
||||
|
||||
# find a reasonable radius
|
||||
D, _ = index.search(ds.get_queries(), 10)
|
||||
radius0 = float(np.median(D[:, -1]))
|
||||
|
||||
# baseline = search with that radius
|
||||
lims_ref, Dref, Iref = index.range_search(ds.get_queries(), radius0)
|
||||
|
||||
# now see if using just the total number of results, we can get back the same
|
||||
# result table
|
||||
query_iterator = exponential_query_iterator(ds.get_queries())
|
||||
|
||||
init_radius = 1e10 if metric_type == faiss.METRIC_L2 else -1e10
|
||||
radius1, lims_new, Dnew, Inew = range_search_max_results(
|
||||
index, query_iterator, init_radius, min_results=Dref.size, clip_to_min=True
|
||||
)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
lims_ref, Dref, Iref,
|
||||
lims_new, Dnew, Inew
|
||||
)
|
||||
|
||||
def test_L2(self):
|
||||
self.do_test(faiss.METRIC_L2)
|
||||
|
||||
def test_IP(self):
|
||||
self.do_test(faiss.METRIC_INNER_PRODUCT)
|
||||
|
|
|
@ -168,6 +168,7 @@ class TestCloneSize(unittest.TestCase):
|
|||
index2 = faiss.clone_index(index)
|
||||
assert index2.ntotal == 100
|
||||
|
||||
|
||||
class TestCloneIVFPQ(unittest.TestCase):
|
||||
|
||||
def test_clone(self):
|
||||
|
@ -177,3 +178,12 @@ class TestCloneIVFPQ(unittest.TestCase):
|
|||
index.add(xb)
|
||||
index2 = faiss.clone_index(index)
|
||||
assert index2.ntotal == index.ntotal
|
||||
|
||||
|
||||
class TestVTDowncast(unittest.TestCase):
|
||||
|
||||
def test_itq_transform(self):
|
||||
codec = faiss.index_factory(16, "ITQ8,LSHt")
|
||||
|
||||
itqt = faiss.downcast_VectorTransform(codec.chain.at(0))
|
||||
itqt.pca_then_itq
|
||||
|
|
Loading…
Reference in New Issue