Fix HNSW stats (#3309)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3309

Make sure that the HNSW search stats work, remove stats for deprecated functionality.
Remove code of the link and code paper that is not supported anymore.

Reviewed By: kuarora, junjieqi

Differential Revision: D55247802

fbshipit-source-id: 03f176be092bff6b2db359cc956905d8646ea702
pull/3312/head
Matthijs Douze 2024-03-22 12:55:30 -07:00 committed by Facebook GitHub Bot
parent b77061ff5e
commit fa1f39ec9f
9 changed files with 27 additions and 947 deletions

View File

@ -21,138 +21,5 @@ graph to improve the reconstruction. It is described in
ArXiV [here](https://arxiv.org/abs/1804.09996)
Code structure
--------------
The test runs with 3 files:
- `bench_link_and_code.py`: driver script
- `datasets.py`: code to load the datasets. The example code runs on the
deep1b and bigann datasets. See the [toplevel README](../README.md)
on how to download them. They should be put in a directory, edit
datasets.py to set the path.
- `neighbor_codec.py`: this is where the representation is trained.
The code runs on top of Faiss. The HNSW index can be extended with a
`ReconstructFromNeighbors` C++ object that refines the distances. The
training is implemented in Python.
Update: 2023-12-28: the current Faiss dropped support for reconstruction with
this method.
Reproducing Table 2 in the paper
--------------------------------
The results of table 2 (accuracy on deep100M) in the paper can be
obtained with:
```bash
python bench_link_and_code.py \
--db deep100M \
--M0 6 \
--indexkey OPQ36_144,HNSW32_PQ36 \
--indexfile $bdir/deep100M_PQ36_L6.index \
--beta_nsq 4 \
--beta_centroids $bdir/deep100M_PQ36_L6_nsq4.npy \
--neigh_recons_codes $bdir/deep100M_PQ36_L6_nsq4_codes.npy \
--k_reorder 0,5 --efSearch 1,1024
```
Set `bdir` to a scratch directory.
Explanation of the flags:
- `--db deep1M`: dataset to process
- `--M0 6`: number of links on the base level (L6)
- `--indexkey OPQ36_144,HNSW32_PQ36`: Faiss index key to construct the
HNSW structure. It means that vectors are transformed by OPQ and
encoded with PQ 36x8 (with an intermediate size of 144D). The HNSW
level>0 nodes have 32 links (theses ones are "cheap" to store
because there are fewer nodes in the upper levels.
- `--indexfile $bdir/deep1M_PQ36_M6.index`: name of the index file
(without information for the L&C extension)
- `--beta_nsq 4`: number of bytes to allocate for the codes (M in the
paper)
- `--beta_centroids $bdir/deep1M_PQ36_M6_nsq4.npy`: filename to store
the trained beta centroids
- `--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq4_codes.npy`: filename
for the encoded weights (beta) of the combination
- `--k_reorder 0,5`: number of results to reorder. 0 = baseline
without reordering, 5 = value used throughout the paper
- `--efSearch 1,1024`: number of nodes to visit (T in the paper)
The script will proceed with the following steps:
0. load dataset (and possibly compute the ground-truth if the
ground-truth file is not provided)
1. train the OPQ encoder
2. build the index and store it
3. compute the residuals and train the beta vocabulary to do the reconstruction
4. encode the vertices
5. search and evaluate the search results.
With option `--exhaustive` the results of the exhaustive column can be
obtained.
The run above should output:
```bash
...
setting k_reorder=5
...
efSearch=1024 0.3132 ms per query, R@1: 0.4283 R@10: 0.6337 R@100: 0.6520 ndis 40941919 nreorder 50000
```
which matches the paper's table 2.
Note that in multi-threaded mode, the building of the HNSW structure
is not deterministic. Therefore, the results across runs may not be exactly the same.
Reproducing Figure 5 in the paper
---------------------------------
Figure 5 just evaluates the combination of HNSW and PQ. For example,
the operating point L6&OPQ40 can be obtained with
```bash
python bench_link_and_code.py \
--db deep1M \
--M0 6 \
--indexkey OPQ40_160,HNSW32_PQ40 \
--indexfile $bdir/deep1M_PQ40_M6.index \
--beta_nsq 1 --beta_k 1 \
--beta_centroids $bdir/deep1M_PQ40_M6_nsq0.npy \
--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq0_codes.npy \
--k_reorder 0 --efSearch 16,64,256,1024
```
The arguments are similar to the previous table. Note that nsq = 0 is
simulated by setting beta_nsq = 1 and beta_k = 1 (ie a code with a single
reproduction value).
The output should look like:
```bash
setting k_reorder=0
efSearch=16 0.0147 ms per query, R@1: 0.3409 R@10: 0.4388 R@100: 0.4394 ndis 2629735 nreorder 0
efSearch=64 0.0122 ms per query, R@1: 0.4836 R@10: 0.6490 R@100: 0.6509 ndis 4623221 nreorder 0
efSearch=256 0.0344 ms per query, R@1: 0.5730 R@10: 0.7915 R@100: 0.7951 ndis 11090176 nreorder 0
efSearch=1024 0.2656 ms per query, R@1: 0.6212 R@10: 0.8722 R@100: 0.8765 ndis 33501951 nreorder 0
```
The results with k_reorder=5 are not reported in the paper, they
represent the performance of a "free coding" version of the algorithm.
The necessary code for this paper was removed from Faiss in version 1.8.0.
For a functioning verinsion, use Faiss 1.7.4.

View File

@ -1,300 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import sys
import time
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize
import neighbor_codec
######################################################
# Command-line parsing
######################################################
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--db', default='deep1M', help='dataset')
aa( '--compute_gt', default=False, action='store_true',
help='compute and store the groundtruth')
group = parser.add_argument_group('index consturction')
aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--efConstruction', default=200, type=int,
help='HNSW construction factor')
aa('--M0', default=-1, type=int, help='size of base level')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points')
aa('--indexfile', default='', help='file to read or write index from')
aa('--add_bs', default=-1, type=int,
help='add elements index by batches of this size')
aa('--link_singletons', default=False, action='store_true',
help='do a pass to link in the singletons')
group = parser.add_argument_group(
'searching (reconstruct_from_neighbors options)')
aa('--beta_centroids', default='',
help='file with codebook')
aa('--neigh_recons_codes', default='',
help='file with codes for reconstruction')
aa('--beta_ntrain', default=250000, type=int, help='')
aa('--beta_k', default=256, type=int, help='beta codebook size')
aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors')
aa('--beta_niter', default=10, type=int, help='')
aa('--k_reorder', default='-1', help='')
group = parser.add_argument_group('searching')
aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--exhaustive', default=False, action='store_true',
help='report the exhaustive search topline')
aa('--searchthreads', default=-1, type=int,
help='nb of threads to use at search time')
aa('--efSearch', default='', type=str,
help='comma-separated values of efSearch to try')
args = parser.parse_args()
print("args:", args)
######################################################
# Load dataset
######################################################
xt, xb, xq, gt = datasets.load_data(
dataset=args.db, compute_gt=args.compute_gt)
nq, d = xq.shape
nb, d = xb.shape
######################################################
# Make index
######################################################
if os.path.exists(args.indexfile):
print("reading", args.indexfile)
index = faiss.read_index(args.indexfile)
if isinstance(index, faiss.IndexPreTransform):
index_hnsw = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
else:
index_hnsw = index
vec_transform = lambda x:x
hnsw = index_hnsw.hnsw
hnsw_stats = faiss.cvar.hnsw_stats
else:
print("build index, key=", args.indexkey)
index = faiss.index_factory(d, args.indexkey)
if isinstance(index, faiss.IndexPreTransform):
index_hnsw = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
else:
index_hnsw = index
vec_transform = lambda x:x
hnsw = index_hnsw.hnsw
hnsw.efConstruction = args.efConstruction
hnsw_stats = faiss.cvar.hnsw_stats
index.verbose = True
index_hnsw.verbose = True
index_hnsw.storage.verbose = True
if args.M0 != -1:
print("set level 0 nb of neighbors to", args.M0)
hnsw.set_nb_neighbors(0, args.M0)
xt2 = sanitize(xt[:args.maxtrain])
assert np.all(np.isfinite(xt2))
print("train, size", xt.shape)
t0 = time.time()
index.train(xt2)
print(" train in %.3f s" % (time.time() - t0))
print("adding")
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(xb))
else:
for i0 in range(0, nb, args.add_bs):
i1 = min(nb, i0 + args.add_bs)
print(" adding %d:%d / %d" % (i0, i1, nb))
index.add(sanitize(xb[i0:i1]))
print(" add in %.3f s" % (time.time() - t0))
print("storing", args.indexfile)
faiss.write_index(index, args.indexfile)
######################################################
# Train beta centroids and encode dataset
######################################################
if args.beta_centroids:
print("reordering links")
index_hnsw.reorder_links()
if os.path.exists(args.beta_centroids):
print("load", args.beta_centroids)
beta_centroids = np.load(args.beta_centroids)
nsq, k, M1 = beta_centroids.shape
assert M1 == hnsw.nb_neighbors(0) + 1
rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq)
else:
print("train beta centroids")
rfn = faiss.ReconstructFromNeighbors(
index_hnsw, args.beta_k, args.beta_nsq)
xb_full = vec_transform(sanitize(xb[:args.beta_ntrain]))
beta_centroids = neighbor_codec.train_beta_codebook(
rfn, xb_full, niter=args.beta_niter)
print(" storing", args.beta_centroids)
np.save(args.beta_centroids, beta_centroids)
faiss.copy_array_to_vector(beta_centroids.ravel(),
rfn.codebook)
index_hnsw.reconstruct_from_neighbors = rfn
if rfn.k == 1:
pass # no codes to take care of
elif os.path.exists(args.neigh_recons_codes):
print("loading neigh codes", args.neigh_recons_codes)
codes = np.load(args.neigh_recons_codes)
assert codes.size == rfn.code_size * index.ntotal
faiss.copy_array_to_vector(codes.astype('uint8'),
rfn.codes)
rfn.ntotal = index.ntotal
else:
print("encoding neigh codes")
t0 = time.time()
bs = 1000000 if args.add_bs == -1 else args.add_bs
for i0 in range(0, nb, bs):
i1 = min(i0 + bs, nb)
print(" encode %d:%d / %d [%.3f s]\r" % (
i0, i1, nb, time.time() - t0), end=' ')
sys.stdout.flush()
xbatch = vec_transform(sanitize(xb[i0:i1]))
rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch))
print()
print("storing %s" % args.neigh_recons_codes)
codes = faiss.vector_to_array(rfn.codes)
np.save(args.neigh_recons_codes, codes)
######################################################
# Exhaustive evaluation
######################################################
if args.exhaustive:
print("exhaustive evaluation")
xq_tr = vec_transform(sanitize(xq))
index2 = faiss.IndexFlatL2(index_hnsw.d)
accu_recons_error = 0.0
if faiss.get_num_gpus() > 0:
print("do eval on GPU")
co = faiss.GpuMultipleClonerOptions()
co.shard = False
index2 = faiss.index_cpu_to_all_gpus(index2, co)
# process in batches in case the dataset does not fit in RAM
rh = datasets.ResultHeap(xq_tr.shape[0], 100)
t0 = time.time()
bs = 500000
for i0 in range(0, nb, bs):
i1 = min(nb, i0 + bs)
print(' handling batch %d:%d' % (i0, i1))
xb_recons = np.empty(
(i1 - i0, index_hnsw.d), dtype='float32')
rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons))
accu_recons_error += (
(vec_transform(sanitize(xb[i0:i1])) -
xb_recons)**2).sum()
index2.reset()
index2.add(xb_recons)
D, I = index2.search(xq_tr, 100)
rh.add_batch_result(D, I, i0)
rh.finalize()
del index2
t1 = time.time()
print("done in %.3f s" % (t1 - t0))
print("total reconstruction error: ", accu_recons_error)
print("eval retrieval:")
datasets.evaluate_DI(rh.D, rh.I, gt)
def get_neighbors(hnsw, i, level):
" list the neighbors for node i at level "
assert i < hnsw.levels.size()
assert level < hnsw.levels.at(i)
be = np.empty(2, 'uint64')
hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:]))
return [hnsw.neighbors.at(j) for j in range(be[0], be[1])]
#############################################################
# Index is ready
#############################################################
xq = sanitize(xq)
if args.searchthreads != -1:
print("Setting nb of threads to", args.searchthreads)
faiss.omp_set_num_threads(args.searchthreads)
if gt is None:
print("no valid groundtruth -- exit")
sys.exit()
k_reorders = [int(x) for x in args.k_reorder.split(',')]
efSearchs = [int(x) for x in args.efSearch.split(',')]
for k_reorder in k_reorders:
if index_hnsw.reconstruct_from_neighbors:
print("setting k_reorder=%d" % k_reorder)
index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder
for efSearch in efSearchs:
print("efSearch=%-4d" % efSearch, end=' ')
hnsw.efSearch = efSearch
hnsw_stats.reset()
datasets.evaluate(xq, gt, index, k=args.k, endl=False)
print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder))

