From 1ee15ef3c34c6df1e72b806c00106b15fe35c3e5 Mon Sep 17 00:00:00 2001 From: Kaelen Haag Date: Wed, 8 Mar 2023 08:48:54 -0800 Subject: [PATCH] Proposal IDSelectorCombination (#2742) Summary: Adds support for an IDSelector that takes in two IDSelectors and can perform a boolean operation on their is_member outcomes. Current implementation is pretty naive and doesn't try to do any optimizations on the types of IDSelectors combined. Also test cases are definitely lacking but can add more once approach is agreed upon. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2742 Reviewed By: algoriddle Differential Revision: D43904855 Pulled By: mdouze fbshipit-source-id: bbe687800a19b418ca30c9257fb0334c64ab5f52 --- faiss/impl/IDSelector.h | 39 +++++++++++++++++++++++++++++++++++++ faiss/python/__init__.py | 3 +++ tests/test_search_params.py | 38 +++++++++++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/faiss/impl/IDSelector.h b/faiss/impl/IDSelector.h index e00dbe8fb..913d9ff8b 100644 --- a/faiss/impl/IDSelector.h +++ b/faiss/impl/IDSelector.h @@ -131,4 +131,43 @@ struct IDSelectorAll : IDSelector { virtual ~IDSelectorAll() {} }; +/// does an AND operation on the the two given IDSelector's is_membership +/// results. +struct IDSelectorAnd : IDSelector { + const IDSelector* lhs; + const IDSelector* rhs; + IDSelectorAnd(const IDSelector* lhs, const IDSelector* rhs) + : lhs(lhs), rhs(rhs) {} + bool is_member(idx_t id) const final { + return lhs->is_member(id) && rhs->is_member(id); + }; + virtual ~IDSelectorAnd() {} +}; + +/// does an OR operation on the the two given IDSelector's is_membership +/// results. +struct IDSelectorOr : IDSelector { + const IDSelector* lhs; + const IDSelector* rhs; + IDSelectorOr(const IDSelector* lhs, const IDSelector* rhs) + : lhs(lhs), rhs(rhs) {} + bool is_member(idx_t id) const final { + return lhs->is_member(id) || rhs->is_member(id); + }; + virtual ~IDSelectorOr() {} +}; + +/// does an XOR operation on the the two given IDSelector's is_membership +/// results. +struct IDSelectorXOr : IDSelector { + const IDSelector* lhs; + const IDSelector* rhs; + IDSelectorXOr(const IDSelector* lhs, const IDSelector* rhs) + : lhs(lhs), rhs(rhs) {} + bool is_member(idx_t id) const final { + return lhs->is_member(id) ^ rhs->is_member(id); + }; + virtual ~IDSelectorXOr() {} +}; + } // namespace faiss diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 9e95881b5..92eeb3479 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -189,6 +189,9 @@ add_ref_in_constructor(BufferedIOWriter, 0) add_ref_in_constructor(BufferedIOReader, 0) add_ref_in_constructor(IDSelectorNot, 0) +add_ref_in_constructor(IDSelectorAnd, slice(2)) +add_ref_in_constructor(IDSelectorOr, slice(2)) +add_ref_in_constructor(IDSelectorXOr, slice(2)) # seems really marginal... # remove_ref_from_method(IndexReplicas, 'removeIndex', 0) diff --git a/tests/test_search_params.py b/tests/test_search_params.py index bea1818ef..bd7e813bf 100644 --- a/tests/test_search_params.py +++ b/tests/test_search_params.py @@ -23,7 +23,7 @@ class TestSelector(unittest.TestCase): def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2): """ Verify that the id selector returns the subset of results that are members according to the IDSelector. - Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "not" + Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor" """ ds = datasets.SyntheticDataset(32, 1000, 100, 20) index = faiss.index_factory(ds.d, index_key, mt) @@ -33,6 +33,24 @@ class TestSelector(unittest.TestCase): # reference result if "range" in id_selector_type: subset = np.arange(30, 80).astype('int64') + elif id_selector_type == "or": + lhs_rs = np.random.RandomState(123) + lhs_subset = lhs_rs.choice(ds.nb, 50, replace=False).astype("int64") + rhs_rs = np.random.RandomState(456) + rhs_subset = rhs_rs.choice(ds.nb, 20, replace=False).astype("int64") + subset = np.union1d(lhs_subset, rhs_subset) + elif id_selector_type == "and": + lhs_rs = np.random.RandomState(123) + lhs_subset = lhs_rs.choice(ds.nb, 50, replace=False).astype("int64") + rhs_rs = np.random.RandomState(456) + rhs_subset = rhs_rs.choice(ds.nb, 10, replace=False).astype("int64") + subset = np.intersect1d(lhs_subset, rhs_subset) + elif id_selector_type == "xor": + lhs_rs = np.random.RandomState(123) + lhs_subset = lhs_rs.choice(ds.nb, 50, replace=False).astype("int64") + rhs_rs = np.random.RandomState(456) + rhs_subset = rhs_rs.choice(ds.nb, 40, replace=False).astype("int64") + subset = np.setxor1d(lhs_subset, rhs_subset) else: rs = np.random.RandomState(123) subset = rs.choice(ds.nb, 50, replace=False).astype("int64") @@ -81,6 +99,21 @@ class TestSelector(unittest.TestCase): if i not in ssubset ]).astype('int64') sel = faiss.IDSelectorNot(faiss.IDSelectorBatch(inverse_subset)) + elif id_selector_type == "or": + sel = faiss.IDSelectorOr( + faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(rhs_subset) + ) + elif id_selector_type == "and": + sel = faiss.IDSelectorAnd( + faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(rhs_subset) + ) + elif id_selector_type == "xor": + sel = faiss.IDSelectorXOr( + faiss.IDSelectorBatch(lhs_subset), + faiss.IDSelectorBatch(rhs_subset) + ) else: sel = faiss.IDSelectorBatch(subset) @@ -148,6 +181,9 @@ class TestSelector(unittest.TestCase): def test_Flat_id_not(self): self.do_test_id_selector("Flat", id_selector_type="not") + + def test_Flat_id_or(self): + self.do_test_id_selector("Flat", id_selector_type="or") # not implemented