3893 - Fix index factory order of idmap and refinement (#3928)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3928 Fix issue in T203425107 Reviewed By: asadoughi Differential Revision: D64068971 fbshipit-source-id: 56db439793539570a102773ff2c7158d48feb7a9pull/3929/head
parent
c5aed7c359
commit
af70c5bcce
|
@ -679,6 +679,24 @@ std::unique_ptr<Index> index_factory_sub(
|
|||
// for the current match
|
||||
std::smatch sm;
|
||||
|
||||
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
||||
// support both
|
||||
if (re_match(description, "(.+),IDMap2", sm) ||
|
||||
re_match(description, "IDMap2,(.+)", sm)) {
|
||||
IndexIDMap2* idmap2 = new IndexIDMap2(
|
||||
index_factory_sub(d, sm[1].str(), metric).release());
|
||||
idmap2->own_fields = true;
|
||||
return std::unique_ptr<Index>(idmap2);
|
||||
}
|
||||
|
||||
if (re_match(description, "(.+),IDMap", sm) ||
|
||||
re_match(description, "IDMap,(.+)", sm)) {
|
||||
IndexIDMap* idmap = new IndexIDMap(
|
||||
index_factory_sub(d, sm[1].str(), metric).release());
|
||||
idmap->own_fields = true;
|
||||
return std::unique_ptr<Index>(idmap);
|
||||
}
|
||||
|
||||
// handle refines
|
||||
if (re_match(description, "(.+),RFlat", sm) ||
|
||||
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
|
||||
|
@ -755,24 +773,6 @@ std::unique_ptr<Index> index_factory_sub(
|
|||
d);
|
||||
}
|
||||
|
||||
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
||||
// support both
|
||||
if (re_match(description, "(.+),IDMap2", sm) ||
|
||||
re_match(description, "IDMap2,(.+)", sm)) {
|
||||
IndexIDMap2* idmap2 = new IndexIDMap2(
|
||||
index_factory_sub(d, sm[1].str(), metric).release());
|
||||
idmap2->own_fields = true;
|
||||
return std::unique_ptr<Index>(idmap2);
|
||||
}
|
||||
|
||||
if (re_match(description, "(.+),IDMap", sm) ||
|
||||
re_match(description, "IDMap,(.+)", sm)) {
|
||||
IndexIDMap* idmap = new IndexIDMap(
|
||||
index_factory_sub(d, sm[1].str(), metric).release());
|
||||
idmap->own_fields = true;
|
||||
return std::unique_ptr<Index>(idmap);
|
||||
}
|
||||
|
||||
{ // handle basic index types
|
||||
Index* index = parse_other_indexes(description, d, metric);
|
||||
if (index) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import faiss
|
|||
from faiss.contrib import factory_tools
|
||||
from faiss.contrib import datasets
|
||||
|
||||
|
||||
class TestFactory(unittest.TestCase):
|
||||
|
||||
def test_factory_1(self):
|
||||
|
@ -40,7 +41,6 @@ class TestFactory(unittest.TestCase):
|
|||
index = faiss.index_factory(12, "SQ8")
|
||||
assert index.code_size == 12
|
||||
|
||||
|
||||
def test_factory_3(self):
|
||||
|
||||
index = faiss.index_factory(12, "IVF10,PQ4")
|
||||
|
@ -73,7 +73,8 @@ class TestFactory(unittest.TestCase):
|
|||
def test_factory_HNSW_newstyle(self):
|
||||
index = faiss.index_factory(12, "HNSW32,Flat")
|
||||
assert index.storage.sa_code_size() == 12 * 4
|
||||
index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
|
||||
index = faiss.index_factory(12, "HNSW32,SQ8",
|
||||
faiss.METRIC_INNER_PRODUCT)
|
||||
assert index.storage.sa_code_size() == 12
|
||||
assert index.metric_type == faiss.METRIC_INNER_PRODUCT
|
||||
index = faiss.index_factory(12, "HNSW,PQ4")
|
||||
|
@ -131,7 +132,8 @@ class TestFactory(unittest.TestCase):
|
|||
self.assertEqual(index.pq.nbits, 4)
|
||||
index = faiss.index_factory(56, "PQ28x4fs_64")
|
||||
self.assertEqual(index.bbs, 64)
|
||||
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", faiss.METRIC_INNER_PRODUCT)
|
||||
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64",
|
||||
faiss.METRIC_INNER_PRODUCT)
|
||||
self.assertEqual(index.bbs, 64)
|
||||
self.assertEqual(index.nlist, 50)
|
||||
self.assertTrue(index.cp.spherical)
|
||||
|
@ -158,7 +160,6 @@ class TestFactory(unittest.TestCase):
|
|||
self.assertEqual(rf.pq.M, 25)
|
||||
self.assertEqual(rf.pq.nbits, 12)
|
||||
|
||||
|
||||
def test_parenthesis_refine_2(self):
|
||||
# Refine applies on the whole index including pre-transforms
|
||||
index = faiss.index_factory(50, "PCA32,IVF32,Flat,Refine(PQ25x12)")
|
||||
|
@ -264,6 +265,19 @@ class TestFactoryV2(unittest.TestCase):
|
|||
index = faiss.downcast_index(index)
|
||||
self.assertEqual(index.__class__, faiss.IndexIDMap2)
|
||||
|
||||
def test_idmap_refine(self):
|
||||
index = faiss.index_factory(8, "IDMap,PQ4x4fs,RFlat")
|
||||
self.assertEqual(index.__class__, faiss.IndexIDMap)
|
||||
refine_index = faiss.downcast_index(index.index)
|
||||
self.assertEqual(refine_index.__class__, faiss.IndexRefineFlat)
|
||||
base_index = faiss.downcast_index(refine_index.base_index)
|
||||
self.assertEqual(base_index.__class__, faiss.IndexPQFastScan)
|
||||
|
||||
# Index now works with add_with_ids, but not with add
|
||||
index.train(np.zeros((16, 8)))
|
||||
index.add_with_ids(np.zeros((16, 8)), np.arange(16))
|
||||
self.assertRaises(RuntimeError, index.add, np.zeros((16, 8)))
|
||||
|
||||
def test_ivf_hnsw(self):
|
||||
index = faiss.index_factory(123, "IVF100_HNSW,Flat")
|
||||
quantizer = faiss.downcast_index(index.quantizer)
|
||||
|
@ -337,4 +351,4 @@ class TestIVFSpectralHashOwnership(unittest.TestCase):
|
|||
index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1)
|
||||
index.replace_vt(faiss.ITQTransform(10, 10))
|
||||
gc.collect()
|
||||
index.vt.d_out # this should not crash
|
||||
index.vt.d_out # this should not crash
|
||||
|
|
Loading…
Reference in New Issue