T132029385 support merge for IndexFlatCodes (#2488)

Summary:
support merge for all IndexFlatCodes children
make merge_from and check_compatible_for_merge methods of Index and IndexIVF and IndexFlatCodes(the only supported types) inherit them from Index.

This is part 1 of 2 as merge_into still not updated

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2488

Test Plan:
cd build
make -j
make test
cd faiss/python && python setup.py build
cd ../../..
PYTHONPATH="$(ls -d ./build/faiss/python/build/lib*/)" pytest tests/test_*.py

# fbcode

 buck test //faiss/tests/:test_index_merge

buck test //faiss/tests/:test_io

Reviewed By: mdouze

Differential Revision: D39726378

Pulled By: AbdelrahmanElmeniawy

fbshipit-source-id: 6739477fddcad3c7a990f3aae9be07c1b2b74fef
pull/2489/head
Abdelrahman Elmeniawy 2022-09-23 07:19:21 -07:00 committed by Facebook GitHub Bot
parent f00de85645
commit c6c7862089
12 changed files with 297 additions and 223 deletions

View File

@ -168,4 +168,12 @@ DistanceComputer* Index::get_distance_computer() const {
}
}
void Index::merge_from(Index& /* otherIndex */, idx_t /* add_id */) {
FAISS_THROW_MSG("merge_from() not implemented");
}
void Index::check_compatible_for_merge(const Index& /* otherIndex */) const {
FAISS_THROW_MSG("check_compatible_for_merge() not implemented");
}
} // namespace faiss

View File

@ -252,6 +252,17 @@ struct Index {
* @param x output vectors, size n * d
*/
virtual void sa_decode(idx_t n, const uint8_t* bytes, float* x) const;
/** moves the entries from another dataset to self.
* On output, other is empty.
* add_id is added to all moved ids
* (for sequential ids, this would be this->ntotal) */
virtual void merge_from(Index& otherIndex, idx_t add_id = 0);
/** check that the two indexes are compatible (ie, they are
* trained in the same way and have the same
* parameters). Otherwise throw. */
virtual void check_compatible_for_merge(const Index& otherIndex) const;
};
} // namespace faiss

View File

@ -73,4 +73,28 @@ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
FAISS_THROW_MSG("not implemented");
}
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
const IndexFlatCodes* other =
dynamic_cast<const IndexFlatCodes*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
FAISS_THROW_IF_NOT(other->d == d);
FAISS_THROW_IF_NOT(other->code_size == code_size);
FAISS_THROW_IF_NOT_MSG(
typeid(*this) == typeid(*other),
"can only merge indexes of the same type");
}
void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
check_compatible_for_merge(otherIndex);
IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
codes.resize((ntotal + other->ntotal) * code_size);
memcpy(codes.data() + (ntotal * code_size),
other->codes.data(),
other->ntotal * code_size);
ntotal += other->ntotal;
other->reset();
}
} // namespace faiss

View File

@ -50,6 +50,10 @@ struct IndexFlatCodes : Index {
DistanceComputer* get_distance_computer() const override {
return get_FlatCodesDistanceComputer();
}
void check_compatible_for_merge(const Index& otherIndex) const override;
virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
};
} // namespace faiss

View File

