Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IndexWrapper-inl.h
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "../../FaissAssert.h"
12 
13 namespace faiss { namespace gpu {
14 
15 template <typename GpuIndex>
16 IndexWrapper<GpuIndex>::IndexWrapper(
17  int numGpus,
18  std::function<std::unique_ptr<GpuIndex>(GpuResources*, int)> init) {
19  FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
20  for (int i = 0; i < numGpus; ++i) {
21  auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
22  new StandardGpuResources);
23 
24  subIndex.emplace_back(init(res.get(), i));
25  resources.emplace_back(std::move(res));
26  }
27 
28  if (numGpus > 1) {
29  // create proxy
30  proxyIndex =
31  std::unique_ptr<faiss::gpu::IndexProxy>(new faiss::gpu::IndexProxy);
32 
33  for (auto& index : subIndex) {
34  proxyIndex->addIndex(index.get());
35  }
36  }
37 }
38 
39 template <typename GpuIndex>
41 IndexWrapper<GpuIndex>::getIndex() {
42  if ((bool) proxyIndex) {
43  return proxyIndex.get();
44  } else {
45  FAISS_ASSERT(!subIndex.empty());
46  return subIndex.front().get();
47  }
48 }
49 
50 template <typename GpuIndex>
51 void
52 IndexWrapper<GpuIndex>::runOnIndices(std::function<void(GpuIndex*)> f) {
53 
54  if ((bool) proxyIndex) {
55  proxyIndex->runOnIndex(
56  [f](faiss::Index* index) {
57  f(dynamic_cast<GpuIndex*>(index));
58  });
59  } else {
60  FAISS_ASSERT(!subIndex.empty());
61  f(subIndex.front().get());
62  }
63 }
64 
65 template <typename GpuIndex>
66 void
67 IndexWrapper<GpuIndex>::setNumProbes(int nprobe) {
68  runOnIndices([nprobe](GpuIndex* index) {
69  index->setNumProbes(nprobe);
70  });
71 }
72 
73 } }