308 lines
9.1 KiB
Python
308 lines
9.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
import sys
|
|
import os
|
|
import argparse
|
|
|
|
import numpy as np
|
|
|
|
|
|
def eval_recalls(name, I, gt, times):
|
|
k = I.shape[1]
|
|
s = "%-40s recall" % name
|
|
nq = len(gt)
|
|
for rank in 1, 10, 100, 1000:
|
|
if rank > k:
|
|
break
|
|
recall = (I[:, :rank] == gt[:, :1]).sum() / nq
|
|
s += "@%d: %.4f " % (rank, recall)
|
|
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
|
|
print(s)
|
|
|
|
def eval_inters(name, I, gt, times):
|
|
k = I.shape[1]
|
|
s = "%-40s inter" % name
|
|
nq = len(gt)
|
|
for rank in 1, 10, 100, 1000:
|
|
if rank > k:
|
|
break
|
|
ninter = 0
|
|
for i in range(nq):
|
|
ninter += np.intersect1d(I[i, :rank], gt[i, :rank]).size
|
|
inter = ninter / (nq * rank)
|
|
s += "@%d: %.4f " % (rank, inter)
|
|
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
|
|
print(s)
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
def aa(*args, **kwargs):
|
|
group.add_argument(*args, **kwargs)
|
|
|
|
group = parser.add_argument_group('dataset options')
|
|
|
|
aa('--db', default='deep1M', help='dataset')
|
|
aa('--measure', default="1-recall",
|
|
help="perf measure to use: 1-recall or inter")
|
|
aa('--download', default=False, action="store_true")
|
|
aa('--lib', default='faiss', help='library to use (faiss or scann)')
|
|
aa('--thenscann', default=False, action="store_true")
|
|
aa('--base_dir', default='/checkpoint/matthijs/faiss_improvements/cmp_ivf_scan_2')
|
|
|
|
group = parser.add_argument_group('searching')
|
|
aa('--k', default=10, type=int, help='nb of nearest neighbors')
|
|
aa('--pre_reorder_k', default="0,10,100,1000", help='values for reorder_k')
|
|
aa('--nprobe', default="1,2,5,10,20,50,100,200", help='values for nprobe')
|
|
aa('--nrun', default=5, type=int, help='nb of runs to perform')
|
|
args = parser.parse_args()
|
|
|
|
print("args:", args)
|
|
pre_reorder_k_tab = [int(x) for x in args.pre_reorder_k.split(',')]
|
|
nprobe_tab = [int(x) for x in args.nprobe.split(',')]
|
|
|
|
os.system('echo -n "nb processors "; '
|
|
'cat /proc/cpuinfo | grep ^processor | wc -l; '
|
|
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
|
|
|
|
cache_dir = args.base_dir + "/" + args.db + "/"
|
|
k = args.k
|
|
nrun = args.nrun
|
|
|
|
if not os.path.exists(cache_dir + "xb.npy"):
|
|
# prepare cache
|
|
from datasets import load_dataset
|
|
ds = load_dataset(args.db, download=args.download)
|
|
print(ds)
|
|
# store for SCANN
|
|
os.system(f"rm -rf {cache_dir}; mkdir -p {cache_dir}")
|
|
tosave = dict(
|
|
xb = ds.get_database(),
|
|
xq = ds.get_queries(),
|
|
gt = ds.get_groundtruth()
|
|
)
|
|
for name, v in tosave.items():
|
|
fname = cache_dir + "/" + name + ".npy"
|
|
print("save", fname)
|
|
np.save(fname, v)
|
|
|
|
open(cache_dir + "metric", "w").write(ds.metric)
|
|
|
|
dataset = {}
|
|
for kn in "xb xq gt".split():
|
|
fname = cache_dir + "/" + kn + ".npy"
|
|
print("load", fname)
|
|
dataset[kn] = np.load(fname)
|
|
xb = dataset["xb"]
|
|
xq = dataset["xq"]
|
|
gt = dataset["gt"]
|
|
distance_measure = open(cache_dir + "metric").read()
|
|
|
|
if args.lib == "faiss":
|
|
import faiss
|
|
|
|
name1_to_metric = {
|
|
"IP": faiss.METRIC_INNER_PRODUCT,
|
|
"L2": faiss.METRIC_L2
|
|
}
|
|
|
|
index_fname = cache_dir + "index.faiss"
|
|
if not os.path.exists(index_fname):
|
|
index = faiss_make_index(
|
|
xb, name1_to_metric[distance_measure], index_fname)
|
|
else:
|
|
index = faiss.read_index(index_fname)
|
|
|
|
faiss_eval_search(
|
|
index, xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
|
nrun, args.measure
|
|
)
|
|
|
|
if args.lib == "scann":
|
|
from scann.scann_ops.py import scann_ops_pybind
|
|
|
|
name1_to_name2 = {
|
|
"IP": "dot_product",
|
|
"L2": "squared_l2"
|
|
}
|
|
|
|
scann_dir = cache_dir + "/scann1.1.1_serialized"
|
|
if os.path.exists(scann_dir + "/scann_config.pb"):
|
|
searcher = scann_ops_pybind.load_searcher(scann_dir)
|
|
else:
|
|
searcher = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 0)
|
|
|
|
scann_dir = cache_dir + "/scann1.1.1_serialized_reorder"
|
|
if os.path.exists(scann_dir + "/scann_config.pb"):
|
|
searcher_reo = scann_ops_pybind.load_searcher(scann_dir)
|
|
else:
|
|
searcher_reo = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 100)
|
|
|
|
scann_eval_search(
|
|
searcher, searcher_reo,
|
|
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
|
nrun, args.measure
|
|
)
|
|
|
|
if args.lib != "scann" and args.thenscann:
|
|
# just append --lib scann, that will override the previous cmdline
|
|
# options
|
|
cmdline = " ".join(sys.argv) + " --lib scann"
|
|
cmdline = (
|
|
". ~/anaconda3/etc/profile.d/conda.sh ; " +
|
|
"conda activate scann_1.1.1; "
|
|
"python -u " + cmdline)
|
|
|
|
print("running", cmdline)
|
|
|
|
os.system(cmdline)
|
|
|
|
|
|
###############################################################
|
|
# SCANN
|
|
###############################################################
|
|
|
|
def scann_make_index(xb, distance_measure, scann_dir, reorder_k):
|
|
import scann
|
|
|
|
print("build index")
|
|
|
|
if distance_measure == "dot_product":
|
|
thr = 0.2
|
|
else:
|
|
thr = 0
|
|
k = 10
|
|
sb = scann.scann_ops_pybind.builder(xb, k, distance_measure)
|
|
sb = sb.tree(num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
|
|
sb = sb.score_ah(2, anisotropic_quantization_threshold=thr)
|
|
|
|
if reorder_k > 0:
|
|
sb = sb.reorder(reorder_k)
|
|
|
|
searcher = sb.build()
|
|
|
|
print("done")
|
|
|
|
print("write index to", scann_dir)
|
|
|
|
os.system(f"rm -rf {scann_dir}; mkdir -p {scann_dir}")
|
|
# os.mkdir(scann_dir)
|
|
searcher.serialize(scann_dir)
|
|
return searcher
|
|
|
|
def scann_eval_search(
|
|
searcher, searcher_reo,
|
|
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
|
nrun, measure):
|
|
|
|
# warmup
|
|
for _run in range(5):
|
|
searcher.search_batched(xq)
|
|
|
|
for nprobe in nprobe_tab:
|
|
|
|
for pre_reorder_k in pre_reorder_k_tab:
|
|
|
|
times = []
|
|
for _run in range(nrun):
|
|
if pre_reorder_k == 0:
|
|
t0 = time.time()
|
|
I, D = searcher.search_batched(
|
|
xq, leaves_to_search=nprobe, final_num_neighbors=k
|
|
)
|
|
t1 = time.time()
|
|
else:
|
|
t0 = time.time()
|
|
I, D = searcher_reo.search_batched(
|
|
xq, leaves_to_search=nprobe, final_num_neighbors=k,
|
|
pre_reorder_num_neighbors=pre_reorder_k
|
|
)
|
|
t1 = time.time()
|
|
|
|
times.append(t1 - t0)
|
|
header = "SCANN nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
|
|
if measure == "1-recall":
|
|
eval_recalls(header, I, gt, times)
|
|
else:
|
|
eval_inters(header, I, gt, times)
|
|
|
|
|
|
|
|
|
|
###############################################################
|
|
# Faiss
|
|
###############################################################
|
|
|
|
|
|
def faiss_make_index(xb, metric_type, fname):
|
|
import faiss
|
|
|
|
d = xb.shape[1]
|
|
M = d // 2
|
|
index = faiss.index_factory(d, f"IVF2000,PQ{M}x4fs", metric_type)
|
|
# if not by_residual:
|
|
# print("setting no residual")
|
|
# index.by_residual = False
|
|
|
|
print("train")
|
|
index.train(xb[:250000])
|
|
print("add")
|
|
index.add(xb)
|
|
print("write index", fname)
|
|
faiss.write_index(index, fname)
|
|
|
|
return index
|
|
|
|
def faiss_eval_search(
|
|
index, xq, xb, nprobe_tab, pre_reorder_k_tab,
|
|
k, gt, nrun, measure
|
|
):
|
|
import faiss
|
|
|
|
print("use precomputed table=", index.use_precomputed_table,
|
|
"by residual=", index.by_residual)
|
|
|
|
print("adding a refine index")
|
|
index_refine = faiss.IndexRefineFlat(index, faiss.swig_ptr(xb))
|
|
|
|
print("set single thread")
|
|
faiss.omp_set_num_threads(1)
|
|
|
|
print("warmup")
|
|
for _run in range(5):
|
|
index.search(xq, k)
|
|
|
|
print("run timing")
|
|
for nprobe in nprobe_tab:
|
|
for pre_reorder_k in pre_reorder_k_tab:
|
|
index.nprobe = nprobe
|
|
times = []
|
|
for _run in range(nrun):
|
|
if pre_reorder_k == 0:
|
|
t0 = time.time()
|
|
D, I = index.search(xq, k)
|
|
t1 = time.time()
|
|
else:
|
|
index_refine.k_factor = pre_reorder_k / k
|
|
t0 = time.time()
|
|
D, I = index_refine.search(xq, k)
|
|
t1 = time.time()
|
|
|
|
times.append(t1 - t0)
|
|
|
|
header = "Faiss nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
|
|
if measure == "1-recall":
|
|
eval_recalls(header, I, gt, times)
|
|
else:
|
|
eval_inters(header, I, gt, times)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|