1
0
mirror of https://github.com/facebookresearch/faiss.git synced 2025-06-03 21:54:02 +08:00
faiss/tests/torch_test_contrib.py
Matthijs Douze 1806c6af27 Automatic type conversions for Python API ()
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2274

All input matrices needed to be of the correct type and to be C-contiguous. This diff passes the main entry points of the api through `np.ascontiguousarray` so that the function parameters are transparently converted to the suitable format if needed.

We did not have this before because users need to be made aware of the performance impact, but it seems that maybe usability is more useful.

This diff is an alternative to

D35007365
https://github.com/facebookresearch/faiss/pull/2250

Reviewed By: beauby

Differential Revision: D35009612

fbshipit-source-id: fa0d5cfdfbff6b0916d47bd33c620e3ca9d5dd40
2022-03-30 05:42:08 -07:00

346 lines
12 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faiss
import torch
import unittest
import numpy as np
import faiss.contrib.torch_utils
class TestTorchUtilsCPU(unittest.TestCase):
# tests add, search
def test_lookup(self):
d = 128
index = faiss.IndexFlatL2(d)
# Add to CPU index with torch CPU
xb_torch = torch.rand(10000, d)
index.add(xb_torch)
# Test reconstruct
y_torch = index.reconstruct(10)
self.assertTrue(torch.equal(y_torch, xb_torch[10]))
# Add to CPU index with numpy CPU
xb_np = torch.rand(500, d).numpy()
index.add(xb_np)
self.assertEqual(index.ntotal, 10500)
y_np = np.zeros(d, dtype=np.float32)
index.reconstruct(10100, y_np)
self.assertTrue(np.array_equal(y_np, xb_np[100]))
# Search with np cpu
xq_torch = torch.rand(10, d, dtype=torch.float32)
d_np, I_np = index.search(xq_torch.numpy(), 5)
# Search with torch cpu
d_torch, I_torch = index.search(xq_torch, 5)
# The two should be equivalent
self.assertTrue(np.array_equal(d_np, d_torch.numpy()))
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
# Search with np cpu using pre-allocated arrays
d_np_input = np.zeros((10, 5), dtype=np.float32)
I_np_input = np.zeros((10, 5), dtype=np.int64)
index.search(xq_torch.numpy(), 5, d_np_input, I_np_input)
self.assertTrue(np.array_equal(d_np, d_np_input))
self.assertTrue(np.array_equal(I_np, I_np_input))
# Search with torch cpu using pre-allocated arrays
d_torch_input = torch.zeros(10, 5, dtype=torch.float32)
I_torch_input = torch.zeros(10, 5, dtype=torch.int64)
index.search(xq_torch, 5, d_torch_input, I_torch_input)
self.assertTrue(np.array_equal(d_torch_input.numpy(), d_np))
self.assertTrue(np.array_equal(I_torch_input.numpy(), I_np))
# tests train, add_with_ids
def test_train_add_with_ids(self):
d = 32
nlist = 5
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
xb = torch.rand(1000, d, dtype=torch.float32)
index.train(xb)
# Test add_with_ids with torch cpu
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
index.add_with_ids(xb, ids)
_, I = index.search(xb[10:20], 1)
self.assertTrue(torch.equal(I.view(10), ids[10:20]))
# Test add_with_ids with numpy
index.reset()
index.train(xb.numpy())
index.add_with_ids(xb.numpy(), ids.numpy())
_, I = index.search(xb.numpy()[10:20], 1)
self.assertTrue(np.array_equal(I.reshape(10), ids.numpy()[10:20]))
# tests reconstruct, reconstruct_n
def test_reconstruct(self):
d = 32
index = faiss.IndexFlatL2(d)
xb = torch.rand(100, d, dtype=torch.float32)
index.add(xb)
# Test reconstruct with torch cpu (native return)
y = index.reconstruct(7)
self.assertTrue(torch.equal(xb[7], y))
# Test reconstruct with numpy output provided
y = np.empty(d, dtype=np.float32)
index.reconstruct(11, y)
self.assertTrue(np.array_equal(xb.numpy()[11], y))
# Test reconstruct with torch cpu output providesd
y = torch.empty(d, dtype=torch.float32)
index.reconstruct(12, y)
self.assertTrue(torch.equal(xb[12], y))
# Test reconstruct_n with torch cpu (native return)
y = index.reconstruct_n(10, 10)
self.assertTrue(torch.equal(xb[10:20], y))
# Test reconstruct with numpy output provided
y = np.empty((10, d), dtype=np.float32)
index.reconstruct_n(20, 10, y)
self.assertTrue(np.array_equal(xb.cpu().numpy()[20:30], y))
# Test reconstruct_n with torch cpu output provided
y = torch.empty(10, d, dtype=torch.float32)
index.reconstruct_n(40, 10, y)
self.assertTrue(torch.equal(xb[40:50].cpu(), y))
# tests assign
def test_assign(self):
d = 32
index = faiss.IndexFlatL2(d)
xb = torch.rand(1000, d, dtype=torch.float32)
index.add(xb)
index_ref = faiss.IndexFlatL2(d)
index_ref.add(xb.numpy())
# Test assign with native cpu output
xq = torch.rand(10, d, dtype=torch.float32)
labels = index.assign(xq, 5)
labels_ref = index_ref.assign(xq.cpu(), 5)
self.assertTrue(torch.equal(labels, labels_ref))
# Test assign with np input
labels = index.assign(xq.numpy(), 5)
labels_ref = index_ref.assign(xq.numpy(), 5)
self.assertTrue(np.array_equal(labels, labels_ref))
# Test assign with numpy output provided
labels = np.empty((xq.shape[0], 5), dtype='int64')
index.assign(xq.numpy(), 5, labels)
self.assertTrue(np.array_equal(labels, labels_ref))
# Test assign with torch cpu output provided
labels = torch.empty(xq.shape[0], 5, dtype=torch.int64)
index.assign(xq, 5, labels)
labels_ref = index_ref.assign(xq, 5)
self.assertTrue(torch.equal(labels, labels_ref))
# tests remove_ids
def test_remove_ids(self):
# only implemented for cpu index + numpy at the moment
d = 32
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, 5)
index.make_direct_map()
index.set_direct_map_type(faiss.DirectMap.Hashtable)
xb = torch.rand(1000, d, dtype=torch.float32)
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
index.train(xb)
index.add_with_ids(xb, ids)
ids_remove = np.array([1010], dtype=np.int64)
index.remove_ids(ids_remove)
# We should find this
y = index.reconstruct(1011)
self.assertTrue(np.array_equal(xb[11].numpy(), y))
# We should not find this
with self.assertRaises(RuntimeError):
y = index.reconstruct(1010)
# Torch not yet supported
ids_remove = torch.tensor([1012], dtype=torch.int64)
with self.assertRaises(AssertionError):
index.remove_ids(ids_remove)
# tests update_vectors
def test_update_vectors(self):
d = 32
quantizer_np = faiss.IndexFlatL2(d)
index_np = faiss.IndexIVFFlat(quantizer_np, d, 5)
index_np.make_direct_map()
index_np.set_direct_map_type(faiss.DirectMap.Hashtable)
quantizer_torch = faiss.IndexFlatL2(d)
index_torch = faiss.IndexIVFFlat(quantizer_torch, d, 5)
index_torch.make_direct_map()
index_torch.set_direct_map_type(faiss.DirectMap.Hashtable)
xb = torch.rand(1000, d, dtype=torch.float32)
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
index_np.train(xb.numpy())
index_np.add_with_ids(xb.numpy(), ids.numpy())
index_torch.train(xb)
index_torch.add_with_ids(xb, ids)
xb_up = torch.rand(10, d, dtype=torch.float32)
ids_up = ids[0:10]
index_np.update_vectors(ids_up.numpy(), xb_up.numpy())
index_torch.update_vectors(ids_up, xb_up)
xq = torch.rand(10, d, dtype=torch.float32)
D_np, I_np = index_np.search(xq.numpy(), 5)
D_torch, I_torch = index_torch.search(xq, 5)
self.assertTrue(np.array_equal(D_np, D_torch.numpy()))
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
# tests range_search
def test_range_search(self):
torch.manual_seed(10)
d = 32
index = faiss.IndexFlatL2(d)
xb = torch.rand(100, d, dtype=torch.float32)
index.add(xb)
# torch cpu as ground truth
thresh = 2.9
xq = torch.rand(10, d, dtype=torch.float32)
lims, D, I = index.range_search(xq, thresh)
# compare against np
lims_np, D_np, I_np = index.range_search(xq.numpy(), thresh)
self.assertTrue(np.array_equal(lims.numpy(), lims_np))
self.assertTrue(np.array_equal(D.numpy(), D_np))
self.assertTrue(np.array_equal(I.numpy(), I_np))
# tests search_and_reconstruct
def test_search_and_reconstruct(self):
d = 32
nlist = 10
M = 4
k = 5
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, 4)
xb = torch.rand(1000, d, dtype=torch.float32)
index.train(xb)
# different set
xb = torch.rand(500, d, dtype=torch.float32)
index.add(xb)
# torch cpu as ground truth
xq = torch.rand(10, d, dtype=torch.float32)
D, I, R = index.search_and_reconstruct(xq, k)
# compare against numpy
D_np, I_np, R_np = index.search_and_reconstruct(xq.numpy(), k)
self.assertTrue(np.array_equal(D.numpy(), D_np))
self.assertTrue(np.array_equal(I.numpy(), I_np))
self.assertTrue(np.array_equal(R.numpy(), R_np))
# numpy input values
D_input = np.zeros((xq.shape[0], k), dtype=np.float32)
I_input = np.zeros((xq.shape[0], k), dtype=np.int64)
R_input = np.zeros((xq.shape[0], k, d), dtype=np.float32)
index.search_and_reconstruct(xq.numpy(), k, D_input, I_input, R_input)
self.assertTrue(np.array_equal(D.numpy(), D_input))
self.assertTrue(np.array_equal(I.numpy(), I_input))
self.assertTrue(np.array_equal(R.numpy(), R_input))
# torch input values
D_input = torch.zeros(xq.shape[0], k, dtype=torch.float32)
I_input = torch.zeros(xq.shape[0], k, dtype=torch.int64)
R_input = torch.zeros(xq.shape[0], k, d, dtype=torch.float32)
index.search_and_reconstruct(xq, k, D_input, I_input, R_input)
self.assertTrue(torch.equal(D, D_input))
self.assertTrue(torch.equal(I, I_input))
self.assertTrue(torch.equal(R, R_input))
# tests sa_encode, sa_decode
def test_sa_encode_decode(self):
d = 16
index = faiss.IndexScalarQuantizer(d, faiss.ScalarQuantizer.QT_8bit)
xb = torch.rand(1000, d, dtype=torch.float32)
index.train(xb)
# torch cpu as ground truth
nq = 10
xq = torch.rand(nq, d, dtype=torch.float32)
encoded_torch = index.sa_encode(xq)
# numpy cpu
encoded_np = index.sa_encode(xq.numpy())
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np))
decoded_torch = index.sa_decode(encoded_torch)
decoded_np = index.sa_decode(encoded_np)
self.assertTrue(torch.equal(decoded_torch, torch.from_numpy(decoded_np)))
# torch cpu as output parameter
encoded_torch_param = torch.zeros(nq, d, dtype=torch.uint8)
index.sa_encode(xq, encoded_torch_param)
self.assertTrue(torch.equal(encoded_torch, encoded_torch))
decoded_torch_param = torch.zeros(nq, d, dtype=torch.float32)
index.sa_decode(encoded_torch, decoded_torch_param)
self.assertTrue(torch.equal(decoded_torch, decoded_torch_param))
# np as output parameter
encoded_np_param = np.zeros((nq, d), dtype=np.uint8)
index.sa_encode(xq.numpy(), encoded_np_param)
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np_param))
decoded_np_param = np.zeros((nq, d), dtype=np.float32)
index.sa_decode(encoded_np_param, decoded_np_param)
self.assertTrue(np.array_equal(decoded_np, decoded_np_param))
def test_non_contiguous(self):
d = 128
index = faiss.IndexFlatL2(d)
xb = torch.rand(d, 100).transpose(0, 1)
with self.assertRaises(AssertionError):
index.add(xb)
# disabled since we now accept non-contiguous arrays
# with self.assertRaises(ValueError):
# index.add(xb.numpy())