Automatic type conversions for Python API (#2274)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2274 All input matrices needed to be of the correct type and to be C-contiguous. This diff passes the main entry points of the api through `np.ascontiguousarray` so that the function parameters are transparently converted to the suitable format if needed. We did not have this before because users need to be made aware of the performance impact, but it seems that maybe usability is more useful. This diff is an alternative to D35007365 https://github.com/facebookresearch/faiss/pull/2250 Reviewed By: beauby Differential Revision: D35009612 fbshipit-source-id: fa0d5cfdfbff6b0916d47bd33c620e3ca9d5dd40pull/2283/head
parent
b32abc95c2
commit
1806c6af27
|
@ -28,6 +28,16 @@ __version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
|
|||
# The C++ version of the classnames will be suffixed with _c
|
||||
##################################################################
|
||||
|
||||
# For most arrays we force the convesion to the target type with
|
||||
# np.ascontiguousarray, but for uint8 codes, we raise a type error
|
||||
# because it is unclear how the conversion should occur: with a view
|
||||
# (= cast) or conversion?
|
||||
def _check_dtype_uint8(codes):
|
||||
if codes.dtype != 'uint8':
|
||||
raise TypeError("Input argument %s must be ndarray of dtype "
|
||||
" uint8, but found %s" % ("x", x.dtype))
|
||||
return np.ascontiguousarray(codes)
|
||||
|
||||
|
||||
def replace_method(the_class, name, replacement, ignore_missing=False):
|
||||
""" Replaces a method in a class with another version. The old method
|
||||
|
@ -60,8 +70,10 @@ def handle_Clustering():
|
|||
average to obtain the centroid (default is 1 for all training vectors).
|
||||
"""
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d
|
||||
if weights is not None:
|
||||
weights = np.ascontiguousarray(weights, dtype='float32')
|
||||
assert weights.shape == (n, )
|
||||
self.train_c(n, swig_ptr(x), index, swig_ptr(weights))
|
||||
else:
|
||||
|
@ -84,9 +96,11 @@ def handle_Clustering():
|
|||
average to obtain the centroid (default is 1 for all training vectors).
|
||||
"""
|
||||
n, d = x.shape
|
||||
x = _check_dtype_uint8(x)
|
||||
assert d == codec.sa_code_size()
|
||||
assert codec.d == index.d
|
||||
if weights is not None:
|
||||
weights = np.ascontiguousarray(weights, dtype='float32')
|
||||
assert weights.shape == (n, )
|
||||
self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights))
|
||||
else:
|
||||
|
@ -110,6 +124,7 @@ def handle_Clustering1D():
|
|||
Training vectors, shape (n, 1). `dtype` must be float32.
|
||||
"""
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d
|
||||
self.train_exact_c(n, swig_ptr(x))
|
||||
|
||||
|
@ -130,6 +145,7 @@ def handle_Quantizer(the_class):
|
|||
Training vectors, shape (n, self.d). `dtype` must be float32.
|
||||
"""
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d
|
||||
self.train_c(n, swig_ptr(x))
|
||||
|
||||
|
@ -148,6 +164,7 @@ def handle_Quantizer(the_class):
|
|||
and `dtype` uint8.
|
||||
"""
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d
|
||||
codes = np.empty((n, self.code_size), dtype='uint8')
|
||||
self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n)
|
||||
|
@ -166,6 +183,7 @@ def handle_Quantizer(the_class):
|
|||
Reconstructed vectors for each code, shape `(n, d)` and `dtype` float32.
|
||||
"""
|
||||
n, cs = codes.shape
|
||||
codes = _check_dtype_uint8(codes)
|
||||
assert cs == self.code_size
|
||||
x = np.empty((n, self.d), dtype='float32')
|
||||
self.decode_c(swig_ptr(codes), swig_ptr(x), n)
|
||||
|
@ -190,6 +208,8 @@ def handle_NSG(the_class):
|
|||
assert graph.ndim == 2
|
||||
assert graph.shape[0] == n
|
||||
K = graph.shape[1]
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
graph = np.ascontiguousarray(graph, dtype='int64')
|
||||
self.build_c(n, swig_ptr(x), swig_ptr(graph), K)
|
||||
|
||||
replace_method(the_class, 'build', replacement_build)
|
||||
|
@ -212,6 +232,7 @@ def handle_Index(the_class):
|
|||
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
self.add_c(n, swig_ptr(x))
|
||||
|
||||
def replacement_add_with_ids(self, x, ids):
|
||||
|
@ -230,7 +251,8 @@ def handle_Index(the_class):
|
|||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
ids = np.ascontiguousarray(ids, dtype='int64')
|
||||
assert ids.shape == (n, ), 'not same nb of vectors as ids'
|
||||
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids))
|
||||
|
||||
|
@ -256,6 +278,7 @@ def handle_Index(the_class):
|
|||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
if labels is None:
|
||||
labels = np.empty((n, k), dtype=np.int64)
|
||||
|
@ -277,6 +300,7 @@ def handle_Index(the_class):
|
|||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
self.train_c(n, swig_ptr(x))
|
||||
|
||||
def replacement_search(self, x, k, D=None, I=None):
|
||||
|
@ -305,6 +329,7 @@ def handle_Index(the_class):
|
|||
"""
|
||||
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d
|
||||
|
||||
assert k > 0
|
||||
|
@ -353,6 +378,7 @@ def handle_Index(the_class):
|
|||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
assert k > 0
|
||||
|
||||
|
@ -398,6 +424,7 @@ def handle_Index(the_class):
|
|||
else:
|
||||
assert x.ndim == 1
|
||||
index_ivf = try_extract_index_ivf (self)
|
||||
x = np.ascontiguousarray(x, dtype='int64')
|
||||
if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable:
|
||||
sel = IDSelectorArray(x.size, swig_ptr(x))
|
||||
else:
|
||||
|
@ -457,7 +484,8 @@ def handle_Index(the_class):
|
|||
n = keys.size
|
||||
assert keys.shape == (n, )
|
||||
assert x.shape == (n, self.d)
|
||||
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
keys = np.ascontiguousarray(keys, dtype='int64')
|
||||
self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x))
|
||||
|
||||
# The CPU does not support passed-in output buffers
|
||||
|
@ -488,6 +516,7 @@ def handle_Index(the_class):
|
|||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
res = RangeSearchResult(n)
|
||||
self.range_search_c(n, swig_ptr(x), thresh, res)
|
||||
|
@ -501,6 +530,7 @@ def handle_Index(the_class):
|
|||
def replacement_sa_encode(self, x, codes=None):
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
if codes is None:
|
||||
codes = np.empty((n, self.sa_code_size()), dtype=np.uint8)
|
||||
|
@ -513,6 +543,7 @@ def handle_Index(the_class):
|
|||
def replacement_sa_decode(self, codes, x=None):
|
||||
n, cs = codes.shape
|
||||
assert cs == self.sa_code_size()
|
||||
codes = _check_dtype_uint8(codes)
|
||||
|
||||
if x is None:
|
||||
x = np.empty((n, self.d), dtype=np.float32)
|
||||
|
@ -525,6 +556,8 @@ def handle_Index(the_class):
|
|||
def replacement_add_sa_codes(self, codes, ids=None):
|
||||
n, cs = codes.shape
|
||||
assert cs == self.sa_code_size()
|
||||
codes = _check_dtype_uint8(codes)
|
||||
|
||||
if ids is not None:
|
||||
assert ids.shape == (n,)
|
||||
ids = swig_ptr(ids)
|
||||
|
@ -568,17 +601,21 @@ def handle_IndexBinary(the_class):
|
|||
|
||||
def replacement_add(self, x):
|
||||
n, d = x.shape
|
||||
x = _check_dtype_uint8(x)
|
||||
assert d * 8 == self.d
|
||||
self.add_c(n, swig_ptr(x))
|
||||
|
||||
def replacement_add_with_ids(self, x, ids):
|
||||
n, d = x.shape
|
||||
x = _check_dtype_uint8(x)
|
||||
ids = np.ascontiguousarray(ids, dtype='int64')
|
||||
assert d * 8 == self.d
|
||||
assert ids.shape == (n, ), 'not same nb of vectors as ids'
|
||||
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids))
|
||||
|
||||
def replacement_train(self, x):
|
||||
n, d = x.shape
|
||||
x = _check_dtype_uint8(x)
|
||||
assert d * 8 == self.d
|
||||
self.train_c(n, swig_ptr(x))
|
||||
|
||||
|
@ -588,6 +625,7 @@ def handle_IndexBinary(the_class):
|
|||
return x
|
||||
|
||||
def replacement_search(self, x, k):
|
||||
x = _check_dtype_uint8(x)
|
||||
n, d = x.shape
|
||||
assert d * 8 == self.d
|
||||
assert k > 0
|
||||
|
@ -600,6 +638,7 @@ def handle_IndexBinary(the_class):
|
|||
|
||||
def replacement_range_search(self, x, thresh):
|
||||
n, d = x.shape
|
||||
x = _check_dtype_uint8(x)
|
||||
assert d * 8 == self.d
|
||||
res = RangeSearchResult(n)
|
||||
self.range_search_c(n, swig_ptr(x), thresh, res)
|
||||
|
@ -615,6 +654,7 @@ def handle_IndexBinary(the_class):
|
|||
sel = x
|
||||
else:
|
||||
assert x.ndim == 1
|
||||
x = np.ascontiguousarray(x, dtype='int64')
|
||||
sel = IDSelectorBatch(x.size, swig_ptr(x))
|
||||
return self.remove_ids_c(sel)
|
||||
|
||||
|
@ -631,6 +671,7 @@ def handle_VectorTransform(the_class):
|
|||
|
||||
def apply_method(self, x):
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d_in
|
||||
y = np.empty((n, self.d_out), dtype=np.float32)
|
||||
self.apply_noalloc(n, swig_ptr(x), swig_ptr(y))
|
||||
|
@ -638,6 +679,7 @@ def handle_VectorTransform(the_class):
|
|||
|
||||
def replacement_reverse_transform(self, x):
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d_out
|
||||
y = np.empty((n, self.d_in), dtype=np.float32)
|
||||
self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y))
|
||||
|
@ -645,6 +687,7 @@ def handle_VectorTransform(the_class):
|
|||
|
||||
def replacement_vt_train(self, x):
|
||||
n, d = x.shape
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert d == self.d_in
|
||||
self.train_c(n, swig_ptr(x))
|
||||
|
||||
|
@ -676,6 +719,7 @@ def handle_AutoTuneCriterion(the_class):
|
|||
def handle_ParameterSpace(the_class):
|
||||
def replacement_explore(self, index, xq, crit):
|
||||
assert xq.shape == (crit.nq, index.d)
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
ops = OperatingPoints()
|
||||
self.explore_c(index, crit.nq, swig_ptr(xq),
|
||||
crit, ops)
|
||||
|
@ -688,6 +732,7 @@ def handle_MatrixStats(the_class):
|
|||
|
||||
def replacement_init(self, m):
|
||||
assert len(m.shape) == 2
|
||||
m = np.ascontiguousarray(m, dtype='float32')
|
||||
original_init(self, m.shape[0], m.shape[1], swig_ptr(m))
|
||||
|
||||
the_class.__init__ = replacement_init
|
||||
|
@ -937,7 +982,8 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2):
|
|||
xq = xq.T
|
||||
xq_row_major = False
|
||||
else:
|
||||
raise TypeError('xq matrix should be row (C) or column-major (Fortran)')
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
xq_row_major = True
|
||||
|
||||
xq_ptr = swig_ptr(xq)
|
||||
|
||||
|
@ -956,7 +1002,8 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2):
|
|||
xb = xb.T
|
||||
xb_row_major = False
|
||||
else:
|
||||
raise TypeError('xb matrix should be row (C) or column-major (Fortran)')
|
||||
xb = np.ascontiguousarray(xb, dtype='float32')
|
||||
xb_row_major = True
|
||||
|
||||
xb_ptr = swig_ptr(xb)
|
||||
|
||||
|
@ -1053,7 +1100,8 @@ def pairwise_distance_gpu(res, xq, xb, D=None, metric=METRIC_L2):
|
|||
elif xq.dtype == np.float16:
|
||||
xq_type = DistanceDataType_F16
|
||||
else:
|
||||
raise TypeError('xq must be float32 or float16')
|
||||
xq = np.ascontiguousarray(xb, dtype='float32')
|
||||
xq_row_major = True
|
||||
|
||||
nb, d2 = xb.shape
|
||||
assert d2 == d
|
||||
|
@ -1063,7 +1111,8 @@ def pairwise_distance_gpu(res, xq, xb, D=None, metric=METRIC_L2):
|
|||
xb = xb.T
|
||||
xb_row_major = False
|
||||
else:
|
||||
raise TypeError('xb matrix should be row (C) or column-major (Fortran)')
|
||||
xb = np.ascontiguousarray(xb, dtype='float32')
|
||||
xb_row_major = True
|
||||
|
||||
xb_ptr = swig_ptr(xb)
|
||||
|
||||
|
@ -1212,6 +1261,7 @@ def AlignedTable_to_array(v):
|
|||
def kmin(array, k):
|
||||
"""return k smallest values (and their indices) of the lines of a
|
||||
float32 array"""
|
||||
array = np.ascontiguousarray(array, dtype='float32')
|
||||
m, n = array.shape
|
||||
I = np.zeros((m, k), dtype='int64')
|
||||
D = np.zeros((m, k), dtype='float32')
|
||||
|
@ -1229,6 +1279,7 @@ def kmin(array, k):
|
|||
def kmax(array, k):
|
||||
"""return k largest values (and their indices) of the lines of a
|
||||
float32 array"""
|
||||
array = np.ascontiguousarray(array, dtype='float32')
|
||||
m, n = array.shape
|
||||
I = np.zeros((m, k), dtype='int64')
|
||||
D = np.zeros((m, k), dtype='float32')
|
||||
|
@ -1246,6 +1297,8 @@ def kmax(array, k):
|
|||
def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0):
|
||||
"""compute the whole pairwise distance matrix between two sets of
|
||||
vectors"""
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
xb = np.ascontiguousarray(xb, dtype='float32')
|
||||
nq, d = xq.shape
|
||||
nb, d2 = xb.shape
|
||||
assert d == d2
|
||||
|
@ -1296,6 +1349,8 @@ def rand_smooth_vectors(n, d, seed=1234):
|
|||
|
||||
def eval_intersection(I1, I2):
|
||||
""" size of intersection between each line of two result tables"""
|
||||
I1 = np.ascontiguousarray(I1, dtype='int64')
|
||||
I2 = np.ascontiguousarray(I2, dtype='int64')
|
||||
n = I1.shape[0]
|
||||
assert I2.shape[0] == n
|
||||
k1, k2 = I1.shape[1], I2.shape[1]
|
||||
|
@ -1334,6 +1389,7 @@ replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple)
|
|||
search_with_parameters_c = search_with_parameters
|
||||
|
||||
def search_with_parameters(index, x, k, params=None, output_stats=False):
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
n, d = x.shape
|
||||
assert d == index.d
|
||||
if not params:
|
||||
|
@ -1366,6 +1422,7 @@ def search_with_parameters(index, x, k, params=None, output_stats=False):
|
|||
range_search_with_parameters_c = range_search_with_parameters
|
||||
|
||||
def range_search_with_parameters(index, x, radius, params=None, output_stats=False):
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
n, d = x.shape
|
||||
assert d == index.d
|
||||
if not params:
|
||||
|
@ -1427,6 +1484,8 @@ def knn(xq, xb, k, metric=METRIC_L2):
|
|||
I : array_like
|
||||
Labels of the nearest neighbors, shape (nq, k)
|
||||
"""
|
||||
xq = np.ascontiguousarray(xq, dtype='float32')
|
||||
xb = np.ascontiguousarray(xb, dtype='float32')
|
||||
nq, d = xq.shape
|
||||
nb, d2 = xb.shape
|
||||
assert d == d2
|
||||
|
@ -1537,6 +1596,7 @@ class Kmeans:
|
|||
final optimization objective
|
||||
|
||||
"""
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
|
||||
|
@ -1581,6 +1641,7 @@ class Kmeans:
|
|||
return self.obj[-1] if self.obj.size > 0 else 0.0
|
||||
|
||||
def assign(self, x):
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
assert self.centroids is not None, "should train before assigning"
|
||||
self.index.reset()
|
||||
self.index.add(self.centroids)
|
||||
|
@ -1647,6 +1708,8 @@ class ResultHeap:
|
|||
def add_result(self, D, I):
|
||||
"""D, I do not need to be in a particular order (heap or sorted)"""
|
||||
nq, kd = D.shape
|
||||
D = np.ascontiguousarray(D, dtype='float32')
|
||||
I = np.ascontiguousarray(I, dtype='int64')
|
||||
assert I.shape == (nq, kd)
|
||||
assert nq == self.nq
|
||||
self.heaps.addn_with_ids(
|
||||
|
|
|
@ -340,5 +340,6 @@ class TestTorchUtilsCPU(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
index.add(xb)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
index.add(xb.numpy())
|
||||
# disabled since we now accept non-contiguous arrays
|
||||
# with self.assertRaises(ValueError):
|
||||
# index.add(xb.numpy())
|
||||
|
|
Loading…
Reference in New Issue