faiss/tests/torch_test_neural_net.py

375 lines
12 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch # usort: skip
from torch import nn # usort: skip
import unittest # usort: skip
import numpy as np # usort: skip
import faiss # usort: skip
from faiss.contrib import datasets # usort: skip
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks # usort: skip
class TestLayer(unittest.TestCase):
@torch.no_grad()
def test_Embedding(self):
""" verify that the Faiss Embedding works the same as in Pytorch """
torch.manual_seed(123)
emb = nn.Embedding(40, 50)
idx = torch.randint(40, (25, ))
ref_batch = emb(idx)
emb2 = faiss.Embedding(emb)
idx2 = faiss.Int32Tensor2D(idx[:, None].to(dtype=torch.int32))
new_batch = emb2(idx2)
new_batch = new_batch.numpy()
np.testing.assert_allclose(ref_batch.numpy(), new_batch, atol=2e-6)
@torch.no_grad()
def do_test_Linear(self, bias):
""" verify that the Faiss Linear works the same as in Pytorch """
torch.manual_seed(123)
linear = nn.Linear(50, 40, bias=bias)
x = torch.randn(25, 50)
ref_y = linear(x)
linear2 = faiss.Linear(linear)
x2 = faiss.Tensor2D(x)
y = linear2(x2)
np.testing.assert_allclose(ref_y.numpy(), y.numpy(), atol=2e-6)
def test_Linear(self):
self.do_test_Linear(True)
def test_Linear_nobias(self):
self.do_test_Linear(False)
######################################################
# QINCo Pytorch implementation copied from
# https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py
#
# The implementation is copied here to avoid introducting an additional
# dependency.
######################################################
def pairwise_distances(a, b):
anorms = (a**2).sum(-1)
bnorms = (b**2).sum(-1)
return anorms[:, None] + bnorms - 2 * a @ b.T
def compute_batch_distances(a, b):
anorms = (a**2).sum(-1)
bnorms = (b**2).sum(-1)
return (
anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.bmm(a, b.transpose(2, 1))
)
def assign_batch_multiple(x, zqs):
bs, d = x.shape
bs, K, d = zqs.shape
L2distances = compute_batch_distances(x.unsqueeze(1), zqs).squeeze(1) # [bs x ksq]
idx = torch.argmin(L2distances, dim=1).unsqueeze(1) # [bsx1]
quantized = torch.gather(zqs, dim=1, index=idx.unsqueeze(-1).repeat(1, 1, d))
return idx.squeeze(1), quantized.squeeze(1)
def assign_to_codebook(x, c, bs=16384):
nq, d = x.shape
nb, d2 = c.shape
assert d == d2
if nq * nb < bs * bs:
# small enough to represent the whole distance table
dis = pairwise_distances(x, c)
return dis.argmin(1)
# otherwise tile computation to avoid OOM
res = torch.empty((nq,), dtype=torch.int64, device=x.device)
cnorms = (c**2).sum(1)
for i in range(0, nq, bs):
xnorms = (x[i : i + bs] ** 2).sum(1, keepdim=True)
for j in range(0, nb, bs):
dis = xnorms + cnorms[j : j + bs] - 2 * x[i : i + bs] @ c[j : j + bs].T
dmini, imini = dis.min(1)
if j == 0:
dmin = dmini
imin = imini
else:
(mask,) = torch.where(dmini < dmin)
dmin[mask] = dmini[mask]
imin[mask] = imini[mask] + j
res[i : i + bs] = imin
return res
class QINCoStep(nn.Module):
"""
One quantization step for QINCo.
Contains the codebook, concatenation block, and residual blocks
"""
def __init__(self, d, K, L, h):
nn.Module.__init__(self)
self.d, self.K, self.L, self.h = d, K, L, h
self.codebook = nn.Embedding(K, d)
self.MLPconcat = nn.Linear(2 * d, d)
self.residual_blocks = []
for l in range(L):
residual_block = nn.Sequential(
nn.Linear(d, h, bias=False), nn.ReLU(), nn.Linear(h, d, bias=False)
)
self.add_module(f"residual_block{l}", residual_block)
self.residual_blocks.append(residual_block)
def decode(self, xhat, codes):
zqs = self.codebook(codes)
cc = torch.concatenate((zqs, xhat), 1)
zqs = zqs + self.MLPconcat(cc)
for residual_block in self.residual_blocks:
zqs = zqs + residual_block(zqs)
return zqs
def encode(self, xhat, x):
# we are trying out the whole codebook
zqs = self.codebook.weight
K, d = zqs.shape
bs, d = xhat.shape
# repeat so that they are of size bs * K
zqs_r = zqs.repeat(bs, 1, 1).reshape(bs * K, d)
xhat_r = xhat.reshape(bs, 1, d).repeat(1, K, 1).reshape(bs * K, d)
# pass on batch of size bs * K
cc = torch.concatenate((zqs_r, xhat_r), 1)
zqs_r = zqs_r + self.MLPconcat(cc)
for residual_block in self.residual_blocks:
zqs_r = zqs_r + residual_block(zqs_r)
# possible next steps
zqs_r = zqs_r.reshape(bs, K, d) + xhat.reshape(bs, 1, d)
codes, xhat_next = assign_batch_multiple(x, zqs_r)
return codes, xhat_next - xhat
class QINCo(nn.Module):
"""
QINCo quantizer, built from a chain of residual quantization steps
"""
def __init__(self, d, K, L, M, h):
nn.Module.__init__(self)
self.d, self.K, self.L, self.M, self.h = d, K, L, M, h
self.codebook0 = nn.Embedding(K, d)
self.steps = []
for m in range(1, M):
step = QINCoStep(d, K, L, h)
self.add_module(f"step{m}", step)
self.steps.append(step)
def decode(self, codes):
xhat = self.codebook0(codes[:, 0])
for i, step in enumerate(self.steps):
xhat = xhat + step.decode(xhat, codes[:, i + 1])
return xhat
def encode(self, x, code0=None):
"""
Encode a batch of vectors x to codes of length M.
If this function is called from IVF-QINCo, codes are 1 index longer,
due to the first index being the IVF index, and codebook0 is the IVF codebook.
"""
M = len(self.steps) + 1
bs, d = x.shape
codes = torch.zeros(bs, M, dtype=int, device=x.device)
if code0 is None:
# at IVF training time, the code0 is fixed (and precomputed)
code0 = assign_to_codebook(x, self.codebook0.weight)
codes[:, 0] = code0
xhat = self.codebook0.weight[code0]
for i, step in enumerate(self.steps):
codes[:, i + 1], toadd = step.encode(xhat, x)
xhat = xhat + toadd
return codes, xhat
######################################################
# QINCo tests
######################################################
def copy_QINCoStep(step):
step2 = faiss.QINCoStep(step.d, step.K, step.L, step.h)
step2.codebook.from_torch(step.codebook)
step2.MLPconcat.from_torch(step.MLPconcat)
for l in range(step.L):
src = step.residual_blocks[l]
dest = step2.get_residual_block(l)
dest.linear1.from_torch(src[0])
dest.linear2.from_torch(src[2])
return step2
class TestQINCoStep(unittest.TestCase):
@torch.no_grad()
def test_decode(self):
torch.manual_seed(123)
step = QINCoStep(d=16, K=20, L=2, h=8)
codes = torch.randint(0, 20, (10, ))
xhat = torch.randn(10, 16)
ref_decode = step.decode(xhat, codes)
# step2 = copy_QINCoStep(step)
step2 = faiss.QINCoStep(step)
codes2 = faiss.Int32Tensor2D(codes[:, None].to(dtype=torch.int32))
np.testing.assert_array_equal(
step.codebook(codes).numpy(),
step2.codebook(codes2).numpy()
)
xhat2 = faiss.Tensor2D(xhat)
# xhat2 = faiss.Tensor2D(len(codes), step2.d)
new_decode = step2.decode(xhat2, codes2)
np.testing.assert_allclose(
ref_decode.numpy(),
new_decode.numpy(),
atol=2e-6
)
@torch.no_grad()
def test_encode(self):
torch.manual_seed(123)
step = QINCoStep(d=16, K=20, L=2, h=8)
# create plausible x for testing starting from actual codes
codes = torch.randint(0, 20, (10, ))
xhat = torch.zeros(10, 16)
x = step.decode(xhat, codes)
del codes
ref_codes, toadd = step.encode(xhat, x)
step2 = copy_QINCoStep(step)
xhat2 = faiss.Tensor2D(xhat)
x2 = faiss.Tensor2D(x)
toadd2 = faiss.Tensor2D(10, 16)
new_codes = step2.encode(xhat2, x2, toadd2)
np.testing.assert_allclose(
ref_codes.numpy(),
new_codes.numpy().ravel(),
atol=2e-6
)
np.testing.assert_allclose(toadd.numpy(), toadd2.numpy(), atol=2e-6)
class TestQINCo(unittest.TestCase):
@torch.no_grad()
def test_decode(self):
torch.manual_seed(123)
qinco = QINCo(d=16, K=20, L=2, M=3, h=8)
codes = torch.randint(0, 20, (10, 3))
x_ref = qinco.decode(codes)
qinco2 = faiss.QINCo(qinco)
codes2 = faiss.Int32Tensor2D(codes.to(dtype=torch.int32))
x_new = qinco2.decode(codes2)
np.testing.assert_allclose(x_ref.numpy(), x_new.numpy(), atol=2e-6)
@torch.no_grad()
def test_encode(self):
torch.manual_seed(123)
qinco = QINCo(d=16, K=20, L=2, M=3, h=8)
codes = torch.randint(0, 20, (10, 3))
x = qinco.decode(codes)
del codes
ref_codes, _ = qinco.encode(x)
qinco2 = faiss.QINCo(qinco)
x2 = faiss.Tensor2D(x)
new_codes = qinco2.encode(x2)
np.testing.assert_allclose(ref_codes.numpy(), new_codes.numpy(), atol=2e-6)
######################################################
# Test index
######################################################
class TestIndexQINCo(unittest.TestCase):
def test_search(self):
"""
We can't train qinco with just Faiss so we just train a RQ and use the
codebooks in QINCo with L = 0 residual blocks
"""
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
# prepare reference quantizer
M = 5
index_ref = faiss.index_factory(ds.d, "RQ5x4")
rq = index_ref.rq
# rq = faiss.ResidualQuantizer(ds.d, M, 4)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.max_beam_size = 1 # beam search not implemented for QINCo (yet)
index_ref.train(ds.get_train())
codebooks = get_additive_quantizer_codebooks(rq)
# convert to QINCo index
qinco_index = faiss.IndexQINCo(ds.d, M, 4, 0, ds.d)
qinco = qinco_index.qinco
qinco.codebook0.from_array(codebooks[0])
for i in range(1, qinco.M):
step = qinco.get_step(i - 1)
step.codebook.from_array(codebooks[i])
# MLPConcat left at zero -- it's added to the backbone
qinco_index.is_trained = True
# verify that the encoding gives the same results
ref_codes = rq.compute_codes(ds.get_database())
ref_decoded = rq.decode(ref_codes)
new_decoded = qinco_index.sa_decode(ref_codes)
np.testing.assert_allclose(ref_decoded, new_decoded, atol=2e-6)
new_codes = qinco_index.sa_encode(ds.get_database())
np.testing.assert_array_equal(ref_codes, new_codes)
# verify that search gives the same results
Dref, Iref = index_ref.search(ds.get_queries(), 5)
Dnew, Inew = qinco_index.search(ds.get_queries(), 5)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_allclose(Dref, Dnew, atol=2e-6)