faiss/benchs/bench_all_ivf/bench_all_ivf.py

464 lines
15 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
import pdb
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize
######################################################
# Command-line parsing
######################################################
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa('--compute_gt', default=False, action='store_true',
help='compute and store the groundtruth')
aa('--force_IP', default=False, action="store_true",
help='force IP search instead of L2')
group = parser.add_argument_group('index consturction')
aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points (0 to set automatically)')
aa('--indexfile', default='', help='file to read or write index from')
aa('--add_bs', default=-1, type=int,
help='add elements index by batches of this size')
group = parser.add_argument_group('IVF options')
aa('--by_residual', default=-1, type=int,
help="set if index should use residuals (default=unchanged)")
aa('--no_precomputed_tables', action='store_true', default=False,
help='disable precomputed tables (uses less memory)')
aa('--get_centroids_from', default='',
help='get the centroids from this index (to speed up training)')
aa('--clustering_niter', default=-1, type=int,
help='number of clustering iterations (-1 = leave default)')
aa('--train_on_gpu', default=False, action='store_true',
help='do training on GPU')
group = parser.add_argument_group('index-specific options')
aa('--M0', default=-1, type=int, help='size of base level for HNSW')
aa('--RQ_train_default', default=False, action="store_true",
help='disable progressive dim training for RQ')
aa('--RQ_beam_size', default=-1, type=int,
help='set beam size at add time')
aa('--LSQ_encode_ils_iters', default=-1, type=int,
help='ILS iterations for LSQ')
aa('--RQ_use_beam_LUT', default=-1, type=int,
help='use beam LUT at add time')
group = parser.add_argument_group('searching')
aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--inter', default=False, action='store_true',
help='use intersection measure instead of 1-recall as metric')
aa('--searchthreads', default=-1, type=int,
help='nb of threads to use at search time')
aa('--searchparams', nargs='+', default=['autotune'],
help="search parameters to use (can be autotune or a list of params)")
aa('--n_autotune', default=500, type=int,
help="max nb of autotune experiments")
aa('--autotune_max', default=[], nargs='*',
help='set max value for autotune variables format "var:val" (exclusive)')
aa('--autotune_range', default=[], nargs='*',
help='set complete autotune range, format "var:val1,val2,..."')
aa('--min_test_duration', default=3.0, type=float,
help='run test at least for so long to avoid jitter')
args = parser.parse_args()
print("args:", args)
os.system('echo -n "nb processors "; '
'cat /proc/cpuinfo | grep ^processor | wc -l; '
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
######################################################
# Load dataset
######################################################
ds = datasets.load_dataset(
dataset=args.db, compute_gt=args.compute_gt)
if args.force_IP:
ds.metric = "IP"
print(ds)
nq, d = ds.nq, ds.d
nb, d = ds.nq, ds.d
######################################################
# Make index
######################################################
def unwind_index_ivf(index):
if isinstance(index, faiss.IndexPreTransform):
assert index.chain.size() == 1
vt = index.chain.at(0)
index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index))
assert vt2 is None
return index_ivf, vt
if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine):
return unwind_index_ivf(faiss.downcast_index(index.base_index))
if isinstance(index, faiss.IndexIVF):
return index, None
else:
return None, None
def apply_AQ_options(index, args):
# if not(
# isinstance(index, faiss.IndexAdditiveQuantize) or
# isinstance(index, faiss.IndexIVFAdditiveQuantizer)):
# return
if args.RQ_train_default:
print("set default training for RQ")
index.rq.train_type
index.rq.train_type = faiss.ResidualQuantizer.Train_default
if args.RQ_beam_size != -1:
print("set RQ beam size to", args.RQ_beam_size)
index.rq.max_beam_size
index.rq.max_beam_size = args.RQ_beam_size
if args.LSQ_encode_ils_iters != -1:
print("set LSQ ils iterations to", args.LSQ_encode_ils_iters)
index.lsq.encode_ils_iters
index.lsq.encode_ils_iters = args.LSQ_encode_ils_iters
if args.RQ_use_beam_LUT != -1:
print("set RQ beam LUT to", args.RQ_use_beam_LUT)
index.rq.use_beam_LUT
index.rq.use_beam_LUT = args.RQ_use_beam_LUT
if args.indexfile and os.path.exists(args.indexfile):
print("reading", args.indexfile)
index = faiss.read_index(args.indexfile)
index_ivf, vec_transform = unwind_index_ivf(index)
if vec_transform is None:
vec_transform = lambda x: x
else:
print("build index, key=", args.indexkey)
index = faiss.index_factory(
d, args.indexkey, faiss.METRIC_L2 if ds.metric == "L2" else
faiss.METRIC_INNER_PRODUCT
)
index_ivf, vec_transform = unwind_index_ivf(index)
if vec_transform is None:
vec_transform = lambda x: x
else:
vec_transform = faiss.downcast_VectorTransform(vec_transform)
if args.by_residual != -1:
by_residual = args.by_residual == 1
print("setting by_residual = ", by_residual)
index_ivf.by_residual # check if field exists
index_ivf.by_residual = by_residual
if index_ivf:
print("Update add-time parameters")
# adjust default parameters used at add time for quantizers
# because otherwise the assignment is inaccurate
quantizer = faiss.downcast_index(index_ivf.quantizer)
if isinstance(quantizer, faiss.IndexRefine):
print(" update quantizer k_factor=", quantizer.k_factor, end=" -> ")
quantizer.k_factor = 32 if index_ivf.nlist < 1e6 else 64
print(quantizer.k_factor)
base_index = faiss.downcast_index(quantizer.base_index)
if isinstance(base_index, faiss.IndexIVF):
print(" update quantizer nprobe=", base_index.nprobe, end=" -> ")
base_index.nprobe = (
16 if base_index.nlist < 1e5 else
32 if base_index.nlist < 4e6 else
64)
print(base_index.nprobe)
elif isinstance(quantizer, faiss.IndexHNSW):
print(" update quantizer efSearch=", quantizer.hnsw.efSearch, end=" -> ")
quantizer.hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64
print(quantizer.hnsw.efSearch)
apply_AQ_options(index_ivf or index, args)
if index_ivf:
index_ivf.verbose = True
index_ivf.quantizer.verbose = True
index_ivf.cp.verbose = True
else:
index.verbose = True
maxtrain = args.maxtrain
if maxtrain == 0:
if 'IMI' in args.indexkey:
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2))
elif index_ivf:
maxtrain = 50 * index_ivf.nlist
else:
# just guess...
maxtrain = 256 * 100
maxtrain = max(maxtrain, 256 * 100)
print("setting maxtrain to %d" % maxtrain)
try:
xt2 = ds.get_train(maxtrain=maxtrain)
except NotImplementedError:
print("No training set: training on database")
xt2 = ds.get_database()[:maxtrain]
print("train, size", xt2.shape)
assert np.all(np.isfinite(xt2))
if (isinstance(vec_transform, faiss.OPQMatrix) and
isinstance(index_ivf, faiss.IndexIVFPQFastScan)):
print(" Forcing OPQ training PQ to PQ4")
ref_pq = index_ivf.pq
training_pq = faiss.ProductQuantizer(
ref_pq.d, ref_pq.M, ref_pq.nbits
)
vec_transform.pq
vec_transform.pq = training_pq
if args.get_centroids_from == '':
if args.clustering_niter >= 0:
print(("setting nb of clustering iterations to %d" %
args.clustering_niter))
index_ivf.cp.niter = args.clustering_niter
if args.train_on_gpu:
print("add a training index on GPU")
train_index = faiss.index_cpu_to_all_gpus(
faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = train_index
else:
print("Getting centroids from", args.get_centroids_from)
src_index = faiss.read_index(args.get_centroids_from)
src_quant = faiss.downcast_index(src_index.quantizer)
centroids = faiss.vector_to_array(src_quant.xb)
centroids = centroids.reshape(-1, d)
print(" centroid table shape", centroids.shape)
if isinstance(vec_transform, faiss.VectorTransform):
print(" training vector transform")
vec_transform.train(xt2)
print(" transform centroids")
centroids = vec_transform.apply_py(centroids)
if not index_ivf.quantizer.is_trained:
print(" training quantizer")
index_ivf.quantizer.train(centroids)
print(" add centroids to quantizer")
index_ivf.quantizer.add(centroids)
del src_index
t0 = time.time()
index.train(xt2)
print(" train in %.3f s" % (time.time() - t0))
print("adding")
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(ds.get_database()))
else:
i0 = 0
for xblock in ds.database_iterator(bs=args.add_bs):
i1 = i0 + len(xblock)
print(" adding %d:%d / %d [%.3f s, RSS %d kiB] " % (
i0, i1, ds.nb, time.time() - t0,
faiss.get_mem_usage_kb()))
index.add(xblock)
i0 = i1
print(" add in %.3f s" % (time.time() - t0))
if args.indexfile:
print("storing", args.indexfile)
faiss.write_index(index, args.indexfile)
if args.no_precomputed_tables:
if isinstance(index_ivf, faiss.IndexIVFPQ):
print("disabling precomputed table")
index_ivf.use_precomputed_table = -1
index_ivf.precomputed_table.clear()
if args.indexfile:
print("index size on disk: ", os.stat(args.indexfile).st_size)
if hasattr(index, "code_size"):
print("vector code_size", index.code_size)
if hasattr(index_ivf, "code_size"):
print("vector code_size (IVF)", index_ivf.code_size)
print("current RSS:", faiss.get_mem_usage_kb() * 1024)
precomputed_table_size = 0
if hasattr(index_ivf, 'precomputed_table'):
precomputed_table_size = index_ivf.precomputed_table.size() * 4
print("precomputed tables size:", precomputed_table_size)
#############################################################
# Index is ready
#############################################################
xq = sanitize(ds.get_queries())
gt = ds.get_groundtruth(k=args.k)
assert gt.shape[1] == args.k, pdb.set_trace()
if args.searchthreads != -1:
print("Setting nb of threads to", args.searchthreads)
faiss.omp_set_num_threads(args.searchthreads)
else:
print("nb search threads: ", faiss.omp_get_max_threads())
ps = faiss.ParameterSpace()
ps.initialize(index)
parametersets = args.searchparams
if args.inter:
header = (
'%-40s inter@%3d time(ms/q) nb distances #runs' %
("parameters", args.k)
)
else:
header = (
'%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' %
"parameters"
)
def compute_inter(a, b):
nq, rank = a.shape
ninter = sum(
np.intersect1d(a[i, :rank], b[i, :rank]).size
for i in range(nq)
)
return ninter / a.size
def eval_setting(index, xq, gt, k, inter, min_time):
nq = xq.shape[0]
ivf_stats = faiss.cvar.indexIVF_stats
ivf_stats.reset()
nrun = 0
t0 = time.time()
while True:
D, I = index.search(xq, k)
nrun += 1
t1 = time.time()
if t1 - t0 > min_time:
break
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
if inter:
rank = k
inter_measure = compute_inter(gt[:, :rank], I[:, :rank])
print("%.4f" % inter_measure, end=' ')
else:
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print(" %9.5f " % ms_per_query, end=' ')
print("%12d " % (ivf_stats.ndis / nrun), end=' ')
print(nrun)
if parametersets == ['autotune']:
ps.n_experiments = args.n_autotune
ps.min_test_duration = args.min_test_duration
for kv in args.autotune_max:
k, vmax = kv.split(':')
vmax = float(vmax)
print("limiting %s to %g" % (k, vmax))
pr = ps.add_range(k)
values = faiss.vector_to_array(pr.values)
values = np.array([v for v in values if v < vmax])
faiss.copy_array_to_vector(values, pr.values)
for kv in args.autotune_range:
k, vals = kv.split(':')
vals = np.fromstring(vals, sep=',')
print("setting %s to %s" % (k, vals))
pr = ps.add_range(k)
faiss.copy_array_to_vector(vals, pr.values)
# setup the Criterion object
if args.inter:
print("Optimize for intersection @ ", args.k)
crit = faiss.IntersectionCriterion(nq, args.k)
else:
print("Optimize for 1-recall @ 1")
crit = faiss.OneRecallAtRCriterion(nq, 1)
# by default, the criterion will request only 1 NN
crit.nnn = args.k
crit.set_groundtruth(None, gt.astype('int64'))
# then we let Faiss find the optimal parameters by itself
print("exploring operating points, %d threads" % faiss.omp_get_max_threads());
ps.display()
t0 = time.time()
op = ps.explore(index, xq, crit)
print("Done in %.3f s, available OPs:" % (time.time() - t0))
op.display()
print("Re-running evaluation on selected OPs")
print(header)
opv = op.optimal_pts
maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40)
for i in range(opv.size()):
opt = opv.at(i)
ps.set_index_parameters(index, opt.key)
print(opt.key.ljust(maxw), end=' ')
sys.stdout.flush()
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
else:
print(header)
for param in parametersets:
print("%-40s " % param, end=' ')
sys.stdout.flush()
ps.set_index_parameters(index, param)
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)