PQ with pytorch (#4116)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4116 This diff implements Product Quantization using Pytorch only. Reviewed By: mdouze Differential Revision: D67766798 fbshipit-source-id: fe2d44a674fc2056f7e2082e9765052c98fdc8f8pull/4022/head^2
parent
0cbc2a885c
commit
9590ad2746
|
@ -13,6 +13,7 @@ import torch
|
|||
# the kmeans can produce both torch and numpy centroids
|
||||
from faiss.contrib.clustering import kmeans
|
||||
|
||||
|
||||
class DatasetAssign:
|
||||
"""Wrapper for a tensor that offers a function to assign the vectors
|
||||
to centroids. All other implementations offer the same interface"""
|
||||
|
|
|
@ -7,33 +7,47 @@
|
|||
This contrib module contains Pytorch code for quantization.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import faiss
|
||||
|
||||
from faiss.contrib import torch_utils
|
||||
import math
|
||||
from faiss.contrib.torch import clustering
|
||||
# the kmeans can produce both torch and numpy centroids
|
||||
|
||||
|
||||
class Quantizer:
|
||||
|
||||
def __init__(self, d, code_size):
|
||||
"""
|
||||
d: dimension of vectors
|
||||
code_size: nb of bytes of the code (per vector)
|
||||
"""
|
||||
self.d = d
|
||||
self.code_size = code_size
|
||||
|
||||
def train(self, x):
|
||||
"""
|
||||
takes a n-by-d array and peforms training
|
||||
"""
|
||||
pass
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
|
||||
"""
|
||||
pass
|
||||
|
||||
def decode(self, x):
|
||||
def decode(self, codes):
|
||||
"""
|
||||
takes a n-by-code_size uint8 array, returns a n-by-d array
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VectorQuantizer(Quantizer):
|
||||
|
||||
def __init__(self, d, k):
|
||||
code_size = int(torch.ceil(torch.log2(k) / 8))
|
||||
|
||||
code_size = int(math.ceil(torch.log2(k) / 8))
|
||||
Quantizer.__init__(d, code_size)
|
||||
self.k = k
|
||||
|
||||
|
@ -42,12 +56,41 @@ class VectorQuantizer(Quantizer):
|
|||
|
||||
|
||||
class ProductQuantizer(Quantizer):
|
||||
|
||||
def __init__(self, d, M, nbits):
|
||||
code_size = int(torch.ceil(M * nbits / 8))
|
||||
Quantizer.__init__(d, code_size)
|
||||
""" M: number of subvectors, d%M == 0
|
||||
nbits: number of bits that each vector is encoded into
|
||||
"""
|
||||
assert d % M == 0
|
||||
assert nbits == 8 # todo: implement other nbits values
|
||||
code_size = int(math.ceil(M * nbits / 8))
|
||||
Quantizer.__init__(self, d, code_size)
|
||||
self.M = M
|
||||
self.nbits = nbits
|
||||
self.code_size = code_size
|
||||
|
||||
def train(self, x):
|
||||
pass
|
||||
nc = 2 ** self.nbits
|
||||
sd = self.d // self.M
|
||||
dev = x.device
|
||||
dtype = x.dtype
|
||||
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
|
||||
for m in range(self.M):
|
||||
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
|
||||
data = clustering.DatasetAssign(xsub.contiguous())
|
||||
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)
|
||||
|
||||
def encode(self, x):
|
||||
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
|
||||
for m in range(self.M):
|
||||
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
|
||||
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
|
||||
codes[:, m] = I.ravel()
|
||||
return codes
|
||||
|
||||
def decode(self, codes):
|
||||
idxs = [codes[:, m].long() for m in range(self.M)]
|
||||
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
|
||||
stacked_vectors = torch.stack(vectors, dim=1)
|
||||
cbd = self.codebook.shape[-1]
|
||||
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
|
||||
return x_rec
|
||||
|
|
|
@ -4,13 +4,14 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch # usort: skip
|
||||
import unittest # usort: skip
|
||||
import numpy as np # usort: skip
|
||||
import unittest # usort: skip
|
||||
import numpy as np # usort: skip
|
||||
|
||||
import faiss # usort: skip
|
||||
import faiss # usort: skip
|
||||
import faiss.contrib.torch_utils # usort: skip
|
||||
from faiss.contrib import datasets
|
||||
from faiss.contrib.torch import clustering
|
||||
from faiss.contrib.torch import clustering, quantization
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -400,3 +401,27 @@ class TestClustering(unittest.TestCase):
|
|||
# 33498.332 33380.477
|
||||
# print(err, err2) 1/0
|
||||
self.assertLess(err2, err * 1.1)
|
||||
|
||||
|
||||
class TestQuantization(unittest.TestCase):
|
||||
def test_python_product_quantization(self):
|
||||
""" Test the python implementation of product quantization """
|
||||
d = 64
|
||||
n = 10000
|
||||
cs = 4
|
||||
nbits = 8
|
||||
M = 4
|
||||
x = np.random.random(size=(n, d)).astype('float32')
|
||||
pq = faiss.ProductQuantizer(d, cs, nbits)
|
||||
pq.train(x)
|
||||
codes = pq.compute_codes(x)
|
||||
x2 = pq.decode(codes)
|
||||
diff = ((x - x2)**2).sum()
|
||||
# vs pure pytorch impl
|
||||
xt = torch.from_numpy(x)
|
||||
my_pq = quantization.ProductQuantizer(d, M, nbits)
|
||||
my_pq.train(xt)
|
||||
my_codes = my_pq.encode(xt)
|
||||
xt2 = my_pq.decode(my_codes)
|
||||
my_diff = ((xt - xt2)**2).sum()
|
||||
self.assertLess(abs(diff - my_diff), 100)
|
||||
|
|
Loading…
Reference in New Issue