Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IndexWrapper-inl.h
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "../../FaissAssert.h"
13 
14 namespace faiss { namespace gpu {
15 
16 template <typename GpuIndex>
17 IndexWrapper<GpuIndex>::IndexWrapper(
18  int numGpus,
19  std::function<std::unique_ptr<GpuIndex>(GpuResources*, int)> init) {
20  FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
21  for (int i = 0; i < numGpus; ++i) {
22  auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
23  new StandardGpuResources);
24 
25  subIndex.emplace_back(init(res.get(), i));
26  resources.emplace_back(std::move(res));
27  }
28 
29  if (numGpus > 1) {
30  // create proxy
31  proxyIndex =
32  std::unique_ptr<faiss::gpu::IndexProxy>(new faiss::gpu::IndexProxy);
33 
34  for (auto& index : subIndex) {
35  proxyIndex->addIndex(index.get());
36  }
37  }
38 }
39 
40 template <typename GpuIndex>
42 IndexWrapper<GpuIndex>::getIndex() {
43  if ((bool) proxyIndex) {
44  return proxyIndex.get();
45  } else {
46  FAISS_ASSERT(!subIndex.empty());
47  return subIndex.front().get();
48  }
49 }
50 
51 template <typename GpuIndex>
52 void
53 IndexWrapper<GpuIndex>::runOnIndices(std::function<void(GpuIndex*)> f) {
54 
55  if ((bool) proxyIndex) {
56  proxyIndex->runOnIndex(
57  [f](faiss::Index* index) {
58  f(dynamic_cast<GpuIndex*>(index));
59  });
60  } else {
61  FAISS_ASSERT(!subIndex.empty());
62  f(subIndex.front().get());
63  }
64 }
65 
66 template <typename GpuIndex>
67 void
68 IndexWrapper<GpuIndex>::setNumProbes(int nprobe) {
69  runOnIndices([nprobe](GpuIndex* index) {
70  index->setNumProbes(nprobe);
71  });
72 }
73 
74 } }