Break distance ties in `heap_replace_top()` by ID (#2245)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2245 This changeset makes the `heap_replace_top()` function of the FAISS heap implementation break distance ties by the element's ID, according to the heap's min/max property. Reviewed By: mdouze Differential Revision: D34669542 fbshipit-source-id: 0db24fd12442eedeee917fbb3e811ba4a070ce0fpull/2108/merge
parent
878275f915
commit
d50211a38f
|
@ -47,21 +47,25 @@ inline void heap_pop(size_t k, typename C::T* bh_val, typename C::TI* bh_ids) {
|
|||
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
||||
bh_ids--;
|
||||
typename C::T val = bh_val[k];
|
||||
typename C::TI id = bh_ids[k];
|
||||
size_t i = 1, i1, i2;
|
||||
while (1) {
|
||||
i1 = i << 1;
|
||||
i2 = i1 + 1;
|
||||
if (i1 > k)
|
||||
break;
|
||||
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
|
||||
if (C::cmp(val, bh_val[i1]))
|
||||
if ((i2 == k + 1) ||
|
||||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
|
||||
if (C::cmp2(val, bh_val[i1], id, bh_ids[i1])) {
|
||||
break;
|
||||
}
|
||||
bh_val[i] = bh_val[i1];
|
||||
bh_ids[i] = bh_ids[i1];
|
||||
i = i1;
|
||||
} else {
|
||||
if (C::cmp(val, bh_val[i2]))
|
||||
if (C::cmp2(val, bh_val[i2], id, bh_ids[i2])) {
|
||||
break;
|
||||
}
|
||||
bh_val[i] = bh_val[i2];
|
||||
bh_ids[i] = bh_ids[i2];
|
||||
i = i2;
|
||||
|
@ -80,24 +84,28 @@ inline void heap_push(
|
|||
typename C::T* bh_val,
|
||||
typename C::TI* bh_ids,
|
||||
typename C::T val,
|
||||
typename C::TI ids) {
|
||||
typename C::TI id) {
|
||||
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
||||
bh_ids--;
|
||||
size_t i = k, i_father;
|
||||
while (i > 1) {
|
||||
i_father = i >> 1;
|
||||
if (!C::cmp(val, bh_val[i_father])) /* the heap structure is ok */
|
||||
if (!C::cmp2(val, bh_val[i_father], id, bh_ids[i_father])) {
|
||||
/* the heap structure is ok */
|
||||
break;
|
||||
}
|
||||
bh_val[i] = bh_val[i_father];
|
||||
bh_ids[i] = bh_ids[i_father];
|
||||
i = i_father;
|
||||
}
|
||||
bh_val[i] = val;
|
||||
bh_ids[i] = ids;
|
||||
bh_ids[i] = id;
|
||||
}
|
||||
|
||||
/** Replace the top element from the heap defined by bh_val[0..k-1] and
|
||||
* bh_ids[0..k-1].
|
||||
/**
|
||||
* Replaces the top element from the heap defined by bh_val[0..k-1] and
|
||||
* bh_ids[0..k-1], and for identical bh_val[] values also sorts by bh_ids[]
|
||||
* values.
|
||||
*/
|
||||
template <class C>
|
||||
inline void heap_replace_top(
|
||||
|
@ -105,31 +113,39 @@ inline void heap_replace_top(
|
|||
typename C::T* bh_val,
|
||||
typename C::TI* bh_ids,
|
||||
typename C::T val,
|
||||
typename C::TI ids) {
|
||||
typename C::TI id) {
|
||||
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
||||
bh_ids--;
|
||||
size_t i = 1, i1, i2;
|
||||
while (1) {
|
||||
i1 = i << 1;
|
||||
i2 = i1 + 1;
|
||||
if (i1 > k)
|
||||
if (i1 > k) {
|
||||
break;
|
||||
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
|
||||
if (C::cmp(val, bh_val[i1]))
|
||||
}
|
||||
|
||||
// Note that C::cmp2() is a bool function answering
|
||||
// `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max
|
||||
// heap and same with the `<` sign for min heap.
|
||||
if ((i2 == k + 1) ||
|
||||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
|
||||
if (C::cmp2(val, bh_val[i1], id, bh_ids[i1])) {
|
||||
break;
|
||||
}
|
||||
bh_val[i] = bh_val[i1];
|
||||
bh_ids[i] = bh_ids[i1];
|
||||
i = i1;
|
||||
} else {
|
||||
if (C::cmp(val, bh_val[i2]))
|
||||
if (C::cmp2(val, bh_val[i2], id, bh_ids[i2])) {
|
||||
break;
|
||||
}
|
||||
bh_val[i] = bh_val[i2];
|
||||
bh_ids[i] = bh_ids[i2];
|
||||
i = i2;
|
||||
}
|
||||
}
|
||||
bh_val[i] = val;
|
||||
bh_ids[i] = ids;
|
||||
bh_ids[i] = id;
|
||||
}
|
||||
|
||||
/* Partial instanciation for heaps with TI = int64_t */
|
||||
|
@ -294,7 +310,7 @@ inline void maxheap_addn(
|
|||
* Heap finalization (reorder elements)
|
||||
*******************************************************************/
|
||||
|
||||
/* This function maps a binary heap into an sorted structure.
|
||||
/* This function maps a binary heap into a sorted structure.
|
||||
It returns the number */
|
||||
template <typename C>
|
||||
inline size_t heap_reorder(
|
||||
|
|
|
@ -46,6 +46,11 @@ struct CMin {
|
|||
inline static bool cmp(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
// Similar to cmp(), but also breaks ties
|
||||
// by comparing the second pair of arguments.
|
||||
inline static bool cmp2(T a1, T b1, TI a2, TI b2) {
|
||||
return (a1 < b1) || ((a1 == b1) && (a2 < b2));
|
||||
}
|
||||
inline static T neutral() {
|
||||
return std::numeric_limits<T>::lowest();
|
||||
}
|
||||
|
@ -64,6 +69,11 @@ struct CMax {
|
|||
inline static bool cmp(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
// Similar to cmp(), but also breaks ties
|
||||
// by comparing the second pair of arguments.
|
||||
inline static bool cmp2(T a1, T b1, TI a2, TI b2) {
|
||||
return (a1 > b1) || ((a1 == b1) && (a2 > b2));
|
||||
}
|
||||
inline static T neutral() {
|
||||
return std::numeric_limits<T>::max();
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ class TestRounding(unittest.TestCase):
|
|||
recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq
|
||||
|
||||
min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99
|
||||
self.assertGreater(recalls[1], min_r1)
|
||||
self.assertGreaterEqual(recalls[1], min_r1)
|
||||
self.assertGreater(recalls[10], 0.995)
|
||||
# check accuracy of distances
|
||||
# err3 = ((D3 - D2) ** 2).sum()
|
||||
|
@ -423,7 +423,7 @@ class TestAdd(unittest.TestCase):
|
|||
|
||||
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
|
||||
|
||||
self.assertGreater(recall_at_1, 0.99)
|
||||
self.assertGreaterEqual(recall_at_1, 0.99)
|
||||
|
||||
data = faiss.serialize_index(index2)
|
||||
index3 = faiss.deserialize_index(data)
|
||||
|
|
|
@ -436,15 +436,15 @@ class TestTraining(unittest.TestCase):
|
|||
m3 = three_metrics(Dref, Iref, Dnew, Inew)
|
||||
# print((by_residual, metric, d), ":", m3)
|
||||
ref_m3_tab = {
|
||||
(True, 1, 32) : (0.995, 1.0, 9.91),
|
||||
(True, 0, 32) : (0.99, 1.0, 9.91),
|
||||
(True, 1, 30) : (0.99, 1.0, 9.885),
|
||||
(False, 1, 32) : (0.99, 1.0, 9.875),
|
||||
(False, 0, 32) : (0.99, 1.0, 9.92),
|
||||
(False, 1, 30) : (1.0, 1.0, 9.895)
|
||||
(True, 1, 32): (0.995, 1.0, 9.91),
|
||||
(True, 0, 32): (0.99, 1.0, 9.91),
|
||||
(True, 1, 30): (0.989, 1.0, 9.885),
|
||||
(False, 1, 32): (0.99, 1.0, 9.875),
|
||||
(False, 0, 32): (0.99, 1.0, 9.92),
|
||||
(False, 1, 30): (1.0, 1.0, 9.895)
|
||||
}
|
||||
ref_m3 = ref_m3_tab[(by_residual, metric, d)]
|
||||
self.assertGreater(m3[0], ref_m3[0] * 0.99)
|
||||
self.assertGreaterEqual(m3[0], ref_m3[0] * 0.99)
|
||||
self.assertGreater(m3[1], ref_m3[1] * 0.99)
|
||||
self.assertGreater(m3[2], ref_m3[2] * 0.99)
|
||||
|
||||
|
|
|
@ -4,13 +4,15 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import faiss
|
||||
|
||||
# noqa E741
|
||||
# translation of test_knn.lua
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import faiss
|
||||
|
||||
from common_faiss_tests import Randu10k, get_dataset_2, Randu10kUnbalanced
|
||||
|
||||
ev = Randu10k()
|
||||
|
@ -25,28 +27,27 @@ kprobe = int(np.sqrt(ncentroids))
|
|||
nbits = d
|
||||
|
||||
# Parameters for indexes involving PQ
|
||||
M = int(d / 8) # for PQ: #subquantizers
|
||||
nbits_per_index = 8 # for PQ
|
||||
M = int(d / 8) # for PQ: #subquantizers
|
||||
nbits_per_index = 8 # for PQ
|
||||
|
||||
|
||||
class IndexAccuracy(unittest.TestCase):
|
||||
|
||||
def test_IndexFlatIP(self):
|
||||
q = faiss.IndexFlatIP(d) # Ask inner product
|
||||
res = ev.launch('FLAT / IP', q)
|
||||
res = ev.launch("FLAT / IP", q)
|
||||
e = ev.evalres(res)
|
||||
assert e[1] == 1.0
|
||||
|
||||
def test_IndexFlatL2(self):
|
||||
q = faiss.IndexFlatL2(d)
|
||||
res = ev.launch('FLAT / L2', q)
|
||||
res = ev.launch("FLAT / L2", q)
|
||||
e = ev.evalres(res)
|
||||
assert e[1] == 1.0
|
||||
|
||||
def test_ivf_kmeans(self):
|
||||
ivfk = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, ncentroids)
|
||||
ivfk.nprobe = kprobe
|
||||
res = ev.launch('IndexIVFFlat', ivfk)
|
||||
res = ev.launch("IndexIVFFlat", ivfk)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.260 0.260 0.260
|
||||
assert e[1] > 0.2
|
||||
|
@ -61,7 +62,7 @@ class IndexAccuracy(unittest.TestCase):
|
|||
|
||||
def test_indexLSH(self):
|
||||
q = faiss.IndexLSH(d, nbits)
|
||||
res = ev.launch('FLAT / LSH Cosine', q)
|
||||
res = ev.launch("FLAT / LSH Cosine", q)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.070 0.250 0.580
|
||||
assert e[10] > 0.2
|
||||
|
@ -70,14 +71,14 @@ class IndexAccuracy(unittest.TestCase):
|
|||
# CHECK: the difference between 32 and 48 does not make much sense
|
||||
for nbits2 in 32, 48:
|
||||
q = faiss.IndexLSH(d, nbits2)
|
||||
res = ev.launch('LSH half size', q)
|
||||
res = ev.launch("LSH half size", q)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.003 0.019 0.108
|
||||
assert e[10] > 0.018
|
||||
|
||||
def test_IndexPQ(self):
|
||||
q = faiss.IndexPQ(d, M, nbits_per_index)
|
||||
res = ev.launch('FLAT / PQ L2', q)
|
||||
res = ev.launch("FLAT / PQ L2", q)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.070 0.230 0.260
|
||||
assert e[10] > 0.2
|
||||
|
@ -85,16 +86,16 @@ class IndexAccuracy(unittest.TestCase):
|
|||
# Approximate search module: PQ with inner product distance
|
||||
def test_IndexPQ_ip(self):
|
||||
q = faiss.IndexPQ(d, M, nbits_per_index, faiss.METRIC_INNER_PRODUCT)
|
||||
res = ev.launch('FLAT / PQ IP', q)
|
||||
res = ev.launch("FLAT / PQ IP", q)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.070 0.230 0.260
|
||||
#(same result as regular PQ on normalized distances)
|
||||
# (same result as regular PQ on normalized distances)
|
||||
assert e[10] > 0.2
|
||||
|
||||
def test_IndexIVFPQ(self):
|
||||
ivfpq = faiss.IndexIVFPQ(faiss.IndexFlatL2(d), d, ncentroids, M, 8)
|
||||
ivfpq.nprobe = kprobe
|
||||
res = ev.launch('IVF PQ', ivfpq)
|
||||
res = ev.launch("IVF PQ", ivfpq)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.070 0.230 0.260
|
||||
assert e[10] > 0.2
|
||||
|
@ -104,17 +105,17 @@ class IndexAccuracy(unittest.TestCase):
|
|||
# Approximate search: PQ with full vector refinement
|
||||
def test_IndexPQ_refined(self):
|
||||
q = faiss.IndexPQ(d, M, nbits_per_index)
|
||||
res = ev.launch('PQ non-refined', q)
|
||||
res = ev.launch("PQ non-refined", q)
|
||||
e = ev.evalres(res)
|
||||
q.reset()
|
||||
|
||||
rq = faiss.IndexRefineFlat(q)
|
||||
res = ev.launch('PQ refined', rq)
|
||||
res = ev.launch("PQ refined", rq)
|
||||
e2 = ev.evalres(res)
|
||||
assert e2[10] >= e[10]
|
||||
rq.k_factor = 4
|
||||
|
||||
res = ev.launch('PQ refined*4', rq)
|
||||
res = ev.launch("PQ refined*4", rq)
|
||||
e3 = ev.evalres(res)
|
||||
assert e3[10] >= e2[10]
|
||||
|
||||
|
@ -124,17 +125,16 @@ class IndexAccuracy(unittest.TestCase):
|
|||
# reduce nb iterations to speed up training for the test
|
||||
index.polysemous_training.n_iter = 50000
|
||||
index.polysemous_training.n_redo = 1
|
||||
res = ev.launch('normal PQ', index)
|
||||
res = ev.launch("normal PQ", index)
|
||||
e_baseline = ev.evalres(res)
|
||||
index.search_type = faiss.IndexPQ.ST_polysemous
|
||||
|
||||
index.polysemous_ht = int(M / 16. * 58)
|
||||
index.polysemous_ht = int(M / 16.0 * 58)
|
||||
|
||||
stats = faiss.cvar.indexPQ_stats
|
||||
stats.reset()
|
||||
|
||||
res = ev.launch('Polysemous ht=%d' % index.polysemous_ht,
|
||||
index)
|
||||
res = ev.launch("Polysemous ht=%d" % index.polysemous_ht, index)
|
||||
e_polysemous = ev.evalres(res)
|
||||
print(e_baseline, e_polysemous, index.polysemous_ht)
|
||||
print(stats.n_hamming_pass, stats.ncode)
|
||||
|
@ -149,16 +149,16 @@ class IndexAccuracy(unittest.TestCase):
|
|||
def test_ScalarQuantizer(self):
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
ivfpq = faiss.IndexIVFScalarQuantizer(
|
||||
quantizer, d, ncentroids,
|
||||
faiss.ScalarQuantizer.QT_8bit)
|
||||
quantizer, d, ncentroids, faiss.ScalarQuantizer.QT_8bit
|
||||
)
|
||||
ivfpq.nprobe = kprobe
|
||||
res = ev.launch('IVF SQ', ivfpq)
|
||||
res = ev.launch("IVF SQ", ivfpq)
|
||||
e = ev.evalres(res)
|
||||
# should give 0.234 0.236 0.236
|
||||
assert e[10] > 0.235
|
||||
|
||||
def test_polysemous_OOM(self):
|
||||
""" this used to cause OOM when training polysemous with large
|
||||
"""this used to cause OOM when training polysemous with large
|
||||
nb bits"""
|
||||
d = 32
|
||||
xt, xb, xq = get_dataset_2(d, 10000, 0, 0)
|
||||
|
@ -170,13 +170,10 @@ class IndexAccuracy(unittest.TestCase):
|
|||
|
||||
|
||||
class TestSQFlavors(unittest.TestCase):
|
||||
""" tests IP in addition to L2, non multiple of 8 dimensions
|
||||
"""
|
||||
"""tests IP in addition to L2, non multiple of 8 dimensions"""
|
||||
|
||||
def add2columns(self, x):
|
||||
return np.hstack((
|
||||
x, np.zeros((x.shape[0], 2), dtype='float32')
|
||||
))
|
||||
return np.hstack((x, np.zeros((x.shape[0], 2), dtype="float32")))
|
||||
|
||||
def subtest_add2col(self, xb, xq, index, qname):
|
||||
"""Test with 2 additional dimensions to take also the non-SIMD
|
||||
|
@ -197,19 +194,17 @@ class TestSQFlavors(unittest.TestCase):
|
|||
centroids2 = self.add2columns(centroids)
|
||||
quantizer2.add(centroids2)
|
||||
index2 = faiss.IndexIVFScalarQuantizer(
|
||||
quantizer2, d2, index.nlist, index.sq.qtype,
|
||||
index.metric_type)
|
||||
quantizer2, d2, index.nlist, index.sq.qtype, index.metric_type
|
||||
)
|
||||
index2.nprobe = 4
|
||||
if qname in ('8bit', '4bit'):
|
||||
if qname in ("8bit", "4bit"):
|
||||
trained = faiss.vector_to_array(index.sq.trained).reshape(2, -1)
|
||||
nt = trained.shape[1]
|
||||
# 2 lines: vmins and vdiffs
|
||||
new_nt = int(nt * d2 / d)
|
||||
trained2 = np.hstack((
|
||||
trained,
|
||||
np.zeros((2, new_nt - nt), dtype='float32')
|
||||
))
|
||||
trained2[1, nt:] = 1.0 # set vdiff to 1 to avoid div by 0
|
||||
trained2 = np.hstack((trained, np.zeros((2, new_nt - nt),
|
||||
dtype="float32")))
|
||||
trained2[1, nt:] = 1.0 # set vdiff to 1 to avoid div by 0
|
||||
faiss.copy_array_to_vector(trained2.ravel(), index2.sq.trained)
|
||||
else:
|
||||
index2.sq.trained = index.sq.trained
|
||||
|
@ -218,22 +213,21 @@ class TestSQFlavors(unittest.TestCase):
|
|||
index2.add(xb2)
|
||||
return index2.search(xq2, 10)
|
||||
|
||||
|
||||
# run on Sept 18, 2018 with nprobe=4 + 4 bit bugfix
|
||||
ref_results = {
|
||||
(0, '8bit'): 984,
|
||||
(0, '4bit'): 978,
|
||||
(0, '8bit_uniform'): 985,
|
||||
(0, '4bit_uniform'): 979,
|
||||
(0, 'fp16'): 985,
|
||||
(1, '8bit'): 979,
|
||||
(1, '4bit'): 973,
|
||||
(1, '8bit_uniform'): 979,
|
||||
(1, '4bit_uniform'): 972,
|
||||
(1, 'fp16'): 979,
|
||||
(0, "8bit"): 984,
|
||||
(0, "4bit"): 978,
|
||||
(0, "8bit_uniform"): 985,
|
||||
(0, "4bit_uniform"): 979,
|
||||
(0, "fp16"): 985,
|
||||
(1, "8bit"): 979,
|
||||
(1, "4bit"): 973,
|
||||
(1, "8bit_uniform"): 979,
|
||||
(1, "4bit_uniform"): 972,
|
||||
(1, "fp16"): 979,
|
||||
# added 2019-06-26
|
||||
(0, '6bit'): 985,
|
||||
(1, '6bit'): 987,
|
||||
(0, "6bit"): 985,
|
||||
(1, "6bit"): 987,
|
||||
}
|
||||
|
||||
def subtest(self, mt):
|
||||
|
@ -245,19 +239,19 @@ class TestSQFlavors(unittest.TestCase):
|
|||
gt_index.add(xb)
|
||||
gt_D, gt_I = gt_index.search(xq, 10)
|
||||
quantizer = faiss.IndexFlat(d, mt)
|
||||
for qname in '8bit 4bit 8bit_uniform 4bit_uniform fp16 6bit'.split():
|
||||
qtype = getattr(faiss.ScalarQuantizer, 'QT_' + qname)
|
||||
index = faiss.IndexIVFScalarQuantizer(
|
||||
quantizer, d, nlist, qtype, mt)
|
||||
for qname in "8bit 4bit 8bit_uniform 4bit_uniform fp16 6bit".split():
|
||||
qtype = getattr(faiss.ScalarQuantizer, "QT_" + qname)
|
||||
index = faiss.IndexIVFScalarQuantizer(quantizer, d, nlist, qtype,
|
||||
mt)
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
index.nprobe = 4 # hopefully more robust than 1
|
||||
index.nprobe = 4 # hopefully more robust than 1
|
||||
D, I = index.search(xq, 10)
|
||||
ninter = faiss.eval_intersection(I, gt_I)
|
||||
print('(%d, %s): %d, ' % (mt, repr(qname), ninter))
|
||||
print("(%d, %s): %d, " % (mt, repr(qname), ninter))
|
||||
assert abs(ninter - self.ref_results[(mt, qname)]) <= 10
|
||||
|
||||
if qname == '6bit':
|
||||
if qname == "6bit":
|
||||
# the test below fails triggers ASAN. TODO check what's wrong
|
||||
continue
|
||||
|
||||
|
@ -270,7 +264,7 @@ class TestSQFlavors(unittest.TestCase):
|
|||
radius = float(D[:, -1].max())
|
||||
else:
|
||||
radius = float(D[:, -1].min())
|
||||
print('radius', radius)
|
||||
print("radius", radius)
|
||||
|
||||
lims, D3, I3 = index.range_search(xq, radius)
|
||||
ntot = ndiff = 0
|
||||
|
@ -284,19 +278,22 @@ class TestSQFlavors(unittest.TestCase):
|
|||
Iref = set(I2[i, mask])
|
||||
ndiff += len(Inew ^ Iref)
|
||||
ntot += len(Iref)
|
||||
print('ndiff %d / %d' % (ndiff, ntot))
|
||||
print("ndiff %d / %d" % (ndiff, ntot))
|
||||
assert ndiff < ntot * 0.01
|
||||
|
||||
for pm in 1, 2:
|
||||
print('parallel_mode=%d' % pm)
|
||||
print("parallel_mode=%d" % pm)
|
||||
index.parallel_mode = pm
|
||||
lims4, D4, I4 = index.range_search(xq, radius)
|
||||
print('sizes', lims4[1:] - lims4[:-1])
|
||||
print("sizes", lims4[1:] - lims4[:-1])
|
||||
for qno in range(len(lims) - 1):
|
||||
Iref = I3[lims[qno]: lims[qno+1]]
|
||||
Inew = I4[lims4[qno]: lims4[qno+1]]
|
||||
Iref = I3[lims[qno]: lims[qno + 1]]
|
||||
Inew = I4[lims4[qno]: lims4[qno + 1]]
|
||||
assert set(Iref) == set(Inew), "q %d ref %s new %s" % (
|
||||
qno, Iref, Inew)
|
||||
qno,
|
||||
Iref,
|
||||
Inew,
|
||||
)
|
||||
|
||||
def test_SQ_IP(self):
|
||||
self.subtest(faiss.METRIC_INNER_PRODUCT)
|
||||
|
@ -311,7 +308,7 @@ class TestSQFlavors(unittest.TestCase):
|
|||
index = faiss.index_factory(d, "IVF64,SQ8")
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
index.nprobe = 4 # hopefully more robust than 1
|
||||
index.nprobe = 4 # hopefully more robust than 1
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
|
||||
for pm in 1, 2, 3:
|
||||
|
@ -323,7 +320,6 @@ class TestSQFlavors(unittest.TestCase):
|
|||
|
||||
|
||||
class TestSQByte(unittest.TestCase):
|
||||
|
||||
def subtest_8bit_direct(self, metric_type, d):
|
||||
xt, xb, xq = get_dataset_2(d, 500, 1000, 30)
|
||||
|
||||
|
@ -345,7 +341,8 @@ class TestSQByte(unittest.TestCase):
|
|||
Dref, Iref = gt_index.search(xq, 10)
|
||||
|
||||
index = faiss.IndexScalarQuantizer(
|
||||
d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type)
|
||||
d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type
|
||||
)
|
||||
index.add(xb)
|
||||
D, I = index.search(xq, 10)
|
||||
|
||||
|
@ -364,8 +361,9 @@ class TestSQByte(unittest.TestCase):
|
|||
Dref, Iref = gt_index.search(xq, 10)
|
||||
|
||||
index = faiss.IndexIVFScalarQuantizer(
|
||||
quantizer, d, nlist,
|
||||
faiss.ScalarQuantizer.QT_8bit_direct, metric_type)
|
||||
quantizer, d, nlist, faiss.ScalarQuantizer.QT_8bit_direct,
|
||||
metric_type
|
||||
)
|
||||
index.nprobe = 4
|
||||
index.by_residual = False
|
||||
index.train(xt)
|
||||
|
@ -382,7 +380,6 @@ class TestSQByte(unittest.TestCase):
|
|||
|
||||
|
||||
class TestNNDescent(unittest.TestCase):
|
||||
|
||||
def test_L1(self):
|
||||
search_Ls = [10, 20, 30]
|
||||
thresholds = [0.83, 0.92, 0.95]
|
||||
|
@ -402,9 +399,11 @@ class TestNNDescent(unittest.TestCase):
|
|||
self.subtest(32, faiss.METRIC_INNER_PRODUCT, 10, search_L, threshold)
|
||||
|
||||
def subtest(self, d, metric, topk, search_L, threshold):
|
||||
metric_names = {faiss.METRIC_L1: 'L1',
|
||||
faiss.METRIC_L2: 'L2',
|
||||
faiss.METRIC_INNER_PRODUCT: 'IP'}
|
||||
metric_names = {
|
||||
faiss.METRIC_L1: "L1",
|
||||
faiss.METRIC_L2: "L2",
|
||||
faiss.METRIC_INNER_PRODUCT: "IP",
|
||||
}
|
||||
topk = 10
|
||||
nt, nb, nq = 2000, 1000, 200
|
||||
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
|
||||
|
@ -432,9 +431,12 @@ class TestNNDescent(unittest.TestCase):
|
|||
recalls += 1
|
||||
break
|
||||
recall = 1.0 * recalls / (nq * topk)
|
||||
print('Metric: {}, L: {}, Recall@{}: {}'.format(
|
||||
metric_names[metric], search_L, topk, recall))
|
||||
assert recall > threshold, '{} <= {}'.format(recall, threshold)
|
||||
print(
|
||||
"Metric: {}, L: {}, Recall@{}: {}".format(
|
||||
metric_names[metric], search_L, topk, recall
|
||||
)
|
||||
)
|
||||
assert recall > threshold, "{} <= {}".format(recall, threshold)
|
||||
|
||||
|
||||
class TestPQFlavors(unittest.TestCase):
|
||||
|
@ -466,8 +468,7 @@ class TestPQFlavors(unittest.TestCase):
|
|||
quantizer = faiss.IndexFlat(d, mt)
|
||||
for by_residual in True, False:
|
||||
|
||||
index = faiss.IndexIVFPQ(
|
||||
quantizer, d, nlist, 4, 8)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, 4, 8)
|
||||
index.metric_type = mt
|
||||
index.by_residual = by_residual
|
||||
if by_residual:
|
||||
|
@ -484,7 +485,7 @@ class TestPQFlavors(unittest.TestCase):
|
|||
D, I = index.search(xq, 10)
|
||||
|
||||
ninter = faiss.eval_intersection(I, gt_I)
|
||||
print('(%d, %s): %d, ' % (mt, by_residual, ninter))
|
||||
print("(%d, %s): %d, " % (mt, by_residual, ninter))
|
||||
|
||||
assert abs(ninter - self.ref_results[mt, by_residual]) <= 3
|
||||
|
||||
|
@ -498,12 +499,16 @@ class TestPQFlavors(unittest.TestCase):
|
|||
index.polysemous_ht = 20
|
||||
D, I = index.search(xq, 10)
|
||||
ninter = faiss.eval_intersection(I, gt_I)
|
||||
print('(%d, %s, %d): %d, ' % (
|
||||
mt, by_residual, index.polysemous_ht, ninter))
|
||||
print(
|
||||
"(%d, %s, %d): %d, "
|
||||
% (mt, by_residual, index.polysemous_ht, ninter)
|
||||
)
|
||||
|
||||
# polysemous behaves bizarrely on ARM
|
||||
assert (ninter >= self.ref_results[
|
||||
mt, by_residual, index.polysemous_ht] - 4)
|
||||
assert (
|
||||
ninter >= self.ref_results[mt, by_residual,
|
||||
index.polysemous_ht] - 4
|
||||
)
|
||||
|
||||
# also test range search
|
||||
|
||||
|
@ -511,7 +516,7 @@ class TestPQFlavors(unittest.TestCase):
|
|||
radius = float(D[:, -1].max())
|
||||
else:
|
||||
radius = float(D[:, -1].min())
|
||||
print('radius', radius)
|
||||
print("radius", radius)
|
||||
|
||||
lims, D3, I3 = index.range_search(xq, radius)
|
||||
ntot = ndiff = 0
|
||||
|
@ -525,7 +530,7 @@ class TestPQFlavors(unittest.TestCase):
|
|||
Iref = set(I2[i, mask])
|
||||
ndiff += len(Inew ^ Iref)
|
||||
ntot += len(Iref)
|
||||
print('ndiff %d / %d' % (ndiff, ntot))
|
||||
print("ndiff %d / %d" % (ndiff, ntot))
|
||||
assert ndiff < ntot * 0.02
|
||||
|
||||
def test_IVFPQ_non8bit(self):
|
||||
|
@ -539,36 +544,33 @@ class TestPQFlavors(unittest.TestCase):
|
|||
|
||||
quantizer = faiss.IndexFlat(d)
|
||||
ninter = {}
|
||||
for v in '2x8', '8x2':
|
||||
if v == '8x2':
|
||||
index = faiss.IndexIVFPQ(
|
||||
quantizer, d, nlist, 2, 8)
|
||||
for v in "2x8", "8x2":
|
||||
if v == "8x2":
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, 2, 8)
|
||||
else:
|
||||
index = faiss.IndexIVFPQ(
|
||||
quantizer, d, nlist, 8, 2)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, 8, 2)
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
index.npobe = 16
|
||||
|
||||
D, I = index.search(xq, 10)
|
||||
ninter[v] = faiss.eval_intersection(I, gt_I)
|
||||
print('ninter=', ninter)
|
||||
print("ninter=", ninter)
|
||||
# this should be the case but we don't observe
|
||||
# that... Probavly too few test points
|
||||
# assert ninter['2x8'] > ninter['8x2']
|
||||
# ref numbers on 2019-11-02
|
||||
assert abs(ninter['2x8'] - 458) < 4
|
||||
assert abs(ninter['8x2'] - 465) < 4
|
||||
assert abs(ninter["2x8"] - 458) < 4
|
||||
assert abs(ninter["8x2"] - 465) < 4
|
||||
|
||||
|
||||
class TestFlat1D(unittest.TestCase):
|
||||
|
||||
def test_flat_1d(self):
|
||||
rs = np.random.RandomState(123545)
|
||||
k = 10
|
||||
xb = rs.uniform(size=(100, 1)).astype('float32')
|
||||
xb = rs.uniform(size=(100, 1)).astype("float32")
|
||||
# make sure to test below and above
|
||||
xq = rs.uniform(size=(1000, 1)).astype('float32') * 1.1 - 0.05
|
||||
xq = rs.uniform(size=(1000, 1)).astype("float32") * 1.1 - 0.05
|
||||
|
||||
ref = faiss.IndexFlatL2(1)
|
||||
ref.add(xb)
|
||||
|
@ -581,10 +583,10 @@ class TestFlat1D(unittest.TestCase):
|
|||
|
||||
ndiff = (np.abs(ref_I - new_I) != 0).sum()
|
||||
|
||||
assert(ndiff < 100)
|
||||
assert ndiff < 100
|
||||
new_D = new_D ** 2
|
||||
max_diff_D = np.abs(ref_D - new_D).max()
|
||||
assert(max_diff_D < 1e-5)
|
||||
assert max_diff_D < 1e-5
|
||||
|
||||
|
||||
class OPQRelativeAccuracy(unittest.TestCase):
|
||||
|
@ -598,7 +600,7 @@ class OPQRelativeAccuracy(unittest.TestCase):
|
|||
d = ev.d
|
||||
index = faiss.IndexPQ(d, M, 8)
|
||||
|
||||
res = ev.launch('PQ', index)
|
||||
res = ev.launch("PQ", index)
|
||||
e_pq = ev.evalres(res)
|
||||
|
||||
index_pq = faiss.IndexPQ(d, M, 8)
|
||||
|
@ -608,15 +610,15 @@ class OPQRelativeAccuracy(unittest.TestCase):
|
|||
opq_matrix.niter_pq = 4
|
||||
index = faiss.IndexPreTransform(opq_matrix, index_pq)
|
||||
|
||||
res = ev.launch('OPQ', index)
|
||||
res = ev.launch("OPQ", index)
|
||||
e_opq = ev.evalres(res)
|
||||
|
||||
print('e_pq=%s' % e_pq)
|
||||
print('e_opq=%s' % e_opq)
|
||||
print("e_pq=%s" % e_pq)
|
||||
print("e_opq=%s" % e_opq)
|
||||
|
||||
# verify that OPQ better than PQ
|
||||
for r in 1, 10, 100:
|
||||
assert(e_opq[r] > e_pq[r])
|
||||
assert e_opq[r] > e_pq[r]
|
||||
|
||||
def test_OIVFPQ(self):
|
||||
# Parameters inverted indexes
|
||||
|
@ -629,7 +631,7 @@ class OPQRelativeAccuracy(unittest.TestCase):
|
|||
index = faiss.IndexIVFPQ(quantizer, d, ncentroids, M, 8)
|
||||
index.nprobe = 5
|
||||
|
||||
res = ev.launch('IVFPQ', index)
|
||||
res = ev.launch("IVFPQ", index)
|
||||
e_ivfpq = ev.evalres(res)
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
@ -639,23 +641,22 @@ class OPQRelativeAccuracy(unittest.TestCase):
|
|||
opq_matrix.niter = 10
|
||||
index = faiss.IndexPreTransform(opq_matrix, index_ivfpq)
|
||||
|
||||
res = ev.launch('O+IVFPQ', index)
|
||||
res = ev.launch("O+IVFPQ", index)
|
||||
e_oivfpq = ev.evalres(res)
|
||||
|
||||
# verify same on OIVFPQ
|
||||
for r in 1, 10, 100:
|
||||
print(e_oivfpq[r], e_ivfpq[r])
|
||||
assert(e_oivfpq[r] >= e_ivfpq[r])
|
||||
assert e_oivfpq[r] >= e_ivfpq[r]
|
||||
|
||||
|
||||
class TestRoundoff(unittest.TestCase):
|
||||
|
||||
def test_roundoff(self):
|
||||
# params that force use of BLAS implementation
|
||||
nb = 100
|
||||
nq = 25
|
||||
d = 4
|
||||
xb = np.zeros((nb, d), dtype='float32')
|
||||
xb = np.zeros((nb, d), dtype="float32")
|
||||
|
||||
xb[:, 0] = np.arange(nb) + 12345
|
||||
xq = xb[:nq] + 0.3
|
||||
|
@ -668,9 +669,8 @@ class TestRoundoff(unittest.TestCase):
|
|||
# this does not work
|
||||
assert not np.all(I.ravel() == np.arange(nq))
|
||||
|
||||
index = faiss.IndexPreTransform(
|
||||
faiss.CenteringTransform(d),
|
||||
faiss.IndexFlat(d))
|
||||
index = faiss.IndexPreTransform(faiss.CenteringTransform(d),
|
||||
faiss.IndexFlat(d))
|
||||
|
||||
index.train(xb)
|
||||
index.add(xb)
|
||||
|
@ -685,30 +685,30 @@ class TestSpectralHash(unittest.TestCase):
|
|||
|
||||
# run on 2019-04-02
|
||||
ref_results = {
|
||||
(32, 'global', 10): 505,
|
||||
(32, 'centroid', 10): 524,
|
||||
(32, 'centroid_half', 10): 21,
|
||||
(32, 'median', 10): 510,
|
||||
(32, 'global', 1): 8,
|
||||
(32, 'centroid', 1): 20,
|
||||
(32, 'centroid_half', 1): 26,
|
||||
(32, 'median', 1): 14,
|
||||
(64, 'global', 10): 768,
|
||||
(64, 'centroid', 10): 767,
|
||||
(64, 'centroid_half', 10): 21,
|
||||
(64, 'median', 10): 765,
|
||||
(64, 'global', 1): 28,
|
||||
(64, 'centroid', 1): 21,
|
||||
(64, 'centroid_half', 1): 20,
|
||||
(64, 'median', 1): 29,
|
||||
(128, 'global', 10): 968,
|
||||
(128, 'centroid', 10): 945,
|
||||
(128, 'centroid_half', 10): 21,
|
||||
(128, 'median', 10): 958,
|
||||
(128, 'global', 1): 271,
|
||||
(128, 'centroid', 1): 279,
|
||||
(128, 'centroid_half', 1): 171,
|
||||
(128, 'median', 1): 253,
|
||||
(32, "global", 10): 505,
|
||||
(32, "centroid", 10): 524,
|
||||
(32, "centroid_half", 10): 21,
|
||||
(32, "median", 10): 510,
|
||||
(32, "global", 1): 8,
|
||||
(32, "centroid", 1): 20,
|
||||
(32, "centroid_half", 1): 26,
|
||||
(32, "median", 1): 14,
|
||||
(64, "global", 10): 768,
|
||||
(64, "centroid", 10): 767,
|
||||
(64, "centroid_half", 10): 21,
|
||||
(64, "median", 10): 765,
|
||||
(64, "global", 1): 28,
|
||||
(64, "centroid", 1): 21,
|
||||
(64, "centroid_half", 1): 20,
|
||||
(64, "median", 1): 29,
|
||||
(128, "global", 10): 968,
|
||||
(128, "centroid", 10): 945,
|
||||
(128, "centroid_half", 10): 21,
|
||||
(128, "median", 10): 958,
|
||||
(128, "global", 1): 271,
|
||||
(128, "centroid", 1): 279,
|
||||
(128, "centroid_half", 1): 171,
|
||||
(128, "median", 1): 253,
|
||||
}
|
||||
|
||||
def test_sh(self):
|
||||
|
@ -728,17 +728,17 @@ class TestSpectralHash(unittest.TestCase):
|
|||
D, I = index_lsh.search(xq, 10)
|
||||
ninter = faiss.eval_intersection(I, gt_I)
|
||||
|
||||
print('LSH baseline: %d' % ninter)
|
||||
print("LSH baseline: %d" % ninter)
|
||||
|
||||
for period in 10.0, 1.0:
|
||||
|
||||
for tt in 'global centroid centroid_half median'.split():
|
||||
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist,
|
||||
nbit, period)
|
||||
for tt in "global centroid centroid_half median".split():
|
||||
index = faiss.IndexIVFSpectralHash(
|
||||
quantizer, d, nlist, nbit, period
|
||||
)
|
||||
index.nprobe = nprobe
|
||||
index.threshold_type = getattr(
|
||||
faiss.IndexIVFSpectralHash,
|
||||
'Thresh_' + tt
|
||||
faiss.IndexIVFSpectralHash, "Thresh_" + tt
|
||||
)
|
||||
|
||||
index.train(xt)
|
||||
|
@ -748,12 +748,13 @@ class TestSpectralHash(unittest.TestCase):
|
|||
ninter = faiss.eval_intersection(I, gt_I)
|
||||
key = (nbit, tt, period)
|
||||
|
||||
print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter))
|
||||
assert abs(ninter - self.ref_results[key]) <= 12
|
||||
print("(%d, %s, %g): %d, " % (nbit, repr(tt), period,
|
||||
ninter))
|
||||
print(abs(ninter - self.ref_results[key]))
|
||||
assert abs(ninter - self.ref_results[key]) <= 14
|
||||
|
||||
|
||||
class TestRefine(unittest.TestCase):
|
||||
|
||||
def do_test(self, metric):
|
||||
d = 32
|
||||
xt, xb, xq = get_dataset_2(d, 2000, 1000, 200)
|
||||
|
|
|
@ -735,7 +735,7 @@ class TestIndexResidualQuantizerSearch(unittest.TestCase):
|
|||
self.assertLess((Iref != I2).sum(), Iref.size * 0.05)
|
||||
else:
|
||||
inter_2 = faiss.eval_intersection(I2, gt)
|
||||
self.assertGreater(inter_ref, inter_2)
|
||||
self.assertGreaterEqual(inter_ref, inter_2)
|
||||
# print(st, inter_ref, inter_2)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue