faiss/benchs/link_and_code/bench_link_and_code.py

301 lines
9.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import sys
import time
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize
import neighbor_codec
######################################################
# Command-line parsing
######################################################
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa( '--compute_gt', default=False, action='store_true',
help='compute and store the groundtruth')
group = parser.add_argument_group('index consturction')
aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--efConstruction', default=200, type=int,
help='HNSW construction factor')
aa('--M0', default=-1, type=int, help='size of base level')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points')
aa('--indexfile', default='', help='file to read or write index from')
aa('--add_bs', default=-1, type=int,
help='add elements index by batches of this size')
aa('--link_singletons', default=False, action='store_true',
help='do a pass to link in the singletons')
group = parser.add_argument_group(
'searching (reconstruct_from_neighbors options)')
aa('--beta_centroids', default='',
help='file with codebook')
aa('--neigh_recons_codes', default='',
help='file with codes for reconstruction')
aa('--beta_ntrain', default=250000, type=int, help='')
aa('--beta_k', default=256, type=int, help='beta codebook size')
aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors')
aa('--beta_niter', default=10, type=int, help='')
aa('--k_reorder', default='-1', help='')
group = parser.add_argument_group('searching')
aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--exhaustive', default=False, action='store_true',
help='report the exhaustive search topline')
aa('--searchthreads', default=-1, type=int,
help='nb of threads to use at search time')
aa('--efSearch', default='', type=str,
help='comma-separated values of efSearch to try')
args = parser.parse_args()
print("args:", args)
######################################################
# Load dataset
######################################################
xt, xb, xq, gt = datasets.load_data(
dataset=args.db, compute_gt=args.compute_gt)
nq, d = xq.shape
nb, d = xb.shape
######################################################
# Make index
######################################################
if os.path.exists(args.indexfile):
print("reading", args.indexfile)
index = faiss.read_index(args.indexfile)
if isinstance(index, faiss.IndexPreTransform):
index_hnsw = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
else:
index_hnsw = index
vec_transform = lambda x:x
hnsw = index_hnsw.hnsw
hnsw_stats = faiss.cvar.hnsw_stats
else:
print("build index, key=", args.indexkey)
index = faiss.index_factory(d, args.indexkey)
if isinstance(index, faiss.IndexPreTransform):
index_hnsw = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
else:
index_hnsw = index
vec_transform = lambda x:x
hnsw = index_hnsw.hnsw
hnsw.efConstruction = args.efConstruction
hnsw_stats = faiss.cvar.hnsw_stats
index.verbose = True
index_hnsw.verbose = True
index_hnsw.storage.verbose = True
if args.M0 != -1:
print("set level 0 nb of neighbors to", args.M0)
hnsw.set_nb_neighbors(0, args.M0)
xt2 = sanitize(xt[:args.maxtrain])
assert np.all(np.isfinite(xt2))
print("train, size", xt.shape)
t0 = time.time()
index.train(xt2)
print(" train in %.3f s" % (time.time() - t0))
print("adding")
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(xb))
else:
for i0 in range(0, nb, args.add_bs):
i1 = min(nb, i0 + args.add_bs)
print(" adding %d:%d / %d" % (i0, i1, nb))
index.add(sanitize(xb[i0:i1]))
print(" add in %.3f s" % (time.time() - t0))
print("storing", args.indexfile)
faiss.write_index(index, args.indexfile)
######################################################
# Train beta centroids and encode dataset
######################################################
if args.beta_centroids:
print("reordering links")
index_hnsw.reorder_links()
if os.path.exists(args.beta_centroids):
print("load", args.beta_centroids)
beta_centroids = np.load(args.beta_centroids)
nsq, k, M1 = beta_centroids.shape
assert M1 == hnsw.nb_neighbors(0) + 1
rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq)
else:
print("train beta centroids")
rfn = faiss.ReconstructFromNeighbors(
index_hnsw, args.beta_k, args.beta_nsq)
xb_full = vec_transform(sanitize(xb[:args.beta_ntrain]))
beta_centroids = neighbor_codec.train_beta_codebook(
rfn, xb_full, niter=args.beta_niter)
print(" storing", args.beta_centroids)
np.save(args.beta_centroids, beta_centroids)
faiss.copy_array_to_vector(beta_centroids.ravel(),
rfn.codebook)
index_hnsw.reconstruct_from_neighbors = rfn
if rfn.k == 1:
pass # no codes to take care of
elif os.path.exists(args.neigh_recons_codes):
print("loading neigh codes", args.neigh_recons_codes)
codes = np.load(args.neigh_recons_codes)
assert codes.size == rfn.code_size * index.ntotal
faiss.copy_array_to_vector(codes.astype('uint8'),
rfn.codes)
rfn.ntotal = index.ntotal
else:
print("encoding neigh codes")
t0 = time.time()
bs = 1000000 if args.add_bs == -1 else args.add_bs
for i0 in range(0, nb, bs):
i1 = min(i0 + bs, nb)
print(" encode %d:%d / %d [%.3f s]\r" % (
i0, i1, nb, time.time() - t0), end=' ')
sys.stdout.flush()
xbatch = vec_transform(sanitize(xb[i0:i1]))
rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch))
print()
print("storing %s" % args.neigh_recons_codes)
codes = faiss.vector_to_array(rfn.codes)
np.save(args.neigh_recons_codes, codes)
######################################################
# Exhaustive evaluation
######################################################
if args.exhaustive:
print("exhaustive evaluation")
xq_tr = vec_transform(sanitize(xq))
index2 = faiss.IndexFlatL2(index_hnsw.d)
accu_recons_error = 0.0
if faiss.get_num_gpus() > 0:
print("do eval on GPU")
co = faiss.GpuMultipleClonerOptions()
co.shard = False
index2 = faiss.index_cpu_to_all_gpus(index2, co)
# process in batches in case the dataset does not fit in RAM
rh = datasets.ResultHeap(xq_tr.shape[0], 100)
t0 = time.time()
bs = 500000
for i0 in range(0, nb, bs):
i1 = min(nb, i0 + bs)
print(' handling batch %d:%d' % (i0, i1))
xb_recons = np.empty(
(i1 - i0, index_hnsw.d), dtype='float32')
rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons))
accu_recons_error += (
(vec_transform(sanitize(xb[i0:i1])) -
xb_recons)**2).sum()
index2.reset()
index2.add(xb_recons)
D, I = index2.search(xq_tr, 100)
rh.add_batch_result(D, I, i0)
rh.finalize()
del index2
t1 = time.time()
print("done in %.3f s" % (t1 - t0))
print("total reconstruction error: ", accu_recons_error)
print("eval retrieval:")
datasets.evaluate_DI(rh.D, rh.I, gt)
def get_neighbors(hnsw, i, level):
" list the neighbors for node i at level "
assert i < hnsw.levels.size()
assert level < hnsw.levels.at(i)
be = np.empty(2, 'uint64')
hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:]))
return [hnsw.neighbors.at(j) for j in range(be[0], be[1])]
#############################################################
# Index is ready
#############################################################
xq = sanitize(xq)
if args.searchthreads != -1:
print("Setting nb of threads to", args.searchthreads)
faiss.omp_set_num_threads(args.searchthreads)
if gt is None:
print("no valid groundtruth -- exit")
sys.exit()
k_reorders = [int(x) for x in args.k_reorder.split(',')]
efSearchs = [int(x) for x in args.efSearch.split(',')]
for k_reorder in k_reorders:
if index_hnsw.reconstruct_from_neighbors:
print("setting k_reorder=%d" % k_reorder)
index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder
for efSearch in efSearchs:
print("efSearch=%-4d" % efSearch, end=' ')
hnsw.efSearch = efSearch
hnsw_stats.reset()
datasets.evaluate(xq, gt, index, k=args.k, endl=False)
print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder))