Search and return codes (#3143)
Summary: This PR adds a functionality where an IVF index can be searched and the corresponding codes be returned. It also adds a few functions to compress int arrays into a bit-compact representation. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3143 Test Plan: ``` buck test //faiss/tests/:test_index_composite -- TestSearchAndReconstruct buck test //faiss/tests/:test_standalone_codec -- test_arrays ``` Reviewed By: algoriddle Differential Revision: D51544613 Pulled By: mdouze fbshipit-source-id: 875f72d0f9140096851592422570efa0f65431fcpull/3145/head
parent
467f70edbf
commit
b109d086a2
|
@ -7,6 +7,7 @@ import argparse
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
@ -19,105 +20,6 @@ except ModuleNotFoundError:
|
|||
sanitize = datasets.sanitize
|
||||
|
||||
|
||||
######################################################
|
||||
# Command-line parsing
|
||||
######################################################
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
||||
def aa(*args, **kwargs):
|
||||
group.add_argument(*args, **kwargs)
|
||||
|
||||
|
||||
group = parser.add_argument_group('dataset options')
|
||||
|
||||
aa('--db', default='deep1M', help='dataset')
|
||||
aa('--compute_gt', default=False, action='store_true',
|
||||
help='compute and store the groundtruth')
|
||||
aa('--force_IP', default=False, action="store_true",
|
||||
help='force IP search instead of L2')
|
||||
|
||||
group = parser.add_argument_group('index consturction')
|
||||
|
||||
aa('--indexkey', default='HNSW32', help='index_factory type')
|
||||
aa('--maxtrain', default=256 * 256, type=int,
|
||||
help='maximum number of training points (0 to set automatically)')
|
||||
aa('--indexfile', default='', help='file to read or write index from')
|
||||
aa('--add_bs', default=-1, type=int,
|
||||
help='add elements index by batches of this size')
|
||||
|
||||
|
||||
group = parser.add_argument_group('IVF options')
|
||||
aa('--by_residual', default=-1, type=int,
|
||||
help="set if index should use residuals (default=unchanged)")
|
||||
aa('--no_precomputed_tables', action='store_true', default=False,
|
||||
help='disable precomputed tables (uses less memory)')
|
||||
aa('--get_centroids_from', default='',
|
||||
help='get the centroids from this index (to speed up training)')
|
||||
aa('--clustering_niter', default=-1, type=int,
|
||||
help='number of clustering iterations (-1 = leave default)')
|
||||
aa('--train_on_gpu', default=False, action='store_true',
|
||||
help='do training on GPU')
|
||||
|
||||
|
||||
group = parser.add_argument_group('index-specific options')
|
||||
aa('--M0', default=-1, type=int, help='size of base level for HNSW')
|
||||
aa('--RQ_train_default', default=False, action="store_true",
|
||||
help='disable progressive dim training for RQ')
|
||||
aa('--RQ_beam_size', default=-1, type=int,
|
||||
help='set beam size at add time')
|
||||
aa('--LSQ_encode_ils_iters', default=-1, type=int,
|
||||
help='ILS iterations for LSQ')
|
||||
aa('--RQ_use_beam_LUT', default=-1, type=int,
|
||||
help='use beam LUT at add time')
|
||||
|
||||
group = parser.add_argument_group('searching')
|
||||
|
||||
aa('--k', default=100, type=int, help='nb of nearest neighbors')
|
||||
aa('--inter', default=False, action='store_true',
|
||||
help='use intersection measure instead of 1-recall as metric')
|
||||
aa('--searchthreads', default=-1, type=int,
|
||||
help='nb of threads to use at search time')
|
||||
aa('--searchparams', nargs='+', default=['autotune'],
|
||||
help="search parameters to use (can be autotune or a list of params)")
|
||||
aa('--n_autotune', default=500, type=int,
|
||||
help="max nb of autotune experiments")
|
||||
aa('--autotune_max', default=[], nargs='*',
|
||||
help='set max value for autotune variables format "var:val" (exclusive)')
|
||||
aa('--autotune_range', default=[], nargs='*',
|
||||
help='set complete autotune range, format "var:val1,val2,..."')
|
||||
aa('--min_test_duration', default=3.0, type=float,
|
||||
help='run test at least for so long to avoid jitter')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("args:", args)
|
||||
|
||||
os.system('echo -n "nb processors "; '
|
||||
'cat /proc/cpuinfo | grep ^processor | wc -l; '
|
||||
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
|
||||
|
||||
######################################################
|
||||
# Load dataset
|
||||
######################################################
|
||||
|
||||
ds = datasets.load_dataset(
|
||||
dataset=args.db, compute_gt=args.compute_gt)
|
||||
|
||||
if args.force_IP:
|
||||
ds.metric = "IP"
|
||||
|
||||
print(ds)
|
||||
|
||||
nq, d = ds.nq, ds.d
|
||||
nb, d = ds.nq, ds.d
|
||||
|
||||
|
||||
######################################################
|
||||
# Make index
|
||||
######################################################
|
||||
|
||||
def unwind_index_ivf(index):
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
|
@ -125,6 +27,10 @@ def unwind_index_ivf(index):
|
|||
vt = index.chain.at(0)
|
||||
index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index))
|
||||
assert vt2 is None
|
||||
if vt is None:
|
||||
vt = lambda x: x
|
||||
else:
|
||||
vt = faiss.downcast_VectorTransform(vt)
|
||||
return index_ivf, vt
|
||||
if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine):
|
||||
return unwind_index_ivf(faiss.downcast_index(index.base_index))
|
||||
|
@ -157,16 +63,50 @@ def apply_AQ_options(index, args):
|
|||
index.rq.use_beam_LUT = args.RQ_use_beam_LUT
|
||||
|
||||
|
||||
if args.indexfile and os.path.exists(args.indexfile):
|
||||
|
||||
print("reading", args.indexfile)
|
||||
index = faiss.read_index(args.indexfile)
|
||||
def eval_setting(index, xq, gt, k, inter, min_time):
|
||||
""" evaluate searching in terms of precision vs. speed """
|
||||
nq = xq.shape[0]
|
||||
ivf_stats = faiss.cvar.indexIVF_stats
|
||||
ivf_stats.reset()
|
||||
nrun = 0
|
||||
t0 = time.time()
|
||||
while True:
|
||||
D, I = index.search(xq, k)
|
||||
nrun += 1
|
||||
t1 = time.time()
|
||||
if t1 - t0 > min_time:
|
||||
break
|
||||
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
|
||||
res = {
|
||||
"ms_per_query": ms_per_query,
|
||||
"nrun": nrun
|
||||
}
|
||||
res["n"] = ms_per_query
|
||||
if inter:
|
||||
rank = k
|
||||
inter_measure = faiss.eval_intersection(gt[:, :rank], I[:, :rank]) / (nq * rank)
|
||||
print("%.4f" % inter_measure, end=' ')
|
||||
res["inter_measure"] = inter_measure
|
||||
else:
|
||||
res["recalls"] = {}
|
||||
for rank in 1, 10, 100:
|
||||
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
|
||||
print("%.4f" % recall, end=' ')
|
||||
res["recalls"][rank] = recall
|
||||
print(" %9.5f " % ms_per_query, end=' ')
|
||||
print("%12d " % (ivf_stats.ndis / nrun), end=' ')
|
||||
print(nrun)
|
||||
res["ndis"] = ivf_stats.ndis / nrun
|
||||
return res
|
||||
|
||||
index_ivf, vec_transform = unwind_index_ivf(index)
|
||||
if vec_transform is None:
|
||||
vec_transform = lambda x: x
|
||||
######################################################
|
||||
# Training
|
||||
######################################################
|
||||
|
||||
else:
|
||||
def run_train(args, ds, res):
|
||||
nq, d = ds.nq, ds.d
|
||||
nb, d = ds.nq, ds.d
|
||||
|
||||
print("build index, key=", args.indexkey)
|
||||
|
||||
|
@ -176,10 +116,6 @@ else:
|
|||
)
|
||||
|
||||
index_ivf, vec_transform = unwind_index_ivf(index)
|
||||
if vec_transform is None:
|
||||
vec_transform = lambda x: x
|
||||
else:
|
||||
vec_transform = faiss.downcast_VectorTransform(vec_transform)
|
||||
|
||||
if args.by_residual != -1:
|
||||
by_residual = args.by_residual == 1
|
||||
|
@ -205,9 +141,14 @@ else:
|
|||
64)
|
||||
print(base_index.nprobe)
|
||||
elif isinstance(quantizer, faiss.IndexHNSW):
|
||||
print(" update quantizer efSearch=", quantizer.hnsw.efSearch, end=" -> ")
|
||||
quantizer.hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64
|
||||
print(quantizer.hnsw.efSearch)
|
||||
hnsw = quantizer.hnsw
|
||||
print(
|
||||
f" update HNSW quantizer options, before: "
|
||||
f"{hnsw.efSearch=:} {hnsw.efConstruction=:}"
|
||||
)
|
||||
hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64
|
||||
hnsw.efConstruction = 200
|
||||
print(f" after: {hnsw.efSearch=:} {hnsw.efConstruction=:}")
|
||||
|
||||
apply_AQ_options(index_ivf or index, args)
|
||||
|
||||
|
@ -286,182 +227,341 @@ else:
|
|||
|
||||
t0 = time.time()
|
||||
index.train(xt2)
|
||||
print(" train in %.3f s" % (time.time() - t0))
|
||||
res.train_time = time.time() - t0
|
||||
print(" train in %.3f s" % res.train_time)
|
||||
return index
|
||||
|
||||
######################################################
|
||||
# Populating index
|
||||
######################################################
|
||||
|
||||
def run_add(args, ds, index, res):
|
||||
|
||||
print("adding")
|
||||
t0 = time.time()
|
||||
if args.add_bs == -1:
|
||||
assert args.split == [1, 0], "split not supported with full batch add"
|
||||
index.add(sanitize(ds.get_database()))
|
||||
else:
|
||||
totn = ds.nb // args.split[0] # approximate
|
||||
i0 = 0
|
||||
for xblock in ds.database_iterator(bs=args.add_bs):
|
||||
print(f"Adding in block sizes {args.add_bs} with split {args.split}")
|
||||
for xblock in ds.database_iterator(bs=args.add_bs, split=args.split):
|
||||
i1 = i0 + len(xblock)
|
||||
print(" adding %d:%d / %d [%.3f s, RSS %d kiB] " % (
|
||||
i0, i1, ds.nb, time.time() - t0,
|
||||
i0, i1, totn, time.time() - t0,
|
||||
faiss.get_mem_usage_kb()))
|
||||
index.add(xblock)
|
||||
i0 = i1
|
||||
|
||||
print(" add in %.3f s" % (time.time() - t0))
|
||||
res.t_add = time.time() - t0
|
||||
print(f" add in {res.t_add:.3f} s index size {index.ntotal}")
|
||||
|
||||
|
||||
######################################################
|
||||
# Search
|
||||
######################################################
|
||||
|
||||
def run_search(args, ds, index, res):
|
||||
|
||||
index_ivf, vec_transform = unwind_index_ivf(index)
|
||||
|
||||
if args.no_precomputed_tables:
|
||||
if isinstance(index_ivf, faiss.IndexIVFPQ):
|
||||
print("disabling precomputed table")
|
||||
index_ivf.use_precomputed_table = -1
|
||||
index_ivf.precomputed_table.clear()
|
||||
|
||||
if args.indexfile:
|
||||
print("storing", args.indexfile)
|
||||
faiss.write_index(index, args.indexfile)
|
||||
print("index size on disk: ", os.stat(args.indexfile).st_size)
|
||||
|
||||
if args.no_precomputed_tables:
|
||||
if isinstance(index_ivf, faiss.IndexIVFPQ):
|
||||
print("disabling precomputed table")
|
||||
index_ivf.use_precomputed_table = -1
|
||||
index_ivf.precomputed_table.clear()
|
||||
if hasattr(index, "code_size"):
|
||||
print("vector code_size", index.code_size)
|
||||
|
||||
if args.indexfile:
|
||||
print("index size on disk: ", os.stat(args.indexfile).st_size)
|
||||
if hasattr(index_ivf, "code_size"):
|
||||
print("vector code_size (IVF)", index_ivf.code_size)
|
||||
|
||||
if hasattr(index, "code_size"):
|
||||
print("vector code_size", index.code_size)
|
||||
print("current RSS:", faiss.get_mem_usage_kb() * 1024)
|
||||
|
||||
if hasattr(index_ivf, "code_size"):
|
||||
print("vector code_size (IVF)", index_ivf.code_size)
|
||||
precomputed_table_size = 0
|
||||
if hasattr(index_ivf, 'precomputed_table'):
|
||||
precomputed_table_size = index_ivf.precomputed_table.size() * 4
|
||||
|
||||
print("current RSS:", faiss.get_mem_usage_kb() * 1024)
|
||||
print("precomputed tables size:", precomputed_table_size)
|
||||
|
||||
precomputed_table_size = 0
|
||||
if hasattr(index_ivf, 'precomputed_table'):
|
||||
precomputed_table_size = index_ivf.precomputed_table.size() * 4
|
||||
# Index is ready
|
||||
|
||||
print("precomputed tables size:", precomputed_table_size)
|
||||
xq = sanitize(ds.get_queries())
|
||||
nq, d = xq.shape
|
||||
gt = ds.get_groundtruth(k=args.k)
|
||||
|
||||
if not args.accept_short_gt: # Deep1B has only a single NN per query
|
||||
assert gt.shape[1] == args.k
|
||||
|
||||
#############################################################
|
||||
# Index is ready
|
||||
#############################################################
|
||||
|
||||
xq = sanitize(ds.get_queries())
|
||||
gt = ds.get_groundtruth(k=args.k)
|
||||
assert gt.shape[1] == args.k
|
||||
|
||||
if args.searchthreads != -1:
|
||||
print("Setting nb of threads to", args.searchthreads)
|
||||
faiss.omp_set_num_threads(args.searchthreads)
|
||||
else:
|
||||
print("nb search threads: ", faiss.omp_get_max_threads())
|
||||
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.initialize(index)
|
||||
|
||||
parametersets = args.searchparams
|
||||
|
||||
|
||||
|
||||
if args.inter:
|
||||
header = (
|
||||
'%-40s inter@%3d time(ms/q) nb distances #runs' %
|
||||
("parameters", args.k)
|
||||
)
|
||||
else:
|
||||
|
||||
header = (
|
||||
'%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' %
|
||||
"parameters"
|
||||
)
|
||||
|
||||
def compute_inter(a, b):
|
||||
nq, rank = a.shape
|
||||
ninter = sum(
|
||||
np.intersect1d(a[i, :rank], b[i, :rank]).size
|
||||
for i in range(nq)
|
||||
)
|
||||
return ninter / a.size
|
||||
|
||||
|
||||
|
||||
def eval_setting(index, xq, gt, k, inter, min_time):
|
||||
nq = xq.shape[0]
|
||||
ivf_stats = faiss.cvar.indexIVF_stats
|
||||
ivf_stats.reset()
|
||||
nrun = 0
|
||||
t0 = time.time()
|
||||
while True:
|
||||
D, I = index.search(xq, k)
|
||||
nrun += 1
|
||||
t1 = time.time()
|
||||
if t1 - t0 > min_time:
|
||||
break
|
||||
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
|
||||
if inter:
|
||||
rank = k
|
||||
inter_measure = compute_inter(gt[:, :rank], I[:, :rank])
|
||||
print("%.4f" % inter_measure, end=' ')
|
||||
if args.searchthreads != -1:
|
||||
print("Setting nb of threads to", args.searchthreads)
|
||||
faiss.omp_set_num_threads(args.searchthreads)
|
||||
else:
|
||||
for rank in 1, 10, 100:
|
||||
n_ok = (I[:, :rank] == gt[:, :1]).sum()
|
||||
print("%.4f" % (n_ok / float(nq)), end=' ')
|
||||
print(" %9.5f " % ms_per_query, end=' ')
|
||||
print("%12d " % (ivf_stats.ndis / nrun), end=' ')
|
||||
print(nrun)
|
||||
print("nb search threads: ", faiss.omp_get_max_threads())
|
||||
|
||||
ps = faiss.ParameterSpace()
|
||||
ps.initialize(index)
|
||||
|
||||
if parametersets == ['autotune']:
|
||||
parametersets = args.searchparams
|
||||
|
||||
ps.n_experiments = args.n_autotune
|
||||
ps.min_test_duration = args.min_test_duration
|
||||
|
||||
for kv in args.autotune_max:
|
||||
k, vmax = kv.split(':')
|
||||
vmax = float(vmax)
|
||||
print("limiting %s to %g" % (k, vmax))
|
||||
pr = ps.add_range(k)
|
||||
values = faiss.vector_to_array(pr.values)
|
||||
values = np.array([v for v in values if v < vmax])
|
||||
faiss.copy_array_to_vector(values, pr.values)
|
||||
|
||||
for kv in args.autotune_range:
|
||||
k, vals = kv.split(':')
|
||||
vals = np.fromstring(vals, sep=',')
|
||||
print("setting %s to %s" % (k, vals))
|
||||
pr = ps.add_range(k)
|
||||
faiss.copy_array_to_vector(vals, pr.values)
|
||||
|
||||
# setup the Criterion object
|
||||
if args.inter:
|
||||
print("Optimize for intersection @ ", args.k)
|
||||
crit = faiss.IntersectionCriterion(nq, args.k)
|
||||
header = (
|
||||
'%-40s inter@%3d time(ms/q) nb distances #runs' %
|
||||
("parameters", args.k)
|
||||
)
|
||||
else:
|
||||
print("Optimize for 1-recall @ 1")
|
||||
crit = faiss.OneRecallAtRCriterion(nq, 1)
|
||||
|
||||
# by default, the criterion will request only 1 NN
|
||||
crit.nnn = args.k
|
||||
crit.set_groundtruth(None, gt.astype('int64'))
|
||||
header = (
|
||||
'%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' %
|
||||
"parameters"
|
||||
)
|
||||
|
||||
# then we let Faiss find the optimal parameters by itself
|
||||
print("exploring operating points, %d threads" % faiss.omp_get_max_threads());
|
||||
ps.display()
|
||||
|
||||
t0 = time.time()
|
||||
op = ps.explore(index, xq, crit)
|
||||
print("Done in %.3f s, available OPs:" % (time.time() - t0))
|
||||
res.search_results = {}
|
||||
if parametersets == ['autotune']:
|
||||
|
||||
op.display()
|
||||
ps.n_experiments = args.n_autotune
|
||||
ps.min_test_duration = args.min_test_duration
|
||||
|
||||
print("Re-running evaluation on selected OPs")
|
||||
print(header)
|
||||
opv = op.optimal_pts
|
||||
maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40)
|
||||
for i in range(opv.size()):
|
||||
opt = opv.at(i)
|
||||
for kv in args.autotune_max:
|
||||
k, vmax = kv.split(':')
|
||||
vmax = float(vmax)
|
||||
print("limiting %s to %g" % (k, vmax))
|
||||
pr = ps.add_range(k)
|
||||
values = faiss.vector_to_array(pr.values)
|
||||
values = np.array([v for v in values if v < vmax])
|
||||
faiss.copy_array_to_vector(values, pr.values)
|
||||
|
||||
ps.set_index_parameters(index, opt.key)
|
||||
for kv in args.autotune_range:
|
||||
k, vals = kv.split(':')
|
||||
vals = np.fromstring(vals, sep=',')
|
||||
print("setting %s to %s" % (k, vals))
|
||||
pr = ps.add_range(k)
|
||||
faiss.copy_array_to_vector(vals, pr.values)
|
||||
|
||||
print(opt.key.ljust(maxw), end=' ')
|
||||
sys.stdout.flush()
|
||||
# setup the Criterion object
|
||||
if args.inter:
|
||||
print("Optimize for intersection @ ", args.k)
|
||||
crit = faiss.IntersectionCriterion(nq, args.k)
|
||||
else:
|
||||
print("Optimize for 1-recall @ 1")
|
||||
crit = faiss.OneRecallAtRCriterion(nq, 1)
|
||||
|
||||
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
# by default, the criterion will request only 1 NN
|
||||
crit.nnn = args.k
|
||||
crit.set_groundtruth(None, gt.astype('int64'))
|
||||
|
||||
else:
|
||||
print(header)
|
||||
for param in parametersets:
|
||||
print("%-40s " % param, end=' ')
|
||||
sys.stdout.flush()
|
||||
ps.set_index_parameters(index, param)
|
||||
# then we let Faiss find the optimal parameters by itself
|
||||
print("exploring operating points, %d threads" % faiss.omp_get_max_threads());
|
||||
ps.display()
|
||||
|
||||
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
t0 = time.time()
|
||||
op = ps.explore(index, xq, crit)
|
||||
res.t_explore = time.time() - t0
|
||||
print("Done in %.3f s, available OPs:" % res.t_explore)
|
||||
|
||||
op.display()
|
||||
|
||||
print("Re-running evaluation on selected OPs")
|
||||
print(header)
|
||||
opv = op.optimal_pts
|
||||
maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40)
|
||||
for i in range(opv.size()):
|
||||
opt = opv.at(i)
|
||||
|
||||
ps.set_index_parameters(index, opt.key)
|
||||
|
||||
print(opt.key.ljust(maxw), end=' ')
|
||||
sys.stdout.flush()
|
||||
|
||||
res_i = eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
res.search_results[opt.key] = res_i
|
||||
|
||||
else:
|
||||
print(header)
|
||||
for param in parametersets:
|
||||
print("%-40s " % param, end=' ')
|
||||
sys.stdout.flush()
|
||||
ps.set_index_parameters(index, param)
|
||||
|
||||
res_i = eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
|
||||
res.search_results[param] = res_i
|
||||
|
||||
|
||||
|
||||
######################################################
|
||||
# Driver function
|
||||
######################################################
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def aa(*args, **kwargs):
|
||||
group.add_argument(*args, **kwargs)
|
||||
|
||||
group = parser.add_argument_group('general options')
|
||||
aa('--nthreads', default=-1, type=int,
|
||||
help='nb of threads to use at train and add time')
|
||||
aa('--json', default=False, action="store_true",
|
||||
help="output stats in JSON format at the end")
|
||||
aa('--todo', default=["check_files"],
|
||||
choices=["train", "add", "search", "check_files"],
|
||||
nargs="+", help='what to do (check_files means decide depending on which index files exist)')
|
||||
|
||||
group = parser.add_argument_group('dataset options')
|
||||
aa('--db', default='deep1M', help='dataset')
|
||||
aa('--compute_gt', default=False, action='store_true',
|
||||
help='compute and store the groundtruth')
|
||||
aa('--force_IP', default=False, action="store_true",
|
||||
help='force IP search instead of L2')
|
||||
aa('--accept_short_gt', default=False, action='store_true',
|
||||
help='work around a problem with Deep1B GT')
|
||||
|
||||
group = parser.add_argument_group('index construction')
|
||||
aa('--indexkey', default='HNSW32', help='index_factory type')
|
||||
aa('--trained_indexfile', default='',
|
||||
help='file to read or write a trained index from')
|
||||
aa('--maxtrain', default=256 * 256, type=int,
|
||||
help='maximum number of training points (0 to set automatically)')
|
||||
aa('--indexfile', default='', help='file to read or write index from')
|
||||
aa('--split', default=[1, 0], type=int, nargs=2, help="database split")
|
||||
aa('--add_bs', default=-1, type=int,
|
||||
help='add elements index by batches of this size')
|
||||
|
||||
group = parser.add_argument_group('IVF options')
|
||||
aa('--by_residual', default=-1, type=int,
|
||||
help="set if index should use residuals (default=unchanged)")
|
||||
aa('--no_precomputed_tables', action='store_true', default=False,
|
||||
help='disable precomputed tables (uses less memory)')
|
||||
aa('--get_centroids_from', default='',
|
||||
help='get the centroids from this index (to speed up training)')
|
||||
aa('--clustering_niter', default=-1, type=int,
|
||||
help='number of clustering iterations (-1 = leave default)')
|
||||
aa('--train_on_gpu', default=False, action='store_true',
|
||||
help='do training on GPU')
|
||||
|
||||
group = parser.add_argument_group('index-specific options')
|
||||
aa('--M0', default=-1, type=int, help='size of base level for HNSW')
|
||||
aa('--RQ_train_default', default=False, action="store_true",
|
||||
help='disable progressive dim training for RQ')
|
||||
aa('--RQ_beam_size', default=-1, type=int,
|
||||
help='set beam size at add time')
|
||||
aa('--LSQ_encode_ils_iters', default=-1, type=int,
|
||||
help='ILS iterations for LSQ')
|
||||
aa('--RQ_use_beam_LUT', default=-1, type=int,
|
||||
help='use beam LUT at add time')
|
||||
|
||||
group = parser.add_argument_group('searching')
|
||||
aa('--k', default=100, type=int, help='nb of nearest neighbors')
|
||||
aa('--inter', default=False, action='store_true',
|
||||
help='use intersection measure instead of 1-recall as metric')
|
||||
aa('--searchthreads', default=-1, type=int,
|
||||
help='nb of threads to use at search time')
|
||||
aa('--searchparams', nargs='+', default=['autotune'],
|
||||
help="search parameters to use (can be autotune or a list of params)")
|
||||
aa('--n_autotune', default=500, type=int,
|
||||
help="max nb of autotune experiments")
|
||||
aa('--autotune_max', default=[], nargs='*',
|
||||
help='set max value for autotune variables format "var:val" (exclusive)')
|
||||
aa('--autotune_range', default=[], nargs='*',
|
||||
help='set complete autotune range, format "var:val1,val2,..."')
|
||||
aa('--min_test_duration', default=3.0, type=float,
|
||||
help='run test at least for so long to avoid jitter')
|
||||
aa('--indexes_to_merge', default=[], nargs="*",
|
||||
help="load these indexes to search and merge them before searching")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.todo == ["check_files"]:
|
||||
if os.path.exists(args.indexfile):
|
||||
args.todo = ["search"]
|
||||
elif os.path.exists(args.trained_indexfile):
|
||||
args.todo = ["add", "search"]
|
||||
else:
|
||||
args.todo = ["train", "add", "search"]
|
||||
print("setting todo to", args.todo)
|
||||
|
||||
print("args:", args)
|
||||
|
||||
os.system('echo -n "nb processors "; '
|
||||
'cat /proc/cpuinfo | grep ^processor | wc -l; '
|
||||
'cat /proc/cpuinfo | grep ^"model name" | tail -1')
|
||||
|
||||
# object to collect results
|
||||
res = argparse.Namespace()
|
||||
res.args = args.__dict__
|
||||
|
||||
res.cpu_model = [
|
||||
l for l in open("/proc/cpuinfo", "r")
|
||||
if "model name" in l][0]
|
||||
|
||||
print("Load dataset")
|
||||
|
||||
ds = datasets.load_dataset(
|
||||
dataset=args.db, compute_gt=args.compute_gt)
|
||||
|
||||
if args.force_IP:
|
||||
ds.metric = "IP"
|
||||
|
||||
print(ds)
|
||||
|
||||
if args.nthreads != -1:
|
||||
print("Set nb of threads to", args.nthreads)
|
||||
faiss.omp_set_num_threads(args.nthreads)
|
||||
else:
|
||||
print("nb threads: ", faiss.omp_get_max_threads())
|
||||
|
||||
index = None
|
||||
if "train" in args.todo:
|
||||
print("================== Training index")
|
||||
index = run_train(args, ds, res)
|
||||
if args.trained_indexfile:
|
||||
print("storing trained index", args.trained_indexfile)
|
||||
faiss.write_index(index, args.trained_indexfile)
|
||||
|
||||
if "add" in args.todo:
|
||||
if not index:
|
||||
assert args.trained_indexfile
|
||||
print("reading trained index", args.trained_indexfile)
|
||||
index = faiss.read_index(args.trained_indexfile)
|
||||
|
||||
print("================== Adding vectors to index")
|
||||
run_add(args, ds, index, res)
|
||||
if args.indexfile:
|
||||
print("storing", args.indexfile)
|
||||
faiss.write_index(index, args.indexfile)
|
||||
|
||||
if "search" in args.todo:
|
||||
if not index:
|
||||
if args.indexfile:
|
||||
print("reading index", args.indexfile)
|
||||
index = faiss.read_index(args.indexfile)
|
||||
elif args.indexes_to_merge:
|
||||
print(f"Merging {len(args.indexes_to_merge)} indexes")
|
||||
sz = 0
|
||||
for fname in args.indexes_to_merge:
|
||||
print(f" reading {fname} (current size {sz})")
|
||||
index_i = faiss.read_index(fname)
|
||||
if index is None:
|
||||
index = index_i
|
||||
else:
|
||||
index.merge_from(index_i, index.ntotal)
|
||||
sz = index.ntotal
|
||||
else:
|
||||
assert False, "provide --indexfile"
|
||||
|
||||
print("================== Searching")
|
||||
run_search(args, ds, index, res)
|
||||
|
||||
if args.json:
|
||||
print("JSON results:", json.dumps(res.__dict__))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -530,14 +530,7 @@ def main():
|
|||
raise RuntimeError()
|
||||
|
||||
totex = op.num_experiments()
|
||||
rs = np.random.RandomState(123)
|
||||
if totex < args.n_autotune:
|
||||
experiments = rs.permutation(totex - 2) + 1
|
||||
else:
|
||||
experiments = rs.randint(
|
||||
totex - 2, size=args.n_autotune - 2, replace=False)
|
||||
|
||||
experiments = [0, totex - 1] + list(experiments)
|
||||
experiments = op.sample_experiments()
|
||||
print(f"total nb experiments {totex}, running {len(experiments)}")
|
||||
|
||||
print("perform search")
|
||||
|
|
|
@ -380,7 +380,23 @@ class OperatingPointsWithRanges(OperatingPoints):
|
|||
return np.zeros(len(self.ranges), dtype=int)
|
||||
|
||||
def num_experiments(self):
|
||||
return np.prod([len(values) for name, values in self.ranges])
|
||||
return int(np.prod([len(values) for name, values in self.ranges]))
|
||||
|
||||
def sample_experiments(self, n_autotune, rs=np.random):
|
||||
""" sample a set of experiments of max size n_autotune
|
||||
(run all experiments in random order if n_autotune is 0)
|
||||
"""
|
||||
assert n_autotune == 0 or n_autotune >= 2
|
||||
totex = self.num_experiments()
|
||||
rs = np.random.RandomState(123)
|
||||
if n_autotune == 0 or totex < n_autotune:
|
||||
experiments = rs.permutation(totex - 2)
|
||||
else:
|
||||
experiments = rs.choice(
|
||||
totex - 2, size=n_autotune - 2, replace=False)
|
||||
|
||||
experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments]
|
||||
return experiments
|
||||
|
||||
def cno_to_key(self, cno):
|
||||
"""Convert a sequential experiment number to a key"""
|
||||
|
|
|
@ -977,14 +977,12 @@ void IndexIVF::search_and_reconstruct(
|
|||
std::min(nlist, params ? params->nprobe : this->nprobe);
|
||||
FAISS_THROW_IF_NOT(nprobe > 0);
|
||||
|
||||
idx_t* idx = new idx_t[n * nprobe];
|
||||
ScopeDeleter<idx_t> del(idx);
|
||||
float* coarse_dis = new float[n * nprobe];
|
||||
ScopeDeleter<float> del2(coarse_dis);
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
||||
quantizer->search(n, x, nprobe, coarse_dis, idx);
|
||||
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
||||
|
||||
invlists->prefetch_lists(idx, n * nprobe);
|
||||
invlists->prefetch_lists(idx.get(), n * nprobe);
|
||||
|
||||
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
||||
// and offset into `codes` for reconstruction
|
||||
|
@ -992,29 +990,94 @@ void IndexIVF::search_and_reconstruct(
|
|||
n,
|
||||
x,
|
||||
k,
|
||||
idx,
|
||||
coarse_dis,
|
||||
idx.get(),
|
||||
coarse_dis.get(),
|
||||
distances,
|
||||
labels,
|
||||
true /* store_pairs */,
|
||||
params);
|
||||
for (idx_t i = 0; i < n; ++i) {
|
||||
for (idx_t j = 0; j < k; ++j) {
|
||||
idx_t ij = i * k + j;
|
||||
idx_t key = labels[ij];
|
||||
float* reconstructed = recons + ij * d;
|
||||
if (key < 0) {
|
||||
// Fill with NaNs
|
||||
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
||||
} else {
|
||||
int list_no = lo_listno(key);
|
||||
int offset = lo_offset(key);
|
||||
#pragma omp parallel for if (n * k > 1000)
|
||||
for (idx_t ij = 0; ij < n * k; ij++) {
|
||||
idx_t key = labels[ij];
|
||||
float* reconstructed = recons + ij * d;
|
||||
if (key < 0) {
|
||||
// Fill with NaNs
|
||||
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
||||
} else {
|
||||
int list_no = lo_listno(key);
|
||||
int offset = lo_offset(key);
|
||||
|
||||
// Update label to the actual id
|
||||
labels[ij] = invlists->get_single_id(list_no, offset);
|
||||
// Update label to the actual id
|
||||
labels[ij] = invlists->get_single_id(list_no, offset);
|
||||
|
||||
reconstruct_from_offset(list_no, offset, reconstructed);
|
||||
reconstruct_from_offset(list_no, offset, reconstructed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void IndexIVF::search_and_return_codes(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
uint8_t* codes,
|
||||
bool include_listno,
|
||||
const SearchParameters* params_in) const {
|
||||
const IVFSearchParameters* params = nullptr;
|
||||
if (params_in) {
|
||||
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
||||
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
||||
}
|
||||
const size_t nprobe =
|
||||
std::min(nlist, params ? params->nprobe : this->nprobe);
|
||||
FAISS_THROW_IF_NOT(nprobe > 0);
|
||||
|
||||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
||||
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
||||
|
||||
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
||||
|
||||
invlists->prefetch_lists(idx.get(), n * nprobe);
|
||||
|
||||
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
||||
// and offset into `codes` for reconstruction
|
||||
search_preassigned(
|
||||
n,
|
||||
x,
|
||||
k,
|
||||
idx.get(),
|
||||
coarse_dis.get(),
|
||||
distances,
|
||||
labels,
|
||||
true /* store_pairs */,
|
||||
params);
|
||||
|
||||
size_t code_size_1 = code_size;
|
||||
if (include_listno) {
|
||||
code_size_1 += coarse_code_size();
|
||||
}
|
||||
|
||||
#pragma omp parallel for if (n * k > 1000)
|
||||
for (idx_t ij = 0; ij < n * k; ij++) {
|
||||
idx_t key = labels[ij];
|
||||
uint8_t* code1 = codes + ij * code_size_1;
|
||||
|
||||
if (key < 0) {
|
||||
// Fill with 0xff
|
||||
memset(code1, -1, code_size_1);
|
||||
} else {
|
||||
int list_no = lo_listno(key);
|
||||
int offset = lo_offset(key);
|
||||
const uint8_t* cc = invlists->get_single_code(list_no, offset);
|
||||
|
||||
labels[ij] = invlists->get_single_id(list_no, offset);
|
||||
|
||||
if (include_listno) {
|
||||
encode_listno(list_no, code1);
|
||||
code1 += code_size_1 - code_size;
|
||||
}
|
||||
memcpy(code1, cc, code_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -357,6 +357,24 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|||
float* recons,
|
||||
const SearchParameters* params = nullptr) const override;
|
||||
|
||||
/** Similar to search, but also returns the codes corresponding to the
|
||||
* stored vectors for the search results.
|
||||
*
|
||||
* @param codes codes (n, k, code_size)
|
||||
* @param include_listno
|
||||
* include the list ids in the code (in this case add
|
||||
* ceil(log8(nlist)) to the code size)
|
||||
*/
|
||||
void search_and_return_codes(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
uint8_t* recons,
|
||||
bool include_listno = false,
|
||||
const SearchParameters* params = nullptr) const;
|
||||
|
||||
/** Reconstruct a vector given the location in terms of (inv list index +
|
||||
* inv list offset) instead of the id.
|
||||
*
|
||||
|
|
|
@ -149,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner {
|
|||
const float* q;
|
||||
/// following codes come from this inverted list
|
||||
void set_list(idx_t list_no, float coarse_dis) override {
|
||||
this->list_no = list_no;
|
||||
if (ia.metric_type == METRIC_L2 && ia.by_residual) {
|
||||
ia.quantizer->compute_residual(q0, tmp.data(), list_no);
|
||||
q = tmp.data();
|
||||
|
|
|
@ -261,7 +261,7 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
|||
is_trained, "The additive quantizer is not trained yet.");
|
||||
|
||||
// standard additive quantizer decoding
|
||||
#pragma omp parallel for if (n > 1000)
|
||||
#pragma omp parallel for if (n > 100)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
BitstringReader bsr(code + i * code_size, code_size);
|
||||
float* xi = x + i * d;
|
||||
|
|
|
@ -306,7 +306,8 @@ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
|
|||
}
|
||||
|
||||
void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
#pragma omp parallel for if (n > 100)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
this->decode(code + code_size * i, x + d * i);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,8 @@ from faiss.array_conversions import *
|
|||
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
|
||||
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
|
||||
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
|
||||
merge_knn_results, MapInt64ToInt64, knn_hamming
|
||||
merge_knn_results, MapInt64ToInt64, knn_hamming, \
|
||||
pack_bitstrings, unpack_bitstrings
|
||||
|
||||
|
||||
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
|
||||
|
|
|
@ -402,6 +402,74 @@ def handle_Index(the_class):
|
|||
)
|
||||
return D, I, R
|
||||
|
||||
def replacement_search_and_return_codes(
|
||||
self, x, k, *,
|
||||
include_listnos=False, params=None, D=None, I=None, codes=None):
|
||||
"""Find the k nearest neighbors of the set of vectors x in the index,
|
||||
and return the codes stored for these vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
Query vectors, shape (n, d) where d is appropriate for the index.
|
||||
`dtype` must be float32.
|
||||
k : int
|
||||
Number of nearest neighbors.
|
||||
params : SearchParameters
|
||||
Search parameters of the current search (overrides the class-level params)
|
||||
include_listnos : bool, optional
|
||||
whether to include the list ids in the first bytes of each code
|
||||
D : array_like, optional
|
||||
Distance array to store the result.
|
||||
I : array_like, optional
|
||||
Labels array to store the result.
|
||||
codes : array_like, optional
|
||||
codes array to store
|
||||
|
||||
Returns
|
||||
-------
|
||||
D : array_like
|
||||
Distances of the nearest neighbors, shape (n, k). When not enough results are found
|
||||
the label is set to +Inf or -Inf.
|
||||
I : array_like
|
||||
Labels of the nearest neighbors, shape (n, k). When not enough results are found,
|
||||
the label is set to -1
|
||||
R : array_like
|
||||
Approximate (reconstructed) nearest neighbor vectors, shape (n, k, d).
|
||||
"""
|
||||
n, d = x.shape
|
||||
assert d == self.d
|
||||
x = np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
assert k > 0
|
||||
|
||||
if D is None:
|
||||
D = np.empty((n, k), dtype=np.float32)
|
||||
else:
|
||||
assert D.shape == (n, k)
|
||||
|
||||
if I is None:
|
||||
I = np.empty((n, k), dtype=np.int64)
|
||||
else:
|
||||
assert I.shape == (n, k)
|
||||
|
||||
code_size_1 = self.code_size
|
||||
if include_listnos:
|
||||
code_size_1 += self.coarse_code_size()
|
||||
|
||||
if codes is None:
|
||||
codes = np.empty((n, k, code_size_1), dtype=np.uint8)
|
||||
else:
|
||||
assert codes.shape == (n, k, code_size_1)
|
||||
|
||||
self.search_and_return_codes_c(
|
||||
n, swig_ptr(x),
|
||||
k, swig_ptr(D),
|
||||
swig_ptr(I), swig_ptr(codes), include_listnos,
|
||||
params
|
||||
)
|
||||
return D, I, codes
|
||||
|
||||
def replacement_remove_ids(self, x):
|
||||
"""Remove some ids from the index.
|
||||
This is a O(ntotal) operation by default, so could be expensive.
|
||||
|
@ -734,6 +802,8 @@ def handle_Index(the_class):
|
|||
ignore_missing=True)
|
||||
replace_method(the_class, 'search_and_reconstruct',
|
||||
replacement_search_and_reconstruct, ignore_missing=True)
|
||||
replace_method(the_class, 'search_and_return_codes',
|
||||
replacement_search_and_return_codes, ignore_missing=True)
|
||||
|
||||
# these ones are IVF-specific
|
||||
replace_method(the_class, 'search_preassigned',
|
||||
|
|
|
@ -14,6 +14,9 @@ from faiss.loader import *
|
|||
|
||||
import faiss
|
||||
|
||||
import collections.abc
|
||||
|
||||
|
||||
###########################################
|
||||
# Wrapper for a few functions
|
||||
###########################################
|
||||
|
@ -579,3 +582,72 @@ class Kmeans:
|
|||
self.index.add(self.centroids)
|
||||
D, I = self.index.search(x, 1)
|
||||
return D.ravel(), I.ravel()
|
||||
|
||||
|
||||
###########################################
|
||||
# Packing and unpacking bistrings
|
||||
###########################################
|
||||
|
||||
def is_sequence(x):
|
||||
return isinstance(x, collections.abc.Sequence)
|
||||
|
||||
pack_bitstrings_c = pack_bitstrings
|
||||
|
||||
def pack_bitstrings(a, nbit):
|
||||
"""
|
||||
Pack a set integers (i, j) where i=0:n and j=0:M into
|
||||
n bitstrings.
|
||||
Output is an uint8 array of size (n, code_size), where code_size is
|
||||
such that at most 7 bits per code are wasted.
|
||||
|
||||
If nbit is an integer: all entries takes nbit bits.
|
||||
If nbit is an array: entry (i, j) takes nbit[j] bits.
|
||||
"""
|
||||
n, M = a.shape
|
||||
a = np.ascontiguousarray(a, dtype='int32')
|
||||
if is_sequence(nbit):
|
||||
nbit = np.ascontiguousarray(nbit, dtype='int32')
|
||||
assert nbit.shape == (M,)
|
||||
code_size = int((nbit.sum() + 7) // 8)
|
||||
b = np.empty((n, code_size), dtype='uint8')
|
||||
pack_bitstrings_c(
|
||||
n, M, swig_ptr(nbit), swig_ptr(a), swig_ptr(b), code_size)
|
||||
else:
|
||||
code_size = (M * nbit + 7) // 8
|
||||
b = np.empty((n, code_size), dtype='uint8')
|
||||
pack_bitstrings_c(n, M, nbit, swig_ptr(a), swig_ptr(b), code_size)
|
||||
return b
|
||||
|
||||
unpack_bitstrings_c = unpack_bitstrings
|
||||
|
||||
def unpack_bitstrings(b, M_or_nbits, nbit=None):
|
||||
"""
|
||||
Unpack a set integers (i, j) where i=0:n and j=0:M from
|
||||
n bitstrings (encoded as uint8s).
|
||||
Input is an uint8 array of size (n, code_size), where code_size is
|
||||
such that at most 7 bits per code are wasted.
|
||||
|
||||
Two forms:
|
||||
- when called with (array, M, nbit): there are M entries of size
|
||||
nbit per row
|
||||
- when called with (array, nbits): element (i, j) is encoded in
|
||||
nbits[j] bits
|
||||
"""
|
||||
n, code_size = b.shape
|
||||
if nbit is None:
|
||||
nbit = np.ascontiguousarray(M_or_nbits, dtype='int32')
|
||||
M = len(nbit)
|
||||
min_code_size = int((nbit.sum() + 7) // 8)
|
||||
assert code_size >= min_code_size
|
||||
a = np.empty((n, M), dtype='int32')
|
||||
unpack_bitstrings_c(
|
||||
n, M, swig_ptr(nbit),
|
||||
swig_ptr(b), code_size, swig_ptr(a))
|
||||
else:
|
||||
M = M_or_nbits
|
||||
min_code_size = (M * nbit + 7) // 8
|
||||
assert code_size >= min_code_size
|
||||
a = np.empty((n, M), dtype='int32')
|
||||
unpack_bitstrings_c(
|
||||
n, M, nbit, swig_ptr(b), code_size, swig_ptr(a))
|
||||
return a
|
||||
|
|
|
@ -681,4 +681,88 @@ void generalized_hammings_knn_hc(
|
|||
ha->reorder();
|
||||
}
|
||||
|
||||
void pack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
int nbit,
|
||||
const int32_t* unpacked,
|
||||
uint8_t* packed,
|
||||
size_t code_size) {
|
||||
FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8);
|
||||
#pragma omp parallel for if (n > 1000)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
const int32_t* in = unpacked + i * M;
|
||||
uint8_t* out = packed + i * code_size;
|
||||
BitstringWriter wr(out, code_size);
|
||||
for (int j = 0; j < M; j++) {
|
||||
wr.write(in[j], nbit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void pack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
const int32_t* nbit,
|
||||
const int32_t* unpacked,
|
||||
uint8_t* packed,
|
||||
size_t code_size) {
|
||||
int totbit = 0;
|
||||
for (int j = 0; j < M; j++) {
|
||||
totbit += nbit[j];
|
||||
}
|
||||
FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8);
|
||||
#pragma omp parallel for if (n > 1000)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
const int32_t* in = unpacked + i * M;
|
||||
uint8_t* out = packed + i * code_size;
|
||||
BitstringWriter wr(out, code_size);
|
||||
for (int j = 0; j < M; j++) {
|
||||
wr.write(in[j], nbit[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
int nbit,
|
||||
const uint8_t* packed,
|
||||
size_t code_size,
|
||||
int32_t* unpacked) {
|
||||
FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8);
|
||||
#pragma omp parallel for if (n > 1000)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
const uint8_t* in = packed + i * code_size;
|
||||
int32_t* out = unpacked + i * M;
|
||||
BitstringReader rd(in, code_size);
|
||||
for (int j = 0; j < M; j++) {
|
||||
out[j] = rd.read(nbit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
const int32_t* nbit,
|
||||
const uint8_t* packed,
|
||||
size_t code_size,
|
||||
int32_t* unpacked) {
|
||||
int totbit = 0;
|
||||
for (int j = 0; j < M; j++) {
|
||||
totbit += nbit[j];
|
||||
}
|
||||
FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8);
|
||||
#pragma omp parallel for if (n > 1000)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
const uint8_t* in = packed + i * code_size;
|
||||
int32_t* out = unpacked + i * M;
|
||||
BitstringReader rd(in, code_size);
|
||||
for (int j = 0; j < M; j++) {
|
||||
out[j] = rd.read(nbit[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -222,6 +222,64 @@ void generalized_hammings_knn_hc(
|
|||
size_t code_size,
|
||||
int ordered = true);
|
||||
|
||||
/** Pack a set of n codes of size M * nbit
|
||||
*
|
||||
* @param n number of codes to pack
|
||||
* @param M number of elementary codes per code
|
||||
* @param nbit number of bits per elementary code
|
||||
* @param unpacked input unpacked codes, size (n, M)
|
||||
* @param packed output packed codes, size (n, code_size)
|
||||
* @param code_size should be >= ceil(M * nbit / 8)
|
||||
*/
|
||||
void pack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
int nbit,
|
||||
const int32_t* unpacked,
|
||||
uint8_t* packed,
|
||||
size_t code_size);
|
||||
|
||||
/** Pack a set of n codes of variable sizes
|
||||
*
|
||||
* @param nbit number of bits per entry (size M)
|
||||
*/
|
||||
void pack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
const int32_t* nbits,
|
||||
const int32_t* unpacked,
|
||||
uint8_t* packed,
|
||||
size_t code_size);
|
||||
|
||||
/** Unpack a set of n codes of size M * nbit
|
||||
*
|
||||
* @param n number of codes to pack
|
||||
* @param M number of elementary codes per code
|
||||
* @param nbit number of bits per elementary code
|
||||
* @param unpacked input unpacked codes, size (n, M)
|
||||
* @param packed output packed codes, size (n, code_size)
|
||||
* @param code_size should be >= ceil(M * nbit / 8)
|
||||
*/
|
||||
void unpack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
int nbit,
|
||||
const uint8_t* packed,
|
||||
size_t code_size,
|
||||
int32_t* unpacked);
|
||||
|
||||
/** Unpack a set of n codes of variable sizes
|
||||
*
|
||||
* @param nbit number of bits per entry (size M)
|
||||
*/
|
||||
void unpack_bitstrings(
|
||||
size_t n,
|
||||
size_t M,
|
||||
const int32_t* nbits,
|
||||
const uint8_t* packed,
|
||||
size_t code_size,
|
||||
int32_t* unpacked);
|
||||
|
||||
} // namespace faiss
|
||||
|
||||
#include <faiss/utils/hamming-inl.h>
|
||||
|
|
|
@ -14,7 +14,7 @@ import shutil
|
|||
import tempfile
|
||||
import platform
|
||||
|
||||
from common_faiss_tests import get_dataset_2
|
||||
from common_faiss_tests import get_dataset_2, get_dataset
|
||||
from faiss.contrib.datasets import SyntheticDataset
|
||||
from faiss.contrib.inspect_tools import make_LinearTransform_matrix
|
||||
from faiss.contrib.evaluation import check_ref_knn_with_draws
|
||||
|
@ -822,3 +822,158 @@ class TestIndependentQuantizer(unittest.TestCase):
|
|||
|
||||
np.testing.assert_array_equal(Dnew, D2)
|
||||
np.testing.assert_array_equal(Inew, I2)
|
||||
|
||||
|
||||
|
||||
class TestSearchAndReconstruct(unittest.TestCase):
|
||||
|
||||
def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None):
|
||||
n, d = xb.shape
|
||||
assert xq.shape[1] == d
|
||||
assert index.d == d
|
||||
|
||||
D_ref, I_ref = index.search(xq, k)
|
||||
R_ref = index.reconstruct_n(0, n)
|
||||
D, I, R = index.search_and_reconstruct(xq, k)
|
||||
|
||||
np.testing.assert_almost_equal(D, D_ref, decimal=5)
|
||||
self.assertTrue((I == I_ref).all())
|
||||
self.assertEqual(R.shape[:2], I.shape)
|
||||
self.assertEqual(R.shape[2], d)
|
||||
|
||||
# (n, k, ..) -> (n * k, ..)
|
||||
I_flat = I.reshape(-1)
|
||||
R_flat = R.reshape(-1, d)
|
||||
# Filter out -1s when not enough results
|
||||
R_flat = R_flat[I_flat >= 0]
|
||||
I_flat = I_flat[I_flat >= 0]
|
||||
|
||||
recons_ref_err = np.mean(np.linalg.norm(R_flat - R_ref[I_flat]))
|
||||
self.assertLessEqual(recons_ref_err, 1e-6)
|
||||
|
||||
def norm1(x):
|
||||
return np.sqrt((x ** 2).sum(axis=1))
|
||||
|
||||
recons_err = np.mean(norm1(R_flat - xb[I_flat]))
|
||||
|
||||
print('Reconstruction error = %.3f' % recons_err)
|
||||
if eps is not None:
|
||||
self.assertLessEqual(recons_err, eps)
|
||||
|
||||
return D, I, R
|
||||
|
||||
def test_IndexFlat(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
index.add(xb)
|
||||
|
||||
self.run_search_and_reconstruct(index, xb, xq, eps=0.0)
|
||||
|
||||
def test_IndexIVFFlat(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFFlat(quantizer, d, 32, faiss.METRIC_L2)
|
||||
index.cp.min_points_per_centroid = 5 # quiet warning
|
||||
index.nprobe = 4
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
self.run_search_and_reconstruct(index, xb, xq, eps=0.0)
|
||||
|
||||
def test_IndexIVFPQ(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, 32, 8, 8)
|
||||
index.cp.min_points_per_centroid = 5 # quiet warning
|
||||
index.nprobe = 4
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
self.run_search_and_reconstruct(index, xb, xq, eps=1.0)
|
||||
|
||||
def test_MultiIndex(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
||||
|
||||
index = faiss.index_factory(d, "IMI2x5,PQ8np")
|
||||
faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4)
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
self.run_search_and_reconstruct(index, xb, xq, eps=1.0)
|
||||
|
||||
def test_IndexTransform(self):
|
||||
d = 32
|
||||
nb = 1000
|
||||
nt = 1500
|
||||
nq = 200
|
||||
|
||||
(xt, xb, xq) = get_dataset(d, nb, nt, nq)
|
||||
|
||||
index = faiss.index_factory(d, "L2norm,PCA8,IVF32,PQ8np")
|
||||
faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4)
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
|
||||
self.run_search_and_reconstruct(index, xb, xq)
|
||||
|
||||
|
||||
class TestSearchAndGetCodes(unittest.TestCase):
|
||||
|
||||
def do_test(self, factory_string):
|
||||
ds = SyntheticDataset(32, 1000, 100, 10)
|
||||
|
||||
index = faiss.index_factory(ds.d, factory_string)
|
||||
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
|
||||
index.nprobe
|
||||
index.nprobe = 10
|
||||
Dref, Iref = index.search(ds.get_queries(), 10)
|
||||
|
||||
#print(index.search_and_return_codes)
|
||||
D, I, codes = index.search_and_return_codes(
|
||||
ds.get_queries(), 10, include_listnos=True)
|
||||
|
||||
np.testing.assert_array_equal(I, Iref)
|
||||
np.testing.assert_array_equal(D, Dref)
|
||||
|
||||
# verify that we get the same distances when decompressing from
|
||||
# returned codes (the codes are compatible with sa_decode)
|
||||
for qi in range(ds.nq):
|
||||
q = ds.get_queries()[qi]
|
||||
xbi = index.sa_decode(codes[qi])
|
||||
D2 = ((q - xbi) ** 2).sum(1)
|
||||
np.testing.assert_allclose(D2, D[qi], rtol=1e-5)
|
||||
|
||||
def test_ivfpq(self):
|
||||
self.do_test("IVF20,PQ4x4np")
|
||||
|
||||
def test_ivfsq(self):
|
||||
self.do_test("IVF20,SQ8")
|
||||
|
||||
def test_ivfrq(self):
|
||||
self.do_test("IVF20,RQ3x4")
|
||||
|
|
|
@ -266,9 +266,9 @@ class LatticeTest(unittest.TestCase):
|
|||
|
||||
|
||||
class TestBitstring(unittest.TestCase):
|
||||
""" Low-level bit string tests """
|
||||
|
||||
def test_rw(self):
|
||||
""" Low-level bit string tests """
|
||||
rs = np.random.RandomState(1234)
|
||||
nbyte = 1000
|
||||
sz = 0
|
||||
|
@ -311,6 +311,26 @@ class TestBitstring(unittest.TestCase):
|
|||
# print('nbit %d xref %x xnew %x' % (nbit, xref, xnew))
|
||||
self.assertTrue(xnew == xref)
|
||||
|
||||
def test_arrays(self):
|
||||
nbit = 5
|
||||
M = 10
|
||||
n = 20
|
||||
rs = np.random.RandomState(123)
|
||||
a = rs.randint(1<<nbit, size=(n, M), dtype='int32')
|
||||
b = faiss.pack_bitstrings(a, nbit)
|
||||
c = faiss.unpack_bitstrings(b, M, nbit)
|
||||
np.testing.assert_array_equal(a, c)
|
||||
|
||||
def test_arrays_variable_size(self):
|
||||
nbits = [10, 5, 3, 12, 6, 7, 4]
|
||||
n = 20
|
||||
rs = np.random.RandomState(123)
|
||||
a = rs.randint(1<<16, size=(n, len(nbits)), dtype='int32')
|
||||
a &= (1 << np.array(nbits)) - 1
|
||||
b = faiss.pack_bitstrings(a, nbits)
|
||||
c = faiss.unpack_bitstrings(b, nbits)
|
||||
np.testing.assert_array_equal(a, c)
|
||||
|
||||
|
||||
class TestIVFTransfer(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue