diff --git a/benchs/link_and_code/README.md b/benchs/link_and_code/README.md index 697c7bdfc..0c04cadac 100644 --- a/benchs/link_and_code/README.md +++ b/benchs/link_and_code/README.md @@ -21,138 +21,5 @@ graph to improve the reconstruction. It is described in ArXiV [here](https://arxiv.org/abs/1804.09996) -Code structure --------------- - -The test runs with 3 files: - -- `bench_link_and_code.py`: driver script - -- `datasets.py`: code to load the datasets. The example code runs on the - deep1b and bigann datasets. See the [toplevel README](../README.md) - on how to download them. They should be put in a directory, edit - datasets.py to set the path. - -- `neighbor_codec.py`: this is where the representation is trained. - -The code runs on top of Faiss. The HNSW index can be extended with a -`ReconstructFromNeighbors` C++ object that refines the distances. The -training is implemented in Python. - -Update: 2023-12-28: the current Faiss dropped support for reconstruction with -this method. - -Reproducing Table 2 in the paper --------------------------------- - -The results of table 2 (accuracy on deep100M) in the paper can be -obtained with: - -```bash -python bench_link_and_code.py \ - --db deep100M \ - --M0 6 \ - --indexkey OPQ36_144,HNSW32_PQ36 \ - --indexfile $bdir/deep100M_PQ36_L6.index \ - --beta_nsq 4 \ - --beta_centroids $bdir/deep100M_PQ36_L6_nsq4.npy \ - --neigh_recons_codes $bdir/deep100M_PQ36_L6_nsq4_codes.npy \ - --k_reorder 0,5 --efSearch 1,1024 -``` - -Set `bdir` to a scratch directory. - -Explanation of the flags: - -- `--db deep1M`: dataset to process - -- `--M0 6`: number of links on the base level (L6) - -- `--indexkey OPQ36_144,HNSW32_PQ36`: Faiss index key to construct the - HNSW structure. It means that vectors are transformed by OPQ and - encoded with PQ 36x8 (with an intermediate size of 144D). The HNSW - level>0 nodes have 32 links (theses ones are "cheap" to store - because there are fewer nodes in the upper levels. - -- `--indexfile $bdir/deep1M_PQ36_M6.index`: name of the index file - (without information for the L&C extension) - -- `--beta_nsq 4`: number of bytes to allocate for the codes (M in the - paper) - -- `--beta_centroids $bdir/deep1M_PQ36_M6_nsq4.npy`: filename to store - the trained beta centroids - -- `--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq4_codes.npy`: filename - for the encoded weights (beta) of the combination - -- `--k_reorder 0,5`: number of results to reorder. 0 = baseline - without reordering, 5 = value used throughout the paper - -- `--efSearch 1,1024`: number of nodes to visit (T in the paper) - -The script will proceed with the following steps: - -0. load dataset (and possibly compute the ground-truth if the -ground-truth file is not provided) - -1. train the OPQ encoder - -2. build the index and store it - -3. compute the residuals and train the beta vocabulary to do the reconstruction - -4. encode the vertices - -5. search and evaluate the search results. - -With option `--exhaustive` the results of the exhaustive column can be -obtained. - -The run above should output: -```bash -... -setting k_reorder=5 -... -efSearch=1024 0.3132 ms per query, R@1: 0.4283 R@10: 0.6337 R@100: 0.6520 ndis 40941919 nreorder 50000 - -``` -which matches the paper's table 2. - -Note that in multi-threaded mode, the building of the HNSW structure -is not deterministic. Therefore, the results across runs may not be exactly the same. - -Reproducing Figure 5 in the paper ---------------------------------- - -Figure 5 just evaluates the combination of HNSW and PQ. For example, -the operating point L6&OPQ40 can be obtained with - -```bash -python bench_link_and_code.py \ - --db deep1M \ - --M0 6 \ - --indexkey OPQ40_160,HNSW32_PQ40 \ - --indexfile $bdir/deep1M_PQ40_M6.index \ - --beta_nsq 1 --beta_k 1 \ - --beta_centroids $bdir/deep1M_PQ40_M6_nsq0.npy \ - --neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq0_codes.npy \ - --k_reorder 0 --efSearch 16,64,256,1024 -``` - -The arguments are similar to the previous table. Note that nsq = 0 is -simulated by setting beta_nsq = 1 and beta_k = 1 (ie a code with a single -reproduction value). - -The output should look like: - -```bash -setting k_reorder=0 -efSearch=16 0.0147 ms per query, R@1: 0.3409 R@10: 0.4388 R@100: 0.4394 ndis 2629735 nreorder 0 -efSearch=64 0.0122 ms per query, R@1: 0.4836 R@10: 0.6490 R@100: 0.6509 ndis 4623221 nreorder 0 -efSearch=256 0.0344 ms per query, R@1: 0.5730 R@10: 0.7915 R@100: 0.7951 ndis 11090176 nreorder 0 -efSearch=1024 0.2656 ms per query, R@1: 0.6212 R@10: 0.8722 R@100: 0.8765 ndis 33501951 nreorder 0 -``` - -The results with k_reorder=5 are not reported in the paper, they -represent the performance of a "free coding" version of the algorithm. +The necessary code for this paper was removed from Faiss in version 1.8.0. +For a functioning verinsion, use Faiss 1.7.4. diff --git a/benchs/link_and_code/bench_link_and_code.py b/benchs/link_and_code/bench_link_and_code.py deleted file mode 100755 index ed8f86d63..000000000 --- a/benchs/link_and_code/bench_link_and_code.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import print_function -import os -import sys -import time -import numpy as np -import faiss -import argparse -import datasets -from datasets import sanitize -import neighbor_codec - -###################################################### -# Command-line parsing -###################################################### - - -parser = argparse.ArgumentParser() - -def aa(*args, **kwargs): - group.add_argument(*args, **kwargs) - -group = parser.add_argument_group('dataset options') - -aa('--db', default='deep1M', help='dataset') -aa( '--compute_gt', default=False, action='store_true', - help='compute and store the groundtruth') - -group = parser.add_argument_group('index consturction') - -aa('--indexkey', default='HNSW32', help='index_factory type') -aa('--efConstruction', default=200, type=int, - help='HNSW construction factor') -aa('--M0', default=-1, type=int, help='size of base level') -aa('--maxtrain', default=256 * 256, type=int, - help='maximum number of training points') -aa('--indexfile', default='', help='file to read or write index from') -aa('--add_bs', default=-1, type=int, - help='add elements index by batches of this size') -aa('--link_singletons', default=False, action='store_true', - help='do a pass to link in the singletons') - -group = parser.add_argument_group( - 'searching (reconstruct_from_neighbors options)') - -aa('--beta_centroids', default='', - help='file with codebook') -aa('--neigh_recons_codes', default='', - help='file with codes for reconstruction') -aa('--beta_ntrain', default=250000, type=int, help='') -aa('--beta_k', default=256, type=int, help='beta codebook size') -aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors') -aa('--beta_niter', default=10, type=int, help='') -aa('--k_reorder', default='-1', help='') - -group = parser.add_argument_group('searching') - -aa('--k', default=100, type=int, help='nb of nearest neighbors') -aa('--exhaustive', default=False, action='store_true', - help='report the exhaustive search topline') -aa('--searchthreads', default=-1, type=int, - help='nb of threads to use at search time') -aa('--efSearch', default='', type=str, - help='comma-separated values of efSearch to try') - -args = parser.parse_args() - -print("args:", args) - - -###################################################### -# Load dataset -###################################################### - -xt, xb, xq, gt = datasets.load_data( - dataset=args.db, compute_gt=args.compute_gt) - -nq, d = xq.shape -nb, d = xb.shape - - -###################################################### -# Make index -###################################################### - -if os.path.exists(args.indexfile): - - print("reading", args.indexfile) - index = faiss.read_index(args.indexfile) - - if isinstance(index, faiss.IndexPreTransform): - index_hnsw = faiss.downcast_index(index.index) - vec_transform = index.chain.at(0).apply_py - else: - index_hnsw = index - vec_transform = lambda x:x - - hnsw = index_hnsw.hnsw - hnsw_stats = faiss.cvar.hnsw_stats - -else: - - print("build index, key=", args.indexkey) - - index = faiss.index_factory(d, args.indexkey) - - if isinstance(index, faiss.IndexPreTransform): - index_hnsw = faiss.downcast_index(index.index) - vec_transform = index.chain.at(0).apply_py - else: - index_hnsw = index - vec_transform = lambda x:x - - hnsw = index_hnsw.hnsw - hnsw.efConstruction = args.efConstruction - hnsw_stats = faiss.cvar.hnsw_stats - index.verbose = True - index_hnsw.verbose = True - index_hnsw.storage.verbose = True - - if args.M0 != -1: - print("set level 0 nb of neighbors to", args.M0) - hnsw.set_nb_neighbors(0, args.M0) - - xt2 = sanitize(xt[:args.maxtrain]) - assert np.all(np.isfinite(xt2)) - - print("train, size", xt.shape) - t0 = time.time() - index.train(xt2) - print(" train in %.3f s" % (time.time() - t0)) - - print("adding") - t0 = time.time() - if args.add_bs == -1: - index.add(sanitize(xb)) - else: - for i0 in range(0, nb, args.add_bs): - i1 = min(nb, i0 + args.add_bs) - print(" adding %d:%d / %d" % (i0, i1, nb)) - index.add(sanitize(xb[i0:i1])) - - print(" add in %.3f s" % (time.time() - t0)) - print("storing", args.indexfile) - faiss.write_index(index, args.indexfile) - - -###################################################### -# Train beta centroids and encode dataset -###################################################### - -if args.beta_centroids: - print("reordering links") - index_hnsw.reorder_links() - - if os.path.exists(args.beta_centroids): - print("load", args.beta_centroids) - beta_centroids = np.load(args.beta_centroids) - nsq, k, M1 = beta_centroids.shape - assert M1 == hnsw.nb_neighbors(0) + 1 - - rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq) - else: - print("train beta centroids") - rfn = faiss.ReconstructFromNeighbors( - index_hnsw, args.beta_k, args.beta_nsq) - - xb_full = vec_transform(sanitize(xb[:args.beta_ntrain])) - - beta_centroids = neighbor_codec.train_beta_codebook( - rfn, xb_full, niter=args.beta_niter) - - print(" storing", args.beta_centroids) - np.save(args.beta_centroids, beta_centroids) - - - faiss.copy_array_to_vector(beta_centroids.ravel(), - rfn.codebook) - index_hnsw.reconstruct_from_neighbors = rfn - - if rfn.k == 1: - pass # no codes to take care of - elif os.path.exists(args.neigh_recons_codes): - print("loading neigh codes", args.neigh_recons_codes) - codes = np.load(args.neigh_recons_codes) - assert codes.size == rfn.code_size * index.ntotal - faiss.copy_array_to_vector(codes.astype('uint8'), - rfn.codes) - rfn.ntotal = index.ntotal - else: - print("encoding neigh codes") - t0 = time.time() - - bs = 1000000 if args.add_bs == -1 else args.add_bs - - for i0 in range(0, nb, bs): - i1 = min(i0 + bs, nb) - print(" encode %d:%d / %d [%.3f s]\r" % ( - i0, i1, nb, time.time() - t0), end=' ') - sys.stdout.flush() - xbatch = vec_transform(sanitize(xb[i0:i1])) - rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch)) - print() - - print("storing %s" % args.neigh_recons_codes) - codes = faiss.vector_to_array(rfn.codes) - np.save(args.neigh_recons_codes, codes) - -###################################################### -# Exhaustive evaluation -###################################################### - -if args.exhaustive: - print("exhaustive evaluation") - xq_tr = vec_transform(sanitize(xq)) - index2 = faiss.IndexFlatL2(index_hnsw.d) - accu_recons_error = 0.0 - - if faiss.get_num_gpus() > 0: - print("do eval on GPU") - co = faiss.GpuMultipleClonerOptions() - co.shard = False - index2 = faiss.index_cpu_to_all_gpus(index2, co) - - # process in batches in case the dataset does not fit in RAM - rh = datasets.ResultHeap(xq_tr.shape[0], 100) - t0 = time.time() - bs = 500000 - for i0 in range(0, nb, bs): - i1 = min(nb, i0 + bs) - print(' handling batch %d:%d' % (i0, i1)) - - xb_recons = np.empty( - (i1 - i0, index_hnsw.d), dtype='float32') - rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons)) - - accu_recons_error += ( - (vec_transform(sanitize(xb[i0:i1])) - - xb_recons)**2).sum() - - index2.reset() - index2.add(xb_recons) - D, I = index2.search(xq_tr, 100) - rh.add_batch_result(D, I, i0) - - rh.finalize() - del index2 - t1 = time.time() - print("done in %.3f s" % (t1 - t0)) - print("total reconstruction error: ", accu_recons_error) - print("eval retrieval:") - datasets.evaluate_DI(rh.D, rh.I, gt) - - -def get_neighbors(hnsw, i, level): - " list the neighbors for node i at level " - assert i < hnsw.levels.size() - assert level < hnsw.levels.at(i) - be = np.empty(2, 'uint64') - hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:])) - return [hnsw.neighbors.at(j) for j in range(be[0], be[1])] - - -############################################################# -# Index is ready -############################################################# - -xq = sanitize(xq) - -if args.searchthreads != -1: - print("Setting nb of threads to", args.searchthreads) - faiss.omp_set_num_threads(args.searchthreads) - - -if gt is None: - print("no valid groundtruth -- exit") - sys.exit() - - -k_reorders = [int(x) for x in args.k_reorder.split(',')] -efSearchs = [int(x) for x in args.efSearch.split(',')] - - -for k_reorder in k_reorders: - - if index_hnsw.reconstruct_from_neighbors: - print("setting k_reorder=%d" % k_reorder) - index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder - - for efSearch in efSearchs: - print("efSearch=%-4d" % efSearch, end=' ') - hnsw.efSearch = efSearch - hnsw_stats.reset() - datasets.evaluate(xq, gt, index, k=args.k, endl=False) - - print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder)) diff --git a/benchs/link_and_code/datasets.py b/benchs/link_and_code/datasets.py deleted file mode 100755 index a043eb888..000000000 --- a/benchs/link_and_code/datasets.py +++ /dev/null @@ -1,236 +0,0 @@ -#! /usr/bin/env python2 - -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -Common functions to load datasets and compute their ground-truth -""" -from __future__ import print_function - -import time -import numpy as np -import faiss -import pdb -import sys - -# set this to the directory that contains the datafiles. -# deep1b data should be at simdir + 'deep1b' -# bigann data should be at simdir + 'bigann' -simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/' - -################################################################# -# Small I/O functions -################################################################# - - -def ivecs_read(fname): - a = np.fromfile(fname, dtype='int32') - d = a[0] - return a.reshape(-1, d + 1)[:, 1:].copy() - - -def fvecs_read(fname): - return ivecs_read(fname).view('float32') - - -def ivecs_mmap(fname): - a = np.memmap(fname, dtype='int32', mode='r') - d = a[0] - return a.reshape(-1, d + 1)[:, 1:] - - -def fvecs_mmap(fname): - return ivecs_mmap(fname).view('float32') - - -def bvecs_mmap(fname): - x = np.memmap(fname, dtype='uint8', mode='r') - d = x[:4].view('int32')[0] - return x.reshape(-1, d + 4)[:, 4:] - - -def ivecs_write(fname, m): - n, d = m.shape - m1 = np.empty((n, d + 1), dtype='int32') - m1[:, 0] = d - m1[:, 1:] = m - m1.tofile(fname) - - -def fvecs_write(fname, m): - m = m.astype('float32') - ivecs_write(fname, m.view('int32')) - - -################################################################# -# Dataset -################################################################# - -def sanitize(x): - return np.ascontiguousarray(x, dtype='float32') - - -class ResultHeap: - """ Combine query results from a sliced dataset """ - - def __init__(self, nq, k): - " nq: number of query vectors, k: number of results per query " - self.I = np.zeros((nq, k), dtype='int64') - self.D = np.zeros((nq, k), dtype='float32') - self.nq, self.k = nq, k - heaps = faiss.float_maxheap_array_t() - heaps.k = k - heaps.nh = nq - heaps.val = faiss.swig_ptr(self.D) - heaps.ids = faiss.swig_ptr(self.I) - heaps.heapify() - self.heaps = heaps - - def add_batch_result(self, D, I, i0): - assert D.shape == (self.nq, self.k) - assert I.shape == (self.nq, self.k) - I += i0 - self.heaps.addn_with_ids( - self.k, faiss.swig_ptr(D), - faiss.swig_ptr(I), self.k) - - def finalize(self): - self.heaps.reorder() - - - -def compute_GT_sliced(xb, xq, k): - print("compute GT") - t0 = time.time() - nb, d = xb.shape - nq, d = xq.shape - rh = ResultHeap(nq, k) - bs = 10 ** 5 - - xqs = sanitize(xq) - - db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) - - # compute ground-truth by blocks of bs, and add to heaps - for i0 in range(0, nb, bs): - i1 = min(nb, i0 + bs) - xsl = sanitize(xb[i0:i1]) - db_gt.add(xsl) - D, I = db_gt.search(xqs, k) - rh.add_batch_result(D, I, i0) - db_gt.reset() - print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ') - sys.stdout.flush() - print() - rh.finalize() - gt_I = rh.I - - print("GT time: %.3f s" % (time.time() - t0)) - return gt_I - - -def do_compute_gt(xb, xq, k): - print("computing GT") - nb, d = xb.shape - index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) - if nb < 100 * 1000: - print(" add") - index.add(np.ascontiguousarray(xb, dtype='float32')) - print(" search") - D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k) - else: - I = compute_GT_sliced(xb, xq, k) - - return I.astype('int32') - - -def load_data(dataset='deep1M', compute_gt=False): - - print("load data", dataset) - - if dataset == 'sift1M': - basedir = simdir + 'sift1M/' - - xt = fvecs_read(basedir + "sift_learn.fvecs") - xb = fvecs_read(basedir + "sift_base.fvecs") - xq = fvecs_read(basedir + "sift_query.fvecs") - gt = ivecs_read(basedir + "sift_groundtruth.ivecs") - - elif dataset.startswith('bigann'): - basedir = simdir + 'bigann/' - - dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1]) - xb = bvecs_mmap(basedir + 'bigann_base.bvecs') - xq = bvecs_mmap(basedir + 'bigann_query.bvecs') - xt = bvecs_mmap(basedir + 'bigann_learn.bvecs') - # trim xb to correct size - xb = xb[:dbsize * 1000 * 1000] - gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize) - - elif dataset.startswith("deep"): - basedir = simdir + 'deep1b/' - szsuf = dataset[4:] - if szsuf[-1] == 'M': - dbsize = 10 ** 6 * int(szsuf[:-1]) - elif szsuf == '1B': - dbsize = 10 ** 9 - elif szsuf[-1] == 'k': - dbsize = 1000 * int(szsuf[:-1]) - else: - assert False, "did not recognize suffix " + szsuf - - xt = fvecs_mmap(basedir + "learn.fvecs") - xb = fvecs_mmap(basedir + "base.fvecs") - xq = fvecs_read(basedir + "deep1B_queries.fvecs") - - xb = xb[:dbsize] - - gt_fname = basedir + "%s_groundtruth.ivecs" % dataset - if compute_gt: - gt = do_compute_gt(xb, xq, 100) - print("store", gt_fname) - ivecs_write(gt_fname, gt) - - gt = ivecs_read(gt_fname) - - else: - assert False - - print("dataset %s sizes: B %s Q %s T %s" % ( - dataset, xb.shape, xq.shape, xt.shape)) - - return xt, xb, xq, gt - -################################################################# -# Evaluation -################################################################# - - -def evaluate_DI(D, I, gt): - nq = gt.shape[0] - k = I.shape[1] - rank = 1 - while rank <= k: - recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) - print("R@%d: %.4f" % (rank, recall), end=' ') - rank *= 10 - - -def evaluate(xq, gt, index, k=100, endl=True): - t0 = time.time() - D, I = index.search(xq, k) - t1 = time.time() - nq = xq.shape[0] - print("\t %8.4f ms per query, " % ( - (t1 - t0) * 1000.0 / nq), end=' ') - rank = 1 - while rank <= k: - recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) - print("R@%d: %.4f" % (rank, recall), end=' ') - rank *= 10 - if endl: - print() - return D, I diff --git a/benchs/link_and_code/neighbor_codec.py b/benchs/link_and_code/neighbor_codec.py deleted file mode 100755 index 54cad8168..000000000 --- a/benchs/link_and_code/neighbor_codec.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -This is the training code for the link and code. Especially the -neighbors_kmeans function implements the EM-algorithm to find the -appropriate weightings and cluster them. -""" -from __future__ import print_function - -import time -import numpy as np -import faiss - -#---------------------------------------------------------- -# Utils -#---------------------------------------------------------- - -def sanitize(x): - return np.ascontiguousarray(x, dtype='float32') - - -def train_kmeans(x, k, ngpu, max_points_per_centroid=256): - "Runs kmeans on one or several GPUs" - d = x.shape[1] - clus = faiss.Clustering(d, k) - clus.verbose = True - clus.niter = 20 - clus.max_points_per_centroid = max_points_per_centroid - - if ngpu == 0: - index = faiss.IndexFlatL2(d) - else: - res = [faiss.StandardGpuResources() for i in range(ngpu)] - - flat_config = [] - for i in range(ngpu): - cfg = faiss.GpuIndexFlatConfig() - cfg.useFloat16 = False - cfg.device = i - flat_config.append(cfg) - - if ngpu == 1: - index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0]) - else: - indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i]) - for i in range(ngpu)] - index = faiss.IndexReplicas() - for sub_index in indexes: - index.addIndex(sub_index) - - # perform the training - clus.train(x, index) - centroids = faiss.vector_float_to_array(clus.centroids) - - stats = clus.iteration_stats - stats = [stats.at(i) for i in range(stats.size())] - obj = np.array([st.obj for st in stats]) - print("final objective: %.4g" % obj[-1]) - - return centroids.reshape(k, d) - - -#---------------------------------------------------------- -# Learning the codebook from neighbors -#---------------------------------------------------------- - - -# works with both a full Inn table and dynamically generated neighbors - -def get_Inn_shape(Inn): - if type(Inn) != tuple: - return Inn.shape - return Inn[:2] - -def get_neighbor_table(x_coded, Inn, i): - if type(Inn) != tuple: - return x_coded[Inn[i,:],:] - rfn = x_coded - M, d = rfn.M, rfn.index.d - out = np.zeros((M + 1, d), dtype='float32') - int_i = int(i) - rfn.get_neighbor_table(int_i, faiss.swig_ptr(out)) - _, _, sq = Inn - return out[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] - - -# Function that produces the best regression values from the vector -# and its neighbors -def regress_from_neighbors (x, x_coded, Inn): - (N, knn) = get_Inn_shape(Inn) - betas = np.zeros((N,knn)) - t0 = time.time() - for i in range (N): - xi = x[i,:] - NNi = get_neighbor_table(x_coded, Inn, i) - betas[i,:] = np.linalg.lstsq(NNi.transpose(), xi, rcond=0.01)[0] - if i % (N / 10) == 0: - print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) - return betas - - - -# find the best beta minimizing ||x-x_coded[Inn,:]*beta||^2 -def regress_opt_beta (x, x_coded, Inn): - (N, knn) = get_Inn_shape(Inn) - d = x.shape[1] - - # construct the linear system to be solved - X = np.zeros ((d*N)) - Y = np.zeros ((d*N, knn)) - for i in range (N): - X[i*d:(i+1)*d] = x[i,:] - neighbor_table = get_neighbor_table(x_coded, Inn, i) - Y[i*d:(i+1)*d, :] = neighbor_table.transpose() - beta_opt = np.linalg.lstsq(Y, X, rcond=0.01)[0] - return beta_opt - - -# Find the best encoding by minimizing the reconstruction error using -# a set of pre-computed beta values -def assign_beta (beta_centroids, x, x_coded, Inn, verbose=True): - if type(Inn) == tuple: - return assign_beta_2(beta_centroids, x, x_coded, Inn) - (N, knn) = Inn.shape - x_ibeta = np.zeros ((N), dtype='int32') - t0= time.time() - for i in range (N): - NNi = x_coded[Inn[i,:]] - # Consider all possible betas for the encoding and compute the - # encoding error - x_reg_all = np.dot (beta_centroids, NNi) - err = ((x_reg_all - x[i,:]) ** 2).sum(axis=1) - x_ibeta[i] = err.argmin() - if verbose: - if i % (N / 10) == 0: - print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) - return x_ibeta - - -# Reconstruct a set of vectors using the beta_centroids, the -# assignment, the encoded neighbors identified by the list Inn (which -# includes the vector itself) -def recons_from_neighbors (beta_centroids, x_ibeta, x_coded, Inn): - (N, knn) = Inn.shape - x_rec = np.zeros(x_coded.shape) - t0= time.time() - for i in range (N): - NNi = x_coded[Inn[i,:]] - x_rec[i, :] = np.dot (beta_centroids[x_ibeta[i]], NNi) - if i % (N / 10) == 0: - print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) - return x_rec - - -# Compute a EM-like algorithm trying at optimizing the beta such as they -# minimize the reconstruction error from the neighbors -def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5): - # First compute centroids using a regular k-means algorithm - betas = regress_from_neighbors (x, x_coded, Inn) - beta_centroids = train_kmeans( - sanitize(betas), K, ngpus, max_points_per_centroid=1000000) - _, knn = get_Inn_shape(Inn) - d = x.shape[1] - - rs = np.random.RandomState() - for iter in range(niter): - print('iter', iter) - idx = assign_beta (beta_centroids, x, x_coded, Inn, verbose=False) - - hist = np.bincount(idx) - for cl0 in np.where(hist == 0)[0]: - print(" cluster %d empty, split" % cl0, end=' ') - cl1 = idx[np.random.randint(idx.size)] - pos = np.nonzero (idx == cl1)[0] - pos = rs.choice(pos, pos.size / 2) - print(" cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos))) - idx[pos] = cl0 - hist = np.bincount(idx) - - tot_err = 0 - for k in range (K): - pos = np.nonzero (idx == k)[0] - npos = pos.shape[0] - - X = np.zeros (d*npos) - Y = np.zeros ((d*npos, knn)) - - for i in range(npos): - X[i*d:(i+1)*d] = x[pos[i],:] - neighbor_table = get_neighbor_table(x_coded, Inn, pos[i]) - Y[i*d:(i+1)*d, :] = neighbor_table.transpose() - sol, residuals, _, _ = np.linalg.lstsq(Y, X, rcond=0.01) - if residuals.size > 0: - tot_err += residuals.sum() - beta_centroids[k, :] = sol - print(' err=%g' % tot_err) - return beta_centroids - - -# assign the betas in C++ -def assign_beta_2(beta_centroids, x, rfn, Inn): - _, _, sq = Inn - if rfn.k == 1: - return np.zeros(x.shape[0], dtype=int) - # add dummy dimensions to beta_centroids and x - all_beta_centroids = np.zeros( - (rfn.nsq, rfn.k, rfn.M + 1), dtype='float32') - all_beta_centroids[sq] = beta_centroids - all_x = np.zeros((len(x), rfn.d), dtype='float32') - all_x[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] = x - rfn.codes.clear() - rfn.ntotal = 0 - faiss.copy_array_to_vector( - all_beta_centroids.ravel(), rfn.codebook) - rfn.add_codes(len(x), faiss.swig_ptr(all_x)) - codes = faiss.vector_to_array(rfn.codes) - codes = codes.reshape(-1, rfn.nsq) - return codes[:, sq] - - -####################################################### -# For usage from bench_storages.py - -def train_beta_codebook(rfn, xb_full, niter=10): - beta_centroids = [] - for sq in range(rfn.nsq): - d0, d1 = sq * rfn.dsub, (sq + 1) * rfn.dsub - print("training subquantizer %d/%d on dimensions %d:%d" % ( - sq, rfn.nsq, d0, d1)) - beta_centroids_i = neighbors_kmeans( - xb_full[:, d0:d1], rfn, (xb_full.shape[0], rfn.M + 1, sq), - rfn.k, - ngpus=0, niter=niter) - beta_centroids.append(beta_centroids_i) - rfn.ntotal = 0 - rfn.codes.clear() - rfn.codebook.clear() - return np.stack(beta_centroids) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 9a67332d6..3325c8c0e 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -307,7 +307,7 @@ void hnsw_search( FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); efSearch = params->efSearch; } - size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0; + size_t n1 = 0, n2 = 0, ndis = 0; idx_t check_period = InterruptCallback::get_period_hint( hnsw.max_level * index->d * efSearch); @@ -323,7 +323,7 @@ void hnsw_search( std::unique_ptr dis( storage_distance_computer(index->storage)); -#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided) +#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided) for (idx_t i = i0; i < i1; i++) { res.begin(i); dis->set_query(x + i * index->d); @@ -331,16 +331,14 @@ void hnsw_search( HNSWStats stats = hnsw.search(*dis, res, vt, params); n1 += stats.n1; n2 += stats.n2; - n3 += stats.n3; ndis += stats.ndis; - nreorder += stats.nreorder; res.end(); } } InterruptCallback::check(); } - hnsw_stats.combine({n1, n2, n3, ndis, nreorder}); + hnsw_stats.combine({n1, n2, ndis}); } } // anonymous namespace @@ -800,7 +798,7 @@ void IndexHNSW2Level::search( IndexHNSW::search(n, x, k, distances, labels); } else { // "mixed" search - size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0; + size_t n1 = 0, n2 = 0, ndis = 0; const IndexIVFPQ* index_ivfpq = dynamic_cast(storage); @@ -832,7 +830,7 @@ void IndexHNSW2Level::search( int candidates_size = hnsw.upper_beam; MinimaxHeap candidates(candidates_size); -#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) +#pragma omp for reduction(+ : n1, n2, ndis) for (idx_t i = 0; i < n; i++) { idx_t* idxi = labels + i * k; float* simi = distances + i * k; @@ -877,9 +875,7 @@ void IndexHNSW2Level::search( k); n1 += search_stats.n1; n2 += search_stats.n2; - n3 += search_stats.n3; ndis += search_stats.ndis; - nreorder += search_stats.nreorder; vt.advance(); vt.advance(); @@ -888,7 +884,7 @@ void IndexHNSW2Level::search( } } - hnsw_stats.combine({n1, n2, n3, ndis, nreorder}); + hnsw_stats.combine({n1, n2, ndis}); } } diff --git a/faiss/gpu/GpuIndex.h b/faiss/gpu/GpuIndex.h index 36de98c09..cc10f2158 100644 --- a/faiss/gpu/GpuIndex.h +++ b/faiss/gpu/GpuIndex.h @@ -84,19 +84,14 @@ class GpuIndex : public faiss::Index { /// `x` and `labels` can be resident on the CPU or any GPU; copies are /// performed as needed - void assign( - idx_t n, - const float* x, - idx_t* labels, - // faiss::Index has idx_t for k - idx_t k = 1) const override; + void assign(idx_t n, const float* x, idx_t* labels, idx_t k = 1) + const override; /// `x`, `distances` and `labels` can be resident on the CPU or any /// GPU; copies are performed as needed void search( idx_t n, const float* x, - // faiss::Index has idx_t for k idx_t k, float* distances, idx_t* labels, @@ -107,7 +102,6 @@ class GpuIndex : public faiss::Index { void search_and_reconstruct( idx_t n, const float* x, - // faiss::Index has idx_t for k idx_t k, float* distances, idx_t* labels, diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index a9fb9daf5..b1324e121 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -664,7 +664,7 @@ int search_from_candidates( if (candidates.size() == 0) { stats.n2++; } - stats.n3 += ndis; + stats.ndis += ndis; } return nres; @@ -793,7 +793,7 @@ std::priority_queue search_from_candidate_unbounded( if (candidates.size() == 0) { ++stats.n2; } - stats.n3 += ndis; + stats.ndis += ndis; return top_candidates; } diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index cb6b422c3..8261423cd 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -230,30 +230,20 @@ struct HNSW { }; struct HNSWStats { - size_t n1, n2, n3; - size_t ndis; - size_t nreorder; - - HNSWStats( - size_t n1 = 0, - size_t n2 = 0, - size_t n3 = 0, - size_t ndis = 0, - size_t nreorder = 0) - : n1(n1), n2(n2), n3(n3), ndis(ndis), nreorder(nreorder) {} + size_t n1 = 0; /// numbner of vectors searched + size_t n2 = + 0; /// number of queries for which the candidate list is exhasted + size_t ndis = 0; /// number of distances computed void reset() { - n1 = n2 = n3 = 0; + n1 = n2 = 0; ndis = 0; - nreorder = 0; } void combine(const HNSWStats& other) { n1 += other.n1; n2 += other.n2; - n3 += other.n3; ndis += other.ndis; - nreorder += other.nreorder; } }; diff --git a/tests/test_graph_based.py b/tests/test_graph_based.py index 914fac3ff..dd4212d71 100644 --- a/tests/test_graph_based.py +++ b/tests/test_graph_based.py @@ -123,6 +123,16 @@ class TestHNSW(unittest.TestCase): mask = Iref[:, 0] == Ihnsw[:, 0] assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0]) + def test_ndis_stats(self): + d = self.xq.shape[1] + + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + stats = faiss.cvar.hnsw_stats + stats.reset() + Dhnsw, Ihnsw = index.search(self.xq, 1) + self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch) + class TestNSG(unittest.TestCase):