From a64c76fadd5227deda291d124a913859f3684c1c Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Tue, 4 Oct 2022 07:54:00 -0700 Subject: [PATCH] Fix sub-object ownership of python interface of IVFSpectralHash Summary: Code would crash when deallocating the coarse quantizer for a IVFSpectralHash. Reviewed By: algoriddle Differential Revision: D40053030 fbshipit-source-id: 6a2987a6983f0e5fc5c5b6296d9000354176af83 --- faiss/python/__init__.py | 20 ++++++++++++++++++++ tests/test_factory.py | 15 +++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index e67df6f4c..52e8410a5 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -105,6 +105,23 @@ def add_ref_in_method(the_class, method_name, parameter_no): setattr(the_class, method_name, replacement_method) +def add_ref_in_method_explicit_own(the_class, method_name): + # for methods of format set_XXX(object, own) + original_method = getattr(the_class, method_name) + + def replacement_method(self, ref, own=False): + if not own: + if not hasattr(self, 'referenced_objects'): + self.referenced_objects = [ref] + else: + self.referenced_objects.append(ref) + else: + # transfer ownership to C++ class + ref.this.disown() + return original_method(self, ref, own) + setattr(the_class, method_name, replacement_method) + + def add_ref_in_function(function_name, parameter_no): # assumes the function returns an object original_function = getattr(this_module, function_name) @@ -128,6 +145,9 @@ add_ref_in_constructor(IndexIVFResidualQuantizer, 0) add_ref_in_constructor(IndexIVFLocalSearchQuantizer, 0) add_ref_in_constructor(IndexIVFResidualQuantizerFastScan, 0) add_ref_in_constructor(IndexIVFLocalSearchQuantizerFastScan, 0) +add_ref_in_constructor(IndexIVFSpectralHash, 0) +add_ref_in_method_explicit_own(IndexIVFSpectralHash, "replace_vt") + add_ref_in_constructor(Index2Layer, 0) add_ref_in_constructor(Level1Quantizer, 0) add_ref_in_constructor(IndexIVFScalarQuantizer, 0) diff --git a/tests/test_factory.py b/tests/test_factory.py index 735d25e3c..d26b8f4c6 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -6,6 +6,7 @@ import numpy as np import unittest +import gc import faiss from faiss.contrib import factory_tools @@ -296,3 +297,17 @@ class TestQuantizerClone(unittest.TestCase): codes2 = quant2.compute_codes(ds.get_database()) np.testing.assert_array_equal(codes, codes2) + + +class TestIVFSpectralHashOwnerhsip(unittest.TestCase): + + def test_constructor(self): + index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1) + gc.collect() + index.quantizer.ntotal # this should not crash + + def test_replace_vt(self): + index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1) + index.replace_vt(faiss.ITQTransform(10, 10)) + gc.collect() + index.vt.d_out # this should not crash