Fix HNSW stats (#3309)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3309 Make sure that the HNSW search stats work, remove stats for deprecated functionality. Remove code of the link and code paper that is not supported anymore. Reviewed By: kuarora, junjieqi Differential Revision: D55247802 fbshipit-source-id: 03f176be092bff6b2db359cc956905d8646ea702pull/3312/head
parent
b77061ff5e
commit
fa1f39ec9f
|
@ -21,138 +21,5 @@ graph to improve the reconstruction. It is described in
|
|||
|
||||
ArXiV [here](https://arxiv.org/abs/1804.09996)
|
||||
|
||||
Code structure
|
||||
--------------
|
||||
|
||||
The test runs with 3 files:
|
||||
|
||||
- `bench_link_and_code.py`: driver script
|
||||
|
||||
- `datasets.py`: code to load the datasets. The example code runs on the
|
||||
deep1b and bigann datasets. See the [toplevel README](../README.md)
|
||||
on how to download them. They should be put in a directory, edit
|
||||
datasets.py to set the path.
|
||||
|
||||
- `neighbor_codec.py`: this is where the representation is trained.
|
||||
|
||||
The code runs on top of Faiss. The HNSW index can be extended with a
|
||||
`ReconstructFromNeighbors` C++ object that refines the distances. The
|
||||
training is implemented in Python.
|
||||
|
||||
Update: 2023-12-28: the current Faiss dropped support for reconstruction with
|
||||
this method.
|
||||
|
||||
Reproducing Table 2 in the paper
|
||||
--------------------------------
|
||||
|
||||
The results of table 2 (accuracy on deep100M) in the paper can be
|
||||
obtained with:
|
||||
|
||||
```bash
|
||||
python bench_link_and_code.py \
|
||||
--db deep100M \
|
||||
--M0 6 \
|
||||
--indexkey OPQ36_144,HNSW32_PQ36 \
|
||||
--indexfile $bdir/deep100M_PQ36_L6.index \
|
||||
--beta_nsq 4 \
|
||||
--beta_centroids $bdir/deep100M_PQ36_L6_nsq4.npy \
|
||||
--neigh_recons_codes $bdir/deep100M_PQ36_L6_nsq4_codes.npy \
|
||||
--k_reorder 0,5 --efSearch 1,1024
|
||||
```
|
||||
|
||||
Set `bdir` to a scratch directory.
|
||||
|
||||
Explanation of the flags:
|
||||
|
||||
- `--db deep1M`: dataset to process
|
||||
|
||||
- `--M0 6`: number of links on the base level (L6)
|
||||
|
||||
- `--indexkey OPQ36_144,HNSW32_PQ36`: Faiss index key to construct the
|
||||
HNSW structure. It means that vectors are transformed by OPQ and
|
||||
encoded with PQ 36x8 (with an intermediate size of 144D). The HNSW
|
||||
level>0 nodes have 32 links (theses ones are "cheap" to store
|
||||
because there are fewer nodes in the upper levels.
|
||||
|
||||
- `--indexfile $bdir/deep1M_PQ36_M6.index`: name of the index file
|
||||
(without information for the L&C extension)
|
||||
|
||||
- `--beta_nsq 4`: number of bytes to allocate for the codes (M in the
|
||||
paper)
|
||||
|
||||
- `--beta_centroids $bdir/deep1M_PQ36_M6_nsq4.npy`: filename to store
|
||||
the trained beta centroids
|
||||
|
||||
- `--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq4_codes.npy`: filename
|
||||
for the encoded weights (beta) of the combination
|
||||
|
||||
- `--k_reorder 0,5`: number of results to reorder. 0 = baseline
|
||||
without reordering, 5 = value used throughout the paper
|
||||
|
||||
- `--efSearch 1,1024`: number of nodes to visit (T in the paper)
|
||||
|
||||
The script will proceed with the following steps:
|
||||
|
||||
0. load dataset (and possibly compute the ground-truth if the
|
||||
ground-truth file is not provided)
|
||||
|
||||
1. train the OPQ encoder
|
||||
|
||||
2. build the index and store it
|
||||
|
||||
3. compute the residuals and train the beta vocabulary to do the reconstruction
|
||||
|
||||
4. encode the vertices
|
||||
|
||||
5. search and evaluate the search results.
|
||||
|
||||
With option `--exhaustive` the results of the exhaustive column can be
|
||||
obtained.
|
||||
|
||||
The run above should output:
|
||||
```bash
|
||||
...
|
||||
setting k_reorder=5
|
||||
...
|
||||
efSearch=1024 0.3132 ms per query, R@1: 0.4283 R@10: 0.6337 R@100: 0.6520 ndis 40941919 nreorder 50000
|
||||
|
||||
```
|
||||
which matches the paper's table 2.
|
||||
|
||||
Note that in multi-threaded mode, the building of the HNSW structure
|
||||
is not deterministic. Therefore, the results across runs may not be exactly the same.
|
||||
|
||||
Reproducing Figure 5 in the paper
|
||||
---------------------------------
|
||||
|
||||
Figure 5 just evaluates the combination of HNSW and PQ. For example,
|
||||
the operating point L6&OPQ40 can be obtained with
|
||||
|
||||
```bash
|
||||
python bench_link_and_code.py \
|
||||
--db deep1M \
|
||||
--M0 6 \
|
||||
--indexkey OPQ40_160,HNSW32_PQ40 \
|
||||
--indexfile $bdir/deep1M_PQ40_M6.index \
|
||||
--beta_nsq 1 --beta_k 1 \
|
||||
--beta_centroids $bdir/deep1M_PQ40_M6_nsq0.npy \
|
||||
--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq0_codes.npy \
|
||||
--k_reorder 0 --efSearch 16,64,256,1024
|
||||
```
|
||||
|
||||
The arguments are similar to the previous table. Note that nsq = 0 is
|
||||
simulated by setting beta_nsq = 1 and beta_k = 1 (ie a code with a single
|
||||
reproduction value).
|
||||
|
||||
The output should look like:
|
||||
|
||||
```bash
|
||||
setting k_reorder=0
|
||||
efSearch=16 0.0147 ms per query, R@1: 0.3409 R@10: 0.4388 R@100: 0.4394 ndis 2629735 nreorder 0
|
||||
efSearch=64 0.0122 ms per query, R@1: 0.4836 R@10: 0.6490 R@100: 0.6509 ndis 4623221 nreorder 0
|
||||
efSearch=256 0.0344 ms per query, R@1: 0.5730 R@10: 0.7915 R@100: 0.7951 ndis 11090176 nreorder 0
|
||||
efSearch=1024 0.2656 ms per query, R@1: 0.6212 R@10: 0.8722 R@100: 0.8765 ndis 33501951 nreorder 0
|
||||
```
|
||||
|
||||
The results with k_reorder=5 are not reported in the paper, they
|
||||
represent the performance of a "free coding" version of the algorithm.
|
||||
The necessary code for this paper was removed from Faiss in version 1.8.0.
|
||||
For a functioning verinsion, use Faiss 1.7.4.
|
||||
|
|
|
@ -1,300 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
import argparse
|
||||
import datasets
|
||||
from datasets import sanitize
|
||||
import neighbor_codec
|
||||
|
||||
######################################################
|
||||
# Command-line parsing
|
||||
######################################################
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def aa(*args, **kwargs):
|
||||
group.add_argument(*args, **kwargs)
|
||||
|
||||
group = parser.add_argument_group('dataset options')
|
||||
|
||||
aa('--db', default='deep1M', help='dataset')
|
||||
aa( '--compute_gt', default=False, action='store_true',
|
||||
help='compute and store the groundtruth')
|
||||
|
||||
group = parser.add_argument_group('index consturction')
|
||||
|
||||
aa('--indexkey', default='HNSW32', help='index_factory type')
|
||||
aa('--efConstruction', default=200, type=int,
|
||||
help='HNSW construction factor')
|
||||
aa('--M0', default=-1, type=int, help='size of base level')
|
||||
aa('--maxtrain', default=256 * 256, type=int,
|
||||
help='maximum number of training points')
|
||||
aa('--indexfile', default='', help='file to read or write index from')
|
||||
aa('--add_bs', default=-1, type=int,
|
||||
help='add elements index by batches of this size')
|
||||
aa('--link_singletons', default=False, action='store_true',
|
||||
help='do a pass to link in the singletons')
|
||||
|
||||
group = parser.add_argument_group(
|
||||
'searching (reconstruct_from_neighbors options)')
|
||||
|
||||
aa('--beta_centroids', default='',
|
||||
help='file with codebook')
|
||||
aa('--neigh_recons_codes', default='',
|
||||
help='file with codes for reconstruction')
|
||||
aa('--beta_ntrain', default=250000, type=int, help='')
|
||||
aa('--beta_k', default=256, type=int, help='beta codebook size')
|
||||
aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors')
|
||||
aa('--beta_niter', default=10, type=int, help='')
|
||||
aa('--k_reorder', default='-1', help='')
|
||||
|
||||
group = parser.add_argument_group('searching')
|
||||
|
||||
aa('--k', default=100, type=int, help='nb of nearest neighbors')
|
||||
aa('--exhaustive', default=False, action='store_true',
|
||||
help='report the exhaustive search topline')
|
||||
aa('--searchthreads', default=-1, type=int,
|
||||
help='nb of threads to use at search time')
|
||||
aa('--efSearch', default='', type=str,
|
||||
help='comma-separated values of efSearch to try')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("args:", args)
|
||||
|
||||
|
||||
######################################################
|
||||
# Load dataset
|
||||
######################################################
|
||||
|
||||
xt, xb, xq, gt = datasets.load_data(
|
||||
dataset=args.db, compute_gt=args.compute_gt)
|
||||
|
||||
nq, d = xq.shape
|
||||
nb, d = xb.shape
|
||||
|
||||
|
||||
######################################################
|
||||
# Make index
|
||||
######################################################
|
||||
|
||||
if os.path.exists(args.indexfile):
|
||||
|
||||
print("reading", args.indexfile)
|
||||
index = faiss.read_index(args.indexfile)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
index_hnsw = faiss.downcast_index(index.index)
|
||||
vec_transform = index.chain.at(0).apply_py
|
||||
else:
|
||||
index_hnsw = index
|
||||
vec_transform = lambda x:x
|
||||
|
||||
hnsw = index_hnsw.hnsw
|
||||
hnsw_stats = faiss.cvar.hnsw_stats
|
||||
|
||||
else:
|
||||
|
||||
print("build index, key=", args.indexkey)
|
||||
|
||||
index = faiss.index_factory(d, args.indexkey)
|
||||
|
||||
if isinstance(index, faiss.IndexPreTransform):
|
||||
index_hnsw = faiss.downcast_index(index.index)
|
||||
vec_transform = index.chain.at(0).apply_py
|
||||
else:
|
||||
index_hnsw = index
|
||||
vec_transform = lambda x:x
|
||||
|
||||
hnsw = index_hnsw.hnsw
|
||||
hnsw.efConstruction = args.efConstruction
|
||||
hnsw_stats = faiss.cvar.hnsw_stats
|
||||
index.verbose = True
|
||||
index_hnsw.verbose = True
|
||||
index_hnsw.storage.verbose = True
|
||||
|
||||
if args.M0 != -1:
|
||||
print("set level 0 nb of neighbors to", args.M0)
|
||||
hnsw.set_nb_neighbors(0, args.M0)
|
||||
|
||||
xt2 = sanitize(xt[:args.maxtrain])
|
||||
assert np.all(np.isfinite(xt2))
|
||||
|
||||
print("train, size", xt.shape)
|
||||
t0 = time.time()
|
||||
index.train(xt2)
|
||||
print(" train in %.3f s" % (time.time() - t0))
|
||||
|
||||
print("adding")
|
||||
t0 = time.time()
|
||||
if args.add_bs == -1:
|
||||
index.add(sanitize(xb))
|
||||
else:
|
||||
for i0 in range(0, nb, args.add_bs):
|
||||
i1 = min(nb, i0 + args.add_bs)
|
||||
print(" adding %d:%d / %d" % (i0, i1, nb))
|
||||
index.add(sanitize(xb[i0:i1]))
|
||||
|
||||
print(" add in %.3f s" % (time.time() - t0))
|
||||
print("storing", args.indexfile)
|
||||
faiss.write_index(index, args.indexfile)
|
||||
|
||||
|
||||
######################################################
|
||||
# Train beta centroids and encode dataset
|
||||
######################################################
|
||||
|
||||
if args.beta_centroids:
|
||||
print("reordering links")
|
||||
index_hnsw.reorder_links()
|
||||
|
||||
if os.path.exists(args.beta_centroids):
|
||||
print("load", args.beta_centroids)
|
||||
beta_centroids = np.load(args.beta_centroids)
|
||||
nsq, k, M1 = beta_centroids.shape
|
||||
assert M1 == hnsw.nb_neighbors(0) + 1
|
||||
|
||||
rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq)
|
||||
else:
|
||||
print("train beta centroids")
|
||||
rfn = faiss.ReconstructFromNeighbors(
|
||||
index_hnsw, args.beta_k, args.beta_nsq)
|
||||
|
||||
xb_full = vec_transform(sanitize(xb[:args.beta_ntrain]))
|
||||
|
||||
beta_centroids = neighbor_codec.train_beta_codebook(
|
||||
rfn, xb_full, niter=args.beta_niter)
|
||||
|
||||
print(" storing", args.beta_centroids)
|
||||
np.save(args.beta_centroids, beta_centroids)
|
||||
|
||||
|
||||
faiss.copy_array_to_vector(beta_centroids.ravel(),
|
||||
rfn.codebook)
|
||||
index_hnsw.reconstruct_from_neighbors = rfn
|
||||
|
||||
if rfn.k == 1:
|
||||
pass # no codes to take care of
|
||||
elif os.path.exists(args.neigh_recons_codes):
|
||||
print("loading neigh codes", args.neigh_recons_codes)
|
||||
codes = np.load(args.neigh_recons_codes)
|
||||
assert codes.size == rfn.code_size * index.ntotal
|
||||
faiss.copy_array_to_vector(codes.astype('uint8'),
|
||||
rfn.codes)
|
||||
rfn.ntotal = index.ntotal
|
||||
else:
|
||||
print("encoding neigh codes")
|
||||
t0 = time.time()
|
||||
|
||||
bs = 1000000 if args.add_bs == -1 else args.add_bs
|
||||
|
||||
for i0 in range(0, nb, bs):
|
||||
i1 = min(i0 + bs, nb)
|
||||
print(" encode %d:%d / %d [%.3f s]\r" % (
|
||||
i0, i1, nb, time.time() - t0), end=' ')
|
||||
sys.stdout.flush()
|
||||
xbatch = vec_transform(sanitize(xb[i0:i1]))
|
||||
rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch))
|
||||
print()
|
||||
|
||||
print("storing %s" % args.neigh_recons_codes)
|
||||
codes = faiss.vector_to_array(rfn.codes)
|
||||
np.save(args.neigh_recons_codes, codes)
|
||||
|
||||
######################################################
|
||||
# Exhaustive evaluation
|
||||
######################################################
|
||||
|
||||
if args.exhaustive:
|
||||
print("exhaustive evaluation")
|
||||
xq_tr = vec_transform(sanitize(xq))
|
||||
index2 = faiss.IndexFlatL2(index_hnsw.d)
|
||||
accu_recons_error = 0.0
|
||||
|
||||
if faiss.get_num_gpus() > 0:
|
||||
print("do eval on GPU")
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = False
|
||||
index2 = faiss.index_cpu_to_all_gpus(index2, co)
|
||||
|
||||
# process in batches in case the dataset does not fit in RAM
|
||||
rh = datasets.ResultHeap(xq_tr.shape[0], 100)
|
||||
t0 = time.time()
|
||||
bs = 500000
|
||||
for i0 in range(0, nb, bs):
|
||||
i1 = min(nb, i0 + bs)
|
||||
print(' handling batch %d:%d' % (i0, i1))
|
||||
|
||||
xb_recons = np.empty(
|
||||
(i1 - i0, index_hnsw.d), dtype='float32')
|
||||
rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons))
|
||||
|
||||
accu_recons_error += (
|
||||
(vec_transform(sanitize(xb[i0:i1])) -
|
||||
xb_recons)**2).sum()
|
||||
|
||||
index2.reset()
|
||||
index2.add(xb_recons)
|
||||
D, I = index2.search(xq_tr, 100)
|
||||
rh.add_batch_result(D, I, i0)
|
||||
|
||||
rh.finalize()
|
||||
del index2
|
||||
t1 = time.time()
|
||||
print("done in %.3f s" % (t1 - t0))
|
||||
print("total reconstruction error: ", accu_recons_error)
|
||||
print("eval retrieval:")
|
||||
datasets.evaluate_DI(rh.D, rh.I, gt)
|
||||
|
||||
|
||||
def get_neighbors(hnsw, i, level):
|
||||
" list the neighbors for node i at level "
|
||||
assert i < hnsw.levels.size()
|
||||
assert level < hnsw.levels.at(i)
|
||||
be = np.empty(2, 'uint64')
|
||||
hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:]))
|
||||
return [hnsw.neighbors.at(j) for j in range(be[0], be[1])]
|
||||
|
||||
|
||||
#############################################################
|
||||
# Index is ready
|
||||
#############################################################
|
||||
|
||||
xq = sanitize(xq)
|
||||
|
||||
if args.searchthreads != -1:
|
||||
print("Setting nb of threads to", args.searchthreads)
|
||||
faiss.omp_set_num_threads(args.searchthreads)
|
||||
|
||||
|
||||
if gt is None:
|
||||
print("no valid groundtruth -- exit")
|
||||
sys.exit()
|
||||
|
||||
|
||||
k_reorders = [int(x) for x in args.k_reorder.split(',')]
|
||||
efSearchs = [int(x) for x in args.efSearch.split(',')]
|
||||
|
||||
|
||||
for k_reorder in k_reorders:
|
||||
|
||||
if index_hnsw.reconstruct_from_neighbors:
|
||||
print("setting k_reorder=%d" % k_reorder)
|
||||
index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder
|
||||
|
||||
for efSearch in efSearchs:
|
||||
print("efSearch=%-4d" % efSearch, end=' ')
|
||||
hnsw.efSearch = efSearch
|
||||
hnsw_stats.reset()
|
||||
datasets.evaluate(xq, gt, index, k=args.k, endl=False)
|
||||
|
||||
print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder))
|
|
@ -1,236 +0,0 @@
|
|||
#! /usr/bin/env python2
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Common functions to load datasets and compute their ground-truth
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
import pdb
|
||||
import sys
|
||||
|
||||
# set this to the directory that contains the datafiles.
|
||||
# deep1b data should be at simdir + 'deep1b'
|
||||
# bigann data should be at simdir + 'bigann'
|
||||
simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/'
|
||||
|
||||
#################################################################
|
||||
# Small I/O functions
|
||||
#################################################################
|
||||
|
||||
|
||||
def ivecs_read(fname):
|
||||
a = np.fromfile(fname, dtype='int32')
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
|
||||
|
||||
def fvecs_read(fname):
|
||||
return ivecs_read(fname).view('float32')
|
||||
|
||||
|
||||
def ivecs_mmap(fname):
|
||||
a = np.memmap(fname, dtype='int32', mode='r')
|
||||
d = a[0]
|
||||
return a.reshape(-1, d + 1)[:, 1:]
|
||||
|
||||
|
||||
def fvecs_mmap(fname):
|
||||
return ivecs_mmap(fname).view('float32')
|
||||
|
||||
|
||||
def bvecs_mmap(fname):
|
||||
x = np.memmap(fname, dtype='uint8', mode='r')
|
||||
d = x[:4].view('int32')[0]
|
||||
return x.reshape(-1, d + 4)[:, 4:]
|
||||
|
||||
|
||||
def ivecs_write(fname, m):
|
||||
n, d = m.shape
|
||||
m1 = np.empty((n, d + 1), dtype='int32')
|
||||
m1[:, 0] = d
|
||||
m1[:, 1:] = m
|
||||
m1.tofile(fname)
|
||||
|
||||
|
||||
def fvecs_write(fname, m):
|
||||
m = m.astype('float32')
|
||||
ivecs_write(fname, m.view('int32'))
|
||||
|
||||
|
||||
#################################################################
|
||||
# Dataset
|
||||
#################################################################
|
||||
|
||||
def sanitize(x):
|
||||
return np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
|
||||
class ResultHeap:
|
||||
""" Combine query results from a sliced dataset """
|
||||
|
||||
def __init__(self, nq, k):
|
||||
" nq: number of query vectors, k: number of results per query "
|
||||
self.I = np.zeros((nq, k), dtype='int64')
|
||||
self.D = np.zeros((nq, k), dtype='float32')
|
||||
self.nq, self.k = nq, k
|
||||
heaps = faiss.float_maxheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = faiss.swig_ptr(self.D)
|
||||
heaps.ids = faiss.swig_ptr(self.I)
|
||||
heaps.heapify()
|
||||
self.heaps = heaps
|
||||
|
||||
def add_batch_result(self, D, I, i0):
|
||||
assert D.shape == (self.nq, self.k)
|
||||
assert I.shape == (self.nq, self.k)
|
||||
I += i0
|
||||
self.heaps.addn_with_ids(
|
||||
self.k, faiss.swig_ptr(D),
|
||||
faiss.swig_ptr(I), self.k)
|
||||
|
||||
def finalize(self):
|
||||
self.heaps.reorder()
|
||||
|
||||
|
||||
|
||||
def compute_GT_sliced(xb, xq, k):
|
||||
print("compute GT")
|
||||
t0 = time.time()
|
||||
nb, d = xb.shape
|
||||
nq, d = xq.shape
|
||||
rh = ResultHeap(nq, k)
|
||||
bs = 10 ** 5
|
||||
|
||||
xqs = sanitize(xq)
|
||||
|
||||
db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
|
||||
|
||||
# compute ground-truth by blocks of bs, and add to heaps
|
||||
for i0 in range(0, nb, bs):
|
||||
i1 = min(nb, i0 + bs)
|
||||
xsl = sanitize(xb[i0:i1])
|
||||
db_gt.add(xsl)
|
||||
D, I = db_gt.search(xqs, k)
|
||||
rh.add_batch_result(D, I, i0)
|
||||
db_gt.reset()
|
||||
print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ')
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
rh.finalize()
|
||||
gt_I = rh.I
|
||||
|
||||
print("GT time: %.3f s" % (time.time() - t0))
|
||||
return gt_I
|
||||
|
||||
|
||||
def do_compute_gt(xb, xq, k):
|
||||
print("computing GT")
|
||||
nb, d = xb.shape
|
||||
index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
|
||||
if nb < 100 * 1000:
|
||||
print(" add")
|
||||
index.add(np.ascontiguousarray(xb, dtype='float32'))
|
||||
print(" search")
|
||||
D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k)
|
||||
else:
|
||||
I = compute_GT_sliced(xb, xq, k)
|
||||
|
||||
return I.astype('int32')
|
||||
|
||||
|
||||
def load_data(dataset='deep1M', compute_gt=False):
|
||||
|
||||
print("load data", dataset)
|
||||
|
||||
if dataset == 'sift1M':
|
||||
basedir = simdir + 'sift1M/'
|
||||
|
||||
xt = fvecs_read(basedir + "sift_learn.fvecs")
|
||||
xb = fvecs_read(basedir + "sift_base.fvecs")
|
||||
xq = fvecs_read(basedir + "sift_query.fvecs")
|
||||
gt = ivecs_read(basedir + "sift_groundtruth.ivecs")
|
||||
|
||||
elif dataset.startswith('bigann'):
|
||||
basedir = simdir + 'bigann/'
|
||||
|
||||
dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1])
|
||||
xb = bvecs_mmap(basedir + 'bigann_base.bvecs')
|
||||
xq = bvecs_mmap(basedir + 'bigann_query.bvecs')
|
||||
xt = bvecs_mmap(basedir + 'bigann_learn.bvecs')
|
||||
# trim xb to correct size
|
||||
xb = xb[:dbsize * 1000 * 1000]
|
||||
gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize)
|
||||
|
||||
elif dataset.startswith("deep"):
|
||||
basedir = simdir + 'deep1b/'
|
||||
szsuf = dataset[4:]
|
||||
if szsuf[-1] == 'M':
|
||||
dbsize = 10 ** 6 * int(szsuf[:-1])
|
||||
elif szsuf == '1B':
|
||||
dbsize = 10 ** 9
|
||||
elif szsuf[-1] == 'k':
|
||||
dbsize = 1000 * int(szsuf[:-1])
|
||||
else:
|
||||
assert False, "did not recognize suffix " + szsuf
|
||||
|
||||
xt = fvecs_mmap(basedir + "learn.fvecs")
|
||||
xb = fvecs_mmap(basedir + "base.fvecs")
|
||||
xq = fvecs_read(basedir + "deep1B_queries.fvecs")
|
||||
|
||||
xb = xb[:dbsize]
|
||||
|
||||
gt_fname = basedir + "%s_groundtruth.ivecs" % dataset
|
||||
if compute_gt:
|
||||
gt = do_compute_gt(xb, xq, 100)
|
||||
print("store", gt_fname)
|
||||
ivecs_write(gt_fname, gt)
|
||||
|
||||
gt = ivecs_read(gt_fname)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
print("dataset %s sizes: B %s Q %s T %s" % (
|
||||
dataset, xb.shape, xq.shape, xt.shape))
|
||||
|
||||
return xt, xb, xq, gt
|
||||
|
||||
#################################################################
|
||||
# Evaluation
|
||||
#################################################################
|
||||
|
||||
|
||||
def evaluate_DI(D, I, gt):
|
||||
nq = gt.shape[0]
|
||||
k = I.shape[1]
|
||||
rank = 1
|
||||
while rank <= k:
|
||||
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
|
||||
print("R@%d: %.4f" % (rank, recall), end=' ')
|
||||
rank *= 10
|
||||
|
||||
|
||||
def evaluate(xq, gt, index, k=100, endl=True):
|
||||
t0 = time.time()
|
||||
D, I = index.search(xq, k)
|
||||
t1 = time.time()
|
||||
nq = xq.shape[0]
|
||||
print("\t %8.4f ms per query, " % (
|
||||
(t1 - t0) * 1000.0 / nq), end=' ')
|
||||
rank = 1
|
||||
while rank <= k:
|
||||
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
|
||||
print("R@%d: %.4f" % (rank, recall), end=' ')
|
||||
rank *= 10
|
||||
if endl:
|
||||
print()
|
||||
return D, I
|
|
@ -1,241 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This is the training code for the link and code. Especially the
|
||||
neighbors_kmeans function implements the EM-algorithm to find the
|
||||
appropriate weightings and cluster them.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
|
||||
#----------------------------------------------------------
|
||||
# Utils
|
||||
#----------------------------------------------------------
|
||||
|
||||
def sanitize(x):
|
||||
return np.ascontiguousarray(x, dtype='float32')
|
||||
|
||||
|
||||
def train_kmeans(x, k, ngpu, max_points_per_centroid=256):
|
||||
"Runs kmeans on one or several GPUs"
|
||||
d = x.shape[1]
|
||||
clus = faiss.Clustering(d, k)
|
||||
clus.verbose = True
|
||||
clus.niter = 20
|
||||
clus.max_points_per_centroid = max_points_per_centroid
|
||||
|
||||
if ngpu == 0:
|
||||
index = faiss.IndexFlatL2(d)
|
||||
else:
|
||||
res = [faiss.StandardGpuResources() for i in range(ngpu)]
|
||||
|
||||
flat_config = []
|
||||
for i in range(ngpu):
|
||||
cfg = faiss.GpuIndexFlatConfig()
|
||||
cfg.useFloat16 = False
|
||||
cfg.device = i
|
||||
flat_config.append(cfg)
|
||||
|
||||
if ngpu == 1:
|
||||
index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
|
||||
else:
|
||||
indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i])
|
||||
for i in range(ngpu)]
|
||||
index = faiss.IndexReplicas()
|
||||
for sub_index in indexes:
|
||||
index.addIndex(sub_index)
|
||||
|
||||
# perform the training
|
||||
clus.train(x, index)
|
||||
centroids = faiss.vector_float_to_array(clus.centroids)
|
||||
|
||||
stats = clus.iteration_stats
|
||||
stats = [stats.at(i) for i in range(stats.size())]
|
||||
obj = np.array([st.obj for st in stats])
|
||||
print("final objective: %.4g" % obj[-1])
|
||||
|
||||
return centroids.reshape(k, d)
|
||||
|
||||
|
||||
#----------------------------------------------------------
|
||||
# Learning the codebook from neighbors
|
||||
#----------------------------------------------------------
|
||||
|
||||
|
||||
# works with both a full Inn table and dynamically generated neighbors
|
||||
|
||||
def get_Inn_shape(Inn):
|
||||
if type(Inn) != tuple:
|
||||
return Inn.shape
|
||||
return Inn[:2]
|
||||
|
||||
def get_neighbor_table(x_coded, Inn, i):
|
||||
if type(Inn) != tuple:
|
||||
return x_coded[Inn[i,:],:]
|
||||
rfn = x_coded
|
||||
M, d = rfn.M, rfn.index.d
|
||||
out = np.zeros((M + 1, d), dtype='float32')
|
||||
int_i = int(i)
|
||||
rfn.get_neighbor_table(int_i, faiss.swig_ptr(out))
|
||||
_, _, sq = Inn
|
||||
return out[:, sq * rfn.dsub : (sq + 1) * rfn.dsub]
|
||||
|
||||
|
||||
# Function that produces the best regression values from the vector
|
||||
# and its neighbors
|
||||
def regress_from_neighbors (x, x_coded, Inn):
|
||||
(N, knn) = get_Inn_shape(Inn)
|
||||
betas = np.zeros((N,knn))
|
||||
t0 = time.time()
|
||||
for i in range (N):
|
||||
xi = x[i,:]
|
||||
NNi = get_neighbor_table(x_coded, Inn, i)
|
||||
betas[i,:] = np.linalg.lstsq(NNi.transpose(), xi, rcond=0.01)[0]
|
||||
if i % (N / 10) == 0:
|
||||
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
|
||||
return betas
|
||||
|
||||
|
||||
|
||||
# find the best beta minimizing ||x-x_coded[Inn,:]*beta||^2
|
||||
def regress_opt_beta (x, x_coded, Inn):
|
||||
(N, knn) = get_Inn_shape(Inn)
|
||||
d = x.shape[1]
|
||||
|
||||
# construct the linear system to be solved
|
||||
X = np.zeros ((d*N))
|
||||
Y = np.zeros ((d*N, knn))
|
||||
for i in range (N):
|
||||
X[i*d:(i+1)*d] = x[i,:]
|
||||
neighbor_table = get_neighbor_table(x_coded, Inn, i)
|
||||
Y[i*d:(i+1)*d, :] = neighbor_table.transpose()
|
||||
beta_opt = np.linalg.lstsq(Y, X, rcond=0.01)[0]
|
||||
return beta_opt
|
||||
|
||||
|
||||
# Find the best encoding by minimizing the reconstruction error using
|
||||
# a set of pre-computed beta values
|
||||
def assign_beta (beta_centroids, x, x_coded, Inn, verbose=True):
|
||||
if type(Inn) == tuple:
|
||||
return assign_beta_2(beta_centroids, x, x_coded, Inn)
|
||||
(N, knn) = Inn.shape
|
||||
x_ibeta = np.zeros ((N), dtype='int32')
|
||||
t0= time.time()
|
||||
for i in range (N):
|
||||
NNi = x_coded[Inn[i,:]]
|
||||
# Consider all possible betas for the encoding and compute the
|
||||
# encoding error
|
||||
x_reg_all = np.dot (beta_centroids, NNi)
|
||||
err = ((x_reg_all - x[i,:]) ** 2).sum(axis=1)
|
||||
x_ibeta[i] = err.argmin()
|
||||
if verbose:
|
||||
if i % (N / 10) == 0:
|
||||
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
|
||||
return x_ibeta
|
||||
|
||||
|
||||
# Reconstruct a set of vectors using the beta_centroids, the
|
||||
# assignment, the encoded neighbors identified by the list Inn (which
|
||||
# includes the vector itself)
|
||||
def recons_from_neighbors (beta_centroids, x_ibeta, x_coded, Inn):
|
||||
(N, knn) = Inn.shape
|
||||
x_rec = np.zeros(x_coded.shape)
|
||||
t0= time.time()
|
||||
for i in range (N):
|
||||
NNi = x_coded[Inn[i,:]]
|
||||
x_rec[i, :] = np.dot (beta_centroids[x_ibeta[i]], NNi)
|
||||
if i % (N / 10) == 0:
|
||||
print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0))
|
||||
return x_rec
|
||||
|
||||
|
||||
# Compute a EM-like algorithm trying at optimizing the beta such as they
|
||||
# minimize the reconstruction error from the neighbors
|
||||
def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5):
|
||||
# First compute centroids using a regular k-means algorithm
|
||||
betas = regress_from_neighbors (x, x_coded, Inn)
|
||||
beta_centroids = train_kmeans(
|
||||
sanitize(betas), K, ngpus, max_points_per_centroid=1000000)
|
||||
_, knn = get_Inn_shape(Inn)
|
||||
d = x.shape[1]
|
||||
|
||||
rs = np.random.RandomState()
|
||||
for iter in range(niter):
|
||||
print('iter', iter)
|
||||
idx = assign_beta (beta_centroids, x, x_coded, Inn, verbose=False)
|
||||
|
||||
hist = np.bincount(idx)
|
||||
for cl0 in np.where(hist == 0)[0]:
|
||||
print(" cluster %d empty, split" % cl0, end=' ')
|
||||
cl1 = idx[np.random.randint(idx.size)]
|
||||
pos = np.nonzero (idx == cl1)[0]
|
||||
pos = rs.choice(pos, pos.size / 2)
|
||||
print(" cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos)))
|
||||
idx[pos] = cl0
|
||||
hist = np.bincount(idx)
|
||||
|
||||
tot_err = 0
|
||||
for k in range (K):
|
||||
pos = np.nonzero (idx == k)[0]
|
||||
npos = pos.shape[0]
|
||||
|
||||
X = np.zeros (d*npos)
|
||||
Y = np.zeros ((d*npos, knn))
|
||||
|
||||
for i in range(npos):
|
||||
X[i*d:(i+1)*d] = x[pos[i],:]
|
||||
neighbor_table = get_neighbor_table(x_coded, Inn, pos[i])
|
||||
Y[i*d:(i+1)*d, :] = neighbor_table.transpose()
|
||||
sol, residuals, _, _ = np.linalg.lstsq(Y, X, rcond=0.01)
|
||||
if residuals.size > 0:
|
||||
tot_err += residuals.sum()
|
||||
beta_centroids[k, :] = sol
|
||||
print(' err=%g' % tot_err)
|
||||
return beta_centroids
|
||||
|
||||
|
||||
# assign the betas in C++
|
||||
def assign_beta_2(beta_centroids, x, rfn, Inn):
|
||||
_, _, sq = Inn
|
||||
if rfn.k == 1:
|
||||
return np.zeros(x.shape[0], dtype=int)
|
||||
# add dummy dimensions to beta_centroids and x
|
||||
all_beta_centroids = np.zeros(
|
||||
(rfn.nsq, rfn.k, rfn.M + 1), dtype='float32')
|
||||
all_beta_centroids[sq] = beta_centroids
|
||||
all_x = np.zeros((len(x), rfn.d), dtype='float32')
|
||||
all_x[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] = x
|
||||
rfn.codes.clear()
|
||||
rfn.ntotal = 0
|
||||
faiss.copy_array_to_vector(
|
||||
all_beta_centroids.ravel(), rfn.codebook)
|
||||
rfn.add_codes(len(x), faiss.swig_ptr(all_x))
|
||||
codes = faiss.vector_to_array(rfn.codes)
|
||||
codes = codes.reshape(-1, rfn.nsq)
|
||||
return codes[:, sq]
|
||||
|
||||
|
||||
#######################################################
|
||||
# For usage from bench_storages.py
|
||||
|
||||
def train_beta_codebook(rfn, xb_full, niter=10):
|
||||
beta_centroids = []
|
||||
for sq in range(rfn.nsq):
|
||||
d0, d1 = sq * rfn.dsub, (sq + 1) * rfn.dsub
|
||||
print("training subquantizer %d/%d on dimensions %d:%d" % (
|
||||
sq, rfn.nsq, d0, d1))
|
||||
beta_centroids_i = neighbors_kmeans(
|
||||
xb_full[:, d0:d1], rfn, (xb_full.shape[0], rfn.M + 1, sq),
|
||||
rfn.k,
|
||||
ngpus=0, niter=niter)
|
||||
beta_centroids.append(beta_centroids_i)
|
||||
rfn.ntotal = 0
|
||||
rfn.codes.clear()
|
||||
rfn.codebook.clear()
|
||||
return np.stack(beta_centroids)
|
|
@ -307,7 +307,7 @@ void hnsw_search(
|
|||
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
|
||||
efSearch = params->efSearch;
|
||||
}
|
||||
size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
|
||||
size_t n1 = 0, n2 = 0, ndis = 0;
|
||||
|
||||
idx_t check_period = InterruptCallback::get_period_hint(
|
||||
hnsw.max_level * index->d * efSearch);
|
||||
|
@ -323,7 +323,7 @@ void hnsw_search(
|
|||
std::unique_ptr<DistanceComputer> dis(
|
||||
storage_distance_computer(index->storage));
|
||||
|
||||
#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided)
|
||||
#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
|
||||
for (idx_t i = i0; i < i1; i++) {
|
||||
res.begin(i);
|
||||
dis->set_query(x + i * index->d);
|
||||
|
@ -331,16 +331,14 @@ void hnsw_search(
|
|||
HNSWStats stats = hnsw.search(*dis, res, vt, params);
|
||||
n1 += stats.n1;
|
||||
n2 += stats.n2;
|
||||
n3 += stats.n3;
|
||||
ndis += stats.ndis;
|
||||
nreorder += stats.nreorder;
|
||||
res.end();
|
||||
}
|
||||
}
|
||||
InterruptCallback::check();
|
||||
}
|
||||
|
||||
hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
|
||||
hnsw_stats.combine({n1, n2, ndis});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -800,7 +798,7 @@ void IndexHNSW2Level::search(
|
|||
IndexHNSW::search(n, x, k, distances, labels);
|
||||
|
||||
} else { // "mixed" search
|
||||
size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
|
||||
size_t n1 = 0, n2 = 0, ndis = 0;
|
||||
|
||||
const IndexIVFPQ* index_ivfpq =
|
||||
dynamic_cast<const IndexIVFPQ*>(storage);
|
||||
|
@ -832,7 +830,7 @@ void IndexHNSW2Level::search(
|
|||
int candidates_size = hnsw.upper_beam;
|
||||
MinimaxHeap candidates(candidates_size);
|
||||
|
||||
#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
|
||||
#pragma omp for reduction(+ : n1, n2, ndis)
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
idx_t* idxi = labels + i * k;
|
||||
float* simi = distances + i * k;
|
||||
|
@ -877,9 +875,7 @@ void IndexHNSW2Level::search(
|
|||
k);
|
||||
n1 += search_stats.n1;
|
||||
n2 += search_stats.n2;
|
||||
n3 += search_stats.n3;
|
||||
ndis += search_stats.ndis;
|
||||
nreorder += search_stats.nreorder;
|
||||
|
||||
vt.advance();
|
||||
vt.advance();
|
||||
|
@ -888,7 +884,7 @@ void IndexHNSW2Level::search(
|
|||
}
|
||||
}
|
||||
|
||||
hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
|
||||
hnsw_stats.combine({n1, n2, ndis});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,19 +84,14 @@ class GpuIndex : public faiss::Index {
|
|||
|
||||
/// `x` and `labels` can be resident on the CPU or any GPU; copies are
|
||||
/// performed as needed
|
||||
void assign(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t* labels,
|
||||
// faiss::Index has idx_t for k
|
||||
idx_t k = 1) const override;
|
||||
void assign(idx_t n, const float* x, idx_t* labels, idx_t k = 1)
|
||||
const override;
|
||||
|
||||
/// `x`, `distances` and `labels` can be resident on the CPU or any
|
||||
/// GPU; copies are performed as needed
|
||||
void search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
// faiss::Index has idx_t for k
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
|
@ -107,7 +102,6 @@ class GpuIndex : public faiss::Index {
|
|||
void search_and_reconstruct(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
// faiss::Index has idx_t for k
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels,
|
||||
|
|
|
@ -664,7 +664,7 @@ int search_from_candidates(
|
|||
if (candidates.size() == 0) {
|
||||
stats.n2++;
|
||||
}
|
||||
stats.n3 += ndis;
|
||||
stats.ndis += ndis;
|
||||
}
|
||||
|
||||
return nres;
|
||||
|
@ -793,7 +793,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|||
if (candidates.size() == 0) {
|
||||
++stats.n2;
|
||||
}
|
||||
stats.n3 += ndis;
|
||||
stats.ndis += ndis;
|
||||
|
||||
return top_candidates;
|
||||
}
|
||||
|
|
|
@ -230,30 +230,20 @@ struct HNSW {
|
|||
};
|
||||
|
||||
struct HNSWStats {
|
||||
size_t n1, n2, n3;
|
||||
size_t ndis;
|
||||
size_t nreorder;
|
||||
|
||||
HNSWStats(
|
||||
size_t n1 = 0,
|
||||
size_t n2 = 0,
|
||||
size_t n3 = 0,
|
||||
size_t ndis = 0,
|
||||
size_t nreorder = 0)
|
||||
: n1(n1), n2(n2), n3(n3), ndis(ndis), nreorder(nreorder) {}
|
||||
size_t n1 = 0; /// numbner of vectors searched
|
||||
size_t n2 =
|
||||
0; /// number of queries for which the candidate list is exhasted
|
||||
size_t ndis = 0; /// number of distances computed
|
||||
|
||||
void reset() {
|
||||
n1 = n2 = n3 = 0;
|
||||
n1 = n2 = 0;
|
||||
ndis = 0;
|
||||
nreorder = 0;
|
||||
}
|
||||
|
||||
void combine(const HNSWStats& other) {
|
||||
n1 += other.n1;
|
||||
n2 += other.n2;
|
||||
n3 += other.n3;
|
||||
ndis += other.ndis;
|
||||
nreorder += other.nreorder;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -123,6 +123,16 @@ class TestHNSW(unittest.TestCase):
|
|||
mask = Iref[:, 0] == Ihnsw[:, 0]
|
||||
assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0])
|
||||
|
||||
def test_ndis_stats(self):
|
||||
d = self.xq.shape[1]
|
||||
|
||||
index = faiss.IndexHNSWFlat(d, 16)
|
||||
index.add(self.xb)
|
||||
stats = faiss.cvar.hnsw_stats
|
||||
stats.reset()
|
||||
Dhnsw, Ihnsw = index.search(self.xq, 1)
|
||||
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)
|
||||
|
||||
|
||||
class TestNSG(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue