PQ with pytorch (#4116)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4116

This diff implements Product Quantization using Pytorch only.

Reviewed By: mdouze

Differential Revision: D67766798

fbshipit-source-id: fe2d44a674fc2056f7e2082e9765052c98fdc8f8
pull/4022/head^2
Maria Lomeli 2025-01-06 09:48:32 -08:00 committed by Facebook GitHub Bot
parent 0cbc2a885c
commit 9590ad2746
3 changed files with 82 additions and 13 deletions

View File

@ -13,6 +13,7 @@ import torch
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans
class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""

View File

@ -7,33 +7,47 @@
This contrib module contains Pytorch code for quantization.
"""
import numpy as np
import torch
import faiss
from faiss.contrib import torch_utils
import math
from faiss.contrib.torch import clustering
# the kmeans can produce both torch and numpy centroids
class Quantizer:
def __init__(self, d, code_size):
"""
d: dimension of vectors
code_size: nb of bytes of the code (per vector)
"""
self.d = d
self.code_size = code_size
def train(self, x):
"""
takes a n-by-d array and peforms training
"""
pass
def encode(self, x):
"""
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
"""
pass
def decode(self, x):
def decode(self, codes):
"""
takes a n-by-code_size uint8 array, returns a n-by-d array
"""
pass
class VectorQuantizer(Quantizer):
def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))
code_size = int(math.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k
@ -42,12 +56,41 @@ class VectorQuantizer(Quantizer):
class ProductQuantizer(Quantizer):
def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
""" M: number of subvectors, d%M == 0
nbits: number of bits that each vector is encoded into
"""
assert d % M == 0
assert nbits == 8 # todo: implement other nbits values
code_size = int(math.ceil(M * nbits / 8))
Quantizer.__init__(self, d, code_size)
self.M = M
self.nbits = nbits
self.code_size = code_size
def train(self, x):
pass
nc = 2 ** self.nbits
sd = self.d // self.M
dev = x.device
dtype = x.dtype
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
for m in range(self.M):
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
data = clustering.DatasetAssign(xsub.contiguous())
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)
def encode(self, x):
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
for m in range(self.M):
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
codes[:, m] = I.ravel()
return codes
def decode(self, codes):
idxs = [codes[:, m].long() for m in range(self.M)]
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
stacked_vectors = torch.stack(vectors, dim=1)
cbd = self.codebook.shape[-1]
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
return x_rec

View File

@ -10,7 +10,8 @@ import numpy as np # usort: skip
import faiss # usort: skip
import faiss.contrib.torch_utils # usort: skip
from faiss.contrib import datasets
from faiss.contrib.torch import clustering
from faiss.contrib.torch import clustering, quantization
@ -400,3 +401,27 @@ class TestClustering(unittest.TestCase):
# 33498.332 33380.477
# print(err, err2) 1/0
self.assertLess(err2, err * 1.1)
class TestQuantization(unittest.TestCase):
def test_python_product_quantization(self):
""" Test the python implementation of product quantization """
d = 64
n = 10000
cs = 4
nbits = 8
M = 4
x = np.random.random(size=(n, d)).astype('float32')
pq = faiss.ProductQuantizer(d, cs, nbits)
pq.train(x)
codes = pq.compute_codes(x)
x2 = pq.decode(codes)
diff = ((x - x2)**2).sum()
# vs pure pytorch impl
xt = torch.from_numpy(x)
my_pq = quantization.ProductQuantizer(d, M, nbits)
my_pq.train(xt)
my_codes = my_pq.encode(xt)
xt2 = my_pq.decode(my_codes)
my_diff = ((xt - xt2)**2).sum()
self.assertLess(abs(diff - my_diff), 100)