Summary:
## Description:
This diff implemented Navigating Spreading-out Graph (NSG) which accepts a KNN graph as input.
Here is the interface of building an NSG graph:
``` c++
void IndexNSG::build(idx_t n, const float *x, idx_t *knn_graph, int GK);
```
where `GK` is the nb of neighbors per node and `knn_graph[i * GK + j]` is the j-th neighbor of node i.

The `add` method is not implemented yet.

The unit tests could be found in `tests/test_nsg.cpp`.

mdouze beauby Maybe I need some advice on how to design the interface and support python.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1707

Test Plan: buck test //faiss/tests/:test_index -- TestNSG

Reviewed By: beauby

Differential Revision: D26748498

Pulled By: mdouze

fbshipit-source-id: 3280f705fb1b5f9c8cc5efeba63b904c3b832544
pull/1764/head
Check Deng 2021-03-10 15:01:06 -08:00 committed by Facebook GitHub Bot
parent 5e54fb57d8
commit b35103a138
13 changed files with 1480 additions and 6 deletions

View File

@ -9,6 +9,8 @@ at the moment.
### Added
- Support for building C bindings through the `FAISS_ENABLE_C_API` CMake option.
- Serializing the indexes with the python pickle module
- Support for the NNDescent k-NN graph building method
- Support for the NSG indexing method
### Changed
- The order of xb an xq was different between `faiss.knn` and `faiss.knn_gpu`.

View File

@ -7,18 +7,33 @@ import time
import sys
import numpy as np
import faiss
from datasets import load_sift1M
try:
from faiss.contrib.datasets_fb import DatasetSIFT1M
except ImportError:
from faiss.contrib.datasets import DatasetSIFT1M
# from datasets import load_sift1M
k = int(sys.argv[1])
todo = sys.argv[1:]
print("load data")
xb, xq, xt, gt = load_sift1M()
# xb, xq, xt, gt = load_sift1M()
ds = DatasetSIFT1M()
xq = ds.get_queries()
xb = ds.get_database()
gt = ds.get_groundtruth()
xt = ds.get_train()
nq, d = xq.shape
if todo == []:
todo = 'hnsw hnsw_sq ivf ivf_hnsw_quantizer kmeans kmeans_hnsw'.split()
todo = 'hnsw hnsw_sq ivf ivf_hnsw_quantizer kmeans kmeans_hnsw nsg'.split()
def evaluate(index):
@ -153,3 +168,25 @@ if 'kmeans_hnsw' in todo:
# clusters is too high.
index.hnsw.efSearch = 128
clus.train(xb, index)
if 'nsg' in todo:
print("Testing NSG Flat")
index = faiss.IndexNSGFlat(d, 32)
index.build_type = 1
# training is not needed
# this is the default, higher is more accurate and slower to
# construct
print("add")
# to see progress
index.verbose = True
index.add(xb)
print("search")
for search_L in -1, 16, 32, 64, 128, 256:
print("search_L", search_L, end=' ')
index.nsg.search_L = search_L
evaluate(index)

View File

