selector parameter for FastScan (#3362)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3362

Add test to Alex' PR

Reviewed By: junjieqi

Differential Revision: D56003946

fbshipit-source-id: 5a8a881d450bc97ae0777d73ce0ce8607ec6b686
This commit is contained in:
Matthijs Douze 2024-04-11 14:23:46 -07:00 committed by Facebook GitHub Bot
parent 17fbeb8d7e
commit 40e8643336

View File

@ -22,7 +22,7 @@ class TestSelector(unittest.TestCase):
combinations as possible. combinations as possible.
""" """
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2): def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10):
""" Verify that the id selector returns the subset of results that are """ Verify that the id selector returns the subset of results that are
members according to the IDSelector. members according to the IDSelector.
Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor" Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor"
@ -30,7 +30,6 @@ class TestSelector(unittest.TestCase):
ds = datasets.SyntheticDataset(32, 1000, 100, 20) ds = datasets.SyntheticDataset(32, 1000, 100, 20)
index = faiss.index_factory(ds.d, index_key, mt) index = faiss.index_factory(ds.d, index_key, mt)
index.train(ds.get_train()) index.train(ds.get_train())
k = 10
# reference result # reference result
if "range" in id_selector_type: if "range" in id_selector_type:
@ -145,6 +144,16 @@ class TestSelector(unittest.TestCase):
def test_IVFPQ(self): def test_IVFPQ(self):
self.do_test_id_selector("IVF32,PQ4x4np") self.do_test_id_selector("IVF32,PQ4x4np")
def test_IVFPQfs(self):
self.do_test_id_selector("IVF32,PQ4x4fs")
def test_IVFPQfs_k1(self):
self.do_test_id_selector("IVF32,PQ4x4fs", k=1)
def test_IVFPQfs_k40(self):
# test reservoir codepath
self.do_test_id_selector("IVF32,PQ4x4fs", k=40)
def test_IVFSQ(self): def test_IVFSQ(self):
self.do_test_id_selector("IVF32,SQ8") self.do_test_id_selector("IVF32,SQ8")