diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py index 743853460..8ff5cb24c 100644 --- a/contrib/torch/clustering.py +++ b/contrib/torch/clustering.py @@ -13,6 +13,7 @@ import torch # the kmeans can produce both torch and numpy centroids from faiss.contrib.clustering import kmeans + class DatasetAssign: """Wrapper for a tensor that offers a function to assign the vectors to centroids. All other implementations offer the same interface""" diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py index a1d8f7dd8..2ae6599a0 100644 --- a/contrib/torch/quantization.py +++ b/contrib/torch/quantization.py @@ -7,33 +7,47 @@ This contrib module contains Pytorch code for quantization. """ -import numpy as np import torch import faiss - -from faiss.contrib import torch_utils +import math +from faiss.contrib.torch import clustering +# the kmeans can produce both torch and numpy centroids class Quantizer: def __init__(self, d, code_size): + """ + d: dimension of vectors + code_size: nb of bytes of the code (per vector) + """ self.d = d self.code_size = code_size def train(self, x): + """ + takes a n-by-d array and peforms training + """ pass def encode(self, x): + """ + takes a n-by-d float array, encodes to an n-by-code_size uint8 array + """ pass - def decode(self, x): + def decode(self, codes): + """ + takes a n-by-code_size uint8 array, returns a n-by-d array + """ pass class VectorQuantizer(Quantizer): def __init__(self, d, k): - code_size = int(torch.ceil(torch.log2(k) / 8)) + + code_size = int(math.ceil(torch.log2(k) / 8)) Quantizer.__init__(d, code_size) self.k = k @@ -42,12 +56,41 @@ class VectorQuantizer(Quantizer): class ProductQuantizer(Quantizer): - def __init__(self, d, M, nbits): - code_size = int(torch.ceil(M * nbits / 8)) - Quantizer.__init__(d, code_size) + """ M: number of subvectors, d%M == 0 + nbits: number of bits that each vector is encoded into + """ + assert d % M == 0 + assert nbits == 8 # todo: implement other nbits values + code_size = int(math.ceil(M * nbits / 8)) + Quantizer.__init__(self, d, code_size) self.M = M self.nbits = nbits + self.code_size = code_size def train(self, x): - pass + nc = 2 ** self.nbits + sd = self.d // self.M + dev = x.device + dtype = x.dtype + self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype) + for m in range(self.M): + xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M] + data = clustering.DatasetAssign(xsub.contiguous()) + self.codebook[m] = clustering.kmeans(2 ** self.nbits, data) + + def encode(self, x): + codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8) + for m in range(self.M): + xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M] + _, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1) + codes[:, m] = I.ravel() + return codes + + def decode(self, codes): + idxs = [codes[:, m].long() for m in range(self.M)] + vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)] + stacked_vectors = torch.stack(vectors, dim=1) + cbd = self.codebook.shape[-1] + x_rec = stacked_vectors.reshape(-1, cbd * self.M) + return x_rec diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index 26c381b3c..3eb6c6c43 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. import torch # usort: skip -import unittest # usort: skip -import numpy as np # usort: skip +import unittest # usort: skip +import numpy as np # usort: skip -import faiss # usort: skip +import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip from faiss.contrib import datasets -from faiss.contrib.torch import clustering +from faiss.contrib.torch import clustering, quantization + @@ -400,3 +401,27 @@ class TestClustering(unittest.TestCase): # 33498.332 33380.477 # print(err, err2) 1/0 self.assertLess(err2, err * 1.1) + + +class TestQuantization(unittest.TestCase): + def test_python_product_quantization(self): + """ Test the python implementation of product quantization """ + d = 64 + n = 10000 + cs = 4 + nbits = 8 + M = 4 + x = np.random.random(size=(n, d)).astype('float32') + pq = faiss.ProductQuantizer(d, cs, nbits) + pq.train(x) + codes = pq.compute_codes(x) + x2 = pq.decode(codes) + diff = ((x - x2)**2).sum() + # vs pure pytorch impl + xt = torch.from_numpy(x) + my_pq = quantization.ProductQuantizer(d, M, nbits) + my_pq.train(xt) + my_codes = my_pq.encode(xt) + xt2 = my_pq.decode(my_codes) + my_diff = ((xt - xt2)**2).sum() + self.assertLess(abs(diff - my_diff), 100)