375 lines
12 KiB
Python
375 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch # usort: skip
|
|
from torch import nn # usort: skip
|
|
import unittest # usort: skip
|
|
import numpy as np # usort: skip
|
|
|
|
import faiss # usort: skip
|
|
|
|
from faiss.contrib import datasets # usort: skip
|
|
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks # usort: skip
|
|
|
|
|
|
class TestLayer(unittest.TestCase):
|
|
|
|
@torch.no_grad()
|
|
def test_Embedding(self):
|
|
""" verify that the Faiss Embedding works the same as in Pytorch """
|
|
torch.manual_seed(123)
|
|
|
|
emb = nn.Embedding(40, 50)
|
|
idx = torch.randint(40, (25, ))
|
|
ref_batch = emb(idx)
|
|
|
|
emb2 = faiss.Embedding(emb)
|
|
idx2 = faiss.Int32Tensor2D(idx[:, None].to(dtype=torch.int32))
|
|
new_batch = emb2(idx2)
|
|
|
|
new_batch = new_batch.numpy()
|
|
np.testing.assert_allclose(ref_batch.numpy(), new_batch, atol=2e-6)
|
|
|
|
@torch.no_grad()
|
|
def do_test_Linear(self, bias):
|
|
""" verify that the Faiss Linear works the same as in Pytorch """
|
|
torch.manual_seed(123)
|
|
linear = nn.Linear(50, 40, bias=bias)
|
|
x = torch.randn(25, 50)
|
|
ref_y = linear(x)
|
|
|
|
linear2 = faiss.Linear(linear)
|
|
x2 = faiss.Tensor2D(x)
|
|
y = linear2(x2)
|
|
np.testing.assert_allclose(ref_y.numpy(), y.numpy(), atol=2e-6)
|
|
|
|
def test_Linear(self):
|
|
self.do_test_Linear(True)
|
|
|
|
def test_Linear_nobias(self):
|
|
self.do_test_Linear(False)
|
|
|
|
######################################################
|
|
# QINCo Pytorch implementation copied from
|
|
# https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py
|
|
#
|
|
# The implementation is copied here to avoid introducting an additional
|
|
# dependency.
|
|
######################################################
|
|
|
|
|
|
def pairwise_distances(a, b):
|
|
anorms = (a**2).sum(-1)
|
|
bnorms = (b**2).sum(-1)
|
|
return anorms[:, None] + bnorms - 2 * a @ b.T
|
|
|
|
|
|
def compute_batch_distances(a, b):
|
|
anorms = (a**2).sum(-1)
|
|
bnorms = (b**2).sum(-1)
|
|
return (
|
|
anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.bmm(a, b.transpose(2, 1))
|
|
)
|
|
|
|
|
|
def assign_batch_multiple(x, zqs):
|
|
bs, d = x.shape
|
|
bs, K, d = zqs.shape
|
|
|
|
L2distances = compute_batch_distances(x.unsqueeze(1), zqs).squeeze(1) # [bs x ksq]
|
|
idx = torch.argmin(L2distances, dim=1).unsqueeze(1) # [bsx1]
|
|
quantized = torch.gather(zqs, dim=1, index=idx.unsqueeze(-1).repeat(1, 1, d))
|
|
return idx.squeeze(1), quantized.squeeze(1)
|
|
|
|
|
|
def assign_to_codebook(x, c, bs=16384):
|
|
nq, d = x.shape
|
|
nb, d2 = c.shape
|
|
assert d == d2
|
|
if nq * nb < bs * bs:
|
|
# small enough to represent the whole distance table
|
|
dis = pairwise_distances(x, c)
|
|
return dis.argmin(1)
|
|
|
|
# otherwise tile computation to avoid OOM
|
|
res = torch.empty((nq,), dtype=torch.int64, device=x.device)
|
|
cnorms = (c**2).sum(1)
|
|
for i in range(0, nq, bs):
|
|
xnorms = (x[i : i + bs] ** 2).sum(1, keepdim=True)
|
|
for j in range(0, nb, bs):
|
|
dis = xnorms + cnorms[j : j + bs] - 2 * x[i : i + bs] @ c[j : j + bs].T
|
|
dmini, imini = dis.min(1)
|
|
if j == 0:
|
|
dmin = dmini
|
|
imin = imini
|
|
else:
|
|
(mask,) = torch.where(dmini < dmin)
|
|
dmin[mask] = dmini[mask]
|
|
imin[mask] = imini[mask] + j
|
|
res[i : i + bs] = imin
|
|
return res
|
|
|
|
|
|
class QINCoStep(nn.Module):
|
|
"""
|
|
One quantization step for QINCo.
|
|
Contains the codebook, concatenation block, and residual blocks
|
|
"""
|
|
|
|
def __init__(self, d, K, L, h):
|
|
nn.Module.__init__(self)
|
|
|
|
self.d, self.K, self.L, self.h = d, K, L, h
|
|
|
|
self.codebook = nn.Embedding(K, d)
|
|
self.MLPconcat = nn.Linear(2 * d, d)
|
|
|
|
self.residual_blocks = []
|
|
for l in range(L):
|
|
residual_block = nn.Sequential(
|
|
nn.Linear(d, h, bias=False), nn.ReLU(), nn.Linear(h, d, bias=False)
|
|
)
|
|
self.add_module(f"residual_block{l}", residual_block)
|
|
self.residual_blocks.append(residual_block)
|
|
|
|
def decode(self, xhat, codes):
|
|
zqs = self.codebook(codes)
|
|
cc = torch.concatenate((zqs, xhat), 1)
|
|
zqs = zqs + self.MLPconcat(cc)
|
|
|
|
for residual_block in self.residual_blocks:
|
|
zqs = zqs + residual_block(zqs)
|
|
|
|
return zqs
|
|
|
|
def encode(self, xhat, x):
|
|
# we are trying out the whole codebook
|
|
zqs = self.codebook.weight
|
|
K, d = zqs.shape
|
|
bs, d = xhat.shape
|
|
|
|
# repeat so that they are of size bs * K
|
|
zqs_r = zqs.repeat(bs, 1, 1).reshape(bs * K, d)
|
|
xhat_r = xhat.reshape(bs, 1, d).repeat(1, K, 1).reshape(bs * K, d)
|
|
|
|
# pass on batch of size bs * K
|
|
cc = torch.concatenate((zqs_r, xhat_r), 1)
|
|
zqs_r = zqs_r + self.MLPconcat(cc)
|
|
|
|
for residual_block in self.residual_blocks:
|
|
zqs_r = zqs_r + residual_block(zqs_r)
|
|
|
|
# possible next steps
|
|
zqs_r = zqs_r.reshape(bs, K, d) + xhat.reshape(bs, 1, d)
|
|
codes, xhat_next = assign_batch_multiple(x, zqs_r)
|
|
|
|
return codes, xhat_next - xhat
|
|
|
|
|
|
class QINCo(nn.Module):
|
|
"""
|
|
QINCo quantizer, built from a chain of residual quantization steps
|
|
"""
|
|
|
|
def __init__(self, d, K, L, M, h):
|
|
nn.Module.__init__(self)
|
|
|
|
self.d, self.K, self.L, self.M, self.h = d, K, L, M, h
|
|
|
|
self.codebook0 = nn.Embedding(K, d)
|
|
|
|
self.steps = []
|
|
for m in range(1, M):
|
|
step = QINCoStep(d, K, L, h)
|
|
self.add_module(f"step{m}", step)
|
|
self.steps.append(step)
|
|
|
|
def decode(self, codes):
|
|
xhat = self.codebook0(codes[:, 0])
|
|
for i, step in enumerate(self.steps):
|
|
xhat = xhat + step.decode(xhat, codes[:, i + 1])
|
|
return xhat
|
|
|
|
def encode(self, x, code0=None):
|
|
"""
|
|
Encode a batch of vectors x to codes of length M.
|
|
If this function is called from IVF-QINCo, codes are 1 index longer,
|
|
due to the first index being the IVF index, and codebook0 is the IVF codebook.
|
|
"""
|
|
M = len(self.steps) + 1
|
|
bs, d = x.shape
|
|
codes = torch.zeros(bs, M, dtype=int, device=x.device)
|
|
|
|
if code0 is None:
|
|
# at IVF training time, the code0 is fixed (and precomputed)
|
|
code0 = assign_to_codebook(x, self.codebook0.weight)
|
|
|
|
codes[:, 0] = code0
|
|
xhat = self.codebook0.weight[code0]
|
|
|
|
for i, step in enumerate(self.steps):
|
|
codes[:, i + 1], toadd = step.encode(xhat, x)
|
|
xhat = xhat + toadd
|
|
|
|
return codes, xhat
|
|
|
|
|
|
######################################################
|
|
# QINCo tests
|
|
######################################################
|
|
|
|
def copy_QINCoStep(step):
|
|
step2 = faiss.QINCoStep(step.d, step.K, step.L, step.h)
|
|
step2.codebook.from_torch(step.codebook)
|
|
step2.MLPconcat.from_torch(step.MLPconcat)
|
|
|
|
for l in range(step.L):
|
|
src = step.residual_blocks[l]
|
|
dest = step2.get_residual_block(l)
|
|
dest.linear1.from_torch(src[0])
|
|
dest.linear2.from_torch(src[2])
|
|
return step2
|
|
|
|
|
|
class TestQINCoStep(unittest.TestCase):
|
|
@torch.no_grad()
|
|
def test_decode(self):
|
|
torch.manual_seed(123)
|
|
step = QINCoStep(d=16, K=20, L=2, h=8)
|
|
|
|
codes = torch.randint(0, 20, (10, ))
|
|
xhat = torch.randn(10, 16)
|
|
ref_decode = step.decode(xhat, codes)
|
|
|
|
# step2 = copy_QINCoStep(step)
|
|
step2 = faiss.QINCoStep(step)
|
|
codes2 = faiss.Int32Tensor2D(codes[:, None].to(dtype=torch.int32))
|
|
|
|
np.testing.assert_array_equal(
|
|
step.codebook(codes).numpy(),
|
|
step2.codebook(codes2).numpy()
|
|
)
|
|
|
|
xhat2 = faiss.Tensor2D(xhat)
|
|
# xhat2 = faiss.Tensor2D(len(codes), step2.d)
|
|
|
|
new_decode = step2.decode(xhat2, codes2)
|
|
|
|
np.testing.assert_allclose(
|
|
ref_decode.numpy(),
|
|
new_decode.numpy(),
|
|
atol=2e-6
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def test_encode(self):
|
|
torch.manual_seed(123)
|
|
step = QINCoStep(d=16, K=20, L=2, h=8)
|
|
|
|
# create plausible x for testing starting from actual codes
|
|
codes = torch.randint(0, 20, (10, ))
|
|
xhat = torch.zeros(10, 16)
|
|
x = step.decode(xhat, codes)
|
|
del codes
|
|
ref_codes, toadd = step.encode(xhat, x)
|
|
|
|
step2 = copy_QINCoStep(step)
|
|
xhat2 = faiss.Tensor2D(xhat)
|
|
x2 = faiss.Tensor2D(x)
|
|
toadd2 = faiss.Tensor2D(10, 16)
|
|
|
|
new_codes = step2.encode(xhat2, x2, toadd2)
|
|
|
|
np.testing.assert_allclose(
|
|
ref_codes.numpy(),
|
|
new_codes.numpy().ravel(),
|
|
atol=2e-6
|
|
)
|
|
np.testing.assert_allclose(toadd.numpy(), toadd2.numpy(), atol=2e-6)
|
|
|
|
|
|
|
|
class TestQINCo(unittest.TestCase):
|
|
|
|
@torch.no_grad()
|
|
def test_decode(self):
|
|
torch.manual_seed(123)
|
|
qinco = QINCo(d=16, K=20, L=2, M=3, h=8)
|
|
codes = torch.randint(0, 20, (10, 3))
|
|
x_ref = qinco.decode(codes)
|
|
|
|
qinco2 = faiss.QINCo(qinco)
|
|
codes2 = faiss.Int32Tensor2D(codes.to(dtype=torch.int32))
|
|
x_new = qinco2.decode(codes2)
|
|
|
|
np.testing.assert_allclose(x_ref.numpy(), x_new.numpy(), atol=2e-6)
|
|
|
|
@torch.no_grad()
|
|
def test_encode(self):
|
|
torch.manual_seed(123)
|
|
qinco = QINCo(d=16, K=20, L=2, M=3, h=8)
|
|
codes = torch.randint(0, 20, (10, 3))
|
|
x = qinco.decode(codes)
|
|
del codes
|
|
|
|
ref_codes, _ = qinco.encode(x)
|
|
|
|
qinco2 = faiss.QINCo(qinco)
|
|
x2 = faiss.Tensor2D(x)
|
|
|
|
new_codes = qinco2.encode(x2)
|
|
|
|
np.testing.assert_allclose(ref_codes.numpy(), new_codes.numpy(), atol=2e-6)
|
|
|
|
|
|
######################################################
|
|
# Test index
|
|
######################################################
|
|
|
|
class TestIndexQINCo(unittest.TestCase):
|
|
|
|
def test_search(self):
|
|
"""
|
|
We can't train qinco with just Faiss so we just train a RQ and use the
|
|
codebooks in QINCo with L = 0 residual blocks
|
|
"""
|
|
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
|
|
|
|
# prepare reference quantizer
|
|
M = 5
|
|
index_ref = faiss.index_factory(ds.d, "RQ5x4")
|
|
rq = index_ref.rq
|
|
# rq = faiss.ResidualQuantizer(ds.d, M, 4)
|
|
rq.train_type = faiss.ResidualQuantizer.Train_default
|
|
rq.max_beam_size = 1 # beam search not implemented for QINCo (yet)
|
|
index_ref.train(ds.get_train())
|
|
codebooks = get_additive_quantizer_codebooks(rq)
|
|
|
|
# convert to QINCo index
|
|
qinco_index = faiss.IndexQINCo(ds.d, M, 4, 0, ds.d)
|
|
qinco = qinco_index.qinco
|
|
qinco.codebook0.from_array(codebooks[0])
|
|
for i in range(1, qinco.M):
|
|
step = qinco.get_step(i - 1)
|
|
step.codebook.from_array(codebooks[i])
|
|
# MLPConcat left at zero -- it's added to the backbone
|
|
qinco_index.is_trained = True
|
|
|
|
# verify that the encoding gives the same results
|
|
ref_codes = rq.compute_codes(ds.get_database())
|
|
ref_decoded = rq.decode(ref_codes)
|
|
new_decoded = qinco_index.sa_decode(ref_codes)
|
|
np.testing.assert_allclose(ref_decoded, new_decoded, atol=2e-6)
|
|
|
|
new_codes = qinco_index.sa_encode(ds.get_database())
|
|
np.testing.assert_array_equal(ref_codes, new_codes)
|
|
|
|
# verify that search gives the same results
|
|
Dref, Iref = index_ref.search(ds.get_queries(), 5)
|
|
Dnew, Inew = qinco_index.search(ds.get_queries(), 5)
|
|
|
|
np.testing.assert_array_equal(Iref, Inew)
|
|
np.testing.assert_allclose(Dref, Dnew, atol=2e-6)
|