Add range search accuracy evaluation

Summary:
Added a few functions in contrib to:
- run range searches by batches on the query or the database side
- emulate range search on GPU: search on GPU with k=1024, if the farthest neighbor is still within range, re-perform search on CPU
- as reference implementations for precision-recall on range search datasets
- optimized code to plot precision-recall plots (ie. sweep over thresholds)

The new functions are mainly in a new `evaluation.py`

Reviewed By: wickedfoo

Differential Revision: D25627619

fbshipit-source-id: 58f90654c32c925557d7bbf8083efbb710712e03
pull/1591/head
Matthijs Douze 2020-12-17 17:15:54 -08:00 committed by Facebook GitHub Bot
parent 32df3f3198
commit 3dd7ba8ff9
10 changed files with 630 additions and 8 deletions

View File

@ -54,3 +54,7 @@ Defintion of how to access data for some standard datsets.
### factory_tools.py
Functions related to factory strings.
### evaluation.py
A few non-trivial evaluation functions for search results

View File

@ -0,0 +1,268 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import unittest
from multiprocessing.dummy import Pool as ThreadPool
###############################################################
# Simple functions to evaluate knn results
def knn_intersection_measure(I1, I2):
""" computes the intersection measure of two result tables
"""
nq, rank = I1.shape
assert I2.shape == (nq, rank)
ninter = sum(
np.intersect1d(I1[i], I2[i]).size
for i in range(nq)
)
return ninter / I1.size
###############################################################
# Range search results can be compared with Precision-Recall
def filter_range_results(lims, D, I, thresh):
""" select a set of results """
nq = lims.size - 1
mask = D < thresh
new_lims = np.zeros_like(lims)
for i in range(nq):
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
return new_lims, D[mask], I[mask]
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
"""compute the precision and recall of range search results. The
function does not take the distances into account. """
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
return Inew[lims_new[i]:lims_new[i + 1]]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
ninter = np.zeros(nq, dtype="int64")
def compute_PR_for(q):
# ground truth results for this query
gt_ids = ref_result_for(q)
# results for this query
new_ids = new_result_for(q)
# there are no set functions in numpy so let's do this
inter = np.intersect1d(gt_ids, new_ids)
ninter[q] = len(inter)
# run in a thread pool, which helps in spite of the GIL
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
return counts_to_PR(
lims_ref[1:] - lims_ref[:-1],
lims_new[1:] - lims_new[:-1],
ninter,
mode=mode
)
def counts_to_PR(ngt, nres, ninter, mode="overall"):
""" computes a precision-recall for a ser of queries.
ngt = nb of GT results per query
nres = nb of found results per query
ninter = nb of correct results per query (smaller than nres of course)
"""
if mode == "overall":
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
if nres > 0:
precision = ninter / nres
else:
precision = 1.0
if ngt > 0:
recall = ninter / ngt
elif nres == 0:
recall = 1.0
else:
recall = 0.0
return precision, recall
elif mode == "average":
# average precision and recall over queries
mask = ngt == 0
ngt[mask] = 1
recalls = ninter / ngt
recalls[mask] = (nres[mask] == 0).astype(float)
# avoid division by 0
mask = nres == 0
assert np.all(ninter[mask] == 0)
ninter[mask] = 1
nres[mask] = 1
precisions = ninter / nres
return precisions.mean(), recalls.mean()
else:
raise AssertionError()
def sort_range_res_2(lims, D, I):
""" sort 2 arrays using the first as key """
I2 = np.empty_like(I)
D2 = np.empty_like(D)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
ii = I[l0:l1]
di = D[l0:l1]
o = di.argsort()
I2[l0:l1] = ii[o]
D2[l0:l1] = di[o]
return I2, D2
def sort_range_res_1(lims, I):
I2 = np.empty_like(I)
nq = len(lims) - 1
for i in range(nq):
l0, l1 = lims[i], lims[i + 1]
I2[l0:l1] = I[l0:l1]
I2[l0:l1].sort()
return I2
def range_PR_multiple_thresholds(
lims_ref, Iref,
lims_new, Dnew, Inew,
thresholds,
mode="overall", do_sort="ref,new"
):
""" compute precision-recall values for range search results
for several thresholds on the "new" results.
This is to plot PR curves
"""
# ref should be sorted by ids
if "ref" in do_sort:
Iref = sort_range_res_1(lims_ref, Iref)
# new should be sorted by distances
if "new" in do_sort:
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
def ref_result_for(i):
return Iref[lims_ref[i]:lims_ref[i + 1]]
def new_result_for(i):
l0, l1 = lims_new[i], lims_new[i + 1]
return Inew[l0:l1], Dnew[l0:l1]
nq = lims_ref.size - 1
assert lims_new.size - 1 == nq
nt = len(thresholds)
counts = np.zeros((nq, nt, 3), dtype="int64")
def compute_PR_for(q):
gt_ids = ref_result_for(q)
res_ids, res_dis = new_result_for(q)
counts[q, :, 0] = len(gt_ids)
if res_dis.size == 0:
# the rest remains at 0
return
# which offsets we are interested in
nres= np.searchsorted(res_dis, thresholds)
counts[q, :, 1] = nres
if gt_ids.size == 0:
return
# find number of TPs at each stage in the result list
ii = np.searchsorted(gt_ids, res_ids)
ii[ii == len(gt_ids)] = -1
n_ok = np.cumsum(gt_ids[ii] == res_ids)
# focus on threshold points
n_ok = np.hstack(([0], n_ok))
counts[q, :, 2] = n_ok[nres]
pool = ThreadPool(20)
pool.map(compute_PR_for, range(nq))
# print(counts.transpose(2, 1, 0))
precisions = np.zeros(nt)
recalls = np.zeros(nt)
for t in range(nt):
p, r = counts_to_PR(
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
mode=mode
)
precisions[t] = p
recalls[t] = r
return precisions, recalls
###############################################################
# Functions that compare search results with a reference result.
# They are intended for use in tests
def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:
continue
mask = Dref[i, :] == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
def test_ref_range_results(lims_ref, Dref, Iref,
lims_new, Dnew, Inew):
""" compare range search results wrt. a reference result,
throw if it fails """
np.testing.assert_array_equal(lims_ref, lims_new)
nq = len(lims_ref) - 1
for i in range(nq):
l0, l1 = lims_ref[i], lims_ref[i + 1]
Ii_ref = Iref[l0:l1]
Ii_new = Inew[l0:l1]
Di_ref = Dref[l0:l1]
Di_new = Dnew[l0:l1]
if np.all(Ii_ref == Ii_new): # easy
pass
else:
def sort_by_ids(I, D):
o = I.argsort()
return I[o], D[o]
# sort both
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
np.testing.assert_array_equal(Ii_ref, Ii_new)
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)

