265 lines
8.5 KiB
Python
265 lines
8.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
import faiss
|
|
import numpy as np
|
|
from faiss.contrib.datasets import SyntheticDataset
|
|
|
|
from common_faiss_tests import Randu10k
|
|
|
|
ru = Randu10k()
|
|
xb = ru.xb
|
|
xt = ru.xt
|
|
xq = ru.xq
|
|
nb, d = xb.shape
|
|
nq, d = xq.shape
|
|
|
|
|
|
class TestMerge1(unittest.TestCase):
|
|
def make_index_for_merge(self, quant, index_type, master_index):
|
|
ncent = 40
|
|
if index_type == 1:
|
|
index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
|
|
if master_index:
|
|
index.is_trained = True
|
|
elif index_type == 2:
|
|
index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
|
|
if master_index:
|
|
index.pq = master_index.pq
|
|
index.is_trained = True
|
|
elif index_type == 3:
|
|
index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
|
|
if master_index:
|
|
index.pq = master_index.pq
|
|
index.refine_pq = master_index.refine_pq
|
|
index.is_trained = True
|
|
elif index_type == 4:
|
|
# quant used as the actual index
|
|
index = faiss.IndexIDMap(quant)
|
|
return index
|
|
|
|
def do_test_merge(self, index_type):
|
|
k = 16
|
|
quant = faiss.IndexFlatL2(d)
|
|
ref_index = self.make_index_for_merge(quant, index_type, False)
|
|
|
|
# trains the quantizer
|
|
ref_index.train(xt)
|
|
|
|
print('ref search')
|
|
ref_index.add(xb)
|
|
_Dref, Iref = ref_index.search(xq, k)
|
|
print(Iref[:5, :6])
|
|
|
|
indexes = []
|
|
ni = 3
|
|
for i in range(ni):
|
|
i0 = int(i * nb / ni)
|
|
i1 = int((i + 1) * nb / ni)
|
|
index = self.make_index_for_merge(quant, index_type, ref_index)
|
|
index.is_trained = True
|
|
index.add(xb[i0:i1])
|
|
indexes.append(index)
|
|
|
|
index = indexes[0]
|
|
|
|
for i in range(1, ni):
|
|
print('merge ntotal=%d other.ntotal=%d ' % (
|
|
index.ntotal, indexes[i].ntotal))
|
|
index.merge_from(indexes[i], index.ntotal)
|
|
|
|
_D, I = index.search(xq, k)
|
|
print(I[:5, :6])
|
|
|
|
ndiff = (I != Iref).sum()
|
|
print('%d / %d differences' % (ndiff, nq * k))
|
|
assert (ndiff < nq * k / 1000.)
|
|
|
|
def test_merge(self):
|
|
self.do_test_merge(1)
|
|
self.do_test_merge(2)
|
|
self.do_test_merge(3)
|
|
|
|
######################################
|
|
# remove tests that piggyback on merge
|
|
|
|
def do_test_remove(self, index_type):
|
|
k = 16
|
|
quant = faiss.IndexFlatL2(d)
|
|
index = self.make_index_for_merge(quant, index_type, None)
|
|
|
|
# trains the quantizer
|
|
index.train(xt)
|
|
|
|
if index_type < 4:
|
|
index.add(xb)
|
|
else:
|
|
gen = np.random.RandomState(1234)
|
|
id_list = gen.permutation(nb * 7)[:nb].astype('int64')
|
|
index.add_with_ids(xb, id_list)
|
|
|
|
print('ref search ntotal=%d' % index.ntotal)
|
|
Dref, Iref = index.search(xq, k)
|
|
|
|
toremove = np.zeros(nq * k, dtype='int64')
|
|
nr = 0
|
|
for i in range(nq):
|
|
for j in range(k):
|
|
# remove all even results (it's ok if there are duplicates
|
|
# in the list of ids)
|
|
if Iref[i, j] % 2 == 0:
|
|
nr = nr + 1
|
|
toremove[nr] = Iref[i, j]
|
|
|
|
print('nr=', nr)
|
|
|
|
idsel = faiss.IDSelectorBatch(
|
|
nr, faiss.swig_ptr(toremove))
|
|
|
|
for i in range(nr):
|
|
assert (idsel.is_member(int(toremove[i])))
|
|
|
|
nremoved = index.remove_ids(idsel)
|
|
|
|
print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
|
|
|
|
D, I = index.search(xq, k)
|
|
|
|
# make sure results are in the same order with even ones removed
|
|
ndiff = 0
|
|
for i in range(nq):
|
|
j2 = 0
|
|
for j in range(k):
|
|
if Iref[i, j] % 2 != 0:
|
|
if I[i, j2] != Iref[i, j]:
|
|
ndiff += 1
|
|
assert abs(D[i, j2] - Dref[i, j]) < 1e-5
|
|
j2 += 1
|
|
# draws are ordered arbitrarily
|
|
assert ndiff < 5
|
|
|
|
def test_remove(self):
|
|
self.do_test_remove(1)
|
|
self.do_test_remove(2)
|
|
self.do_test_remove(4)
|
|
|
|
|
|
# Test merge_from method for all IndexFlatCodes Types
|
|
class TestMerge2(unittest.TestCase):
|
|
|
|
def do_flat_codes_test(self, factory_key):
|
|
ds = SyntheticDataset(32, 300, 300, 100)
|
|
index1 = faiss.index_factory(ds.d, factory_key)
|
|
index1.train(ds.get_train())
|
|
index1.add(ds.get_database())
|
|
_, Iref = index1.search(ds.get_queries(), 5)
|
|
index1.reset()
|
|
index2 = faiss.clone_index(index1)
|
|
index1.add(ds.get_database()[:100])
|
|
index2.add(ds.get_database()[100:])
|
|
index1.merge_from(index2)
|
|
_, Inew = index1.search(ds.get_queries(), 5)
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
|
|
def test_merge_IndexFlat(self):
|
|
self.do_flat_codes_test("Flat")
|
|
|
|
def test_merge_IndexPQ(self):
|
|
self.do_flat_codes_test("PQ8np")
|
|
|
|
def test_merge_IndexLSH(self):
|
|
self.do_flat_codes_test("LSHr")
|
|
|
|
def test_merge_IndexScalarQuantizer(self):
|
|
self.do_flat_codes_test("SQ4")
|
|
|
|
def test_merge_PreTransform(self):
|
|
self.do_flat_codes_test("PCA16,SQ4")
|
|
|
|
def do_fast_scan_test(self, factory_key, size1, with_add_id=False):
|
|
ds = SyntheticDataset(110, 1000, 1000, 100)
|
|
index_trained = faiss.index_factory(ds.d, factory_key)
|
|
index_trained.train(ds.get_train())
|
|
# test both clone and index_read/write
|
|
if True:
|
|
index1 = faiss.deserialize_index(
|
|
faiss.serialize_index(index_trained))
|
|
else:
|
|
index1 = faiss.clone_index(index_trained)
|
|
# assert index1.aq.qnorm.ntotal == index_trained.aq.qnorm.ntotal
|
|
|
|
index1.add(ds.get_database())
|
|
_, Iref = index1.search(ds.get_queries(), 5)
|
|
index1.reset()
|
|
index2 = faiss.clone_index(index_trained)
|
|
index1.add(ds.get_database()[:size1])
|
|
index2.add(ds.get_database()[size1:])
|
|
if with_add_id:
|
|
index1.merge_from(index2, add_id=index1.ntotal)
|
|
else:
|
|
index1.merge_from(index2)
|
|
_, Inew = index1.search(ds.get_queries(), 5)
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
|
|
def test_merge_IndexFastScan_complete_block(self):
|
|
self.do_fast_scan_test("PQ5x4fs", 320)
|
|
|
|
def test_merge_IndexFastScan_not_complete_block(self):
|
|
self.do_fast_scan_test("PQ11x4fs", 310)
|
|
|
|
def test_merge_IndexFastScan_even_M(self):
|
|
self.do_fast_scan_test("PQ10x4fs", 500)
|
|
|
|
def test_merge_IndexAdditiveQuantizerFastScan(self):
|
|
self.do_fast_scan_test("RQ10x4fs_32_Nrq2x4", 330)
|
|
|
|
def test_merge_IVFFastScan(self):
|
|
self.do_fast_scan_test("IVF20,PQ5x4fs", 123, with_add_id=True)
|
|
|
|
def do_test_with_ids(self, factory_key):
|
|
ds = SyntheticDataset(32, 300, 300, 100)
|
|
rs = np.random.RandomState(123)
|
|
ids = rs.choice(10000, ds.nb, replace=False).astype('int64')
|
|
index1 = faiss.index_factory(ds.d, factory_key)
|
|
index1.train(ds.get_train())
|
|
index1.add_with_ids(ds.get_database(), ids)
|
|
_, Iref = index1.search(ds.get_queries(), 5)
|
|
index1.reset()
|
|
index2 = faiss.clone_index(index1)
|
|
index1.add_with_ids(ds.get_database()[:100], ids[:100])
|
|
index2.add_with_ids(ds.get_database()[100:], ids[100:])
|
|
index1.merge_from(index2)
|
|
_, Inew = index1.search(ds.get_queries(), 5)
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
if "IDMap2" in factory_key:
|
|
index1.check_consistency()
|
|
|
|
def test_merge_IDMap(self):
|
|
self.do_test_with_ids("Flat,IDMap")
|
|
|
|
def test_merge_IDMap2(self):
|
|
self.do_test_with_ids("Flat,IDMap2")
|
|
|
|
|
|
class TestRemoveFastScan(unittest.TestCase):
|
|
|
|
def do_fast_scan_test(self, factory_key, size1):
|
|
ds = SyntheticDataset(110, 1000, 1000, 100)
|
|
index1 = faiss.index_factory(ds.d, factory_key)
|
|
index1.train(ds.get_train())
|
|
index1.reset()
|
|
tokeep = [i % 3 == 0 for i in range(ds.nb)]
|
|
index1.add(ds.get_database()[tokeep])
|
|
_, Iref = index1.search(ds.get_queries(), 5)
|
|
index1.reset()
|
|
index1.add(ds.get_database())
|
|
index1.remove_ids(np.where(np.logical_not(tokeep))[0])
|
|
_, Inew = index1.search(ds.get_queries(), 5)
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
|
|
def test_remove(self):
|
|
self.do_fast_scan_test("PQ5x4fs", 320)
|