88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import sys
|
|
import numpy as np
|
|
import faiss
|
|
from faiss.contrib.ondisk import merge_ondisk
|
|
|
|
#################################################################
|
|
# Small I/O functions
|
|
#################################################################
|
|
|
|
|
|
def ivecs_read(fname):
|
|
a = np.fromfile(fname, dtype='int32')
|
|
d = a[0]
|
|
return a.reshape(-1, d + 1)[:, 1:].copy()
|
|
|
|
|
|
def fvecs_read(fname):
|
|
return ivecs_read(fname).view('float32')
|
|
|
|
|
|
#################################################################
|
|
# Main program
|
|
#################################################################
|
|
|
|
stage = int(sys.argv[1])
|
|
|
|
tmpdir = '/tmp/'
|
|
|
|
if stage == 0:
|
|
# train the index
|
|
xt = fvecs_read("sift1M/sift_learn.fvecs")
|
|
index = faiss.index_factory(xt.shape[1], "IVF4096,Flat")
|
|
print("training index")
|
|
index.train(xt)
|
|
print("write " + tmpdir + "trained.index")
|
|
faiss.write_index(index, tmpdir + "trained.index")
|
|
|
|
|
|
if 1 <= stage <= 4:
|
|
# add 1/4 of the database to 4 independent indexes
|
|
bno = stage - 1
|
|
xb = fvecs_read("sift1M/sift_base.fvecs")
|
|
i0, i1 = int(bno * xb.shape[0] / 4), int((bno + 1) * xb.shape[0] / 4)
|
|
index = faiss.read_index(tmpdir + "trained.index")
|
|
print("adding vectors %d:%d" % (i0, i1))
|
|
index.add_with_ids(xb[i0:i1], np.arange(i0, i1))
|
|
print("write " + tmpdir + "block_%d.index" % bno)
|
|
faiss.write_index(index, tmpdir + "block_%d.index" % bno)
|
|
|
|
if stage == 5:
|
|
|
|
print('loading trained index')
|
|
# construct the output index
|
|
index = faiss.read_index(tmpdir + "trained.index")
|
|
|
|
block_fnames = [
|
|
tmpdir + "block_%d.index" % bno
|
|
for bno in range(4)
|
|
]
|
|
|
|
merge_ondisk(index, block_fnames, tmpdir + "merged_index.ivfdata")
|
|
|
|
print("write " + tmpdir + "populated.index")
|
|
faiss.write_index(index, tmpdir + "populated.index")
|
|
|
|
|
|
if stage == 6:
|
|
# perform a search from disk
|
|
print("read " + tmpdir + "populated.index")
|
|
index = faiss.read_index(tmpdir + "populated.index")
|
|
index.nprobe = 16
|
|
|
|
# load query vectors and ground-truth
|
|
xq = fvecs_read("sift1M/sift_query.fvecs")
|
|
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
|
|
|
|
D, I = index.search(xq, 5)
|
|
|
|
recall_at_1 = (I[:, :1] == gt[:, :1]).sum() / float(xq.shape[0])
|
|
print("recall@1: %.3f" % recall_at_1)
|