add remove and merge features for IndexFastScan (#2497)

Summary:
* Modify pq4_get_paked_element to make it not depend on an auxiliary table
* Create pq4_set_packed_element which sets a single element in codes in packed format
(These methods would be used in merge and remove for IndexFastScan
get method is also used in FastScan indices for reconstruction)
* Add remove feature for IndexFastScan
* Add merge feature for indexFast Scan

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2497

Test Plan:
cd build && make -j
make test
cd faiss/python && python setup.py build && cd ../../..
PYTHONPATH="$(ls -d ./build/faiss/python/build/lib*/)" pytest tests/test_*.py

Reviewed By: mdouze

Differential Revision: D39927403

Pulled By: mdouze

fbshipit-source-id: 45271b98419203dfb1cea4f4e7eaf0662523a5b5
pull/2532/head
Abdelrahman Elmeniawy 2022-10-11 04:14:29 -07:00 committed by Facebook GitHub Bot
parent 16d5ec755f
commit 47a9953a35
7 changed files with 256 additions and 22 deletions

View File

@ -14,6 +14,7 @@
#include <omp.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/LookupTableScaler.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/distances.h>
@ -97,6 +98,61 @@ void IndexFastScan::add(idx_t n, const float* x) {
ntotal += n;
}
size_t IndexFastScan::remove_ids(const IDSelector& sel) {
idx_t j = 0;
for (idx_t i = 0; i < ntotal; i++) {
if (sel.is_member(i)) {
// should be removed
} else {
if (i > j) {
for (int sq = 0; sq < M; sq++) {
uint8_t code =
pq4_get_packed_element(codes.data(), bbs, M, i, sq);
pq4_set_packed_element(codes.data(), code, bbs, M, j, sq);
}
}
j++;
}
}
size_t nremove = ntotal - j;
if (nremove > 0) {
ntotal = j;
ntotal2 = roundup(ntotal, bbs);
size_t new_size = ntotal2 * M2 / 2;
codes.resize(new_size);
}
return nremove;
}
void IndexFastScan::check_compatible_for_merge(const Index& otherIndex) const {
const IndexFastScan* other =
dynamic_cast<const IndexFastScan*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
FAISS_THROW_IF_NOT(other->M == M);
FAISS_THROW_IF_NOT(other->bbs == bbs);
FAISS_THROW_IF_NOT(other->d == d);
FAISS_THROW_IF_NOT(other->code_size == code_size);
FAISS_THROW_IF_NOT_MSG(
typeid(*this) == typeid(*other),
"can only merge indexes of the same type");
}
void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
check_compatible_for_merge(otherIndex);
IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
ntotal2 = roundup(ntotal + other->ntotal, bbs);
codes.resize(ntotal2 * M2 / 2);
for (int i = 0; i < other->ntotal; i++) {
for (int sq = 0; sq < M; sq++) {
uint8_t code =
pq4_get_packed_element(other->codes.data(), bbs, M, i, sq);
pq4_set_packed_element(codes.data(), code, bbs, M, ntotal + i, sq);
}
}
ntotal += other->ntotal;
other->reset();
}
namespace {
template <class C, typename dis_t, class Scaler>

View File

@ -16,6 +16,8 @@ namespace faiss {
*
* The codes are not stored sequentially but grouped in blocks of size bbs.
* This makes it possible to compute distances quickly with SIMD instructions.
* The trailing codes (padding codes that are added to complete the last code)
* are garbage.
*
* Implementations:
* 12: blocked loop with internal loop on Q with qbs
@ -123,6 +125,9 @@ struct IndexFastScan : Index {
const Scaler& scaler) const;
void reconstruct(idx_t key, float* recons) const override;
size_t remove_ids(const IDSelector& sel) override;
void merge_from(Index& otherIndex, idx_t add_id = 0) override;
void check_compatible_for_merge(const Index& otherIndex) const override;
};
struct FastScanStats {

View File

@ -122,30 +122,70 @@ void pq4_pack_codes_range(
}
}
namespace {
// get the specific address of the vector inside a block
// shift is used for determine the if the saved in bits 0..3 (false) or
// bits 4..7 (true)
uint8_t get_vector_specific_address(
size_t bbs,
size_t vector_id,
size_t sq,
bool& shift) {
// get the vector_id inside the block
vector_id = vector_id % bbs;
shift = vector_id > 15;
vector_id = vector_id & 15;
// get the address of the vector in sq
size_t address;
if (vector_id < 8) {
address = vector_id << 1;
} else {
address = ((vector_id - 8) << 1) + 1;
}
if (sq & 1) {
address += 16;
}
return (sq >> 1) * bbs + address;
}
} // anonymous namespace
uint8_t pq4_get_packed_element(
const uint8_t* data,
size_t bbs,
size_t nsq,
size_t i,
size_t vector_id,
size_t sq) {
// move to correct bbs-sized block
data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
sq = sq & 1;
i = i % bbs;
// another step
data += (i / 32) * 32;
i = i % 32;
if (sq == 1) {
data += 16;
}
const uint8_t iperm0[16] = {
0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
if (i < 16) {
return data[iperm0[i]] & 15;
// number of blocks * block size
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
bool shift;
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
if (shift) {
return data[address] >> 4;
} else {
return data[iperm0[i - 16]] >> 4;
return data[address] & 15;
}
}
void pq4_set_packed_element(
uint8_t* data,
uint8_t code,
size_t bbs,
size_t nsq,
size_t vector_id,
size_t sq) {
// move to correct bbs-sized block
// number of blocks * block size
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
bool shift;
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
if (shift) {
data[address] = (code << 4) | (data[address] & 15);
} else {
data[address] = code | (data[address] & ~15);
}
}

View File

@ -61,14 +61,27 @@ void pq4_pack_codes_range(
/** get a single element from a packed codes table
*
* @param i vector id
* @param vector_id vector id
* @param sq subquantizer (< nsq)
*/
uint8_t pq4_get_packed_element(
const uint8_t* data,
size_t bbs,
size_t nsq,
size_t i,
size_t vector_id,
size_t sq);
/** set a single element "code" into a packed codes table
*
* @param vector_id vector id
* @param sq subquantizer (< nsq)
*/
void pq4_set_packed_element(
uint8_t* data,
uint8_t code,
size_t bbs,
size_t nsq,
size_t vector_id,
size_t sq);
/** Pack Look-up table for consumption by the kernel.

View File

@ -16,6 +16,44 @@ import platform
from common_faiss_tests import get_dataset_2
class TestRemoveFastScan(unittest.TestCase):
def do_test(self, ntotal, removed):
d = 20
xt, xb, _ = get_dataset_2(d, ntotal, ntotal, 0)
index = faiss.index_factory(20, 'IDMap2,PQ5x4fs')
index.train(xt)
index.add_with_ids(xb, np.arange(ntotal).astype("int64"))
before = index.reconstruct_n(0, ntotal)
index.remove_ids(np.array(removed))
for i in range(ntotal):
if i in removed:
# should throw RuntimeError as this vector should be removed
try:
after = index.reconstruct(i)
assert False
except RuntimeError:
pass
else:
after = index.reconstruct(i)
np.testing.assert_array_equal(before[i], after)
assert index.ntotal == ntotal - len(removed)
def test_remove_last_vector(self):
self.do_test(993, [992])
# test remove element from every address 0 -> 31
# [0, 32 + 1, 2 * 32 + 2, ....]
# [0, 33 , 66 , 99, 132, .....]
def test_remove_every_address(self):
removed = (33 * np.arange(32)).tolist()
self.do_test(1100, removed)
# test remove range of vectors and leave ntotal divisible by 32
def test_leave_complete_block(self):
self.do_test(1000, np.arange(8).tolist())
class TestRemove(unittest.TestCase):
def do_merge_then_remove(self, ondisk):
@ -171,7 +209,6 @@ class TestRemove(unittest.TestCase):
assert False, 'should have raised an exception'
class TestRangeSearch(unittest.TestCase):
def test_range_search_id_map(self):
@ -331,6 +368,7 @@ class TestTransformChain(unittest.TestCase):
assert np.all(I == I2)
@unittest.skipIf(platform.system() == 'Windows', \
'Mmap not supported on Windows.')
class TestRareIO(unittest.TestCase):
@ -504,6 +542,7 @@ class TestSerialize(unittest.TestCase):
Dnew, Inew = index3.search(xq, 5)
assert np.all(Dnew == Dref) and np.all(Inew == Iref)
@unittest.skipIf(platform.system() == 'Windows',
'OnDiskInvertedLists is unsupported on Windows.')
class TestRenameOndisk(unittest.TestCase):
@ -638,7 +677,5 @@ class TestInvlistMeta(unittest.TestCase):
index.replace_invlists(il, True)
if __name__ == '__main__':
unittest.main()

View File

@ -167,3 +167,34 @@ class IndexFlatCodes_merge(unittest.TestCase):
def test_merge_IndexScalarQuantizer(self):
self.do_flat_codes_test("SQ4")
# Test merge_from method for IndexFastScan Types
class IndexFastScan_merge(unittest.TestCase):
def do_fast_scan_test(self, index, size1):
ds = SyntheticDataset(110, 1000, 1000, 100)
index1 = faiss.index_factory(ds.d, index)
index1.train(ds.get_train())
index1.add(ds.get_database())
_, Iref = index1.search(ds.get_queries(), 5)
index1.reset()
index2 = faiss.index_factory(ds.d, index)
index2.train(ds.get_train())
index1.add(ds.get_database()[:size1])
index2.add(ds.get_database()[size1:])
index1.merge_from(index2)
_, Inew = index1.search(ds.get_queries(), 5)
np.testing.assert_array_equal(Inew, Iref)
def test_merge_IndexFastScan_complete_block(self):
self.do_fast_scan_test("PQ5x4fs", 320)
def test_merge_IndexFastScan_not_complete_block(self):
self.do_fast_scan_test("PQ11x4fs", 310)
def test_merge_IndexFastScan_even_M(self):
self.do_fast_scan_test("PQ10x4fs", 500)
def test_merge_IndexAdditiveQuantizerFastScan(self):
self.do_fast_scan_test("RQ10x4fs_32_Nrq2x4", 330)

View File

@ -11,7 +11,9 @@
#include <gtest/gtest.h>
#include <faiss/IndexPQFastScan.h>
#include <faiss/impl/ProductQuantizer.h>
#include <faiss/impl/pq4_fast_scan.h>
namespace {
@ -24,6 +26,15 @@ const std::vector<uint64_t> random_vector(size_t s) {
return v;
}
const std::vector<float> random_vector_float(size_t s) {
std::vector<float> v(s, 0);
for (size_t i = 0; i < s; ++i) {
v[i] = rand();
}
return v;
}
} // namespace
TEST(PQEncoderGeneric, encode) {
@ -91,3 +102,44 @@ TEST(PQEncoder16, encode) {
EXPECT_EQ(values[i] & mask, v);
}
}
TEST(PQFastScan, set_packed_element) {
int d = 20, ntotal = 1000, M = 5, nbits = 4;
const std::vector<float> ds = random_vector_float(ntotal * d);
faiss::IndexPQFastScan index(d, M, nbits);
index.train(ntotal, ds.data());
index.add(ntotal, ds.data());
for (int j = 0; j < 10; j++) {
int vector_id = rand() % ntotal;
std::vector<uint8_t> old(ntotal * M);
std::vector<uint8_t> code(M);
for (int i = 0; i < ntotal; i++) {
for (int sq = 0; sq < M; sq++) {
old[i * M + sq] = faiss::pq4_get_packed_element(
index.codes.data(), index.bbs, M, i, sq);
}
}
for (int sq = 0; sq < M; sq++) {
faiss::pq4_set_packed_element(
index.codes.data(),
((old[vector_id * M + sq] + 3) % 16),
index.bbs,
M,
vector_id,
sq);
}
for (int i = 0; i < ntotal; i++) {
for (int sq = 0; sq < M; sq++) {
uint8_t newcode = faiss::pq4_get_packed_element(
index.codes.data(), index.bbs, M, i, sq);
uint8_t oldcode = old[i * M + sq];
if (i == vector_id) {
EXPECT_EQ(newcode, (oldcode + 3) % 16);
} else {
EXPECT_EQ(newcode, oldcode);
}
}
}
}
}