faiss/benchs/bench_quantizer.py

65 lines
1.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import time
try:
from faiss.contrib.datasets_fb import DatasetSIFT1M
except ImportError:
from faiss.contrib.datasets import DatasetSIFT1M
def eval_codec(q, xb):
t0 = time.time()
codes = q.compute_codes(xb)
t1 = time.time()
decoded = q.decode(codes)
return ((xb - decoded) ** 2).sum() / xb.shape[0], t1 - t0
def eval_quantizer(q, xb, xt, name):
t0 = time.time()
q.train(xt)
t1 = time.time()
train_t = t1 - t0
err, encode_t = eval_codec(q, xb)
print(f'===== {name}:')
print(f'\tmean square error = {err}')
print(f'\ttraining time: {train_t} s')
print(f'\tencoding time: {encode_t} s')
todo = sys.argv[1:]
ds = DatasetSIFT1M()
xq = ds.get_queries()
xb = ds.get_database()
gt = ds.get_groundtruth()
xt = ds.get_train()
nb, d = xb.shape
nq, d = xq.shape
nt, d = xt.shape
M = 8
nbits = 8
if 'lsq' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.log_level = 2 # show detailed training progress
eval_quantizer(lsq, xb, xt, 'lsq')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
eval_quantizer(pq, xb, xt, 'pq')
if 'rq' in todo:
rq = faiss.ResidualQuantizer(d, M, nbits)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.verbose = True
eval_quantizer(rq, xb, xt, 'rq')