2024-10-23 00:46:48 +08:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2022-12-10 00:53:13 +08:00
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import time
|
|
|
|
|
|
|
|
import faiss
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from faiss.contrib.datasets import SyntheticDataset
|
2023-07-05 21:59:05 +08:00
|
|
|
from faiss.contrib.big_batch_search import big_batch_search
|
2022-12-10 00:53:13 +08:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
|
|
def aa(*args, **kwargs):
|
|
|
|
group.add_argument(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
group = parser.add_argument_group('dataset options')
|
|
|
|
aa('--dim', type=int, default=64)
|
|
|
|
aa('--size', default="S")
|
|
|
|
|
|
|
|
group = parser.add_argument_group('index options')
|
|
|
|
aa('--nlist', type=int, default=100)
|
|
|
|
aa('--factory_string', default="", help="overrides nlist")
|
|
|
|
aa('--k', type=int, default=10)
|
|
|
|
aa('--nprobe', type=int, default=5)
|
|
|
|
aa('--nt', type=int, default=-1, help="nb search threads")
|
|
|
|
aa('--method', default="pairwise_distances", help="")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("args:", args)
|
|
|
|
|
|
|
|
if args.size == "S":
|
|
|
|
ds = SyntheticDataset(32, 2000, 4000, 1000)
|
|
|
|
elif args.size == "M":
|
|
|
|
ds = SyntheticDataset(32, 20000, 40000, 10000)
|
|
|
|
elif args.size == "L":
|
|
|
|
ds = SyntheticDataset(32, 200000, 400000, 100000)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"dataset size {args.size} not supported")
|
|
|
|
|
|
|
|
nlist = args.nlist
|
|
|
|
nprobe = args.nprobe
|
|
|
|
k = args.k
|
|
|
|
|
|
|
|
|
|
|
|
def tic(name):
|
|
|
|
global tictoc
|
|
|
|
tictoc = (name, time.time())
|
|
|
|
print(name, end="\r", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
def toc():
|
|
|
|
global tictoc
|
|
|
|
name, t0 = tictoc
|
|
|
|
dt = time.time() - t0
|
|
|
|
print(f"{name}: {dt:.3f} s")
|
|
|
|
return dt
|
|
|
|
|
|
|
|
|
|
|
|
print(f"dataset {ds}, {nlist=:} {nprobe=:} {k=:}")
|
|
|
|
|
|
|
|
if args.factory_string == "":
|
|
|
|
factory_string = f"IVF{nlist},Flat"
|
|
|
|
else:
|
|
|
|
factory_string = args.factory_string
|
|
|
|
|
2023-01-13 05:34:44 +08:00
|
|
|
print(f"instantiate {factory_string}")
|
2022-12-10 00:53:13 +08:00
|
|
|
index = faiss.index_factory(ds.d, factory_string)
|
|
|
|
|
|
|
|
if args.factory_string != "":
|
|
|
|
nlist = index.nlist
|
|
|
|
|
|
|
|
print("nlist", nlist)
|
|
|
|
|
|
|
|
tic("train")
|
|
|
|
index.train(ds.get_train())
|
|
|
|
toc()
|
|
|
|
|
|
|
|
tic("add")
|
|
|
|
index.add(ds.get_database())
|
|
|
|
toc()
|
|
|
|
|
|
|
|
if args.nt != -1:
|
|
|
|
print("setting nb of threads to", args.nt)
|
|
|
|
faiss.omp_set_num_threads(args.nt)
|
|
|
|
|
|
|
|
tic("reference search")
|
|
|
|
index.nprobe
|
|
|
|
index.nprobe = nprobe
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), k)
|
|
|
|
t_ref = toc()
|
|
|
|
|
|
|
|
tic("block search")
|
|
|
|
Dnew, Inew = big_batch_search(
|
|
|
|
index, ds.get_queries(),
|
|
|
|
k, method=args.method, verbose=10
|
|
|
|
)
|
|
|
|
t_tot = toc()
|
|
|
|
|
|
|
|
assert (Inew != Iref).sum() / Iref.size < 1e-4
|
|
|
|
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
|
|
|
|
|
|
|
|
print(f"total block search time {t_tot:.3f} s, speedup {t_ref / t_tot:.3f}x")
|