@ -27,6 +27,7 @@ add_library(faiss
IndexLSH.cpp
IndexNNDescent.cpp
IndexLattice.cpp
IndexNSG.cpp
IndexPQ.cpp
IndexPQFastScan.cpp
IndexPreTransform.cpp
@ -42,6 +43,7 @@ add_library(faiss
impl/AuxIndexStructures.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
impl/PolysemousTraining.cpp
impl/ProductQuantizer.cpp
impl/ScalarQuantizer.cpp
@ -94,6 +96,7 @@ set(FAISS_HEADERS
IndexLSH.h
IndexLattice.h
IndexNNDescent.h
IndexNSG.h
IndexPQ.h
IndexPQFastScan.h
IndexPreTransform.h
@ -112,6 +115,7 @@ set(FAISS_HEADERS
impl/FaissAssert.h
impl/FaissException.h
impl/HNSW.h
impl/NSG.h
impl/PolysemousTraining.h
impl/ProductQuantizer-inl.h
impl/ProductQuantizer.h

View File

@ -9,6 +9,7 @@
#include <faiss/IndexNNDescent.h>
#include <inttypes.h>
#include <omp.h>
#include <cstdio>
#include <cstdlib>
@ -141,7 +142,9 @@ void IndexNNDescent::search(
"Please use IndexNNDescentFlat (or variants) "
"instead of IndexNNDescent directly");
if (verbose) {
printf("Parameters: k=%ld, search_L=%d\n", k, nndescent.search_L);
printf("Parameters: k=%" PRId64 ", search_L=%d\n",
k,
nndescent.search_L);
}
idx_t check_period =

302
faiss/IndexNSG.cpp 100644
View File

@ -0,0 +1,302 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexNSG.h>
#include <inttypes.h>
#include <omp.h>
#include <memory>
#include <faiss/IndexFlat.h>
#include <faiss/IndexNNDescent.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
namespace faiss {
using idx_t = Index::idx_t;
using namespace nsg;
/**************************************************************
* IndexNSG implementation
**************************************************************/
IndexNSG::IndexNSG(int d, int R, MetricType metric)
: Index(d, metric),
nsg(R),
own_fields(false),
storage(nullptr),
is_built(false),
GK(64),
build_type(0) {
nndescent_S = 10;
nndescent_R = 100;
nndescent_L = GK + 50;
nndescent_iter = 10;
}
IndexNSG::IndexNSG(Index* storage, int R)
: Index(storage->d, storage->metric_type),
nsg(R),
own_fields(false),
storage(storage),
is_built(false),
GK(64),
build_type(0) {
nndescent_S = 10;
nndescent_R = 100;
nndescent_L = GK + 50;
nndescent_iter = 10;
}
IndexNSG::~IndexNSG() {
if (own_fields) {
delete storage;
}
}
void IndexNSG::train(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
// nsg structure does not require training
storage->train(n, x);
is_trained = true;
}
void IndexNSG::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const
{
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
idx_t check_period = InterruptCallback::get_period_hint(d * nsg.search_L);
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);
#pragma omp parallel
{
VisitedTable vt(ntotal);
DistanceComputer* dis = storage_distance_computer(storage);
ScopeDeleter1<DistanceComputer> del(dis);
#pragma omp for
for (idx_t i = i0; i < i1; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
dis->set_query(x + i * d);
maxheap_heapify(k, simi, idxi);
nsg.search(*dis, k, idxi, simi, vt);
maxheap_reorder(k, simi, idxi);
vt.advance();
}
}
InterruptCallback::check();
}
if (metric_type == METRIC_INNER_PRODUCT) {
// we need to revert the negated distances
for (size_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
}
void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int GK) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
FAISS_THROW_IF_NOT_MSG(
!is_built && ntotal == 0, "The IndexNSG is already built");
storage->add(n, x);
ntotal = storage->ntotal;
// check the knn graph
check_knn_graph(knn_graph, n, GK);
const nsg::Graph<idx_t> knng(knn_graph, n, GK);
nsg.build(storage, n, knng, verbose);
is_built = true;
}
void IndexNSG::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexNSGFlat (or variants) "
"instead of IndexNSG directly");
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT_MSG(
!is_built && ntotal == 0,
"NSG does not support incremental addition");
std::vector<idx_t> knng;
if (verbose) {
printf("IndexNSG::add %zd vectors\n", size_t(n));
}
if (build_type == 0) { // build with brute force search
if (verbose) {
printf(" Build knn graph with brute force search on storage index\n");
}
storage->add(n, x);
ntotal = storage->ntotal;
FAISS_THROW_IF_NOT(ntotal == n);
knng.resize(ntotal * (GK + 1));
storage->assign(ntotal, x, knng.data(), GK + 1);
// Remove itself
// - For metric distance, we just need to remove the first neighbor
// - But for non-metric, e.g. inner product, we need to check
// - each neighbor
if (storage->metric_type == METRIC_INNER_PRODUCT) {
for (idx_t i = 0; i < ntotal; i++) {
int count = 0;
for (int j = 0; j < GK + 1; j++) {
idx_t id = knng[i * (GK + 1) + j];
if (id != i) {
knng[i * GK + count] = id;
count += 1;
}
if (count == GK) {
break;
}
}
}
} else {
for (idx_t i = 0; i < ntotal; i++) {
memmove(knng.data() + i * GK,
knng.data() + i * (GK + 1) + 1,
GK * sizeof(idx_t));
}
}
} else if (build_type == 1) { // build with NNDescent
IndexNNDescent index(storage, GK);
index.nndescent.S = nndescent_S;
index.nndescent.R = nndescent_R;
index.nndescent.L = std::max(nndescent_L, GK + 50);
index.nndescent.iter = nndescent_iter;
index.verbose = verbose;
if (verbose) {
printf(" Build knn graph with NNdescent S=%d R=%d L=%d niter=%d\n",
index.nndescent.S,
index.nndescent.R,
index.nndescent.L,
index.nndescent.iter);
}
// prevent IndexNSG from deleting the storage
index.own_fields = false;
index.add(n, x);
// storage->add is already implicit called in IndexNSG.add
ntotal = storage->ntotal;
FAISS_THROW_IF_NOT(ntotal == n);
knng.resize(ntotal * GK);
// cast from idx_t to int
const int* knn_graph = index.nndescent.final_graph.data();
#pragma omp parallel for
for (idx_t i = 0; i < ntotal * GK; i++) {
knng[i] = knn_graph[i];
}
} else {
FAISS_THROW_MSG("build_type should be 0 or 1");
}
if (verbose) {
printf(" Check the knn graph\n");
}
// check the knn graph
check_knn_graph(knng.data(), n, GK);
if (verbose) {
printf(" nsg building\n");
}
const nsg::Graph<idx_t> knn_graph(knng.data(), n, GK);
nsg.build(storage, n, knn_graph, verbose);
is_built = true;
}
void IndexNSG::reset() {
nsg.reset();
storage->reset();
ntotal = 0;
is_built = false;
}
void IndexNSG::reconstruct(idx_t key, float* recons) const {
storage->reconstruct(key, recons);
}
void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const {
idx_t total_count = 0;
#pragma omp parallel for reduction(+ : total_count)
for (idx_t i = 0; i < n; i++) {
int count = 0;
for (int j = 0; j < K; j++) {
idx_t id = knn_graph[i * K + j];
if (id < 0 || id >= n || id == i) {
count += 1;
}
}
total_count += count;
}
if (total_count > 0) {
fprintf(stderr,
"WARNING: the input knn graph "
"has %" PRId64 " invalid entries\n",
total_count);
}
FAISS_THROW_IF_NOT_MSG(
total_count < n / 10,
"There are too much invalid entries in the knn graph. "
"It may be an invalid knn graph.");
}
/**************************************************************
* IndexNSGFlat implementation
**************************************************************/
IndexNSGFlat::IndexNSGFlat() {
is_trained = true;
}
IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric)
: IndexNSG(new IndexFlat(d, metric), R) {
own_fields = true;
is_trained = true;
}
} // namespace faiss

