9 #include "../../FaissAssert.h"
11 namespace faiss {
namespace gpu {
13 template <
typename GpuIndex>
14 IndexWrapper<GpuIndex>::IndexWrapper(
16 std::function<std::unique_ptr<GpuIndex>(GpuResources*,
int)> init) {
17 FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
18 for (
int i = 0; i < numGpus; ++i) {
19 auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
20 new StandardGpuResources);
22 subIndex.emplace_back(init(res.get(), i));
23 resources.emplace_back(std::move(res));
31 for (
auto& index : subIndex) {
32 replicaIndex->addIndex(index.get());
37 template <
typename GpuIndex>
39 IndexWrapper<GpuIndex>::getIndex() {
40 if ((
bool) replicaIndex) {
41 return replicaIndex.get();
43 FAISS_ASSERT(!subIndex.empty());
44 return subIndex.front().get();
48 template <
typename GpuIndex>
50 IndexWrapper<GpuIndex>::runOnIndices(std::function<
void(GpuIndex*)> f) {
52 if ((
bool) replicaIndex) {
53 replicaIndex->runOnIndex(
55 f(dynamic_cast<GpuIndex*>(index));
58 FAISS_ASSERT(!subIndex.empty());
59 f(subIndex.front().get());
63 template <
typename GpuIndex>
65 IndexWrapper<GpuIndex>::setNumProbes(
int nprobe) {
66 runOnIndices([nprobe](GpuIndex* index) {
67 index->setNumProbes(nprobe);