Add range search accuracy evaluation
Summary: Added a few functions in contrib to: - run range searches by batches on the query or the database side - emulate range search on GPU: search on GPU with k=1024, if the farthest neighbor is still within range, re-perform search on CPU - as reference implementations for precision-recall on range search datasets - optimized code to plot precision-recall plots (ie. sweep over thresholds) The new functions are mainly in a new `evaluation.py` Reviewed By: wickedfoo Differential Revision: D25627619 fbshipit-source-id: 58f90654c32c925557d7bbf8083efbb710712e03pull/1591/head
parent
32df3f3198
commit
3dd7ba8ff9
|
@ -54,3 +54,7 @@ Defintion of how to access data for some standard datsets.
|
|||
### factory_tools.py
|
||||
|
||||
Functions related to factory strings.
|
||||
|
||||
### evaluation.py
|
||||
|
||||
A few non-trivial evaluation functions for search results
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
|
||||
###############################################################
|
||||
# Simple functions to evaluate knn results
|
||||
|
||||
def knn_intersection_measure(I1, I2):
|
||||
""" computes the intersection measure of two result tables
|
||||
"""
|
||||
nq, rank = I1.shape
|
||||
assert I2.shape == (nq, rank)
|
||||
ninter = sum(
|
||||
np.intersect1d(I1[i], I2[i]).size
|
||||
for i in range(nq)
|
||||
)
|
||||
return ninter / I1.size
|
||||
|
||||
###############################################################
|
||||
# Range search results can be compared with Precision-Recall
|
||||
|
||||
def filter_range_results(lims, D, I, thresh):
|
||||
""" select a set of results """
|
||||
nq = lims.size - 1
|
||||
mask = D < thresh
|
||||
new_lims = np.zeros_like(lims)
|
||||
for i in range(nq):
|
||||
new_lims[i + 1] = new_lims[i] + mask[lims[i] : lims[i + 1]].sum()
|
||||
return new_lims, D[mask], I[mask]
|
||||
|
||||
|
||||
def range_PR(lims_ref, Iref, lims_new, Inew, mode="overall"):
|
||||
"""compute the precision and recall of range search results. The
|
||||
function does not take the distances into account. """
|
||||
|
||||
def ref_result_for(i):
|
||||
return Iref[lims_ref[i]:lims_ref[i + 1]]
|
||||
|
||||
def new_result_for(i):
|
||||
return Inew[lims_new[i]:lims_new[i + 1]]
|
||||
|
||||
nq = lims_ref.size - 1
|
||||
assert lims_new.size - 1 == nq
|
||||
|
||||
ninter = np.zeros(nq, dtype="int64")
|
||||
|
||||
def compute_PR_for(q):
|
||||
|
||||
# ground truth results for this query
|
||||
gt_ids = ref_result_for(q)
|
||||
|
||||
# results for this query
|
||||
new_ids = new_result_for(q)
|
||||
|
||||
# there are no set functions in numpy so let's do this
|
||||
inter = np.intersect1d(gt_ids, new_ids)
|
||||
|
||||
ninter[q] = len(inter)
|
||||
|
||||
# run in a thread pool, which helps in spite of the GIL
|
||||
pool = ThreadPool(20)
|
||||
pool.map(compute_PR_for, range(nq))
|
||||
|
||||
return counts_to_PR(
|
||||
lims_ref[1:] - lims_ref[:-1],
|
||||
lims_new[1:] - lims_new[:-1],
|
||||
ninter,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
|
||||
def counts_to_PR(ngt, nres, ninter, mode="overall"):
|
||||
""" computes a precision-recall for a ser of queries.
|
||||
ngt = nb of GT results per query
|
||||
nres = nb of found results per query
|
||||
ninter = nb of correct results per query (smaller than nres of course)
|
||||
"""
|
||||
|
||||
if mode == "overall":
|
||||
ngt, nres, ninter = ngt.sum(), nres.sum(), ninter.sum()
|
||||
|
||||
if nres > 0:
|
||||
precision = ninter / nres
|
||||
else:
|
||||
precision = 1.0
|
||||
|
||||
if ngt > 0:
|
||||
recall = ninter / ngt
|
||||
elif nres == 0:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = 0.0
|
||||
|
||||
return precision, recall
|
||||
|
||||
elif mode == "average":
|
||||
# average precision and recall over queries
|
||||
|
||||
mask = ngt == 0
|
||||
ngt[mask] = 1
|
||||
|
||||
recalls = ninter / ngt
|
||||
recalls[mask] = (nres[mask] == 0).astype(float)
|
||||
|
||||
# avoid division by 0
|
||||
mask = nres == 0
|
||||
assert np.all(ninter[mask] == 0)
|
||||
ninter[mask] = 1
|
||||
nres[mask] = 1
|
||||
|
||||
precisions = ninter / nres
|
||||
|
||||
return precisions.mean(), recalls.mean()
|
||||
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
def sort_range_res_2(lims, D, I):
|
||||
""" sort 2 arrays using the first as key """
|
||||
I2 = np.empty_like(I)
|
||||
D2 = np.empty_like(D)
|
||||
nq = len(lims) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = lims[i], lims[i + 1]
|
||||
ii = I[l0:l1]
|
||||
di = D[l0:l1]
|
||||
o = di.argsort()
|
||||
I2[l0:l1] = ii[o]
|
||||
D2[l0:l1] = di[o]
|
||||
return I2, D2
|
||||
|
||||
|
||||
def sort_range_res_1(lims, I):
|
||||
I2 = np.empty_like(I)
|
||||
nq = len(lims) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = lims[i], lims[i + 1]
|
||||
I2[l0:l1] = I[l0:l1]
|
||||
I2[l0:l1].sort()
|
||||
return I2
|
||||
|
||||
|
||||
def range_PR_multiple_thresholds(
|
||||
lims_ref, Iref,
|
||||
lims_new, Dnew, Inew,
|
||||
thresholds,
|
||||
mode="overall", do_sort="ref,new"
|
||||
):
|
||||
""" compute precision-recall values for range search results
|
||||
for several thresholds on the "new" results.
|
||||
This is to plot PR curves
|
||||
"""
|
||||
# ref should be sorted by ids
|
||||
if "ref" in do_sort:
|
||||
Iref = sort_range_res_1(lims_ref, Iref)
|
||||
|
||||
# new should be sorted by distances
|
||||
if "new" in do_sort:
|
||||
Inew, Dnew = sort_range_res_2(lims_new, Dnew, Inew)
|
||||
|
||||
def ref_result_for(i):
|
||||
return Iref[lims_ref[i]:lims_ref[i + 1]]
|
||||
|
||||
def new_result_for(i):
|
||||
l0, l1 = lims_new[i], lims_new[i + 1]
|
||||
return Inew[l0:l1], Dnew[l0:l1]
|
||||
|
||||
nq = lims_ref.size - 1
|
||||
assert lims_new.size - 1 == nq
|
||||
|
||||
nt = len(thresholds)
|
||||
counts = np.zeros((nq, nt, 3), dtype="int64")
|
||||
|
||||
def compute_PR_for(q):
|
||||
gt_ids = ref_result_for(q)
|
||||
res_ids, res_dis = new_result_for(q)
|
||||
|
||||
counts[q, :, 0] = len(gt_ids)
|
||||
|
||||
if res_dis.size == 0:
|
||||
# the rest remains at 0
|
||||
return
|
||||
|
||||
# which offsets we are interested in
|
||||
nres= np.searchsorted(res_dis, thresholds)
|
||||
counts[q, :, 1] = nres
|
||||
|
||||
if gt_ids.size == 0:
|
||||
return
|
||||
|
||||
# find number of TPs at each stage in the result list
|
||||
ii = np.searchsorted(gt_ids, res_ids)
|
||||
ii[ii == len(gt_ids)] = -1
|
||||
n_ok = np.cumsum(gt_ids[ii] == res_ids)
|
||||
|
||||
# focus on threshold points
|
||||
n_ok = np.hstack(([0], n_ok))
|
||||
counts[q, :, 2] = n_ok[nres]
|
||||
|
||||
pool = ThreadPool(20)
|
||||
pool.map(compute_PR_for, range(nq))
|
||||
# print(counts.transpose(2, 1, 0))
|
||||
|
||||
precisions = np.zeros(nt)
|
||||
recalls = np.zeros(nt)
|
||||
for t in range(nt):
|
||||
p, r = counts_to_PR(
|
||||
counts[:, t, 0], counts[:, t, 1], counts[:, t, 2],
|
||||
mode=mode
|
||||
)
|
||||
precisions[t] = p
|
||||
recalls[t] = r
|
||||
|
||||
return precisions, recalls
|
||||
|
||||
|
||||
|
||||
|
||||
###############################################################
|
||||
# Functions that compare search results with a reference result.
|
||||
# They are intended for use in tests
|
||||
|
||||
def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
|
||||
""" test that knn search results are identical, raise if not """
|
||||
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
|
||||
# here we have to be careful because of draws
|
||||
testcase = unittest.TestCase() # because it makes nice error messages
|
||||
for i in range(len(Iref)):
|
||||
if np.all(Iref[i] == Inew[i]): # easy case
|
||||
continue
|
||||
# we can deduce nothing about the latest line
|
||||
skip_dis = Dref[i, -1]
|
||||
for dis in np.unique(Dref):
|
||||
if dis == skip_dis:
|
||||
continue
|
||||
mask = Dref[i, :] == dis
|
||||
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
|
||||
|
||||
|
||||
def test_ref_range_results(lims_ref, Dref, Iref,
|
||||
lims_new, Dnew, Inew):
|
||||
""" compare range search results wrt. a reference result,
|
||||
throw if it fails """
|
||||
np.testing.assert_array_equal(lims_ref, lims_new)
|
||||
nq = len(lims_ref) - 1
|
||||
for i in range(nq):
|
||||
l0, l1 = lims_ref[i], lims_ref[i + 1]
|
||||
Ii_ref = Iref[l0:l1]
|
||||
Ii_new = Inew[l0:l1]
|
||||
Di_ref = Dref[l0:l1]
|
||||
Di_new = Dnew[l0:l1]
|
||||
if np.all(Ii_ref == Ii_new): # easy
|
||||
pass
|
||||
else:
|
||||
def sort_by_ids(I, D):
|
||||
o = I.argsort()
|
||||
return I[o], D[o]
|
||||
# sort both
|
||||
(Ii_ref, Di_ref) = sort_by_ids(Ii_ref, Di_ref)
|
||||
(Ii_new, Di_new) = sort_by_ids(Ii_new, Di_new)
|
||||
np.testing.assert_array_equal(Ii_ref, Ii_new)
|
||||
np.testing.assert_array_almost_equal(Di_ref, Di_new, decimal=5)
|
|
@ -11,7 +11,7 @@ import logging
|
|||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
def knn_ground_truth(xq, db_iterator, k):
|
||||
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
|
||||
"""Computes the exact KNN search results for a dataset that possibly
|
||||
does not fit in RAM but for which we have an iterator that
|
||||
returns it block by block.
|
||||
|
@ -21,12 +21,12 @@ def knn_ground_truth(xq, db_iterator, k):
|
|||
nq, d = xq.shape
|
||||
rh = faiss.ResultHeap(nq, k)
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexFlat(d, metric_type)
|
||||
if faiss.get_num_gpus():
|
||||
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
|
||||
index = faiss.index_cpu_to_all_gpus(index)
|
||||
|
||||
# compute ground-truth by blocks of bs, and add to heaps
|
||||
# compute ground-truth by blocks, and add to heaps
|
||||
i0 = 0
|
||||
for xbi in db_iterator:
|
||||
ni = xbi.shape[0]
|
||||
|
@ -44,4 +44,184 @@ def knn_ground_truth(xq, db_iterator, k):
|
|||
return rh.D, rh.I
|
||||
|
||||
# knn function used to be here
|
||||
knn = faiss.knn
|
||||
knn = faiss.knn
|
||||
|
||||
|
||||
|
||||
|
||||
def range_search_gpu(xq, r2, index_gpu, xb):
|
||||
""" GPU does not support range search, so we emulate it with
|
||||
knn search + fallback to CPU index """
|
||||
nq, d = xq.shape
|
||||
LOG.debug("GPU search %d queries" % nq)
|
||||
k = min(index_gpu.ntotal, 1024)
|
||||
D, I = index_gpu.search(xq, k)
|
||||
if index_gpu.metric_type == faiss.METRIC_L2:
|
||||
mask = D[:, k - 1] < r2
|
||||
else:
|
||||
mask = D[:, k - 1] > r2
|
||||
if mask.sum() > 0:
|
||||
LOG.debug("CPU search remain %d" % mask.sum())
|
||||
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
|
||||
index_cpu.add(xb)
|
||||
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
|
||||
LOG.debug("combine")
|
||||
D_res, I_res = [], []
|
||||
nr = 0
|
||||
for i in range(nq):
|
||||
if not mask[i]:
|
||||
if index_gpu.metric_type == faiss.METRIC_L2:
|
||||
nv = (D[i, :] < r2).sum()
|
||||
else:
|
||||
nv = (D[i, :] > r2).sum()
|
||||
D_res.append(D[i, :nv])
|
||||
I_res.append(I[i, :nv])
|
||||
else:
|
||||
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
|
||||
D_res.append(D_remain[l0:l1])
|
||||
I_res.append(I_remain[l0:l1])
|
||||
nr += 1
|
||||
lims = np.cumsum([0] + [len(di) for di in D_res])
|
||||
return lims, np.hstack(D_res), np.hstack(I_res)
|
||||
|
||||
|
||||
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
|
||||
shard=False, ngpu=-1):
|
||||
"""Computes the range-search search results for a dataset that possibly
|
||||
does not fit in RAM but for which we have an iterator that
|
||||
returns it block by block.
|
||||
"""
|
||||
nq, d = xq.shape
|
||||
t0 = time.time()
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
|
||||
index = faiss.IndexFlat(d, metric_type)
|
||||
if ngpu == -1:
|
||||
ngpu = faiss.get_num_gpus()
|
||||
if ngpu:
|
||||
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = shard
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
|
||||
|
||||
# compute ground-truth by blocks
|
||||
i0 = 0
|
||||
D = [[] for _i in range(nq)]
|
||||
I = [[] for _i in range(nq)]
|
||||
all_lims = []
|
||||
for xbi in db_iterator:
|
||||
ni = xbi.shape[0]
|
||||
if ngpu > 0:
|
||||
index_gpu.add(xbi)
|
||||
lims_i, Di, Ii = range_search_gpu(xq, threshold, index_gpu, xbi)
|
||||
index_gpu.reset()
|
||||
else:
|
||||
index.add(xbi)
|
||||
lims_i, Di, Ii = index.range_search(xq, threshold)
|
||||
index.reset()
|
||||
Ii += i0
|
||||
for j in range(nq):
|
||||
l0, l1 = lims_i[j], lims_i[j + 1]
|
||||
if l1 > l0:
|
||||
D[j].append(Di[l0:l1])
|
||||
I[j].append(Ii[l0:l1])
|
||||
i0 += ni
|
||||
LOG.info("%d db elements, %.3f s" % (i0, time.time() - t0))
|
||||
|
||||
empty_I = np.zeros(0, dtype='int64')
|
||||
empty_D = np.zeros(0, dtype='float32')
|
||||
# import pdb; pdb.set_trace()
|
||||
D = [(np.hstack(i) if i != [] else empty_D) for i in D]
|
||||
I = [(np.hstack(i) if i != [] else empty_I) for i in I]
|
||||
sizes = [len(i) for i in I]
|
||||
assert len(sizes) == nq
|
||||
lims = np.zeros(nq + 1, dtype="uint64")
|
||||
lims[1:] = np.cumsum(sizes)
|
||||
return lims, np.hstack(D), np.hstack(I)
|
||||
|
||||
|
||||
def threshold_radius(nres, dis, ids, thresh):
|
||||
""" select a set of results """
|
||||
mask = dis < thresh
|
||||
new_nres = np.zeros_like(nres)
|
||||
o = 0
|
||||
for i, nr in enumerate(nres):
|
||||
nr = int(nr) # avoid issues with int64 + uint64
|
||||
new_nres[i] = mask[o : o + nr].sum()
|
||||
o += nr
|
||||
return new_nres, dis[mask], ids[mask]
|
||||
|
||||
|
||||
def apply_maxres(res_batches, target_nres):
|
||||
"""find radius that reduces number of results to target_nres, and
|
||||
applies it in-place to the result batches"""
|
||||
alldis = np.hstack([dis for _, dis, _ in res_batches])
|
||||
alldis.partition(target_nres)
|
||||
radius = alldis[target_nres]
|
||||
|
||||
if alldis.dtype == 'float32':
|
||||
radius = float(radius)
|
||||
else:
|
||||
radius = int(radius)
|
||||
print(' setting radius to %s' % radius)
|
||||
totres = 0
|
||||
for i, (nres, dis, ids) in enumerate(res_batches):
|
||||
nres, dis, ids = threshold_radius(nres, dis, ids, radius)
|
||||
totres += len(dis)
|
||||
res_batches[i] = nres, dis, ids
|
||||
print(' updated previous results, new nb results %d' % totres)
|
||||
return radius, totres
|
||||
|
||||
|
||||
def range_search_max_results(index, query_iterator, radius,
|
||||
max_results=None, min_results=None):
|
||||
""" Performs a range search with many queries (given by an iterator)
|
||||
and adjusts the threshold on-the-fly so that the total results
|
||||
table does not grow larger than max_results """
|
||||
|
||||
if max_results is not None:
|
||||
if min_results is None:
|
||||
min_results = int(0.8 * max_results)
|
||||
|
||||
t_start = time.time()
|
||||
t_search = t_post_process = 0
|
||||
qtot = totres = raw_totres = 0
|
||||
res_batches = []
|
||||
|
||||
for xqi in query_iterator:
|
||||
t0 = time.time()
|
||||
lims_i, Di, Ii = index.range_search(xqi, radius)
|
||||
nres_i = lims_i[1:] - lims_i[:-1]
|
||||
raw_totres += len(Di)
|
||||
qtot += len(xqi)
|
||||
|
||||
t1 = time.time()
|
||||
if xqi.dtype != np.float32:
|
||||
# for binary indexes
|
||||
# weird Faiss quirk that returns floats for Hamming distances
|
||||
Di = Di.astype('int16')
|
||||
|
||||
totres += len(Di)
|
||||
res_batches.append((nres_i, Di, Ii))
|
||||
|
||||
if max_results is not None and totres > max_results:
|
||||
LOG.info('too many results %d > %d, scaling back radius' %
|
||||
(totres, max_results))
|
||||
radius, totres = apply_maxres(res_batches, min_results)
|
||||
t2 = time.time()
|
||||
t_search += t1 - t0
|
||||
t_post_process += t2 - t1
|
||||
LOG.debug(' [%.3f s] %d queries done, %d results' % (
|
||||
time.time() - t_start, qtot, totres))
|
||||
|
||||
LOG.info(' search done in %.3f s + %.3f s, total %d results, end threshold %g' % (
|
||||
t_search, t_post_process, totres, radius))
|
||||
|
||||
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
|
||||
dis = np.hstack([dis_i for nres_i, dis_i, ids_i in res_batches])
|
||||
ids = np.hstack([ids_i for nres_i, dis_i, ids_i in res_batches])
|
||||
|
||||
lims = np.zeros(len(nres) + 1, dtype='uint64')
|
||||
lims[1:] = np.cumsum(nres)
|
||||
|
||||
return radius, lims, dis, ids
|
||||
|
|
|
@ -8,7 +8,9 @@ import unittest
|
|||
import numpy as np
|
||||
|
||||
from faiss.contrib import datasets
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, range_ground_truth
|
||||
from faiss.contrib import evaluation
|
||||
|
||||
|
||||
from common import get_dataset_2
|
||||
|
||||
|
@ -33,3 +35,29 @@ class TestComputeGT(unittest.TestCase):
|
|||
|
||||
np.testing.assert_array_equal(Iref, Inew)
|
||||
np.testing.assert_almost_equal(Dref, Dnew, decimal=4)
|
||||
|
||||
def do_test_range(self, metric):
|
||||
ds = datasets.SyntheticDataset(32, 0, 1000, 10)
|
||||
xq = ds.get_queries()
|
||||
xb = ds.get_database()
|
||||
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
|
||||
threshold = float(D[:, -1].mean())
|
||||
|
||||
index = faiss.IndexFlat(32, metric)
|
||||
index.add(xb)
|
||||
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
||||
|
||||
new_lims, new_D, new_I = range_ground_truth(
|
||||
xq, ds.database_iterator(bs=100), threshold,
|
||||
metric_type=metric)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
ref_lims, ref_D, ref_I,
|
||||
new_lims, new_D, new_I
|
||||
)
|
||||
|
||||
def test_range_L2(self):
|
||||
self.do_test_range(faiss.METRIC_L2)
|
||||
|
||||
def test_range_IP(self):
|
||||
self.do_test_range(faiss.METRIC_INNER_PRODUCT)
|
||||
|
|
|
@ -10,14 +10,18 @@ import platform
|
|||
|
||||
from faiss.contrib import datasets
|
||||
from faiss.contrib import inspect_tools
|
||||
from faiss.contrib import evaluation
|
||||
|
||||
from common import get_dataset_2
|
||||
try:
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, knn
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, knn, range_ground_truth
|
||||
from faiss.contrib.exhaustive_search import range_search_max_results
|
||||
|
||||
except:
|
||||
pass # Submodule import broken in python 2.
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(platform.python_version_tuple()[0] < '3', \
|
||||
'Submodule import broken in python 2.')
|
||||
class TestComputeGT(unittest.TestCase):
|
||||
|
@ -80,11 +84,9 @@ class TestDatasets(unittest.TestCase):
|
|||
class TestExhaustiveSearch(unittest.TestCase):
|
||||
|
||||
def test_knn_cpu(self):
|
||||
|
||||
xb = np.random.rand(200, 32).astype('float32')
|
||||
xq = np.random.rand(100, 32).astype('float32')
|
||||
|
||||
|
||||
index = faiss.IndexFlatL2(32)
|
||||
index.add(xb)
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
|
@ -104,6 +106,72 @@ class TestExhaustiveSearch(unittest.TestCase):
|
|||
assert np.all(Inew == Iref)
|
||||
assert np.allclose(Dref, Dnew)
|
||||
|
||||
def do_test_range(self, metric):
|
||||
ds = datasets.SyntheticDataset(32, 0, 1000, 10)
|
||||
xq = ds.get_queries()
|
||||
xb = ds.get_database()
|
||||
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
|
||||
threshold = float(D[:, -1].mean())
|
||||
|
||||
index = faiss.IndexFlat(32, metric)
|
||||
index.add(xb)
|
||||
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
||||
|
||||
new_lims, new_D, new_I = range_ground_truth(
|
||||
xq, ds.database_iterator(bs=100), threshold, ngpu=0,
|
||||
metric_type=metric)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
ref_lims, ref_D, ref_I,
|
||||
new_lims, new_D, new_I
|
||||
)
|
||||
|
||||
def test_range_L2(self):
|
||||
self.do_test_range(faiss.METRIC_L2)
|
||||
|
||||
def test_range_IP(self):
|
||||
self.do_test_range(faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
def test_query_iterator(self, metric=faiss.METRIC_L2):
|
||||
ds = datasets.SyntheticDataset(32, 0, 1000, 1000)
|
||||
xq = ds.get_queries()
|
||||
xb = ds.get_database()
|
||||
D, I = faiss.knn(xq, xb, 10, distance_type=metric)
|
||||
threshold = float(D[:, -1].mean())
|
||||
print(threshold)
|
||||
|
||||
index = faiss.IndexFlat(32, metric)
|
||||
index.add(xb)
|
||||
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
||||
|
||||
def matrix_iterator(xb, bs):
|
||||
for i0 in range(0, xb.shape[0], bs):
|
||||
yield xb[i0:i0 + bs]
|
||||
|
||||
# check repro OK
|
||||
_, new_lims, new_D, new_I = range_search_max_results(
|
||||
index, matrix_iterator(xq, 100), threshold)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
ref_lims, ref_D, ref_I,
|
||||
new_lims, new_D, new_I
|
||||
)
|
||||
|
||||
max_res = ref_lims[-1] // 2
|
||||
|
||||
new_threshold, new_lims, new_D, new_I = range_search_max_results(
|
||||
index, matrix_iterator(xq, 100), threshold, max_results=max_res)
|
||||
|
||||
self.assertLessEqual(new_lims[-1], max_res)
|
||||
|
||||
ref_lims, ref_D, ref_I = index.range_search(xq, new_threshold)
|
||||
|
||||
evaluation.test_ref_range_results(
|
||||
ref_lims, ref_D, ref_I,
|
||||
new_lims, new_D, new_I
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TestInspect(unittest.TestCase):
|
||||
|
||||
|
@ -123,3 +191,77 @@ class TestInspect(unittest.TestCase):
|
|||
# verify
|
||||
ynew = x @ A.T + b
|
||||
np.testing.assert_array_almost_equal(yref, ynew)
|
||||
|
||||
|
||||
class TestRangeEval(unittest.TestCase):
|
||||
|
||||
def test_precision_recall(self):
|
||||
Iref = [
|
||||
[1, 2, 3],
|
||||
[5, 6],
|
||||
[],
|
||||
[]
|
||||
]
|
||||
Inew = [
|
||||
[1, 2],
|
||||
[6, 7],
|
||||
[1],
|
||||
[]
|
||||
]
|
||||
|
||||
lims_ref = np.cumsum([0] + [len(x) for x in Iref])
|
||||
Iref = np.hstack(Iref)
|
||||
lims_new = np.cumsum([0] + [len(x) for x in Inew])
|
||||
Inew = np.hstack(Inew)
|
||||
|
||||
precision, recall = evaluation.range_PR(lims_ref, Iref, lims_new, Inew)
|
||||
print(precision, recall)
|
||||
|
||||
self.assertEqual(precision, 0.6)
|
||||
self.assertEqual(recall, 0.6)
|
||||
|
||||
def test_PR_multiple(self):
|
||||
metric = faiss.METRIC_L2
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 10)
|
||||
xq = ds.get_queries()
|
||||
xb = ds.get_database()
|
||||
|
||||
# good for ~10k results
|
||||
threshold = 15
|
||||
|
||||
index = faiss.IndexFlat(32, metric)
|
||||
index.add(xb)
|
||||
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
||||
|
||||
# now make a slightly suboptimal index
|
||||
index2 = faiss.index_factory(32, "PCA16,Flat")
|
||||
index2.train(ds.get_train())
|
||||
index2.add(xb)
|
||||
|
||||
# PCA reduces distances so will have more results
|
||||
new_lims, new_D, new_I = index2.range_search(xq, threshold)
|
||||
|
||||
all_thr = np.array([5.0, 10.0, 12.0, 15.0])
|
||||
for mode in "overall", "average":
|
||||
ref_precisions = np.zeros_like(all_thr)
|
||||
ref_recalls = np.zeros_like(all_thr)
|
||||
|
||||
for i, thr in enumerate(all_thr):
|
||||
|
||||
lims2, _, I2 = evaluation.filter_range_results(
|
||||
new_lims, new_D, new_I, thr)
|
||||
|
||||
prec, recall = evaluation.range_PR(
|
||||
ref_lims, ref_I, lims2, I2, mode=mode)
|
||||
|
||||
ref_precisions[i] = prec
|
||||
ref_recalls[i] = recall
|
||||
|
||||
precisions, recalls = evaluation.range_PR_multiple_thresholds(
|
||||
ref_lims, ref_I,
|
||||
new_lims, new_D, new_I, all_thr,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
np.testing.assert_array_almost_equal(ref_precisions, precisions)
|
||||
np.testing.assert_array_almost_equal(ref_recalls, recalls)
|
||||
|
|
Loading…
Reference in New Issue