/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD+Patents license found in the * LICENSE file in the root directory of this source tree. */ // -*- c++ -*- #include "IndexShards.h" #include #include #include "FaissAssert.h" #include "Heap.h" #include "WorkerThread.h" namespace faiss { // subroutines namespace { typedef Index::idx_t idx_t; // add translation to all valid labels void translate_labels (long n, idx_t *labels, long translation) { if (translation == 0) return; for (long i = 0; i < n; i++) { if(labels[i] < 0) continue; labels[i] += translation; } } /** merge result tables from several shards. * @param all_distances size nshard * n * k * @param all_labels idem * @param translartions label translations to apply, size nshard */ template void merge_tables (long n, long k, long nshard, typename IndexClass::distance_t *distances, idx_t *labels, const typename IndexClass::distance_t *all_distances, idx_t *all_labels, const long *translations) { if(k == 0) { return; } using distance_t = typename IndexClass::distance_t; long stride = n * k; #pragma omp parallel { std::vector buf (2 * nshard); int * pointer = buf.data(); int * shard_ids = pointer + nshard; std::vector buf2 (nshard); distance_t * heap_vals = buf2.data(); #pragma omp for for (long i = 0; i < n; i++) { // the heap maps values to the shard where they are // produced. const distance_t *D_in = all_distances + i * k; const idx_t *I_in = all_labels + i * k; int heap_size = 0; for (long s = 0; s < nshard; s++) { pointer[s] = 0; if (I_in[stride * s] >= 0) heap_push (++heap_size, heap_vals, shard_ids, D_in[stride * s], s); } distance_t *D = distances + i * k; idx_t *I = labels + i * k; for (int j = 0; j < k; j++) { if (heap_size == 0) { I[j] = -1; D[j] = C::neutral(); } else { // pop best element int s = shard_ids[0]; int & p = pointer[s]; D[j] = heap_vals[0]; I[j] = I_in[stride * s + p] + translations[s]; heap_pop (heap_size--, heap_vals, shard_ids); p++; if (p < k && I_in[stride * s + p] >= 0) heap_push (++heap_size, heap_vals, shard_ids, D_in[stride * s + p], s); } } } } } template void runOnIndexes(bool threaded, std::function f, std::vector indexes) { FAISS_THROW_IF_NOT_MSG(!indexes.empty(), "no shards in index"); if (!threaded) { for (int no = 0; no < indexes.size(); no++) { IndexClass *index = indexes[no]; f(no, index); } } else { std::vector > threads; std::vector> v; for (int no = 0; no < indexes.size(); no++) { IndexClass *index = indexes[no]; threads.emplace_back(new WorkerThread()); WorkerThread *wt = threads.back().get(); v.emplace_back(wt->add([no, index, f](){ f(no, index); })); } // Blocking wait for completion for (auto& func : v) { func.get(); } } }; } // anonymous namespace template IndexShardsTemplate::IndexShardsTemplate (idx_t d, bool threaded, bool successive_ids): IndexClass (d), own_fields (false), threaded (threaded), successive_ids (successive_ids) { } template void IndexShardsTemplate::add_shard (IndexClass *idx) { shard_indexes.push_back (idx); sync_with_shard_indexes (); } template void IndexShardsTemplate::sync_with_shard_indexes () { if (shard_indexes.empty()) return; IndexClass * index0 = shard_indexes[0]; this->d = index0->d; this->metric_type = index0->metric_type; this->is_trained = index0->is_trained; this->ntotal = index0->ntotal; for (int i = 1; i < shard_indexes.size(); i++) { IndexClass * index = shard_indexes[i]; FAISS_THROW_IF_NOT (this->metric_type == index->metric_type); FAISS_THROW_IF_NOT (this->d == index->d); this->ntotal += index->ntotal; } } template void IndexShardsTemplate::train (idx_t n, const component_t *x) { auto train_func = [n, x](int no, IndexClass *index) { if (index->verbose) printf ("begin train shard %d on %ld points\n", no, n); index->train(n, x); if (index->verbose) printf ("end train shard %d\n", no); }; runOnIndexes (threaded, train_func, shard_indexes); sync_with_shard_indexes (); } template void IndexShardsTemplate::add (idx_t n, const component_t *x) { add_with_ids (n, x, nullptr); } template void IndexShardsTemplate::add_with_ids (idx_t n, const component_t * x, const idx_t *xids) { FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids), "It makes no sense to pass in ids and " "request them to be shifted"); if (successive_ids) { FAISS_THROW_IF_NOT_MSG(!xids, "It makes no sense to pass in ids and " "request them to be shifted"); FAISS_THROW_IF_NOT_MSG(this->ntotal == 0, "when adding to IndexShards with sucessive_ids, " "only add() in a single pass is supported"); } long nshard = shard_indexes.size(); const idx_t *ids = xids; ScopeDeleter del; if (!ids && !successive_ids) { idx_t *aids = new idx_t[n]; for (idx_t i = 0; i < n; i++) aids[i] = this->ntotal + i; ids = aids; del.set (ids); } size_t components_per_vec = sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d; auto add_func = [n, ids, x, nshard, components_per_vec] (int no, IndexClass *index) { idx_t i0 = no * n / nshard; idx_t i1 = (no + 1) * n / nshard; auto x0 = x + i0 * components_per_vec; if (index->verbose) { printf ("begin add shard %d on %ld points\n", no, n); } if (ids) { index->add_with_ids (i1 - i0, x0, ids + i0); } else { index->add (i1 - i0, x0); } if (index->verbose) { printf ("end add shard %d on %ld points\n", no, i1 - i0); } }; runOnIndexes (threaded, add_func, shard_indexes); this->ntotal += n; } template void IndexShardsTemplate::reset () { for (int i = 0; i < shard_indexes.size(); i++) { shard_indexes[i]->reset (); } sync_with_shard_indexes (); } template void IndexShardsTemplate::search ( idx_t n, const component_t *x, idx_t k, distance_t *distances, idx_t *labels) const { long nshard = shard_indexes.size(); distance_t *all_distances = new distance_t [nshard * k * n]; idx_t *all_labels = new idx_t [nshard * k * n]; ScopeDeleter del (all_distances); ScopeDeleter del2 (all_labels); auto query_func = [n, k, x, all_distances, all_labels] (int no, IndexClass *index) { if (index->verbose) { printf ("begin query shard %d on %ld points\n", no, n); } index->search (n, x, k, all_distances + no * k * n, all_labels + no * k * n); if (index->verbose) { printf ("end query shard %d\n", no); } }; runOnIndexes (threaded, query_func, shard_indexes); std::vector translations (nshard, 0); if (successive_ids) { translations[0] = 0; for (int s = 0; s + 1 < nshard; s++) translations [s + 1] = translations [s] + shard_indexes [s]->ntotal; } if (this->metric_type == METRIC_L2) { merge_tables > ( n, k, nshard, distances, labels, all_distances, all_labels, translations.data ()); } else { merge_tables > ( n, k, nshard, distances, labels, all_distances, all_labels, translations.data ()); } } template IndexShardsTemplate::~IndexShardsTemplate () { if (own_fields) { for (int s = 0; s < shard_indexes.size(); s++) delete shard_indexes [s]; } } // explicit instanciations template struct IndexShardsTemplate; template struct IndexShardsTemplate; } // namespace faiss