Fix sub-object ownership of python interface of IVFSpectralHash

Summary: Code would crash when deallocating the coarse quantizer for a IVFSpectralHash.

Reviewed By: algoriddle

Differential Revision: D40053030

fbshipit-source-id: 6a2987a6983f0e5fc5c5b6296d9000354176af83
pull/2512/head
Matthijs Douze 2022-10-04 07:54:00 -07:00 committed by Facebook GitHub Bot
parent c5b49b79df
commit a64c76fadd
2 changed files with 35 additions and 0 deletions

View File

@ -105,6 +105,23 @@ def add_ref_in_method(the_class, method_name, parameter_no):
setattr(the_class, method_name, replacement_method)
def add_ref_in_method_explicit_own(the_class, method_name):
# for methods of format set_XXX(object, own)
original_method = getattr(the_class, method_name)
def replacement_method(self, ref, own=False):
if not own:
if not hasattr(self, 'referenced_objects'):
self.referenced_objects = [ref]
else:
self.referenced_objects.append(ref)
else:
# transfer ownership to C++ class
ref.this.disown()
return original_method(self, ref, own)
setattr(the_class, method_name, replacement_method)
def add_ref_in_function(function_name, parameter_no):
# assumes the function returns an object
original_function = getattr(this_module, function_name)
@ -128,6 +145,9 @@ add_ref_in_constructor(IndexIVFResidualQuantizer, 0)
add_ref_in_constructor(IndexIVFLocalSearchQuantizer, 0)
add_ref_in_constructor(IndexIVFResidualQuantizerFastScan, 0)
add_ref_in_constructor(IndexIVFLocalSearchQuantizerFastScan, 0)
add_ref_in_constructor(IndexIVFSpectralHash, 0)
add_ref_in_method_explicit_own(IndexIVFSpectralHash, "replace_vt")
add_ref_in_constructor(Index2Layer, 0)
add_ref_in_constructor(Level1Quantizer, 0)
add_ref_in_constructor(IndexIVFScalarQuantizer, 0)

View File

@ -6,6 +6,7 @@
import numpy as np
import unittest
import gc
import faiss
from faiss.contrib import factory_tools
@ -296,3 +297,17 @@ class TestQuantizerClone(unittest.TestCase):
codes2 = quant2.compute_codes(ds.get_database())
np.testing.assert_array_equal(codes, codes2)
class TestIVFSpectralHashOwnerhsip(unittest.TestCase):
def test_constructor(self):
index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1)
gc.collect()
index.quantizer.ntotal # this should not crash
def test_replace_vt(self):
index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1)
index.replace_vt(faiss.ITQTransform(10, 10))
gc.collect()
index.vt.d_out # this should not crash