mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1906 This PR implemented LSQ/LSQ++, a vector quantization technique described in the following two papers: 1. Revisiting additive quantization 2. LSQ++: Lower running time and higher recall in multi-codebook quantization Here is a benchmark running on SIFT1M for 64 bits encoding: ``` ===== lsq: mean square error = 17335.390208 training time: 312.729779958725 s encoding time: 244.6277096271515 s ===== pq: mean square error = 23743.004672 training time: 1.1610801219940186 s encoding time: 2.636141061782837 s ===== rq: mean square error = 20999.737344 training time: 31.813055515289307 s encoding time: 307.51959800720215 s ``` Changes: 1. Add LocalSearchQuantizer object 2. Fix an out of memory bug in ResidualQuantizer 3. Add a benchmark for evaluating quantizers 4. Add tests for LocalSearchQuantizer Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1862 Test Plan: ``` buck test //faiss/tests/:test_lsq buck run mode/opt //faiss/benchs/:bench_quantizer -- lsq pq rq ``` Reviewed By: beauby Differential Revision: D28376369 Pulled By: mdouze fbshipit-source-id: 2a394d38bf75b9de0a1c2cd6faddf7dd362a6fa8
65 lines
1.5 KiB
Python
65 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import sys
|
|
import faiss
|
|
import time
|
|
|
|
try:
|
|
from faiss.contrib.datasets_fb import DatasetSIFT1M
|
|
except ImportError:
|
|
from faiss.contrib.datasets import DatasetSIFT1M
|
|
|
|
|
|
def eval_codec(q, xb):
|
|
t0 = time.time()
|
|
codes = q.compute_codes(xb)
|
|
t1 = time.time()
|
|
decoded = q.decode(codes)
|
|
return ((xb - decoded) ** 2).sum() / xb.shape[0], t1 - t0
|
|
|
|
|
|
def eval_quantizer(q, xb, xt, name):
|
|
t0 = time.time()
|
|
q.train(xt)
|
|
t1 = time.time()
|
|
train_t = t1 - t0
|
|
err, encode_t = eval_codec(q, xb)
|
|
print(f'===== {name}:')
|
|
print(f'\tmean square error = {err}')
|
|
print(f'\ttraining time: {train_t} s')
|
|
print(f'\tencoding time: {encode_t} s')
|
|
|
|
|
|
todo = sys.argv[1:]
|
|
ds = DatasetSIFT1M()
|
|
|
|
xq = ds.get_queries()
|
|
xb = ds.get_database()
|
|
gt = ds.get_groundtruth()
|
|
xt = ds.get_train()
|
|
|
|
nb, d = xb.shape
|
|
nq, d = xq.shape
|
|
nt, d = xt.shape
|
|
|
|
M = 8
|
|
nbits = 8
|
|
|
|
if 'lsq' in todo:
|
|
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
|
lsq.log_level = 2 # show detailed training progress
|
|
eval_quantizer(lsq, xb, xt, 'lsq')
|
|
|
|
if 'pq' in todo:
|
|
pq = faiss.ProductQuantizer(d, M, nbits)
|
|
eval_quantizer(pq, xb, xt, 'pq')
|
|
|
|
if 'rq' in todo:
|
|
rq = faiss.ResidualQuantizer(d, M, nbits)
|
|
rq.train_type = faiss.ResidualQuantizer.Train_default
|
|
rq.verbose = True
|
|
eval_quantizer(rq, xb, xt, 'rq')
|