428 lines
14 KiB
Python
428 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch # usort: skip
|
|
import unittest # usort: skip
|
|
import numpy as np # usort: skip
|
|
|
|
import faiss # usort: skip
|
|
import faiss.contrib.torch_utils # usort: skip
|
|
from faiss.contrib import datasets
|
|
from faiss.contrib.torch import clustering, quantization
|
|
|
|
|
|
|
|
|
|
class TestTorchUtilsCPU(unittest.TestCase):
|
|
# tests add, search
|
|
def test_lookup(self):
|
|
d = 128
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
# Add to CPU index with torch CPU
|
|
xb_torch = torch.rand(10000, d)
|
|
index.add(xb_torch)
|
|
|
|
# Test reconstruct
|
|
y_torch = index.reconstruct(10)
|
|
self.assertTrue(torch.equal(y_torch, xb_torch[10]))
|
|
|
|
# Add to CPU index with numpy CPU
|
|
xb_np = torch.rand(500, d).numpy()
|
|
index.add(xb_np)
|
|
self.assertEqual(index.ntotal, 10500)
|
|
|
|
y_np = np.zeros(d, dtype=np.float32)
|
|
index.reconstruct(10100, y_np)
|
|
self.assertTrue(np.array_equal(y_np, xb_np[100]))
|
|
|
|
# Search with np cpu
|
|
xq_torch = torch.rand(10, d, dtype=torch.float32)
|
|
d_np, I_np = index.search(xq_torch.numpy(), 5)
|
|
|
|
# Search with torch cpu
|
|
d_torch, I_torch = index.search(xq_torch, 5)
|
|
|
|
# The two should be equivalent
|
|
self.assertTrue(np.array_equal(d_np, d_torch.numpy()))
|
|
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
|
|
|
|
# Search with np cpu using pre-allocated arrays
|
|
d_np_input = np.zeros((10, 5), dtype=np.float32)
|
|
I_np_input = np.zeros((10, 5), dtype=np.int64)
|
|
index.search(xq_torch.numpy(), 5, d_np_input, I_np_input)
|
|
|
|
self.assertTrue(np.array_equal(d_np, d_np_input))
|
|
self.assertTrue(np.array_equal(I_np, I_np_input))
|
|
|
|
# Search with torch cpu using pre-allocated arrays
|
|
d_torch_input = torch.zeros(10, 5, dtype=torch.float32)
|
|
I_torch_input = torch.zeros(10, 5, dtype=torch.int64)
|
|
index.search(xq_torch, 5, d_torch_input, I_torch_input)
|
|
|
|
self.assertTrue(np.array_equal(d_torch_input.numpy(), d_np))
|
|
self.assertTrue(np.array_equal(I_torch_input.numpy(), I_np))
|
|
|
|
# tests train, add_with_ids
|
|
def test_train_add_with_ids(self):
|
|
d = 32
|
|
nlist = 5
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
index.train(xb)
|
|
|
|
# Test add_with_ids with torch cpu
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
index.add_with_ids(xb, ids)
|
|
_, I = index.search(xb[10:20], 1)
|
|
self.assertTrue(torch.equal(I.view(10), ids[10:20]))
|
|
|
|
# Test add_with_ids with numpy
|
|
index.reset()
|
|
index.train(xb.numpy())
|
|
index.add_with_ids(xb.numpy(), ids.numpy())
|
|
_, I = index.search(xb.numpy()[10:20], 1)
|
|
self.assertTrue(np.array_equal(I.reshape(10), ids.numpy()[10:20]))
|
|
|
|
# tests reconstruct, reconstruct_n
|
|
def test_reconstruct(self):
|
|
d = 32
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
xb = torch.rand(100, d, dtype=torch.float32)
|
|
index.add(xb)
|
|
|
|
# Test reconstruct with torch cpu (native return)
|
|
y = index.reconstruct(7)
|
|
self.assertTrue(torch.equal(xb[7], y))
|
|
|
|
# Test reconstruct with numpy output provided
|
|
y = np.empty(d, dtype=np.float32)
|
|
index.reconstruct(11, y)
|
|
self.assertTrue(np.array_equal(xb.numpy()[11], y))
|
|
|
|
# Test reconstruct with torch cpu output providesd
|
|
y = torch.empty(d, dtype=torch.float32)
|
|
index.reconstruct(12, y)
|
|
self.assertTrue(torch.equal(xb[12], y))
|
|
|
|
# Test reconstruct_n with torch cpu (native return)
|
|
y = index.reconstruct_n(10, 10)
|
|
self.assertTrue(torch.equal(xb[10:20], y))
|
|
|
|
# Test reconstruct with numpy output provided
|
|
y = np.empty((10, d), dtype=np.float32)
|
|
index.reconstruct_n(20, 10, y)
|
|
self.assertTrue(np.array_equal(xb.cpu().numpy()[20:30], y))
|
|
|
|
# Test reconstruct_n with torch cpu output provided
|
|
y = torch.empty(10, d, dtype=torch.float32)
|
|
index.reconstruct_n(40, 10, y)
|
|
self.assertTrue(torch.equal(xb[40:50].cpu(), y))
|
|
|
|
# tests assign
|
|
def test_assign(self):
|
|
d = 32
|
|
index = faiss.IndexFlatL2(d)
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
index.add(xb)
|
|
|
|
index_ref = faiss.IndexFlatL2(d)
|
|
index_ref.add(xb.numpy())
|
|
|
|
# Test assign with native cpu output
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
labels = index.assign(xq, 5)
|
|
labels_ref = index_ref.assign(xq.cpu(), 5)
|
|
|
|
self.assertTrue(torch.equal(labels, labels_ref))
|
|
|
|
# Test assign with np input
|
|
labels = index.assign(xq.numpy(), 5)
|
|
labels_ref = index_ref.assign(xq.numpy(), 5)
|
|
self.assertTrue(np.array_equal(labels, labels_ref))
|
|
|
|
# Test assign with numpy output provided
|
|
labels = np.empty((xq.shape[0], 5), dtype='int64')
|
|
index.assign(xq.numpy(), 5, labels)
|
|
self.assertTrue(np.array_equal(labels, labels_ref))
|
|
|
|
# Test assign with torch cpu output provided
|
|
labels = torch.empty(xq.shape[0], 5, dtype=torch.int64)
|
|
index.assign(xq, 5, labels)
|
|
labels_ref = index_ref.assign(xq, 5)
|
|
self.assertTrue(torch.equal(labels, labels_ref))
|
|
|
|
# tests remove_ids
|
|
def test_remove_ids(self):
|
|
# only implemented for cpu index + numpy at the moment
|
|
d = 32
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
index = faiss.IndexIVFFlat(quantizer, d, 5)
|
|
index.make_direct_map()
|
|
index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
index.train(xb)
|
|
index.add_with_ids(xb, ids)
|
|
|
|
ids_remove = np.array([1010], dtype=np.int64)
|
|
index.remove_ids(ids_remove)
|
|
|
|
# We should find this
|
|
y = index.reconstruct(1011)
|
|
self.assertTrue(np.array_equal(xb[11].numpy(), y))
|
|
|
|
# We should not find this
|
|
with self.assertRaises(RuntimeError):
|
|
y = index.reconstruct(1010)
|
|
|
|
# Torch not yet supported
|
|
ids_remove = torch.tensor([1012], dtype=torch.int64)
|
|
with self.assertRaises(AssertionError):
|
|
index.remove_ids(ids_remove)
|
|
|
|
# tests update_vectors
|
|
def test_update_vectors(self):
|
|
d = 32
|
|
quantizer_np = faiss.IndexFlatL2(d)
|
|
index_np = faiss.IndexIVFFlat(quantizer_np, d, 5)
|
|
index_np.make_direct_map()
|
|
index_np.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
quantizer_torch = faiss.IndexFlatL2(d)
|
|
index_torch = faiss.IndexIVFFlat(quantizer_torch, d, 5)
|
|
index_torch.make_direct_map()
|
|
index_torch.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
|
|
index_np.train(xb.numpy())
|
|
index_np.add_with_ids(xb.numpy(), ids.numpy())
|
|
|
|
index_torch.train(xb)
|
|
index_torch.add_with_ids(xb, ids)
|
|
|
|
xb_up = torch.rand(10, d, dtype=torch.float32)
|
|
ids_up = ids[0:10]
|
|
|
|
index_np.update_vectors(ids_up.numpy(), xb_up.numpy())
|
|
index_torch.update_vectors(ids_up, xb_up)
|
|
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
D_np, I_np = index_np.search(xq.numpy(), 5)
|
|
D_torch, I_torch = index_torch.search(xq, 5)
|
|
|
|
self.assertTrue(np.array_equal(D_np, D_torch.numpy()))
|
|
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
|
|
|
|
# tests range_search
|
|
def test_range_search(self):
|
|
torch.manual_seed(10)
|
|
d = 32
|
|
index = faiss.IndexFlatL2(d)
|
|
xb = torch.rand(100, d, dtype=torch.float32)
|
|
index.add(xb)
|
|
|
|
# torch cpu as ground truth
|
|
thresh = 2.9
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
lims, D, I = index.range_search(xq, thresh)
|
|
|
|
# compare against np
|
|
lims_np, D_np, I_np = index.range_search(xq.numpy(), thresh)
|
|
|
|
self.assertTrue(np.array_equal(lims.numpy(), lims_np))
|
|
self.assertTrue(np.array_equal(D.numpy(), D_np))
|
|
self.assertTrue(np.array_equal(I.numpy(), I_np))
|
|
|
|
# tests search_and_reconstruct
|
|
def test_search_and_reconstruct(self):
|
|
d = 32
|
|
nlist = 10
|
|
M = 4
|
|
k = 5
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, 4)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
index.train(xb)
|
|
|
|
# different set
|
|
xb = torch.rand(500, d, dtype=torch.float32)
|
|
index.add(xb)
|
|
|
|
# torch cpu as ground truth
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
D, I, R = index.search_and_reconstruct(xq, k)
|
|
|
|
# compare against numpy
|
|
D_np, I_np, R_np = index.search_and_reconstruct(xq.numpy(), k)
|
|
|
|
self.assertTrue(np.array_equal(D.numpy(), D_np))
|
|
self.assertTrue(np.array_equal(I.numpy(), I_np))
|
|
self.assertTrue(np.array_equal(R.numpy(), R_np))
|
|
|
|
# numpy input values
|
|
D_input = np.zeros((xq.shape[0], k), dtype=np.float32)
|
|
I_input = np.zeros((xq.shape[0], k), dtype=np.int64)
|
|
R_input = np.zeros((xq.shape[0], k, d), dtype=np.float32)
|
|
|
|
index.search_and_reconstruct(xq.numpy(), k, D_input, I_input, R_input)
|
|
|
|
self.assertTrue(np.array_equal(D.numpy(), D_input))
|
|
self.assertTrue(np.array_equal(I.numpy(), I_input))
|
|
self.assertTrue(np.array_equal(R.numpy(), R_input))
|
|
|
|
# torch input values
|
|
D_input = torch.zeros(xq.shape[0], k, dtype=torch.float32)
|
|
I_input = torch.zeros(xq.shape[0], k, dtype=torch.int64)
|
|
R_input = torch.zeros(xq.shape[0], k, d, dtype=torch.float32)
|
|
|
|
index.search_and_reconstruct(xq, k, D_input, I_input, R_input)
|
|
|
|
self.assertTrue(torch.equal(D, D_input))
|
|
self.assertTrue(torch.equal(I, I_input))
|
|
self.assertTrue(torch.equal(R, R_input))
|
|
|
|
def test_search_preassigned(self):
|
|
ds = datasets.SyntheticDataset(32, 1000, 100, 10)
|
|
index = faiss.index_factory(32, "IVF20,PQ4np")
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
Dref, Iref = index.search(ds.get_queries(), 10)
|
|
quantizer = faiss.clone_index(index.quantizer)
|
|
|
|
# mutilate the index' quantizer
|
|
index.quantizer.reset()
|
|
index.quantizer.add(np.zeros((20, 32), dtype='float32'))
|
|
|
|
# test numpy codepath
|
|
Dq, Iq = quantizer.search(ds.get_queries(), 4)
|
|
Dref2, Iref2 = index.search_preassigned(ds.get_queries(), 10, Iq, Dq)
|
|
np.testing.assert_array_equal(Iref, Iref2)
|
|
np.testing.assert_array_equal(Dref, Dref2)
|
|
|
|
# test torch codepath
|
|
xq = torch.from_numpy(ds.get_queries())
|
|
Dq, Iq = quantizer.search(xq, 4)
|
|
Dref2, Iref2 = index.search_preassigned(xq, 10, Iq, Dq)
|
|
np.testing.assert_array_equal(Iref, Iref2.numpy())
|
|
np.testing.assert_array_equal(Dref, Dref2.numpy())
|
|
|
|
# tests sa_encode, sa_decode
|
|
def test_sa_encode_decode(self):
|
|
d = 16
|
|
index = faiss.IndexScalarQuantizer(d, faiss.ScalarQuantizer.QT_8bit)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
index.train(xb)
|
|
|
|
# torch cpu as ground truth
|
|
nq = 10
|
|
xq = torch.rand(nq, d, dtype=torch.float32)
|
|
encoded_torch = index.sa_encode(xq)
|
|
|
|
# numpy cpu
|
|
encoded_np = index.sa_encode(xq.numpy())
|
|
|
|
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np))
|
|
|
|
decoded_torch = index.sa_decode(encoded_torch)
|
|
decoded_np = index.sa_decode(encoded_np)
|
|
|
|
self.assertTrue(torch.equal(decoded_torch, torch.from_numpy(decoded_np)))
|
|
|
|
# torch cpu as output parameter
|
|
encoded_torch_param = torch.zeros(nq, d, dtype=torch.uint8)
|
|
index.sa_encode(xq, encoded_torch_param)
|
|
|
|
self.assertTrue(torch.equal(encoded_torch, encoded_torch))
|
|
|
|
decoded_torch_param = torch.zeros(nq, d, dtype=torch.float32)
|
|
index.sa_decode(encoded_torch, decoded_torch_param)
|
|
|
|
self.assertTrue(torch.equal(decoded_torch, decoded_torch_param))
|
|
|
|
# np as output parameter
|
|
encoded_np_param = np.zeros((nq, d), dtype=np.uint8)
|
|
index.sa_encode(xq.numpy(), encoded_np_param)
|
|
|
|
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np_param))
|
|
|
|
decoded_np_param = np.zeros((nq, d), dtype=np.float32)
|
|
index.sa_decode(encoded_np_param, decoded_np_param)
|
|
|
|
self.assertTrue(np.array_equal(decoded_np, decoded_np_param))
|
|
|
|
def test_non_contiguous(self):
|
|
d = 128
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
xb = torch.rand(d, 100).transpose(0, 1)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
index.add(xb)
|
|
|
|
# disabled since we now accept non-contiguous arrays
|
|
# with self.assertRaises(ValueError):
|
|
# index.add(xb.numpy())
|
|
|
|
|
|
class TestClustering(unittest.TestCase):
|
|
|
|
def test_python_kmeans(self):
|
|
""" Test the python implementation of kmeans """
|
|
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
|
|
x = ds.get_train()
|
|
|
|
# bad distribution to stress-test split code
|
|
xt = x[:10000].copy()
|
|
xt[:5000] = x[0]
|
|
|
|
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
|
|
km_ref.train(xt)
|
|
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
|
|
|
|
xt_torch = torch.from_numpy(xt)
|
|
data = clustering.DatasetAssign(xt_torch)
|
|
centroids = clustering.kmeans(100, data, 10)
|
|
centroids = centroids.numpy()
|
|
err2 = faiss.knn(xt, centroids, 1)[0].sum()
|
|
|
|
# 33498.332 33380.477
|
|
# print(err, err2) 1/0
|
|
self.assertLess(err2, err * 1.1)
|
|
|
|
|
|
class TestQuantization(unittest.TestCase):
|
|
def test_python_product_quantization(self):
|
|
""" Test the python implementation of product quantization """
|
|
d = 64
|
|
n = 10000
|
|
cs = 4
|
|
nbits = 8
|
|
M = 4
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
pq = faiss.ProductQuantizer(d, cs, nbits)
|
|
pq.train(x)
|
|
codes = pq.compute_codes(x)
|
|
x2 = pq.decode(codes)
|
|
diff = ((x - x2)**2).sum()
|
|
# vs pure pytorch impl
|
|
xt = torch.from_numpy(x)
|
|
my_pq = quantization.ProductQuantizer(d, M, nbits)
|
|
my_pq.train(xt)
|
|
my_codes = my_pq.encode(xt)
|
|
xt2 = my_pq.decode(my_codes)
|
|
my_diff = ((xt - xt2)**2).sum()
|
|
self.assertLess(abs(diff - my_diff), 100)
|