65 lines
1.5 KiB
Python
65 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import sys
|
|
import faiss
|
|
import time
|
|
|
|
try:
|
|
from faiss.contrib.datasets_fb import DatasetSIFT1M
|
|
except ImportError:
|
|
from faiss.contrib.datasets import DatasetSIFT1M
|
|
|
|
|
|
def eval_codec(q, xb):
|
|
t0 = time.time()
|
|
codes = q.compute_codes(xb)
|
|
t1 = time.time()
|
|
decoded = q.decode(codes)
|
|
return ((xb - decoded) ** 2).sum() / xb.shape[0], t1 - t0
|
|
|
|
|
|
def eval_quantizer(q, xb, xt, name):
|
|
t0 = time.time()
|
|
q.train(xt)
|
|
t1 = time.time()
|
|
train_t = t1 - t0
|
|
err, encode_t = eval_codec(q, xb)
|
|
print(f'===== {name}:')
|
|
print(f'\tmean square error = {err}')
|
|
print(f'\ttraining time: {train_t} s')
|
|
print(f'\tencoding time: {encode_t} s')
|
|
|
|
|
|
todo = sys.argv[1:]
|
|
ds = DatasetSIFT1M()
|
|
|
|
xq = ds.get_queries()
|
|
xb = ds.get_database()
|
|
gt = ds.get_groundtruth()
|
|
xt = ds.get_train()
|
|
|
|
nb, d = xb.shape
|
|
nq, d = xq.shape
|
|
nt, d = xt.shape
|
|
|
|
M = 8
|
|
nbits = 8
|
|
|
|
if 'lsq' in todo:
|
|
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
|
lsq.log_level = 2 # show detailed training progress
|
|
eval_quantizer(lsq, xb, xt, 'lsq')
|
|
|
|
if 'pq' in todo:
|
|
pq = faiss.ProductQuantizer(d, M, nbits)
|
|
eval_quantizer(pq, xb, xt, 'pq')
|
|
|
|
if 'rq' in todo:
|
|
rq = faiss.ResidualQuantizer(d, M, nbits)
|
|
rq.train_type = faiss.ResidualQuantizer.Train_default
|
|
rq.verbose = True
|
|
eval_quantizer(rq, xb, xt, 'rq')
|