faiss/benchs/bench_big_batch_ivf.py

110 lines
2.4 KiB
Python
Raw Permalink Normal View History

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import time
import faiss
import numpy as np
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.big_batch_search import big_batch_search
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--dim', type=int, default=64)
aa('--size', default="S")
group = parser.add_argument_group('index options')
aa('--nlist', type=int, default=100)
aa('--factory_string', default="", help="overrides nlist")
aa('--k', type=int, default=10)
aa('--nprobe', type=int, default=5)
aa('--nt', type=int, default=-1, help="nb search threads")
aa('--method', default="pairwise_distances", help="")
args = parser.parse_args()
print("args:", args)
if args.size == "S":
ds = SyntheticDataset(32, 2000, 4000, 1000)
elif args.size == "M":
ds = SyntheticDataset(32, 20000, 40000, 10000)
elif args.size == "L":
ds = SyntheticDataset(32, 200000, 400000, 100000)
else:
raise RuntimeError(f"dataset size {args.size} not supported")
nlist = args.nlist
nprobe = args.nprobe
k = args.k
def tic(name):
global tictoc
tictoc = (name, time.time())
print(name, end="\r", flush=True)
def toc():
global tictoc
name, t0 = tictoc
dt = time.time() - t0
print(f"{name}: {dt:.3f} s")
return dt
print(f"dataset {ds}, {nlist=:} {nprobe=:} {k=:}")
if args.factory_string == "":
factory_string = f"IVF{nlist},Flat"
else:
factory_string = args.factory_string
print(f"instantiate {factory_string}")
index = faiss.index_factory(ds.d, factory_string)
if args.factory_string != "":
nlist = index.nlist
print("nlist", nlist)
tic("train")
index.train(ds.get_train())
toc()
tic("add")
index.add(ds.get_database())
toc()
if args.nt != -1:
print("setting nb of threads to", args.nt)
faiss.omp_set_num_threads(args.nt)
tic("reference search")
index.nprobe
index.nprobe = nprobe
Dref, Iref = index.search(ds.get_queries(), k)
t_ref = toc()
tic("block search")
Dnew, Inew = big_batch_search(
index, ds.get_queries(),
k, method=args.method, verbose=10
)
t_tot = toc()
assert (Inew != Iref).sum() / Iref.size < 1e-4
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
print(f"total block search time {t_tot:.3f} s, speedup {t_ref / t_tot:.3f}x")