85
faiss/IndexNSG.h 100644
View File

@ -0,0 +1,85 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <vector>
#include <faiss/IndexFlat.h>
#include <faiss/IndexNNDescent.h>
#include <faiss/impl/NSG.h>
#include <faiss/utils/utils.h>
namespace faiss {
/** The NSG index is a normal random-access index with a NSG
* link structure built on top */
struct IndexNSG : Index {
/// the link strcuture
NSG nsg;
/// the sequential storage
bool own_fields;
Index* storage;
/// the index is built or not
bool is_built;
/// K of KNN graph for building
int GK;
/// indicate how to build a knn graph
/// - 0: build NSG with brute force search
/// - 1: build NSG with NNDescent
char build_type;
/// parameters for nndescent
int nndescent_S;
int nndescent_R;
int nndescent_L;
int nndescent_iter;
explicit IndexNSG(int d = 0, int R = 32, MetricType metric = METRIC_L2);
explicit IndexNSG(Index* storage, int R = 32);
~IndexNSG() override;
void build(idx_t n, const float* x, idx_t* knn_graph, int GK);
void add(idx_t n, const float* x) override;
/// Trains the storage if needed
void train(idx_t n, const float* x) override;
/// entry point for search
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void reconstruct(idx_t key, float* recons) const override;
void reset() override;
void check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const;
};
/** Flat index topped with with a NSG structure to access elements
* more efficiently.
*/
struct IndexNSGFlat : IndexNSG {
IndexNSGFlat();
IndexNSGFlat(int d, int R, MetricType metric = METRIC_L2);
};
} // namespace faiss

View File

@ -151,7 +151,7 @@ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
has_built = false;
S = 10;
R = 100;
L = K;
L = K + 50;
iter = 10;
search_L = 0;
}

681
faiss/impl/NSG.cpp 100644
View File