View File

@ -1,236 +0,0 @@
#! /usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Common functions to load datasets and compute their ground-truth
"""
from __future__ import print_function
import time
import numpy as np
import faiss
import pdb
import sys
# set this to the directory that contains the datafiles.
# deep1b data should be at simdir + 'deep1b'
# bigann data should be at simdir + 'bigann'
simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/'
#################################################################
# Small I/O functions
#################################################################
def ivecs_read(fname):
a = np.fromfile(fname, dtype='int32')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
def ivecs_mmap(fname):
a = np.memmap(fname, dtype='int32', mode='r')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:]
def fvecs_mmap(fname):
return ivecs_mmap(fname).view('float32')
def bvecs_mmap(fname):
x = np.memmap(fname, dtype='uint8', mode='r')
d = x[:4].view('int32')[0]
return x.reshape(-1, d + 4)[:, 4:]
def ivecs_write(fname, m):
n, d = m.shape
m1 = np.empty((n, d + 1), dtype='int32')
m1[:, 0] = d
m1[:, 1:] = m
m1.tofile(fname)
def fvecs_write(fname, m):
m = m.astype('float32')
ivecs_write(fname, m.view('int32'))
#################################################################
# Dataset
#################################################################
def sanitize(x):
return np.ascontiguousarray(x, dtype='float32')
class ResultHeap:
""" Combine query results from a sliced dataset """
def __init__(self, nq, k):
" nq: number of query vectors, k: number of results per query "
self.I = np.zeros((nq, k), dtype='int64')
self.D = np.zeros((nq, k), dtype='float32')
self.nq, self.k = nq, k
heaps = faiss.float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = faiss.swig_ptr(self.D)
heaps.ids = faiss.swig_ptr(self.I)
heaps.heapify()
self.heaps = heaps
def add_batch_result(self, D, I, i0):
assert D.shape == (self.nq, self.k)
assert I.shape == (self.nq, self.k)
I += i0
self.heaps.addn_with_ids(
self.k, faiss.swig_ptr(D),
faiss.swig_ptr(I), self.k)
def finalize(self):
self.heaps.reorder()
def compute_GT_sliced(xb, xq, k):
print("compute GT")
t0 = time.time()
nb, d = xb.shape
nq, d = xq.shape
rh = ResultHeap(nq, k)
bs = 10 ** 5
xqs = sanitize(xq)
db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
# compute ground-truth by blocks of bs, and add to heaps
for i0 in range(0, nb, bs):
i1 = min(nb, i0 + bs)
xsl = sanitize(xb[i0:i1])
db_gt.add(xsl)
D, I = db_gt.search(xqs, k)
rh.add_batch_result(D, I, i0)
db_gt.reset()
print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ')
sys.stdout.flush()
print()
rh.finalize()
gt_I = rh.I
print("GT time: %.3f s" % (time.time() - t0))
return gt_I
def do_compute_gt(xb, xq, k):
print("computing GT")
nb, d = xb.shape
index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
if nb < 100 * 1000:
print(" add")
index.add(np.ascontiguousarray(xb, dtype='float32'))
print(" search")
D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k)
else:
I = compute_GT_sliced(xb, xq, k)
return I.astype('int32')
def load_data(dataset='deep1M', compute_gt=False):
print("load data", dataset)
if dataset == 'sift1M':
basedir = simdir + 'sift1M/'
xt = fvecs_read(basedir + "sift_learn.fvecs")
xb = fvecs_read(basedir + "sift_base.fvecs")
xq = fvecs_read(basedir + "sift_query.fvecs")
gt = ivecs_read(basedir + "sift_groundtruth.ivecs")
elif dataset.startswith('bigann'):
basedir = simdir + 'bigann/'
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
xb = bvecs_mmap(basedir + 'bigann_base.bvecs')
xq = bvecs_mmap(basedir + 'bigann_query.bvecs')
xt = bvecs_mmap(basedir + 'bigann_learn.bvecs')
# trim xb to correct size
xb = xb[:dbsize * 1000 * 1000]
gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize)
elif dataset.startswith("deep"):
basedir = simdir + 'deep1b/'
szsuf = dataset[4:]
if szsuf[-1] == 'M':
dbsize = 10 ** 6 * int(szsuf[:-1])
elif szsuf == '1B':
dbsize = 10 ** 9
elif szsuf[-1] == 'k':
dbsize = 1000 * int(szsuf[:-1])
else:
assert False, "did not recognize suffix " + szsuf
xt = fvecs_mmap(basedir + "learn.fvecs")
xb = fvecs_mmap(basedir + "base.fvecs")
xq = fvecs_read(basedir + "deep1B_queries.fvecs")
xb = xb[:dbsize]
gt_fname = basedir + "%s_groundtruth.ivecs" % dataset
if compute_gt:
gt = do_compute_gt(xb, xq, 100)
print("store", gt_fname)
ivecs_write(gt_fname, gt)
gt = ivecs_read(gt_fname)
else:
assert False
print("dataset %s sizes: B %s Q %s T %s" % (
dataset, xb.shape, xq.shape, xt.shape))
return xt, xb, xq, gt
#################################################################
# Evaluation
#################################################################
def evaluate_DI(D, I, gt):
nq = gt.shape[0]
k = I.shape[1]
rank = 1
while rank <= k:
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
print("R@%d: %.4f" % (rank, recall), end=' ')
rank *= 10
def evaluate(xq, gt, index, k=100, endl=True):
t0 = time.time()
D, I = index.search(xq, k)
t1 = time.time()
nq = xq.shape[0]
print("\t %8.4f ms per query, " % (
(t1 - t0) * 1000.0 / nq), end=' ')
rank = 1
while rank <= k:
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
print("R@%d: %.4f" % (rank, recall), end=' ')
rank *= 10
if endl:
print()
return D, I

View File

@ -1,241 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This is the training code for the link and code. Especially the
neighbors_kmeans function implements the EM-algorithm to find the
appropriate weightings and cluster them.
"""
from __future__ import print_function
import time
import numpy as np
import faiss
#----------------------------------------------------------
# Utils
#----------------------------------------------------------
def sanitize(x):
return np.ascontiguousarray(x, dtype='float32')
def train_kmeans(x, k, ngpu, max_points_per_centroid=256):
"Runs kmeans on one or several GPUs"
d = x.shape[1]
clus = faiss.Clustering(d, k)
clus.verbose = True
clus.niter = 20
clus.max_points_per_centroid = max_points_per_centroid
if ngpu == 0:
index = faiss.IndexFlatL2(d)
else:
res = [faiss.StandardGpuResources() for i in range(ngpu)]
flat_config = []
for i in range(ngpu):
cfg = faiss.GpuIndexFlatConfig()
cfg.useFloat16 = False
cfg.device = i
flat_config.append(cfg)
if ngpu == 1:
index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
else:
indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i])
for i in range(ngpu)]
index = faiss.IndexReplicas()
for sub_index in indexes:
index.addIndex(sub_index)
# perform the training
clus.train(x, index)
centroids = faiss.vector_float_to_array(clus.centroids)
stats = clus.iteration_stats
stats = [stats.at(i) for i in range(stats.size())]
obj = np.array([st.obj for st in stats])
print("final objective: %.4g" % obj[-1])
return centroids.reshape(k, d)
#----------------------------------------------------------
# Learning the codebook from neighbors
#----------------------------------------------------------
# works with both a full Inn table and dynamically generated neighbors
def get_Inn_shape(Inn):
if type(Inn) != tuple:
return Inn.shape
return Inn[:2]
def get_neighbor_table(x_coded, Inn, i):
if type(Inn) != tuple:
return x_coded[Inn[i,:],:]
rfn = x_coded
M, d = rfn.M, rfn.index.d
out = np.zeros((M + 1, d), dtype='float32')
int_i = int(i)
rfn.get_neighbor_table(int_i, faiss.swig_ptr(out))
_, _, sq = Inn
return out[:, sq * rfn.dsub : (sq + 1) * rfn.dsub]
# Function that produces the best regression values from the vector
# and its neighbors
def regress_from_neighbors (x, x_coded, Inn):
(N, knn) = get_Inn_shape(Inn)
betas = np.zeros((N,knn))
t0 = time.time()
for i in range (N):
xi = x[i,:]
NNi = get_neighbor_table(x_coded, Inn, i)
betas[i,:] = np.linalg.lstsq(NNi.transpose(), xi, rcond=0.01)[0]
if i % (N / 10) == 0:
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
return betas
# find the best beta minimizing ||x-x_coded[Inn,:]*beta||^2
def regress_opt_beta (x, x_coded, Inn):
(N, knn) = get_Inn_shape(Inn)
d = x.shape[1]
# construct the linear system to be solved
X = np.zeros ((d*N))
Y = np.zeros ((d*N, knn))
for i in range (N):
X[i*d:(i+1)*d] = x[i,:]
neighbor_table = get_neighbor_table(x_coded, Inn, i)
Y[i*d:(i+1)*d, :] = neighbor_table.transpose()
beta_opt = np.linalg.lstsq(Y, X, rcond=0.01)[0]
return beta_opt
# Find the best encoding by minimizing the reconstruction error using
# a set of pre-computed beta values
def assign_beta (beta_centroids, x, x_coded, Inn, verbose=True):
if type(Inn) == tuple:
return assign_beta_2(beta_centroids, x, x_coded, Inn)
(N, knn) = Inn.shape
x_ibeta = np.zeros ((N), dtype='int32')
t0= time.time()
for i in range (N):
NNi = x_coded[Inn[i,:]]
# Consider all possible betas for the encoding and compute the
# encoding error
x_reg_all = np.dot (beta_centroids, NNi)
err = ((x_reg_all - x[i,:]) ** 2).sum(axis=1)
x_ibeta[i] = err.argmin()
if verbose:
if i % (N / 10) == 0:
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
return x_ibeta
# Reconstruct a set of vectors using the beta_centroids, the
# assignment, the encoded neighbors identified by the list Inn (which
# includes the vector itself)
def recons_from_neighbors (beta_centroids, x_ibeta, x_coded, Inn):
(N, knn) = Inn.shape
x_rec = np.zeros(x_coded.shape)
t0= time.time()
for i in range (N):
NNi = x_coded[Inn[i,:]]
x_rec[i, :] = np.dot (beta_centroids[x_ibeta[i]], NNi)
if i % (N / 10) == 0:
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
return x_rec
# Compute a EM-like algorithm trying at optimizing the beta such as they
# minimize the reconstruction error from the neighbors
def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5):
# First compute centroids using a regular k-means algorithm
betas = regress_from_neighbors (x, x_coded, Inn)
beta_centroids = train_kmeans(
sanitize(betas), K, ngpus, max_points_per_centroid=1000000)
_, knn = get_Inn_shape(Inn)
d = x.shape[1]
rs = np.random.RandomState()
for iter in range(niter):
print('iter', iter)
idx = assign_beta (beta_centroids, x, x_coded, Inn, verbose=False)
hist = np.bincount(idx)
for cl0 in np.where(hist == 0)[0]:
print(" cluster %d empty, split" % cl0, end=' ')
cl1 = idx[np.random.randint(idx.size)]
pos = np.nonzero (idx == cl1)[0]
pos = rs.choice(pos, pos.size / 2)
print(" cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos)))
idx[pos] = cl0
hist = np.bincount(idx)
tot_err = 0
for k in range (K):
pos = np.nonzero (idx == k)[0]
npos = pos.shape[0]
X = np.zeros (d*npos)
Y = np.zeros ((d*npos, knn))
for i in range(npos):
X[i*d:(i+1)*d] = x[pos[i],:]
neighbor_table = get_neighbor_table(x_coded, Inn, pos[i])
Y[i*d:(i+1)*d, :] = neighbor_table.transpose()
sol, residuals, _, _ = np.linalg.lstsq(Y, X, rcond=0.01)
if residuals.size > 0:
tot_err += residuals.sum()
beta_centroids[k, :] = sol
print(' err=%g' % tot_err)
return beta_centroids
# assign the betas in C++
def assign_beta_2(beta_centroids, x, rfn, Inn):
_, _, sq = Inn
if rfn.k == 1:
return np.zeros(x.shape[0], dtype=int)
# add dummy dimensions to beta_centroids and x
all_beta_centroids = np.zeros(
(rfn.nsq, rfn.k, rfn.M + 1), dtype='float32')
all_beta_centroids[sq] = beta_centroids
all_x = np.zeros((len(x), rfn.d), dtype='float32')
all_x[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] = x
rfn.codes.clear()
rfn.ntotal = 0
faiss.copy_array_to_vector(
all_beta_centroids.ravel(), rfn.codebook)
rfn.add_codes(len(x), faiss.swig_ptr(all_x))
codes = faiss.vector_to_array(rfn.codes)
codes = codes.reshape(-1, rfn.nsq)
return codes[:, sq]
#######################################################
# For usage from bench_storages.py
def train_beta_codebook(rfn, xb_full, niter=10):
beta_centroids = []
for sq in range(rfn.nsq):
d0, d1 = sq * rfn.dsub, (sq + 1) * rfn.dsub
print("training subquantizer %d/%d on dimensions %d:%d" % (
sq, rfn.nsq, d0, d1))
beta_centroids_i = neighbors_kmeans(
xb_full[:, d0:d1], rfn, (xb_full.shape[0], rfn.M + 1, sq),
rfn.k,
ngpus=0, niter=niter)
beta_centroids.append(beta_centroids_i)
rfn.ntotal = 0
rfn.codes.clear()
rfn.codebook.clear()
return np.stack(beta_centroids)

View File

@ -307,7 +307,7 @@ void hnsw_search(
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
efSearch = params->efSearch;
}
size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
size_t n1 = 0, n2 = 0, ndis = 0;
idx_t check_period = InterruptCallback::get_period_hint(
hnsw.max_level * index->d * efSearch);
@ -323,7 +323,7 @@ void hnsw_search(
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(index->storage));
#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided)
#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
for (idx_t i = i0; i < i1; i++) {
res.begin(i);
dis->set_query(x + i * index->d);
@ -331,16 +331,14 @@ void hnsw_search(
HNSWStats stats = hnsw.search(*dis, res, vt, params);
n1 += stats.n1;
n2 += stats.n2;
n3 += stats.n3;
ndis += stats.ndis;
nreorder += stats.nreorder;
res.end();
}
}
InterruptCallback::check();
}
hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
hnsw_stats.combine({n1, n2, ndis});
}
} // anonymous namespace
@ -800,7 +798,7 @@ void IndexHNSW2Level::search(
IndexHNSW::search(n, x, k, distances, labels);
} else { // "mixed" search
size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
size_t n1 = 0, n2 = 0, ndis = 0;
const IndexIVFPQ* index_ivfpq =
dynamic_cast<const IndexIVFPQ*>(storage);
@ -832,7 +830,7 @@ void IndexHNSW2Level::search(
int candidates_size = hnsw.upper_beam;
MinimaxHeap candidates(candidates_size);
#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
#pragma omp for reduction(+ : n1, n2, ndis)
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
@ -877,9 +875,7 @@ void IndexHNSW2Level::search(
k);
n1 += search_stats.n1;
n2 += search_stats.n2;
n3 += search_stats.n3;
ndis += search_stats.ndis;
nreorder += search_stats.nreorder;
vt.advance();
vt.advance();
@ -888,7 +884,7 @@ void IndexHNSW2Level::search(
}
}
hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
hnsw_stats.combine({n1, n2, ndis});
}
}

