faiss/benchs/bench_all_ivf/cmp_with_scann.py
chasingegg adc9d1a0cd Refactor prepare cache code in cmp_with_scann benchmark (#2573)
Summary:
In ```cmp_with_scann.py```, we will save npy file for base and query vector file and gt file. However, we will only do this while the lib is faiss, if we directly run this script with scann lib it will complain that file does not exsit.
Therefore, the code should be refactored to save npy file from the beginning so that nothing will go wrong.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2573

Reviewed By: mdouze

Differential Revision: D42338435

Pulled By: algoriddle

fbshipit-source-id: 9227f95e1ff79f5329f6206a0cb7ca169185fdb3
2023-01-04 02:35:18 -08:00

308 lines
9.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import sys
import os
import argparse
import numpy as np
def eval_recalls(name, I, gt, times):
k = I.shape[1]
s = "%-40s recall" % name
nq = len(gt)
for rank in 1, 10, 100, 1000:
if rank > k:
break
recall = (I[:, :rank] == gt[:, :1]).sum() / nq
s += "@%d: %.4f " % (rank, recall)
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
print(s)
def eval_inters(name, I, gt, times):
k = I.shape[1]
s = "%-40s inter" % name
nq = len(gt)
for rank in 1, 10, 100, 1000:
if rank > k:
break
ninter = 0
for i in range(nq):
ninter += np.intersect1d(I[i, :rank], gt[i, :rank]).size
inter = ninter / (nq * rank)
s += "@%d: %.4f " % (rank, inter)
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
print(s)
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa('--measure', default="1-recall",
help="perf measure to use: 1-recall or inter")
aa('--download', default=False, action="store_true")
aa('--lib', default='faiss', help='library to use (faiss or scann)')
aa('--thenscann', default=False, action="store_true")
aa('--base_dir', default='/checkpoint/matthijs/faiss_improvements/cmp_ivf_scan_2')
group = parser.add_argument_group('searching')
aa('--k', default=10, type=int, help='nb of nearest neighbors')
aa('--pre_reorder_k', default="0,10,100,1000", help='values for reorder_k')
aa('--nprobe', default="1,2,5,10,20,50,100,200", help='values for nprobe')
aa('--nrun', default=5, type=int, help='nb of runs to perform')
args = parser.parse_args()
print("args:", args)
pre_reorder_k_tab = [int(x) for x in args.pre_reorder_k.split(',')]
nprobe_tab = [int(x) for x in args.nprobe.split(',')]
os.system('echo -n "nb processors "; '
'cat /proc/cpuinfo | grep ^processor | wc -l; '
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
cache_dir = args.base_dir + "/" + args.db + "/"
k = args.k
nrun = args.nrun
if not os.path.exists(cache_dir + "xb.npy"):
# prepare cache
from datasets import load_dataset
ds = load_dataset(args.db, download=args.download)
print(ds)
# store for SCANN
os.system(f"rm -rf {cache_dir}; mkdir -p {cache_dir}")
tosave = dict(
xb = ds.get_database(),
xq = ds.get_queries(),
gt = ds.get_groundtruth()
)
for name, v in tosave.items():
fname = cache_dir + "/" + name + ".npy"
print("save", fname)
np.save(fname, v)
open(cache_dir + "metric", "w").write(ds.metric)
dataset = {}
for kn in "xb xq gt".split():
fname = cache_dir + "/" + kn + ".npy"
print("load", fname)
dataset[kn] = np.load(fname)
xb = dataset["xb"]
xq = dataset["xq"]
gt = dataset["gt"]
distance_measure = open(cache_dir + "metric").read()
if args.lib == "faiss":
import faiss
name1_to_metric = {
"IP": faiss.METRIC_INNER_PRODUCT,
"L2": faiss.METRIC_L2
}
index_fname = cache_dir + "index.faiss"
if not os.path.exists(index_fname):
index = faiss_make_index(
xb, name1_to_metric[distance_measure], index_fname)
else:
index = faiss.read_index(index_fname)
faiss_eval_search(
index, xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
nrun, args.measure
)
if args.lib == "scann":
from scann.scann_ops.py import scann_ops_pybind
name1_to_name2 = {
"IP": "dot_product",
"L2": "squared_l2"
}
scann_dir = cache_dir + "/scann1.1.1_serialized"
if os.path.exists(scann_dir + "/scann_config.pb"):
searcher = scann_ops_pybind.load_searcher(scann_dir)
else:
searcher = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 0)
scann_dir = cache_dir + "/scann1.1.1_serialized_reorder"
if os.path.exists(scann_dir + "/scann_config.pb"):
searcher_reo = scann_ops_pybind.load_searcher(scann_dir)
else:
searcher_reo = scann_make_index(xb, name1_to_name2[distance_measure], scann_dir, 100)
scann_eval_search(
searcher, searcher_reo,
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
nrun, args.measure
)
if args.lib != "scann" and args.thenscann:
# just append --lib scann, that will override the previous cmdline
# options
cmdline = " ".join(sys.argv) + " --lib scann"
cmdline = (
". ~/anaconda3/etc/profile.d/conda.sh ; " +
"conda activate scann_1.1.1; "
"python -u " + cmdline)
print("running", cmdline)
os.system(cmdline)
###############################################################
# SCANN
###############################################################
def scann_make_index(xb, distance_measure, scann_dir, reorder_k):
import scann
print("build index")
if distance_measure == "dot_product":
thr = 0.2
else:
thr = 0
k = 10
sb = scann.scann_ops_pybind.builder(xb, k, distance_measure)
sb = sb.tree(num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
sb = sb.score_ah(2, anisotropic_quantization_threshold=thr)
if reorder_k > 0:
sb = sb.reorder(reorder_k)
searcher = sb.build()
print("done")
print("write index to", scann_dir)
os.system(f"rm -rf {scann_dir}; mkdir -p {scann_dir}")
# os.mkdir(scann_dir)
searcher.serialize(scann_dir)
return searcher
def scann_eval_search(
searcher, searcher_reo,
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
nrun, measure):
# warmup
for _run in range(5):
searcher.search_batched(xq)
for nprobe in nprobe_tab:
for pre_reorder_k in pre_reorder_k_tab:
times = []
for _run in range(nrun):
if pre_reorder_k == 0:
t0 = time.time()
I, D = searcher.search_batched(
xq, leaves_to_search=nprobe, final_num_neighbors=k
)
t1 = time.time()
else:
t0 = time.time()
I, D = searcher_reo.search_batched(
xq, leaves_to_search=nprobe, final_num_neighbors=k,
pre_reorder_num_neighbors=pre_reorder_k
)
t1 = time.time()
times.append(t1 - t0)
header = "SCANN nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
if measure == "1-recall":
eval_recalls(header, I, gt, times)
else:
eval_inters(header, I, gt, times)
###############################################################
# Faiss
###############################################################
def faiss_make_index(xb, metric_type, fname):
import faiss
d = xb.shape[1]
M = d // 2
index = faiss.index_factory(d, f"IVF2000,PQ{M}x4fs", metric_type)
# if not by_residual:
# print("setting no residual")
# index.by_residual = False
print("train")
index.train(xb[:250000])
print("add")
index.add(xb)
print("write index", fname)
faiss.write_index(index, fname)
return index
def faiss_eval_search(
index, xq, xb, nprobe_tab, pre_reorder_k_tab,
k, gt, nrun, measure
):
import faiss
print("use precomputed table=", index.use_precomputed_table,
"by residual=", index.by_residual)
print("adding a refine index")
index_refine = faiss.IndexRefineFlat(index, faiss.swig_ptr(xb))
print("set single thread")
faiss.omp_set_num_threads(1)
print("warmup")
for _run in range(5):
index.search(xq, k)
print("run timing")
for nprobe in nprobe_tab:
for pre_reorder_k in pre_reorder_k_tab:
index.nprobe = nprobe
times = []
for _run in range(nrun):
if pre_reorder_k == 0:
t0 = time.time()
D, I = index.search(xq, k)
t1 = time.time()
else:
index_refine.k_factor = pre_reorder_k / k
t0 = time.time()
D, I = index_refine.search(xq, k)
t1 = time.time()
times.append(t1 - t0)
header = "Faiss nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
if measure == "1-recall":
eval_recalls(header, I, gt, times)
else:
eval_inters(header, I, gt, times)
if __name__ == "__main__":
main()