2017-06-21 21:54:28 +08:00
|
|
|
# Copyright (c) 2015-present, Facebook, Inc.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
2017-07-30 15:18:45 +08:00
|
|
|
# This source code is licensed under the BSD+Patents license found in the
|
2017-06-21 21:54:28 +08:00
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
#! /usr/bin/env python2
|
|
|
|
|
2017-07-18 17:51:27 +08:00
|
|
|
"""this is a basic test script for simple indices work"""
|
2017-06-21 21:54:28 +08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
import faiss
|
|
|
|
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
def get_dataset(d, nb, nt, nq):
|
|
|
|
rs = np.random.RandomState(123)
|
|
|
|
xb = rs.rand(nb, d).astype('float32')
|
|
|
|
xt = rs.rand(nt, d).astype('float32')
|
|
|
|
xq = rs.rand(nq, d).astype('float32')
|
|
|
|
|
|
|
|
return (xt, xb, xq)
|
|
|
|
|
|
|
|
|
2017-06-21 21:54:28 +08:00
|
|
|
class EvalIVFPQAccuracy(unittest.TestCase):
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
def test_IndexIVFPQ(self):
|
2017-06-21 21:54:28 +08:00
|
|
|
d = 64
|
|
|
|
nb = 1000
|
|
|
|
nt = 1500
|
|
|
|
nq = 200
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
2017-06-21 21:54:28 +08:00
|
|
|
d = xt.shape[1]
|
|
|
|
|
|
|
|
gt_index = faiss.IndexFlatL2(d)
|
|
|
|
gt_index.add(xb)
|
|
|
|
D, gt_nns = gt_index.search(xq, 1)
|
|
|
|
|
|
|
|
coarse_quantizer = faiss.IndexFlatL2(d)
|
|
|
|
index = faiss.IndexIVFPQ(coarse_quantizer, d, 25, 16, 8)
|
|
|
|
index.train(xt)
|
|
|
|
index.add(xb)
|
|
|
|
index.nprobe = 5
|
|
|
|
D, nns = index.search(xq, 10)
|
|
|
|
n_ok = (nns == gt_nns).sum()
|
|
|
|
nq = xq.shape[0]
|
|
|
|
|
|
|
|
self.assertGreater(n_ok, nq * 0.4)
|
|
|
|
|
|
|
|
|
|
|
|
class TestMultiIndexQuantizer(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_search_k1(self):
|
|
|
|
|
|
|
|
# verify codepath for k = 1 and k > 1
|
|
|
|
|
|
|
|
d = 64
|
2017-08-10 02:13:51 +08:00
|
|
|
nb = 0
|
2017-06-21 21:54:28 +08:00
|
|
|
nt = 1500
|
|
|
|
nq = 200
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
2017-06-21 21:54:28 +08:00
|
|
|
|
|
|
|
miq = faiss.MultiIndexQuantizer(d, 2, 6)
|
|
|
|
|
|
|
|
miq.train(xt)
|
|
|
|
|
|
|
|
D1, I1 = miq.search(xq, 1)
|
|
|
|
|
|
|
|
D5, I5 = miq.search(xq, 5)
|
|
|
|
|
|
|
|
self.assertEqual(np.abs(I1[:, :1] - I5[:, :1]).max(), 0)
|
|
|
|
self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0)
|
|
|
|
|
|
|
|
|
2017-07-18 17:51:27 +08:00
|
|
|
class TestScalarQuantizer(unittest.TestCase):
|
2017-06-21 21:54:28 +08:00
|
|
|
|
2017-07-18 17:51:27 +08:00
|
|
|
def test_4variants_ivf(self):
|
2017-06-21 21:54:28 +08:00
|
|
|
d = 32
|
|
|
|
nt = 1500
|
|
|
|
nq = 200
|
|
|
|
nb = 10000
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
2017-06-21 21:54:28 +08:00
|
|
|
|
|
|
|
# common quantizer
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
|
|
|
|
|
|
ncent = 128
|
|
|
|
|
|
|
|
index_gt = faiss.IndexFlatL2(d)
|
|
|
|
index_gt.add(xb)
|
|
|
|
D, I_ref = index_gt.search(xq, 10)
|
|
|
|
|
|
|
|
nok = {}
|
|
|
|
|
|
|
|
index = faiss.IndexIVFFlat(quantizer, d, ncent,
|
|
|
|
faiss.METRIC_L2)
|
|
|
|
index.nprobe = 4
|
|
|
|
index.train(xt)
|
|
|
|
index.add(xb)
|
|
|
|
D, I = index.search(xq, 10)
|
|
|
|
nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum()
|
|
|
|
|
|
|
|
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform".split():
|
|
|
|
qtype = getattr(faiss.ScalarQuantizer, qname)
|
|
|
|
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
|
|
|
|
qtype, faiss.METRIC_L2)
|
|
|
|
|
|
|
|
index.nprobe = 4
|
|
|
|
index.train(xt)
|
|
|
|
index.add(xb)
|
|
|
|
D, I = index.search(xq, 10)
|
|
|
|
|
|
|
|
nok[qname] = (I[:, 0] == I_ref[:, 0]).sum()
|
|
|
|
|
|
|
|
print(nok)
|
|
|
|
|
|
|
|
self.assertGreaterEqual(nok['flat'], nok['QT_8bit'])
|
|
|
|
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit'])
|
|
|
|
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
|
|
|
|
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
|
|
|
|
|
2017-07-18 17:51:27 +08:00
|
|
|
def test_4variants(self):
|
|
|
|
d = 32
|
|
|
|
nt = 1500
|
|
|
|
nq = 200
|
|
|
|
nb = 10000
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
2017-06-21 21:54:28 +08:00
|
|
|
|
2017-07-18 17:51:27 +08:00
|
|
|
index_gt = faiss.IndexFlatL2(d)
|
|
|
|
index_gt.add(xb)
|
|
|
|
D, I_ref = index_gt.search(xq, 10)
|
|
|
|
|
|
|
|
nok = {}
|
|
|
|
|
|
|
|
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform".split():
|
|
|
|
qtype = getattr(faiss.ScalarQuantizer, qname)
|
|
|
|
index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
|
|
|
|
index.train(xt)
|
|
|
|
index.add(xb)
|
|
|
|
D, I = index.search(xq, 10)
|
|
|
|
|
|
|
|
nok[qname] = (I[:, 0] == I_ref[:, 0]).sum()
|
|
|
|
|
|
|
|
print(nok)
|
|
|
|
|
|
|
|
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit'])
|
|
|
|
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
|
|
|
|
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
|
2017-06-21 21:54:28 +08:00
|
|
|
|
|
|
|
|
2017-08-10 02:13:51 +08:00
|
|
|
class TestRangeSearch(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_range_search(self):
|
|
|
|
d = 4
|
|
|
|
nt = 100
|
|
|
|
nq = 10
|
|
|
|
nb = 50
|
|
|
|
|
|
|
|
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
|
|
|
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
Dref, Iref = index.search(xq, 5)
|
|
|
|
|
|
|
|
thresh = 0.1 # *squared* distance
|
|
|
|
lims, D, I = index.range_search(xq, thresh)
|
|
|
|
|
|
|
|
for i in range(nq):
|
|
|
|
Iline = I[lims[i]:lims[i + 1]]
|
|
|
|
Dline = D[lims[i]:lims[i + 1]]
|
|
|
|
for j, dis in zip(Iref[i], Dref[i]):
|
|
|
|
if dis < thresh:
|
|
|
|
li, = np.where(Iline == j)
|
|
|
|
self.assertTrue(li.size == 1)
|
|
|
|
idx = li[0]
|
|
|
|
self.assertGreaterEqual(1e-4, abs(Dline[idx] - dis))
|
|
|
|
|
2017-06-21 21:54:28 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|