add remove and merge features for IndexFastScan (#2497)
Summary: * Modify pq4_get_paked_element to make it not depend on an auxiliary table * Create pq4_set_packed_element which sets a single element in codes in packed format (These methods would be used in merge and remove for IndexFastScan get method is also used in FastScan indices for reconstruction) * Add remove feature for IndexFastScan * Add merge feature for indexFast Scan Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2497 Test Plan: cd build && make -j make test cd faiss/python && python setup.py build && cd ../../.. PYTHONPATH="$(ls -d ./build/faiss/python/build/lib*/)" pytest tests/test_*.py Reviewed By: mdouze Differential Revision: D39927403 Pulled By: mdouze fbshipit-source-id: 45271b98419203dfb1cea4f4e7eaf0662523a5b5pull/2532/head
parent
16d5ec755f
commit
47a9953a35
|
@ -14,6 +14,7 @@
|
|||
#include <omp.h>
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/impl/IDSelector.h>
|
||||
#include <faiss/impl/LookupTableScaler.h>
|
||||
#include <faiss/impl/ResultHandler.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
|
@ -97,6 +98,61 @@ void IndexFastScan::add(idx_t n, const float* x) {
|
|||
ntotal += n;
|
||||
}
|
||||
|
||||
size_t IndexFastScan::remove_ids(const IDSelector& sel) {
|
||||
idx_t j = 0;
|
||||
for (idx_t i = 0; i < ntotal; i++) {
|
||||
if (sel.is_member(i)) {
|
||||
// should be removed
|
||||
} else {
|
||||
if (i > j) {
|
||||
for (int sq = 0; sq < M; sq++) {
|
||||
uint8_t code =
|
||||
pq4_get_packed_element(codes.data(), bbs, M, i, sq);
|
||||
pq4_set_packed_element(codes.data(), code, bbs, M, j, sq);
|
||||
}
|
||||
}
|
||||
j++;
|
||||
}
|
||||
}
|
||||
size_t nremove = ntotal - j;
|
||||
if (nremove > 0) {
|
||||
ntotal = j;
|
||||
ntotal2 = roundup(ntotal, bbs);
|
||||
size_t new_size = ntotal2 * M2 / 2;
|
||||
codes.resize(new_size);
|
||||
}
|
||||
return nremove;
|
||||
}
|
||||
|
||||
void IndexFastScan::check_compatible_for_merge(const Index& otherIndex) const {
|
||||
const IndexFastScan* other =
|
||||
dynamic_cast<const IndexFastScan*>(&otherIndex);
|
||||
FAISS_THROW_IF_NOT(other);
|
||||
FAISS_THROW_IF_NOT(other->M == M);
|
||||
FAISS_THROW_IF_NOT(other->bbs == bbs);
|
||||
FAISS_THROW_IF_NOT(other->d == d);
|
||||
FAISS_THROW_IF_NOT(other->code_size == code_size);
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
typeid(*this) == typeid(*other),
|
||||
"can only merge indexes of the same type");
|
||||
}
|
||||
|
||||
void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
|
||||
check_compatible_for_merge(otherIndex);
|
||||
IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
|
||||
ntotal2 = roundup(ntotal + other->ntotal, bbs);
|
||||
codes.resize(ntotal2 * M2 / 2);
|
||||
for (int i = 0; i < other->ntotal; i++) {
|
||||
for (int sq = 0; sq < M; sq++) {
|
||||
uint8_t code =
|
||||
pq4_get_packed_element(other->codes.data(), bbs, M, i, sq);
|
||||
pq4_set_packed_element(codes.data(), code, bbs, M, ntotal + i, sq);
|
||||
}
|
||||
}
|
||||
ntotal += other->ntotal;
|
||||
other->reset();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <class C, typename dis_t, class Scaler>
|
||||
|
|
|
@ -16,6 +16,8 @@ namespace faiss {
|
|||
*
|
||||
* The codes are not stored sequentially but grouped in blocks of size bbs.
|
||||
* This makes it possible to compute distances quickly with SIMD instructions.
|
||||
* The trailing codes (padding codes that are added to complete the last code)
|
||||
* are garbage.
|
||||
*
|
||||
* Implementations:
|
||||
* 12: blocked loop with internal loop on Q with qbs
|
||||
|
@ -123,6 +125,9 @@ struct IndexFastScan : Index {
|
|||
const Scaler& scaler) const;
|
||||
|
||||
void reconstruct(idx_t key, float* recons) const override;
|
||||
size_t remove_ids(const IDSelector& sel) override;
|
||||
void merge_from(Index& otherIndex, idx_t add_id = 0) override;
|
||||
void check_compatible_for_merge(const Index& otherIndex) const override;
|
||||
};
|
||||
|
||||
struct FastScanStats {
|
||||
|
|
|
@ -122,30 +122,70 @@ void pq4_pack_codes_range(
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// get the specific address of the vector inside a block
|
||||
// shift is used for determine the if the saved in bits 0..3 (false) or
|
||||
// bits 4..7 (true)
|
||||
uint8_t get_vector_specific_address(
|
||||
size_t bbs,
|
||||
size_t vector_id,
|
||||
size_t sq,
|
||||
bool& shift) {
|
||||
// get the vector_id inside the block
|
||||
vector_id = vector_id % bbs;
|
||||
shift = vector_id > 15;
|
||||
vector_id = vector_id & 15;
|
||||
|
||||
// get the address of the vector in sq
|
||||
size_t address;
|
||||
if (vector_id < 8) {
|
||||
address = vector_id << 1;
|
||||
} else {
|
||||
address = ((vector_id - 8) << 1) + 1;
|
||||
}
|
||||
if (sq & 1) {
|
||||
address += 16;
|
||||
}
|
||||
return (sq >> 1) * bbs + address;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
uint8_t pq4_get_packed_element(
|
||||
const uint8_t* data,
|
||||
size_t bbs,
|
||||
size_t nsq,
|
||||
size_t i,
|
||||
size_t vector_id,
|
||||
size_t sq) {
|
||||
// move to correct bbs-sized block
|
||||
data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
|
||||
sq = sq & 1;
|
||||
i = i % bbs;
|
||||
|
||||
// another step
|
||||
data += (i / 32) * 32;
|
||||
i = i % 32;
|
||||
|
||||
if (sq == 1) {
|
||||
data += 16;
|
||||
}
|
||||
const uint8_t iperm0[16] = {
|
||||
0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
||||
if (i < 16) {
|
||||
return data[iperm0[i]] & 15;
|
||||
// number of blocks * block size
|
||||
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
||||
bool shift;
|
||||
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
||||
if (shift) {
|
||||
return data[address] >> 4;
|
||||
} else {
|
||||
return data[iperm0[i - 16]] >> 4;
|
||||
return data[address] & 15;
|
||||
}
|
||||
}
|
||||
|
||||
void pq4_set_packed_element(
|
||||
uint8_t* data,
|
||||
uint8_t code,
|
||||
size_t bbs,
|
||||
size_t nsq,
|
||||
size_t vector_id,
|
||||
size_t sq) {
|
||||
// move to correct bbs-sized block
|
||||
// number of blocks * block size
|
||||
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
||||
bool shift;
|
||||
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
||||
if (shift) {
|
||||
data[address] = (code << 4) | (data[address] & 15);
|
||||
} else {
|
||||
data[address] = code | (data[address] & ~15);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -61,14 +61,27 @@ void pq4_pack_codes_range(
|
|||
|
||||
/** get a single element from a packed codes table
|
||||
*
|
||||
* @param i vector id
|
||||
* @param vector_id vector id
|
||||
* @param sq subquantizer (< nsq)
|
||||
*/
|
||||
uint8_t pq4_get_packed_element(
|
||||
const uint8_t* data,
|
||||
size_t bbs,
|
||||
size_t nsq,
|
||||
size_t i,
|
||||
size_t vector_id,
|
||||
size_t sq);
|
||||
|
||||
/** set a single element "code" into a packed codes table
|
||||
*
|
||||
* @param vector_id vector id
|
||||
* @param sq subquantizer (< nsq)
|
||||
*/
|
||||
void pq4_set_packed_element(
|
||||
uint8_t* data,
|
||||
uint8_t code,
|
||||
size_t bbs,
|
||||
size_t nsq,
|
||||
size_t vector_id,
|
||||
size_t sq);
|
||||
|
||||
/** Pack Look-up table for consumption by the kernel.
|
||||
|
|
|
@ -16,6 +16,44 @@ import platform
|
|||
|
||||
from common_faiss_tests import get_dataset_2
|
||||
|
||||
|
||||
class TestRemoveFastScan(unittest.TestCase):
|
||||
def do_test(self, ntotal, removed):
|
||||
d = 20
|
||||
xt, xb, _ = get_dataset_2(d, ntotal, ntotal, 0)
|
||||
index = faiss.index_factory(20, 'IDMap2,PQ5x4fs')
|
||||
index.train(xt)
|
||||
index.add_with_ids(xb, np.arange(ntotal).astype("int64"))
|
||||
before = index.reconstruct_n(0, ntotal)
|
||||
index.remove_ids(np.array(removed))
|
||||
for i in range(ntotal):
|
||||
if i in removed:
|
||||
# should throw RuntimeError as this vector should be removed
|
||||
try:
|
||||
after = index.reconstruct(i)
|
||||
assert False
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
after = index.reconstruct(i)
|
||||
np.testing.assert_array_equal(before[i], after)
|
||||
assert index.ntotal == ntotal - len(removed)
|
||||
|
||||
def test_remove_last_vector(self):
|
||||
self.do_test(993, [992])
|
||||
|
||||
# test remove element from every address 0 -> 31
|
||||
# [0, 32 + 1, 2 * 32 + 2, ....]
|
||||
# [0, 33 , 66 , 99, 132, .....]
|
||||
def test_remove_every_address(self):
|
||||
removed = (33 * np.arange(32)).tolist()
|
||||
self.do_test(1100, removed)
|
||||
|
||||
# test remove range of vectors and leave ntotal divisible by 32
|
||||
def test_leave_complete_block(self):
|
||||
self.do_test(1000, np.arange(8).tolist())
|
||||
|
||||
|
||||
class TestRemove(unittest.TestCase):
|
||||
|
||||
def do_merge_then_remove(self, ondisk):
|
||||
|
@ -171,7 +209,6 @@ class TestRemove(unittest.TestCase):
|
|||
assert False, 'should have raised an exception'
|
||||
|
||||
|
||||
|
||||
class TestRangeSearch(unittest.TestCase):
|
||||
|
||||
def test_range_search_id_map(self):
|
||||
|
@ -331,6 +368,7 @@ class TestTransformChain(unittest.TestCase):
|
|||
|
||||
assert np.all(I == I2)
|
||||
|
||||
|
||||
@unittest.skipIf(platform.system() == 'Windows', \
|
||||
'Mmap not supported on Windows.')
|
||||
class TestRareIO(unittest.TestCase):
|
||||
|
@ -504,6 +542,7 @@ class TestSerialize(unittest.TestCase):
|
|||
Dnew, Inew = index3.search(xq, 5)
|
||||
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
|
||||
|
||||
|
||||
@unittest.skipIf(platform.system() == 'Windows',
|
||||
'OnDiskInvertedLists is unsupported on Windows.')
|
||||
class TestRenameOndisk(unittest.TestCase):
|
||||
|
@ -638,7 +677,5 @@ class TestInvlistMeta(unittest.TestCase):
|
|||
index.replace_invlists(il, True)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -167,3 +167,34 @@ class IndexFlatCodes_merge(unittest.TestCase):
|
|||
|
||||
def test_merge_IndexScalarQuantizer(self):
|
||||
self.do_flat_codes_test("SQ4")
|
||||
|
||||
|
||||
# Test merge_from method for IndexFastScan Types
|
||||
class IndexFastScan_merge(unittest.TestCase):
|
||||
|
||||
def do_fast_scan_test(self, index, size1):
|
||||
ds = SyntheticDataset(110, 1000, 1000, 100)
|
||||
index1 = faiss.index_factory(ds.d, index)
|
||||
index1.train(ds.get_train())
|
||||
index1.add(ds.get_database())
|
||||
_, Iref = index1.search(ds.get_queries(), 5)
|
||||
index1.reset()
|
||||
index2 = faiss.index_factory(ds.d, index)
|
||||
index2.train(ds.get_train())
|
||||
index1.add(ds.get_database()[:size1])
|
||||
index2.add(ds.get_database()[size1:])
|
||||
index1.merge_from(index2)
|
||||
_, Inew = index1.search(ds.get_queries(), 5)
|
||||
np.testing.assert_array_equal(Inew, Iref)
|
||||
|
||||
def test_merge_IndexFastScan_complete_block(self):
|
||||
self.do_fast_scan_test("PQ5x4fs", 320)
|
||||
|
||||
def test_merge_IndexFastScan_not_complete_block(self):
|
||||
self.do_fast_scan_test("PQ11x4fs", 310)
|
||||
|
||||
def test_merge_IndexFastScan_even_M(self):
|
||||
self.do_fast_scan_test("PQ10x4fs", 500)
|
||||
|
||||
def test_merge_IndexAdditiveQuantizerFastScan(self):
|
||||
self.do_fast_scan_test("RQ10x4fs_32_Nrq2x4", 330)
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <faiss/IndexPQFastScan.h>
|
||||
#include <faiss/impl/ProductQuantizer.h>
|
||||
#include <faiss/impl/pq4_fast_scan.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -24,6 +26,15 @@ const std::vector<uint64_t> random_vector(size_t s) {
|
|||
return v;
|
||||
}
|
||||
|
||||
const std::vector<float> random_vector_float(size_t s) {
|
||||
std::vector<float> v(s, 0);
|
||||
for (size_t i = 0; i < s; ++i) {
|
||||
v[i] = rand();
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(PQEncoderGeneric, encode) {
|
||||
|
@ -91,3 +102,44 @@ TEST(PQEncoder16, encode) {
|
|||
EXPECT_EQ(values[i] & mask, v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PQFastScan, set_packed_element) {
|
||||
int d = 20, ntotal = 1000, M = 5, nbits = 4;
|
||||
const std::vector<float> ds = random_vector_float(ntotal * d);
|
||||
faiss::IndexPQFastScan index(d, M, nbits);
|
||||
index.train(ntotal, ds.data());
|
||||
index.add(ntotal, ds.data());
|
||||
|
||||
for (int j = 0; j < 10; j++) {
|
||||
int vector_id = rand() % ntotal;
|
||||
std::vector<uint8_t> old(ntotal * M);
|
||||
std::vector<uint8_t> code(M);
|
||||
for (int i = 0; i < ntotal; i++) {
|
||||
for (int sq = 0; sq < M; sq++) {
|
||||
old[i * M + sq] = faiss::pq4_get_packed_element(
|
||||
index.codes.data(), index.bbs, M, i, sq);
|
||||
}
|
||||
}
|
||||
for (int sq = 0; sq < M; sq++) {
|
||||
faiss::pq4_set_packed_element(
|
||||
index.codes.data(),
|
||||
((old[vector_id * M + sq] + 3) % 16),
|
||||
index.bbs,
|
||||
M,
|
||||
vector_id,
|
||||
sq);
|
||||
}
|
||||
for (int i = 0; i < ntotal; i++) {
|
||||
for (int sq = 0; sq < M; sq++) {
|
||||
uint8_t newcode = faiss::pq4_get_packed_element(
|
||||
index.codes.data(), index.bbs, M, i, sq);
|
||||
uint8_t oldcode = old[i * M + sq];
|
||||
if (i == vector_id) {
|
||||
EXPECT_EQ(newcode, (oldcode + 3) % 16);
|
||||
} else {
|
||||
EXPECT_EQ(newcode, oldcode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue