115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
# Copyright (c) 2015-present, Facebook, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD+Patents license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
#!/usr/bin/env python2
|
|
|
|
import time
|
|
import numpy as np
|
|
import faiss
|
|
|
|
#################################################################
|
|
# I/O functions
|
|
#################################################################
|
|
|
|
|
|
def ivecs_read(fname):
|
|
a = np.fromfile(fname, dtype='int32')
|
|
d = a[0]
|
|
return a.reshape(-1, d + 1)[:, 1:].copy()
|
|
|
|
|
|
def fvecs_read(fname):
|
|
return ivecs_read(fname).view('float32')
|
|
|
|
|
|
#################################################################
|
|
# Main program
|
|
#################################################################
|
|
|
|
print "load data"
|
|
|
|
xt = fvecs_read("sift1M/sift_learn.fvecs")
|
|
xb = fvecs_read("sift1M/sift_base.fvecs")
|
|
xq = fvecs_read("sift1M/sift_query.fvecs")
|
|
|
|
# xq = xq[:1000]
|
|
# xb = xb[:100000]
|
|
|
|
nq, d = xq.shape
|
|
|
|
print "load GT"
|
|
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
|
|
|
|
# gt = gt[:1000]
|
|
|
|
|
|
ncent = 256
|
|
|
|
variants = [(name, getattr(faiss.ScalarQuantizer, name))
|
|
for name in dir(faiss.ScalarQuantizer)
|
|
if name.startswith('QT_')]
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
# quantizer.add(np.zeros((1, d), dtype='float32'))
|
|
|
|
if False:
|
|
for name, qtype in [('flat', 0)] + variants:
|
|
|
|
print "============== test", name
|
|
t0 = time.time()
|
|
|
|
if name == 'flat':
|
|
index = faiss.IndexIVFFlat(quantizer, d, ncent,
|
|
faiss.METRIC_L2)
|
|
else:
|
|
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
|
|
qtype, faiss.METRIC_L2)
|
|
|
|
index.nprobe = 16
|
|
print "[%.3f s] train" % (time.time() - t0)
|
|
index.train(xt)
|
|
print "[%.3f s] add" % (time.time() - t0)
|
|
index.add(xb)
|
|
print "[%.3f s] search" % (time.time() - t0)
|
|
D, I = index.search(xq, 100)
|
|
print "[%.3f s] eval" % (time.time() - t0)
|
|
|
|
for rank in 1, 10, 100:
|
|
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
|
print "%.4f" % (n_ok / float(nq)),
|
|
print
|
|
|
|
if True:
|
|
for name, qtype in variants:
|
|
|
|
print "============== test", name
|
|
|
|
for rsname, vals in [('RS_minmax',
|
|
[-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
|
|
('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
|
|
('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
|
|
('RS_optim', [0.0])]:
|
|
for val in vals:
|
|
print "%-15s %5g " % (rsname, val),
|
|
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
|
|
qtype, faiss.METRIC_L2)
|
|
index.nprobe = 16
|
|
index.sq.rangestat = getattr(faiss.ScalarQuantizer,
|
|
rsname)
|
|
|
|
index.rangestat_arg = val
|
|
|
|
index.train(xt)
|
|
index.add(xb)
|
|
t0 = time.time()
|
|
D, I = index.search(xq, 100)
|
|
t1 = time.time()
|
|
|
|
for rank in 1, 10, 100:
|
|
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
|
print "%.4f" % (n_ok / float(nq)),
|
|
print " %.3f s" % (t1 - t0)
|