CodeSet for deduping large datasets (#2949)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2949

A more scalable alternative to `np.unique` for deduping large datasets with a quantized code.

Reviewed By: mlomeli1

Differential Revision: D47443953

fbshipit-source-id: 4a1554d4d4200b5fa657e9d8b7395bba9856a8e3
pull/2958/head^2
Gergely Szilvasy 2023-07-19 10:05:46 -07:00 committed by Facebook GitHub Bot
parent 43d86e3073
commit 821a401ae9
5 changed files with 52 additions and 0 deletions

View File

@ -41,6 +41,7 @@ class_wrappers.handle_MapLong2Long(MapLong2Long)
class_wrappers.handle_IDSelectorSubset(IDSelectorBatch, class_owns=True)
class_wrappers.handle_IDSelectorSubset(IDSelectorArray, class_owns=False)
class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False)
class_wrappers.handle_CodeSet(CodeSet)
this_module = sys.modules[__name__]

View File

@ -1102,3 +1102,21 @@ def handle_IDSelectorSubset(the_class, class_owns, force_int64=True):
self.original_init(*args)
the_class.__init__ = replacement_init
def handle_CodeSet(the_class):
def replacement_insert(self, codes, inserted=None):
n, d = codes.shape
assert d == self.d
codes = np.ascontiguousarray(codes, dtype=np.uint8)
if inserted is None:
inserted = np.empty(n, dtype=bool)
else:
assert inserted.shape == (n, )
self.insert_c(n, swig_ptr(codes), swig_ptr(inserted))
return inserted
replace_method(the_class, 'insert', replacement_insert)

View File

@ -28,6 +28,7 @@
#include <omp.h>
#include <algorithm>
#include <set>
#include <type_traits>
#include <vector>
@ -623,4 +624,12 @@ void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
template struct CombinerRangeKNN<float>;
template struct CombinerRangeKNN<int16_t>;
void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
for (size_t i = 0; i < n; i++) {
auto res = s.insert(
std::vector<uint8_t>(codes + i * d, codes + i * d + d));
inserted[i] = res.second;
}
}
} // namespace faiss

View File

@ -17,7 +17,9 @@
#define FAISS_utils_h
#include <stdint.h>
#include <set>
#include <string>
#include <vector>
#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>
@ -209,6 +211,14 @@ struct CombinerRangeKNN {
void write_result(T* D_res, int64_t* I_res);
};
struct CodeSet {
size_t d;
std::set<std::vector<uint8_t>> s;
explicit CodeSet(size_t d) : d(d) {}
void insert(size_t n, const uint8_t* codes, bool* inserted);
};
} // namespace faiss
#endif /* FAISS_utils_h */

View File

@ -630,3 +630,17 @@ class TestInvlistSort(unittest.TestCase):
np.testing.assert_equal(Dnew, Dref)
Inew_remap = perm[Inew]
np.testing.assert_equal(Inew_remap, Iref)
class TestCodeSet(unittest.TestCase):
def test_code_set(self):
""" CodeSet and np.unique should produce the same output """
d = 8
n = 1000 # > 256 and using only 0 or 1 so there must be duplicates
codes = np.random.randint(0, 2, (n, d), dtype=np.uint8)
s = faiss.CodeSet(d)
inserted = s.insert(codes)
np.testing.assert_equal(
np.sort(np.unique(codes, axis=0), axis=None),
np.sort(codes[inserted], axis=None))