View File

@ -84,19 +84,14 @@ class GpuIndex : public faiss::Index {
/// `x` and `labels` can be resident on the CPU or any GPU; copies are
/// performed as needed
void assign(
idx_t n,
const float* x,
idx_t* labels,
// faiss::Index has idx_t for k
idx_t k = 1) const override;
void assign(idx_t n, const float* x, idx_t* labels, idx_t k = 1)
const override;
/// `x`, `distances` and `labels` can be resident on the CPU or any
/// GPU; copies are performed as needed
void search(
idx_t n,
const float* x,
// faiss::Index has idx_t for k
idx_t k,
float* distances,
idx_t* labels,
@ -107,7 +102,6 @@ class GpuIndex : public faiss::Index {
void search_and_reconstruct(
idx_t n,
const float* x,
// faiss::Index has idx_t for k
idx_t k,
float* distances,
idx_t* labels,

View File

@ -664,7 +664,7 @@ int search_from_candidates(
if (candidates.size() == 0) {
stats.n2++;
}
stats.n3 += ndis;
stats.ndis += ndis;
}
return nres;
@ -793,7 +793,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
if (candidates.size() == 0) {
++stats.n2;
}
stats.n3 += ndis;
stats.ndis += ndis;
return top_candidates;
}

View File

@ -230,30 +230,20 @@ struct HNSW {
};
struct HNSWStats {
size_t n1, n2, n3;
size_t ndis;
size_t nreorder;
HNSWStats(
size_t n1 = 0,
size_t n2 = 0,
size_t n3 = 0,
size_t ndis = 0,
size_t nreorder = 0)
: n1(n1), n2(n2), n3(n3), ndis(ndis), nreorder(nreorder) {}
size_t n1 = 0; /// numbner of vectors searched
size_t n2 =
0; /// number of queries for which the candidate list is exhasted
size_t ndis = 0; /// number of distances computed
void reset() {
n1 = n2 = n3 = 0;
n1 = n2 = 0;
ndis = 0;
nreorder = 0;
}
void combine(const HNSWStats& other) {
n1 += other.n1;
n2 += other.n2;
n3 += other.n3;
ndis += other.ndis;
nreorder += other.nreorder;
}
};

View File

@ -123,6 +123,16 @@ class TestHNSW(unittest.TestCase):
mask = Iref[:, 0] == Ihnsw[:, 0]
assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0])
def test_ndis_stats(self):
d = self.xq.shape[1]
index = faiss.IndexHNSWFlat(d, 16)
index.add(self.xb)
stats = faiss.cvar.hnsw_stats
stats.reset()
Dhnsw, Ihnsw = index.search(self.xq, 1)
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)
class TestNSG(unittest.TestCase):