@ -0,0 +1,681 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/impl/NSG.h>
#include <algorithm>
#include <memory>
#include <mutex>
#include <stack>
#include <faiss/impl/AuxIndexStructures.h>
namespace faiss {
namespace nsg {
namespace {
/* Wrap the distance computer into one that negates the
distances. This makes supporting INNER_PRODUCE search easier */
struct NegativeDistanceComputer : DistanceComputer {
using idx_t = Index::idx_t;
/// owned by this
DistanceComputer* basedis;
explicit NegativeDistanceComputer(DistanceComputer* basedis)
: basedis(basedis) {}
void set_query(const float* x) override {
basedis->set_query(x);
}
/// compute distance of vector i to current query
float operator()(idx_t i) override {
return -(*basedis)(i);
}
/// compute distance between two stored vectors
float symmetric_dis(idx_t i, idx_t j) override {
return -basedis->symmetric_dis(i, j);
}
~NegativeDistanceComputer() override {
delete basedis;
}
};
} // namespace
DistanceComputer* storage_distance_computer(const Index* storage) {
if (storage->metric_type == METRIC_INNER_PRODUCT) {
return new NegativeDistanceComputer(storage->get_distance_computer());
} else {
return storage->get_distance_computer();
}
}
} // namespace nsg
using namespace nsg;
using LockGuard = std::lock_guard<std::mutex>;
struct Neighbor {
int id;
float distance;
bool flag;
Neighbor() = default;
Neighbor(int id, float distance, bool f)
: id(id), distance(distance), flag(f) {}
inline bool operator<(const Neighbor& other) const {
return distance < other.distance;
}
};
struct Node {
int id;
float distance;
Node() = default;
Node(int id, float distance) : id(id), distance(distance) {}
inline bool operator<(const Node& other) const {
return distance < other.distance;
}
};
inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
// find the location to insert
int left = 0, right = K - 1;
if (addr[left].distance > nn.distance) {
memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
addr[left] = nn;
return left;
}
if (addr[right].distance < nn.distance) {
addr[K] = nn;
return K;
}
while (left < right - 1) {
int mid = (left + right) / 2;
if (addr[mid].distance > nn.distance) {
right = mid;
} else {
left = mid;
}
}
// check equal ID
while (left > 0) {
if (addr[left].distance < nn.distance) {
break;
}
if (addr[left].id == nn.id) {
return K + 1;
}
left--;
}
if (addr[left].id == nn.id || addr[right].id == nn.id) {
return K + 1;
}
memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
addr[right] = nn;
return right;
}
const int NSG::EMPTY_ID = -1;
NSG::NSG(int R) : R(R), rng(0x0903) {
L = R + 32;
C = R + 100;
search_L = 16;
ntotal = 0;
is_built = false;
srand(0x1998);
}
void NSG::search(
DistanceComputer& dis,
int k,
idx_t* I,
float* D,
VisitedTable& vt) const {
FAISS_THROW_IF_NOT(is_built);
FAISS_THROW_IF_NOT(final_graph);
int pool_size = std::max(search_L, k);
std::vector<Neighbor> retset;
std::vector<Node> tmp;
search_on_graph<false>(
*final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
std::partial_sort(
retset.begin(), retset.begin() + k, retset.begin() + pool_size);
for (size_t i = 0; i < k; i++) {
I[i] = retset[i].id;
D[i] = retset[i].distance;
}
}
void NSG::build(
Index* storage,
idx_t n,
const nsg::Graph<idx_t>& knn_graph,
bool verbose) {
FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
if (verbose) {
printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
}
ntotal = n;
init_graph(storage, knn_graph);
std::vector<int> degrees(n, 0);
{
nsg::Graph<Node> tmp_graph(n, R);
link(storage, knn_graph, tmp_graph, verbose);
final_graph = std::make_shared<nsg::Graph<int>>(n, R);
std::fill_n(final_graph->data, n * R, EMPTY_ID);
#pragma omp parallel for
for (int i = 0; i < n; i++) {
int cnt = 0;
for (int j = 0; j < R; j++) {
int id = tmp_graph.at(i, j).id;
if (id != EMPTY_ID) {
final_graph->at(i, cnt) = id;
cnt += 1;
}
degrees[i] = cnt;
}
}
}
int num_attached = tree_grow(storage, degrees);
check_graph();
is_built = true;
if (verbose) {
int max = 0, min = 1e6;
double avg = 0;
for (int i = 0; i < n; i++) {
int size = 0;
while (size < R && final_graph->at(i, size) != EMPTY_ID) {
size += 1;
}
max = std::max(size, max);
min = std::min(size, min);
avg += size;
}
avg = avg / n;
printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
max,
min,
avg);
printf("Attached nodes: %d\n", num_attached);
}
}
void NSG::reset() {
final_graph.reset();
ntotal = 0;
is_built = false;
}
void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
int d = storage->d;
int n = storage->ntotal;
std::unique_ptr<float[]> center(new float[d]);
std::unique_ptr<float[]> tmp(new float[d]);
std::fill_n(center.get(), d, 0.0f);
for (int i = 0; i < n; i++) {
storage->reconstruct(i, tmp.get());
for (int j = 0; j < d; j++) {
center[j] += tmp[j];
}
}
for (int i = 0; i < d; i++) {
center[i] /= n;
}
std::vector<Neighbor> retset;
std::vector<Node> tmpset;
// random initialize navigating point
int ep = rng.rand_int(n);
std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
dis->set_query(center.get());
VisitedTable vt(ntotal);
// Do not collect the visited nodes
search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
// set enterpoint
enterpoint = retset[0].id;
}
template <bool collect_fullset, class index_t>
void NSG::search_on_graph(
const nsg::Graph<index_t>& graph,
DistanceComputer& dis,
VisitedTable& vt,
int ep,
int pool_size,
std::vector<Neighbor>& retset,
std::vector<Node>& fullset) const {
RandomGenerator gen(rand());
retset.resize(pool_size + 1);
std::vector<int> init_ids(pool_size);
int num_ids = 0;
for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
int id = (int)graph.at(ep, i);
if (id < 0 || id >= ntotal) {
continue;
}
init_ids[i] = id;
vt.set(id);
num_ids += 1;
}
while (num_ids < pool_size) {
int id = gen.rand_int(ntotal);
if (vt.get(id)) {
continue;
}
init_ids[num_ids] = id;
num_ids++;
vt.set(id);
}
for (int i = 0; i < init_ids.size(); i++) {
int id = init_ids[i];
float dist = dis(id);
retset[i] = Neighbor(id, dist, true);
if (collect_fullset) {
fullset.emplace_back(retset[i].id, retset[i].distance);
}
}
std::sort(retset.begin(), retset.begin() + pool_size);
int k = 0;
while (k < pool_size) {
int updated_pos = pool_size;
if (retset[k].flag) {
retset[k].flag = false;
int n = retset[k].id;
for (int m = 0; m < graph.K; m++) {
int id = (int)graph.at(n, m);
if (id < 0 || id > ntotal || vt.get(id)) {
continue;
}
vt.set(id);
float dist = dis(id);
Neighbor nn(id, dist, true);
if (collect_fullset) {
fullset.emplace_back(id, dist);
}
if (dist >= retset[pool_size - 1].distance) {
continue;
}
int r = insert_into_pool(retset.data(), pool_size, nn);
updated_pos = std::min(updated_pos, r);
}
}
k = (updated_pos <= k) ? updated_pos : (k + 1);
}
}
void NSG::link(
Index* storage,
const nsg::Graph<idx_t>& knn_graph,
nsg::Graph<Node>& graph,
bool /* verbose */) {
#pragma omp parallel
{
std::unique_ptr<float[]> vec(new float[storage->d]);
std::vector<Node> pool;
std::vector<Neighbor> tmp;
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
#pragma omp for schedule(dynamic, 100)
for (int i = 0; i < ntotal; i++) {
storage->reconstruct(i, vec.get());
dis->set_query(vec.get());
// Collect the visited nodes into pool
search_on_graph<true>(
knn_graph, *dis, vt, enterpoint, L, tmp, pool);
sync_prune(i, pool, *dis, vt, knn_graph, graph);
pool.clear();
tmp.clear();
vt.advance();
}
} // omp parallel
std::vector<std::mutex> locks(ntotal);
#pragma omp parallel
{
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
#pragma omp for schedule(dynamic, 100)
for (int i = 0; i < ntotal; ++i) {
add_reverse_links(i, locks, *dis, graph);
}
} // omp parallel
}
void NSG::sync_prune(
int q,
std::vector<Node>& pool,
DistanceComputer& dis,
VisitedTable& vt,
const nsg::Graph<idx_t>& knn_graph,
nsg::Graph<Node>& graph) {
for (int i = 0; i < knn_graph.K; i++) {
int id = knn_graph.at(q, i);
if (id < 0 || id >= ntotal || vt.get(id)) {
continue;
}
float dist = dis.symmetric_dis(q, id);
pool.emplace_back(id, dist);
}
std::sort(pool.begin(), pool.end());
std::vector<Node> result;
int start = 0;
if (pool[start].id == q) {
start++;
}
result.push_back(pool[start]);
while (result.size() < R && (++start) < pool.size() && start < C) {
auto& p = pool[start];
bool occlude = false;
for (int t = 0; t < result.size(); t++) {
if (p.id == result[t].id) {
occlude = true;
break;
}
float djk = dis.symmetric_dis(result[t].id, p.id);
if (djk < p.distance /* dik */) {
occlude = true;
break;
}
}
if (!occlude) {
result.push_back(p);
}
}
for (size_t i = 0; i < R; i++) {
if (i < result.size()) {
graph.at(q, i).id = result[i].id;
graph.at(q, i).distance = result[i].distance;
} else {
graph.at(q, i).id = EMPTY_ID;
}
}
}
void NSG::add_reverse_links(
int q,
std::vector<std::mutex>& locks,
DistanceComputer& dis,
nsg::Graph<Node>& graph) {
for (size_t i = 0; i < R; i++) {
if (graph.at(q, i).id == EMPTY_ID) {
break;
}
Node sn(q, graph.at(q, i).distance);
int des = graph.at(q, i).id;
std::vector<Node> tmp_pool;
int dup = 0;
{
LockGuard guard(locks[des]);
for (int j = 0; j < R; j++) {
if (graph.at(des, j).id == EMPTY_ID) {
break;
}
if (q == graph.at(des, j).id) {
dup = 1;
break;
}
tmp_pool.push_back(graph.at(des, j));
}
}
if (dup) {
continue;
}
tmp_pool.push_back(sn);
if (tmp_pool.size() > R) {
std::vector<Node> result;
int start = 0;
std::sort(tmp_pool.begin(), tmp_pool.end());
result.push_back(tmp_pool[start]);
while (result.size() < R && (++start) < tmp_pool.size()) {
auto& p = tmp_pool[start];
bool occlude = false;
for (int t = 0; t < result.size(); t++) {
if (p.id == result[t].id) {
occlude = true;
break;
}
float djk = dis.symmetric_dis(result[t].id, p.id);
if (djk < p.distance /* dik */) {
occlude = true;
break;
}
}
if (!occlude) {
result.push_back(p);
}
}
{
LockGuard guard(locks[des]);
for (int t = 0; t < result.size(); t++) {
graph.at(des, t) = result[t];
}
}
} else {
LockGuard guard(locks[des]);
for (int t = 0; t < R; t++) {
if (graph.at(des, t).id == EMPTY_ID) {
graph.at(des, t) = sn;
break;
}
}
}
}
}
int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
int root = enterpoint;
VisitedTable vt(ntotal);
VisitedTable vt2(ntotal);
int num_attached = 0;
int cnt = 0;
while (true) {
cnt = dfs(vt, root, cnt);
if (cnt >= ntotal) {
break;
}
root = attach_unlinked(storage, vt, vt2, degrees);
vt2.advance();
num_attached += 1;
}
return num_attached;
}
int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
int node = root;
std::stack<int> stack;
stack.push(root);
if (!vt.get(root)) {
cnt++;
}
vt.set(root);
while (!stack.empty()) {
int next = EMPTY_ID;
for (int i = 0; i < R; i++) {
int id = final_graph->at(node, i);
if (id != EMPTY_ID && !vt.get(id)) {
next = id;
break;
}
}
if (next == EMPTY_ID) {
stack.pop();
if (stack.empty()) {
break;
}
node = stack.top();
continue;
}
node = next;
vt.set(node);
stack.push(node);
cnt++;
}
return cnt;
}
int NSG::attach_unlinked(
Index* storage,
VisitedTable& vt,
VisitedTable& vt2,
std::vector<int>& degrees) {
/* NOTE: This implementation is slightly different from the original paper.
*
* Instead of connecting the unlinked node to the nearest point in the
* spanning tree which will increase the maximum degree of the graph and
* also make the graph hard to maintain, this implementation links the
* unlinked node to the nearest node of which the degree is smaller than R.
* It will keep the degree of all nodes to be no more than `R`.
*/
// find one unlinked node
int id = EMPTY_ID;
for (int i = 0; i < ntotal; i++) {
if (!vt.get(i)) {
id = i;
break;
}
}
if (id == EMPTY_ID) {
return EMPTY_ID; // No Unlinked Node
}
std::vector<Neighbor> tmp;
std::vector<Node> pool;
std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
std::unique_ptr<float[]> vec(new float[storage->d]);
storage->reconstruct(id, vec.get());
dis->set_query(vec.get());
// Collect the visited nodes into pool
search_on_graph<true>(
*final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
std::sort(pool.begin(), pool.end());
int node;
bool found = false;
for (int i = 0; i < pool.size(); i++) {
node = pool[i].id;
if (degrees[node] < R && node != id) {
found = true;
break;
}
}
// randomly choice annother node
if (!found) {
do {
node = rng.rand_int(ntotal);
if (vt.get(node) && degrees[node] < R && node != id) {
found = true;
}
} while (!found);
}
int pos = degrees[node];
final_graph->at(node, pos) = id; // replace
degrees[node] += 1;
return node;
}
void NSG::check_graph() const {
#pragma omp parallel for
for (int i = 0; i < ntotal; i++) {
for (int j = 0; j < R; j++) {
int id = final_graph->at(i, j);
FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
}
}
}
} // namespace faiss