View File

@ -11,7 +11,7 @@ import logging
LOG = logging.getLogger(__name__)
def knn_ground_truth(xq, db_iterator, k):
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
"""Computes the exact KNN search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
@ -21,12 +21,12 @@ def knn_ground_truth(xq, db_iterator, k):
nq, d = xq.shape
rh = faiss.ResultHeap(nq, k)
index = faiss.IndexFlatL2(d)
index = faiss.IndexFlat(d, metric_type)
if faiss.get_num_gpus():
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
index = faiss.index_cpu_to_all_gpus(index)
# compute ground-truth by blocks of bs, and add to heaps
# compute ground-truth by blocks, and add to heaps
i0 = 0
for xbi in db_iterator:
ni = xbi.shape[0]
@ -44,4 +44,184 @@ def knn_ground_truth(xq, db_iterator, k):
return rh.D, rh.I
# knn function used to be here
knn = faiss.knn
knn = faiss.knn
def range_search_gpu(xq, r2, index_gpu, xb):
""" GPU does not support range search, so we emulate it with
knn search + fallback to CPU index """
nq, d = xq.shape
LOG.debug("GPU search %d queries" % nq)
k = min(index_gpu.ntotal, 1024)
D, I = index_gpu.search(xq, k)
if index_gpu.metric_type == faiss.METRIC_L2:
mask = D[:, k - 1] < r2
else:
mask = D[:, k - 1] > r2
if mask.sum() > 0:
LOG.debug("CPU search remain %d" % mask.sum())
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
LOG.debug("combine")
D_res, I_res = [], []
nr = 0
for i in range(nq):
if not mask[i]:
if index_gpu.metric_type == faiss.METRIC_L2:
nv = (D[i, :] < r2).sum()
else:
nv = (D[i, :] > r2).sum()
D_res.append(D[i, :nv])
I_res.append(I[i, :nv])
else:
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
D_res.append(D_remain[l0:l1])
I_res.append(I_remain[l0:l1])
nr += 1
lims = np.cumsum([0] + [len(di) for di in D_res])
return lims, np.hstack(D_res), np.hstack(I_res)
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
shard=False, ngpu=-1):
"""Computes the range-search search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
returns it block by block.
"""
nq, d = xq.shape
t0 = time.time()
xq = np.ascontiguousarray(xq, dtype='float32')
index = faiss.IndexFlat(d, metric_type)
if ngpu == -1:
ngpu = faiss.get_num_gpus()
if ngpu:
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
# compute ground-truth by blocks
i0 = 0
D = [[] for _i in range(nq)]
I = [[] for _i in range(nq)]
all_lims = []
for xbi in db_iterator:
ni = xbi.shape[0]
if ngpu > 0:
index_gpu.add(xbi)
lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
index_gpu.reset()
else:
index.add(xbi)
lims_i, Di, Ii = index.range_search(xq, threshold)
index.reset()
Ii += i0
for j in range(nq):
l0, l1 = lims_i[j], lims_i[j + 1]
if l1 > l0:
D[j].append(Di[l0:l1])
I[j].append(Ii[l0:l1])
i0 += ni
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
empty_I = np.zeros(0, dtype='int64')
empty_D = np.zeros(0, dtype='float32')
# import pdb; pdb.set_trace()
D = [(np.hstack(i) if i != [] else empty_D) for i in D]
I = [(np.hstack(i) if i != [] else empty_I) for i in I]
sizes = [len(i) for i in I]
assert len(sizes) == nq
lims = np.zeros(nq + 1, dtype="uint64")
lims[1:] = np.cumsum(sizes)
return lims, np.hstack(D), np.hstack(I)
def threshold_radius(nres, dis, ids, thresh):
""" select a set of results """
mask = dis < thresh
new_nres = np.zeros_like(nres)
o = 0
for i, nr in enumerate(nres):
nr = int(nr) # avoid issues with int64 + uint64
new_nres[i] = mask[o : o + nr].sum()
o += nr
return new_nres, dis[mask], ids[mask]
def apply_maxres(res_batches, target_nres):
"""find radius that reduces number of results to target_nres, and
applies it in-place to the result batches"""
alldis = np.hstack([dis for _, dis, _ in res_batches])
alldis.partition(target_nres)
radius = alldis[target_nres]
if alldis.dtype == 'float32':
radius = float(radius)
else:
radius = int(radius)
print(' setting radius to %s' % radius)
totres = 0
for i, (nres, dis, ids) in enumerate(res_batches):
nres, dis, ids = threshold_radius(nres, dis, ids, radius)
totres += len(dis)
res_batches[i] = nres, dis, ids
print(' updated previous results, new nb results %d' % totres)
return radius, totres
def range_search_max_results(index, query_iterator, radius,
max_results=None, min_results=None):
""" Performs a range search with many queries (given by an iterator)
and adjusts the threshold on-the-fly so that the total results
table does not grow larger than max_results """
if max_results is not None:
if min_results is None:
min_results = int(0.8 * max_results)
t_start = time.time()
t_search = t_post_process = 0
qtot = totres = raw_totres = 0
res_batches = []
for xqi in query_iterator:
t0 = time.time()
lims_i, Di, Ii = index.range_search(xqi, radius)
nres_i = lims_i[1:] - lims_i[:-1]
raw_totres += len(Di)
qtot += len(xqi)
t1 = time.time()
if xqi.dtype != np.float32:
# for binary indexes
# weird Faiss quirk that returns floats for Hamming distances
Di = Di.astype('int16')
totres += len(Di)
res_batches.append((nres_i, Di, Ii))
if max_results is not None and totres > max_results:
LOG.info('too many results %d > %d, scaling back radius' %
(totres, max_results))
radius, totres = apply_maxres(res_batches, min_results)
t2 = time.time()
t_search += t1 - t0
t_post_process += t2 - t1
LOG.debug(' [%.3f s] %d queries done, %d results' % (
time.time() - t_start, qtot, totres))
LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
t_search, t_post_process, totres, radius))
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
lims = np.zeros(len(nres) + 1, dtype='uint64')
lims[1:] = np.cumsum(nres)
return radius, lims, dis, ids

View File

@ -8,7 +8,9 @@ import unittest
import numpy as np
from faiss.contrib import datasets
from faiss.contrib.exhaustive_search import knn_ground_truth
from faiss.contrib.exhaustive_search import knn_ground_truth, range_ground_truth
from faiss.contrib import evaluation
from common import get_dataset_2
@ -33,3 +35,29 @@ class TestComputeGT(unittest.TestCase):
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_almost_equal(Dref, Dnew, decimal=4)
def do_test_range(self, metric):
ds = datasets.SyntheticDataset(32, 0, 1000, 10)
xq = ds.get_queries()
xb = ds.get_database()
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
threshold = float(D[:, -1].mean())
index = faiss.IndexFlat(32, metric)
index.add(xb)
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
new_lims, new_D, new_I = range_ground_truth(
xq, ds.database_iterator(bs=100), threshold,
metric_type=metric)
evaluation.test_ref_range_results(
ref_lims, ref_D, ref_I,
new_lims, new_D, new_I
)
def test_range_L2(self):
self.do_test_range(faiss.METRIC_L2)
def test_range_IP(self):
self.do_test_range(faiss.METRIC_INNER_PRODUCT)

0
tests/test_binary_hashindex.py 100755 → 100644
View File

View File

@ -10,14 +10,18 @@ import platform
from faiss.contrib import datasets
from faiss.contrib import inspect_tools
from faiss.contrib import evaluation
from common import get_dataset_2
try:
from faiss.contrib.exhaustive_search import knn_ground_truth, knn
from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth
from faiss.contrib.exhaustive_search import range_search_max_results
except:
pass # Submodule import broken in python 2.
@unittest.skipIf(platform.python_version_tuple()[0] < '3', \
'Submodule import broken in python 2.')
class TestComputeGT(unittest.TestCase):
@ -80,11 +84,9 @@ class TestDatasets(unittest.TestCase):
class TestExhaustiveSearch(unittest.TestCase):
def test_knn_cpu(self):
xb = np.random.rand(200, 32).astype('float32')
xq = np.random.rand(100, 32).astype('float32')
index = faiss.IndexFlatL2(32)
index.add(xb)
Dref, Iref = index.search(xq, 10)
@ -104,6 +106,72 @@ class TestExhaustiveSearch(unittest.TestCase):
assert np.all(Inew == Iref)
assert np.allclose(Dref, Dnew)
def do_test_range(self, metric):
ds = datasets.SyntheticDataset(32, 0, 1000, 10)
xq = ds.get_queries()
xb = ds.get_database()
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
threshold = float(D[:, -1].mean())
index = faiss.IndexFlat(32, metric)
index.add(xb)
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
new_lims, new_D, new_I = range_ground_truth(
xq, ds.database_iterator(bs=100), threshold, ngpu=0,
metric_type=metric)
evaluation.test_ref_range_results(
ref_lims, ref_D, ref_I,
new_lims, new_D, new_I
)
def test_range_L2(self):
self.do_test_range(faiss.METRIC_L2)
def test_range_IP(self):
self.do_test_range(faiss.METRIC_INNER_PRODUCT)
def test_query_iterator(self, metric=faiss.METRIC_L2):
ds = datasets.SyntheticDataset(32, 0, 1000, 1000)
xq = ds.get_queries()
xb = ds.get_database()
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
threshold = float(D[:, -1].mean())
print(threshold)
index = faiss.IndexFlat(32, metric)
index.add(xb)
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
def matrix_iterator(xb, bs):
for i0 in range(0, xb.shape[0], bs):
yield xb[i0:i0 + bs]
# check repro OK
_, new_lims, new_D, new_I = range_search_max_results(
index, matrix_iterator(xq, 100), threshold)
evaluation.test_ref_range_results(
ref_lims, ref_D, ref_I,
new_lims, new_D, new_I
)
max_res = ref_lims[-1] // 2
new_threshold, new_lims, new_D, new_I = range_search_max_results(
index, matrix_iterator(xq, 100), threshold, max_results=max_res)
self.assertLessEqual(new_lims[-1], max_res)
ref_lims, ref_D, ref_I = index.range_search(xq, new_threshold)
evaluation.test_ref_range_results(
ref_lims, ref_D, ref_I,
new_lims, new_D, new_I
)
class TestInspect(unittest.TestCase):
@ -123,3 +191,77 @@ class TestInspect(unittest.TestCase):
# verify
ynew = x @ A.T + b
np.testing.assert_array_almost_equal(yref, ynew)
class TestRangeEval(unittest.TestCase):
def test_precision_recall(self):
Iref = [
[1, 2, 3],
[5, 6],
[],
[]
]
Inew = [
[1, 2],
[6, 7],
[1],
[]
]
lims_ref = np.cumsum([0] + [len(x) for x in Iref])
Iref = np.hstack(Iref)
lims_new = np.cumsum([0] + [len(x) for x in Inew])
Inew = np.hstack(Inew)
precision, recall = evaluation.range_PR(lims_ref, Iref, lims_new, Inew)
print(precision, recall)
self.assertEqual(precision, 0.6)
self.assertEqual(recall, 0.6)
def test_PR_multiple(self):
metric = faiss.METRIC_L2
ds = datasets.SyntheticDataset(32, 1000, 1000, 10)
xq = ds.get_queries()
xb = ds.get_database()
# good for ~10k results
threshold = 15
index = faiss.IndexFlat(32, metric)
index.add(xb)
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
# now make a slightly suboptimal index
index2 = faiss.index_factory(32, "PCA16,Flat")
index2.train(ds.get_train())
index2.add(xb)
# PCA reduces distances so will have more results
new_lims, new_D, new_I = index2.range_search(xq, threshold)
all_thr = np.array([5.0, 10.0, 12.0, 15.0])
for mode in "overall", "average":
ref_precisions = np.zeros_like(all_thr)
ref_recalls = np.zeros_like(all_thr)
for i, thr in enumerate(all_thr):
lims2, _, I2 = evaluation.filter_range_results(
new_lims, new_D, new_I, thr)
prec, recall = evaluation.range_PR(
ref_lims, ref_I, lims2, I2, mode=mode)
ref_precisions[i] = prec
ref_recalls[i] = recall
precisions, recalls = evaluation.range_PR_multiple_thresholds(
ref_lims, ref_I,
new_lims, new_D, new_I, all_thr,
mode=mode
)
np.testing.assert_array_almost_equal(ref_precisions, precisions)
np.testing.assert_array_almost_equal(ref_recalls, recalls)

0
tests/test_extra_distances.py 100755 → 100644
View File

0
tests/test_io.py 100755 → 100644
View File

0
tests/test_oom_exception.py 100755 → 100644
View File

0
tests/test_standalone_codec.py 100755 → 100644
View File