/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include namespace faiss { namespace gpu { template IndexWrapper::IndexWrapper( int numGpus, std::function(GpuResources*, int)> init) { FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices()); for (int i = 0; i < numGpus; ++i) { auto res = std::unique_ptr( new StandardGpuResources); subIndex.emplace_back(init(res.get(), i)); resources.emplace_back(std::move(res)); } if (numGpus > 1) { // create proxy replicaIndex = std::unique_ptr(new faiss::IndexReplicas); for (auto& index : subIndex) { replicaIndex->addIndex(index.get()); } } } template faiss::Index* IndexWrapper::getIndex() { if ((bool) replicaIndex) { return replicaIndex.get(); } else { FAISS_ASSERT(!subIndex.empty()); return subIndex.front().get(); } } template void IndexWrapper::runOnIndices(std::function f) { if ((bool) replicaIndex) { replicaIndex->runOnIndex( [f](int, faiss::Index* index) { f(dynamic_cast(index)); }); } else { FAISS_ASSERT(!subIndex.empty()); f(subIndex.front().get()); } } template void IndexWrapper::setNumProbes(int nprobe) { runOnIndices([nprobe](GpuIndex* index) { index->setNumProbes(nprobe); }); } } }