197
faiss/impl/NSG.h 100644
View File

@ -0,0 +1,197 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#pragma once
#include <memory>
#include <mutex>
#include <vector>
#include <omp.h>
#include <faiss/Index.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/random.h>
namespace faiss {
/** Implementation of the Navigating Spreading-out Graph (NSG)
* datastructure.
*
* Fast Approximate Nearest Neighbor Search With The
* Navigating Spreading-out Graph
*
* Cong Fu, Chao Xiang, Changxu Wang, Deng Cai, VLDB 2019
*
* This implmentation is heavily influenced by the NSG
* implementation by ZJULearning Group
* (https://github.com/zjulearning/nsg)
*
* The NSG object stores only the neighbor link structure, see
* IndexNSG.h for the full index object.
*/
struct DistanceComputer; // from AuxIndexStructures
struct Neighbor;
struct Node;
namespace nsg {
/***********************************************************
* Graph structure to store a graph.
*
* It is represented by an adjacency matrix `data`, where
* data[i, j] is the j-th neighbor of node i.
***********************************************************/
template <class node_t>
struct Graph {
node_t* data; ///< the flattened adjacency matrix
int K; ///< nb of neighbors per node
int N; ///< total nb of nodes
bool own_fields; ///< the underlying data owned by itself or not
// construct from a known graph
Graph(node_t* data, int N, int K)
: data(data), K(K), N(N), own_fields(false) {}
// construct an empty graph
// NOTE: the newly allocated data needs to be destroyed at destruction time
Graph(int N, int K) : K(K), N(N), own_fields(true) {
data = new node_t[N * K];
}
// release the allocated memory if needed
~Graph() {
if (own_fields) {
delete[] data;
}
}
// access the j-th neighbor of node i
inline node_t at(int i, int j) const {
return data[i * K + j];
}
// access the j-th neighbor of node i by reference
inline node_t& at(int i, int j) {
return data[i * K + j];
}
};
DistanceComputer* storage_distance_computer(const Index* storage);
} // namespace nsg
struct NSG {
/// internal storage of vectors (32 bits: this is expensive)
using storage_idx_t = int;
/// Faiss results are 64-bit
using idx_t = Index::idx_t;
/// It needs to be smaller than 0
static const int EMPTY_ID;
int ntotal; ///< nb of nodes
/// construction-time parameters
int R; ///< nb of neighbors per node
int L; ///< length of the search path at construction time
int C; ///< candidate pool size at construction time
// search-time parameters
int search_L; ///< length of the search path
int enterpoint; ///< enterpoint
std::shared_ptr<nsg::Graph<int>> final_graph; ///< NSG graph structure
bool is_built; ///< NSG is built or not
RandomGenerator rng; ///< random generator
explicit NSG(int R = 32);
// build NSG from a KNN graph
void build(
Index* storage,
idx_t n,
const nsg::Graph<idx_t>& knn_graph,
bool verbose);
// reset the graph
void reset();
// search interface
void search(
DistanceComputer& dis,
int k,
idx_t* I,
float* D,
VisitedTable& vt) const;
// Compute the center point
void init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph);
// Search on a built graph.
// If collect_fullset is true, the visited nodes will be
// collected in `fullset`.
template <bool collect_fullset, class index_t>
void search_on_graph(
const nsg::Graph<index_t>& graph,
DistanceComputer& dis,
VisitedTable& vt,
int ep,
int pool_size,
std::vector<Neighbor>& retset,
std::vector<Node>& fullset) const;
// Add reverse links
void add_reverse_links(
int q,
std::vector<std::mutex>& locks,
DistanceComputer& dis,
nsg::Graph<Node>& graph);
void sync_prune(
int q,
std::vector<Node>& pool,
DistanceComputer& dis,
VisitedTable& vt,
const nsg::Graph<idx_t>& knn_graph,
nsg::Graph<Node>& graph);
void link(
Index* storage,
const nsg::Graph<idx_t>& knn_graph,
nsg::Graph<Node>& graph,
bool verbose);
// make NSG be fully connected
int tree_grow(Index* storage, std::vector<int>& degrees);
// count the size of the connected component
// using depth first search start by root
int dfs(VisitedTable& vt, int root, int cnt) const;
// attach one unlinked node
int attach_unlinked(
Index* storage,
VisitedTable& vt,
VisitedTable& vt2,
std::vector<int>& degrees);
// check the integrity of the NSG built
void check_graph() const;
};
} // namespace faiss

