faiss/ThreadedIndex-inl.h

193 lines
5.1 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "FaissAssert.h"
#include <exception>
#include <iostream>
namespace faiss {
template <typename IndexT>
ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
// 0 is default dimension
: ThreadedIndex(0, threaded) {
}
template <typename IndexT>
ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
: IndexT(d),
own_fields(false),
isThreaded_(threaded) {
}
template <typename IndexT>
ThreadedIndex<IndexT>::~ThreadedIndex() {
for (auto& p : indices_) {
if (isThreaded_) {
// should have worker thread
FAISS_ASSERT((bool) p.second);
// This will also flush all pending work
p.second->stop();
p.second->waitForThreadExit();
} else {
// should not have worker thread
FAISS_ASSERT(!(bool) p.second);
}
if (own_fields) {
delete p.first;
}
}
}
template <typename IndexT>
void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
// We inherit the dimension from the first index added to us if we don't have
// a set dimension
if (indices_.empty() && this->d == 0) {
this->d = index->d;
}
// The new index must match our set dimension
FAISS_THROW_IF_NOT_FMT(this->d == index->d,
"addIndex: dimension mismatch for "
"newly added index; expecting dim %d, "
"new index has dim %d",
this->d, index->d);
if (!indices_.empty()) {
auto& existing = indices_.front().first;
FAISS_THROW_IF_NOT_MSG(index->metric_type == existing->metric_type,
"addIndex: newly added index is "
"of different metric type than old index");
// Make sure this index is not duplicated
for (auto& p : indices_) {
FAISS_THROW_IF_NOT_MSG(p.first != index,
"addIndex: attempting to add index "
"that is already in the collection");
}
}
indices_.emplace_back(
std::make_pair(
index,
std::unique_ptr<WorkerThread>(isThreaded_ ?
new WorkerThread : nullptr)));
onAfterAddIndex(index);
}
template <typename IndexT>
void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
for (auto it = indices_.begin(); it != indices_.end(); ++it) {
if (it->first == index) {
// This is our index; stop the worker thread before removing it,
// to ensure that it has finished before function exit
if (isThreaded_) {
// should have worker thread
FAISS_ASSERT((bool) it->second);
it->second->stop();
it->second->waitForThreadExit();
} else {
// should not have worker thread
FAISS_ASSERT(!(bool) it->second);
}
indices_.erase(it);
onAfterRemoveIndex(index);
if (own_fields) {
delete index;
}
return;
}
}
// could not find our index
FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
}
template <typename IndexT>
void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
if (isThreaded_) {
std::vector<std::future<bool>> v;
for (int i = 0; i < this->indices_.size(); ++i) {
auto& p = this->indices_[i];
auto indexPtr = p.first;
v.emplace_back(p.second->add([f, i, indexPtr](){ f(i, indexPtr); }));
}
waitAndHandleFutures(v);
} else {
// Multiple exceptions may be thrown; gather them as we encounter them,
// while letting everything else run to completion
std::vector<std::pair<int, std::exception_ptr>> exceptions;
for (int i = 0; i < this->indices_.size(); ++i) {
auto& p = this->indices_[i];
try {
f(i, p.first);
} catch (...) {
exceptions.emplace_back(std::make_pair(i, std::current_exception()));
}
}
handleExceptions(exceptions);
}
}
template <typename IndexT>
void ThreadedIndex<IndexT>::runOnIndex(
std::function<void(int, const IndexT*)> f) const {
const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
[f](int i, IndexT* idx){ f(i, idx); });
}
template <typename IndexT>
void ThreadedIndex<IndexT>::reset() {
runOnIndex([](int, IndexT* index){ index->reset(); });
this->ntotal = 0;
this->is_trained = false;
}
template <typename IndexT>
void
ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {
}
template <typename IndexT>
void
ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {
}
template <typename IndexT>
void
ThreadedIndex<IndexT>::waitAndHandleFutures(std::vector<std::future<bool>>& v) {
// Blocking wait for completion for all of the indices, capturing any
// exceptions that are generated
std::vector<std::pair<int, std::exception_ptr>> exceptions;
for (int i = 0; i < v.size(); ++i) {
auto& fut = v[i];
try {
fut.get();
} catch (...) {
exceptions.emplace_back(std::make_pair(i, std::current_exception()));
}
}
handleExceptions(exceptions);
}
} // namespace