2019-03-29 23:32:28 +08:00
|
|
|
/**
|
2019-05-28 22:17:22 +08:00
|
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
2019-03-29 23:32:28 +08:00
|
|
|
*
|
2019-05-28 22:17:22 +08:00
|
|
|
* This source code is licensed under the MIT license found in the
|
2019-03-29 23:32:28 +08:00
|
|
|
* LICENSE file in the root directory of this source tree.
|
|
|
|
*/
|
|
|
|
|
|
|
|
// -*- c++ -*-
|
|
|
|
|
2019-09-21 00:59:10 +08:00
|
|
|
#include <faiss/IndexShards.h>
|
2019-03-29 23:32:28 +08:00
|
|
|
|
|
|
|
#include <cstdio>
|
|
|
|
#include <functional>
|
|
|
|
|
2019-09-21 00:59:10 +08:00
|
|
|
#include <faiss/impl/FaissAssert.h>
|
|
|
|
#include <faiss/utils/Heap.h>
|
|
|
|
#include <faiss/utils/WorkerThread.h>
|
2019-03-29 23:32:28 +08:00
|
|
|
|
|
|
|
namespace faiss {
|
|
|
|
|
|
|
|
// subroutines
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
typedef Index::idx_t idx_t;
|
|
|
|
|
|
|
|
|
|
|
|
// add translation to all valid labels
|
|
|
|
void translate_labels (long n, idx_t *labels, long translation)
|
|
|
|
{
|
|
|
|
if (translation == 0) return;
|
|
|
|
for (long i = 0; i < n; i++) {
|
|
|
|
if(labels[i] < 0) continue;
|
|
|
|
labels[i] += translation;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/** merge result tables from several shards.
|
|
|
|
* @param all_distances size nshard * n * k
|
|
|
|
* @param all_labels idem
|
|
|
|
* @param translartions label translations to apply, size nshard
|
|
|
|
*/
|
|
|
|
|
|
|
|
template <class IndexClass, class C>
|
2019-05-28 22:17:22 +08:00
|
|
|
void
|
|
|
|
merge_tables(long n, long k, long nshard,
|
|
|
|
typename IndexClass::distance_t *distances,
|
|
|
|
idx_t *labels,
|
|
|
|
const std::vector<typename IndexClass::distance_t>& all_distances,
|
|
|
|
const std::vector<idx_t>& all_labels,
|
|
|
|
const std::vector<long>& translations) {
|
|
|
|
if (k == 0) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
using distance_t = typename IndexClass::distance_t;
|
|
|
|
|
|
|
|
long stride = n * k;
|
2019-03-29 23:32:28 +08:00
|
|
|
#pragma omp parallel
|
2019-05-28 22:17:22 +08:00
|
|
|
{
|
|
|
|
std::vector<int> buf (2 * nshard);
|
|
|
|
int * pointer = buf.data();
|
|
|
|
int * shard_ids = pointer + nshard;
|
|
|
|
std::vector<distance_t> buf2 (nshard);
|
|
|
|
distance_t * heap_vals = buf2.data();
|
2019-03-29 23:32:28 +08:00
|
|
|
#pragma omp for
|
2019-05-28 22:17:22 +08:00
|
|
|
for (long i = 0; i < n; i++) {
|
|
|
|
// the heap maps values to the shard where they are
|
|
|
|
// produced.
|
|
|
|
const distance_t *D_in = all_distances.data() + i * k;
|
|
|
|
const idx_t *I_in = all_labels.data() + i * k;
|
|
|
|
int heap_size = 0;
|
|
|
|
|
|
|
|
for (long s = 0; s < nshard; s++) {
|
|
|
|
pointer[s] = 0;
|
|
|
|
if (I_in[stride * s] >= 0) {
|
|
|
|
heap_push<C> (++heap_size, heap_vals, shard_ids,
|
|
|
|
D_in[stride * s], s);
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
2019-05-28 22:17:22 +08:00
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
distance_t *D = distances + i * k;
|
|
|
|
idx_t *I = labels + i * k;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
for (int j = 0; j < k; j++) {
|
|
|
|
if (heap_size == 0) {
|
|
|
|
I[j] = -1;
|
|
|
|
D[j] = C::neutral();
|
|
|
|
} else {
|
|
|
|
// pop best element
|
|
|
|
int s = shard_ids[0];
|
|
|
|
int & p = pointer[s];
|
|
|
|
D[j] = heap_vals[0];
|
|
|
|
I[j] = I_in[stride * s + p] + translations[s];
|
|
|
|
|
|
|
|
heap_pop<C> (heap_size--, heap_vals, shard_ids);
|
|
|
|
p++;
|
|
|
|
if (p < k && I_in[stride * s + p] >= 0) {
|
|
|
|
heap_push<C> (++heap_size, heap_vals, shard_ids,
|
|
|
|
D_in[stride * s + p], s);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
2019-05-28 22:17:22 +08:00
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
2019-05-28 22:17:22 +08:00
|
|
|
}
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
|
|
|
} // anonymous namespace
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
|
|
|
|
bool threaded,
|
|
|
|
bool successive_ids)
|
|
|
|
: ThreadedIndex<IndexT>(d, threaded),
|
|
|
|
successive_ids(successive_ids) {
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
|
|
|
|
bool threaded,
|
|
|
|
bool successive_ids)
|
|
|
|
: ThreadedIndex<IndexT>(d, threaded),
|
|
|
|
successive_ids(successive_ids) {
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
|
|
|
|
bool successive_ids)
|
|
|
|
: ThreadedIndex<IndexT>(threaded),
|
|
|
|
successive_ids(successive_ids) {
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
|
|
|
|
sync_with_shard_indexes();
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
|
|
|
|
sync_with_shard_indexes();
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::sync_with_shard_indexes() {
|
|
|
|
if (!this->count()) {
|
|
|
|
this->is_trained = false;
|
|
|
|
this->ntotal = 0;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
return;
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
auto firstIndex = this->at(0);
|
|
|
|
this->metric_type = firstIndex->metric_type;
|
|
|
|
this->is_trained = firstIndex->is_trained;
|
|
|
|
this->ntotal = firstIndex->ntotal;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
for (int i = 1; i < this->count(); ++i) {
|
|
|
|
auto index = this->at(i);
|
|
|
|
FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
|
|
|
|
FAISS_THROW_IF_NOT(this->d == index->d);
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
this->ntotal += index->ntotal;
|
|
|
|
}
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::train(idx_t n,
|
|
|
|
const component_t *x) {
|
|
|
|
auto fn =
|
|
|
|
[n, x](int no, IndexT *index) {
|
|
|
|
if (index->verbose) {
|
|
|
|
printf("begin train shard %d on %ld points\n", no, n);
|
|
|
|
}
|
|
|
|
|
|
|
|
index->train(n, x);
|
|
|
|
|
|
|
|
if (index->verbose) {
|
|
|
|
printf("end train shard %d\n", no);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
};
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
this->runOnIndex(fn);
|
|
|
|
sync_with_shard_indexes();
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::add(idx_t n,
|
|
|
|
const component_t *x) {
|
|
|
|
add_with_ids(n, x, nullptr);
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
|
|
|
|
const component_t * x,
|
|
|
|
const idx_t *xids) {
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
|
|
|
|
"It makes no sense to pass in ids and "
|
|
|
|
"request them to be shifted");
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
if (successive_ids) {
|
|
|
|
FAISS_THROW_IF_NOT_MSG(!xids,
|
|
|
|
"It makes no sense to pass in ids and "
|
|
|
|
"request them to be shifted");
|
|
|
|
FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
|
|
|
|
"when adding to IndexShards with sucessive_ids, "
|
|
|
|
"only add() in a single pass is supported");
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
idx_t nshard = this->count();
|
|
|
|
const idx_t *ids = xids;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
std::vector<idx_t> aids;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
if (!ids && !successive_ids) {
|
|
|
|
aids.resize(n);
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
for (idx_t i = 0; i < n; i++) {
|
|
|
|
aids[i] = this->ntotal + i;
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
ids = aids.data();
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
size_t components_per_vec =
|
|
|
|
sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
auto fn =
|
|
|
|
[n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
|
|
|
|
idx_t i0 = (idx_t) no * n / nshard;
|
|
|
|
idx_t i1 = ((idx_t) no + 1) * n / nshard;
|
|
|
|
auto x0 = x + i0 * components_per_vec;
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
if (index->verbose) {
|
|
|
|
printf ("begin add shard %d on %ld points\n", no, n);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
if (ids) {
|
|
|
|
index->add_with_ids (i1 - i0, x0, ids + i0);
|
|
|
|
} else {
|
|
|
|
index->add (i1 - i0, x0);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
if (index->verbose) {
|
|
|
|
printf ("end add shard %d on %ld points\n", no, i1 - i0);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
};
|
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
this->runOnIndex(fn);
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
// This is safe to do here because the current thread controls execution in
|
|
|
|
// all threads, and nothing else is happening
|
|
|
|
this->ntotal += n;
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
template <typename IndexT>
|
|
|
|
void
|
|
|
|
IndexShardsTemplate<IndexT>::search(idx_t n,
|
|
|
|
const component_t *x,
|
|
|
|
idx_t k,
|
|
|
|
distance_t *distances,
|
|
|
|
idx_t *labels) const {
|
|
|
|
long nshard = this->count();
|
|
|
|
|
|
|
|
std::vector<distance_t> all_distances(nshard * k * n);
|
|
|
|
std::vector<idx_t> all_labels(nshard * k * n);
|
|
|
|
|
|
|
|
auto fn =
|
|
|
|
[n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
|
|
|
|
if (index->verbose) {
|
|
|
|
printf ("begin query shard %d on %ld points\n", no, n);
|
|
|
|
}
|
|
|
|
|
|
|
|
index->search (n, x, k,
|
|
|
|
all_distances.data() + no * k * n,
|
|
|
|
all_labels.data() + no * k * n);
|
|
|
|
|
|
|
|
if (index->verbose) {
|
|
|
|
printf ("end query shard %d\n", no);
|
|
|
|
}
|
|
|
|
};
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
this->runOnIndex(fn);
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
std::vector<long> translations(nshard, 0);
|
2019-03-29 23:32:28 +08:00
|
|
|
|
2019-05-28 22:17:22 +08:00
|
|
|
// Because we just called runOnIndex above, it is safe to access the sub-index
|
|
|
|
// ntotal here
|
|
|
|
if (successive_ids) {
|
|
|
|
translations[0] = 0;
|
|
|
|
|
|
|
|
for (int s = 0; s + 1 < nshard; s++) {
|
|
|
|
translations[s + 1] = translations[s] + this->at(s)->ntotal;
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
2019-05-28 22:17:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (this->metric_type == METRIC_L2) {
|
|
|
|
merge_tables<IndexT, CMin<distance_t, int>>(
|
|
|
|
n, k, nshard, distances, labels,
|
|
|
|
all_distances, all_labels, translations);
|
|
|
|
} else {
|
|
|
|
merge_tables<IndexT, CMax<distance_t, int>>(
|
|
|
|
n, k, nshard, distances, labels,
|
|
|
|
all_distances, all_labels, translations);
|
|
|
|
}
|
2019-03-29 23:32:28 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// explicit instanciations
|
|
|
|
template struct IndexShardsTemplate<Index>;
|
|
|
|
template struct IndexShardsTemplate<IndexBinary>;
|
|
|
|
|
|
|
|
} // namespace faiss
|