Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IndexWrapper-inl.h
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "../../FaissAssert.h"
11 
12 namespace faiss { namespace gpu {
13 
14 template <typename GpuIndex>
15 IndexWrapper<GpuIndex>::IndexWrapper(
16  int numGpus,
17  std::function<std::unique_ptr<GpuIndex>(GpuResources*, int)> init) {
18  FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
19  for (int i = 0; i < numGpus; ++i) {
20  auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
21  new StandardGpuResources);
22 
23  subIndex.emplace_back(init(res.get(), i));
24  resources.emplace_back(std::move(res));
25  }
26 
27  if (numGpus > 1) {
28  // create proxy
29  proxyIndex =
30  std::unique_ptr<faiss::gpu::IndexProxy>(new faiss::gpu::IndexProxy);
31 
32  for (auto& index : subIndex) {
33  proxyIndex->addIndex(index.get());
34  }
35  }
36 }
37 
38 template <typename GpuIndex>
40 IndexWrapper<GpuIndex>::getIndex() {
41  if ((bool) proxyIndex) {
42  return proxyIndex.get();
43  } else {
44  FAISS_ASSERT(!subIndex.empty());
45  return subIndex.front().get();
46  }
47 }
48 
49 template <typename GpuIndex>
50 void
51 IndexWrapper<GpuIndex>::runOnIndices(std::function<void(GpuIndex*)> f) {
52 
53  if ((bool) proxyIndex) {
54  proxyIndex->runOnIndex(
55  [f](faiss::Index* index) {
56  f(dynamic_cast<GpuIndex*>(index));
57  });
58  } else {
59  FAISS_ASSERT(!subIndex.empty());
60  f(subIndex.front().get());
61  }
62 }
63 
64 template <typename GpuIndex>
65 void
66 IndexWrapper<GpuIndex>::setNumProbes(int nprobe) {
67  runOnIndices([nprobe](GpuIndex* index) {
68  index->setNumProbes(nprobe);
69  });
70 }
71 
72 } }