181 lines
4.2 KiB
Python
181 lines
4.2 KiB
Python
# Copyright (c) 2015-present, Facebook, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD+Patents license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
#!/usr/bin/env python2
|
|
|
|
import time
|
|
import sys
|
|
import numpy as np
|
|
import faiss
|
|
|
|
|
|
#################################################################
|
|
# Small I/O functions
|
|
#################################################################
|
|
|
|
|
|
def ivecs_read(fname):
|
|
a = np.fromfile(fname, dtype='int32')
|
|
d = a[0]
|
|
return a.reshape(-1, d + 1)[:, 1:].copy()
|
|
|
|
|
|
def fvecs_read(fname):
|
|
return ivecs_read(fname).view('float32')
|
|
|
|
|
|
#################################################################
|
|
# Main program
|
|
#################################################################
|
|
|
|
print "load data"
|
|
|
|
xt = fvecs_read("sift1M/sift_learn.fvecs")
|
|
xb = fvecs_read("sift1M/sift_base.fvecs")
|
|
xq = fvecs_read("sift1M/sift_query.fvecs")
|
|
|
|
nq, d = xq.shape
|
|
|
|
print "load GT"
|
|
|
|
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
|
|
|
|
todo = sys.argv[1:]
|
|
|
|
if todo == []:
|
|
todo = 'hnsw hnsw_sq ivf ivf_hnsw_quantizer kmeans kmeans_hnsw'.split()
|
|
|
|
|
|
def evaluate(index):
|
|
# for timing with a single core
|
|
# faiss.omp_set_num_threads(1)
|
|
|
|
t0 = time.time()
|
|
D, I = index.search(xq, 1)
|
|
t1 = time.time()
|
|
|
|
recall_at_1 = (I == gt[:, :1]).sum() / float(nq)
|
|
print "\t %7.3f ms per query, R@1 %.4f" % (
|
|
(t1 - t0) * 1000.0 / nq, recall_at_1)
|
|
|
|
|
|
if 'hnsw' in todo:
|
|
|
|
print "Testing HNSW Flat"
|
|
|
|
index = faiss.IndexHNSWFlat(d, 32)
|
|
|
|
# training is not needed
|
|
|
|
# this is the default, higher is more accurate and slower to
|
|
# construct
|
|
index.hnsw.efConstruction = 40
|
|
|
|
print "add"
|
|
# to see progress
|
|
index.verbose = True
|
|
index.add(xb)
|
|
|
|
print "search"
|
|
for efSearch in 16, 32, 64, 128, 256:
|
|
print "efSearch", efSearch,
|
|
index.hnsw.efSearch = efSearch
|
|
evaluate(index)
|
|
|
|
if 'hnsw_sq' in todo:
|
|
|
|
print "Testing HNSW with a scalar quantizer"
|
|
# also set M so that the vectors and links both use 128 bytes per
|
|
# entry (total 256 bytes)
|
|
index = faiss.IndexHNSWSQ(d, faiss.ScalarQuantizer.QT_8bit, 16)
|
|
|
|
print "training"
|
|
# training for the scalar quantizer
|
|
index.train(xt)
|
|
|
|
# this is the default, higher is more accurate and slower to
|
|
# construct
|
|
index.hnsw.efConstruction = 40
|
|
|
|
print "add"
|
|
# to see progress
|
|
index.verbose = True
|
|
index.add(xb)
|
|
|
|
print "search"
|
|
for efSearch in 16, 32, 64, 128, 256:
|
|
print "efSearch", efSearch,
|
|
index.hnsw.efSearch = efSearch
|
|
evaluate(index)
|
|
|
|
if 'ivf' in todo:
|
|
|
|
print "Testing IVF Flat (baseline)"
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
index = faiss.IndexIVFFlat(quantizer, d, 16384)
|
|
index.cp.min_points_per_centroid = 5 # quiet warning
|
|
|
|
# to see progress
|
|
index.verbose = True
|
|
|
|
print "training"
|
|
index.train(xt)
|
|
|
|
print "add"
|
|
index.add(xb)
|
|
|
|
print "search"
|
|
for nprobe in 1, 4, 16, 64, 256:
|
|
print "nprobe", nprobe,
|
|
index.nprobe = nprobe
|
|
evaluate(index)
|
|
|
|
if 'ivf_hnsw_quantizer' in todo:
|
|
|
|
print "Testing IVF Flat with HNSW quantizer"
|
|
quantizer = faiss.IndexHNSWFlat(d, 32)
|
|
index = faiss.IndexIVFFlat(quantizer, d, 16384)
|
|
index.cp.min_points_per_centroid = 5 # quiet warning
|
|
index.quantizer_trains_alone = 2
|
|
|
|
# to see progress
|
|
index.verbose = True
|
|
|
|
print "training"
|
|
index.train(xt)
|
|
|
|
print "add"
|
|
index.add(xb)
|
|
|
|
print "search"
|
|
quantizer.hnsw.efSearch = 64
|
|
for nprobe in 1, 4, 16, 64, 256:
|
|
print "nprobe", nprobe,
|
|
index.nprobe = nprobe
|
|
evaluate(index)
|
|
|
|
# Bonus: 2 kmeans tests
|
|
|
|
if 'kmeans' in todo:
|
|
print "Performing kmeans on sift1M database vectors (baseline)"
|
|
clus = faiss.Clustering(d, 16384)
|
|
clus.verbose = True
|
|
clus.niter = 10
|
|
index = faiss.IndexFlatL2(d)
|
|
clus.train(xb, index)
|
|
|
|
|
|
if 'kmeans_hnsw' in todo:
|
|
print "Performing kmeans on sift1M using HNSW assignment"
|
|
clus = faiss.Clustering(d, 16384)
|
|
clus.verbose = True
|
|
clus.niter = 10
|
|
index = faiss.IndexHNSWFlat(d, 32)
|
|
# increase the default efSearch, otherwise the number of empty
|
|
# clusters is too high.
|
|
index.hnsw.efSearch = 128
|
|
clus.train(xb, index)
|