T132029385 support merge for IndexFlatCodes (#2488)
Summary: support merge for all IndexFlatCodes children make merge_from and check_compatible_for_merge methods of Index and IndexIVF and IndexFlatCodes(the only supported types) inherit them from Index. This is part 1 of 2 as merge_into still not updated Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2488 Test Plan: cd build make -j make test cd faiss/python && python setup.py build cd ../../.. PYTHONPATH="$(ls -d ./build/faiss/python/build/lib*/)" pytest tests/test_*.py # fbcode buck test //faiss/tests/:test_index_merge buck test //faiss/tests/:test_io Reviewed By: mdouze Differential Revision: D39726378 Pulled By: AbdelrahmanElmeniawy fbshipit-source-id: 6739477fddcad3c7a990f3aae9be07c1b2b74fefpull/2489/head
parent
f00de85645
commit
c6c7862089
|
@ -168,4 +168,12 @@ DistanceComputer* Index::get_distance_computer() const {
|
|||
}
|
||||
}
|
||||
|
||||
void Index::merge_from(Index& /* otherIndex */, idx_t /* add_id */) {
|
||||
FAISS_THROW_MSG("merge_from() not implemented");
|
||||
}
|
||||
|
||||
void Index::check_compatible_for_merge(const Index& /* otherIndex */) const {
|
||||
FAISS_THROW_MSG("check_compatible_for_merge() not implemented");
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -252,6 +252,17 @@ struct Index {
|
|||
* @param x output vectors, size n * d
|
||||
*/
|
||||
virtual void sa_decode(idx_t n, const uint8_t* bytes, float* x) const;
|
||||
|
||||
/** moves the entries from another dataset to self.
|
||||
* On output, other is empty.
|
||||
* add_id is added to all moved ids
|
||||
* (for sequential ids, this would be this->ntotal) */
|
||||
virtual void merge_from(Index& otherIndex, idx_t add_id = 0);
|
||||
|
||||
/** check that the two indexes are compatible (ie, they are
|
||||
* trained in the same way and have the same
|
||||
* parameters). Otherwise throw. */
|
||||
virtual void check_compatible_for_merge(const Index& otherIndex) const;
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -73,4 +73,28 @@ FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
|||
FAISS_THROW_MSG("not implemented");
|
||||
}
|
||||
|
||||
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
|
||||
// minimal sanity checks
|
||||
const IndexFlatCodes* other =
|
||||
dynamic_cast<const IndexFlatCodes*>(&otherIndex);
|
||||
FAISS_THROW_IF_NOT(other);
|
||||
FAISS_THROW_IF_NOT(other->d == d);
|
||||
FAISS_THROW_IF_NOT(other->code_size == code_size);
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
typeid(*this) == typeid(*other),
|
||||
"can only merge indexes of the same type");
|
||||
}
|
||||
|
||||
void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
|
||||
FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatCodes index");
|
||||
check_compatible_for_merge(otherIndex);
|
||||
IndexFlatCodes* other = static_cast<IndexFlatCodes*>(&otherIndex);
|
||||
codes.resize((ntotal + other->ntotal) * code_size);
|
||||
memcpy(codes.data() + (ntotal * code_size),
|
||||
other->codes.data(),
|
||||
other->ntotal * code_size);
|
||||
ntotal += other->ntotal;
|
||||
other->reset();
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -50,6 +50,10 @@ struct IndexFlatCodes : Index {
|
|||
DistanceComputer* get_distance_computer() const override {
|
||||
return get_FlatCodesDistanceComputer();
|
||||
}
|
||||
|
||||
void check_compatible_for_merge(const Index& otherIndex) const override;
|
||||
|
||||
virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -974,26 +974,28 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
|
|||
// does nothing by default
|
||||
}
|
||||
|
||||
void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
|
||||
void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
|
||||
// minimal sanity checks
|
||||
FAISS_THROW_IF_NOT(other.d == d);
|
||||
FAISS_THROW_IF_NOT(other.nlist == nlist);
|
||||
FAISS_THROW_IF_NOT(other.code_size == code_size);
|
||||
const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
|
||||
FAISS_THROW_IF_NOT(other);
|
||||
FAISS_THROW_IF_NOT(other->d == d);
|
||||
FAISS_THROW_IF_NOT(other->nlist == nlist);
|
||||
FAISS_THROW_IF_NOT(other->code_size == code_size);
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
typeid(*this) == typeid(other),
|
||||
typeid(*this) == typeid(*other),
|
||||
"can only merge indexes of the same type");
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
this->direct_map.no() && other.direct_map.no(),
|
||||
this->direct_map.no() && other->direct_map.no(),
|
||||
"merge direct_map not implemented");
|
||||
}
|
||||
|
||||
void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
|
||||
check_compatible_for_merge(other);
|
||||
void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
|
||||
check_compatible_for_merge(otherIndex);
|
||||
IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
|
||||
invlists->merge_from(other->invlists, add_id);
|
||||
|
||||
invlists->merge_from(other.invlists, add_id);
|
||||
|
||||
ntotal += other.ntotal;
|
||||
other.ntotal = 0;
|
||||
ntotal += other->ntotal;
|
||||
other->ntotal = 0;
|
||||
}
|
||||
|
||||
void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
|
||||
|
|
|
@ -302,15 +302,9 @@ struct IndexIVF : Index, Level1Quantizer {
|
|||
|
||||
size_t remove_ids(const IDSelector& sel) override;
|
||||
|
||||
/** check that the two indexes are compatible (ie, they are
|
||||
* trained in the same way and have the same
|
||||
* parameters). Otherwise throw. */
|
||||
void check_compatible_for_merge(const IndexIVF& other) const;
|
||||
void check_compatible_for_merge(const Index& otherIndex) const override;
|
||||
|
||||
/** moves the entries from another dataset to self. On output,
|
||||
* other is empty. add_id is added to all moved ids (for
|
||||
* sequential ids, this would be this->ntotal */
|
||||
virtual void merge_from(IndexIVF& other, idx_t add_id);
|
||||
virtual void merge_from(Index& otherIndex, idx_t add_id) override;
|
||||
|
||||
/** copy a subset of the entries index to the other index
|
||||
*
|
||||
|
|
|
@ -201,11 +201,11 @@ void IndexIVFPQR::reconstruct_from_offset(
|
|||
}
|
||||
}
|
||||
|
||||
void IndexIVFPQR::merge_from(IndexIVF& other_in, idx_t add_id) {
|
||||
IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&other_in);
|
||||
void IndexIVFPQR::merge_from(Index& otherIndex, idx_t add_id) {
|
||||
IndexIVFPQR* other = dynamic_cast<IndexIVFPQR*>(&otherIndex);
|
||||
FAISS_THROW_IF_NOT(other);
|
||||
|
||||
IndexIVF::merge_from(other_in, add_id);
|
||||
IndexIVF::merge_from(otherIndex, add_id);
|
||||
|
||||
refine_codes.insert(
|
||||
refine_codes.end(),
|
||||
|
|
|
@ -51,7 +51,7 @@ struct IndexIVFPQR : IndexIVFPQ {
|
|||
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
||||
const override;
|
||||
|
||||
void merge_from(IndexIVF& other, idx_t add_id) override;
|
||||
void merge_from(Index& otherIndex, idx_t add_id) override;
|
||||
|
||||
void search_preassigned(
|
||||
idx_t n,
|
||||
|
|
|
@ -278,3 +278,63 @@ class TestPickle(unittest.TestCase):
|
|||
|
||||
def test_ivf(self):
|
||||
self.dump_load_factory("IVF5,Flat")
|
||||
|
||||
|
||||
class Test_IO_VectorTransform(unittest.TestCase):
|
||||
"""
|
||||
test write_VectorTransform using IOWriter Pointer
|
||||
and read_VectorTransform using file name
|
||||
"""
|
||||
def test_write_vector_transform(self):
|
||||
d, n = 32, 1000
|
||||
x = np.random.uniform(size=(n, d)).astype('float32')
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
|
||||
index.train(x)
|
||||
index.add(x)
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
|
||||
writer = faiss.FileIOWriter(fname)
|
||||
faiss.write_VectorTransform(index.vt, writer)
|
||||
del writer
|
||||
|
||||
vt = faiss.read_VectorTransform(fname)
|
||||
|
||||
assert vt.d_in == index.vt.d_in
|
||||
assert vt.d_out == index.vt.d_out
|
||||
assert vt.is_trained
|
||||
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
||||
"""
|
||||
test write_VectorTransform using file name
|
||||
and read_VectorTransform using IOWriter Pointer
|
||||
"""
|
||||
def test_read_vector_transform(self):
|
||||
d, n = 32, 1000
|
||||
x = np.random.uniform(size=(n, d)).astype('float32')
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
|
||||
index.train(x)
|
||||
index.add(x)
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
|
||||
faiss.write_VectorTransform(index.vt, fname)
|
||||
|
||||
reader = faiss.FileIOReader(fname)
|
||||
vt = faiss.read_VectorTransform(reader)
|
||||
del reader
|
||||
|
||||
assert vt.d_in == index.vt.d_in
|
||||
assert vt.d_out == index.vt.d_out
|
||||
assert vt.is_trained
|
||||
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
||||
|
||||
class Test_IO_VectorTransform(unittest.TestCase):
|
||||
"""
|
||||
test write_VectorTransform using IOWriter Pointer
|
||||
and read_VectorTransform using file name
|
||||
"""
|
||||
def test_write_vector_transform(self):
|
||||
d, n = 32, 1000
|
||||
x = np.random.uniform(size=(n, d)).astype('float32')
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
|
||||
index.train(x)
|
||||
index.add(x)
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
|
||||
writer = faiss.FileIOWriter(fname)
|
||||
faiss.write_VectorTransform(index.vt, writer)
|
||||
del writer
|
||||
|
||||
vt = faiss.read_VectorTransform(fname)
|
||||
|
||||
assert vt.d_in == index.vt.d_in
|
||||
assert vt.d_out == index.vt.d_out
|
||||
assert vt.is_trained
|
||||
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
||||
"""
|
||||
test write_VectorTransform using file name
|
||||
and read_VectorTransform using IOWriter Pointer
|
||||
"""
|
||||
def test_read_vector_transform(self):
|
||||
d, n = 32, 1000
|
||||
x = np.random.uniform(size=(n, d)).astype('float32')
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFSpectralHash(quantizer, d, n, 8, 1.0)
|
||||
index.train(x)
|
||||
index.add(x)
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
|
||||
faiss.write_VectorTransform(index.vt, fname)
|
||||
|
||||
reader = faiss.FileIOReader(fname)
|
||||
vt = faiss.read_VectorTransform(reader)
|
||||
del reader
|
||||
|
||||
assert vt.d_in == index.vt.d_in
|
||||
assert vt.d_out == index.vt.d_out
|
||||
assert vt.is_trained
|
||||
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
|
@ -0,0 +1,169 @@
|
|||
import unittest
|
||||
import faiss
|
||||
import numpy as np
|
||||
from faiss.contrib.datasets import SyntheticDataset
|
||||
|
||||
from common_faiss_tests import Randu10k
|
||||
|
||||
ru = Randu10k()
|
||||
xb = ru.xb
|
||||
xt = ru.xt
|
||||
xq = ru.xq
|
||||
nb, d = xb.shape
|
||||
nq, d = xq.shape
|
||||
|
||||
|
||||
class Merge(unittest.TestCase):
|
||||
def make_index_for_merge(self, quant, index_type, master_index):
|
||||
ncent = 40
|
||||
if index_type == 1:
|
||||
index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
|
||||
if master_index:
|
||||
index.is_trained = True
|
||||
elif index_type == 2:
|
||||
index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
|
||||
if master_index:
|
||||
index.pq = master_index.pq
|
||||
index.is_trained = True
|
||||
elif index_type == 3:
|
||||
index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
|
||||
if master_index:
|
||||
index.pq = master_index.pq
|
||||
index.refine_pq = master_index.refine_pq
|
||||
index.is_trained = True
|
||||
elif index_type == 4:
|
||||
# quant used as the actual index
|
||||
index = faiss.IndexIDMap(quant)
|
||||
return index
|
||||
|
||||
def do_test_merge(self, index_type):
|
||||
k = 16
|
||||
quant = faiss.IndexFlatL2(d)
|
||||
ref_index = self.make_index_for_merge(quant, index_type, False)
|
||||
|
||||
# trains the quantizer
|
||||
ref_index.train(xt)
|
||||
|
||||
print('ref search')
|
||||
ref_index.add(xb)
|
||||
_Dref, Iref = ref_index.search(xq, k)
|
||||
print(Iref[:5, :6])
|
||||
|
||||
indexes = []
|
||||
ni = 3
|
||||
for i in range(ni):
|
||||
i0 = int(i * nb / ni)
|
||||
i1 = int((i + 1) * nb / ni)
|
||||
index = self.make_index_for_merge(quant, index_type, ref_index)
|
||||
index.is_trained = True
|
||||
index.add(xb[i0:i1])
|
||||
indexes.append(index)
|
||||
|
||||
index = indexes[0]
|
||||
|
||||
for i in range(1, ni):
|
||||
print('merge ntotal=%d other.ntotal=%d ' % (
|
||||
index.ntotal, indexes[i].ntotal))
|
||||
index.merge_from(indexes[i], index.ntotal)
|
||||
|
||||
_D, I = index.search(xq, k)
|
||||
print(I[:5, :6])
|
||||
|
||||
ndiff = (I != Iref).sum()
|
||||
print('%d / %d differences' % (ndiff, nq * k))
|
||||
assert (ndiff < nq * k / 1000.)
|
||||
|
||||
def test_merge(self):
|
||||
self.do_test_merge(1)
|
||||
self.do_test_merge(2)
|
||||
self.do_test_merge(3)
|
||||
|
||||
def do_test_remove(self, index_type):
|
||||
k = 16
|
||||
quant = faiss.IndexFlatL2(d)
|
||||
index = self.make_index_for_merge(quant, index_type, None)
|
||||
|
||||
# trains the quantizer
|
||||
index.train(xt)
|
||||
|
||||
if index_type < 4:
|
||||
index.add(xb)
|
||||
else:
|
||||
gen = np.random.RandomState(1234)
|
||||
id_list = gen.permutation(nb * 7)[:nb].astype('int64')
|
||||
index.add_with_ids(xb, id_list)
|
||||
|
||||
print('ref search ntotal=%d' % index.ntotal)
|
||||
Dref, Iref = index.search(xq, k)
|
||||
|
||||
toremove = np.zeros(nq * k, dtype='int64')
|
||||
nr = 0
|
||||
for i in range(nq):
|
||||
for j in range(k):
|
||||
# remove all even results (it's ok if there are duplicates
|
||||
# in the list of ids)
|
||||
if Iref[i, j] % 2 == 0:
|
||||
nr = nr + 1
|
||||
toremove[nr] = Iref[i, j]
|
||||
|
||||
print('nr=', nr)
|
||||
|
||||
idsel = faiss.IDSelectorBatch(
|
||||
nr, faiss.swig_ptr(toremove))
|
||||
|
||||
for i in range(nr):
|
||||
assert (idsel.is_member(int(toremove[i])))
|
||||
|
||||
nremoved = index.remove_ids(idsel)
|
||||
|
||||
print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
|
||||
|
||||
D, I = index.search(xq, k)
|
||||
|
||||
# make sure results are in the same order with even ones removed
|
||||
ndiff = 0
|
||||
for i in range(nq):
|
||||
j2 = 0
|
||||
for j in range(k):
|
||||
if Iref[i, j] % 2 != 0:
|
||||
if I[i, j2] != Iref[i, j]:
|
||||
ndiff += 1
|
||||
assert abs(D[i, j2] - Dref[i, j]) < 1e-5
|
||||
j2 += 1
|
||||
# draws are ordered arbitrarily
|
||||
assert ndiff < 5
|
||||
|
||||
def test_remove(self):
|
||||
self.do_test_remove(1)
|
||||
self.do_test_remove(2)
|
||||
self.do_test_remove(4)
|
||||
|
||||
|
||||
# Test merge_from method for all IndexFlatCodes Types
|
||||
class IndexFlatCodes_merge(unittest.TestCase):
|
||||
|
||||
def do_flat_codes_test(self, index):
|
||||
ds = SyntheticDataset(32, 300, 300, 100)
|
||||
index1 = faiss.index_factory(ds.d, index)
|
||||
index1.train(ds.get_train())
|
||||
index1.add(ds.get_database())
|
||||
_, Iref = index1.search(ds.get_queries(), 5)
|
||||
index1.reset()
|
||||
index2 = faiss.clone_index(index1)
|
||||
index1.add(ds.get_database()[:100])
|
||||
index2.add(ds.get_database()[100:])
|
||||
index1.merge_from(index2)
|
||||
_, Inew = index1.search(ds.get_queries(), 5)
|
||||
np.testing.assert_array_equal(Inew, Iref)
|
||||
|
||||
def test_merge_IndexFlat(self):
|
||||
self.do_flat_codes_test("Flat")
|
||||
|
||||
def test_merge_IndexPQ(self):
|
||||
self.do_flat_codes_test("PQ8")
|
||||
|
||||
def test_merge_IndexLSH(self):
|
||||
self.do_flat_codes_test("LSHr")
|
||||
|
||||
def test_merge_IndexScalarQuantizer(self):
|
||||
self.do_flat_codes_test("SQ4")
|
|
@ -4,7 +4,6 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import faiss
|
||||
import unittest
|
||||
|
@ -129,138 +128,7 @@ class Shards(unittest.TestCase):
|
|||
ndiff = (I != Iref).sum()
|
||||
|
||||
print('%d / %d differences' % (ndiff, nq * k))
|
||||
assert(ndiff < nq * k / 1000.)
|
||||
|
||||
|
||||
class Merge(unittest.TestCase):
|
||||
|
||||
def make_index_for_merge(self, quant, index_type, master_index):
|
||||
ncent = 40
|
||||
if index_type == 1:
|
||||
index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2)
|
||||
if master_index:
|
||||
index.is_trained = True
|
||||
elif index_type == 2:
|
||||
index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8)
|
||||
if master_index:
|
||||
index.pq = master_index.pq
|
||||
index.is_trained = True
|
||||
elif index_type == 3:
|
||||
index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8)
|
||||
if master_index:
|
||||
index.pq = master_index.pq
|
||||
index.refine_pq = master_index.refine_pq
|
||||
index.is_trained = True
|
||||
elif index_type == 4:
|
||||
# quant used as the actual index
|
||||
index = faiss.IndexIDMap(quant)
|
||||
return index
|
||||
|
||||
def do_test_merge(self, index_type):
|
||||
k = 16
|
||||
quant = faiss.IndexFlatL2(d)
|
||||
ref_index = self.make_index_for_merge(quant, index_type, False)
|
||||
|
||||
# trains the quantizer
|
||||
ref_index.train(xt)
|
||||
|
||||
print('ref search')
|
||||
ref_index.add(xb)
|
||||
_Dref, Iref = ref_index.search(xq, k)
|
||||
print(Iref[:5, :6])
|
||||
|
||||
indexes = []
|
||||
ni = 3
|
||||
for i in range(ni):
|
||||
i0 = int(i * nb / ni)
|
||||
i1 = int((i + 1) * nb / ni)
|
||||
index = self.make_index_for_merge(quant, index_type, ref_index)
|
||||
index.is_trained = True
|
||||
index.add(xb[i0:i1])
|
||||
indexes.append(index)
|
||||
|
||||
index = indexes[0]
|
||||
|
||||
for i in range(1, ni):
|
||||
print('merge ntotal=%d other.ntotal=%d ' % (
|
||||
index.ntotal, indexes[i].ntotal))
|
||||
index.merge_from(indexes[i], index.ntotal)
|
||||
|
||||
_D, I = index.search(xq, k)
|
||||
print(I[:5, :6])
|
||||
|
||||
ndiff = (I != Iref).sum()
|
||||
print('%d / %d differences' % (ndiff, nq * k))
|
||||
assert(ndiff < nq * k / 1000.)
|
||||
|
||||
def test_merge(self):
|
||||
self.do_test_merge(1)
|
||||
self.do_test_merge(2)
|
||||
self.do_test_merge(3)
|
||||
|
||||
def do_test_remove(self, index_type):
|
||||
k = 16
|
||||
quant = faiss.IndexFlatL2(d)
|
||||
index = self.make_index_for_merge(quant, index_type, None)
|
||||
|
||||
# trains the quantizer
|
||||
index.train(xt)
|
||||
|
||||
if index_type < 4:
|
||||
index.add(xb)
|
||||
else:
|
||||
gen = np.random.RandomState(1234)
|
||||
id_list = gen.permutation(nb * 7)[:nb].astype('int64')
|
||||
index.add_with_ids(xb, id_list)
|
||||
|
||||
print('ref search ntotal=%d' % index.ntotal)
|
||||
Dref, Iref = index.search(xq, k)
|
||||
|
||||
toremove = np.zeros(nq * k, dtype='int64')
|
||||
nr = 0
|
||||
for i in range(nq):
|
||||
for j in range(k):
|
||||
# remove all even results (it's ok if there are duplicates
|
||||
# in the list of ids)
|
||||
if Iref[i, j] % 2 == 0:
|
||||
nr = nr + 1
|
||||
toremove[nr] = Iref[i, j]
|
||||
|
||||
print('nr=', nr)
|
||||
|
||||
idsel = faiss.IDSelectorBatch(
|
||||
nr, faiss.swig_ptr(toremove))
|
||||
|
||||
for i in range(nr):
|
||||
assert(idsel.is_member(int(toremove[i])))
|
||||
|
||||
nremoved = index.remove_ids(idsel)
|
||||
|
||||
print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal))
|
||||
|
||||
D, I = index.search(xq, k)
|
||||
|
||||
# make sure results are in the same order with even ones removed
|
||||
ndiff = 0
|
||||
for i in range(nq):
|
||||
j2 = 0
|
||||
for j in range(k):
|
||||
if Iref[i, j] % 2 != 0:
|
||||
if I[i, j2] != Iref[i, j]:
|
||||
ndiff += 1
|
||||
assert abs(D[i, j2] - Dref[i, j]) < 1e-5
|
||||
j2 += 1
|
||||
# draws are ordered arbitrarily
|
||||
assert ndiff < 5
|
||||
|
||||
def test_remove(self):
|
||||
self.do_test_remove(1)
|
||||
self.do_test_remove(2)
|
||||
self.do_test_remove(4)
|
||||
|
||||
|
||||
|
||||
|
||||
assert (ndiff < nq * k / 1000.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue