faiss/benchs/bench_all_ivf/bench_all_ivf.py

309 lines
9.4 KiB
Python
Raw Normal View History

2020-08-04 04:15:02 +08:00
#!/usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
2018-12-20 21:43:36 +08:00
#
# This source code is licensed under the MIT license found in the
2018-12-20 21:43:36 +08:00
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize
######################################################
# Command-line parsing
######################################################
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa('--compute_gt', default=False, action='store_true',
help='compute and store the groundtruth')
group = parser.add_argument_group('index consturction')
aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--efConstruction', default=200, type=int,
help='HNSW construction factor')
aa('--M0', default=-1, type=int, help='size of base level')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points (0 to set automatically)')
aa('--indexfile', default='', help='file to read or write index from')
aa('--add_bs', default=-1, type=int,
help='add elements index by batches of this size')
aa('--no_precomputed_tables', action='store_true', default=False,
help='disable precomputed tables (uses less memory)')
aa('--clustering_niter', default=-1, type=int,
help='number of clustering iterations (-1 = leave default)')
aa('--train_on_gpu', default=False, action='store_true',
help='do training on GPU')
aa('--get_centroids_from', default='',
help='get the centroids from this index (to speed up training)')
group = parser.add_argument_group('searching')
aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--searchthreads', default=-1, type=int,
help='nb of threads to use at search time')
aa('--searchparams', nargs='+', default=['autotune'],
help="search parameters to use (can be autotune or a list of params)")
aa('--n_autotune', default=500, type=int,
help="max nb of autotune experiments")
aa('--autotune_max', default=[], nargs='*',
help='set max value for autotune variables format "var:val" (exclusive)')
aa('--autotune_range', default=[], nargs='*',
help='set complete autotune range, format "var:val1,val2,..."')
aa('--min_test_duration', default=0, type=float,
help='run test at least for so long to avoid jitter')
args = parser.parse_args()
print("args:", args)
2018-12-20 21:43:36 +08:00
os.system('echo -n "nb processors "; '
'cat /proc/cpuinfo | grep ^processor | wc -l; '
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
######################################################
# Load dataset
######################################################
xt, xb, xq, gt = datasets.load_data(
dataset=args.db, compute_gt=args.compute_gt)
print("dataset sizes: train %s base %s query %s GT %s" % (
xt.shape, xb.shape, xq.shape, gt.shape))
2018-12-20 21:43:36 +08:00
nq, d = xq.shape
nb, d = xb.shape
######################################################
# Make index
######################################################
if args.indexfile and os.path.exists(args.indexfile):
print("reading", args.indexfile)
2018-12-20 21:43:36 +08:00
index = faiss.read_index(args.indexfile)
if isinstance(index, faiss.IndexPreTransform):
index_ivf = faiss.downcast_index(index.index)
else:
index_ivf = index
assert isinstance(index_ivf, faiss.IndexIVF)
vec_transform = lambda x: x
assert isinstance(index_ivf, faiss.IndexIVF)
else:
print("build index, key=", args.indexkey)
2018-12-20 21:43:36 +08:00
index = faiss.index_factory(d, args.indexkey)
if isinstance(index, faiss.IndexPreTransform):
index_ivf = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
else:
index_ivf = index
vec_transform = lambda x:x
assert isinstance(index_ivf, faiss.IndexIVF)
index_ivf.verbose = True
index_ivf.quantizer.verbose = True
index_ivf.cp.verbose = True
maxtrain = args.maxtrain
if maxtrain == 0:
if 'IMI' in args.indexkey:
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2))
else:
maxtrain = 50 * index_ivf.nlist
print("setting maxtrain to %d" % maxtrain)
2018-12-20 21:43:36 +08:00
args.maxtrain = maxtrain
xt2 = sanitize(xt[:args.maxtrain])
assert np.all(np.isfinite(xt2))
print("train, size", xt2.shape)
2018-12-20 21:43:36 +08:00
if args.get_centroids_from == '':
if args.clustering_niter >= 0:
print(("setting nb of clustering iterations to %d" %
args.clustering_niter))
2018-12-20 21:43:36 +08:00
index_ivf.cp.niter = args.clustering_niter
if args.train_on_gpu:
print("add a training index on GPU")
2018-12-20 21:43:36 +08:00
train_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
index_ivf.clustering_index = train_index
else:
print("Getting centroids from", args.get_centroids_from)
2018-12-20 21:43:36 +08:00
src_index = faiss.read_index(args.get_centroids_from)
src_quant = faiss.downcast_index(src_index.quantizer)
centroids = faiss.vector_to_array(src_quant.xb)
centroids = centroids.reshape(-1, d)
print(" centroid table shape", centroids.shape)
2018-12-20 21:43:36 +08:00
if isinstance(index, faiss.IndexPreTransform):
print(" training vector transform")
2018-12-20 21:43:36 +08:00
assert index.chain.size() == 1
vt = index.chain.at(0)
vt.train(xt2)
print(" transform centroids")
2018-12-20 21:43:36 +08:00
centroids = vt.apply_py(centroids)
print(" add centroids to quantizer")
2018-12-20 21:43:36 +08:00
index_ivf.quantizer.add(centroids)
del src_index
t0 = time.time()
index.train(xt2)
print(" train in %.3f s" % (time.time() - t0))
2018-12-20 21:43:36 +08:00
print("adding")
2018-12-20 21:43:36 +08:00
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(xb))
else:
for i0 in range(0, nb, args.add_bs):
i1 = min(nb, i0 + args.add_bs)
print(" adding %d:%d / %d" % (i0, i1, nb))
2018-12-20 21:43:36 +08:00
index.add(sanitize(xb[i0:i1]))
print(" add in %.3f s" % (time.time() - t0))
2018-12-20 21:43:36 +08:00
if args.indexfile:
print("storing", args.indexfile)
2018-12-20 21:43:36 +08:00
faiss.write_index(index, args.indexfile)
if args.no_precomputed_tables:
if isinstance(index_ivf, faiss.IndexIVFPQ):
print("disabling precomputed table")
2018-12-20 21:43:36 +08:00
index_ivf.use_precomputed_table = -1
index_ivf.precomputed_table.clear()
if args.indexfile:
print("index size on disk: ", os.stat(args.indexfile).st_size)
2018-12-20 21:43:36 +08:00
print("current RSS:", faiss.get_mem_usage_kb() * 1024)
2018-12-20 21:43:36 +08:00
precomputed_table_size = 0
if hasattr(index_ivf, 'precomputed_table'):
precomputed_table_size = index_ivf.precomputed_table.size() * 4
print("precomputed tables size:", precomputed_table_size)
2018-12-20 21:43:36 +08:00
#############################################################
# Index is ready
#############################################################
xq = sanitize(xq)
if args.searchthreads != -1:
print("Setting nb of threads to", args.searchthreads)
2018-12-20 21:43:36 +08:00
faiss.omp_set_num_threads(args.searchthreads)
ps = faiss.ParameterSpace()
ps.initialize(index)
parametersets = args.searchparams
header = '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % "parameters"
def eval_setting(index, xq, gt, min_time):
nq = xq.shape[0]
ivf_stats = faiss.cvar.indexIVF_stats
ivf_stats.reset()
nrun = 0
t0 = time.time()
while True:
D, I = index.search(xq, 100)
nrun += 1
t1 = time.time()
if t1 - t0 > min_time:
break
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print(" %8.3f " % ms_per_query, end=' ')
print("%12d " % (ivf_stats.ndis / nrun), end=' ')
print(nrun)
2018-12-20 21:43:36 +08:00
if parametersets == ['autotune']:
ps.n_experiments = args.n_autotune
ps.min_test_duration = args.min_test_duration
for kv in args.autotune_max:
k, vmax = kv.split(':')
vmax = float(vmax)
print("limiting %s to %g" % (k, vmax))
2018-12-20 21:43:36 +08:00
pr = ps.add_range(k)
values = faiss.vector_to_array(pr.values)
values = np.array([v for v in values if v < vmax])
faiss.copy_array_to_vector(values, pr.values)
for kv in args.autotune_range:
k, vals = kv.split(':')
vals = np.fromstring(vals, sep=',')
print("setting %s to %s" % (k, vals))
2018-12-20 21:43:36 +08:00
pr = ps.add_range(k)
faiss.copy_array_to_vector(vals, pr.values)
# setup the Criterion object: optimize for 1-R@1
crit = faiss.OneRecallAtRCriterion(nq, 1)
# by default, the criterion will request only 1 NN
crit.nnn = 100
crit.set_groundtruth(None, gt.astype('int64'))
# then we let Faiss find the optimal parameters by itself
print("exploring operating points")
2018-12-20 21:43:36 +08:00
ps.display()
t0 = time.time()
op = ps.explore(index, xq, crit)
print("Done in %.3f s, available OPs:" % (time.time() - t0))
2018-12-20 21:43:36 +08:00
op.display()
print(header)
2018-12-20 21:43:36 +08:00
opv = op.optimal_pts
for i in range(opv.size()):
opt = opv.at(i)
ps.set_index_parameters(index, opt.key)
print("%-40s " % opt.key, end=' ')
2018-12-20 21:43:36 +08:00
sys.stdout.flush()
eval_setting(index, xq, gt, args.min_test_duration)
else:
print(header)
2018-12-20 21:43:36 +08:00
for param in parametersets:
print("%-40s " % param, end=' ')
2018-12-20 21:43:36 +08:00
sys.stdout.flush()
ps.set_index_parameters(index, param)
eval_setting(index, xq, gt, args.min_test_duration)