331 lines
9.2 KiB
C++
331 lines
9.2 KiB
C++
/**
|
|
* Copyright (c) 2015-present, Facebook, Inc.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD+Patents license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include "IndexShards.h"
|
|
|
|
#include <cstdio>
|
|
#include <functional>
|
|
|
|
#include "FaissAssert.h"
|
|
#include "Heap.h"
|
|
#include "WorkerThread.h"
|
|
|
|
namespace faiss {
|
|
|
|
// subroutines
|
|
namespace {
|
|
|
|
typedef Index::idx_t idx_t;
|
|
|
|
|
|
// add translation to all valid labels
|
|
void translate_labels (long n, idx_t *labels, long translation)
|
|
{
|
|
if (translation == 0) return;
|
|
for (long i = 0; i < n; i++) {
|
|
if(labels[i] < 0) continue;
|
|
labels[i] += translation;
|
|
}
|
|
}
|
|
|
|
|
|
/** merge result tables from several shards.
|
|
* @param all_distances size nshard * n * k
|
|
* @param all_labels idem
|
|
* @param translartions label translations to apply, size nshard
|
|
*/
|
|
|
|
template <class IndexClass, class C>
|
|
void merge_tables (long n, long k, long nshard,
|
|
typename IndexClass::distance_t *distances,
|
|
idx_t *labels,
|
|
const typename IndexClass::distance_t *all_distances,
|
|
idx_t *all_labels,
|
|
const long *translations)
|
|
{
|
|
if(k == 0) {
|
|
return;
|
|
}
|
|
using distance_t = typename IndexClass::distance_t;
|
|
|
|
long stride = n * k;
|
|
#pragma omp parallel
|
|
{
|
|
std::vector<int> buf (2 * nshard);
|
|
int * pointer = buf.data();
|
|
int * shard_ids = pointer + nshard;
|
|
std::vector<distance_t> buf2 (nshard);
|
|
distance_t * heap_vals = buf2.data();
|
|
#pragma omp for
|
|
for (long i = 0; i < n; i++) {
|
|
// the heap maps values to the shard where they are
|
|
// produced.
|
|
const distance_t *D_in = all_distances + i * k;
|
|
const idx_t *I_in = all_labels + i * k;
|
|
int heap_size = 0;
|
|
|
|
for (long s = 0; s < nshard; s++) {
|
|
pointer[s] = 0;
|
|
if (I_in[stride * s] >= 0)
|
|
heap_push<C> (++heap_size, heap_vals, shard_ids,
|
|
D_in[stride * s], s);
|
|
}
|
|
|
|
distance_t *D = distances + i * k;
|
|
idx_t *I = labels + i * k;
|
|
|
|
for (int j = 0; j < k; j++) {
|
|
if (heap_size == 0) {
|
|
I[j] = -1;
|
|
D[j] = C::neutral();
|
|
} else {
|
|
// pop best element
|
|
int s = shard_ids[0];
|
|
int & p = pointer[s];
|
|
D[j] = heap_vals[0];
|
|
I[j] = I_in[stride * s + p] + translations[s];
|
|
|
|
heap_pop<C> (heap_size--, heap_vals, shard_ids);
|
|
p++;
|
|
if (p < k && I_in[stride * s + p] >= 0)
|
|
heap_push<C> (++heap_size, heap_vals, shard_ids,
|
|
D_in[stride * s + p], s);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void runOnIndexes(bool threaded,
|
|
std::function<void(int no, IndexClass*)> f,
|
|
std::vector<IndexClass *> indexes)
|
|
{
|
|
FAISS_THROW_IF_NOT_MSG(!indexes.empty(), "no shards in index");
|
|
|
|
if (!threaded) {
|
|
for (int no = 0; no < indexes.size(); no++) {
|
|
IndexClass *index = indexes[no];
|
|
f(no, index);
|
|
}
|
|
} else {
|
|
std::vector<std::unique_ptr<WorkerThread> > threads;
|
|
std::vector<std::future<bool>> v;
|
|
|
|
for (int no = 0; no < indexes.size(); no++) {
|
|
IndexClass *index = indexes[no];
|
|
threads.emplace_back(new WorkerThread());
|
|
WorkerThread *wt = threads.back().get();
|
|
v.emplace_back(wt->add([no, index, f](){ f(no, index); }));
|
|
}
|
|
|
|
// Blocking wait for completion
|
|
for (auto& func : v) {
|
|
func.get();
|
|
}
|
|
}
|
|
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
|
|
template<class IndexClass>
|
|
IndexShardsTemplate<IndexClass>::IndexShardsTemplate (idx_t d, bool threaded, bool successive_ids):
|
|
IndexClass (d), own_fields (false),
|
|
threaded (threaded), successive_ids (successive_ids)
|
|
{
|
|
|
|
|
|
}
|
|
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::add_shard (IndexClass *idx)
|
|
{
|
|
shard_indexes.push_back (idx);
|
|
sync_with_shard_indexes ();
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::sync_with_shard_indexes ()
|
|
{
|
|
if (shard_indexes.empty()) return;
|
|
IndexClass * index0 = shard_indexes[0];
|
|
this->d = index0->d;
|
|
this->metric_type = index0->metric_type;
|
|
this->is_trained = index0->is_trained;
|
|
this->ntotal = index0->ntotal;
|
|
for (int i = 1; i < shard_indexes.size(); i++) {
|
|
IndexClass * index = shard_indexes[i];
|
|
FAISS_THROW_IF_NOT (this->metric_type == index->metric_type);
|
|
FAISS_THROW_IF_NOT (this->d == index->d);
|
|
this->ntotal += index->ntotal;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::train (idx_t n, const component_t *x)
|
|
{
|
|
auto train_func = [n, x](int no, IndexClass *index)
|
|
{
|
|
if (index->verbose)
|
|
printf ("begin train shard %d on %ld points\n", no, n);
|
|
index->train(n, x);
|
|
if (index->verbose)
|
|
printf ("end train shard %d\n", no);
|
|
};
|
|
|
|
runOnIndexes<IndexClass> (threaded, train_func, shard_indexes);
|
|
sync_with_shard_indexes ();
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::add (idx_t n, const component_t *x)
|
|
{
|
|
add_with_ids (n, x, nullptr);
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::add_with_ids (idx_t n, const component_t * x, const idx_t *xids)
|
|
{
|
|
|
|
FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
|
|
"It makes no sense to pass in ids and "
|
|
"request them to be shifted");
|
|
|
|
if (successive_ids) {
|
|
FAISS_THROW_IF_NOT_MSG(!xids,
|
|
"It makes no sense to pass in ids and "
|
|
"request them to be shifted");
|
|
FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
|
|
"when adding to IndexShards with sucessive_ids, "
|
|
"only add() in a single pass is supported");
|
|
}
|
|
|
|
long nshard = shard_indexes.size();
|
|
const idx_t *ids = xids;
|
|
ScopeDeleter<idx_t> del;
|
|
if (!ids && !successive_ids) {
|
|
idx_t *aids = new idx_t[n];
|
|
for (idx_t i = 0; i < n; i++)
|
|
aids[i] = this->ntotal + i;
|
|
ids = aids;
|
|
del.set (ids);
|
|
}
|
|
|
|
size_t components_per_vec =
|
|
sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
|
|
|
|
auto add_func = [n, ids, x, nshard, components_per_vec]
|
|
(int no, IndexClass *index) {
|
|
|
|
idx_t i0 = no * n / nshard;
|
|
idx_t i1 = (no + 1) * n / nshard;
|
|
|
|
auto x0 = x + i0 * components_per_vec;
|
|
|
|
if (index->verbose) {
|
|
printf ("begin add shard %d on %ld points\n", no, n);
|
|
}
|
|
if (ids) {
|
|
index->add_with_ids (i1 - i0, x0, ids + i0);
|
|
} else {
|
|
index->add (i1 - i0, x0);
|
|
}
|
|
if (index->verbose) {
|
|
printf ("end add shard %d on %ld points\n", no, i1 - i0);
|
|
}
|
|
};
|
|
|
|
runOnIndexes<IndexClass> (threaded, add_func, shard_indexes);
|
|
|
|
this->ntotal += n;
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::reset ()
|
|
{
|
|
for (int i = 0; i < shard_indexes.size(); i++) {
|
|
shard_indexes[i]->reset ();
|
|
}
|
|
sync_with_shard_indexes ();
|
|
}
|
|
|
|
template<class IndexClass>
|
|
void IndexShardsTemplate<IndexClass>::search (
|
|
idx_t n, const component_t *x, idx_t k,
|
|
distance_t *distances, idx_t *labels) const
|
|
{
|
|
long nshard = shard_indexes.size();
|
|
distance_t *all_distances = new distance_t [nshard * k * n];
|
|
idx_t *all_labels = new idx_t [nshard * k * n];
|
|
ScopeDeleter<distance_t> del (all_distances);
|
|
ScopeDeleter<idx_t> del2 (all_labels);
|
|
|
|
auto query_func = [n, k, x, all_distances, all_labels]
|
|
(int no, IndexClass *index) {
|
|
|
|
if (index->verbose) {
|
|
printf ("begin query shard %d on %ld points\n", no, n);
|
|
}
|
|
index->search (n, x, k,
|
|
all_distances + no * k * n,
|
|
all_labels + no * k * n);
|
|
if (index->verbose) {
|
|
printf ("end query shard %d\n", no);
|
|
}
|
|
};
|
|
|
|
runOnIndexes<IndexClass> (threaded, query_func, shard_indexes);
|
|
|
|
std::vector<long> translations (nshard, 0);
|
|
if (successive_ids) {
|
|
translations[0] = 0;
|
|
for (int s = 0; s + 1 < nshard; s++)
|
|
translations [s + 1] = translations [s] +
|
|
shard_indexes [s]->ntotal;
|
|
}
|
|
|
|
if (this->metric_type == METRIC_L2) {
|
|
merge_tables<IndexClass, CMin<distance_t, int> > (
|
|
n, k, nshard, distances, labels,
|
|
all_distances, all_labels, translations.data ());
|
|
} else {
|
|
merge_tables<IndexClass, CMax<distance_t, int> > (
|
|
n, k, nshard, distances, labels,
|
|
all_distances, all_labels, translations.data ());
|
|
}
|
|
|
|
}
|
|
|
|
|
|
template<class IndexClass>
|
|
IndexShardsTemplate<IndexClass>::~IndexShardsTemplate ()
|
|
{
|
|
if (own_fields) {
|
|
for (int s = 0; s < shard_indexes.size(); s++)
|
|
delete shard_indexes [s];
|
|
}
|
|
}
|
|
|
|
// explicit instanciations
|
|
template struct IndexShardsTemplate<Index>;
|
|
template struct IndexShardsTemplate<IndexBinary>;
|
|
|
|
|
|
|
|
} // namespace faiss
|