View File

@ -158,6 +158,19 @@ handle_Quantizer(ProductQuantizer)
handle_Quantizer(ScalarQuantizer)
def handle_NSG(the_class):
def replacement_build(self, x, graph):
n, d = x.shape
assert d == self.d
assert graph.ndim == 2
assert graph.shape[0] == n
K = graph.shape[1]
self.build_c(n, swig_ptr(x), swig_ptr(graph), K)
replace_method(the_class, 'build', replacement_build)
def handle_Index(the_class):
def replacement_add(self, x):
@ -691,6 +704,9 @@ for symbol in dir(this_module):
if issubclass(the_class, ParameterSpace):
handle_ParameterSpace(the_class)
if issubclass(the_class, IndexNSG):
handle_NSG(the_class)
###########################################
# Add Python references to objects

View File

@ -91,6 +91,9 @@ typedef uint64_t size_t;
#include <faiss/impl/NNDescent.h>
#include <faiss/IndexNNDescent.h>
#include <faiss/impl/NSG.h>
#include <faiss/IndexNSG.h>
#include <faiss/MetaIndexes.h>
#include <faiss/IndexRefine.h>
@ -406,6 +409,8 @@ void gpu_sync_all_devices()
%include <faiss/IndexNNDescent.h>
%include <faiss/IndexIVFFlat.h>
%include <faiss/impl/NSG.h>
%include <faiss/IndexNSG.h>
#ifndef SWIGWIN
%warnfilter(401) faiss::OnDiskInvertedListsIOHook;
@ -552,6 +557,7 @@ void gpu_sync_all_devices()
DOWNCAST ( IndexHNSWSQ )
DOWNCAST ( IndexHNSW2Level )
DOWNCAST ( IndexNNDescentFlat )
DOWNCAST ( IndexNSGFlat )
DOWNCAST ( Index2Layer )
#ifdef GPU_WRAPPER
DOWNCAST_GPU ( GpuIndexIVFPQ )

