import unittest import faiss import numpy as np from faiss.contrib.datasets import SyntheticDataset from common_faiss_tests import Randu10k ru = Randu10k() xb = ru.xb xt = ru.xt xq = ru.xq nb, d = xb.shape nq, d = xq.shape class TestMerge1(unittest.TestCase): def make_index_for_merge(self, quant, index_type, master_index): ncent = 40 if index_type == 1: index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2) if master_index: index.is_trained = True elif index_type == 2: index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8) if master_index: index.pq = master_index.pq index.is_trained = True elif index_type == 3: index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8) if master_index: index.pq = master_index.pq index.refine_pq = master_index.refine_pq index.is_trained = True elif index_type == 4: # quant used as the actual index index = faiss.IndexIDMap(quant) return index def do_test_merge(self, index_type): k = 16 quant = faiss.IndexFlatL2(d) ref_index = self.make_index_for_merge(quant, index_type, False) # trains the quantizer ref_index.train(xt) print('ref search') ref_index.add(xb) _Dref, Iref = ref_index.search(xq, k) print(Iref[:5, :6]) indexes = [] ni = 3 for i in range(ni): i0 = int(i * nb / ni) i1 = int((i + 1) * nb / ni) index = self.make_index_for_merge(quant, index_type, ref_index) index.is_trained = True index.add(xb[i0:i1]) indexes.append(index) index = indexes[0] for i in range(1, ni): print('merge ntotal=%d other.ntotal=%d ' % ( index.ntotal, indexes[i].ntotal)) index.merge_from(indexes[i], index.ntotal) _D, I = index.search(xq, k) print(I[:5, :6]) ndiff = (I != Iref).sum() print('%d / %d differences' % (ndiff, nq * k)) assert (ndiff < nq * k / 1000.) def test_merge(self): self.do_test_merge(1) self.do_test_merge(2) self.do_test_merge(3) ###################################### # remove tests that piggyback on merge def do_test_remove(self, index_type): k = 16 quant = faiss.IndexFlatL2(d) index = self.make_index_for_merge(quant, index_type, None) # trains the quantizer index.train(xt) if index_type < 4: index.add(xb) else: gen = np.random.RandomState(1234) id_list = gen.permutation(nb * 7)[:nb].astype('int64') index.add_with_ids(xb, id_list) print('ref search ntotal=%d' % index.ntotal) Dref, Iref = index.search(xq, k) toremove = np.zeros(nq * k, dtype='int64') nr = 0 for i in range(nq): for j in range(k): # remove all even results (it's ok if there are duplicates # in the list of ids) if Iref[i, j] % 2 == 0: nr = nr + 1 toremove[nr] = Iref[i, j] print('nr=', nr) idsel = faiss.IDSelectorBatch( nr, faiss.swig_ptr(toremove)) for i in range(nr): assert (idsel.is_member(int(toremove[i]))) nremoved = index.remove_ids(idsel) print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal)) D, I = index.search(xq, k) # make sure results are in the same order with even ones removed ndiff = 0 for i in range(nq): j2 = 0 for j in range(k): if Iref[i, j] % 2 != 0: if I[i, j2] != Iref[i, j]: ndiff += 1 assert abs(D[i, j2] - Dref[i, j]) < 1e-5 j2 += 1 # draws are ordered arbitrarily assert ndiff < 5 def test_remove(self): self.do_test_remove(1) self.do_test_remove(2) self.do_test_remove(4) # Test merge_from method for all IndexFlatCodes Types class TestMerge2(unittest.TestCase): def do_flat_codes_test(self, factory_key): ds = SyntheticDataset(32, 300, 300, 100) index1 = faiss.index_factory(ds.d, factory_key) index1.train(ds.get_train()) index1.add(ds.get_database()) _, Iref = index1.search(ds.get_queries(), 5) index1.reset() index2 = faiss.clone_index(index1) index1.add(ds.get_database()[:100]) index2.add(ds.get_database()[100:]) index1.merge_from(index2) _, Inew = index1.search(ds.get_queries(), 5) np.testing.assert_array_equal(Inew, Iref) def test_merge_IndexFlat(self): self.do_flat_codes_test("Flat") def test_merge_IndexPQ(self): self.do_flat_codes_test("PQ8") def test_merge_IndexLSH(self): self.do_flat_codes_test("LSHr") def test_merge_IndexScalarQuantizer(self): self.do_flat_codes_test("SQ4") def test_merge_PreTransform(self): self.do_flat_codes_test("PCA16,SQ4") def do_fast_scan_test(self, factory_key, size1): ds = SyntheticDataset(110, 1000, 1000, 100) index1 = faiss.index_factory(ds.d, factory_key) index1.train(ds.get_train()) index1.add(ds.get_database()) _, Iref = index1.search(ds.get_queries(), 5) index1.reset() index2 = faiss.index_factory(ds.d, factory_key) index2.train(ds.get_train()) index1.add(ds.get_database()[:size1]) index2.add(ds.get_database()[size1:]) index1.merge_from(index2) _, Inew = index1.search(ds.get_queries(), 5) np.testing.assert_array_equal(Inew, Iref) def test_merge_IndexFastScan_complete_block(self): self.do_fast_scan_test("PQ5x4fs", 320) def test_merge_IndexFastScan_not_complete_block(self): self.do_fast_scan_test("PQ11x4fs", 310) def test_merge_IndexFastScan_even_M(self): self.do_fast_scan_test("PQ10x4fs", 500) def test_merge_IndexAdditiveQuantizerFastScan(self): self.do_fast_scan_test("RQ10x4fs_32_Nrq2x4", 330) def do_test_with_ids(self, factory_key): ds = SyntheticDataset(32, 300, 300, 100) rs = np.random.RandomState(123) ids = rs.choice(10000, ds.nb, replace=False).astype('int64') index1 = faiss.index_factory(ds.d, factory_key) index1.train(ds.get_train()) index1.add_with_ids(ds.get_database(), ids) _, Iref = index1.search(ds.get_queries(), 5) index1.reset() index2 = faiss.clone_index(index1) index1.add_with_ids(ds.get_database()[:100], ids[:100]) index2.add_with_ids(ds.get_database()[100:], ids[100:]) index1.merge_from(index2) _, Inew = index1.search(ds.get_queries(), 5) np.testing.assert_array_equal(Inew, Iref) if "IDMap2" in factory_key: index1.check_consistency() def test_merge_IDMap(self): self.do_test_with_ids("Flat,IDMap") def test_merge_IDMap2(self): self.do_test_with_ids("Flat,IDMap2")