Break distance ties in `heap_replace_top()` by ID (#2245)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2245

This changeset makes the `heap_replace_top()` function of the FAISS heap implementation break distance ties by the element's ID, according to the heap's min/max property.

Reviewed By: mdouze

Differential Revision: D34669542

fbshipit-source-id: 0db24fd12442eedeee917fbb3e811ba4a070ce0f
pull/2108/merge
Ivan Sopin 2022-03-09 10:23:48 -08:00 committed by Facebook GitHub Bot
parent 878275f915
commit d50211a38f
6 changed files with 199 additions and 172 deletions

View File

@ -47,21 +47,25 @@ inline void heap_pop(size_t k, typename C::T* bh_val, typename C::TI* bh_ids) {
bh_val--; /* Use 1-based indexing for easier node->child translation */ bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--; bh_ids--;
typename C::T val = bh_val[k]; typename C::T val = bh_val[k];
typename C::TI id = bh_ids[k];
size_t i = 1, i1, i2; size_t i = 1, i1, i2;
while (1) { while (1) {
i1 = i << 1; i1 = i << 1;
i2 = i1 + 1; i2 = i1 + 1;
if (i1 > k) if (i1 > k)
break; break;
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { if ((i2 == k + 1) ||
if (C::cmp(val, bh_val[i1])) C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
if (C::cmp2(val, bh_val[i1], id, bh_ids[i1])) {
break; break;
}
bh_val[i] = bh_val[i1]; bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1]; bh_ids[i] = bh_ids[i1];
i = i1; i = i1;
} else { } else {
if (C::cmp(val, bh_val[i2])) if (C::cmp2(val, bh_val[i2], id, bh_ids[i2])) {
break; break;
}
bh_val[i] = bh_val[i2]; bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2]; bh_ids[i] = bh_ids[i2];
i = i2; i = i2;
@ -80,24 +84,28 @@ inline void heap_push(
typename C::T* bh_val, typename C::T* bh_val,
typename C::TI* bh_ids, typename C::TI* bh_ids,
typename C::T val, typename C::T val,
typename C::TI ids) { typename C::TI id) {
bh_val--; /* Use 1-based indexing for easier node->child translation */ bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--; bh_ids--;
size_t i = k, i_father; size_t i = k, i_father;
while (i > 1) { while (i > 1) {
i_father = i >> 1; i_father = i >> 1;
if (!C::cmp(val, bh_val[i_father])) /* the heap structure is ok */ if (!C::cmp2(val, bh_val[i_father], id, bh_ids[i_father])) {
/* the heap structure is ok */
break; break;
}
bh_val[i] = bh_val[i_father]; bh_val[i] = bh_val[i_father];
bh_ids[i] = bh_ids[i_father]; bh_ids[i] = bh_ids[i_father];
i = i_father; i = i_father;
} }
bh_val[i] = val; bh_val[i] = val;
bh_ids[i] = ids; bh_ids[i] = id;
} }
/** Replace the top element from the heap defined by bh_val[0..k-1] and /**
* bh_ids[0..k-1]. * Replaces the top element from the heap defined by bh_val[0..k-1] and
* bh_ids[0..k-1], and for identical bh_val[] values also sorts by bh_ids[]
* values.
*/ */
template <class C> template <class C>
inline void heap_replace_top( inline void heap_replace_top(
@ -105,31 +113,39 @@ inline void heap_replace_top(
typename C::T* bh_val, typename C::T* bh_val,
typename C::TI* bh_ids, typename C::TI* bh_ids,
typename C::T val, typename C::T val,
typename C::TI ids) { typename C::TI id) {
bh_val--; /* Use 1-based indexing for easier node->child translation */ bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--; bh_ids--;
size_t i = 1, i1, i2; size_t i = 1, i1, i2;
while (1) { while (1) {
i1 = i << 1; i1 = i << 1;
i2 = i1 + 1; i2 = i1 + 1;
if (i1 > k) if (i1 > k) {
break; break;
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { }
if (C::cmp(val, bh_val[i1]))
// Note that C::cmp2() is a bool function answering
// `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max
// heap and same with the `<` sign for min heap.
if ((i2 == k + 1) ||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
if (C::cmp2(val, bh_val[i1], id, bh_ids[i1])) {
break; break;
}
bh_val[i] = bh_val[i1]; bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1]; bh_ids[i] = bh_ids[i1];
i = i1; i = i1;
} else { } else {
if (C::cmp(val, bh_val[i2])) if (C::cmp2(val, bh_val[i2], id, bh_ids[i2])) {
break; break;
}
bh_val[i] = bh_val[i2]; bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2]; bh_ids[i] = bh_ids[i2];
i = i2; i = i2;
} }
} }
bh_val[i] = val; bh_val[i] = val;
bh_ids[i] = ids; bh_ids[i] = id;
} }
/* Partial instanciation for heaps with TI = int64_t */ /* Partial instanciation for heaps with TI = int64_t */
@ -294,7 +310,7 @@ inline void maxheap_addn(
* Heap finalization (reorder elements) * Heap finalization (reorder elements)
*******************************************************************/ *******************************************************************/
/* This function maps a binary heap into an sorted structure. /* This function maps a binary heap into a sorted structure.
It returns the number */ It returns the number */
template <typename C> template <typename C>
inline size_t heap_reorder( inline size_t heap_reorder(

View File

@ -46,6 +46,11 @@ struct CMin {
inline static bool cmp(T a, T b) { inline static bool cmp(T a, T b) {
return a < b; return a < b;
} }
// Similar to cmp(), but also breaks ties
// by comparing the second pair of arguments.
inline static bool cmp2(T a1, T b1, TI a2, TI b2) {
return (a1 < b1) || ((a1 == b1) && (a2 < b2));
}
inline static T neutral() { inline static T neutral() {
return std::numeric_limits<T>::lowest(); return std::numeric_limits<T>::lowest();
} }
@ -64,6 +69,11 @@ struct CMax {
inline static bool cmp(T a, T b) { inline static bool cmp(T a, T b) {
return a > b; return a > b;
} }
// Similar to cmp(), but also breaks ties
// by comparing the second pair of arguments.
inline static bool cmp2(T a1, T b1, TI a2, TI b2) {
return (a1 > b1) || ((a1 == b1) && (a2 > b2));
}
inline static T neutral() { inline static T neutral() {
return std::numeric_limits<T>::max(); return std::numeric_limits<T>::max();
} }

View File

@ -112,7 +112,7 @@ class TestRounding(unittest.TestCase):
recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq
min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99 min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99
self.assertGreater(recalls[1], min_r1) self.assertGreaterEqual(recalls[1], min_r1)
self.assertGreater(recalls[10], 0.995) self.assertGreater(recalls[10], 0.995)
# check accuracy of distances # check accuracy of distances
# err3 = ((D3 - D2) ** 2).sum() # err3 = ((D3 - D2) ** 2).sum()
@ -423,7 +423,7 @@ class TestAdd(unittest.TestCase):
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
self.assertGreater(recall_at_1, 0.99) self.assertGreaterEqual(recall_at_1, 0.99)
data = faiss.serialize_index(index2) data = faiss.serialize_index(index2)
index3 = faiss.deserialize_index(data) index3 = faiss.deserialize_index(data)

View File

@ -438,13 +438,13 @@ class TestTraining(unittest.TestCase):
ref_m3_tab = { ref_m3_tab = {
(True, 1, 32): (0.995, 1.0, 9.91), (True, 1, 32): (0.995, 1.0, 9.91),
(True, 0, 32): (0.99, 1.0, 9.91), (True, 0, 32): (0.99, 1.0, 9.91),
(True, 1, 30) : (0.99, 1.0, 9.885), (True, 1, 30): (0.989, 1.0, 9.885),
(False, 1, 32): (0.99, 1.0, 9.875), (False, 1, 32): (0.99, 1.0, 9.875),
(False, 0, 32): (0.99, 1.0, 9.92), (False, 0, 32): (0.99, 1.0, 9.92),
(False, 1, 30): (1.0, 1.0, 9.895) (False, 1, 30): (1.0, 1.0, 9.895)
} }
ref_m3 = ref_m3_tab[(by_residual, metric, d)] ref_m3 = ref_m3_tab[(by_residual, metric, d)]
self.assertGreater(m3[0], ref_m3[0] * 0.99) self.assertGreaterEqual(m3[0], ref_m3[0] * 0.99)
self.assertGreater(m3[1], ref_m3[1] * 0.99) self.assertGreater(m3[1], ref_m3[1] * 0.99)
self.assertGreater(m3[2], ref_m3[2] * 0.99) self.assertGreater(m3[2], ref_m3[2] * 0.99)

View File

@ -4,13 +4,15 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
import faiss
# noqa E741 # noqa E741
# translation of test_knn.lua # translation of test_knn.lua
import numpy as np import numpy as np
import unittest
import faiss
from common_faiss_tests import Randu10k, get_dataset_2, Randu10kUnbalanced from common_faiss_tests import Randu10k, get_dataset_2, Randu10kUnbalanced
ev = Randu10k() ev = Randu10k()
@ -30,23 +32,22 @@ nbits_per_index = 8 # for PQ
class IndexAccuracy(unittest.TestCase): class IndexAccuracy(unittest.TestCase):
def test_IndexFlatIP(self): def test_IndexFlatIP(self):
q = faiss.IndexFlatIP(d) # Ask inner product q = faiss.IndexFlatIP(d) # Ask inner product
res = ev.launch('FLAT / IP', q) res = ev.launch("FLAT / IP", q)
e = ev.evalres(res) e = ev.evalres(res)
assert e[1] == 1.0 assert e[1] == 1.0
def test_IndexFlatL2(self): def test_IndexFlatL2(self):
q = faiss.IndexFlatL2(d) q = faiss.IndexFlatL2(d)
res = ev.launch('FLAT / L2', q) res = ev.launch("FLAT / L2", q)
e = ev.evalres(res) e = ev.evalres(res)
assert e[1] == 1.0 assert e[1] == 1.0
def test_ivf_kmeans(self): def test_ivf_kmeans(self):
ivfk = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, ncentroids) ivfk = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, ncentroids)
ivfk.nprobe = kprobe ivfk.nprobe = kprobe
res = ev.launch('IndexIVFFlat', ivfk) res = ev.launch("IndexIVFFlat", ivfk)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.260 0.260 0.260 # should give 0.260 0.260 0.260
assert e[1] > 0.2 assert e[1] > 0.2
@ -61,7 +62,7 @@ class IndexAccuracy(unittest.TestCase):
def test_indexLSH(self): def test_indexLSH(self):
q = faiss.IndexLSH(d, nbits) q = faiss.IndexLSH(d, nbits)
res = ev.launch('FLAT / LSH Cosine', q) res = ev.launch("FLAT / LSH Cosine", q)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.070 0.250 0.580 # should give 0.070 0.250 0.580
assert e[10] > 0.2 assert e[10] > 0.2
@ -70,14 +71,14 @@ class IndexAccuracy(unittest.TestCase):
# CHECK: the difference between 32 and 48 does not make much sense # CHECK: the difference between 32 and 48 does not make much sense
for nbits2 in 32, 48: for nbits2 in 32, 48:
q = faiss.IndexLSH(d, nbits2) q = faiss.IndexLSH(d, nbits2)
res = ev.launch('LSH half size', q) res = ev.launch("LSH half size", q)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.003 0.019 0.108 # should give 0.003 0.019 0.108
assert e[10] > 0.018 assert e[10] > 0.018
def test_IndexPQ(self): def test_IndexPQ(self):
q = faiss.IndexPQ(d, M, nbits_per_index) q = faiss.IndexPQ(d, M, nbits_per_index)
res = ev.launch('FLAT / PQ L2', q) res = ev.launch("FLAT / PQ L2", q)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.070 0.230 0.260 # should give 0.070 0.230 0.260
assert e[10] > 0.2 assert e[10] > 0.2
@ -85,7 +86,7 @@ class IndexAccuracy(unittest.TestCase):
# Approximate search module: PQ with inner product distance # Approximate search module: PQ with inner product distance
def test_IndexPQ_ip(self): def test_IndexPQ_ip(self):
q = faiss.IndexPQ(d, M, nbits_per_index, faiss.METRIC_INNER_PRODUCT) q = faiss.IndexPQ(d, M, nbits_per_index, faiss.METRIC_INNER_PRODUCT)
res = ev.launch('FLAT / PQ IP', q) res = ev.launch("FLAT / PQ IP", q)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.070 0.230 0.260 # should give 0.070 0.230 0.260
# (same result as regular PQ on normalized distances) # (same result as regular PQ on normalized distances)
@ -94,7 +95,7 @@ class IndexAccuracy(unittest.TestCase):
def test_IndexIVFPQ(self): def test_IndexIVFPQ(self):
ivfpq = faiss.IndexIVFPQ(faiss.IndexFlatL2(d), d, ncentroids, M, 8) ivfpq = faiss.IndexIVFPQ(faiss.IndexFlatL2(d), d, ncentroids, M, 8)
ivfpq.nprobe = kprobe ivfpq.nprobe = kprobe
res = ev.launch('IVF PQ', ivfpq) res = ev.launch("IVF PQ", ivfpq)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.070 0.230 0.260 # should give 0.070 0.230 0.260
assert e[10] > 0.2 assert e[10] > 0.2
@ -104,17 +105,17 @@ class IndexAccuracy(unittest.TestCase):
# Approximate search: PQ with full vector refinement # Approximate search: PQ with full vector refinement
def test_IndexPQ_refined(self): def test_IndexPQ_refined(self):
q = faiss.IndexPQ(d, M, nbits_per_index) q = faiss.IndexPQ(d, M, nbits_per_index)
res = ev.launch('PQ non-refined', q) res = ev.launch("PQ non-refined", q)
e = ev.evalres(res) e = ev.evalres(res)
q.reset() q.reset()
rq = faiss.IndexRefineFlat(q) rq = faiss.IndexRefineFlat(q)
res = ev.launch('PQ refined', rq) res = ev.launch("PQ refined", rq)
e2 = ev.evalres(res) e2 = ev.evalres(res)
assert e2[10] >= e[10] assert e2[10] >= e[10]
rq.k_factor = 4 rq.k_factor = 4
res = ev.launch('PQ refined*4', rq) res = ev.launch("PQ refined*4", rq)
e3 = ev.evalres(res) e3 = ev.evalres(res)
assert e3[10] >= e2[10] assert e3[10] >= e2[10]
@ -124,17 +125,16 @@ class IndexAccuracy(unittest.TestCase):
# reduce nb iterations to speed up training for the test # reduce nb iterations to speed up training for the test
index.polysemous_training.n_iter = 50000 index.polysemous_training.n_iter = 50000
index.polysemous_training.n_redo = 1 index.polysemous_training.n_redo = 1
res = ev.launch('normal PQ', index) res = ev.launch("normal PQ", index)
e_baseline = ev.evalres(res) e_baseline = ev.evalres(res)
index.search_type = faiss.IndexPQ.ST_polysemous index.search_type = faiss.IndexPQ.ST_polysemous
index.polysemous_ht = int(M / 16. * 58) index.polysemous_ht = int(M / 16.0 * 58)
stats = faiss.cvar.indexPQ_stats stats = faiss.cvar.indexPQ_stats
stats.reset() stats.reset()
res = ev.launch('Polysemous ht=%d' % index.polysemous_ht, res = ev.launch("Polysemous ht=%d" % index.polysemous_ht, index)
index)
e_polysemous = ev.evalres(res) e_polysemous = ev.evalres(res)
print(e_baseline, e_polysemous, index.polysemous_ht) print(e_baseline, e_polysemous, index.polysemous_ht)
print(stats.n_hamming_pass, stats.ncode) print(stats.n_hamming_pass, stats.ncode)
@ -149,10 +149,10 @@ class IndexAccuracy(unittest.TestCase):
def test_ScalarQuantizer(self): def test_ScalarQuantizer(self):
quantizer = faiss.IndexFlatL2(d) quantizer = faiss.IndexFlatL2(d)
ivfpq = faiss.IndexIVFScalarQuantizer( ivfpq = faiss.IndexIVFScalarQuantizer(
quantizer, d, ncentroids, quantizer, d, ncentroids, faiss.ScalarQuantizer.QT_8bit
faiss.ScalarQuantizer.QT_8bit) )
ivfpq.nprobe = kprobe ivfpq.nprobe = kprobe
res = ev.launch('IVF SQ', ivfpq) res = ev.launch("IVF SQ", ivfpq)
e = ev.evalres(res) e = ev.evalres(res)
# should give 0.234 0.236 0.236 # should give 0.234 0.236 0.236
assert e[10] > 0.235 assert e[10] > 0.235
@ -170,13 +170,10 @@ class IndexAccuracy(unittest.TestCase):
class TestSQFlavors(unittest.TestCase): class TestSQFlavors(unittest.TestCase):
""" tests IP in addition to L2, non multiple of 8 dimensions """tests IP in addition to L2, non multiple of 8 dimensions"""
"""
def add2columns(self, x): def add2columns(self, x):
return np.hstack(( return np.hstack((x, np.zeros((x.shape[0], 2), dtype="float32")))
x, np.zeros((x.shape[0], 2), dtype='float32')
))
def subtest_add2col(self, xb, xq, index, qname): def subtest_add2col(self, xb, xq, index, qname):
"""Test with 2 additional dimensions to take also the non-SIMD """Test with 2 additional dimensions to take also the non-SIMD
@ -197,18 +194,16 @@ class TestSQFlavors(unittest.TestCase):
centroids2 = self.add2columns(centroids) centroids2 = self.add2columns(centroids)
quantizer2.add(centroids2) quantizer2.add(centroids2)
index2 = faiss.IndexIVFScalarQuantizer( index2 = faiss.IndexIVFScalarQuantizer(
quantizer2, d2, index.nlist, index.sq.qtype, quantizer2, d2, index.nlist, index.sq.qtype, index.metric_type
index.metric_type) )
index2.nprobe = 4 index2.nprobe = 4
if qname in ('8bit', '4bit'): if qname in ("8bit", "4bit"):
trained = faiss.vector_to_array(index.sq.trained).reshape(2, -1) trained = faiss.vector_to_array(index.sq.trained).reshape(2, -1)
nt = trained.shape[1] nt = trained.shape[1]
# 2 lines: vmins and vdiffs # 2 lines: vmins and vdiffs
new_nt = int(nt * d2 / d) new_nt = int(nt * d2 / d)
trained2 = np.hstack(( trained2 = np.hstack((trained, np.zeros((2, new_nt - nt),
trained, dtype="float32")))
np.zeros((2, new_nt - nt), dtype='float32')
))
trained2[1, nt:] = 1.0 # set vdiff to 1 to avoid div by 0 trained2[1, nt:] = 1.0 # set vdiff to 1 to avoid div by 0
faiss.copy_array_to_vector(trained2.ravel(), index2.sq.trained) faiss.copy_array_to_vector(trained2.ravel(), index2.sq.trained)
else: else:
@ -218,22 +213,21 @@ class TestSQFlavors(unittest.TestCase):
index2.add(xb2) index2.add(xb2)
return index2.search(xq2, 10) return index2.search(xq2, 10)
# run on Sept 18, 2018 with nprobe=4 + 4 bit bugfix # run on Sept 18, 2018 with nprobe=4 + 4 bit bugfix
ref_results = { ref_results = {
(0, '8bit'): 984, (0, "8bit"): 984,
(0, '4bit'): 978, (0, "4bit"): 978,
(0, '8bit_uniform'): 985, (0, "8bit_uniform"): 985,
(0, '4bit_uniform'): 979, (0, "4bit_uniform"): 979,
(0, 'fp16'): 985, (0, "fp16"): 985,
(1, '8bit'): 979, (1, "8bit"): 979,
(1, '4bit'): 973, (1, "4bit"): 973,
(1, '8bit_uniform'): 979, (1, "8bit_uniform"): 979,
(1, '4bit_uniform'): 972, (1, "4bit_uniform"): 972,
(1, 'fp16'): 979, (1, "fp16"): 979,
# added 2019-06-26 # added 2019-06-26
(0, '6bit'): 985, (0, "6bit"): 985,
(1, '6bit'): 987, (1, "6bit"): 987,
} }
def subtest(self, mt): def subtest(self, mt):
@ -245,19 +239,19 @@ class TestSQFlavors(unittest.TestCase):
gt_index.add(xb) gt_index.add(xb)
gt_D, gt_I = gt_index.search(xq, 10) gt_D, gt_I = gt_index.search(xq, 10)
quantizer = faiss.IndexFlat(d, mt) quantizer = faiss.IndexFlat(d, mt)
for qname in '8bit 4bit 8bit_uniform 4bit_uniform fp16 6bit'.split(): for qname in "8bit 4bit 8bit_uniform 4bit_uniform fp16 6bit".split():
qtype = getattr(faiss.ScalarQuantizer, 'QT_' + qname) qtype = getattr(faiss.ScalarQuantizer, "QT_" + qname)
index = faiss.IndexIVFScalarQuantizer( index = faiss.IndexIVFScalarQuantizer(quantizer, d, nlist, qtype,
quantizer, d, nlist, qtype, mt) mt)
index.train(xt) index.train(xt)
index.add(xb) index.add(xb)
index.nprobe = 4 # hopefully more robust than 1 index.nprobe = 4 # hopefully more robust than 1
D, I = index.search(xq, 10) D, I = index.search(xq, 10)
ninter = faiss.eval_intersection(I, gt_I) ninter = faiss.eval_intersection(I, gt_I)
print('(%d, %s): %d, ' % (mt, repr(qname), ninter)) print("(%d, %s): %d, " % (mt, repr(qname), ninter))
assert abs(ninter - self.ref_results[(mt, qname)]) <= 10 assert abs(ninter - self.ref_results[(mt, qname)]) <= 10
if qname == '6bit': if qname == "6bit":
# the test below fails triggers ASAN. TODO check what's wrong # the test below fails triggers ASAN. TODO check what's wrong
continue continue
@ -270,7 +264,7 @@ class TestSQFlavors(unittest.TestCase):
radius = float(D[:, -1].max()) radius = float(D[:, -1].max())
else: else:
radius = float(D[:, -1].min()) radius = float(D[:, -1].min())
print('radius', radius) print("radius", radius)
lims, D3, I3 = index.range_search(xq, radius) lims, D3, I3 = index.range_search(xq, radius)
ntot = ndiff = 0 ntot = ndiff = 0
@ -284,19 +278,22 @@ class TestSQFlavors(unittest.TestCase):
Iref = set(I2[i, mask]) Iref = set(I2[i, mask])
ndiff += len(Inew ^ Iref) ndiff += len(Inew ^ Iref)
ntot += len(Iref) ntot += len(Iref)
print('ndiff %d / %d' % (ndiff, ntot)) print("ndiff %d / %d" % (ndiff, ntot))
assert ndiff < ntot * 0.01 assert ndiff < ntot * 0.01
for pm in 1, 2: for pm in 1, 2:
print('parallel_mode=%d' % pm) print("parallel_mode=%d" % pm)
index.parallel_mode = pm index.parallel_mode = pm
lims4, D4, I4 = index.range_search(xq, radius) lims4, D4, I4 = index.range_search(xq, radius)
print('sizes', lims4[1:] - lims4[:-1]) print("sizes", lims4[1:] - lims4[:-1])
for qno in range(len(lims) - 1): for qno in range(len(lims) - 1):
Iref = I3[lims[qno]: lims[qno + 1]] Iref = I3[lims[qno]: lims[qno + 1]]
Inew = I4[lims4[qno]: lims4[qno + 1]] Inew = I4[lims4[qno]: lims4[qno + 1]]
assert set(Iref) == set(Inew), "q %d ref %s new %s" % ( assert set(Iref) == set(Inew), "q %d ref %s new %s" % (
qno, Iref, Inew) qno,
Iref,
Inew,
)
def test_SQ_IP(self): def test_SQ_IP(self):
self.subtest(faiss.METRIC_INNER_PRODUCT) self.subtest(faiss.METRIC_INNER_PRODUCT)
@ -323,7 +320,6 @@ class TestSQFlavors(unittest.TestCase):
class TestSQByte(unittest.TestCase): class TestSQByte(unittest.TestCase):
def subtest_8bit_direct(self, metric_type, d): def subtest_8bit_direct(self, metric_type, d):
xt, xb, xq = get_dataset_2(d, 500, 1000, 30) xt, xb, xq = get_dataset_2(d, 500, 1000, 30)
@ -345,7 +341,8 @@ class TestSQByte(unittest.TestCase):
Dref, Iref = gt_index.search(xq, 10) Dref, Iref = gt_index.search(xq, 10)
index = faiss.IndexScalarQuantizer( index = faiss.IndexScalarQuantizer(
d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type) d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type
)
index.add(xb) index.add(xb)
D, I = index.search(xq, 10) D, I = index.search(xq, 10)
@ -364,8 +361,9 @@ class TestSQByte(unittest.TestCase):
Dref, Iref = gt_index.search(xq, 10) Dref, Iref = gt_index.search(xq, 10)
index = faiss.IndexIVFScalarQuantizer( index = faiss.IndexIVFScalarQuantizer(
quantizer, d, nlist, quantizer, d, nlist, faiss.ScalarQuantizer.QT_8bit_direct,
faiss.ScalarQuantizer.QT_8bit_direct, metric_type) metric_type
)
index.nprobe = 4 index.nprobe = 4
index.by_residual = False index.by_residual = False
index.train(xt) index.train(xt)
@ -382,7 +380,6 @@ class TestSQByte(unittest.TestCase):
class TestNNDescent(unittest.TestCase): class TestNNDescent(unittest.TestCase):
def test_L1(self): def test_L1(self):
search_Ls = [10, 20, 30] search_Ls = [10, 20, 30]
thresholds = [0.83, 0.92, 0.95] thresholds = [0.83, 0.92, 0.95]
@ -402,9 +399,11 @@ class TestNNDescent(unittest.TestCase):
self.subtest(32, faiss.METRIC_INNER_PRODUCT, 10, search_L, threshold) self.subtest(32, faiss.METRIC_INNER_PRODUCT, 10, search_L, threshold)
def subtest(self, d, metric, topk, search_L, threshold): def subtest(self, d, metric, topk, search_L, threshold):
metric_names = {faiss.METRIC_L1: 'L1', metric_names = {
faiss.METRIC_L2: 'L2', faiss.METRIC_L1: "L1",
faiss.METRIC_INNER_PRODUCT: 'IP'} faiss.METRIC_L2: "L2",
faiss.METRIC_INNER_PRODUCT: "IP",
}
topk = 10 topk = 10
nt, nb, nq = 2000, 1000, 200 nt, nb, nq = 2000, 1000, 200
xt, xb, xq = get_dataset_2(d, nt, nb, nq) xt, xb, xq = get_dataset_2(d, nt, nb, nq)
@ -432,9 +431,12 @@ class TestNNDescent(unittest.TestCase):
recalls += 1 recalls += 1
break break
recall = 1.0 * recalls / (nq * topk) recall = 1.0 * recalls / (nq * topk)
print('Metric: {}, L: {}, Recall@{}: {}'.format( print(
metric_names[metric], search_L, topk, recall)) "Metric: {}, L: {}, Recall@{}: {}".format(
assert recall > threshold, '{} <= {}'.format(recall, threshold) metric_names[metric], search_L, topk, recall
)
)
assert recall > threshold, "{} <= {}".format(recall, threshold)
class TestPQFlavors(unittest.TestCase): class TestPQFlavors(unittest.TestCase):
@ -466,8 +468,7 @@ class TestPQFlavors(unittest.TestCase):
quantizer = faiss.IndexFlat(d, mt) quantizer = faiss.IndexFlat(d, mt)
for by_residual in True, False: for by_residual in True, False:
index = faiss.IndexIVFPQ( index = faiss.IndexIVFPQ(quantizer, d, nlist, 4, 8)
quantizer, d, nlist, 4, 8)
index.metric_type = mt index.metric_type = mt
index.by_residual = by_residual index.by_residual = by_residual
if by_residual: if by_residual:
@ -484,7 +485,7 @@ class TestPQFlavors(unittest.TestCase):
D, I = index.search(xq, 10) D, I = index.search(xq, 10)
ninter = faiss.eval_intersection(I, gt_I) ninter = faiss.eval_intersection(I, gt_I)
print('(%d, %s): %d, ' % (mt, by_residual, ninter)) print("(%d, %s): %d, " % (mt, by_residual, ninter))
assert abs(ninter - self.ref_results[mt, by_residual]) <= 3 assert abs(ninter - self.ref_results[mt, by_residual]) <= 3
@ -498,12 +499,16 @@ class TestPQFlavors(unittest.TestCase):
index.polysemous_ht = 20 index.polysemous_ht = 20
D, I = index.search(xq, 10) D, I = index.search(xq, 10)
ninter = faiss.eval_intersection(I, gt_I) ninter = faiss.eval_intersection(I, gt_I)
print('(%d, %s, %d): %d, ' % ( print(
mt, by_residual, index.polysemous_ht, ninter)) "(%d, %s, %d): %d, "
% (mt, by_residual, index.polysemous_ht, ninter)
)
# polysemous behaves bizarrely on ARM # polysemous behaves bizarrely on ARM
assert (ninter >= self.ref_results[ assert (
mt, by_residual, index.polysemous_ht] - 4) ninter >= self.ref_results[mt, by_residual,
index.polysemous_ht] - 4
)
# also test range search # also test range search
@ -511,7 +516,7 @@ class TestPQFlavors(unittest.TestCase):
radius = float(D[:, -1].max()) radius = float(D[:, -1].max())
else: else:
radius = float(D[:, -1].min()) radius = float(D[:, -1].min())
print('radius', radius) print("radius", radius)
lims, D3, I3 = index.range_search(xq, radius) lims, D3, I3 = index.range_search(xq, radius)
ntot = ndiff = 0 ntot = ndiff = 0
@ -525,7 +530,7 @@ class TestPQFlavors(unittest.TestCase):
Iref = set(I2[i, mask]) Iref = set(I2[i, mask])
ndiff += len(Inew ^ Iref) ndiff += len(Inew ^ Iref)
ntot += len(Iref) ntot += len(Iref)
print('ndiff %d / %d' % (ndiff, ntot)) print("ndiff %d / %d" % (ndiff, ntot))
assert ndiff < ntot * 0.02 assert ndiff < ntot * 0.02
def test_IVFPQ_non8bit(self): def test_IVFPQ_non8bit(self):
@ -539,36 +544,33 @@ class TestPQFlavors(unittest.TestCase):
quantizer = faiss.IndexFlat(d) quantizer = faiss.IndexFlat(d)
ninter = {} ninter = {}
for v in '2x8', '8x2': for v in "2x8", "8x2":
if v == '8x2': if v == "8x2":
index = faiss.IndexIVFPQ( index = faiss.IndexIVFPQ(quantizer, d, nlist, 2, 8)
quantizer, d, nlist, 2, 8)
else: else:
index = faiss.IndexIVFPQ( index = faiss.IndexIVFPQ(quantizer, d, nlist, 8, 2)
quantizer, d, nlist, 8, 2)
index.train(xt) index.train(xt)
index.add(xb) index.add(xb)
index.npobe = 16 index.npobe = 16
D, I = index.search(xq, 10) D, I = index.search(xq, 10)
ninter[v] = faiss.eval_intersection(I, gt_I) ninter[v] = faiss.eval_intersection(I, gt_I)
print('ninter=', ninter) print("ninter=", ninter)
# this should be the case but we don't observe # this should be the case but we don't observe
# that... Probavly too few test points # that... Probavly too few test points
# assert ninter['2x8'] > ninter['8x2'] # assert ninter['2x8'] > ninter['8x2']
# ref numbers on 2019-11-02 # ref numbers on 2019-11-02
assert abs(ninter['2x8'] - 458) < 4 assert abs(ninter["2x8"] - 458) < 4
assert abs(ninter['8x2'] - 465) < 4 assert abs(ninter["8x2"] - 465) < 4
class TestFlat1D(unittest.TestCase): class TestFlat1D(unittest.TestCase):
def test_flat_1d(self): def test_flat_1d(self):
rs = np.random.RandomState(123545) rs = np.random.RandomState(123545)
k = 10 k = 10
xb = rs.uniform(size=(100, 1)).astype('float32') xb = rs.uniform(size=(100, 1)).astype("float32")
# make sure to test below and above # make sure to test below and above
xq = rs.uniform(size=(1000, 1)).astype('float32') * 1.1 - 0.05 xq = rs.uniform(size=(1000, 1)).astype("float32") * 1.1 - 0.05
ref = faiss.IndexFlatL2(1) ref = faiss.IndexFlatL2(1)
ref.add(xb) ref.add(xb)
@ -581,10 +583,10 @@ class TestFlat1D(unittest.TestCase):
ndiff = (np.abs(ref_I - new_I) != 0).sum() ndiff = (np.abs(ref_I - new_I) != 0).sum()
assert(ndiff < 100) assert ndiff < 100
new_D = new_D ** 2 new_D = new_D ** 2
max_diff_D = np.abs(ref_D - new_D).max() max_diff_D = np.abs(ref_D - new_D).max()
assert(max_diff_D < 1e-5) assert max_diff_D < 1e-5
class OPQRelativeAccuracy(unittest.TestCase): class OPQRelativeAccuracy(unittest.TestCase):
@ -598,7 +600,7 @@ class OPQRelativeAccuracy(unittest.TestCase):
d = ev.d d = ev.d
index = faiss.IndexPQ(d, M, 8) index = faiss.IndexPQ(d, M, 8)
res = ev.launch('PQ', index) res = ev.launch("PQ", index)
e_pq = ev.evalres(res) e_pq = ev.evalres(res)
index_pq = faiss.IndexPQ(d, M, 8) index_pq = faiss.IndexPQ(d, M, 8)
@ -608,15 +610,15 @@ class OPQRelativeAccuracy(unittest.TestCase):
opq_matrix.niter_pq = 4 opq_matrix.niter_pq = 4
index = faiss.IndexPreTransform(opq_matrix, index_pq) index = faiss.IndexPreTransform(opq_matrix, index_pq)
res = ev.launch('OPQ', index) res = ev.launch("OPQ", index)
e_opq = ev.evalres(res) e_opq = ev.evalres(res)
print('e_pq=%s' % e_pq) print("e_pq=%s" % e_pq)
print('e_opq=%s' % e_opq) print("e_opq=%s" % e_opq)
# verify that OPQ better than PQ # verify that OPQ better than PQ
for r in 1, 10, 100: for r in 1, 10, 100:
assert(e_opq[r] > e_pq[r]) assert e_opq[r] > e_pq[r]
def test_OIVFPQ(self): def test_OIVFPQ(self):
# Parameters inverted indexes # Parameters inverted indexes
@ -629,7 +631,7 @@ class OPQRelativeAccuracy(unittest.TestCase):
index = faiss.IndexIVFPQ(quantizer, d, ncentroids, M, 8) index = faiss.IndexIVFPQ(quantizer, d, ncentroids, M, 8)
index.nprobe = 5 index.nprobe = 5
res = ev.launch('IVFPQ', index) res = ev.launch("IVFPQ", index)
e_ivfpq = ev.evalres(res) e_ivfpq = ev.evalres(res)
quantizer = faiss.IndexFlatL2(d) quantizer = faiss.IndexFlatL2(d)
@ -639,23 +641,22 @@ class OPQRelativeAccuracy(unittest.TestCase):
opq_matrix.niter = 10 opq_matrix.niter = 10
index = faiss.IndexPreTransform(opq_matrix, index_ivfpq) index = faiss.IndexPreTransform(opq_matrix, index_ivfpq)
res = ev.launch('O+IVFPQ', index) res = ev.launch("O+IVFPQ", index)
e_oivfpq = ev.evalres(res) e_oivfpq = ev.evalres(res)
# verify same on OIVFPQ # verify same on OIVFPQ
for r in 1, 10, 100: for r in 1, 10, 100:
print(e_oivfpq[r], e_ivfpq[r]) print(e_oivfpq[r], e_ivfpq[r])
assert(e_oivfpq[r] >= e_ivfpq[r]) assert e_oivfpq[r] >= e_ivfpq[r]
class TestRoundoff(unittest.TestCase): class TestRoundoff(unittest.TestCase):
def test_roundoff(self): def test_roundoff(self):
# params that force use of BLAS implementation # params that force use of BLAS implementation
nb = 100 nb = 100
nq = 25 nq = 25
d = 4 d = 4
xb = np.zeros((nb, d), dtype='float32') xb = np.zeros((nb, d), dtype="float32")
xb[:, 0] = np.arange(nb) + 12345 xb[:, 0] = np.arange(nb) + 12345
xq = xb[:nq] + 0.3 xq = xb[:nq] + 0.3
@ -668,8 +669,7 @@ class TestRoundoff(unittest.TestCase):
# this does not work # this does not work
assert not np.all(I.ravel() == np.arange(nq)) assert not np.all(I.ravel() == np.arange(nq))
index = faiss.IndexPreTransform( index = faiss.IndexPreTransform(faiss.CenteringTransform(d),
faiss.CenteringTransform(d),
faiss.IndexFlat(d)) faiss.IndexFlat(d))
index.train(xb) index.train(xb)
@ -685,30 +685,30 @@ class TestSpectralHash(unittest.TestCase):
# run on 2019-04-02 # run on 2019-04-02
ref_results = { ref_results = {
(32, 'global', 10): 505, (32, "global", 10): 505,
(32, 'centroid', 10): 524, (32, "centroid", 10): 524,
(32, 'centroid_half', 10): 21, (32, "centroid_half", 10): 21,
(32, 'median', 10): 510, (32, "median", 10): 510,
(32, 'global', 1): 8, (32, "global", 1): 8,
(32, 'centroid', 1): 20, (32, "centroid", 1): 20,
(32, 'centroid_half', 1): 26, (32, "centroid_half", 1): 26,
(32, 'median', 1): 14, (32, "median", 1): 14,
(64, 'global', 10): 768, (64, "global", 10): 768,
(64, 'centroid', 10): 767, (64, "centroid", 10): 767,
(64, 'centroid_half', 10): 21, (64, "centroid_half", 10): 21,
(64, 'median', 10): 765, (64, "median", 10): 765,
(64, 'global', 1): 28, (64, "global", 1): 28,
(64, 'centroid', 1): 21, (64, "centroid", 1): 21,
(64, 'centroid_half', 1): 20, (64, "centroid_half", 1): 20,
(64, 'median', 1): 29, (64, "median", 1): 29,
(128, 'global', 10): 968, (128, "global", 10): 968,
(128, 'centroid', 10): 945, (128, "centroid", 10): 945,
(128, 'centroid_half', 10): 21, (128, "centroid_half", 10): 21,
(128, 'median', 10): 958, (128, "median", 10): 958,
(128, 'global', 1): 271, (128, "global", 1): 271,
(128, 'centroid', 1): 279, (128, "centroid", 1): 279,
(128, 'centroid_half', 1): 171, (128, "centroid_half", 1): 171,
(128, 'median', 1): 253, (128, "median", 1): 253,
} }
def test_sh(self): def test_sh(self):
@ -728,17 +728,17 @@ class TestSpectralHash(unittest.TestCase):
D, I = index_lsh.search(xq, 10) D, I = index_lsh.search(xq, 10)
ninter = faiss.eval_intersection(I, gt_I) ninter = faiss.eval_intersection(I, gt_I)
print('LSH baseline: %d' % ninter) print("LSH baseline: %d" % ninter)
for period in 10.0, 1.0: for period in 10.0, 1.0:
for tt in 'global centroid centroid_half median'.split(): for tt in "global centroid centroid_half median".split():
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, index = faiss.IndexIVFSpectralHash(
nbit, period) quantizer, d, nlist, nbit, period
)
index.nprobe = nprobe index.nprobe = nprobe
index.threshold_type = getattr( index.threshold_type = getattr(
faiss.IndexIVFSpectralHash, faiss.IndexIVFSpectralHash, "Thresh_" + tt
'Thresh_' + tt
) )
index.train(xt) index.train(xt)
@ -748,12 +748,13 @@ class TestSpectralHash(unittest.TestCase):
ninter = faiss.eval_intersection(I, gt_I) ninter = faiss.eval_intersection(I, gt_I)
key = (nbit, tt, period) key = (nbit, tt, period)
print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter)) print("(%d, %s, %g): %d, " % (nbit, repr(tt), period,
assert abs(ninter - self.ref_results[key]) <= 12 ninter))
print(abs(ninter - self.ref_results[key]))
assert abs(ninter - self.ref_results[key]) <= 14
class TestRefine(unittest.TestCase): class TestRefine(unittest.TestCase):
def do_test(self, metric): def do_test(self, metric):
d = 32 d = 32
xt, xb, xq = get_dataset_2(d, 2000, 1000, 200) xt, xb, xq = get_dataset_2(d, 2000, 1000, 200)

View File

@ -735,7 +735,7 @@ class TestIndexResidualQuantizerSearch(unittest.TestCase):
self.assertLess((Iref != I2).sum(), Iref.size * 0.05) self.assertLess((Iref != I2).sum(), Iref.size * 0.05)
else: else:
inter_2 = faiss.eval_intersection(I2, gt) inter_2 = faiss.eval_intersection(I2, gt)
self.assertGreater(inter_ref, inter_2) self.assertGreaterEqual(inter_ref, inter_2)
# print(st, inter_ref, inter_2) # print(st, inter_ref, inter_2)