PQ4 fast scan benchmarks (#1555)
Summary: Code + scripts for Faiss benchmarks around the Fast scan codes. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1555 Test Plan: buck test //faiss/tests/:test_refine Reviewed By: wickedfoo Differential Revision: D25546505 Pulled By: mdouze fbshipit-source-id: 902486b7f47e36221a2671d124df8c114f25db58pull/1576/head
parent
90c891b616
commit
c5975cda72
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/env python2
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
|
@ -8,12 +6,15 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import pdb
|
||||
import numpy as np
|
||||
import faiss
|
||||
import argparse
|
||||
import datasets
|
||||
from datasets import sanitize
|
||||
|
||||
|
||||
|
||||
######################################################
|
||||
# Command-line parsing
|
||||
######################################################
|
||||
|
@ -34,8 +35,8 @@ aa('--compute_gt', default=False, action='store_true',
|
|||
group = parser.add_argument_group('index consturction')
|
||||
|
||||
aa('--indexkey', default='HNSW32', help='index_factory type')
|
||||
aa('--efConstruction', default=200, type=int,
|
||||
help='HNSW construction factor')
|
||||
aa('--by_residual', default=-1, type=int,
|
||||
help="set if index should use residuals (default=unchanged)")
|
||||
aa('--M0', default=-1, type=int, help='size of base level')
|
||||
aa('--maxtrain', default=256 * 256, type=int,
|
||||
help='maximum number of training points (0 to set automatically)')
|
||||
|
@ -54,6 +55,8 @@ aa('--get_centroids_from', default='',
|
|||
group = parser.add_argument_group('searching')
|
||||
|
||||
aa('--k', default=100, type=int, help='nb of nearest neighbors')
|
||||
aa('--inter', default=False, action='store_true',
|
||||
help='use intersection measure instead of 1-recall as metric')
|
||||
aa('--searchthreads', default=-1, type=int,
|
||||
help='nb of threads to use at search time')
|
||||
aa('--searchparams', nargs='+', default=['autotune'],
|
||||
|
@ -64,7 +67,7 @@ aa('--autotune_max', default=[], nargs='*',
|
|||
help='set max value for autotune variables format "var:val" (exclusive)')
|
||||
aa('--autotune_range', default=[], nargs='*',
|
||||
help='set complete autotune range, format "var:val1,val2,..."')
|
||||
aa('--min_test_duration', default=0, type=float,
|
||||
aa('--min_test_duration', default=3.0, type=float,
|
||||
help='run test at least for so long to avoid jitter')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -79,64 +82,126 @@ os.system('echo -n "nb processors "; '
|
|||
# Load dataset
|
||||
######################################################
|
||||
|
||||
xt, xb, xq, gt = datasets.load_data(
|
||||
ds = datasets.load_dataset(
|
||||
dataset=args.db, compute_gt=args.compute_gt)
|
||||
|
||||
|
||||
print("dataset sizes: train %s base %s query %s GT %s" % (
|
||||
xt.shape, xb.shape, xq.shape, gt.shape))
|
||||
print(ds)
|
||||
|
||||
nq, d = xq.shape
|
||||
nb, d = xb.shape
|
||||
nq, d = ds.nq, ds.d
|
||||
nb, d = ds.nq, ds.d
|
||||
|
||||
|
||||
######################################################
|
||||
# Make index
|
||||
######################################################
|
||||
|
||||
def unwind_index_ivf(index):
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
assert index.chain.size() == 1
|
||||
vt = index.chain.at(0)
|
||||
index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index))
|
||||
assert vt2 is None
|
||||
return index_ivf, vt
|
||||
if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine):
|
||||
return unwind_index_ivf(faiss.downcast_index(index.base_index))
|
||||
if isinstance(index, faiss.IndexIVF):
|
||||
return index, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
if args.indexfile and os.path.exists(args.indexfile):
|
||||
|
||||
print("reading", args.indexfile)
|
||||
index = faiss.read_index(args.indexfile)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
index_ivf = faiss.downcast_index(index.index)
|
||||
else:
|
||||
index_ivf = index
|
||||
assert isinstance(index_ivf, faiss.IndexIVF)
|
||||
index_ivf, vec_transform = unwind_index_ivf(index)
|
||||
if vec_transform is None:
|
||||
vec_transform = lambda x: x
|
||||
assert isinstance(index_ivf, faiss.IndexIVF)
|
||||
|
||||
else:
|
||||
|
||||
print("build index, key=", args.indexkey)
|
||||
|
||||
index = faiss.index_factory(d, args.indexkey)
|
||||
index = faiss.index_factory(
|
||||
d, args.indexkey, faiss.METRIC_L2 if ds.metric == "L2" else
|
||||
faiss.METRIC_INNER_PRODUCT
|
||||
)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
index_ivf = faiss.downcast_index(index.index)
|
||||
vec_transform = index.chain.at(0).apply_py
|
||||
index_ivf, vec_transform = unwind_index_ivf(index)
|
||||
if vec_transform is None:
|
||||
vec_transform = lambda x: x
|
||||
else:
|
||||
index_ivf = index
|
||||
vec_transform = lambda x:x
|
||||
assert isinstance(index_ivf, faiss.IndexIVF)
|
||||
index_ivf.verbose = True
|
||||
index_ivf.quantizer.verbose = True
|
||||
index_ivf.cp.verbose = True
|
||||
vec_transform = faiss.downcast_VectorTransform(vec_transform)
|
||||
|
||||
if args.by_residual != -1:
|
||||
by_residual = args.by_residual == 1
|
||||
print("setting by_residual = ", by_residual)
|
||||
index_ivf.by_residual # check if field exists
|
||||
index_ivf.by_residual = by_residual
|
||||
|
||||
|
||||
if index_ivf:
|
||||
print("Update add-time parameters")
|
||||
# adjust default parameters used at add time for quantizers
|
||||
# because otherwise the assignment is inaccurate
|
||||
quantizer = faiss.downcast_index(index_ivf.quantizer)
|
||||
if isinstance(quantizer, faiss.IndexRefine):
|
||||
print(" update quantizer k_factor=", quantizer.k_factor, end=" -> ")
|
||||
quantizer.k_factor = 32 if index_ivf.nlist < 1e6 else 64
|
||||
print(quantizer.k_factor)
|
||||
base_index = faiss.downcast_index(quantizer.base_index)
|
||||
if isinstance(base_index, faiss.IndexIVF):
|
||||
print(" update quantizer nprobe=", base_index.nprobe, end=" -> ")
|
||||
base_index.nprobe = (
|
||||
16 if base_index.nlist < 1e5 else
|
||||
32 if base_index.nlist < 4e6 else
|
||||
64)
|
||||
print(base_index.nprobe)
|
||||
elif isinstance(quantizer, faiss.IndexHNSW):
|
||||
print(" update quantizer efSearch=", quantizer.hnsw.efSearch, end=" -> ")
|
||||
quantizer.hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64
|
||||
print(quantizer.hnsw.efSearch)
|
||||
|
||||
if index_ivf:
|
||||
index_ivf.verbose = True
|
||||
index_ivf.quantizer.verbose = True
|
||||
index_ivf.cp.verbose = True
|
||||
else:
|
||||
index.verbose = True
|
||||
|
||||
maxtrain = args.maxtrain
|
||||
if maxtrain == 0:
|
||||
if 'IMI' in args.indexkey:
|
||||
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2))
|
||||
else:
|
||||
elif index_ivf:
|
||||
maxtrain = 50 * index_ivf.nlist
|
||||
else:
|
||||
# just guess...
|
||||
maxtrain = 256 * 100
|
||||
maxtrain = max(maxtrain, 256 * 100)
|
||||
print("setting maxtrain to %d" % maxtrain)
|
||||
args.maxtrain = maxtrain
|
||||
|
||||
xt2 = sanitize(xt[:args.maxtrain])
|
||||
assert np.all(np.isfinite(xt2))
|
||||
try:
|
||||
xt2 = ds.get_train(maxtrain=maxtrain)
|
||||
except NotImplementedError:
|
||||
print("No training set: training on database")
|
||||
xt2 = ds.get_database()[:maxtrain]
|
||||
|
||||
print("train, size", xt2.shape)
|
||||
assert np.all(np.isfinite(xt2))
|
||||
|
||||
if (isinstance(vec_transform, faiss.OPQMatrix) and
|
||||
isinstance(index_ivf, faiss.IndexIVFPQFastScan)):
|
||||
print(" Forcing OPQ training PQ to PQ4")
|
||||
ref_pq = index_ivf.pq
|
||||
training_pq = faiss.ProductQuantizer(
|
||||
ref_pq.d, ref_pq.M, ref_pq.nbits
|
||||
)
|
||||
vec_transform.pq
|
||||
vec_transform.pq = training_pq
|
||||
|
||||
|
||||
if args.get_centroids_from == '':
|
||||
|
||||
|
@ -147,7 +212,8 @@ else:
|
|||
|
||||
if args.train_on_gpu:
|
||||
print("add a training index on GPU")
|
||||
train_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
|
||||
train_index = faiss.index_cpu_to_all_gpus(
|
||||
faiss.IndexFlatL2(index_ivf.d))
|
||||
index_ivf.clustering_index = train_index
|
||||
|
||||
else:
|
||||
|
@ -158,13 +224,15 @@ else:
|
|||
centroids = centroids.reshape(-1, d)
|
||||
print(" centroid table shape", centroids.shape)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
if isinstance(vec_transform, faiss.VectorTransform):
|
||||
print(" training vector transform")
|
||||
assert index.chain.size() == 1
|
||||
vt = index.chain.at(0)
|
||||
vt.train(xt2)
|
||||
vec_transform.train(xt2)
|
||||
print(" transform centroids")
|
||||
centroids = vt.apply_py(centroids)
|
||||
centroids = vec_transform.apply_py(centroids)
|
||||
|
||||
if not index_ivf.quantizer.is_trained:
|
||||
print(" training quantizer")
|
||||
index_ivf.quantizer.train(centroids)
|
||||
|
||||
print(" add centroids to quantizer")
|
||||
index_ivf.quantizer.add(centroids)
|
||||
|
@ -177,12 +245,16 @@ else:
|
|||
print("adding")
|
||||
t0 = time.time()
|
||||
if args.add_bs == -1:
|
||||
index.add(sanitize(xb))
|
||||
index.add(sanitize(ds.get_database()))
|
||||
else:
|
||||
for i0 in range(0, nb, args.add_bs):
|
||||
i1 = min(nb, i0 + args.add_bs)
|
||||
print(" adding %d:%d / %d" % (i0, i1, nb))
|
||||
index.add(sanitize(xb[i0:i1]))
|
||||
i0 = 0
|
||||
for xblock in ds.database_iterator(bs=args.add_bs):
|
||||
i1 = i0 + len(xblock)
|
||||
print(" adding %d:%d / %d [%.3f s, RSS %d kiB] " % (
|
||||
i0, i1, ds.nb, time.time() - t0,
|
||||
faiss.get_mem_usage_kb()))
|
||||
index.add(xblock)
|
||||
i0 = i1
|
||||
|
||||
print(" add in %.3f s" % (time.time() - t0))
|
||||
if args.indexfile:
|
||||
|
@ -211,39 +283,65 @@ print("precomputed tables size:", precomputed_table_size)
|
|||
# Index is ready
|
||||
#############################################################
|
||||
|
||||
xq = sanitize(xq)
|
||||
xq = sanitize(ds.get_queries())
|
||||
gt = ds.get_groundtruth(k=args.k)
|
||||
assert gt.shape[1] == args.k, pdb.set_trace()
|
||||
|
||||
if args.searchthreads != -1:
|
||||
print("Setting nb of threads to", args.searchthreads)
|
||||
faiss.omp_set_num_threads(args.searchthreads)
|
||||
|
||||
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.initialize(index)
|
||||
|
||||
|
||||
parametersets = args.searchparams
|
||||
|
||||
header = '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % "parameters"
|
||||
|
||||
|
||||
def eval_setting(index, xq, gt, min_time):
|
||||
if args.inter:
|
||||
header = (
|
||||
'%-40s inter@%3d time(ms/q) nb distances #runs' %
|
||||
("parameters", args.k)
|
||||
)
|
||||
else:
|
||||
|
||||
header = (
|
||||
'%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' %
|
||||
"parameters"
|
||||
)
|
||||
|
||||
def compute_inter(a, b):
|
||||
nq, rank = a.shape
|
||||
ninter = sum(
|
||||
np.intersect1d(a[i, :rank], b[i, :rank]).size
|
||||
for i in range(nq)
|
||||
)
|
||||
return ninter / a.size
|
||||
|
||||
|
||||
|
||||
def eval_setting(index, xq, gt, k, inter, min_time):
|
||||
nq = xq.shape[0]
|
||||
ivf_stats = faiss.cvar.indexIVF_stats
|
||||
ivf_stats.reset()
|
||||
nrun = 0
|
||||
t0 = time.time()
|
||||
while True:
|
||||
D, I = index.search(xq, 100)
|
||||
D, I = index.search(xq, k)
|
||||
nrun += 1
|
||||
t1 = time.time()
|
||||
if t1 - t0 > min_time:
|
||||
break
|
||||
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
|
||||
for rank in 1, 10, 100:
|
||||
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
||||
print("%.4f" % (n_ok / float(nq)), end=' ')
|
||||
print(" %8.3f " % ms_per_query, end=' ')
|
||||
if inter:
|
||||
rank = k
|
||||
inter_measure = compute_inter(gt[:, :rank], I[:, :rank])
|
||||
print("%.4f" % inter_measure, end=' ')
|
||||
else:
|
||||
for rank in 1, 10, 100:
|
||||
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
||||
print("%.4f" % (n_ok / float(nq)), end=' ')
|
||||
print(" %9.5f " % ms_per_query, end=' ')
|
||||
print("%12d " % (ivf_stats.ndis / nrun), end=' ')
|
||||
print(nrun)
|
||||
|
||||
|
@ -269,15 +367,20 @@ if parametersets == ['autotune']:
|
|||
pr = ps.add_range(k)
|
||||
faiss.copy_array_to_vector(vals, pr.values)
|
||||
|
||||
# setup the Criterion object: optimize for 1-R@1
|
||||
crit = faiss.OneRecallAtRCriterion(nq, 1)
|
||||
# setup the Criterion object
|
||||
if args.inter:
|
||||
print("Optimize for intersection @ ", args.k)
|
||||
crit = faiss.IntersectionCriterion(nq, args.k)
|
||||
else:
|
||||
print("Optimize for 1-recall @ 1")
|
||||
crit = faiss.OneRecallAtRCriterion(nq, 1)
|
||||
|
||||
# by default, the criterion will request only 1 NN
|
||||
crit.nnn = 100
|
||||
crit.nnn = args.k
|
||||
crit.set_groundtruth(None, gt.astype('int64'))
|
||||
|
||||
# then we let Faiss find the optimal parameters by itself
|
||||
print("exploring operating points")
|
||||
print("exploring operating points, %d threads" % faiss.omp_get_max_threads());
|
||||
ps.display()
|
||||
|
||||
t0 = time.time()
|
||||
|
@ -286,17 +389,19 @@ if parametersets == ['autotune']:
|
|||
|
||||
op.display()
|
||||
|
||||
print("Re-running evaluation on selected OPs")
|
||||
print(header)
|
||||
opv = op.optimal_pts
|
||||
maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40)
|
||||
for i in range(opv.size()):
|
||||
opt = opv.at(i)
|
||||
|
||||
ps.set_index_parameters(index, opt.key)
|
||||
|
||||
print("%-40s " % opt.key, end=' ')
|
||||
print(opt.key.ljust(maxw), end=' ')
|
||||
sys.stdout.flush()
|
||||
|
||||
eval_setting(index, xq, gt, args.min_test_duration)
|
||||
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
|
||||
else:
|
||||
print(header)
|
||||
|
@ -305,4 +410,4 @@ else:
|
|||
sys.stdout.flush()
|
||||
ps.set_index_parameters(index, param)
|
||||
|
||||
eval_setting(index, xq, gt, args.min_test_duration)
|
||||
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
#!/usr/bin/env python2
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
@ -98,6 +95,7 @@ clustering.seed = args.seed
|
|||
clustering.max_points_per_centroid = 10**6
|
||||
clustering.min_points_per_centroid = 1
|
||||
|
||||
centroids = None
|
||||
|
||||
for iter0 in range(0, args.niter, args.eval_freq):
|
||||
iter1 = min(args.niter, iter0 + args.eval_freq)
|
||||
|
|
|
@ -0,0 +1,308 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def eval_recalls(name, I, gt, times):
|
||||
k = I.shape[1]
|
||||
s = "%-40s recall" % name
|
||||
nq = len(gt)
|
||||
for rank in 1, 10, 100, 1000:
|
||||
if rank > k:
|
||||
break
|
||||
recall = (I[:, :rank] == gt[:, :1]).sum() / nq
|
||||
s += "@%d: %.4f " % (rank, recall)
|
||||
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
|
||||
print(s)
|
||||
|
||||
def eval_inters(name, I, gt, times):
|
||||
k = I.shape[1]
|
||||
s = "%-40s inter" % name
|
||||
nq = len(gt)
|
||||
for rank in 1, 10, 100, 1000:
|
||||
if rank > k:
|
||||
break
|
||||
ninter = 0
|
||||
for i in range(nq):
|
||||
ninter += np.intersect1d(I[i, :rank], gt[i, :rank]).size
|
||||
inter = ninter / (nq * rank)
|
||||
s += "@%d: %.4f " % (rank, inter)
|
||||
s += "time: %.4f s (± %.4f)" % (np.mean(times), np.std(times))
|
||||
print(s)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def aa(*args, **kwargs):
|
||||
group.add_argument(*args, **kwargs)
|
||||
|
||||
group = parser.add_argument_group('dataset options')
|
||||
|
||||
aa('--db', default='deep1M', help='dataset')
|
||||
aa('--measure', default="1-recall",
|
||||
help="perf measure to use: 1-recall or inter")
|
||||
aa('--download', default=False, action="store_true")
|
||||
aa('--lib', default='faiss', help='library to use (faiss or scann)')
|
||||
aa('--thenscann', default=False, action="store_true")
|
||||
aa('--base_dir', default='/checkpoint/matthijs/faiss_improvements/cmp_ivf_scan_2')
|
||||
|
||||
group = parser.add_argument_group('searching')
|
||||
aa('--k', default=10, type=int, help='nb of nearest neighbors')
|
||||
aa('--pre_reorder_k', default="0,10,100,1000", help='values for reorder_k')
|
||||
aa('--nprobe', default="1,2,5,10,20,50,100,200", help='values for nprobe')
|
||||
aa('--nrun', default=5, type=int, help='nb of runs to perform')
|
||||
args = parser.parse_args()
|
||||
|
||||
print("args:", args)
|
||||
pre_reorder_k_tab = [int(x) for x in args.pre_reorder_k.split(',')]
|
||||
nprobe_tab = [int(x) for x in args.nprobe.split(',')]
|
||||
|
||||
os.system('echo -n "nb processors "; '
|
||||
'cat /proc/cpuinfo | grep ^processor | wc -l; '
|
||||
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
|
||||
|
||||
cache_dir = args.base_dir + "/" + args.db + "/"
|
||||
k = args.k
|
||||
nrun = args.nrun
|
||||
|
||||
if args.lib == "faiss":
|
||||
# prepare cache
|
||||
import faiss
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset(args.db, download=args.download)
|
||||
print(ds)
|
||||
if not os.path.exists(cache_dir + "xb.npy"):
|
||||
# store for SCANN
|
||||
os.system(f"rm -rf {cache_dir}; mkdir -p {cache_dir}")
|
||||
tosave = dict(
|
||||
# xt = ds.get_train(10),
|
||||
xb = ds.get_database(),
|
||||
xq = ds.get_queries(),
|
||||
gt = ds.get_groundtruth()
|
||||
)
|
||||
for name, v in tosave.items():
|
||||
fname = cache_dir + "/" + name + ".npy"
|
||||
print("save", fname)
|
||||
np.save(fname, v)
|
||||
|
||||
open(cache_dir + "metric", "w").write(ds.metric)
|
||||
|
||||
name1_to_metric = {
|
||||
"IP": faiss.METRIC_INNER_PRODUCT,
|
||||
"L2": faiss.METRIC_L2
|
||||
}
|
||||
|
||||
index_fname = cache_dir + "index.faiss"
|
||||
if not os.path.exists(index_fname):
|
||||
index = faiss_make_index(
|
||||
ds.get_database(), name1_to_metric[ds.metric], index_fname)
|
||||
else:
|
||||
index = faiss.read_index(index_fname)
|
||||
|
||||
xb = ds.get_database()
|
||||
xq = ds.get_queries()
|
||||
gt = ds.get_groundtruth()
|
||||
|
||||
faiss_eval_search(
|
||||
index, xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
||||
nrun, args.measure
|
||||
)
|
||||
|
||||
if args.lib == "scann":
|
||||
from scann.scann_ops.py import scann_ops_pybind
|
||||
|
||||
dataset = {}
|
||||
for kn in "xb xq gt".split():
|
||||
fname = cache_dir + "/" + kn + ".npy"
|
||||
print("load", fname)
|
||||
dataset[kn] = np.load(fname)
|
||||
name1_to_name2 = {
|
||||
"IP": "dot_product",
|
||||
"L2": "squared_l2"
|
||||
}
|
||||
distance_measure = name1_to_name2[open(cache_dir + "metric").read()]
|
||||
|
||||
xb = dataset["xb"]
|
||||
xq = dataset["xq"]
|
||||
gt = dataset["gt"]
|
||||
|
||||
scann_dir = cache_dir + "/scann1.1.1_serialized"
|
||||
if os.path.exists(scann_dir + "/scann_config.pb"):
|
||||
searcher = scann_ops_pybind.load_searcher(scann_dir)
|
||||
else:
|
||||
searcher = scann_make_index(xb, distance_measure, scann_dir, 0)
|
||||
|
||||
scann_dir = cache_dir + "/scann1.1.1_serialized_reorder"
|
||||
if os.path.exists(scann_dir + "/scann_config.pb"):
|
||||
searcher_reo = scann_ops_pybind.load_searcher(scann_dir)
|
||||
else:
|
||||
searcher_reo = scann_make_index(xb, distance_measure, scann_dir, 100)
|
||||
|
||||
scann_eval_search(
|
||||
searcher, searcher_reo,
|
||||
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
||||
nrun, args.measure
|
||||
)
|
||||
|
||||
if args.lib != "scann" and args.thenscann:
|
||||
# just append --lib scann, that will override the previous cmdline
|
||||
# options
|
||||
cmdline = " ".join(sys.argv) + " --lib scann"
|
||||
cmdline = (
|
||||
". ~/anaconda3/etc/profile.d/conda.sh ; " +
|
||||
"conda activate scann_1.1.1; "
|
||||
"python -u " + cmdline)
|
||||
|
||||
print("running", cmdline)
|
||||
|
||||
os.system(cmdline)
|
||||
|
||||
|
||||
###############################################################
|
||||
# SCANN
|
||||
###############################################################
|
||||
|
||||
def scann_make_index(xb, distance_measure, scann_dir, reorder_k):
|
||||
import scann
|
||||
|
||||
print("build index")
|
||||
|
||||
if distance_measure == "dot_product":
|
||||
thr = 0.2
|
||||
else:
|
||||
thr = 0
|
||||
k = 10
|
||||
sb = scann.scann_ops_pybind.builder(xb, k, distance_measure)
|
||||
sb = sb.tree(num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
|
||||
sb = sb.score_ah(2, anisotropic_quantization_threshold=thr)
|
||||
|
||||
if reorder_k > 0:
|
||||
sb = sb.reorder(reorder_k)
|
||||
|
||||
searcher = sb.build()
|
||||
|
||||
print("done")
|
||||
|
||||
print("write index to", scann_dir)
|
||||
|
||||
os.system(f"rm -rf {scann_dir}; mkdir -p {scann_dir}")
|
||||
# os.mkdir(scann_dir)
|
||||
searcher.serialize(scann_dir)
|
||||
return searcher
|
||||
|
||||
def scann_eval_search(
|
||||
searcher, searcher_reo,
|
||||
xq, xb, nprobe_tab, pre_reorder_k_tab, k, gt,
|
||||
nrun, measure):
|
||||
|
||||
# warmup
|
||||
for _run in range(5):
|
||||
searcher.search_batched(xq)
|
||||
|
||||
for nprobe in nprobe_tab:
|
||||
|
||||
for pre_reorder_k in pre_reorder_k_tab:
|
||||
|
||||
times = []
|
||||
for _run in range(nrun):
|
||||
if pre_reorder_k == 0:
|
||||
t0 = time.time()
|
||||
I, D = searcher.search_batched(
|
||||
xq, leaves_to_search=nprobe, final_num_neighbors=k
|
||||
)
|
||||
t1 = time.time()
|
||||
else:
|
||||
t0 = time.time()
|
||||
I, D = searcher_reo.search_batched(
|
||||
xq, leaves_to_search=nprobe, final_num_neighbors=k,
|
||||
pre_reorder_num_neighbors=pre_reorder_k
|
||||
)
|
||||
t1 = time.time()
|
||||
|
||||
times.append(t1 - t0)
|
||||
header = "SCANN nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
|
||||
if measure == "1-recall":
|
||||
eval_recalls(header, I, gt, times)
|
||||
else:
|
||||
eval_inters(header, I, gt, times)
|
||||
|
||||
|
||||
|
||||
|
||||
###############################################################
|
||||
# Faiss
|
||||
###############################################################
|
||||
|
||||
|
||||
def faiss_make_index(xb, metric_type, fname):
|
||||
import faiss
|
||||
|
||||
d = xb.shape[1]
|
||||
M = d // 2
|
||||
index = faiss.index_factory(d, f"IVF2000,PQ{M}x4fs", metric_type)
|
||||
# if not by_residual:
|
||||
# print("setting no residual")
|
||||
# index.by_residual = False
|
||||
|
||||
print("train")
|
||||
# index.train(ds.get_train())
|
||||
index.train(xb[:250000])
|
||||
print("add")
|
||||
index.add(xb)
|
||||
print("write index", fname)
|
||||
faiss.write_index(index, fname)
|
||||
|
||||
return index
|
||||
|
||||
def faiss_eval_search(
|
||||
index, xq, xb, nprobe_tab, pre_reorder_k_tab,
|
||||
k, gt, nrun, measure
|
||||
):
|
||||
import faiss
|
||||
|
||||
print("use precomputed table=", index.use_precomputed_table,
|
||||
"by residual=", index.by_residual)
|
||||
|
||||
print("adding a refine index")
|
||||
index_refine = faiss.IndexRefineFlat(index, faiss.swig_ptr(xb))
|
||||
|
||||
print("set single thread")
|
||||
faiss.omp_set_num_threads(1)
|
||||
|
||||
print("warmup")
|
||||
for _run in range(5):
|
||||
index.search(xq, k)
|
||||
|
||||
print("run timing")
|
||||
for nprobe in nprobe_tab:
|
||||
for pre_reorder_k in pre_reorder_k_tab:
|
||||
index.nprobe = nprobe
|
||||
times = []
|
||||
for _run in range(nrun):
|
||||
if pre_reorder_k == 0:
|
||||
t0 = time.time()
|
||||
D, I = index.search(xq, k)
|
||||
t1 = time.time()
|
||||
else:
|
||||
index_refine.k_factor = pre_reorder_k / k
|
||||
t0 = time.time()
|
||||
D, I = index_refine.search(xq, k)
|
||||
t1 = time.time()
|
||||
|
||||
times.append(t1 - t0)
|
||||
|
||||
header = "Faiss nprobe=%4d reo=%4d" % (nprobe, pre_reorder_k)
|
||||
if measure == "1-recall":
|
||||
eval_recalls(header, I, gt, times)
|
||||
else:
|
||||
eval_inters(header, I, gt, times)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,5 +1,3 @@
|
|||
#! /usr/bin/env python2
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
|
@ -9,168 +7,83 @@
|
|||
Common functions to load datasets and compute their ground-truth
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
import sys
|
||||
|
||||
# set this to the directory that contains the datafiles.
|
||||
# deep1b data should be at simdir + 'deep1b'
|
||||
# bigann data should be at simdir + 'bigann'
|
||||
simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/'
|
||||
from faiss.contrib import datasets as faiss_datasets
|
||||
|
||||
#################################################################
|
||||
# Small I/O functions
|
||||
#################################################################
|
||||
print("path:", faiss_datasets.__file__)
|
||||
|
||||
faiss_datasets.dataset_basedir = '/checkpoint/matthijs/simsearch/'
|
||||
|
||||
def ivecs_read(fname):
|
||||
a = np.fromfile(fname, dtype='int32')
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
|
||||
|
||||
def fvecs_read(fname):
|
||||
return ivecs_read(fname).view('float32')
|
||||
|
||||
|
||||
def ivecs_mmap(fname):
|
||||
a = np.memmap(fname, dtype='int32', mode='r')
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:]
|
||||
|
||||
|
||||
def fvecs_mmap(fname):
|
||||
return ivecs_mmap(fname).view('float32')
|
||||
|
||||
|
||||
def bvecs_mmap(fname):
|
||||
x = np.memmap(fname, dtype='uint8', mode='r')
|
||||
d = x[:4].view('int32')[0]
|
||||
return x.reshape(-1, d + 4)[:, 4:]
|
||||
|
||||
|
||||
def ivecs_write(fname, m):
|
||||
n, d = m.shape
|
||||
m1 = np.empty((n, d + 1), dtype='int32')
|
||||
m1[:, 0] = d
|
||||
m1[:, 1:] = m
|
||||
m1.tofile(fname)
|
||||
|
||||
|
||||
def fvecs_write(fname, m):
|
||||
m = m.astype('float32')
|
||||
ivecs_write(fname, m.view('int32'))
|
||||
|
||||
def sanitize(x):
|
||||
return np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
|
||||
#################################################################
|
||||
# Dataset
|
||||
#################################################################
|
||||
|
||||
def sanitize(x):
|
||||
return np.ascontiguousarray(x, dtype='float32')
|
||||
class DatasetCentroids(faiss_datasets.Dataset):
|
||||
|
||||
def __init__(self, ds, indexfile):
|
||||
self.d = ds.d
|
||||
self.metric = ds.metric
|
||||
self.nq = ds.nq
|
||||
self.xq = ds.get_queries()
|
||||
|
||||
# get the xb set
|
||||
src_index = faiss.read_index(indexfile)
|
||||
src_quant = faiss.downcast_index(src_index.quantizer)
|
||||
centroids = faiss.vector_to_array(src_quant.xb)
|
||||
self.xb = centroids.reshape(-1, self.d)
|
||||
self.nb = self.nt = len(self.xb)
|
||||
|
||||
def get_queries(self):
|
||||
return self.xq
|
||||
|
||||
def get_database(self):
|
||||
return self.xb
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
return self.xb
|
||||
|
||||
def get_groundtruth(self, k=100):
|
||||
return faiss.knn(
|
||||
self.xq, self.xb, k,
|
||||
faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
|
||||
)[1]
|
||||
|
||||
|
||||
class ResultHeap:
|
||||
""" Combine query results from a sliced dataset """
|
||||
|
||||
def __init__(self, nq, k):
|
||||
" nq: number of query vectors, k: number of results per query "
|
||||
self.I = np.zeros((nq, k), dtype='int64')
|
||||
self.D = np.zeros((nq, k), dtype='float32')
|
||||
self.nq, self.k = nq, k
|
||||
heaps = faiss.float_maxheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = faiss.swig_ptr(self.D)
|
||||
heaps.ids = faiss.swig_ptr(self.I)
|
||||
heaps.heapify()
|
||||
self.heaps = heaps
|
||||
|
||||
def add_batch_result(self, D, I, i0):
|
||||
assert D.shape == (self.nq, self.k)
|
||||
assert I.shape == (self.nq, self.k)
|
||||
I += i0
|
||||
self.heaps.addn_with_ids(
|
||||
self.k, faiss.swig_ptr(D),
|
||||
faiss.swig_ptr(I), self.k)
|
||||
|
||||
def finalize(self):
|
||||
self.heaps.reorder()
|
||||
|
||||
|
||||
def compute_GT_sliced(xb, xq, k):
|
||||
print("compute GT")
|
||||
t0 = time.time()
|
||||
nb, d = xb.shape
|
||||
nq, d = xq.shape
|
||||
rh = ResultHeap(nq, k)
|
||||
bs = 10 ** 5
|
||||
|
||||
xqs = sanitize(xq)
|
||||
|
||||
db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
|
||||
|
||||
# compute ground-truth by blocks of bs, and add to heaps
|
||||
for i0 in range(0, nb, bs):
|
||||
i1 = min(nb, i0 + bs)
|
||||
xsl = sanitize(xb[i0:i1])
|
||||
db_gt.add(xsl)
|
||||
D, I = db_gt.search(xqs, k)
|
||||
rh.add_batch_result(D, I, i0)
|
||||
db_gt.reset()
|
||||
print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ')
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
rh.finalize()
|
||||
gt_I = rh.I
|
||||
|
||||
print("GT time: %.3f s" % (time.time() - t0))
|
||||
return gt_I
|
||||
|
||||
|
||||
def do_compute_gt(xb, xq, k):
|
||||
print("computing GT")
|
||||
nb, d = xb.shape
|
||||
index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
|
||||
if nb < 100 * 1000:
|
||||
print(" add")
|
||||
index.add(np.ascontiguousarray(xb, dtype='float32'))
|
||||
print(" search")
|
||||
D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k)
|
||||
else:
|
||||
I = compute_GT_sliced(xb, xq, k)
|
||||
|
||||
return I.astype('int32')
|
||||
|
||||
|
||||
def load_data(dataset='deep1M', compute_gt=False):
|
||||
def load_dataset(dataset='deep1M', compute_gt=False, download=False):
|
||||
|
||||
print("load data", dataset)
|
||||
|
||||
if dataset == 'sift1M':
|
||||
basedir = simdir + 'sift1M/'
|
||||
|
||||
xt = fvecs_read(basedir + "sift_learn.fvecs")
|
||||
xb = fvecs_read(basedir + "sift_base.fvecs")
|
||||
xq = fvecs_read(basedir + "sift_query.fvecs")
|
||||
gt = ivecs_read(basedir + "sift_groundtruth.ivecs")
|
||||
return faiss_datasets.DatasetSIFT1M()
|
||||
|
||||
elif dataset.startswith('bigann'):
|
||||
basedir = simdir + 'bigann/'
|
||||
|
||||
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
|
||||
xb = bvecs_mmap(basedir + 'bigann_base.bvecs')
|
||||
xq = bvecs_mmap(basedir + 'bigann_query.bvecs')
|
||||
xt = bvecs_mmap(basedir + 'bigann_learn.bvecs')
|
||||
# trim xb to correct size
|
||||
xb = xb[:dbsize * 1000 * 1000]
|
||||
gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize)
|
||||
|
||||
return faiss_datasets.DatasetBigANN(nb_M=dbsize)
|
||||
|
||||
elif dataset.startswith("deep_centroids_"):
|
||||
ncent = int(dataset[len("deep_centroids_"):])
|
||||
centdir = "/checkpoint/matthijs/bench_all_ivf/precomputed_clusters"
|
||||
return DatasetCentroids(
|
||||
faiss_datasets.DatasetDeep1B(nb=1000000),
|
||||
f"{centdir}/clustering.dbdeep1M.IVF{ncent}.faissindex"
|
||||
)
|
||||
|
||||
|
||||
elif dataset.startswith("deep"):
|
||||
basedir = simdir + 'deep1b/'
|
||||
|
||||
szsuf = dataset[4:]
|
||||
if szsuf[-1] == 'M':
|
||||
dbsize = 10 ** 6 * int(szsuf[:-1])
|
||||
|
@ -180,28 +93,17 @@ def load_data(dataset='deep1M', compute_gt=False):
|
|||
dbsize = 1000 * int(szsuf[:-1])
|
||||
else:
|
||||
assert False, "did not recognize suffix " + szsuf
|
||||
return faiss_datasets.DatasetDeep1B(nb=dbsize)
|
||||
|
||||
xt = fvecs_mmap(basedir + "learn.fvecs")
|
||||
xb = fvecs_mmap(basedir + "base.fvecs")
|
||||
xq = fvecs_read(basedir + "deep1B_queries.fvecs")
|
||||
elif dataset == "music-100":
|
||||
return faiss_datasets.DatasetMusic100()
|
||||
|
||||
xb = xb[:dbsize]
|
||||
|
||||
gt_fname = basedir + "%s_groundtruth.ivecs" % dataset
|
||||
if compute_gt:
|
||||
gt = do_compute_gt(xb, xq, 100)
|
||||
print("store", gt_fname)
|
||||
ivecs_write(gt_fname, gt)
|
||||
|
||||
gt = ivecs_read(gt_fname)
|
||||
elif dataset == "glove":
|
||||
return faiss_datasets.DatasetGlove(download=download)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
print("dataset %s sizes: B %s Q %s T %s" % (
|
||||
dataset, xb.shape, xq.shape, xt.shape))
|
||||
|
||||
return xt, xb, xq, gt
|
||||
|
||||
#################################################################
|
||||
# Evaluation
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
# https://stackoverflow.com/questions/7016056/python-logging-not-outputting-anything
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('faiss.contrib.exhaustive_search')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
from faiss.contrib import datasets
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth
|
||||
from faiss.contrib import vecs_io
|
||||
|
||||
ds = datasets.DatasetDeep1B(nb=int(1e9))
|
||||
|
||||
print("computing GT matches for", ds)
|
||||
|
||||
D, I = knn_ground_truth(
|
||||
ds.get_queries(),
|
||||
ds.database_iterator(bs=65536),
|
||||
k=100
|
||||
)
|
||||
|
||||
vecs_io.ivecs_write("/tmp/tt.ivecs", I)
|
|
@ -1,54 +1,18 @@
|
|||
#! /usr/bin/env python2
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from matplotlib import pyplot
|
||||
|
||||
import re
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
|
||||
# the directory used in run_on_cluster.bash
|
||||
basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/'
|
||||
logdir = basedir + 'logs/'
|
||||
|
||||
|
||||
# which plot to output
|
||||
db = 'bigann1B'
|
||||
code_size = 8
|
||||
|
||||
|
||||
|
||||
def unitsize(indexkey):
|
||||
""" size of one vector in the index """
|
||||
mo = re.match('.*,PQ(\\d+)', indexkey)
|
||||
if mo:
|
||||
return int(mo.group(1))
|
||||
if indexkey.endswith('SQ8'):
|
||||
bits_per_d = 8
|
||||
elif indexkey.endswith('SQ4'):
|
||||
bits_per_d = 4
|
||||
elif indexkey.endswith('SQfp16'):
|
||||
bits_per_d = 16
|
||||
else:
|
||||
assert False
|
||||
mo = re.match('PCAR(\\d+),.*', indexkey)
|
||||
if mo:
|
||||
return bits_per_d * int(mo.group(1)) / 8
|
||||
mo = re.match('OPQ\\d+_(\\d+),.*', indexkey)
|
||||
if mo:
|
||||
return bits_per_d * int(mo.group(1)) / 8
|
||||
mo = re.match('RR(\\d+),.*', indexkey)
|
||||
if mo:
|
||||
return bits_per_d * int(mo.group(1)) / 8
|
||||
assert False
|
||||
from faiss.contrib.factory_tools import get_code_size as unitsize
|
||||
|
||||
|
||||
def dbsize_from_name(dbname):
|
||||
|
@ -84,10 +48,20 @@ def parse_result_file(fname):
|
|||
keys = []
|
||||
stats = {}
|
||||
stats['run_version'] = fname[-8]
|
||||
indexkey = None
|
||||
for l in open(fname):
|
||||
if st == 0:
|
||||
if l.startswith('CHRONOS_JOB_INSTANCE_ID'):
|
||||
stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1]
|
||||
if l.startswith("srun:"):
|
||||
# looks like a crash...
|
||||
if indexkey is None:
|
||||
raise RuntimeError("instant crash")
|
||||
break
|
||||
elif st == 0:
|
||||
if l.startswith("dataset in dimension"):
|
||||
fi = l.split()
|
||||
stats["d"] = int(fi[3][:-1])
|
||||
stats["nq"] = int(fi[9])
|
||||
stats["nb"] = int(fi[11])
|
||||
stats["nt"] = int(fi[13])
|
||||
if l.startswith('index size on disk:'):
|
||||
stats['index_size'] = int(l.split()[-1])
|
||||
if l.startswith('current RSS:'):
|
||||
|
@ -101,7 +75,25 @@ def parse_result_file(fname):
|
|||
if l.startswith('args:'):
|
||||
args = eval(l[l.find(' '):])
|
||||
indexkey = args.indexkey
|
||||
elif 'R@1 R@10 R@100' in l:
|
||||
elif "time(ms/q)" in l:
|
||||
# result header
|
||||
if 'R@1 R@10 R@100' in l:
|
||||
stats["measure"] = "recall"
|
||||
stats["ranks"] = [1, 10, 100]
|
||||
elif 'I@1 I@10 I@100' in l:
|
||||
stats["measure"] = "inter"
|
||||
stats["ranks"] = [1, 10, 100]
|
||||
elif 'inter@' in l:
|
||||
stats["measure"] = "inter"
|
||||
fi = l.split()
|
||||
if fi[1] == "inter@":
|
||||
rank = int(fi[2])
|
||||
else:
|
||||
rank = int(fi[1][len("inter@"):])
|
||||
stats["ranks"] = [rank]
|
||||
|
||||
else:
|
||||
assert False
|
||||
st = 1
|
||||
elif 'index size on disk:' in l:
|
||||
index_size = int(l.split()[-1])
|
||||
|
@ -109,115 +101,106 @@ def parse_result_file(fname):
|
|||
st = 2
|
||||
elif st == 2:
|
||||
fi = l.split()
|
||||
if l[0] == " ":
|
||||
# means there are 0 parameters
|
||||
fi = [""] + fi
|
||||
keys.append(fi[0])
|
||||
res.append([float(x) for x in fi[1:]])
|
||||
return indexkey, np.array(res), keys, stats
|
||||
|
||||
# run parsing
|
||||
allres = {}
|
||||
allstats = {}
|
||||
nts = []
|
||||
missing = []
|
||||
versions = {}
|
||||
# the directory used in run_on_cluster.bash
|
||||
basedir = "/checkpoint/matthijs/bench_all_ivf/"
|
||||
logdir = basedir + 'logs/'
|
||||
|
||||
fnames = keep_latest_stdout(os.listdir(logdir))
|
||||
# print fnames
|
||||
# filenames are in the form <key>.x.stdout
|
||||
# where x is a version number (from a to z)
|
||||
# keep only latest version of each name
|
||||
|
||||
for fname in fnames:
|
||||
if not ('db' + db in fname and fname.endswith('.stdout')):
|
||||
continue
|
||||
indexkey, res, _, stats = parse_result_file(logdir + fname)
|
||||
if res.size == 0:
|
||||
missing.append(fname)
|
||||
errorline = open(
|
||||
logdir + fname.replace('.stdout', '.stderr')).readlines()
|
||||
if len(errorline) > 0:
|
||||
errorline = errorline[-1]
|
||||
def collect_results_for(db='deep1M', prefix="autotune."):
|
||||
# run parsing
|
||||
allres = {}
|
||||
allstats = {}
|
||||
missing = []
|
||||
|
||||
fnames = keep_latest_stdout(os.listdir(logdir))
|
||||
# print fnames
|
||||
# filenames are in the form <key>.x.stdout
|
||||
# where x is a version number (from a to z)
|
||||
# keep only latest version of each name
|
||||
|
||||
for fname in fnames:
|
||||
if not (
|
||||
'db' + db in fname and
|
||||
fname.startswith(prefix) and
|
||||
fname.endswith('.stdout')
|
||||
):
|
||||
continue
|
||||
print("parse", fname, end=" ", flush=True)
|
||||
try:
|
||||
indexkey, res, _, stats = parse_result_file(logdir + fname)
|
||||
except RuntimeError as e:
|
||||
print("FAIL %s" % e)
|
||||
res = np.zeros((2, 0))
|
||||
except Exception as e:
|
||||
print("PARSE ERROR " + e)
|
||||
res = np.zeros((2, 0))
|
||||
else:
|
||||
errorline = 'NO STDERR'
|
||||
print(fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline)
|
||||
print(len(res), "results")
|
||||
if res.size == 0:
|
||||
missing.append(fname)
|
||||
else:
|
||||
if indexkey in allres:
|
||||
if allstats[indexkey]['run_version'] > stats['run_version']:
|
||||
# don't use this run
|
||||
continue
|
||||
|
||||
else:
|
||||
if indexkey in allres:
|
||||
if allstats[indexkey]['run_version'] > stats['run_version']:
|
||||
# don't use this run
|
||||
continue
|
||||
n_threads = stats.get('n_threads', 1)
|
||||
nts.append(n_threads)
|
||||
allres[indexkey] = res
|
||||
allstats[indexkey] = stats
|
||||
allres[indexkey] = res
|
||||
allstats[indexkey] = stats
|
||||
|
||||
assert len(set(nts)) == 1
|
||||
n_threads = nts[0]
|
||||
|
||||
|
||||
def plot_tradeoffs(allres, code_size, recall_rank):
|
||||
dbsize = dbsize_from_name(db)
|
||||
recall_idx = int(np.log10(recall_rank))
|
||||
return allres, allstats
|
||||
|
||||
def extract_pareto_optimal(allres, keys, recall_idx=0, times_idx=3):
|
||||
bigtab = []
|
||||
names = []
|
||||
|
||||
for k,v in sorted(allres.items()):
|
||||
if v.ndim != 2: continue
|
||||
us = unitsize(k)
|
||||
if us != code_size: continue
|
||||
for i, k in enumerate(keys):
|
||||
v = allres[k]
|
||||
perf = v[:, recall_idx]
|
||||
times = v[:, 3]
|
||||
times = v[:, times_idx]
|
||||
bigtab.append(
|
||||
np.vstack((
|
||||
np.ones(times.size, dtype=int) * len(names),
|
||||
np.ones(times.size) * i,
|
||||
perf, times
|
||||
))
|
||||
)
|
||||
names.append(k)
|
||||
if bigtab == []:
|
||||
return [], np.zeros((3, 0))
|
||||
|
||||
bigtab = np.hstack(bigtab)
|
||||
|
||||
# sort by perf
|
||||
perm = np.argsort(bigtab[1, :])
|
||||
bigtab = bigtab[:, perm]
|
||||
bigtab_sorted = bigtab[:, perm]
|
||||
best_times = np.minimum.accumulate(bigtab_sorted[2, ::-1])[::-1]
|
||||
selection, = np.where(bigtab_sorted[2, :] == best_times)
|
||||
selected_keys = [
|
||||
keys[i] for i in
|
||||
np.unique(bigtab_sorted[0, selection].astype(int))
|
||||
]
|
||||
ops = bigtab_sorted[:, selection]
|
||||
|
||||
times = np.minimum.accumulate(bigtab[2, ::-1])[::-1]
|
||||
selection = np.where(bigtab[2, :] == times)
|
||||
return selected_keys, ops
|
||||
|
||||
selected_methods = [names[i] for i in
|
||||
np.unique(bigtab[0, selection].astype(int))]
|
||||
not_selected = list(set(names) - set(selected_methods))
|
||||
|
||||
print("methods without an optimal OP: ", not_selected)
|
||||
|
||||
nq = 10000
|
||||
pyplot.title('database ' + db + ' code_size=%d' % code_size)
|
||||
|
||||
# grayed out lines
|
||||
|
||||
for k in not_selected:
|
||||
v = allres[k]
|
||||
if v.ndim != 2: continue
|
||||
us = unitsize(k)
|
||||
if us != code_size: continue
|
||||
|
||||
linestyle = (':' if 'PQ' in k else
|
||||
'-.' if 'SQ4' in k else
|
||||
'--' if 'SQ8' in k else '-')
|
||||
|
||||
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None,
|
||||
linestyle=linestyle,
|
||||
marker='o' if 'HNSW' in k else '+',
|
||||
color='#cccccc', linewidth=0.2)
|
||||
def plot_subset(
|
||||
allres, allstats, selected_methods, recall_idx, times_idx=3,
|
||||
report=["overhead", "build time"]):
|
||||
|
||||
# important methods
|
||||
for k in selected_methods:
|
||||
v = allres[k]
|
||||
if v.ndim != 2: continue
|
||||
us = unitsize(k)
|
||||
if us != code_size: continue
|
||||
|
||||
stats = allstats[k]
|
||||
tot_size = stats['index_size'] + stats['tables_size']
|
||||
d = stats["d"]
|
||||
dbsize = stats["nb"]
|
||||
if "index_size" in stats and "tables_size" in stats:
|
||||
tot_size = stats['index_size'] + stats['tables_size']
|
||||
else:
|
||||
tot_size = -1
|
||||
id_size = 8 # 64 bit
|
||||
|
||||
addt = ''
|
||||
|
@ -230,18 +213,107 @@ def plot_tradeoffs(allres, code_size, recall_rank):
|
|||
add_sec = int(add_time)
|
||||
addt = ', %dm%02d' % (add_sec / 60, add_sec % 60)
|
||||
|
||||
code_size = unitsize(d, k)
|
||||
|
||||
label = k + ' (size+%.1f%%%s)' % (
|
||||
tot_size / float((code_size + id_size) * dbsize) * 100 - 100,
|
||||
addt)
|
||||
label = k
|
||||
|
||||
if "code_size" in report:
|
||||
label += " %d bytes" % code_size
|
||||
|
||||
tight_size = (code_size + id_size) * dbsize
|
||||
|
||||
if tot_size < 0 or "overhead" not in report:
|
||||
pass # don't know what the index size is
|
||||
elif tot_size > 10 * tight_size:
|
||||
label += " overhead x%.1f" % (tot_size / tight_size)
|
||||
else:
|
||||
label += " overhead+%.1f%%" % (
|
||||
tot_size / tight_size * 100 - 100)
|
||||
|
||||
if "build time" in report:
|
||||
label += " " + addt
|
||||
|
||||
linestyle = (':' if 'Refine' in k or 'RFlat' in k else
|
||||
'-.' if 'SQ' in k else
|
||||
'-' if '4fs' in k else
|
||||
'-')
|
||||
print(k, linestyle)
|
||||
pyplot.semilogy(v[:, recall_idx], 1000 / v[:, times_idx], label=label,
|
||||
linestyle=linestyle,
|
||||
marker='o' if '4fs' in k else '+')
|
||||
|
||||
recall_rank = stats["ranks"][recall_idx]
|
||||
if stats["measure"] == "recall":
|
||||
pyplot.xlabel('1-recall at %d' % recall_rank)
|
||||
elif stats["measure"] == "inter":
|
||||
pyplot.xlabel('inter @ %d' % recall_rank)
|
||||
else:
|
||||
assert False
|
||||
pyplot.ylabel('QPS (%d threads)' % stats["n_threads"])
|
||||
|
||||
|
||||
def plot_tradeoffs(db, allres, allstats, code_size, recall_rank):
|
||||
stat0 = next(iter(allstats.values()))
|
||||
d = stat0["d"]
|
||||
n_threads = stat0["n_threads"]
|
||||
nq = stat0["nq"]
|
||||
recall_idx = stat0["ranks"].index(recall_rank)
|
||||
# times come after the perf measure
|
||||
times_idx = len(stat0["ranks"])
|
||||
|
||||
if type(code_size) == int:
|
||||
if code_size == 0:
|
||||
code_size = [0, 1e50]
|
||||
code_size_name = "any code size"
|
||||
else:
|
||||
code_size_name = "code_size=%d" % code_size
|
||||
code_size = [code_size, code_size]
|
||||
elif type(code_size) == tuple:
|
||||
code_size_name = "code_size in [%d, %d]" % code_size
|
||||
else:
|
||||
assert False
|
||||
|
||||
names_maxperf = []
|
||||
|
||||
for k in sorted(allres):
|
||||
v = allres[k]
|
||||
if v.ndim != 2: continue
|
||||
us = unitsize(d, k)
|
||||
if not code_size[0] <= us <= code_size[1]: continue
|
||||
names_maxperf.append((v[-1, recall_idx], k))
|
||||
|
||||
# sort from lowest to highest topline accuracy
|
||||
names_maxperf.sort()
|
||||
names = [name for mp, name in names_maxperf]
|
||||
|
||||
selected_methods, optimal_points = \
|
||||
extract_pareto_optimal(allres, names, recall_idx, times_idx)
|
||||
|
||||
not_selected = list(set(names) - set(selected_methods))
|
||||
|
||||
print("methods without an optimal OP: ", not_selected)
|
||||
|
||||
pyplot.title('database ' + db + ' ' + code_size_name)
|
||||
|
||||
# grayed out lines
|
||||
|
||||
for k in not_selected:
|
||||
v = allres[k]
|
||||
if v.ndim != 2: continue
|
||||
us = unitsize(d, k)
|
||||
if not code_size[0] <= us <= code_size[1]: continue
|
||||
|
||||
linestyle = (':' if 'PQ' in k else
|
||||
'-.' if 'SQ4' in k else
|
||||
'--' if 'SQ8' in k else '-')
|
||||
|
||||
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label,
|
||||
pyplot.semilogy(v[:, recall_idx], 1000 / v[:, times_idx], label=None,
|
||||
linestyle=linestyle,
|
||||
marker='o' if 'HNSW' in k else '+')
|
||||
marker='o' if 'HNSW' in k else '+',
|
||||
color='#cccccc', linewidth=0.2)
|
||||
|
||||
plot_subset(allres, allstats, selected_methods, recall_idx, times_idx)
|
||||
|
||||
|
||||
if len(not_selected) == 0:
|
||||
om = ''
|
||||
|
@ -255,15 +327,175 @@ def plot_tradeoffs(allres, code_size, recall_rank):
|
|||
om += ' ' + m
|
||||
nc += len(m) + 1
|
||||
|
||||
# pyplot.semilogy(optimal_points[1, :], optimal_points[2, :], marker="s")
|
||||
# print(optimal_points[0, :])
|
||||
pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) )
|
||||
pyplot.ylabel('search time per query (ms, %d threads)' % n_threads)
|
||||
pyplot.ylabel('QPS (%d threads)' % n_threads)
|
||||
pyplot.legend()
|
||||
pyplot.grid()
|
||||
pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % (
|
||||
db, code_size, recall_rank))
|
||||
return selected_methods, not_selected
|
||||
|
||||
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
|
||||
plot_tradeoffs(allres, code_size=code_size, recall_rank=1)
|
||||
if __name__ == "__main__xx":
|
||||
# tests on centroids indexing (v1)
|
||||
|
||||
for k in 1, 32, 128:
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
i = 1
|
||||
for ncent in 65536, 262144, 1048576, 4194304:
|
||||
db = f'deep_centroids_{ncent}.k{k}.'
|
||||
allres, allstats = collect_results_for(
|
||||
db=db, prefix="cent_index.")
|
||||
|
||||
pyplot.subplot(2, 2, i)
|
||||
plot_subset(
|
||||
allres, allstats, list(allres.keys()),
|
||||
recall_idx=0,
|
||||
times_idx=1,
|
||||
report=["code_size"]
|
||||
)
|
||||
i += 1
|
||||
pyplot.title(f"{ncent} centroids")
|
||||
pyplot.legend()
|
||||
pyplot.xlim([0.95, 1])
|
||||
pyplot.grid()
|
||||
|
||||
pyplot.savefig('figs/deep1B_centroids_k%d.png' % k)
|
||||
|
||||
|
||||
if __name__ == "__main__xx":
|
||||
# centroids plot per k
|
||||
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
|
||||
i=1
|
||||
for ncent in 65536, 262144, 1048576, 4194304:
|
||||
|
||||
xyd = defaultdict(list)
|
||||
|
||||
for k in 1, 4, 8, 16, 32, 64, 128, 256:
|
||||
|
||||
db = f'deep_centroids_{ncent}.k{k}.'
|
||||
allres, allstats = collect_results_for(db=db, prefix="cent_index.")
|
||||
|
||||
for indexkey, res in allres.items():
|
||||
idx, = np.where(res[:, 0] >= 0.99)
|
||||
if idx.size > 0:
|
||||
xyd[indexkey].append((k, 1000 / res[idx[0], 1]))
|
||||
|
||||
pyplot.subplot(2, 2, i)
|
||||
i += 1
|
||||
for indexkey, xy in xyd.items():
|
||||
xy = np.array(xy)
|
||||
pyplot.loglog(xy[:, 0], xy[:, 1], 'o-', label=indexkey)
|
||||
|
||||
pyplot.title(f"{ncent} centroids")
|
||||
pyplot.xlabel("k")
|
||||
xt = 2**np.arange(9)
|
||||
pyplot.xticks(xt, ["%d" % x for x in xt])
|
||||
pyplot.ylabel("QPS (32 threads)")
|
||||
pyplot.legend()
|
||||
pyplot.grid()
|
||||
|
||||
pyplot.savefig('../plots/deep1B_centroids_min99.png')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__xx":
|
||||
# main indexing plots
|
||||
|
||||
i = 0
|
||||
for db in 'bigann10M', 'deep10M', 'bigann100M', 'deep100M', 'deep1B', 'bigann1B':
|
||||
allres, allstats = collect_results_for(
|
||||
db=db, prefix="autotune.")
|
||||
|
||||
for cs in 8, 16, 32, 64:
|
||||
pyplot.figure(i)
|
||||
i += 1
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
|
||||
cs_range = (
|
||||
(0, 8) if cs == 8 else (cs // 2 + 1, cs)
|
||||
)
|
||||
|
||||
plot_tradeoffs(
|
||||
db, allres, allstats, code_size=cs_range, recall_rank=1)
|
||||
pyplot.savefig('../plots/tradeoffs_%s_cs%d_r1.png' % (
|
||||
db, cs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1M indexes
|
||||
i = 0
|
||||
for db in "glove", "music-100":
|
||||
pyplot.figure(i)
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
i += 1
|
||||
allres, allstats = collect_results_for(db=db, prefix="autotune.")
|
||||
plot_tradeoffs(db, allres, allstats, code_size=0, recall_rank=1)
|
||||
pyplot.savefig('../plots/1M_tradeoffs_' + db + ".png")
|
||||
|
||||
for db in "sift1M", "deep1M":
|
||||
allres, allstats = collect_results_for(db=db, prefix="autotune.")
|
||||
pyplot.figure(i)
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
i += 1
|
||||
plot_tradeoffs(db, allres, allstats, code_size=(0, 64), recall_rank=1)
|
||||
pyplot.savefig('../plots/1M_tradeoffs_' + db + "_small.png")
|
||||
|
||||
pyplot.figure(i)
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
i += 1
|
||||
plot_tradeoffs(db, allres, allstats, code_size=(65, 10000), recall_rank=1)
|
||||
pyplot.savefig('../plots/1M_tradeoffs_' + db + "_large.png")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__xx":
|
||||
db = 'sift1M'
|
||||
allres, allstats = collect_results_for(db=db, prefix="autotune.")
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
|
||||
keys = [
|
||||
"IVF1024,PQ32x8",
|
||||
"IVF1024,PQ64x4",
|
||||
"IVF1024,PQ64x4fs",
|
||||
"IVF1024,PQ64x4fsr",
|
||||
"IVF1024,SQ4",
|
||||
"IVF1024,SQ8"
|
||||
]
|
||||
|
||||
plot_subset(allres, allstats, keys, recall_idx=0, report=["code_size"])
|
||||
|
||||
pyplot.legend()
|
||||
pyplot.title(db)
|
||||
pyplot.xlabel("1-recall@1")
|
||||
pyplot.ylabel("QPS (32 threads)")
|
||||
pyplot.grid()
|
||||
|
||||
pyplot.savefig('../plots/ivf1024_variants.png')
|
||||
|
||||
pyplot.figure(2)
|
||||
pyplot.gcf().set_size_inches(15, 10)
|
||||
|
||||
keys = [
|
||||
"HNSW32",
|
||||
"IVF1024,PQ64x4fs",
|
||||
"IVF1024,PQ64x4fsr",
|
||||
"IVF1024,PQ64x4fs,RFlat",
|
||||
"IVF1024,PQ64x4fs,Refine(SQfp16)",
|
||||
"IVF1024,PQ64x4fs,Refine(SQ8)",
|
||||
]
|
||||
|
||||
plot_subset(allres, allstats, keys, recall_idx=0, report=["code_size"])
|
||||
|
||||
pyplot.legend()
|
||||
pyplot.title(db)
|
||||
pyplot.xlabel("1-recall@1")
|
||||
pyplot.ylabel("QPS (32 threads)")
|
||||
pyplot.grid()
|
||||
|
||||
pyplot.savefig('../plots/ivf1024_rerank.png')
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
set -e
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
|
@ -18,95 +20,452 @@
|
|||
#
|
||||
# the stdout of the command should be stored in $logdir/<name>.stdout
|
||||
|
||||
function run_on_1machine () {
|
||||
# To be implemented
|
||||
|
||||
function run_on ()
|
||||
{
|
||||
sys="$1"
|
||||
shift
|
||||
name="$1"
|
||||
shift
|
||||
script="$logdir/$name.sh"
|
||||
|
||||
if [ -e "$script" ]; then
|
||||
echo script "$script" exists
|
||||
return
|
||||
fi
|
||||
|
||||
# srun handles special characters fine, but the shell interpreter
|
||||
# does not
|
||||
escaped_cmd=$( printf "%q " "$@" )
|
||||
|
||||
cat > $script <<EOF
|
||||
#! /bin/bash
|
||||
srun $escaped_cmd
|
||||
EOF
|
||||
|
||||
echo -n "$logdir/$name.stdout "
|
||||
sbatch -n1 -J "$name" \
|
||||
$sys \
|
||||
--comment='priority is the only one that works' \
|
||||
--output="$logdir/$name.stdout" \
|
||||
"$script"
|
||||
|
||||
}
|
||||
|
||||
|
||||
function run_on_1machine {
|
||||
run_on "--cpus-per-task=80 --gres=gpu:0 --mem=500G --time=70:00:00 --partition=priority" "$@"
|
||||
}
|
||||
|
||||
function run_on_1machine_1h {
|
||||
run_on "--cpus-per-task=80 --gres=gpu:2 --mem=100G --time=1:00:00 --partition=priority" "$@"
|
||||
}
|
||||
|
||||
function run_on_1machine_3h {
|
||||
run_on "--cpus-per-task=80 --gres=gpu:2 --mem=100G --time=3:00:00 --partition=priority" "$@"
|
||||
}
|
||||
|
||||
function run_on_4gpu_3h {
|
||||
run_on "--cpus-per-task=40 --gres=gpu:4 --mem=100G --time=3:00:00 --partition=priority" "$@"
|
||||
}
|
||||
|
||||
function run_on_8gpu () {
|
||||
# To be implemented
|
||||
run_on "--cpus-per-task=80 --gres=gpu:8 --mem=100G --time=70:00:00 --partition=priority" "$@"
|
||||
}
|
||||
|
||||
|
||||
# prepare output directories
|
||||
# set to some directory where all indexes, can be written.
|
||||
basedir=XXXXX
|
||||
basedir=/checkpoint/matthijs/bench_all_ivf
|
||||
|
||||
logdir=$basedir/logs
|
||||
indexdir=$basedir/indexes
|
||||
centdir=$basedir/precomputed_clusters
|
||||
|
||||
mkdir -p $lars $logdir $indexdir
|
||||
mkdir -p $logdir $indexdir
|
||||
|
||||
|
||||
############################### 1M experiments
|
||||
# adds an option to use a pretrained quantizer
|
||||
function add_precomputed_quantizer () {
|
||||
local db="$1"
|
||||
local coarse="$2"
|
||||
|
||||
for db in sift1M deep1M bigann1M; do
|
||||
case $db in
|
||||
bigann*) rname=bigann ;;
|
||||
deep*) rname=deep ;;
|
||||
sift1M) return;;
|
||||
music-100) return ;;
|
||||
glove) return ;;
|
||||
*) echo "bad db"; exit 1;;
|
||||
esac
|
||||
|
||||
for coarse in IMI2x9 IMI2x10 IVF1024_HNSW32 IVF4096_HNSW32 IVF16384_HNSW32
|
||||
case $coarse in
|
||||
IVF65536*)
|
||||
cname=clustering.db${rname}1M.IVF65536.faissindex
|
||||
copt="--get_centroids_from $centdir/$cname"
|
||||
;;
|
||||
IVF262144*)
|
||||
cname=clustering.db${rname}1M.IVF262144.faissindex
|
||||
copt="--get_centroids_from $centdir/$cname"
|
||||
;;
|
||||
IVF1048576*)
|
||||
cname=clustering.db${rname}1M.IVF1048576.faissindex
|
||||
copt="--get_centroids_from $centdir/$cname"
|
||||
;;
|
||||
IVF4194304*)
|
||||
cname=clustering.db${rname}1M.IVF4194304.faissindex
|
||||
copt="--get_centroids_from $centdir/$cname"
|
||||
;;
|
||||
*)
|
||||
copt="" ;;
|
||||
esac
|
||||
|
||||
echo $copt
|
||||
}
|
||||
|
||||
function get_db_dim () {
|
||||
local db="$1"
|
||||
case $db in
|
||||
sift1M) dim=128;;
|
||||
bigann*) dim=128;;
|
||||
deep*) dim=96;;
|
||||
music-100) dim=100;;
|
||||
glove) dim=100;;
|
||||
*) echo "bad db"; exit 1;;
|
||||
esac
|
||||
echo $dim
|
||||
}
|
||||
|
||||
|
||||
# replace HD = half dim with the half of the dimension we need to handle
|
||||
# relying that variables are global by default...
|
||||
function replace_coarse_PQHD () {
|
||||
local coarse="$1"
|
||||
local dim=$2
|
||||
|
||||
|
||||
coarseD=${coarse//PQHD/PQ$((dim/2))}
|
||||
coarse16=${coarse//PQHD/PQ8}
|
||||
coarse32=${coarse//PQHD/PQ16}
|
||||
coarse64=${coarse//PQHD/PQ32}
|
||||
coarse128=${coarse//PQHD/PQ64}
|
||||
coarse256=${coarse//PQHD/PQ128}
|
||||
coarse112=${coarse//PQHD/PQ56}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
if false; then
|
||||
|
||||
|
||||
|
||||
###############################################
|
||||
# comparison with SCANN
|
||||
|
||||
for db in sift1M deep1M glove music-100
|
||||
do
|
||||
opt=""
|
||||
if [ $db == glove ]; then
|
||||
opt="--measure inter"
|
||||
fi
|
||||
|
||||
run_on_1machine_1h cmp_with_scann.$db.c \
|
||||
python -u cmp_with_scann.py --db $db \
|
||||
--lib faiss $opt --thenscann
|
||||
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
||||
############################### Preliminary SIFT1M experiment
|
||||
|
||||
|
||||
for db in sift1M ; do
|
||||
|
||||
for coarse in IVF1024
|
||||
do
|
||||
indexkeys="
|
||||
HNSW32
|
||||
$coarse,SQfp16
|
||||
$coarse,SQ4
|
||||
$coarse,SQ8
|
||||
$coarse,PQ32x8
|
||||
$coarse,PQ64x4
|
||||
$coarse,PQ64x4fs
|
||||
$coarse,PQ64x4fs,RFlat
|
||||
$coarse,PQ64x4fs,Refine(SQfp16)
|
||||
$coarse,PQ64x4fs,Refine(SQ8)
|
||||
OPQ64,$coarse,PQ64x4fs
|
||||
OPQ64,$coarse,PQ64x4fs,RFlat
|
||||
"
|
||||
indexkeys="
|
||||
$coarse,PQ64x4fsr
|
||||
$coarse,PQ64x4fsr,RFlat
|
||||
"
|
||||
|
||||
for indexkey in \
|
||||
OPQ8_64,$coarse,PQ8 \
|
||||
PCAR16,$coarse,SQ4 \
|
||||
OPQ16_64,$coarse,PQ16 \
|
||||
PCAR32,$coarse,SQ4 \
|
||||
PCAR16,$coarse,SQ8 \
|
||||
OPQ32_128,$coarse,PQ32 \
|
||||
PCAR64,$coarse,SQ4 \
|
||||
PCAR32,$coarse,SQ8 \
|
||||
PCAR16,$coarse,SQfp16 \
|
||||
PCAR64,$coarse,SQ8 \
|
||||
PCAR32,$coarse,SQfp16 \
|
||||
PCAR128,$coarse,SQ4
|
||||
# OPQ actually degrades the results on SIFT1M, so let's ignore
|
||||
|
||||
for indexkey in $indexkeys
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
run_on_1machine $key \
|
||||
# escape nasty characters
|
||||
key="autotune.db$db.${indexkey//,/_}"
|
||||
key="${key//(/_}"
|
||||
key="${key//)/_}"
|
||||
run_on_1machine_1h $key.a \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey $indexkey \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--indexfile $indexdir/$key.faissindex
|
||||
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--searchthreads 32
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
||||
############################### 1M experiments
|
||||
|
||||
fi
|
||||
# for db in sift1M deep1M music-100 glove; do
|
||||
|
||||
for db in glove music-100; do
|
||||
|
||||
dim=$( get_db_dim $db )
|
||||
|
||||
for coarse in IVF1024 IVF4096_HNSW32
|
||||
do
|
||||
|
||||
replace_coarse_PQHD "$coarse" $dim
|
||||
|
||||
indexkeys="
|
||||
$coarseD,PQ$((dim/2))x4fs
|
||||
$coarseD,PQ$((dim/2))x4fsr
|
||||
|
||||
OPQ8_64,$coarse64,PQ8
|
||||
PCAR16,$coarse16,SQ4
|
||||
OPQ16_64,$coarse64,PQ16x4fs
|
||||
OPQ16_64,$coarse64,PQ16x4fsr
|
||||
|
||||
OPQ16_64,$coarse64,PQ16
|
||||
PCAR16,$coarse16,SQ8
|
||||
PCAR32,$coarse32,SQ4
|
||||
OPQ32_64,$coarse64,PQ32x4fs
|
||||
OPQ32_64,$coarse64,PQ32x4fsr
|
||||
|
||||
OPQ32_128,$coarse128,PQ32
|
||||
PCAR32,$coarse32,SQ8
|
||||
PCAR64,$coarse64,SQ4
|
||||
PCAR16,$coarse16,SQfp16
|
||||
OPQ64_128,$coarse128,PQ64x4fs
|
||||
OPQ64_128,$coarse128,PQ64x4fsr
|
||||
|
||||
OPQ64_128,$coarse128,PQ64
|
||||
PCAR64,$coarse64,SQ8
|
||||
PCAR32,$coarse32,SQfp16
|
||||
PCAR128,$coarse128,SQ4
|
||||
OPQ128_256,$coarse256,PQ128x4fs
|
||||
OPQ128_256,$coarse256,PQ128x4fsr
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(OPQ56_112,PQ56)
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(PCAR72,SQ6)
|
||||
OPQ32_64,$coarse64,PQ16x4fs,Refine(PCAR64,SQ6)
|
||||
OPQ32_64,$coarse64,PQ32x4fs,Refine(OPQ48_96,PQ48)
|
||||
OPQ64_128,$coarse,PQ64x12
|
||||
|
||||
OPQ64_128,$coarse,PQ64x4fs,RFlat
|
||||
OPQ64_128,$coarse,PQ64x4fs,Refine(SQfp16)
|
||||
OPQ64_128,$coarse,PQ64x4fs,Refine(SQ8)
|
||||
OPQ64_128,$coarse,PQ64x4fs,Refine(SQ6)
|
||||
OPQ64_128,$coarse,PQ64x4fs,Refine(SQ4)
|
||||
OPQ32_64,$coarse,PQ32x4fs,Refine(SQfp16)
|
||||
OPQ32_64,$coarse,PQ32x4fs,Refine(SQ8)
|
||||
OPQ32_64,$coarse,PQ32x4fs,Refine(SQ6)
|
||||
OPQ32_64,$coarse,PQ32x4fs,Refine(SQ4)
|
||||
|
||||
"
|
||||
|
||||
indexkeys="
|
||||
$coarseD,PQ$((dim/2))x4fs
|
||||
$coarseD,PQ$((dim/2))x4fsr
|
||||
$coarseD,PQ$((dim/2))x4fsr,RFlat
|
||||
$coarseD,PQ$((dim/2))x4fsr,Refine(SQfp16)
|
||||
$coarseD,PQ$((dim/2))x4fsr,Refine(SQ8)
|
||||
$coarseD,PQ$((dim/4))x4fs
|
||||
$coarseD,PQ$((dim/4))x4fsr
|
||||
$coarseD,PQ$((dim/4))x4fsr,RFlat
|
||||
$coarseD,PQ$((dim/4))x4fsr,Refine(SQfp16)
|
||||
$coarseD,PQ$((dim/4))x4fsr,Refine(SQ8)
|
||||
$coarseD,PQ$((dim/2))
|
||||
$coarseD,PQ$((dim/4))
|
||||
HNSW32,Flat
|
||||
"
|
||||
|
||||
indexkeys="HNSW32,Flat"
|
||||
|
||||
for indexkey in $indexkeys
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
key="${key//(/_}"
|
||||
key="${key//)/_}"
|
||||
run_on_1machine_3h $key.q \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--indexfile "$indexdir/$key.faissindex" \
|
||||
$( add_precomputed_quantizer $db $coarse ) \
|
||||
--searchthreads 32 \
|
||||
--min_test_duration 3
|
||||
done
|
||||
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
if false; then
|
||||
|
||||
############################################
|
||||
# precompute centroids on GPU for large vocabularies
|
||||
|
||||
for db in deep1M bigann1M; do
|
||||
|
||||
for ncent in 262144 65536 1048576 4194304; do
|
||||
|
||||
key=clustering.db$db.IVF$ncent
|
||||
run_on_4gpu_3h $key.e \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey IVF$ncent,SQ8 \
|
||||
--maxtrain 100000000 \
|
||||
--indexfile $centdir/$key.faissindex \
|
||||
--searchthreads 32 \
|
||||
--min_test_duration 3 \
|
||||
--add_bs 1000000 \
|
||||
--train_on_gpu
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
###############################
|
||||
## coarse quantizer experiments on the centroids of deep1B
|
||||
|
||||
|
||||
for k in 4 8 16 64 256; do
|
||||
|
||||
for ncent in 65536 262144 1048576 4194304; do
|
||||
db=deep_centroids_$ncent
|
||||
|
||||
# compute square root of ncent...
|
||||
for(( ls=0; ncent > (1 << (2 * ls)); ls++)); do
|
||||
echo -n
|
||||
done
|
||||
sncent=$(( 1 << ls ))
|
||||
|
||||
indexkeys="
|
||||
IVF$((sncent/2)),PQ48x4fs,RFlat
|
||||
IVF$((sncent*2)),PQ48x4fs,RFlat
|
||||
HNSW32
|
||||
PQ48x4fs
|
||||
PQ48x4fs,RFlat
|
||||
IVF$sncent,PQ48x4fs,RFlat
|
||||
"
|
||||
|
||||
for indexkey in $indexkeys; do
|
||||
key="cent_index.db$db.k$k.$indexkey"
|
||||
run_on_1machine_1h "$key.b" \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--inter \
|
||||
--searchthreads 32 \
|
||||
--k $k
|
||||
done
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
############################### 10M experiments
|
||||
|
||||
|
||||
for db in deep10M bigann10M; do
|
||||
|
||||
for coarse in \
|
||||
IMI2x10 IMI2x11 IMI2x12 IMI2x13 IVF4096_HNSW32 \
|
||||
IVF16384_HNSW32 IVF65536_HNSW32 IVF262144_HNSW32
|
||||
coarses="
|
||||
IVF65536(IVF256,PQHDx4fs,RFlat)
|
||||
IVF16384_HNSW32
|
||||
IVF65536_HNSW32
|
||||
IVF262144_HNSW32
|
||||
IVF262144(IVF512,PQHDx4fs,RFlat)
|
||||
"
|
||||
|
||||
dim=$( get_db_dim $db )
|
||||
|
||||
for coarse in $coarses
|
||||
do
|
||||
|
||||
for indexkey in \
|
||||
OPQ8_64,$coarse,PQ8 \
|
||||
PCAR16,$coarse,SQ4 \
|
||||
OPQ16_64,$coarse,PQ16 \
|
||||
PCAR32,$coarse,SQ4 \
|
||||
PCAR16,$coarse,SQ8 \
|
||||
OPQ32_128,$coarse,PQ32 \
|
||||
PCAR64,$coarse,SQ4 \
|
||||
PCAR32,$coarse,SQ8 \
|
||||
PCAR16,$coarse,SQfp16 \
|
||||
PCAR64,$coarse,SQ8 \
|
||||
PCAR32,$coarse,SQfp16 \
|
||||
PCAR128,$coarse,SQ4 \
|
||||
OPQ64_128,$coarse,PQ64
|
||||
replace_coarse_PQHD "$coarse" $dim
|
||||
|
||||
indexkeys="
|
||||
$coarseD,PQ$((dim/2))x4fs
|
||||
|
||||
OPQ8_64,$coarse64,PQ8
|
||||
PCAR16,$coarse16,SQ4
|
||||
OPQ16_64,$coarse64,PQ16x4fs
|
||||
OPQ16_64,$coarse64,PQ16x4fsr
|
||||
|
||||
OPQ16_64,$coarse64,PQ16
|
||||
PCAR16,$coarse16,SQ8
|
||||
PCAR32,$coarse32,SQ4
|
||||
OPQ32_64,$coarse64,PQ32x4fs
|
||||
OPQ32_64,$coarse64,PQ32x4fsr
|
||||
|
||||
OPQ32_128,$coarse128,PQ32
|
||||
PCAR32,$coarse32,SQ8
|
||||
PCAR64,$coarse64,SQ4
|
||||
PCAR16,$coarse16,SQfp16
|
||||
OPQ64_128,$coarse128,PQ64x4fs
|
||||
OPQ64_128,$coarse128,PQ64x4fsr
|
||||
|
||||
OPQ64_128,$coarse128,PQ64
|
||||
PCAR64,$coarse64,SQ8
|
||||
PCAR32,$coarse32,SQfp16
|
||||
PCAR128,$coarse128,SQ4
|
||||
OPQ128_256,$coarse256,PQ128x4fs
|
||||
OPQ128_256,$coarse256,PQ128x4fsr
|
||||
OPQ56_112,$coarse112,PQ7+56
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(OPQ56_112,PQ56)
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(PCAR72,SQ6)
|
||||
OPQ32_64,$coarse64,PQ16x4fs,Refine(PCAR64,SQ6)
|
||||
OPQ32_64,$coarse64,PQ32x4fs,Refine(OPQ48_96,PQ48)
|
||||
"
|
||||
|
||||
indexkeys="
|
||||
OPQ16_64,$coarse64,PQ16x4fsr
|
||||
OPQ32_64,$coarse64,PQ32x4fsr
|
||||
OPQ64_128,$coarse128,PQ64x4fsr
|
||||
OPQ128_256,$coarse256,PQ128x4fsr
|
||||
"
|
||||
|
||||
|
||||
for indexkey in $indexkeys
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
run_on_1machine $key \
|
||||
python -u bench_all_ivf.py \
|
||||
key="${key//(/_}"
|
||||
key="${key//)/_}"
|
||||
run_on_1machine_3h $key.l \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey $indexkey \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--searchthreads 16 \
|
||||
--indexfile "$indexdir/$key.faissindex" \
|
||||
$( add_precomputed_quantizer $db $coarse ) \
|
||||
--searchthreads 32 \
|
||||
--min_test_duration 3 \
|
||||
|
||||
--autotune_max nprobe:2000
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -115,135 +474,130 @@ done
|
|||
############################### 100M experiments
|
||||
|
||||
for db in deep100M bigann100M; do
|
||||
coarses="
|
||||
IVF65536_HNSW32
|
||||
IVF262144_HNSW32
|
||||
IVF262144(IVF512,PQHDx4fs,RFlat)
|
||||
IVF1048576_HNSW32
|
||||
IVF1048576(IVF1024,PQHDx4fs,RFlat)
|
||||
"
|
||||
dim=$( get_db_dim $db )
|
||||
|
||||
for coarse in IMI2x11 IMI2x12 IVF65536_HNSW32 IVF262144_HNSW32
|
||||
for coarse in $coarses
|
||||
do
|
||||
replace_coarse_PQHD "$coarse" $dim
|
||||
|
||||
for indexkey in \
|
||||
OPQ8_64,$coarse,PQ8 \
|
||||
OPQ16_64,$coarse,PQ16 \
|
||||
PCAR32,$coarse,SQ4 \
|
||||
OPQ32_128,$coarse,PQ32 \
|
||||
PCAR64,$coarse,SQ4 \
|
||||
PCAR32,$coarse,SQ8 \
|
||||
PCAR64,$coarse,SQ8 \
|
||||
PCAR32,$coarse,SQfp16 \
|
||||
PCAR128,$coarse,SQ4 \
|
||||
OPQ64_128,$coarse,PQ64
|
||||
indexkeys="
|
||||
OPQ8_64,$coarse64,PQ8
|
||||
OPQ16_64,$coarse64,PQ16x4fs
|
||||
|
||||
PCAR32,$coarse32,SQ4
|
||||
OPQ16_64,$coarse64,PQ16
|
||||
OPQ32_64,$coarse64,PQ32x4fs
|
||||
|
||||
OPQ32_128,$coarse128,PQ32
|
||||
PCAR64,$coarse64,SQ4
|
||||
PCAR32,$coarse32,SQ8
|
||||
OPQ64_128,$coarse128,PQ64x4fs
|
||||
|
||||
PCAR128,$coarse128,SQ4
|
||||
OPQ64_128,$coarse128,PQ64
|
||||
|
||||
PCAR32,$coarse32,SQfp16
|
||||
PCAR64,$coarse64,SQ8
|
||||
OPQ128_256,$coarse256,PQ128x4fs
|
||||
|
||||
OPQ56_112,$coarse112,PQ7+56
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(OPQ56_112,PQ56)
|
||||
|
||||
$coarseD,PQ$((dim/2))x4fs
|
||||
"
|
||||
|
||||
indexkeys="
|
||||
OPQ128_256,$coarse256,PQ128x4fsr
|
||||
OPQ64_128,$coarse128,PQ64x4fsr
|
||||
OPQ32_64,$coarse64,PQ32x4fsr
|
||||
OPQ16_64,$coarse64,PQ16x4fsr
|
||||
OPQ16_64,$coarse64,PQ16x4fsr,Refine(OPQ56_112,PQ56)
|
||||
"
|
||||
|
||||
for indexkey in $indexkeys
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
run_on_1machine $key \
|
||||
key="${key//(/_}"
|
||||
key="${key//)/_}"
|
||||
run_on_1machine $key.e \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey $indexkey \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--searchthreads 16 \
|
||||
--searchthreads 32 \
|
||||
--min_test_duration 3 \
|
||||
--add_bs 1000000
|
||||
$( add_precomputed_quantizer $db $coarse ) \
|
||||
--add_bs 1000000 \
|
||||
--autotune_max nprobe:2000
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
############################### 1B experiments
|
||||
|
||||
for db in deep1B bigann1B; do
|
||||
|
||||
for coarse in IMI2x12 IMI2x13 IVF262144_HNSW32
|
||||
do
|
||||
|
||||
for indexkey in \
|
||||
OPQ8_64,$coarse,PQ8 \
|
||||
OPQ16_64,$coarse,PQ16 \
|
||||
PCAR32,$coarse,SQ4 \
|
||||
OPQ32_128,$coarse,PQ32 \
|
||||
PCAR64,$coarse,SQ4 \
|
||||
PCAR32,$coarse,SQ8 \
|
||||
PCAR64,$coarse,SQ8 \
|
||||
PCAR32,$coarse,SQfp16 \
|
||||
PCAR128,$coarse,SQ4 \
|
||||
PQ64_128,$coarse,PQ64 \
|
||||
RR128,$coarse,SQ4
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
run_on_1machine $key \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey $indexkey \
|
||||
--maxtrain 0 \
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--searchthreads 16 \
|
||||
--min_test_duration 3 \
|
||||
--add_bs 1000000
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
done
|
||||
|
||||
############################################
|
||||
# precompute centroids on GPU for large vocabularies
|
||||
|
||||
|
||||
for db in deep1M bigann1M; do
|
||||
|
||||
for ncent in 1048576 4194304; do
|
||||
|
||||
key=clustering.db$db.IVF$ncent
|
||||
run_on_8gpu $key \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey IVF$ncent,SQ8 \
|
||||
--maxtrain 100000000 \
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--searchthreads 16 \
|
||||
--min_test_duration 3 \
|
||||
--add_bs 1000000 \
|
||||
--train_on_gpu
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
#################################
|
||||
# Run actual experiment
|
||||
# 1B-scale experiment
|
||||
|
||||
|
||||
|
||||
for db in deep1B bigann1B; do
|
||||
coarses="
|
||||
IVF1048576_HNSW32
|
||||
IVF4194304_HNSW32
|
||||
IVF4194304(IVF1024,PQHDx4fs,RFlat)
|
||||
"
|
||||
dim=$( get_db_dim $db )
|
||||
|
||||
for ncent in 1048576 4194304; do
|
||||
coarse=IVF${ncent}_HNSW32
|
||||
centroidsname=clustering.db${db/1B/1M}.IVF${ncent}.faissindex
|
||||
for coarse in $coarses; do
|
||||
|
||||
for indexkey in \
|
||||
OPQ8_64,$coarse,PQ8 \
|
||||
OPQ16_64,$coarse,PQ16 \
|
||||
PCAR32,$coarse,SQ4 \
|
||||
OPQ32_128,$coarse,PQ32 \
|
||||
PCAR64,$coarse,SQ4 \
|
||||
PCAR32,$coarse,SQ8 \
|
||||
PCAR64,$coarse,SQ8 \
|
||||
PCAR32,$coarse,SQfp16 \
|
||||
OPQ64_128,$coarse,PQ64 \
|
||||
RR128,$coarse,SQ4 \
|
||||
OPQ64_128,$coarse,PQ64 \
|
||||
RR128,$coarse,SQ4
|
||||
replace_coarse_PQHD "$coarse" $dim
|
||||
|
||||
|
||||
indexkeys="
|
||||
OPQ8_64,$coarse64,PQ8
|
||||
OPQ16_64,$coarse64,PQ16x4fsr
|
||||
|
||||
OPQ16_64,$coarse64,PQ16
|
||||
OPQ32_64,$coarse64,PQ32x4fsr
|
||||
|
||||
OPQ32_128,$coarse128,PQ32
|
||||
OPQ64_128,$coarse128,PQ64x4fsr
|
||||
|
||||
OPQ64_128,$coarse128,PQ64
|
||||
OPQ128_256,$coarse256,PQ128x4fsr
|
||||
OPQ56_112,$coarse112,PQ7+56
|
||||
OPQ16_64,$coarse64,PQ16x4fs,Refine(OPQ56_112,PQ56)
|
||||
|
||||
$coarseD,PQ$((dim/2))x4fs
|
||||
"
|
||||
|
||||
for indexkey in $indexkeys
|
||||
do
|
||||
key=autotune.db$db.${indexkey//,/_}
|
||||
|
||||
run_on_1machine $key.c $key \
|
||||
key="${key//(/_}"
|
||||
key="${key//)/_}"
|
||||
run_on_1machine $key.d \
|
||||
python -u bench_all_ivf.py \
|
||||
--db $db \
|
||||
--indexkey $indexkey \
|
||||
--maxtrain 256000 \
|
||||
--indexkey "$indexkey" \
|
||||
--maxtrain 0 \
|
||||
--indexfile $indexdir/$key.faissindex \
|
||||
--get_centroids_from $indexdir/$centroidsname \
|
||||
--searchthreads 16 \
|
||||
--searchthreads 32 \
|
||||
--min_test_duration 3 \
|
||||
--add_bs 1000000
|
||||
|
||||
$( add_precomputed_quantizer $db $coarse ) \
|
||||
--add_bs 1000000 \
|
||||
--autotune_max nprobe:3000
|
||||
done
|
||||
done
|
||||
|
||||
done
|
||||
|
||||
fi
|
|
@ -50,3 +50,7 @@ fields and converting them to the proper python array.
|
|||
(may require h5py)
|
||||
|
||||
Defintion of how to access data for some standard datsets.
|
||||
|
||||
### factory_tools.py
|
||||
|
||||
Functions related to factory strings.
|
||||
|
|
|
@ -96,7 +96,7 @@ class SyntheticDataset(Dataset):
|
|||
return self.xq
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain or self.nt
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return self.xt[:maxtrain]
|
||||
|
||||
def get_database(self):
|
||||
|
@ -140,7 +140,7 @@ class DatasetSIFT1M(Dataset):
|
|||
return fvecs_read(self.basedir + "sift_query.fvecs")
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain or self.nt
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return fvecs_read(self.basedir + "sift_learn.fvecs")[:maxtrain]
|
||||
|
||||
def get_database(self):
|
||||
|
@ -176,7 +176,7 @@ class DatasetBigANN(Dataset):
|
|||
return sanitize(bvecs_mmap(self.basedir + 'bigann_query.bvecs')[:])
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain or self.nt
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return sanitize(bvecs_mmap(self.basedir + 'bigann_learn.bvecs')[:maxtrain])
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
|
@ -224,7 +224,7 @@ class DatasetDeep1B(Dataset):
|
|||
return sanitize(fvecs_read(self.basedir + "deep1B_queries.fvecs"))
|
||||
|
||||
def get_train(self, maxtrain=None):
|
||||
maxtrain = maxtrain or self.nt
|
||||
maxtrain = maxtrain if maxtrain is not None else self.nt
|
||||
return sanitize(fvecs_mmap(self.basedir + "learn.fvecs")[:maxtrain])
|
||||
|
||||
def get_groundtruth(self, k=None):
|
||||
|
@ -251,8 +251,9 @@ class DatasetGlove(Dataset):
|
|||
Data from http://ann-benchmarks.com/glove-100-angular.hdf5
|
||||
"""
|
||||
|
||||
def __init__(self, loc=None):
|
||||
def __init__(self, loc=None, download=False):
|
||||
import h5py
|
||||
assert not download, "not implemented"
|
||||
if not loc:
|
||||
loc = dataset_basedir + 'glove/glove-100-angular.hdf5'
|
||||
self.glove_h5py = h5py.File(loc, 'r')
|
||||
|
|
|
@ -16,6 +16,7 @@ def knn_ground_truth(xq, db_iterator, k):
|
|||
does not fit in RAM but for which we have an iterator that
|
||||
returns it block by block.
|
||||
"""
|
||||
LOG.info("knn_ground_truth queries size %s k=%d" % (xq.shape, k))
|
||||
t0 = time.time()
|
||||
nq, d = xq.shape
|
||||
rh = faiss.ResultHeap(nq, k)
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import faiss
|
||||
import re
|
||||
|
||||
|
||||
def get_code_size(d, indexkey):
|
||||
""" size of one vector in an index in dimension d
|
||||
constructed with factory string indexkey"""
|
||||
|
||||
if indexkey == "Flat":
|
||||
return d * 4
|
||||
|
||||
if indexkey.endswith(",RFlat"):
|
||||
return d * 4 + get_code_size(d, indexkey[:-len(",RFlat")])
|
||||
|
||||
mo = re.match("IVF\\d+(_HNSW32)?,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(2))
|
||||
|
||||
mo = re.match("IVF\\d+\\(.*\\)?,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
|
||||
mo = re.match("IMI\\d+x2,(.*)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
|
||||
mo = re.match("(.*),Refine\\((.*)\\)$", indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1)) + get_code_size(d, mo.group(2))
|
||||
|
||||
mo = re.match('PQ(\\d+)x(\\d+)(fs|fsr)?$', indexkey)
|
||||
if mo:
|
||||
return (int(mo.group(1)) * int(mo.group(2)) + 7) // 8
|
||||
|
||||
mo = re.match('PQ(\\d+)\\+(\\d+)$', indexkey)
|
||||
if mo:
|
||||
return (int(mo.group(1)) + int(mo.group(2)))
|
||||
|
||||
mo = re.match('PQ(\\d+)$', indexkey)
|
||||
if mo:
|
||||
return int(mo.group(1))
|
||||
|
||||
if indexkey == "HNSW32" or indexkey == "HNSW32,Flat":
|
||||
return d * 4 + 64 * 4 # roughly
|
||||
|
||||
if indexkey == 'SQ8':
|
||||
return d
|
||||
elif indexkey == 'SQ4':
|
||||
return (d + 1) // 2
|
||||
elif indexkey == 'SQ6':
|
||||
return (d * 6 + 7) // 8
|
||||
elif indexkey == 'SQfp16':
|
||||
return d * 2
|
||||
|
||||
mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
mo = re.match('OPQ\\d+_(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
mo = re.match('OPQ\\d+,(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(d, mo.group(1))
|
||||
mo = re.match('RR(\\d+),(.*)$', indexkey)
|
||||
if mo:
|
||||
return get_code_size(int(mo.group(1)), mo.group(2))
|
||||
raise RuntimeError("cannot parse " + indexkey)
|
||||
|
||||
|
||||
|
||||
def reverse_index_factory(index):
|
||||
"""
|
||||
attempts to get the factory string the index was built with
|
||||
"""
|
||||
index = faiss.downcast_index(index)
|
||||
if isinstance(index, faiss.IndexFlat):
|
||||
return "Flat"
|
||||
if isinstance(index, faiss.IndexIVF):
|
||||
quantizer = faiss.downcast_index(index.quantizer)
|
||||
|
||||
if isinstance(quantizer, faiss.IndexFlat):
|
||||
prefix = "IVF%d" % index.nlist
|
||||
elif isinstance(quantizer, faiss.MultiIndexQuantizer):
|
||||
prefix = "IMI%dx%d" % (quantizer.pq.M, quantizer.pq.nbit)
|
||||
elif isinstance(quantizer, faiss.IndexHNSW):
|
||||
prefix = "IVF%d_HNSW%d" % (index.nlist, quantizer.hnsw.M)
|
||||
else:
|
||||
prefix = "IVF%d(%s)" % (index.nlist, reverse_index_factory(quantizer))
|
||||
|
||||
if isinstance(index, faiss.IndexIVFFlat):
|
||||
return prefix + ",Flat"
|
||||
if isinstance(index, faiss.IndexIVFScalarQuantizer):
|
||||
return prefix + ",SQ8"
|
||||
|
||||
raise NotImplementedError()
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
|
||||
"""
|
||||
I/O functions in fvecs, bvecs, ivecs formats
|
||||
definition of the formats here: http://corpus-texmex.irisa.fr/
|
||||
"""
|
||||
|
||||
|
||||
|
@ -34,3 +35,16 @@ def bvecs_mmap(fname):
|
|||
x = np.memmap(fname, dtype='uint8', mode='r')
|
||||
d = x[:4].view('int32')[0]
|
||||
return x.reshape(-1, d + 4)[:, 4:]
|
||||
|
||||
|
||||
def ivecs_write(fname, m):
|
||||
n, d = m.shape
|
||||
m1 = np.empty((n, d + 1), dtype='int32')
|
||||
m1[:, 0] = d
|
||||
m1[:, 1:] = m
|
||||
m1.tofile(fname)
|
||||
|
||||
|
||||
def fvecs_write(fname, m):
|
||||
m = m.astype('float32')
|
||||
ivecs_write(fname, m.view('int32'))
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
|
@ -32,6 +33,7 @@
|
|||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/IndexScalarQuantizer.h>
|
||||
#include <faiss/IndexHNSW.h>
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/IndexBinaryHNSW.h>
|
||||
|
@ -234,7 +236,7 @@ void OperatingPoints::display (bool only_optimal) const
|
|||
{
|
||||
const std::vector<OperatingPoint> &pts =
|
||||
only_optimal ? optimal_pts : all_pts;
|
||||
printf("Tested %zd operating points, %zd ones are optimal:\n",
|
||||
printf("Tested %zd operating points, %zd ones are Pareto-optimal:\n",
|
||||
all_pts.size(), optimal_pts.size());
|
||||
|
||||
for (int i = 0; i < pts.size(); i++) {
|
||||
|
@ -333,7 +335,7 @@ static void init_pq_ParameterRange (const ProductQuantizer & pq,
|
|||
pr.values.push_back (pq.code_size * 8);
|
||||
}
|
||||
|
||||
ParameterRange &ParameterSpace::add_range(const char * name)
|
||||
ParameterRange &ParameterSpace::add_range(const std::string & name)
|
||||
{
|
||||
for (auto & pr : parameter_ranges) {
|
||||
if (pr.name == name) {
|
||||
|
@ -346,13 +348,13 @@ ParameterRange &ParameterSpace::add_range(const char * name)
|
|||
}
|
||||
|
||||
|
||||
/// initialize with reasonable parameters for the index
|
||||
/// initialize with reasonable parameters for this type of index
|
||||
void ParameterSpace::initialize (const Index * index)
|
||||
{
|
||||
if (DC (IndexPreTransform)) {
|
||||
index = ix->index;
|
||||
}
|
||||
if (DC (IndexRefineFlat)) {
|
||||
if (DC (IndexRefine)) {
|
||||
ParameterRange & pr = add_range("k_factor_rf");
|
||||
for (int i = 0; i <= 6; i++) {
|
||||
pr.values.push_back (1 << i);
|
||||
|
@ -372,12 +374,14 @@ void ParameterSpace::initialize (const Index * index)
|
|||
pr.values.push_back (nprobe);
|
||||
}
|
||||
}
|
||||
if (dynamic_cast<const IndexHNSW*>(ix->quantizer)) {
|
||||
ParameterRange & pr = add_range("efSearch");
|
||||
for (int i = 2; i <= 9; i++) {
|
||||
pr.values.push_back (1 << i);
|
||||
}
|
||||
ParameterSpace ivf_pspace;
|
||||
ivf_pspace.initialize(ix->quantizer);
|
||||
|
||||
for (const ParameterRange & p: ivf_pspace.parameter_ranges) {
|
||||
ParameterRange & pr = add_range("quantizer_" + p.name);
|
||||
pr.values = p.values;
|
||||
}
|
||||
|
||||
}
|
||||
if (DC (IndexPQ)) {
|
||||
ParameterRange & pr = add_range("ht");
|
||||
|
@ -457,44 +461,38 @@ void ParameterSpace::set_index_parameters (
|
|||
void ParameterSpace::set_index_parameter (
|
||||
Index * index, const std::string & name, double val) const
|
||||
{
|
||||
if (verbose > 1)
|
||||
printf(" set %s=%g\n", name.c_str(), val);
|
||||
if (verbose > 1) {
|
||||
printf(" set_index_parameter %s=%g\n", name.c_str(), val);
|
||||
}
|
||||
|
||||
if (name == "verbose") {
|
||||
index->verbose = int(val);
|
||||
// and fall through to also enable it on sub-indexes
|
||||
}
|
||||
if (DC (IndexIDMap)) {
|
||||
set_index_parameter (ix->index, name, val);
|
||||
return;
|
||||
}
|
||||
if (DC (IndexPreTransform)) {
|
||||
set_index_parameter (ix->index, name, val);
|
||||
return;
|
||||
}
|
||||
if (DC (IndexShards)) {
|
||||
if (DC (ThreadedIndex<Index>)) {
|
||||
// call on all sub-indexes
|
||||
auto fn =
|
||||
[this, name, val](int, Index* subIndex) {
|
||||
set_index_parameter(subIndex, name, val);
|
||||
[this, name, val](int /* no */, Index* subIndex) {
|
||||
set_index_parameter(subIndex, name, val);
|
||||
};
|
||||
|
||||
ix->runOnIndex(fn);
|
||||
return;
|
||||
}
|
||||
if (DC (IndexReplicas)) {
|
||||
// call on all sub-indexes
|
||||
auto fn =
|
||||
[this, name, val](int, Index* subIndex) {
|
||||
set_index_parameter(subIndex, name, val);
|
||||
};
|
||||
|
||||
ix->runOnIndex(fn);
|
||||
return;
|
||||
}
|
||||
if (DC (IndexRefineFlat)) {
|
||||
if (DC (IndexRefine)) {
|
||||
if (name == "k_factor_rf") {
|
||||
ix->k_factor = int(val);
|
||||
return;
|
||||
}
|
||||
// otherwise it is for the sub-index
|
||||
set_index_parameter (&ix->refine_index, name, val);
|
||||
set_index_parameter (ix->base_index, name, val);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -504,10 +502,7 @@ void ParameterSpace::set_index_parameter (
|
|||
}
|
||||
|
||||
if (name == "nprobe") {
|
||||
if (DC (IndexIDMap)) {
|
||||
set_index_parameter (ix->index, name, val);
|
||||
return;
|
||||
} else if (DC (IndexIVF)) {
|
||||
if (DC (IndexIVF)) {
|
||||
ix->nprobe = int(val);
|
||||
return;
|
||||
}
|
||||
|
@ -559,6 +554,14 @@ void ParameterSpace::set_index_parameter (
|
|||
}
|
||||
}
|
||||
|
||||
if (name.find("quantizer_") == 0) {
|
||||
if (DC(IndexIVF)) {
|
||||
std::string sub_name = name.substr(strlen("quantizer_"));
|
||||
set_index_parameter(ix->quantizer, sub_name, val);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
FAISS_THROW_FMT ("ParameterSpace::set_index_parameter:"
|
||||
"could not set parameter %s",
|
||||
name.c_str());
|
||||
|
@ -707,8 +710,8 @@ void ParameterSpace::explore (Index *index,
|
|||
bool keep = ops->add (perf, t_search, combination_name (cno), cno);
|
||||
|
||||
if (verbose)
|
||||
printf(" perf %.3f t %.3f (%d runs) %s\n",
|
||||
perf, t_search, nrun,
|
||||
printf(" perf %.3f t %.3f (%d %s) %s\n",
|
||||
perf, t_search, nrun, nrun >= 2 ? "runs" : "run",
|
||||
keep ? "*" : "");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,10 @@ struct IntersectionCriterion: AutoTuneCriterion {
|
|||
/**
|
||||
* Maintains a list of experimental results. Each operating point is a
|
||||
* (perf, t, key) triplet, where higher perf and lower t is
|
||||
* better. The key field is an arbitrary identifier for the operating point
|
||||
* better. The key field is an arbitrary identifier for the operating point.
|
||||
*
|
||||
* Includes primitives to extract the Pareto-optimal operating points in the
|
||||
* (perf, t) space.
|
||||
*/
|
||||
|
||||
struct OperatingPoint {
|
||||
|
@ -168,7 +171,7 @@ struct ParameterSpace {
|
|||
void display () const;
|
||||
|
||||
/// add a new parameter (or return it if it exists)
|
||||
ParameterRange &add_range(const char * name);
|
||||
ParameterRange &add_range(const std::string & name);
|
||||
|
||||
/// initialize with reasonable parameters for the index
|
||||
virtual void initialize (const Index * index);
|
||||
|
@ -179,7 +182,7 @@ struct ParameterSpace {
|
|||
/// set a combination of parameters described by a string
|
||||
void set_index_parameters (Index *index, const char *param_string) const;
|
||||
|
||||
/// set one of the parameters
|
||||
/// set one of the parameters, returns whether setting was successful
|
||||
virtual void set_index_parameter (
|
||||
Index * index, const std::string & name, double val) const;
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ add_library(faiss
|
|||
IndexReplicas.cpp
|
||||
IndexScalarQuantizer.cpp
|
||||
IndexShards.cpp
|
||||
IndexRefine.cpp
|
||||
MatrixStats.cpp
|
||||
MetaIndexes.cpp
|
||||
VectorTransform.cpp
|
||||
|
@ -94,6 +95,7 @@ set(FAISS_HEADERS
|
|||
IndexReplicas.h
|
||||
IndexScalarQuantizer.h
|
||||
IndexShards.h
|
||||
IndexRefine.h
|
||||
MatrixStats.h
|
||||
MetaIndexes.h
|
||||
MetricType.h
|
||||
|
|
|
@ -227,144 +227,6 @@ void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
|
|||
|
||||
|
||||
|
||||
/***************************************************
|
||||
* IndexRefineFlat
|
||||
***************************************************/
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat (Index *base_index):
|
||||
Index (base_index->d, base_index->metric_type),
|
||||
refine_index (base_index->d, base_index->metric_type),
|
||||
base_index (base_index), own_fields (false),
|
||||
k_factor (1)
|
||||
{
|
||||
is_trained = base_index->is_trained;
|
||||
FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
|
||||
"base_index should be empty in the beginning");
|
||||
}
|
||||
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
|
||||
Index (base_index->d, base_index->metric_type),
|
||||
refine_index (base_index->d, base_index->metric_type),
|
||||
base_index (base_index), own_fields (false),
|
||||
k_factor (1)
|
||||
{
|
||||
is_trained = base_index->is_trained;
|
||||
refine_index.add (base_index->ntotal, xb);
|
||||
ntotal = base_index->ntotal;
|
||||
}
|
||||
|
||||
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat () {
|
||||
base_index = nullptr;
|
||||
own_fields = false;
|
||||
k_factor = 1;
|
||||
}
|
||||
|
||||
|
||||
void IndexRefineFlat::train (idx_t n, const float *x)
|
||||
{
|
||||
base_index->train (n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
void IndexRefineFlat::add (idx_t n, const float *x) {
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
base_index->add (n, x);
|
||||
refine_index.add (n, x);
|
||||
ntotal = refine_index.ntotal;
|
||||
}
|
||||
|
||||
void IndexRefineFlat::reset ()
|
||||
{
|
||||
base_index->reset ();
|
||||
refine_index.reset ();
|
||||
ntotal = 0;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
typedef faiss::Index::idx_t idx_t;
|
||||
|
||||
template<class C>
|
||||
static void reorder_2_heaps (
|
||||
idx_t n,
|
||||
idx_t k, idx_t *labels, float *distances,
|
||||
idx_t k_base, const idx_t *base_labels, const float *base_distances)
|
||||
{
|
||||
#pragma omp parallel for
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
idx_t *idxo = labels + i * k;
|
||||
float *diso = distances + i * k;
|
||||
const idx_t *idxi = base_labels + i * k_base;
|
||||
const float *disi = base_distances + i * k_base;
|
||||
|
||||
heap_heapify<C> (k, diso, idxo, disi, idxi, k);
|
||||
if (k_base != k) { // add remaining elements
|
||||
heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
|
||||
}
|
||||
heap_reorder<C> (k, diso, idxo);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
void IndexRefineFlat::search (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
idx_t k_base = idx_t (k * k_factor);
|
||||
idx_t * base_labels = labels;
|
||||
float * base_distances = distances;
|
||||
ScopeDeleter<idx_t> del1;
|
||||
ScopeDeleter<float> del2;
|
||||
|
||||
|
||||
if (k != k_base) {
|
||||
base_labels = new idx_t [n * k_base];
|
||||
del1.set (base_labels);
|
||||
base_distances = new float [n * k_base];
|
||||
del2.set (base_distances);
|
||||
}
|
||||
|
||||
base_index->search (n, x, k_base, base_distances, base_labels);
|
||||
|
||||
for (int i = 0; i < n * k_base; i++)
|
||||
assert (base_labels[i] >= -1 &&
|
||||
base_labels[i] < ntotal);
|
||||
|
||||
// compute refined distances
|
||||
refine_index.compute_distance_subset (
|
||||
n, x, k_base, base_distances, base_labels);
|
||||
|
||||
// sort and store result
|
||||
if (metric_type == METRIC_L2) {
|
||||
typedef CMax <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
|
||||
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
typedef CMin <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
} else {
|
||||
FAISS_THROW_MSG("Metric type not supported");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
IndexRefineFlat::~IndexRefineFlat ()
|
||||
{
|
||||
if (own_fields) delete base_index;
|
||||
}
|
||||
|
||||
/***************************************************
|
||||
* IndexFlat1D
|
||||
|
|
|
@ -94,47 +94,6 @@ struct IndexFlatL2:IndexFlat {
|
|||
|
||||
|
||||
|
||||
/** Index that queries in a base_index (a fast one) and refines the
|
||||
* results with an exact search, hopefully improving the results.
|
||||
*/
|
||||
struct IndexRefineFlat: Index {
|
||||
|
||||
/// storage for full vectors
|
||||
IndexFlat refine_index;
|
||||
|
||||
/// faster index to pre-select the vectors that should be filtered
|
||||
Index *base_index;
|
||||
bool own_fields; ///< should the base index be deallocated?
|
||||
|
||||
/// factor between k requested in search and the k requested from
|
||||
/// the base_index (should be >= 1)
|
||||
float k_factor;
|
||||
|
||||
/// intitialize from empty index
|
||||
explicit IndexRefineFlat (Index *base_index);
|
||||
|
||||
/// initialize from index and corresponding data
|
||||
IndexRefineFlat(Index *base_index, const float *xb);
|
||||
|
||||
IndexRefineFlat ();
|
||||
|
||||
void train(idx_t n, const float* x) override;
|
||||
|
||||
void add(idx_t n, const float* x) override;
|
||||
|
||||
void reset() override;
|
||||
|
||||
void search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels) const override;
|
||||
|
||||
~IndexRefineFlat() override;
|
||||
};
|
||||
|
||||
|
||||
/// optimized version for 1D "vectors".
|
||||
struct IndexFlat1D:IndexFlatL2 {
|
||||
bool continuous_update; ///< is the permutation updated continuously?
|
||||
|
|
|
@ -88,12 +88,19 @@ void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricTy
|
|||
}
|
||||
quantizer->is_trained = true;
|
||||
} else if (quantizer_trains_alone == 2) {
|
||||
if (verbose)
|
||||
if (verbose) {
|
||||
printf (
|
||||
"Training L2 quantizer on %zd vectors in %zdD%s\n",
|
||||
n, d,
|
||||
clustering_index ? "(user provided index)" : "");
|
||||
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
|
||||
}
|
||||
// also accept spherical centroids because in that case
|
||||
// L2 and IP are equivalent
|
||||
FAISS_THROW_IF_NOT (
|
||||
metric_type == METRIC_L2 ||
|
||||
(metric_type == METRIC_INNER_PRODUCT && cp.spherical)
|
||||
);
|
||||
|
||||
Clustering clus (d, nlist, cp);
|
||||
if (!clustering_index) {
|
||||
IndexFlatL2 assigner (d);
|
||||
|
@ -263,23 +270,76 @@ void IndexIVF::set_direct_map_type (DirectMap::Type type)
|
|||
direct_map.set_type (type, invlists, ntotal);
|
||||
}
|
||||
|
||||
|
||||
/** It is a sad fact of software that a conceptually simple function like this
|
||||
* becomes very complex when you factor in several ways of parallelizing +
|
||||
* interrupt/error handling + collecting stats + min/max collection. The
|
||||
* codepath that is used 95% of time is the one for parallel_mode = 0 */
|
||||
void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels) const
|
||||
{
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
||||
double t0 = getmillisecs();
|
||||
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
|
||||
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
||||
|
||||
t0 = getmillisecs();
|
||||
invlists->prefetch_lists (idx.get(), n * nprobe);
|
||||
// search function for a subset of queries
|
||||
auto sub_search_func = [this, k]
|
||||
(idx_t n, const float *x, float *distances, idx_t *labels,
|
||||
IndexIVFStats *ivf_stats) {
|
||||
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
||||
double t0 = getmillisecs();
|
||||
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
|
||||
|
||||
double t1 = getmillisecs();
|
||||
invlists->prefetch_lists (idx.get(), n * nprobe);
|
||||
|
||||
search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
|
||||
distances, labels, false, nullptr, ivf_stats);
|
||||
double t2 = getmillisecs();
|
||||
ivf_stats->quantization_time += t1 - t0;
|
||||
ivf_stats->search_time += t2 - t0;
|
||||
};
|
||||
|
||||
|
||||
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
|
||||
int nt = std::min(omp_get_max_threads(), int(n));
|
||||
std::vector<IndexIVFStats> stats(nt);
|
||||
std::mutex exception_mutex;
|
||||
std::string exception_string;
|
||||
|
||||
#pragma omp parallel for if (nt > 1)
|
||||
for(idx_t slice = 0; slice < nt; slice++) {
|
||||
IndexIVFStats local_stats;
|
||||
idx_t i0 = n * slice / nt;
|
||||
idx_t i1 = n * (slice + 1) / nt;
|
||||
if (i1 > i0) {
|
||||
try {
|
||||
sub_search_func(
|
||||
i1 - i0, x + i0 * d,
|
||||
distances + i0 * k, labels + i0 * k,
|
||||
&stats[slice]
|
||||
);
|
||||
} catch(const std::exception & e) {
|
||||
std::lock_guard<std::mutex> lock(exception_mutex);
|
||||
exception_string = e.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!exception_string.empty()) {
|
||||
FAISS_THROW_MSG (exception_string.c_str());
|
||||
}
|
||||
|
||||
// collect stats
|
||||
for(idx_t slice = 0; slice < nt; slice++) {
|
||||
indexIVF_stats.add(stats[slice]);
|
||||
}
|
||||
} else {
|
||||
// handle paralellization at level below (or don't run in parallel at all)
|
||||
sub_search_func(n, x, distances, labels, &indexIVF_stats);
|
||||
}
|
||||
|
||||
|
||||
search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
|
||||
distances, labels, false);
|
||||
indexIVF_stats.search_time += getmillisecs() - t0;
|
||||
}
|
||||
|
||||
|
||||
|
@ -288,7 +348,8 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
const float *coarse_dis ,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params) const
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats *ivf_stats) const
|
||||
{
|
||||
long nprobe = params ? params->nprobe : this->nprobe;
|
||||
long max_codes = params ? params->max_codes : this->max_codes;
|
||||
|
@ -305,13 +366,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
||||
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
||||
|
||||
// don't start parallel section if single query
|
||||
bool do_parallel = omp_get_max_threads() >= 2 && (
|
||||
pmode == 0 ? n > 1 :
|
||||
pmode == 0 ? false :
|
||||
pmode == 3 ? n > 1 :
|
||||
pmode == 1 ? nprobe > 1 :
|
||||
nprobe * n > 1);
|
||||
|
||||
|
||||
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
|
||||
{
|
||||
InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
|
||||
|
@ -409,7 +469,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
* Actual loops, depending on parallel_mode
|
||||
****************************************************/
|
||||
|
||||
if (pmode == 0) {
|
||||
if (pmode == 0 || pmode == 3) {
|
||||
|
||||
#pragma omp for
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
|
@ -527,11 +587,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|||
}
|
||||
}
|
||||
|
||||
indexIVF_stats.nq += n;
|
||||
indexIVF_stats.nlist += nlistv;
|
||||
indexIVF_stats.ndis += ndis;
|
||||
indexIVF_stats.nheap_updates += nheap;
|
||||
|
||||
if (ivf_stats) {
|
||||
ivf_stats->nq += n;
|
||||
ivf_stats->nlist += nlistv;
|
||||
ivf_stats->ndis += ndis;
|
||||
ivf_stats->nheap_updates += nheap;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -551,7 +612,7 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
|
|||
invlists->prefetch_lists (keys.get(), nx * nprobe);
|
||||
|
||||
range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
|
||||
result);
|
||||
result, false, nullptr, &indexIVF_stats);
|
||||
|
||||
indexIVF_stats.search_time += getmillisecs() - t0;
|
||||
}
|
||||
|
@ -561,7 +622,8 @@ void IndexIVF::range_search_preassigned (
|
|||
const idx_t *keys, const float *coarse_dis,
|
||||
RangeSearchResult *result,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params) const
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats *stats) const
|
||||
{
|
||||
long nprobe = params ? params->nprobe : this->nprobe;
|
||||
long max_codes = params ? params->max_codes : this->max_codes;
|
||||
|
@ -574,7 +636,15 @@ void IndexIVF::range_search_preassigned (
|
|||
|
||||
std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
|
||||
|
||||
#pragma omp parallel reduction(+: nlistv, ndis)
|
||||
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
||||
// don't start parallel section if single query
|
||||
bool do_parallel = omp_get_max_threads() >= 2 && (
|
||||
pmode == 3 ? false :
|
||||
pmode == 0 ? nx > 1 :
|
||||
pmode == 1 ? nprobe > 1 :
|
||||
nprobe * nx > 1);
|
||||
|
||||
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
|
||||
{
|
||||
RangeSearchPartialResult pres(result);
|
||||
std::unique_ptr<InvertedListScanner> scanner
|
||||
|
@ -680,9 +750,11 @@ void IndexIVF::range_search_preassigned (
|
|||
}
|
||||
}
|
||||
|
||||
indexIVF_stats.nq += nx;
|
||||
indexIVF_stats.nlist += nlistv;
|
||||
indexIVF_stats.ndis += ndis;
|
||||
if (stats) {
|
||||
stats->nq += nx;
|
||||
stats->nlist += nlistv;
|
||||
stats->ndis += ndis;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -975,6 +1047,17 @@ void IndexIVFStats::reset()
|
|||
memset ((void*)this, 0, sizeof (*this));
|
||||
}
|
||||
|
||||
void IndexIVFStats::add (const IndexIVFStats & other)
|
||||
{
|
||||
nq += other.nq;
|
||||
nlist += other.nlist;
|
||||
ndis += other.ndis;
|
||||
nheap_updates += other.nheap_updates;
|
||||
quantization_time += other.quantization_time;
|
||||
search_time += other.search_time;
|
||||
|
||||
}
|
||||
|
||||
|
||||
IndexIVFStats indexIVF_stats;
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ struct IVFSearchParameters {
|
|||
|
||||
|
||||
struct InvertedListScanner;
|
||||
struct IndexIVFStats;
|
||||
|
||||
/** Index based on a inverted file (IVF)
|
||||
*
|
||||
|
@ -109,9 +110,10 @@ struct IndexIVF: Index, Level1Quantizer {
|
|||
|
||||
/** Parallel mode determines how queries are parallelized with OpenMP
|
||||
*
|
||||
* 0 (default): parallelize over queries
|
||||
* 0 (default): split over queries
|
||||
* 1: parallelize over inverted lists
|
||||
* 2: parallelize over both
|
||||
* 3: split over queries with a finer granularity
|
||||
*
|
||||
* PARALLEL_MODE_NO_HEAP_INIT: binary or with the previous to
|
||||
* prevent the heap to be initialized and finalized
|
||||
|
@ -178,13 +180,15 @@ struct IndexIVF: Index, Level1Quantizer {
|
|||
* instead in upper/lower 32 bit of result,
|
||||
* instead of ids (used for reranking).
|
||||
* @param params used to override the object's search parameters
|
||||
* @param stats search stats to be updated (can be null)
|
||||
*/
|
||||
virtual void search_preassigned (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
const idx_t *assign, const float *centroid_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
IndexIVFStats *stats=nullptr
|
||||
) const;
|
||||
|
||||
/** assign the vectors, then call search_preassign */
|
||||
|
@ -199,7 +203,8 @@ struct IndexIVF: Index, Level1Quantizer {
|
|||
const idx_t *keys, const float *coarse_dis,
|
||||
RangeSearchResult *result,
|
||||
bool store_pairs=false,
|
||||
const IVFSearchParameters *params=nullptr) const;
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
IndexIVFStats *stats=nullptr) const;
|
||||
|
||||
/// get a scanner for this index (store_pairs means ignore labels)
|
||||
virtual InvertedListScanner *get_InvertedListScanner (
|
||||
|
@ -365,6 +370,7 @@ struct IndexIVFStats {
|
|||
|
||||
IndexIVFStats () {reset (); }
|
||||
void reset ();
|
||||
void add (const IndexIVFStats & other);
|
||||
};
|
||||
|
||||
// global var that collects them all
|
||||
|
|
|
@ -317,7 +317,8 @@ void IndexIVFFlatDedup::search_preassigned (
|
|||
const float *centroid_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params) const
|
||||
const IVFSearchParameters *params,
|
||||
IndexIVFStats *stats) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT_MSG (
|
||||
!store_pairs, "store_pairs not supported in IVFDedup");
|
||||
|
|
|
@ -77,7 +77,8 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
|
|||
const float *centroid_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
IndexIVFStats *stats=nullptr
|
||||
) const override;
|
||||
|
||||
size_t remove_ids(const IDSelector& sel) override;
|
||||
|
|
|
@ -154,7 +154,7 @@ void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
|
|||
|
||||
if (verbose) {
|
||||
printf ("training %zdx%zd product quantizer on %zd vectors in %dD\n",
|
||||
pq.M, pq.ksub, n, d);
|
||||
pq.M, pq.ksub, long(n), d);
|
||||
}
|
||||
pq.verbose = verbose;
|
||||
pq.train (n, trainset);
|
||||
|
@ -413,17 +413,23 @@ void IndexIVFPQFastScan::compute_LUT(
|
|||
AlignedTable<float> ip_table(n * dim12);
|
||||
pq.compute_inner_prod_tables (n, x, ip_table.get());
|
||||
|
||||
#pragma omp parallel if (n * nprobe > 8000)
|
||||
for(idx_t i = 0; i < n; i++) {
|
||||
for(idx_t j = 0; j < nprobe; j++) {
|
||||
size_t ij = i * nprobe + j;
|
||||
#pragma omp parallel for if (n * nprobe > 8000)
|
||||
for(idx_t ij = 0; ij < n * nprobe; ij++) {
|
||||
idx_t i = ij / nprobe;
|
||||
float *tab = dis_tables.get() + ij * dim12;
|
||||
idx_t cij = coarse_ids[ij];
|
||||
|
||||
if (cij >= 0) {
|
||||
fvec_madd_avx (
|
||||
dim12,
|
||||
precomputed_table.get() + coarse_ids[ij] * dim12,
|
||||
precomputed_table.get() + cij * dim12,
|
||||
-2, ip_table.get() + i * dim12,
|
||||
dis_tables.get() + ij * dim12
|
||||
tab
|
||||
);
|
||||
} else {
|
||||
// fill with NaNs so that they are ignored during
|
||||
// LUT quantization
|
||||
memset (tab, -1, sizeof(float) * dim12);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -433,12 +439,18 @@ void IndexIVFPQFastScan::compute_LUT(
|
|||
biases.resize(n * nprobe);
|
||||
memset(biases.get(), 0, sizeof(float) * n * nprobe);
|
||||
|
||||
#pragma omp parallel if (n > 8000)
|
||||
for(idx_t i = 0; i < n; i++) {
|
||||
for(idx_t j = 0; j < nprobe; j++) {
|
||||
#pragma omp parallel for if (n * nprobe > 8000)
|
||||
for(idx_t ij = 0; ij < n * nprobe; ij++) {
|
||||
idx_t i = ij / nprobe;
|
||||
float *xij = &xrel[ij * d];
|
||||
idx_t cij = coarse_ids[ij];
|
||||
|
||||
if (cij >= 0) {
|
||||
ivfpq.quantizer->compute_residual(
|
||||
x + i * d, &xrel[(i * nprobe + j) * d],
|
||||
coarse_ids[i * nprobe + j]);
|
||||
x + i * d, xij, cij);
|
||||
} else {
|
||||
// will fill with NaNs
|
||||
memset(xij, -1, sizeof(float) * d);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,8 +515,8 @@ void IndexIVFPQFastScan::compute_LUT_uint8(
|
|||
}
|
||||
uint64_t t1 = get_cy();
|
||||
|
||||
#pragma omp parallel if (n > 8000)
|
||||
for(size_t i = 0; i < n; i++) {
|
||||
#pragma omp parallel for if (n > 100)
|
||||
for(int64_t i = 0; i < n; i++) {
|
||||
const float *t_in = dis_tables_float.get() + i * dim123;
|
||||
const float *b_in = nullptr;
|
||||
uint8_t *t_out = dis_tables.get() + i * dim123_2;
|
||||
|
@ -569,8 +581,8 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|||
|
||||
} else if (impl >= 10 && impl <= 13) {
|
||||
size_t ndis = 0, nlist_visited = 0;
|
||||
int nt = std::min(omp_get_max_threads(), int(n));
|
||||
if (nt < 2) {
|
||||
|
||||
if (n < 2) {
|
||||
if (impl == 12 || impl == 13) {
|
||||
search_implem_12<C>
|
||||
(n, x, k, distances, labels, impl, &ndis, &nlist_visited);
|
||||
|
@ -580,10 +592,27 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|||
}
|
||||
} else {
|
||||
// explicitly slice over threads
|
||||
#pragma omp parallel for num_threads(nt) reduction(+: ndis, nlist_visited)
|
||||
for (int slice = 0; slice < nt; slice++) {
|
||||
idx_t i0 = n * slice / nt;
|
||||
idx_t i1 = n * (slice + 1) / nt;
|
||||
int nslice;
|
||||
if (n <= omp_get_max_threads()) {
|
||||
nslice = n;
|
||||
} else if (by_residual && metric_type == METRIC_L2) {
|
||||
// make sure we don't make too big LUT tables
|
||||
size_t lut_size_per_query =
|
||||
pq.M * pq.ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
|
||||
|
||||
size_t max_lut_size = precomputed_table_max_bytes;
|
||||
// how many queries we can handle within mem budget
|
||||
size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
|
||||
nslice = roundup(std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
|
||||
} else {
|
||||
// LUTs unlikely to be a limiting factor
|
||||
nslice = omp_get_max_threads();
|
||||
}
|
||||
|
||||
#pragma omp parallel for reduction(+: ndis, nlist_visited)
|
||||
for (int slice = 0; slice < nslice; slice++) {
|
||||
idx_t i0 = n * slice / nslice;
|
||||
idx_t i1 = n * (slice + 1) / nslice;
|
||||
float *dis_i = distances + i0 * k;
|
||||
idx_t *lab_i = labels + i0 * k;
|
||||
if (impl == 12 || impl == 13) {
|
||||
|
@ -921,9 +950,9 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
TIC;
|
||||
|
||||
struct QC {
|
||||
int qno;
|
||||
int list_no;
|
||||
int rank;
|
||||
int qno; // sequence number of the query
|
||||
int list_no; // list to visit
|
||||
int rank; // this is the rank'th result of the coarse quantizer
|
||||
};
|
||||
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
||||
|
||||
|
@ -947,6 +976,8 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
}
|
||||
TIC;
|
||||
|
||||
// prepare the result handlers
|
||||
|
||||
std::unique_ptr<SIMDResultHandler<C, true> > handler;
|
||||
AlignedTable<uint16_t> tmp_distances;
|
||||
|
||||
|
@ -979,6 +1010,7 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
while (i0 < qcs.size()) {
|
||||
uint64_t tt0 = get_cy();
|
||||
|
||||
// find all queries that access this inverted list
|
||||
int list_no = qcs[i0].list_no;
|
||||
size_t i1 = i0 + 1;
|
||||
|
||||
|
@ -996,6 +1028,7 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
continue;
|
||||
}
|
||||
|
||||
// re-organize LUTs and biases into the right order
|
||||
int nc = i1 - i0;
|
||||
|
||||
std::vector<int> q_map(nc), lut_entries(nc);
|
||||
|
@ -1017,11 +1050,15 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
LUT.get()
|
||||
);
|
||||
|
||||
// access the inverted list
|
||||
|
||||
ndis += (i1 - i0) * list_size;
|
||||
|
||||
InvertedLists::ScopedCodes codes(invlists, list_no);
|
||||
InvertedLists::ScopedIds ids(invlists, list_no);
|
||||
|
||||
// prepare the handler
|
||||
|
||||
handler->ntotal = list_size;
|
||||
handler->q_map = q_map.data();
|
||||
handler->id_map = ids.get();
|
||||
|
@ -1039,14 +1076,16 @@ void IndexIVFPQFastScan::search_implem_12(
|
|||
else DISPATCH(ReservoirHC)
|
||||
else DISPATCH(SingleResultHC)
|
||||
|
||||
// prepare for next loop
|
||||
i0 = i1;
|
||||
|
||||
uint64_t tt2 = get_cy();
|
||||
t_copy_pack += tt1 - tt0;
|
||||
t_scan += tt2 - tt1;
|
||||
i0 = i1;
|
||||
}
|
||||
TIC;
|
||||
|
||||
// labels is the same array
|
||||
// labels is in-place for HeapHC
|
||||
handler->to_flat_arrays(
|
||||
distances, labels,
|
||||
skip & 16 ? nullptr : normalizers.get()
|
||||
|
|
|
@ -97,13 +97,13 @@ void IndexIVFPQR::add_core (idx_t n, const float *x, const idx_t *xids,
|
|||
#define TOC get_cycles () - t0
|
||||
|
||||
|
||||
void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
|
||||
const idx_t *idx,
|
||||
const float *L1_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params
|
||||
) const
|
||||
void IndexIVFPQR::search_preassigned (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
const idx_t *idx, const float *L1_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params, IndexIVFStats *stats
|
||||
) const
|
||||
{
|
||||
uint64_t t0;
|
||||
TIC;
|
||||
|
|
|
@ -55,7 +55,8 @@ struct IndexIVFPQR: IndexIVFPQ {
|
|||
const float *centroid_dis,
|
||||
float *distances, idx_t *labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters *params=nullptr
|
||||
const IVFSearchParameters *params=nullptr,
|
||||
IndexIVFStats *stats=nullptr
|
||||
) const override;
|
||||
|
||||
IndexIVFPQR();
|
||||
|
|
|
@ -129,9 +129,10 @@ void IndexPQ::reconstruct (idx_t key, float * recons) const
|
|||
|
||||
namespace {
|
||||
|
||||
|
||||
struct PQDis: DistanceComputer {
|
||||
template<class PQDecoder>
|
||||
struct PQDistanceComputer: DistanceComputer {
|
||||
size_t d;
|
||||
MetricType metric;
|
||||
Index::idx_t nb;
|
||||
const uint8_t *codes;
|
||||
size_t code_size;
|
||||
|
@ -144,10 +145,11 @@ struct PQDis: DistanceComputer {
|
|||
{
|
||||
const uint8_t *code = codes + i * code_size;
|
||||
const float *dt = precomputed_table.data();
|
||||
PQDecoder decoder(code, pq.nbits);
|
||||
float accu = 0;
|
||||
for (int j = 0; j < pq.M; j++) {
|
||||
accu += dt[*code++];
|
||||
dt += 256;
|
||||
accu += dt[decoder.decode()];
|
||||
dt += 1 << decoder.nbits;
|
||||
}
|
||||
ndis++;
|
||||
return accu;
|
||||
|
@ -155,33 +157,43 @@ struct PQDis: DistanceComputer {
|
|||
|
||||
float symmetric_dis(idx_t i, idx_t j) override
|
||||
{
|
||||
FAISS_THROW_IF_NOT(sdc);
|
||||
const float * sdci = sdc;
|
||||
float accu = 0;
|
||||
const uint8_t *codei = codes + i * code_size;
|
||||
const uint8_t *codej = codes + j * code_size;
|
||||
PQDecoder codei (codes + i * code_size, pq.nbits);
|
||||
PQDecoder codej (codes + j * code_size, pq.nbits);
|
||||
|
||||
for (int l = 0; l < pq.M; l++) {
|
||||
accu += sdci[(*codei++) + (*codej++) * 256];
|
||||
sdci += 256 * 256;
|
||||
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
|
||||
sdci += uint64_t(1) << (2 * codei.nbits);
|
||||
}
|
||||
ndis++;
|
||||
return accu;
|
||||
}
|
||||
|
||||
explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
|
||||
: pq(storage.pq) {
|
||||
explicit PQDistanceComputer(const IndexPQ& storage)
|
||||
: pq(storage.pq) {
|
||||
precomputed_table.resize(pq.M * pq.ksub);
|
||||
nb = storage.ntotal;
|
||||
d = storage.d;
|
||||
metric = storage.metric_type;
|
||||
codes = storage.codes.data();
|
||||
code_size = pq.code_size;
|
||||
FAISS_ASSERT(pq.ksub == 256);
|
||||
FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
|
||||
sdc = pq.sdc_table.data();
|
||||
if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
|
||||
sdc = pq.sdc_table.data();
|
||||
} else {
|
||||
sdc = nullptr;
|
||||
}
|
||||
ndis = 0;
|
||||
}
|
||||
|
||||
void set_query(const float *x) override {
|
||||
pq.compute_distance_table(x, precomputed_table.data());
|
||||
if (metric == METRIC_L2) {
|
||||
pq.compute_distance_table(x, precomputed_table.data());
|
||||
} else {
|
||||
pq.compute_inner_prod_table(x, precomputed_table.data());
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -190,8 +202,13 @@ struct PQDis: DistanceComputer {
|
|||
|
||||
|
||||
DistanceComputer * IndexPQ::get_distance_computer() const {
|
||||
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
||||
return new PQDis(*this);
|
||||
if (pq.nbits == 8) {
|
||||
return new PQDistanceComputer<PQDecoder8>(*this);
|
||||
} else if (pq.nbits == 16) {
|
||||
return new PQDistanceComputer<PQDecoder16>(*this);
|
||||
} else {
|
||||
return new PQDistanceComputer<PQDecoderGeneric>(*this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <omp.h>
|
||||
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/random.h>
|
||||
|
@ -38,6 +39,7 @@ IndexPQFastScan::IndexPQFastScan(
|
|||
Index(d, metric), pq(d, M, nbits),
|
||||
bbs(bbs), ntotal2(0), M2(roundup(M, 2))
|
||||
{
|
||||
FAISS_THROW_IF_NOT(nbits == 4);
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
|
@ -231,6 +233,7 @@ void IndexPQFastScan::search_dispatch_implem(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
if (implem == 1) {
|
||||
FAISS_THROW_IF_NOT(orig_codes);
|
||||
FAISS_THROW_IF_NOT(is_max);
|
||||
|
|
|
@ -85,6 +85,7 @@ struct IndexPQFastScan: Index {
|
|||
idx_t n, const float* x, idx_t k,
|
||||
float* distances, idx_t* labels) const;
|
||||
|
||||
|
||||
template<class C>
|
||||
void search_implem_12(
|
||||
idx_t n, const float* x, idx_t k,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -282,6 +283,52 @@ void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct PreTransformDistanceComputer: DistanceComputer {
|
||||
const IndexPreTransform *index;
|
||||
std::unique_ptr<DistanceComputer> sub_dc;
|
||||
std::unique_ptr<const float []> query;
|
||||
|
||||
explicit PreTransformDistanceComputer(const IndexPreTransform *index):
|
||||
index(index),
|
||||
sub_dc(index->index->get_distance_computer())
|
||||
{}
|
||||
|
||||
void set_query(const float *x) override {
|
||||
const float *xt = index->apply_chain (1, x);
|
||||
if (xt == x) {
|
||||
sub_dc->set_query (x);
|
||||
} else {
|
||||
query.reset(xt);
|
||||
sub_dc->set_query (xt);
|
||||
}
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override
|
||||
{
|
||||
return sub_dc->symmetric_dis(i, j);
|
||||
}
|
||||
|
||||
float operator () (idx_t i) override
|
||||
{
|
||||
return (*sub_dc)(i);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
DistanceComputer * IndexPreTransform::get_distance_computer() const {
|
||||
if (chain.empty()) {
|
||||
return index->get_distance_computer();
|
||||
} else {
|
||||
return new PreTransformDistanceComputer(this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -77,6 +77,8 @@ struct IndexPreTransform: Index {
|
|||
void reverse_chain (idx_t n, const float* xt, float* x) const;
|
||||
|
||||
|
||||
DistanceComputer * get_distance_computer() const override;
|
||||
|
||||
/* standalone codec interface */
|
||||
size_t sa_code_size () const override;
|
||||
void sa_encode (idx_t n, const float *x,
|
||||
|
|
|
@ -0,0 +1,253 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
#include <faiss/utils/distances.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
||||
|
||||
/***************************************************
|
||||
* IndexRefine
|
||||
***************************************************/
|
||||
|
||||
IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
|
||||
Index (base_index->d, base_index->metric_type),
|
||||
base_index (base_index),
|
||||
refine_index (refine_index)
|
||||
{
|
||||
own_fields = own_refine_index = false;
|
||||
FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
|
||||
FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
|
||||
is_trained = base_index->is_trained && refine_index->is_trained;
|
||||
FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
|
||||
ntotal = base_index->ntotal;
|
||||
}
|
||||
|
||||
IndexRefine::IndexRefine ():
|
||||
base_index(nullptr), refine_index(nullptr),
|
||||
own_fields(false), own_refine_index(false)
|
||||
{
|
||||
}
|
||||
|
||||
void IndexRefine::train (idx_t n, const float *x)
|
||||
{
|
||||
base_index->train (n, x);
|
||||
refine_index->train (n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
void IndexRefine::add (idx_t n, const float *x) {
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
base_index->add (n, x);
|
||||
refine_index->add (n, x);
|
||||
ntotal = refine_index->ntotal;
|
||||
}
|
||||
|
||||
void IndexRefine::reset ()
|
||||
{
|
||||
base_index->reset ();
|
||||
refine_index->reset ();
|
||||
ntotal = 0;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
typedef faiss::Index::idx_t idx_t;
|
||||
|
||||
template<class C>
|
||||
static void reorder_2_heaps (
|
||||
idx_t n,
|
||||
idx_t k, idx_t *labels, float *distances,
|
||||
idx_t k_base, const idx_t *base_labels, const float *base_distances)
|
||||
{
|
||||
#pragma omp parallel for
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
idx_t *idxo = labels + i * k;
|
||||
float *diso = distances + i * k;
|
||||
const idx_t *idxi = base_labels + i * k_base;
|
||||
const float *disi = base_distances + i * k_base;
|
||||
|
||||
heap_heapify<C> (k, diso, idxo, disi, idxi, k);
|
||||
if (k_base != k) { // add remaining elements
|
||||
heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
|
||||
}
|
||||
heap_reorder<C> (k, diso, idxo);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
|
||||
void IndexRefine::search (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
idx_t k_base = idx_t (k * k_factor);
|
||||
idx_t * base_labels = labels;
|
||||
float * base_distances = distances;
|
||||
ScopeDeleter<idx_t> del1;
|
||||
ScopeDeleter<float> del2;
|
||||
|
||||
if (k != k_base) {
|
||||
base_labels = new idx_t [n * k_base];
|
||||
del1.set (base_labels);
|
||||
base_distances = new float [n * k_base];
|
||||
del2.set (base_distances);
|
||||
}
|
||||
|
||||
base_index->search (n, x, k_base, base_distances, base_labels);
|
||||
|
||||
for (int i = 0; i < n * k_base; i++)
|
||||
assert (base_labels[i] >= -1 &&
|
||||
base_labels[i] < ntotal);
|
||||
|
||||
// parallelize over queries
|
||||
#pragma omp parallel if (n > 1)
|
||||
{
|
||||
std::unique_ptr<DistanceComputer> dc(
|
||||
refine_index->get_distance_computer()
|
||||
);
|
||||
#pragma omp for
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
dc->set_query(x + i * d);
|
||||
idx_t ij = i * k_base;
|
||||
for (idx_t j = 0; j < k_base; j++) {
|
||||
idx_t idx = base_labels[ij];
|
||||
if (idx < 0) break;
|
||||
base_distances[ij] = (*dc)(idx);
|
||||
ij++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort and store result
|
||||
if (metric_type == METRIC_L2) {
|
||||
typedef CMax <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
|
||||
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
typedef CMin <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
} else {
|
||||
FAISS_THROW_MSG("Metric type not supported");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void IndexRefine::reconstruct (idx_t key, float * recons) const {
|
||||
refine_index->reconstruct (key, recons);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
IndexRefine::~IndexRefine ()
|
||||
{
|
||||
if (own_fields) delete base_index;
|
||||
if (own_refine_index) delete refine_index;
|
||||
}
|
||||
|
||||
|
||||
/***************************************************
|
||||
* IndexRefineFlat
|
||||
***************************************************/
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat (Index *base_index):
|
||||
IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
|
||||
{
|
||||
is_trained = base_index->is_trained;
|
||||
own_refine_index = true;
|
||||
FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
|
||||
"base_index should be empty in the beginning");
|
||||
}
|
||||
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
|
||||
IndexRefine (base_index, new IndexFlat(base_index->d, base_index->metric_type))
|
||||
{
|
||||
is_trained = base_index->is_trained;
|
||||
own_refine_index = true;
|
||||
refine_index->add (base_index->ntotal, xb);
|
||||
ntotal = base_index->ntotal;
|
||||
}
|
||||
|
||||
IndexRefineFlat::IndexRefineFlat():
|
||||
IndexRefine()
|
||||
{
|
||||
own_refine_index = true;
|
||||
}
|
||||
|
||||
|
||||
void IndexRefineFlat::search (
|
||||
idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels) const
|
||||
{
|
||||
FAISS_THROW_IF_NOT (is_trained);
|
||||
idx_t k_base = idx_t (k * k_factor);
|
||||
idx_t * base_labels = labels;
|
||||
float * base_distances = distances;
|
||||
ScopeDeleter<idx_t> del1;
|
||||
ScopeDeleter<float> del2;
|
||||
|
||||
if (k != k_base) {
|
||||
base_labels = new idx_t [n * k_base];
|
||||
del1.set (base_labels);
|
||||
base_distances = new float [n * k_base];
|
||||
del2.set (base_distances);
|
||||
}
|
||||
|
||||
base_index->search (n, x, k_base, base_distances, base_labels);
|
||||
|
||||
for (int i = 0; i < n * k_base; i++)
|
||||
assert (base_labels[i] >= -1 &&
|
||||
base_labels[i] < ntotal);
|
||||
|
||||
// compute refined distances
|
||||
auto rf = dynamic_cast<const IndexFlat *>(refine_index);
|
||||
FAISS_THROW_IF_NOT(rf);
|
||||
|
||||
rf->compute_distance_subset (
|
||||
n, x, k_base, base_distances, base_labels);
|
||||
|
||||
// sort and store result
|
||||
if (metric_type == METRIC_L2) {
|
||||
typedef CMax <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
|
||||
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
typedef CMin <float, idx_t> C;
|
||||
reorder_2_heaps<C> (
|
||||
n, k, labels, distances,
|
||||
k_base, base_labels, base_distances);
|
||||
} else {
|
||||
FAISS_THROW_MSG("Metric type not supported");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace faiss
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/Index.h>
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
||||
/** Index that queries in a base_index (a fast one) and refines the
|
||||
* results with an exact search, hopefully improving the results.
|
||||
*/
|
||||
struct IndexRefine: Index {
|
||||
|
||||
/// faster index to pre-select the vectors that should be filtered
|
||||
Index *base_index;
|
||||
|
||||
/// refinement index
|
||||
Index *refine_index;
|
||||
|
||||
bool own_fields; ///< should the base index be deallocated?
|
||||
bool own_refine_index; ///< same with the refinement index
|
||||
|
||||
/// factor between k requested in search and the k requested from
|
||||
/// the base_index (should be >= 1)
|
||||
float k_factor = 1;
|
||||
|
||||
/// intitialize from empty index
|
||||
IndexRefine (Index *base_index, Index *refine_index);
|
||||
|
||||
IndexRefine ();
|
||||
|
||||
void train(idx_t n, const float* x) override;
|
||||
|
||||
void add(idx_t n, const float* x) override;
|
||||
|
||||
void reset() override;
|
||||
|
||||
void search(
|
||||
idx_t n, const float* x, idx_t k,
|
||||
float* distances, idx_t* labels) const override;
|
||||
|
||||
// reconstruct is routed to the refine_index
|
||||
void reconstruct (idx_t key, float * recons) const override;
|
||||
|
||||
~IndexRefine() override;
|
||||
};
|
||||
|
||||
|
||||
/** Version where the refinement index is an IndexFlat. It has one additional
|
||||
* constructor that takes a table of elements to add to the flat refinement
|
||||
* index */
|
||||
struct IndexRefineFlat: IndexRefine {
|
||||
explicit IndexRefineFlat (Index *base_index);
|
||||
IndexRefineFlat(Index *base_index, const float *xb);
|
||||
|
||||
IndexRefineFlat();
|
||||
|
||||
void search(
|
||||
idx_t n, const float* x, idx_t k,
|
||||
float* distances, idx_t* labels) const override;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace faiss
|
|
@ -192,7 +192,7 @@ void IndexIVFScalarQuantizer::encode_vectors(idx_t n, const float* x,
|
|||
size_t coarse_size = include_listnos ? coarse_code_size () : 0;
|
||||
memset(codes, 0, (code_size + coarse_size) * n);
|
||||
|
||||
#pragma omp parallel if(n > 1)
|
||||
#pragma omp parallel if(n > 1000)
|
||||
{
|
||||
std::vector<float> residual (d);
|
||||
|
||||
|
@ -222,7 +222,7 @@ void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes,
|
|||
std::unique_ptr<ScalarQuantizer::Quantizer> squant (sq.select_quantizer ());
|
||||
size_t coarse_size = coarse_code_size ();
|
||||
|
||||
#pragma omp parallel if(n > 1)
|
||||
#pragma omp parallel if(n > 1000)
|
||||
{
|
||||
std::vector<float> residual (d);
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ struct IndexScalarQuantizer: Index {
|
|||
|
||||
|
||||
/** An IVF implementation where the components of the residuals are
|
||||
* encoded with a scalar uniform quantizer. All distance computations
|
||||
* encoded with a scalar quantizer. All distance computations
|
||||
* are asymmetric, so the encoded vectors are decoded and approximate
|
||||
* distances are computed.
|
||||
*/
|
||||
|
|
|
@ -204,7 +204,8 @@ struct RangeSearchPartialResult: BufferList {
|
|||
struct DistanceComputer {
|
||||
using idx_t = Index::idx_t;
|
||||
|
||||
/// called before computing distances
|
||||
/// called before computing distances. Pointer x should remain valid
|
||||
/// while operator () is called
|
||||
virtual void set_query(const float *x) = 0;
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
|
|
|
@ -219,12 +219,14 @@ struct PQDecoderGeneric {
|
|||
};
|
||||
|
||||
struct PQDecoder8 {
|
||||
static const int nbits = 8;
|
||||
const uint8_t *code;
|
||||
PQDecoder8(const uint8_t *code, int nbits);
|
||||
uint64_t decode();
|
||||
};
|
||||
|
||||
struct PQDecoder16 {
|
||||
static const int nbits = 16;
|
||||
const uint16_t *code;
|
||||
PQDecoder16(const uint8_t *code, int nbits);
|
||||
uint64_t decode();
|
||||
|
|
|
@ -39,8 +39,12 @@ namespace faiss {
|
|||
* that hides the template mess.
|
||||
********************************************************************/
|
||||
|
||||
#if defined(__F16C__) && defined(__AVX2__)
|
||||
#ifdef __AVX2__
|
||||
#ifdef __F16C__
|
||||
#define USE_F16C
|
||||
#else
|
||||
#warning "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include <faiss/IndexLattice.h>
|
||||
#include <faiss/IndexPQFastScan.h>
|
||||
#include <faiss/IndexIVFPQFastScan.h>
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/IndexBinaryFromFloat.h>
|
||||
|
@ -551,14 +552,20 @@ Index *read_index (IOReader *f, int io_flags) {
|
|||
read_ProductQuantizer (&imiq->pq, f);
|
||||
idx = imiq;
|
||||
} else if(h == fourcc ("IxRF")) {
|
||||
IndexRefineFlat *idxrf = new IndexRefineFlat ();
|
||||
IndexRefine *idxrf = new IndexRefine ();
|
||||
read_index_header (idxrf, f);
|
||||
idxrf->base_index = read_index(f, io_flags);
|
||||
idxrf->own_fields = true;
|
||||
IndexFlat *rf = dynamic_cast<IndexFlat*> (read_index (f, io_flags));
|
||||
std::swap (*rf, idxrf->refine_index);
|
||||
delete rf;
|
||||
idxrf->refine_index = read_index(f, io_flags);
|
||||
READ1 (idxrf->k_factor);
|
||||
if (dynamic_cast<IndexFlat*>(idxrf->refine_index)) {
|
||||
// then make a RefineFlat with it
|
||||
IndexRefine *idxrf_old = idxrf;
|
||||
idxrf = new IndexRefineFlat();
|
||||
*idxrf = *idxrf_old;
|
||||
delete idxrf_old;
|
||||
}
|
||||
idxrf->own_fields = true;
|
||||
idxrf->own_refine_index = true;
|
||||
idx = idxrf;
|
||||
} else if(h == fourcc ("IxMp") || h == fourcc ("IxM2")) {
|
||||
bool is_map2 = h == fourcc ("IxM2");
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include <faiss/IndexLattice.h>
|
||||
#include <faiss/IndexPQFastScan.h>
|
||||
#include <faiss/IndexIVFPQFastScan.h>
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/IndexBinaryFromFloat.h>
|
||||
|
@ -392,13 +393,13 @@ void write_index (const Index *idx, IOWriter *f) {
|
|||
WRITE1 (h);
|
||||
write_index_header (imiq, f);
|
||||
write_ProductQuantizer (&imiq->pq, f);
|
||||
} else if(const IndexRefineFlat * idxrf =
|
||||
dynamic_cast<const IndexRefineFlat *> (idx)) {
|
||||
} else if(const IndexRefine * idxrf =
|
||||
dynamic_cast<const IndexRefine *> (idx)) {
|
||||
uint32_t h = fourcc ("IxRF");
|
||||
WRITE1 (h);
|
||||
write_index_header (idxrf, f);
|
||||
write_index (idxrf->base_index, f);
|
||||
write_index (&idxrf->refine_index, f);
|
||||
write_index (idxrf->refine_index, f);
|
||||
WRITE1 (idxrf->k_factor);
|
||||
} else if(const IndexIDMap * idxmap =
|
||||
dynamic_cast<const IndexIDMap *> (idx)) {
|
||||
|
|
|
@ -36,6 +36,8 @@
|
|||
#include <faiss/IndexLattice.h>
|
||||
#include <faiss/IndexPQFastScan.h>
|
||||
#include <faiss/IndexIVFPQFastScan.h>
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/IndexBinaryHNSW.h>
|
||||
|
@ -64,11 +66,53 @@ struct VTChain {
|
|||
/// what kind of training does this coarse quantizer require?
|
||||
char get_trains_alone(const Index *coarse_quantizer) {
|
||||
return
|
||||
dynamic_cast<const IndexFlat*>(coarse_quantizer) ? 0 :
|
||||
// multi index just needs to be quantized
|
||||
dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ? 1 :
|
||||
dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer) ? 2 :
|
||||
0;
|
||||
2; // for complicated indexes, we assume they can't be used as a kmeans index
|
||||
}
|
||||
|
||||
bool str_ends_with(const std::string& s, const std::string& suffix)
|
||||
{
|
||||
return s.rfind(suffix) == std::abs(int(s.size()-suffix.size()));
|
||||
}
|
||||
|
||||
// check if ends with suffix followed by digits
|
||||
bool str_ends_with_digits(const std::string& s, const std::string& suffix)
|
||||
{
|
||||
int i;
|
||||
for(i = s.length() - 1; i >= 0; i--) {
|
||||
if (!isdigit(s[i])) break;
|
||||
}
|
||||
return str_ends_with(s.substr(0, i + 1), suffix);
|
||||
}
|
||||
|
||||
void find_matching_parentheses(const std::string &s, int & i0, int & i1) {
|
||||
int st = 0;
|
||||
for (int i = 0; i < s.length(); i++) {
|
||||
if (s[i] == '(') {
|
||||
if (st == 0) {
|
||||
i0 = i;
|
||||
}
|
||||
st++;
|
||||
}
|
||||
|
||||
if (s[i] == ')') {
|
||||
st--;
|
||||
if (st == 0) {
|
||||
i1 = i;
|
||||
return;
|
||||
}
|
||||
if (st < 0) {
|
||||
FAISS_THROW_FMT("factory string %s: unbalanced parentheses", s.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
FAISS_THROW_FMT("factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
|
||||
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -78,31 +122,32 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
metric == METRIC_INNER_PRODUCT);
|
||||
VTChain vts;
|
||||
Index *coarse_quantizer = nullptr;
|
||||
std::unique_ptr<Index> parenthesis_index;
|
||||
std::string parenthesis_ivf, parenthesis_refine;
|
||||
Index *index = nullptr;
|
||||
bool add_idmap = false;
|
||||
bool make_IndexRefineFlat = false;
|
||||
int d_in = d;
|
||||
|
||||
ScopeDeleter1<Index> del_coarse_quantizer, del_index;
|
||||
|
||||
std::string description(description_in);
|
||||
char *ptr;
|
||||
|
||||
if (description.find('(') != std::string::npos) {
|
||||
// handle indexes in parentheses
|
||||
while (description.find('(') != std::string::npos) {
|
||||
// then we make a sub-index and remove the () from the description
|
||||
int i0 = description.find('(');
|
||||
int i1 = description.find(')');
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
i1 != std::string::npos, "string must contain closing parenthesis");
|
||||
int i0, i1;
|
||||
find_matching_parentheses(description, i0, i1);
|
||||
|
||||
std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
|
||||
// printf("substring=%s\n", sub_description.c_str());
|
||||
|
||||
parenthesis_index.reset(index_factory(d, sub_description.c_str(), metric));
|
||||
|
||||
if (str_ends_with_digits(description.substr(0, i0), "IVF")) {
|
||||
parenthesis_ivf = sub_description;
|
||||
} else if (str_ends_with(description.substr(0, i0), "Refine")) {
|
||||
parenthesis_refine = sub_description;
|
||||
} else {
|
||||
FAISS_THROW_MSG("don't know what to do with parenthesis index");
|
||||
}
|
||||
description = description.erase(i0, i1 - i0 + 1);
|
||||
|
||||
// printf("new description=%s\n", description.c_str());
|
||||
|
||||
}
|
||||
|
||||
int64_t ncentroids = -1;
|
||||
|
@ -162,12 +207,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
// coarse quantizers
|
||||
} else if (!coarse_quantizer &&
|
||||
sscanf (tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
|
||||
coarse_quantizer_1 = new IndexHNSWFlat (d, M);
|
||||
coarse_quantizer_1 = new IndexHNSWFlat (d, M, metric);
|
||||
|
||||
} else if (!coarse_quantizer &&
|
||||
sscanf (tok, "IVF%" PRId64, &ncentroids) == 1) {
|
||||
if (parenthesis_index) {
|
||||
coarse_quantizer_1 = parenthesis_index.release();
|
||||
if (!parenthesis_ivf.empty()) {
|
||||
coarse_quantizer_1 =
|
||||
index_factory(d, parenthesis_ivf.c_str(), metric);
|
||||
|
||||
} else if (metric == METRIC_L2) {
|
||||
coarse_quantizer_1 = new IndexFlatL2 (d);
|
||||
} else {
|
||||
|
@ -254,11 +301,13 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
index_ivf->own_fields = true;
|
||||
index_1 = index_ivf;
|
||||
} else if (!index && (
|
||||
sscanf (tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
|
||||
(sscanf (tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') )) {
|
||||
sscanf (tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
|
||||
(sscanf (tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
|
||||
(sscanf (tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
|
||||
if (bbs == -1) {
|
||||
bbs = 32;
|
||||
}
|
||||
bool by_residual = str_ends_with(stok, "fsr");
|
||||
if (coarse_quantizer) {
|
||||
IndexIVFPQFastScan *index_ivf = new IndexIVFPQFastScan(
|
||||
coarse_quantizer, d, ncentroids, M, 4, metric, bbs
|
||||
|
@ -266,13 +315,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
index_ivf->quantizer_trains_alone =
|
||||
get_trains_alone (coarse_quantizer);
|
||||
index_ivf->metric_type = metric;
|
||||
index_ivf->by_residual = by_residual;
|
||||
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
||||
del_coarse_quantizer.release ();
|
||||
index_ivf->own_fields = true;
|
||||
index_1 = index_ivf;
|
||||
} else {
|
||||
IndexPQFastScan *index_pq = new IndexPQFastScan (
|
||||
d, M, nbit, metric, bbs
|
||||
d, M, 4, metric, bbs
|
||||
);
|
||||
index_1 = index_pq;
|
||||
}
|
||||
|
@ -347,7 +397,12 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
FAISS_THROW_IF_NOT(!coarse_quantizer);
|
||||
index_1 = new IndexLattice(d, M, nbit, r2);
|
||||
} else if (stok == "RFlat") {
|
||||
make_IndexRefineFlat = true;
|
||||
parenthesis_refine = "Flat";
|
||||
} else if (stok == "Refine") {
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
!parenthesis_refine.empty(),
|
||||
"Refine index should be provided in parentheses"
|
||||
);
|
||||
} else {
|
||||
FAISS_THROW_FMT( "could not parse token \"%s\" in %s\n",
|
||||
tok, description_in);
|
||||
|
@ -404,8 +459,10 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
|
|||
index = index_pt;
|
||||
}
|
||||
|
||||
if (make_IndexRefineFlat) {
|
||||
IndexRefineFlat *index_rf = new IndexRefineFlat (index);
|
||||
if (!parenthesis_refine.empty()) {
|
||||
Index *refine_index = index_factory(d_in, parenthesis_refine.c_str(), metric);
|
||||
IndexRefine *index_rf = new IndexRefine(index, refine_index);
|
||||
index_rf->own_refine_index = true;
|
||||
index_rf->own_fields = true;
|
||||
index = index_rf;
|
||||
}
|
||||
|
|
|
@ -494,7 +494,8 @@ add_ref_in_constructor(IndexIDMap2, 0)
|
|||
add_ref_in_constructor(IndexHNSW, 0)
|
||||
add_ref_in_method(IndexShards, 'add_shard', 0)
|
||||
add_ref_in_method(IndexBinaryShards, 'add_shard', 0)
|
||||
# add_ref_in_constructor(IndexRefineFlat, 0)
|
||||
add_ref_in_constructor(IndexRefineFlat, {2:[0], 1:[0]})
|
||||
add_ref_in_constructor(IndexRefine, {2:[0, 1]})
|
||||
|
||||
add_ref_in_constructor(IndexBinaryIVF, 0)
|
||||
add_ref_in_constructor(IndexBinaryFromFloat, 0)
|
||||
|
@ -510,24 +511,6 @@ add_ref_in_constructor(BufferedIOReader, 0)
|
|||
# seems really marginal...
|
||||
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0)
|
||||
|
||||
def handle_IndexRefineFlat(the_class):
|
||||
|
||||
original_init = the_class.__init__
|
||||
|
||||
def replacement_init(self, *args):
|
||||
if len(args) == 2:
|
||||
index, xb = args
|
||||
assert xb.shape == (index.ntotal, index.d)
|
||||
xb = swig_ptr(xb)
|
||||
args = (index, xb)
|
||||
|
||||
original_init(self, *args)
|
||||
self.referenced_objects = [args[0]]
|
||||
|
||||
the_class.__init__ = replacement_init
|
||||
|
||||
handle_IndexRefineFlat(IndexRefineFlat)
|
||||
|
||||
###########################################
|
||||
# GPU functions
|
||||
###########################################
|
||||
|
|
|
@ -88,6 +88,8 @@ typedef uint64_t size_t;
|
|||
#include <faiss/impl/HNSW.h>
|
||||
#include <faiss/IndexHNSW.h>
|
||||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/IndexRefine.h>
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
|
@ -356,6 +358,8 @@ void gpu_sync_all_devices()
|
|||
%include <faiss/utils/random.h>
|
||||
|
||||
%include <faiss/MetricType.h>
|
||||
|
||||
%newobject *::get_distance_computer() const;
|
||||
%include <faiss/Index.h>
|
||||
%include <faiss/Clustering.h>
|
||||
|
||||
|
@ -368,6 +372,7 @@ void gpu_sync_all_devices()
|
|||
%include <faiss/VectorTransform.h>
|
||||
%include <faiss/IndexPreTransform.h>
|
||||
%include <faiss/IndexFlat.h>
|
||||
%include <faiss/IndexRefine.h>
|
||||
%include <faiss/IndexLSH.h>
|
||||
%include <faiss/impl/PolysemousTraining.h>
|
||||
%include <faiss/IndexPQ.h>
|
||||
|
@ -436,6 +441,18 @@ void gpu_sync_all_devices()
|
|||
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
|
||||
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;
|
||||
|
||||
|
||||
|
||||
%ignore faiss::BufferList::Buffer;
|
||||
%ignore faiss::RangeSearchPartialResult::QueryResult;
|
||||
%ignore faiss::IDSelectorBatch::set;
|
||||
%ignore faiss::IDSelectorBatch::bloom;
|
||||
%ignore faiss::InterruptCallback::instance;
|
||||
%ignore faiss::InterruptCallback::lock;
|
||||
|
||||
%include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
#ifdef GPU_WRAPPER
|
||||
|
||||
// quiet SWIG warnings
|
||||
|
@ -513,6 +530,7 @@ void gpu_sync_all_devices()
|
|||
DOWNCAST ( IndexIVF )
|
||||
DOWNCAST ( IndexFlat )
|
||||
DOWNCAST ( IndexRefineFlat )
|
||||
DOWNCAST ( IndexRefine )
|
||||
DOWNCAST ( IndexPQFastScan )
|
||||
DOWNCAST ( IndexPQ )
|
||||
DOWNCAST ( IndexScalarQuantizer )
|
||||
|
@ -907,15 +925,6 @@ void * cast_integer_to_void_ptr (long long x) {
|
|||
* Range search interface
|
||||
*******************************************************************/
|
||||
|
||||
%ignore faiss::BufferList::Buffer;
|
||||
%ignore faiss::RangeSearchPartialResult::QueryResult;
|
||||
%ignore faiss::IDSelectorBatch::set;
|
||||
%ignore faiss::IDSelectorBatch::bloom;
|
||||
|
||||
%ignore faiss::InterruptCallback::instance;
|
||||
%ignore faiss::InterruptCallback::lock;
|
||||
%include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
%inline %{
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ float round_uint8_and_mul(float *tab, size_t n) {
|
|||
return multiplier;
|
||||
}
|
||||
|
||||
|
||||
// there can be NaNs in tables, they should be ignored
|
||||
float tab_min(const float *tab, size_t n) {
|
||||
float min = HUGE_VAL;
|
||||
for(int i = 0; i < n; i++) {
|
||||
|
@ -185,6 +185,7 @@ void quantize_LUT_and_bias(
|
|||
round_tab(bias, nprobe, a, bias_min, biasq);
|
||||
|
||||
} else if (biasq) {
|
||||
// LUT is 3D
|
||||
std::vector<float> mins(nprobe * M);
|
||||
std::vector<float> bias2(nprobe);
|
||||
float bias_min = tab_min(bias, nprobe);
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
import faiss
|
||||
|
||||
|
||||
class TestParameterSpace(unittest.TestCase):
|
||||
|
||||
def test_nprobe(self):
|
||||
index = faiss.index_factory(32, "IVF32,Flat")
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.set_index_parameter(index, "nprobe", 5)
|
||||
self.assertEqual(index.nprobe, 5)
|
||||
|
||||
def test_nprobe_2(self):
|
||||
index = faiss.index_factory(32, "IDMap,IVF32,Flat")
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.set_index_parameter(index, "nprobe", 5)
|
||||
index2 = faiss.downcast_index(index.index)
|
||||
self.assertEqual(index2.nprobe, 5)
|
||||
|
||||
def test_nprobe_3(self):
|
||||
index = faiss.index_factory(32, "IVF32,SQ8,RFlat")
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.set_index_parameter(index, "nprobe", 5)
|
||||
index2 = faiss.downcast_index(index.base_index)
|
||||
self.assertEqual(index2.nprobe, 5)
|
||||
|
||||
def test_nprobe_4(self):
|
||||
index = faiss.index_factory(32, "PCAR32,IVF32,SQ8,RFlat")
|
||||
ps = faiss.ParameterSpace()
|
||||
|
||||
ps.set_index_parameter(index, "nprobe", 5)
|
||||
index2 = faiss.downcast_index(index.base_index)
|
||||
index2 = faiss.downcast_index(index2.index)
|
||||
self.assertEqual(index2.nprobe, 5)
|
||||
|
||||
def test_efSearch(self):
|
||||
index = faiss.index_factory(32, "IVF32_HNSW32,SQ8")
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.set_index_parameter(index, "quantizer_efSearch", 5)
|
||||
index2 = faiss.downcast_index(index.quantizer)
|
||||
self.assertEqual(index2.hnsw.efSearch, 5)
|
|
@ -3,12 +3,13 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import faiss
|
||||
|
||||
from faiss.contrib import factory_tools
|
||||
|
||||
|
||||
class TestFactory(unittest.TestCase):
|
||||
|
||||
|
@ -75,6 +76,7 @@ class TestFactory(unittest.TestCase):
|
|||
def test_factory_fast_scan(self):
|
||||
index = faiss.index_factory(56, "PQ28x4fs")
|
||||
self.assertEqual(index.bbs, 32)
|
||||
self.assertEqual(index.pq.nbits, 4)
|
||||
index = faiss.index_factory(56, "PQ28x4fs_64")
|
||||
self.assertEqual(index.bbs, 64)
|
||||
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", faiss.METRIC_INNER_PRODUCT)
|
||||
|
@ -89,6 +91,45 @@ class TestFactory(unittest.TestCase):
|
|||
quantizer = faiss.downcast_index(index.quantizer)
|
||||
self.assertEqual(quantizer.pq.M, 25)
|
||||
|
||||
def test_parenthesis_2(self):
|
||||
index = faiss.index_factory(50, "PCA30,IVF32(PQ15),Flat")
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
quantizer = faiss.downcast_index(index_ivf.quantizer)
|
||||
self.assertEqual(quantizer.pq.M, 15)
|
||||
self.assertEqual(quantizer.d, 30)
|
||||
|
||||
def test_parenthesis_refine(self):
|
||||
index = faiss.index_factory(50, "IVF32,Flat,Refine(PQ25x12)")
|
||||
rf = faiss.downcast_index(index.refine_index)
|
||||
self.assertEqual(rf.pq.M, 25)
|
||||
self.assertEqual(rf.pq.nbits, 12)
|
||||
|
||||
|
||||
def test_parenthesis_refine_2(self):
|
||||
# Refine applies on the whole index including pre-transforms
|
||||
index = faiss.index_factory(50, "PCA32,IVF32,Flat,Refine(PQ25x12)")
|
||||
rf = faiss.downcast_index(index.refine_index)
|
||||
self.assertEqual(rf.pq.M, 25)
|
||||
|
||||
def test_nested_parenteses(self):
|
||||
index = faiss.index_factory(50, "IVF1000(IVF20,SQ4,Refine(SQ8)),Flat")
|
||||
q = faiss.downcast_index(index.quantizer)
|
||||
qref = faiss.downcast_index(q.refine_index)
|
||||
# check we can access the scalar quantizer
|
||||
self.assertEqual(qref.sq.code_size, 50)
|
||||
|
||||
def test_residual(self):
|
||||
index = faiss.index_factory(50, "IVF1000,PQ25x4fsr")
|
||||
self.assertTrue(index.by_residual)
|
||||
|
||||
class TestCodeSize(unittest.TestCase):
|
||||
|
||||
def test_1(self):
|
||||
self.assertEqual(
|
||||
factory_tools.get_code_size(50, "IVF32,Flat,Refine(PQ25x12)"),
|
||||
50 * 4 + (25 * 12 + 7) // 8
|
||||
)
|
||||
|
||||
|
||||
class TestCloneSize(unittest.TestCase):
|
||||
|
||||
|
|
|
@ -388,7 +388,7 @@ class TestSearchAndReconstruct(unittest.TestCase):
|
|||
R_ref = index.reconstruct_n(0, n)
|
||||
D, I, R = index.search_and_reconstruct(xq, k)
|
||||
|
||||
self.assertTrue((D == D_ref).all())
|
||||
np.testing.assert_almost_equal(D, D_ref, decimal=5)
|
||||
self.assertTrue((I == I_ref).all())
|
||||
self.assertEqual(R.shape[:2], I.shape)
|
||||
self.assertEqual(R.shape[2], d)
|
||||
|
|
|
@ -303,7 +303,7 @@ class TestSQFlavors(unittest.TestCase):
|
|||
index.nprobe = 4 # hopefully more robust than 1
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
|
||||
for pm in 1, 2:
|
||||
for pm in 1, 2, 3:
|
||||
index.parallel_mode = pm
|
||||
|
||||
Dnew, Inew = index.search(xq, 10)
|
||||
|
@ -692,7 +692,6 @@ class TestRefine(unittest.TestCase):
|
|||
d = 32
|
||||
xt, xb, xq = get_dataset_2(d, 2000, 1000, 200)
|
||||
index1 = faiss.index_factory(d, "PQ4x4np", metric)
|
||||
|
||||
Dref, Iref = faiss.knn(xq, xb, 10, metric)
|
||||
|
||||
index1.train(xt)
|
||||
|
@ -703,7 +702,10 @@ class TestRefine(unittest.TestCase):
|
|||
recall1 = (I1 == Iref[:, :1]).sum()
|
||||
|
||||
# add refine index on top
|
||||
index2 = faiss.IndexRefineFlat(index1, xb)
|
||||
index_flat = faiss.IndexFlat(d, metric)
|
||||
index_flat.add(xb)
|
||||
|
||||
index2 = faiss.IndexRefine(index1, index_flat)
|
||||
index2.k_factor = 10.0
|
||||
D2, I2 = index2.search(xq, 10)
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import faiss
|
||||
|
||||
from faiss.contrib import datasets
|
||||
|
||||
|
||||
class TestDistanceComputer(unittest.TestCase):
|
||||
|
||||
def do_test(self, factory_string, metric_type=faiss.METRIC_L2):
|
||||
ds = datasets.SyntheticDataset(32, 1000, 200, 20)
|
||||
|
||||
index = faiss.index_factory(32, factory_string, metric_type)
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
xq = ds.get_queries()
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
dc = index.get_distance_computer()
|
||||
self.assertTrue(dc.this.own())
|
||||
for q in range(ds.nq):
|
||||
dc.set_query(faiss.swig_ptr(xq[q]))
|
||||
for j in range(10):
|
||||
ref_dis = Dref[q, j]
|
||||
new_dis = dc(int(Iref[q, j]))
|
||||
np.testing.assert_almost_equal(
|
||||
new_dis, ref_dis, decimal=5)
|
||||
|
||||
def test_distance_computer_PQ(self):
|
||||
self.do_test("PQ8np")
|
||||
|
||||
def test_distance_computer_SQ(self):
|
||||
self.do_test("SQ8")
|
||||
|
||||
def test_distance_computer_SQ6(self):
|
||||
self.do_test("SQ6")
|
||||
|
||||
def test_distance_computer_PQbit6(self):
|
||||
self.do_test("PQ8x6np")
|
||||
|
||||
def test_distance_computer_PQbit6_ip(self):
|
||||
self.do_test("PQ8x6np", faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
def test_distance_computer_VT(self):
|
||||
self.do_test("PCA20,SQ8")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue