three small fixes (#1972)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1972

This fixes a few issues that I ran into + adds tests:

- range_search_max_results with IP search

- a few missing downcasts for VectorTRansforms

- ResultHeap supports max IP search

Reviewed By: wickedfoo

Differential Revision: D29525093

fbshipit-source-id: d4ff0aff1d83af9717ff1aaa2fe3cda7b53019a3
pull/1983/head
Matthijs Douze 2021-07-01 16:06:59 -07:00 committed by Facebook GitHub Bot
parent 7cce100c92
commit 1829aa92a1
7 changed files with 154 additions and 30 deletions

View File

@ -220,8 +220,6 @@ def range_PR_multiple_thresholds(
return precisions, recalls
###############################################################
# Functions that compare search results with a reference result.
# They are intended for use in tests

View File

@ -11,6 +11,7 @@ import logging
LOG = logging.getLogger(__name__)
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
"""Computes the exact KNN search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
@ -146,21 +147,27 @@ def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
return lims, np.hstack(D), np.hstack(I)
def threshold_radius_nres(nres, dis, ids, thresh):
def threshold_radius_nres(nres, dis, ids, thresh, keep_max=False):
""" select a set of results """
mask = dis < thresh
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_nres = np.zeros_like(nres)
o = 0
for i, nr in enumerate(nres):
nr = int(nr) # avoid issues with int64 + uint64
new_nres[i] = mask[o : o + nr].sum()
new_nres[i] = mask[o:o + nr].sum()
o += nr
return new_nres, dis[mask], ids[mask]
def threshold_radius(lims, dis, ids, thresh):
def threshold_radius(lims, dis, ids, thresh, keep_max=False):
""" restrict range-search results to those below a given radius """
mask = dis < thresh
if keep_max:
mask = dis > thresh
else:
mask = dis < thresh
new_lims = np.zeros_like(lims)
n = len(lims) - 1
for i in range(n):
@ -169,12 +176,18 @@ def threshold_radius(lims, dis, ids, thresh):
return new_lims, dis[mask], ids[mask]
def apply_maxres(res_batches, target_nres):
def apply_maxres(res_batches, target_nres, keep_max=False):
"""find radius that reduces number of results to target_nres, and
applies it in-place to the result batches used in range_search_max_results"""
applies it in-place to the result batches used in
range_search_max_results"""
alldis = np.hstack([dis for _, dis, _ in res_batches])
alldis.partition(target_nres)
radius = alldis[target_nres]
assert len(alldis) > target_nres
if keep_max:
alldis.partition(len(alldis) - target_nres - 1)
radius = alldis[-1 - target_nres]
else:
alldis.partition(target_nres)
radius = alldis[target_nres]
if alldis.dtype == 'float32':
radius = float(radius)
@ -183,7 +196,8 @@ def apply_maxres(res_batches, target_nres):
LOG.debug(' setting radius to %s' % radius)
totres = 0
for i, (nres, dis, ids) in enumerate(res_batches):
nres, dis, ids = threshold_radius_nres(nres, dis, ids, radius)
nres, dis, ids = threshold_radius_nres(
nres, dis, ids, radius, keep_max=keep_max)
totres += len(dis)
res_batches[i] = nres, dis, ids
LOG.debug(' updated previous results, new nb results %d' % totres)
@ -192,7 +206,7 @@ def apply_maxres(res_batches, target_nres):
def range_search_max_results(index, query_iterator, radius,
max_results=None, min_results=None,
shard=False, ngpu=0):
shard=False, ngpu=0, clip_to_min=False):
"""Performs a range search with many queries (given by an iterator)
and adjusts the threshold on-the-fly so that the total results
table does not grow larger than max_results.
@ -200,10 +214,16 @@ def range_search_max_results(index, query_iterator, radius,
If ngpu != 0, the function moves the index to this many GPUs to
speed up search.
"""
# TODO: all result manipulations are in python, should move to C++ if perf
# critical
if max_results is not None:
if min_results is None:
min_results = int(0.8 * max_results)
if min_results is None:
assert max_results is not None
min_results = int(0.8 * max_results)
if max_results is None:
assert min_results is not None
max_results = int(min_results * 1.5)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
@ -242,15 +262,26 @@ def range_search_max_results(index, query_iterator, radius,
if max_results is not None and totres > max_results:
LOG.info('too many results %d > %d, scaling back radius' %
(totres, max_results))
radius, totres = apply_maxres(res_batches, min_results)
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
t2 = time.time()
t_search += t1 - t0
t_post_process += t2 - t1
LOG.debug(' [%.3f s] %d queries done, %d results' % (
time.time() - t_start, qtot, totres))
LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
t_search, t_post_process, totres, radius))
LOG.info(
'search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
t_search, t_post_process, totres, radius)
)
if clip_to_min and totres > min_results:
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
@ -260,3 +291,18 @@ def range_search_max_results(index, query_iterator, radius,
lims[1:] = np.cumsum(nres)
return radius, lims, dis, ids
def exponential_query_iterator(xq, start_bs=32, max_bs=20000):
""" produces batches of progressively increasing sizes. This is useful to
adjust the search radius progressively without overflowing with
intermediate results """
nq = len(xq)
bs = start_bs
i = 0
while i < nq:
xqi = xq[i:i + bs]
yield xqi
if bs < max_bs:
bs *= 2
i += len(xqi)

View File

@ -1600,12 +1600,15 @@ class ResultHeap:
"""Accumulate query results from a sliced dataset. The final result will
be in self.D, self.I."""
def __init__(self, nq, k):
def __init__(self, nq, k, keep_max=False):
" nq: number of query vectors, k: number of results per query "
self.I = np.zeros((nq, k), dtype='int64')
self.D = np.zeros((nq, k), dtype='float32')
self.nq, self.k = nq, k
heaps = float_maxheap_array_t()
if keep_max:
heaps = float_minheap_array_t()
else:
heaps = float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = swig_ptr(self.D)
@ -1615,11 +1618,12 @@ class ResultHeap:
def add_result(self, D, I):
"""D, I do not need to be in a particular order (heap or sorted)"""
assert D.shape == (self.nq, self.k)
assert I.shape == (self.nq, self.k)
nq, kd = D.shape
assert I.shape == (nq, kd)
assert nq == self.nq
self.heaps.addn_with_ids(
self.k, swig_ptr(D),
swig_ptr(I), self.k)
kd, swig_ptr(D),
swig_ptr(I), kd)
def finalize(self):
self.heaps.reorder()