@ -974,26 +974,28 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
// does nothing by default
}
void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
FAISS_THROW_IF_NOT(other.d == d);
FAISS_THROW_IF_NOT(other.nlist == nlist);
FAISS_THROW_IF_NOT(other.code_size == code_size);
const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
FAISS_THROW_IF_NOT(other->d == d);
FAISS_THROW_IF_NOT(other->nlist == nlist);
FAISS_THROW_IF_NOT(other->code_size == code_size);
FAISS_THROW_IF_NOT_MSG(
typeid(*this) == typeid(other),
typeid(*this) == typeid(*other),
"can only merge indexes of the same type");
FAISS_THROW_IF_NOT_MSG(
this->direct_map.no() && other.direct_map.no(),
this->direct_map.no() && other->direct_map.no(),
"merge direct_map not implemented");
}
void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
check_compatible_for_merge(other);
void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
check_compatible_for_merge(otherIndex);
IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
invlists->merge_from(other->invlists, add_id);
invlists->merge_from(other.invlists, add_id);
ntotal += other.ntotal;
other.ntotal = 0;
ntotal += other->ntotal;
other->ntotal = 0;
}
void IndexIVF::replace_invlists(InvertedLists* il, bool own) {

View File

@ -302,15 +302,9 @@ struct IndexIVF : Index, Level1Quantizer {
size_t remove_ids(const IDSelector& sel) override;
/** check that the two indexes are compatible (ie, they are
* trained in the same way and have the same
* parameters). Otherwise throw. */
void check_compatible_for_merge(const IndexIVF& other) const;
void check_compatible_for_merge(const Index& otherIndex) const override;
/** moves the entries from another dataset to self. On output,
* other is empty. add_id is added to all moved ids (for
* sequential ids, this would be this->ntotal */
virtual void merge_from(IndexIVF& other, idx_t add_id);
virtual void merge_from(Index& otherIndex, idx_t add_id) override;
/** copy a subset of the entries index to the other index
*

View File

@ -201,11 +201,11 @@ void IndexIVFPQR::reconstruct_from_offset(
}
}
void IndexIVFPQR::merge_from(IndexIVF& other_in, idx_t add_id) {
IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&other_in);
void IndexIVFPQR::merge_from(Index& otherIndex, idx_t add_id) {
IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
IndexIVF::merge_from(other_in, add_id);
IndexIVF::merge_from(otherIndex, add_id);
refine_codes.insert(
refine_codes.end(),

View File

@ -51,7 +51,7 @@ struct IndexIVFPQR : IndexIVFPQ {
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
const override;
void merge_from(IndexIVF& other, idx_t add_id) override;
void merge_from(Index& otherIndex, idx_t add_id) override;
void search_preassigned(
idx_t n,

View File

@ -278,3 +278,63 @@ class TestPickle(unittest.TestCase):
def test_ivf(self):
self.dump_load_factory("IVF5,Flat")
class Test_IO_VectorTransform(unittest.TestCase):
"""
test write_VectorTransform using IOWriter Pointer
and read_VectorTransform using file name
"""
def test_write_vector_transform(self):
d, n = 32, 1000
x = np.random.uniform(size=(n, d)).astype('float32')
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
index.train(x)
index.add(x)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
writer = faiss.FileIOWriter(fname)
faiss.write_VectorTransform(index.vt, writer)
del writer
vt = faiss.read_VectorTransform(fname)
assert vt.d_in == index.vt.d_in
assert vt.d_out == index.vt.d_out
assert vt.is_trained
finally:
if os.path.exists(fname):
os.unlink(fname)
"""
test write_VectorTransform using file name
and read_VectorTransform using IOWriter Pointer
"""
def test_read_vector_transform(self):
d, n = 32, 1000
x = np.random.uniform(size=(n, d)).astype('float32')
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
index.train(x)
index.add(x)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_VectorTransform(index.vt, fname)
reader = faiss.FileIOReader(fname)
vt = faiss.read_VectorTransform(reader)
del reader
assert vt.d_in == index.vt.d_in
assert vt.d_out == index.vt.d_out
assert vt.is_trained
finally:
if os.path.exists(fname):
os.unlink(fname)

View File

@ -1,66 +0,0 @@
import os
import tempfile
import unittest
import numpy as np
import faiss
class Test_IO_VectorTransform(unittest.TestCase):
"""
test write_VectorTransform using IOWriter Pointer
and read_VectorTransform using file name
"""
def test_write_vector_transform(self):
d, n = 32, 1000
x = np.random.uniform(size=(n, d)).astype('float32')
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
index.train(x)
index.add(x)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
writer = faiss.FileIOWriter(fname)
faiss.write_VectorTransform(index.vt, writer)
del writer
vt = faiss.read_VectorTransform(fname)
assert vt.d_in == index.vt.d_in
assert vt.d_out == index.vt.d_out
assert vt.is_trained
finally:
if os.path.exists(fname):
os.unlink(fname)
"""
test write_VectorTransform using file name
and read_VectorTransform using IOWriter Pointer
"""
def test_read_vector_transform(self):
d, n = 32, 1000
x = np.random.uniform(size=(n, d)).astype('float32')
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
index.train(x)
index.add(x)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_VectorTransform(index.vt, fname)
reader = faiss.FileIOReader(fname)
vt = faiss.read_VectorTransform(reader)
del reader
assert vt.d_in == index.vt.d_in
assert vt.d_out == index.vt.d_out
assert vt.is_trained
finally:
if os.path.exists(fname):
os.unlink(fname)

View File

@ -0,0 +1,169 @@
import unittest
import faiss
import numpy as np
from faiss.contrib.datasets import SyntheticDataset
from common_faiss_tests import Randu10k
ru = Randu10k()
xb = ru.xb
xt = ru.xt
xq = ru.xq
nb, d = xb.shape
nq, d = xq.shape
class Merge(unittest.TestCase):
def make_index_for_merge(self, quant, index_type, master_index):
ncent = 40
if index_type == 1:
index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
if master_index:
index.is_trained = True
elif index_type == 2:
index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
if master_index:
index.pq = master_index.pq
index.is_trained = True
elif index_type == 3:
index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
if master_index:
index.pq = master_index.pq
index.refine_pq = master_index.refine_pq
index.is_trained = True
elif index_type == 4:
# quant used as the actual index
index = faiss.IndexIDMap(quant)
return index
def do_test_merge(self, index_type):
k = 16
quant = faiss.IndexFlatL2(d)
ref_index = self.make_index_for_merge(quant, index_type, False)
# trains the quantizer
ref_index.train(xt)
print('ref search')
ref_index.add(xb)
_Dref, Iref = ref_index.search(xq, k)
print(Iref[:5, :6])
indexes = []
ni = 3
for i in range(ni):
i0 = int(i * nb / ni)
i1 = int((i + 1) * nb / ni)
index = self.make_index_for_merge(quant, index_type, ref_index)
index.is_trained = True
index.add(xb[i0:i1])
indexes.append(index)
index = indexes[0]
for i in range(1, ni):
print('merge ntotal=%d other.ntotal=%d ' % (
index.ntotal, indexes[i].ntotal))
index.merge_from(indexes[i], index.ntotal)
_D, I = index.search(xq, k)
print(I[:5, :6])
ndiff = (I != Iref).sum()
print('%d / %d differences' % (ndiff, nq * k))
assert (ndiff < nq * k / 1000.)
def test_merge(self):
self.do_test_merge(1)
self.do_test_merge(2)
self.do_test_merge(3)
def do_test_remove(self, index_type):
k = 16
quant = faiss.IndexFlatL2(d)
index = self.make_index_for_merge(quant, index_type, None)
# trains the quantizer
index.train(xt)
if index_type < 4:
index.add(xb)
else:
gen = np.random.RandomState(1234)
id_list = gen.permutation(nb * 7)[:nb].astype('int64')
index.add_with_ids(xb, id_list)
print('ref search ntotal=%d' % index.ntotal)
Dref, Iref = index.search(xq, k)
toremove = np.zeros(nq * k, dtype='int64')
nr = 0
for i in range(nq):
for j in range(k):
# remove all even results (it's ok if there are duplicates
# in the list of ids)
if Iref[i, j] % 2 == 0:
nr = nr + 1
toremove[nr] = Iref[i, j]
print('nr=', nr)
idsel = faiss.IDSelectorBatch(
nr, faiss.swig_ptr(toremove))
for i in range(nr):
assert (idsel.is_member(int(toremove[i])))
nremoved = index.remove_ids(idsel)
print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
D, I = index.search(xq, k)
# make sure results are in the same order with even ones removed
ndiff = 0
for i in range(nq):
j2 = 0
for j in range(k):
if Iref[i, j] % 2 != 0:
if I[i, j2] != Iref[i, j]:
ndiff += 1
assert abs(D[i, j2] - Dref[i, j]) < 1e-5
j2 += 1
# draws are ordered arbitrarily
assert ndiff < 5
def test_remove(self):
self.do_test_remove(1)
self.do_test_remove(2)
self.do_test_remove(4)
# Test merge_from method for all IndexFlatCodes Types
class IndexFlatCodes_merge(unittest.TestCase):
def do_flat_codes_test(self, index):
ds = SyntheticDataset(32, 300, 300, 100)
index1 = faiss.index_factory(ds.d, index)
index1.train(ds.get_train())
index1.add(ds.get_database())
_, Iref = index1.search(ds.get_queries(), 5)
index1.reset()
index2 = faiss.clone_index(index1)
index1.add(ds.get_database()[:100])
index2.add(ds.get_database()[100:])
index1.merge_from(index2)
_, Inew = index1.search(ds.get_queries(), 5)
np.testing.assert_array_equal(Inew, Iref)
def test_merge_IndexFlat(self):
self.do_flat_codes_test("Flat")
def test_merge_IndexPQ(self):
self.do_flat_codes_test("PQ8")
def test_merge_IndexLSH(self):
self.do_flat_codes_test("LSHr")
def test_merge_IndexScalarQuantizer(self):
self.do_flat_codes_test("SQ4")

View File

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import os
import sys
import numpy as np
import faiss
import unittest
@ -129,138 +128,7 @@ class Shards(unittest.TestCase):
ndiff = (I != Iref).sum()
print('%d / %d differences' % (ndiff, nq * k))
assert(ndiff < nq * k / 1000.)
class Merge(unittest.TestCase):
def make_index_for_merge(self, quant, index_type, master_index):
ncent = 40
if index_type == 1:
index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
if master_index:
index.is_trained = True
elif index_type == 2:
index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
if master_index:
index.pq = master_index.pq
index.is_trained = True
elif index_type == 3:
index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
if master_index:
index.pq = master_index.pq
index.refine_pq = master_index.refine_pq
index.is_trained = True
elif index_type == 4:
# quant used as the actual index
index = faiss.IndexIDMap(quant)
return index
def do_test_merge(self, index_type):
k = 16
quant = faiss.IndexFlatL2(d)
ref_index = self.make_index_for_merge(quant, index_type, False)
# trains the quantizer
ref_index.train(xt)
print('ref search')
ref_index.add(xb)
_Dref, Iref = ref_index.search(xq, k)
print(Iref[:5, :6])
indexes = []
ni = 3
for i in range(ni):
i0 = int(i * nb / ni)
i1 = int((i + 1) * nb / ni)
index = self.make_index_for_merge(quant, index_type, ref_index)
index.is_trained = True
index.add(xb[i0:i1])
indexes.append(index)
index = indexes[0]
for i in range(1, ni):
print('merge ntotal=%d other.ntotal=%d ' % (
index.ntotal, indexes[i].ntotal))
index.merge_from(indexes[i], index.ntotal)
_D, I = index.search(xq, k)
print(I[:5, :6])
ndiff = (I != Iref).sum()
print('%d / %d differences' % (ndiff, nq * k))
assert(ndiff < nq * k / 1000.)
def test_merge(self):
self.do_test_merge(1)
self.do_test_merge(2)
self.do_test_merge(3)
def do_test_remove(self, index_type):
k = 16
quant = faiss.IndexFlatL2(d)
index = self.make_index_for_merge(quant, index_type, None)
# trains the quantizer
index.train(xt)
if index_type < 4:
index.add(xb)
else:
gen = np.random.RandomState(1234)
id_list = gen.permutation(nb * 7)[:nb].astype('int64')
index.add_with_ids(xb, id_list)
print('ref search ntotal=%d' % index.ntotal)
Dref, Iref = index.search(xq, k)
toremove = np.zeros(nq * k, dtype='int64')
nr = 0
for i in range(nq):
for j in range(k):
# remove all even results (it's ok if there are duplicates
# in the list of ids)
if Iref[i, j] % 2 == 0:
nr = nr + 1
toremove[nr] = Iref[i, j]
print('nr=', nr)
idsel = faiss.IDSelectorBatch(
nr, faiss.swig_ptr(toremove))
for i in range(nr):
assert(idsel.is_member(int(toremove[i])))
nremoved = index.remove_ids(idsel)
print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
D, I = index.search(xq, k)
# make sure results are in the same order with even ones removed
ndiff = 0
for i in range(nq):
j2 = 0
for j in range(k):
if Iref[i, j] % 2 != 0:
if I[i, j2] != Iref[i, j]:
ndiff += 1
assert abs(D[i, j2] - Dref[i, j]) < 1e-5
j2 += 1
# draws are ordered arbitrarily
assert ndiff < 5
def test_remove(self):
self.do_test_remove(1)
self.do_test_remove(2)
self.do_test_remove(4)
assert (ndiff < nq * k / 1000.)
if __name__ == '__main__':