Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IndexWrapper-inl.h
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "../../FaissAssert.h"
10 
11 namespace faiss { namespace gpu {
12 
13 template <typename GpuIndex>
14 IndexWrapper<GpuIndex>::IndexWrapper(
15  int numGpus,
16  std::function<std::unique_ptr<GpuIndex>(GpuResources*, int)> init) {
17  FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
18  for (int i = 0; i < numGpus; ++i) {
19  auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
20  new StandardGpuResources);
21 
22  subIndex.emplace_back(init(res.get(), i));
23  resources.emplace_back(std::move(res));
24  }
25 
26  if (numGpus > 1) {
27  // create proxy
28  replicaIndex =
29  std::unique_ptr<faiss::IndexReplicas>(new faiss::IndexReplicas);
30 
31  for (auto& index : subIndex) {
32  replicaIndex->addIndex(index.get());
33  }
34  }
35 }
36 
37 template <typename GpuIndex>
39 IndexWrapper<GpuIndex>::getIndex() {
40  if ((bool) replicaIndex) {
41  return replicaIndex.get();
42  } else {
43  FAISS_ASSERT(!subIndex.empty());
44  return subIndex.front().get();
45  }
46 }
47 
48 template <typename GpuIndex>
49 void
50 IndexWrapper<GpuIndex>::runOnIndices(std::function<void(GpuIndex*)> f) {
51 
52  if ((bool) replicaIndex) {
53  replicaIndex->runOnIndex(
54  [f](int, faiss::Index* index) {
55  f(dynamic_cast<GpuIndex*>(index));
56  });
57  } else {
58  FAISS_ASSERT(!subIndex.empty());
59  f(subIndex.front().get());
60  }
61 }
62 
63 template <typename GpuIndex>
64 void
65 IndexWrapper<GpuIndex>::setNumProbes(int nprobe) {
66  runOnIndices([nprobe](GpuIndex* index) {
67  index->setNumProbes(nprobe);
68  });
69 }
70 
71 } }