View File

@ -614,10 +614,12 @@ void gpu_sync_all_devices()
DOWNCAST (RemapDimensionsTransform)
DOWNCAST (OPQMatrix)
DOWNCAST (PCAMatrix)
DOWNCAST (ITQMatrix)
DOWNCAST (RandomRotationMatrix)
DOWNCAST (LinearTransform)
DOWNCAST (NormalizationTransform)
DOWNCAST (CenteringTransform)
DOWNCAST (ITQTransform)
DOWNCAST (VectorTransform)
{
assert(false);

View File

@ -470,5 +470,33 @@ class TestNNDescentKNNG(unittest.TestCase):
assert recall > 0.99
if __name__ == '__main__':
unittest.main()
class TestResultHeap(unittest.TestCase):
def test_keep_min(self):
self.run_test(False)
def test_keep_max(self):
self.run_test(True)
def run_test(self, keep_max):
nq = 100
nb = 1000
restab = faiss.rand((nq, nb), 123)
ids = faiss.randint((nq, nb), 1324, 10000)
all_rh = {}
for nstep in 1, 3:
rh = faiss.ResultHeap(nq, 10, keep_max=keep_max)
for i in range(nstep):
i0, i1 = i * nb // nstep, (i + 1) * nb // nstep
D = restab[:, i0:i1].copy()
I = ids[:, i0:i1].copy()
rh.add_result(D, I)
rh.finalize()
if keep_max:
assert np.all(rh.D[:, :-1] >= rh.D[:, 1:])
else:
assert np.all(rh.D[:, :-1] <= rh.D[:, 1:])
all_rh[nstep] = rh
np.testing.assert_equal(all_rh[1].D, all_rh[3].D)
np.testing.assert_equal(all_rh[1].I, all_rh[3].I)

View File

@ -15,8 +15,9 @@ from faiss.contrib import ivf_tools
from common_faiss_tests import get_dataset_2
try:
from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth
from faiss.contrib.exhaustive_search import range_search_max_results
from faiss.contrib.exhaustive_search import \
knn_ground_truth, knn, range_ground_truth, \
range_search_max_results, exponential_query_iterator
except:
pass # Submodule import broken in python 2.
@ -151,7 +152,7 @@ class TestExhaustiveSearch(unittest.TestCase):
# check repro OK
_, new_lims, new_D, new_I = range_search_max_results(
index, matrix_iterator(xq, 100), threshold)
index, matrix_iterator(xq, 100), threshold, max_results=1e10)
evaluation.test_ref_range_results(
ref_lims, ref_D, ref_I,
@ -381,3 +382,38 @@ class TestPreassigned(unittest.TestCase):
for q in range(len(xq)):
l0, l1 = lims[q], lims[q + 1]
self.assertTrue(set(I[q]) <= set(IR[l0:l1]))
class TestRangeSearchMaxResults(unittest.TestCase):
def do_test(self, metric_type):
ds = datasets.SyntheticDataset(32, 0, 1000, 200)
index = faiss.IndexFlat(ds.d, metric_type)
index.add(ds.get_database())
# find a reasonable radius
D, _ = index.search(ds.get_queries(), 10)
radius0 = float(np.median(D[:, -1]))
# baseline = search with that radius
lims_ref, Dref, Iref = index.range_search(ds.get_queries(), radius0)
# now see if using just the total number of results, we can get back the same
# result table
query_iterator = exponential_query_iterator(ds.get_queries())
init_radius = 1e10 if metric_type == faiss.METRIC_L2 else -1e10
radius1, lims_new, Dnew, Inew = range_search_max_results(
index, query_iterator, init_radius, min_results=Dref.size, clip_to_min=True
)
evaluation.test_ref_range_results(
lims_ref, Dref, Iref,
lims_new, Dnew, Inew
)
def test_L2(self):
self.do_test(faiss.METRIC_L2)
def test_IP(self):
self.do_test(faiss.METRIC_INNER_PRODUCT)

View File

@ -168,6 +168,7 @@ class TestCloneSize(unittest.TestCase):
index2 = faiss.clone_index(index)
assert index2.ntotal == 100
class TestCloneIVFPQ(unittest.TestCase):
def test_clone(self):
@ -177,3 +178,12 @@ class TestCloneIVFPQ(unittest.TestCase):
index.add(xb)
index2 = faiss.clone_index(index)
assert index2.ntotal == index.ntotal
class TestVTDowncast(unittest.TestCase):
def test_itq_transform(self):
codec = faiss.index_factory(16, "ITQ8,LSHt")
itqt = faiss.downcast_VectorTransform(codec.chain.at(0))
itqt.pca_then_itq