faiss/tests/test_partition.py

348 lines
10 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
import unittest
class PartitionTests:
def test_partition(self):
self.do_partition(160, 80)
def test_partition_manydups(self):
self.do_partition(160, 80, maxval=16)
def test_partition_lowq(self):
self.do_partition(160, 10, maxval=16)
def test_partition_highq(self):
self.do_partition(165, 155, maxval=16)
def test_partition_q10(self):
self.do_partition(32, 10, maxval=500)
def test_partition_q10_dups(self):
self.do_partition(32, 10, maxval=16)
def test_partition_q10_fuzzy(self):
self.do_partition(32, (10, 15), maxval=500)
def test_partition_fuzzy(self):
self.do_partition(160, (70, 80), maxval=500)
def test_partition_fuzzy_2(self):
self.do_partition(160, (70, 80))
class TestPartitioningFloat(unittest.TestCase, PartitionTests):
def do_partition(self, n, q, maxval=None, seed=None):
if seed is None:
for i in range(50):
self.do_partition(n, q, maxval, i + 1234)
# print("seed=", seed)
rs = np.random.RandomState(seed)
if maxval is None:
vals = rs.rand(n).astype('float32')
else:
vals = rs.randint(maxval, size=n).astype('float32')
ids = (rs.permutation(n) + 12345).astype('int64')
dic = dict(zip(ids, vals))
vals_orig = vals.copy()
sp = faiss.swig_ptr
if type(q) == int:
faiss.CMax_float_partition_fuzzy(
sp(vals), sp(ids), n,
q, q, None
)
else:
q_min, q_max = q
q = np.array([-1], dtype='uint64')
faiss.CMax_float_partition_fuzzy(
sp(vals), sp(ids), n,
q_min, q_max, sp(q)
)
q = q[0]
assert q_min <= q <= q_max
o = vals_orig.argsort()
thresh = vals_orig[o[q]]
n_eq = (vals_orig[o[:q]] == thresh).sum()
for i in range(q):
self.assertEqual(vals[i], dic[ids[i]])
self.assertLessEqual(vals[i], thresh)
if vals[i] == thresh:
n_eq -= 1
self.assertEqual(n_eq, 0)
class TestPartitioningFloatMin(unittest.TestCase, PartitionTests):
def do_partition(self, n, q, maxval=None, seed=None):
if seed is None:
for i in range(50):
self.do_partition(n, q, maxval, i + 1234)
# print("seed=", seed)
rs = np.random.RandomState(seed)
if maxval is None:
vals = rs.rand(n).astype('float32')
mirval = 1.0
else:
vals = rs.randint(maxval, size=n).astype('float32')
mirval = 65536
ids = (rs.permutation(n) + 12345).astype('int64')
dic = dict(zip(ids, vals))
vals_orig = vals.copy()
vals[:] = mirval - vals
sp = faiss.swig_ptr
if type(q) == int:
faiss.CMin_float_partition_fuzzy(
sp(vals), sp(ids), n,
q, q, None
)
else:
q_min, q_max = q
q = np.array([-1], dtype='uint64')
faiss.CMin_float_partition_fuzzy(
sp(vals), sp(ids), n,
q_min, q_max, sp(q)
)
q = q[0]
assert q_min <= q <= q_max
vals[:] = mirval - vals
o = vals_orig.argsort()
thresh = vals_orig[o[q]]
n_eq = (vals_orig[o[:q]] == thresh).sum()
for i in range(q):
np.testing.assert_almost_equal(vals[i], dic[ids[i]], decimal=5)
self.assertLessEqual(vals[i], thresh)
if vals[i] == thresh:
n_eq -= 1
self.assertEqual(n_eq, 0)
class TestPartitioningUint16(unittest.TestCase, PartitionTests):
def do_partition(self, n, q, maxval=65536, seed=None):
if seed is None:
for i in range(50):
self.do_partition(n, q, maxval, i + 1234)
# print("seed=", seed)
rs = np.random.RandomState(seed)
vals = rs.randint(maxval, size=n).astype('uint16')
ids = (rs.permutation(n) + 12345).astype('int64')
dic = dict(zip(ids, vals))
sp = faiss.swig_ptr
vals_orig = vals.copy()
tab_a = faiss.AlignedTableUint16()
faiss.copy_array_to_AlignedTable(vals, tab_a)
# print("tab a type", tab_a.get())
if type(q) == int:
faiss.CMax_uint16_partition_fuzzy(
tab_a.get(), sp(ids), n, q, q, None)
else:
q_min, q_max = q
q = np.array([-1], dtype='uint64')
faiss.CMax_uint16_partition_fuzzy(
tab_a.get(), sp(ids), n,
q_min, q_max, sp(q)
)
q = q[0]
assert q_min <= q <= q_max
vals = faiss.AlignedTable_to_array(tab_a)
o = vals_orig.argsort()
thresh = vals_orig[o[q]]
n_eq = (vals_orig[o[:q]] == thresh).sum()
for i in range(q):
self.assertEqual(vals[i], dic[ids[i]])
self.assertLessEqual(vals[i], thresh)
if vals[i] == thresh:
n_eq -= 1
self.assertEqual(n_eq, 0)
class TestPartitioningUint16Min(unittest.TestCase, PartitionTests):
def do_partition(self, n, q, maxval=65536, seed=None):
#seed = 1235
if seed is None:
for i in range(50):
self.do_partition(n, q, maxval, i + 1234)
# print("seed=", seed)
rs = np.random.RandomState(seed)
vals = rs.randint(maxval, size=n).astype('uint16')
ids = (rs.permutation(n) + 12345).astype('int64')
dic = dict(zip(ids, vals))
sp = faiss.swig_ptr
vals_orig = vals.copy()
tab_a = faiss.AlignedTableUint16()
vals_inv = (65535 - vals).astype('uint16')
faiss.copy_array_to_AlignedTable(vals_inv, tab_a)
# print("tab a type", tab_a.get())
if type(q) == int:
faiss.CMin_uint16_partition_fuzzy(
tab_a.get(), sp(ids), n, q, q, None)
else:
q_min, q_max = q
q = np.array([-1], dtype='uint64')
thresh2 = faiss.CMin_uint16_partition_fuzzy(
tab_a.get(), sp(ids), n,
q_min, q_max, sp(q)
)
q = q[0]
assert q_min <= q <= q_max
vals_inv = faiss.AlignedTable_to_array(tab_a)
vals = 65535 - vals_inv
o = vals_orig.argsort()
thresh = vals_orig[o[q]]
n_eq = (vals_orig[o[:q]] == thresh).sum()
for i in range(q):
self.assertEqual(vals[i], dic[ids[i]])
self.assertLessEqual(vals[i], thresh)
if vals[i] == thresh:
n_eq -= 1
self.assertEqual(n_eq, 0)
class TestHistograms(unittest.TestCase):
def do_test(self, nbin, n):
rs = np.random.RandomState(123)
tab = rs.randint(nbin, size=n).astype('uint16')
ref_histogram = np.bincount(tab, minlength=nbin)
tab_a = faiss.AlignedTableUint16()
faiss.copy_array_to_AlignedTable(tab, tab_a)
sp = faiss.swig_ptr
hist = np.zeros(nbin, 'int32')
if nbin == 8:
faiss.simd_histogram_8(tab_a.get(), n, 0, -1, sp(hist))
elif nbin == 16:
faiss.simd_histogram_16(tab_a.get(), n, 0, -1, sp(hist))
else:
raise AssertionError()
np.testing.assert_array_equal(hist, ref_histogram)
def test_8bin_even(self):
self.do_test(8, 5 * 16)
def test_8bin_odd(self):
self.do_test(8, 123)
def test_16bin_even(self):
self.do_test(16, 5 * 16)
def test_16bin_odd(self):
self.do_test(16, 123)
def do_test_bounded(self, nbin, n, shift=2, minv=500, rspan=None, seed=None):
if seed is None:
for run in range(50):
self.do_test_bounded(nbin, n, shift, minv, rspan, seed=123 + run)
return
if rspan is None:
rmin, rmax = 0, nbin * 6
else:
rmin, rmax = rspan
rs = np.random.RandomState(seed)
tab = rs.randint(rmin, rmax, size=n).astype('uint16')
bc = np.bincount(tab, minlength=65536)
binsize = 1 << shift
ref_histogram = bc[minv : minv + binsize * nbin]
def pad_and_reshape(x, m, n):
xout = np.zeros(m * n, dtype=x.dtype)
xout[:x.size] = x
return xout.reshape(m, n)
ref_histogram = pad_and_reshape(ref_histogram, nbin, binsize)
ref_histogram = ref_histogram.sum(1)
tab_a = faiss.AlignedTableUint16()
faiss.copy_array_to_AlignedTable(tab, tab_a)
sp = faiss.swig_ptr
hist = np.zeros(nbin, 'int32')
if nbin == 8:
faiss.simd_histogram_8(
tab_a.get(), n, minv, shift, sp(hist)
)
elif nbin == 16:
faiss.simd_histogram_16(
tab_a.get(), n, minv, shift, sp(hist)
)
else:
raise AssertionError()
np.testing.assert_array_equal(hist, ref_histogram)
def test_8bin_even_bounded(self):
self.do_test_bounded(8, 22 * 16)
def test_8bin_odd_bounded(self):
self.do_test_bounded(8, 10000)
def test_16bin_even_bounded(self):
self.do_test_bounded(16, 22 * 16)
def test_16bin_odd_bounded(self):
self.do_test_bounded(16, 10000)
def test_16bin_bounded_bigrange(self):
self.do_test_bounded(16, 1000, shift=12, rspan=(10, 65500))
def test_8bin_bounded_bigrange(self):
self.do_test_bounded(8, 1000, shift=13, rspan=(10, 65500))
def test_16bin_bounded_bigrange_2(self):
self.do_test_bounded(16, 10, shift=12, rspan=(65000, 65500))
def test_16bin_bounded_shift0(self):
self.do_test_bounded(16, 10000, shift=0, rspan=(10, 65500))
def test_8bin_bounded_shift0(self):
self.do_test_bounded(8, 10000, shift=0, rspan=(10, 65500))
def test_16bin_bounded_ignore_out_range(self):
self.do_test_bounded(16, 10000, shift=5, rspan=(100, 20000), minv=300)
def test_8bin_bounded_ignore_out_range(self):
self.do_test_bounded(8, 10000, shift=5, rspan=(100, 20000), minv=300)