2024-09-20 09:15:27 -07:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
"""
|
|
|
|
This contrib module contains Pytorch code for quantization.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import numpy as np
|
2024-09-20 09:15:27 -07:00
|
|
|
import torch
|
2024-09-20 09:15:27 -07:00
|
|
|
import faiss
|
|
|
|
|
|
|
|
from faiss.contrib import torch_utils
|
|
|
|
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
class Quantizer:
|
2024-09-20 09:15:27 -07:00
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def __init__(self, d, code_size):
|
|
|
|
self.d = d
|
2024-09-20 09:15:27 -07:00
|
|
|
self.code_size = code_size
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def train(self, x):
|
2024-09-20 09:15:27 -07:00
|
|
|
pass
|
2024-09-20 09:15:27 -07:00
|
|
|
|
|
|
|
def encode(self, x):
|
2024-09-20 09:15:27 -07:00
|
|
|
pass
|
2024-09-20 09:15:27 -07:00
|
|
|
|
|
|
|
def decode(self, x):
|
2024-09-20 09:15:27 -07:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
class VectorQuantizer(Quantizer):
|
2024-09-20 09:15:27 -07:00
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def __init__(self, d, k):
|
2024-09-20 09:15:27 -07:00
|
|
|
code_size = int(torch.ceil(torch.log2(k) / 8))
|
|
|
|
Quantizer.__init__(d, code_size)
|
|
|
|
self.k = k
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def train(self, x):
|
2024-09-20 09:15:27 -07:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
class ProductQuantizer(Quantizer):
|
2024-09-20 09:15:27 -07:00
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def __init__(self, d, M, nbits):
|
2024-09-20 09:15:27 -07:00
|
|
|
code_size = int(torch.ceil(M * nbits / 8))
|
|
|
|
Quantizer.__init__(d, code_size)
|
|
|
|
self.M = M
|
|
|
|
self.nbits = nbits
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def train(self, x):
|
2024-09-20 09:15:27 -07:00
|
|
|
pass
|