170 lines
4.2 KiB
Python
170 lines
4.2 KiB
Python
#!/usr/bin/env python2
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from __future__ import print_function
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
from matplotlib import pyplot
|
|
graphical_output = True
|
|
except ImportError:
|
|
graphical_output = False
|
|
|
|
import faiss
|
|
|
|
#################################################################
|
|
# Small I/O functions
|
|
#################################################################
|
|
|
|
def ivecs_read(fname):
|
|
a = np.fromfile(fname, dtype="int32")
|
|
d = a[0]
|
|
return a.reshape(-1, d + 1)[:, 1:].copy()
|
|
|
|
def fvecs_read(fname):
|
|
return ivecs_read(fname).view('float32')
|
|
|
|
|
|
def plot_OperatingPoints(ops, nq, **kwargs):
|
|
ops = ops.optimal_pts
|
|
n = ops.size() * 2 - 1
|
|
pyplot.plot([ops.at( i // 2).perf for i in range(n)],
|
|
[ops.at((i + 1) // 2).t / nq * 1000 for i in range(n)],
|
|
**kwargs)
|
|
|
|
|
|
#################################################################
|
|
# prepare common data for all indexes
|
|
#################################################################
|
|
|
|
|
|
|
|
t0 = time.time()
|
|
|
|
print("load data")
|
|
|
|
xt = fvecs_read("sift1M/sift_learn.fvecs")
|
|
xb = fvecs_read("sift1M/sift_base.fvecs")
|
|
xq = fvecs_read("sift1M/sift_query.fvecs")
|
|
|
|
d = xt.shape[1]
|
|
|
|
print("load GT")
|
|
|
|
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
|
|
gt = gt.astype('int64')
|
|
k = gt.shape[1]
|
|
|
|
print("prepare criterion")
|
|
|
|
# criterion = 1-recall at 1
|
|
crit = faiss.OneRecallAtRCriterion(xq.shape[0], 1)
|
|
crit.set_groundtruth(None, gt)
|
|
crit.nnn = k
|
|
|
|
# indexes that are useful when there is no limitation on memory usage
|
|
unlimited_mem_keys = [
|
|
"IMI2x10,Flat", "IMI2x11,Flat",
|
|
"IVF4096,Flat", "IVF16384,Flat",
|
|
"PCA64,IMI2x10,Flat"]
|
|
|
|
# memory limited to 16 bytes / vector
|
|
keys_mem_16 = [
|
|
'IMI2x10,PQ16', 'IVF4096,PQ16',
|
|
'IMI2x10,PQ8+8', 'OPQ16_64,IMI2x10,PQ16'
|
|
]
|
|
|
|
# limited to 32 bytes / vector
|
|
keys_mem_32 = [
|
|
'IMI2x10,PQ32', 'IVF4096,PQ32', 'IVF16384,PQ32',
|
|
'IMI2x10,PQ16+16',
|
|
'OPQ32,IVF4096,PQ32', 'IVF4096,PQ16+16', 'OPQ16,IMI2x10,PQ16+16'
|
|
]
|
|
|
|
# indexes that can run on the GPU
|
|
keys_gpu = [
|
|
"PCA64,IVF4096,Flat",
|
|
"PCA64,Flat", "Flat", "IVF4096,Flat", "IVF16384,Flat",
|
|
"IVF4096,PQ32"]
|
|
|
|
|
|
keys_to_test = unlimited_mem_keys
|
|
use_gpu = False
|
|
|
|
|
|
if use_gpu:
|
|
# if this fails, it means that the GPU version was not comp
|
|
assert faiss.StandardGpuResources, \
|
|
"Faiss was not compiled with GPU support, or loading _swigfaiss_gpu.so failed"
|
|
res = faiss.StandardGpuResources()
|
|
dev_no = 0
|
|
|
|
# remember results from other index types
|
|
op_per_key = []
|
|
|
|
|
|
# keep track of optimal operating points seen so far
|
|
op = faiss.OperatingPoints()
|
|
|
|
|
|
for index_key in keys_to_test:
|
|
|
|
print("============ key", index_key)
|
|
|
|
# make the index described by the key
|
|
index = faiss.index_factory(d, index_key)
|
|
|
|
|
|
if use_gpu:
|
|
# transfer to GPU (may be partial)
|
|
index = faiss.index_cpu_to_gpu(res, dev_no, index)
|
|
params = faiss.GpuParameterSpace()
|
|
else:
|
|
params = faiss.ParameterSpace()
|
|
|
|
params.initialize(index)
|
|
|
|
print("[%.3f s] train & add" % (time.time() - t0))
|
|
|
|
index.train(xt)
|
|
index.add(xb)
|
|
|
|
print("[%.3f s] explore op points" % (time.time() - t0))
|
|
|
|
# find operating points for this index
|
|
opi = params.explore(index, xq, crit)
|
|
|
|
print("[%.3f s] result operating points:" % (time.time() - t0))
|
|
opi.display()
|
|
|
|
# update best operating points so far
|
|
op.merge_with(opi, index_key + " ")
|
|
|
|
op_per_key.append((index_key, opi))
|
|
|
|
if graphical_output:
|
|
# graphical output (to tmp/ subdirectory)
|
|
|
|
fig = pyplot.figure(figsize=(12, 9))
|
|
pyplot.xlabel("1-recall at 1")
|
|
pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads())
|
|
pyplot.gca().set_yscale('log')
|
|
pyplot.grid()
|
|
for i2, opi2 in op_per_key:
|
|
plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o')
|
|
# plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r')
|
|
pyplot.legend(loc=2)
|
|
fig.savefig('tmp/demo_auto_tune.png')
|
|
|
|
|
|
print("[%.3f s] final result:" % (time.time() - t0))
|
|
|
|
op.display()
|