faiss/benchs/bench_scalar_quantizer.py

115 lines
3.3 KiB
Python

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD+Patents license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python2
import time
import numpy as np
import faiss
#################################################################
# I/O functions
#################################################################
def ivecs_read(fname):
a = np.fromfile(fname, dtype='int32')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
#################################################################
# Main program
#################################################################
print "load data"
xt = fvecs_read("sift1M/sift_learn.fvecs")
xb = fvecs_read("sift1M/sift_base.fvecs")
xq = fvecs_read("sift1M/sift_query.fvecs")
# xq = xq[:1000]
# xb = xb[:100000]
nq, d = xq.shape
print "load GT"
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
# gt = gt[:1000]
ncent = 256
variants = [(name, getattr(faiss.ScalarQuantizer, name))
for name in dir(faiss.ScalarQuantizer)
if name.startswith('QT_')]
quantizer = faiss.IndexFlatL2(d)
# quantizer.add(np.zeros((1, d), dtype='float32'))
if False:
for name, qtype in [('flat', 0)] + variants:
print "============== test", name
t0 = time.time()
if name == 'flat':
index = faiss.IndexIVFFlat(quantizer, d, ncent,
faiss.METRIC_L2)
else:
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
index.nprobe = 16
print "[%.3f s] train" % (time.time() - t0)
index.train(xt)
print "[%.3f s] add" % (time.time() - t0)
index.add(xb)
print "[%.3f s] search" % (time.time() - t0)
D, I = index.search(xq, 100)
print "[%.3f s] eval" % (time.time() - t0)
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print "%.4f" % (n_ok / float(nq)),
print
if True:
for name, qtype in variants:
print "============== test", name
for rsname, vals in [('RS_minmax',
[-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
('RS_optim', [0.0])]:
for val in vals:
print "%-15s %5g " % (rsname, val),
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
index.nprobe = 16
index.sq.rangestat = getattr(faiss.ScalarQuantizer,
rsname)
index.rangestat_arg = val
index.train(xt)
index.add(xb)
t0 = time.time()
D, I = index.search(xq, 100)
t1 = time.time()
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print "%.4f" % (n_ok / float(nq)),
print " %.3f s" % (t1 - t0)