View File

@ -592,6 +592,147 @@ class TestHNSW(unittest.TestCase):
assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0])
class TestNSG(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 0
nb = 1500
nq = 500
self.GK = 32
_, self.xb, self.xq = get_dataset_2(d, nt, nb, nq)
def make_knn_graph(self, metric):
n = self.xb.shape[0]
d = self.xb.shape[1]
index = faiss.IndexFlat(d, metric)
index.add(self.xb)
_, I = index.search(self.xb, self.GK + 1)
knn_graph = np.zeros((n, self.GK), dtype=np.int64)
# For the inner product distance, the distance between a vector and itself
# may not be the smallest, so it is not guaranteed that I[:, 0] is the query itself.
for i in range(n):
cnt = 0
for j in range(self.GK + 1):
if I[i, j] != i:
knn_graph[i, cnt] = I[i, j]
cnt += 1
if cnt == self.GK:
break
return knn_graph
def subtest_connectivity(self, index, nb):
vt = faiss.VisitedTable(nb)
count = index.nsg.dfs(vt, index.nsg.enterpoint, 0)
self.assertEqual(count, nb)
def subtest_add(self, build_type, thresh, metric=faiss.METRIC_L2):
d = self.xq.shape[1]
metrics = {faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}
flat_index = faiss.IndexFlat(d, metric)
flat_index.add(self.xb)
Dref, Iref = flat_index.search(self.xq, 1)
index = faiss.IndexNSGFlat(d, 16, metric)
index.verbose = True
index.build_type = build_type
index.GK = self.GK
index.add(self.xb)
Dnsg, Insg = index.search(self.xq, 1)
recalls = (Iref == Insg).sum()
print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
self.assertGreaterEqual(recalls, thresh)
self.subtest_connectivity(index, self.xb.shape[0])
def subtest_build(self, knn_graph, thresh, metric=faiss.METRIC_L2):
d = self.xq.shape[1]
metrics = {faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}
flat_index = faiss.IndexFlat(d, metric)
flat_index.add(self.xb)
Dref, Iref = flat_index.search(self.xq, 1)
index = faiss.IndexNSGFlat(d, 16, metric)
index.verbose = True
index.build(self.xb, knn_graph)
Dnsg, Insg = index.search(self.xq, 1)
recalls = (Iref == Insg).sum()
print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
self.assertGreaterEqual(recalls, thresh)
self.subtest_connectivity(index, self.xb.shape[0])
def test_add_bruteforce_L2(self):
self.subtest_add(0, 475, faiss.METRIC_L2)
def test_add_nndescent_L2(self):
self.subtest_add(1, 475, faiss.METRIC_L2)
def test_add_bruteforce_IP(self):
self.subtest_add(0, 480, faiss.METRIC_INNER_PRODUCT)
def test_add_nndescent_IP(self):
self.subtest_add(1, 480, faiss.METRIC_INNER_PRODUCT)
def test_build_L2(self):
knn_graph = self.make_knn_graph(faiss.METRIC_L2)
self.subtest_build(knn_graph, 475, faiss.METRIC_L2)
def test_build_IP(self):
knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT)
self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT)
def test_build_invalid_knng(self):
"""Make some invalid entries in the input knn graph.
It would cause a warning but IndexNSG should be able
to handel this.
"""
knn_graph = self.make_knn_graph(faiss.METRIC_L2)
knn_graph[:100, 5] = -111
self.subtest_build(knn_graph, 475, faiss.METRIC_L2)
knn_graph = self.make_knn_graph(faiss.METRIC_INNER_PRODUCT)
knn_graph[:100, 5] = -111
self.subtest_build(knn_graph, 480, faiss.METRIC_INNER_PRODUCT)
def test_reset(self):
"""test IndexNSG.reset()"""
d = self.xq.shape[1]
metrics = {faiss.METRIC_L2: 'L2',
faiss.METRIC_INNER_PRODUCT: 'IP'}
metric = faiss.METRIC_L2
flat_index = faiss.IndexFlat(d, metric)
flat_index.add(self.xb)
Dref, Iref = flat_index.search(self.xq, 1)
index = faiss.IndexNSGFlat(d, 16)
index.verbose = True
index.GK = 32
index.add(self.xb)
Dnsg, Insg = index.search(self.xq, 1)
recalls = (Iref == Insg).sum()
print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
self.assertGreaterEqual(recalls, 475)
self.subtest_connectivity(index, self.xb.shape[0])
index.reset()
index.add(self.xb)
Dnsg, Insg = index.search(self.xq, 1)
recalls = (Iref == Insg).sum()
print('metric: {}, nb equal: {}'.format(metrics[metric], recalls))
self.assertGreaterEqual(recalls, 475)
self.subtest_connectivity(index, self.xb.shape[0])
class TestDistancesPositive(unittest.TestCase):

View File

@ -419,7 +419,7 @@ class TestNNDescent(unittest.TestCase):
index.nndescent.iter = 5
index.verbose = False
index.nndescent.search_L = search_L;
index.nndescent.search_L = search_L
index.add(xb)
D, I = index.search(xq, topk)