diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index e67df6f4c..52e8410a5 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -105,6 +105,23 @@ def add_ref_in_method(the_class, method_name, parameter_no): setattr(the_class, method_name, replacement_method) +def add_ref_in_method_explicit_own(the_class, method_name): + # for methods of format set_XXX(object, own) + original_method = getattr(the_class, method_name) + + def replacement_method(self, ref, own=False): + if not own: + if not hasattr(self, 'referenced_objects'): + self.referenced_objects = [ref] + else: + self.referenced_objects.append(ref) + else: + # transfer ownership to C++ class + ref.this.disown() + return original_method(self, ref, own) + setattr(the_class, method_name, replacement_method) + + def add_ref_in_function(function_name, parameter_no): # assumes the function returns an object original_function = getattr(this_module, function_name) @@ -128,6 +145,9 @@ add_ref_in_constructor(IndexIVFResidualQuantizer, 0) add_ref_in_constructor(IndexIVFLocalSearchQuantizer, 0) add_ref_in_constructor(IndexIVFResidualQuantizerFastScan, 0) add_ref_in_constructor(IndexIVFLocalSearchQuantizerFastScan, 0) +add_ref_in_constructor(IndexIVFSpectralHash, 0) +add_ref_in_method_explicit_own(IndexIVFSpectralHash, "replace_vt") + add_ref_in_constructor(Index2Layer, 0) add_ref_in_constructor(Level1Quantizer, 0) add_ref_in_constructor(IndexIVFScalarQuantizer, 0) diff --git a/tests/test_factory.py b/tests/test_factory.py index 735d25e3c..d26b8f4c6 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -6,6 +6,7 @@ import numpy as np import unittest +import gc import faiss from faiss.contrib import factory_tools @@ -296,3 +297,17 @@ class TestQuantizerClone(unittest.TestCase): codes2 = quant2.compute_codes(ds.get_database()) np.testing.assert_array_equal(codes, codes2) + + +class TestIVFSpectralHashOwnerhsip(unittest.TestCase): + + def test_constructor(self): + index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1) + gc.collect() + index.quantizer.ntotal # this should not crash + + def test_replace_vt(self): + index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1) + index.replace_vt(faiss.ITQTransform(10, 10)) + gc.collect() + index.vt.d_out # this should not crash