11 #include "../../FaissAssert.h"
13 namespace faiss {
namespace gpu {
15 template <
typename GpuIndex>
16 IndexWrapper<GpuIndex>::IndexWrapper(
18 std::function<std::unique_ptr<GpuIndex>(GpuResources*,
int)> init) {
19 FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
20 for (
int i = 0; i < numGpus; ++i) {
21 auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
22 new StandardGpuResources);
24 subIndex.emplace_back(init(res.get(), i));
25 resources.emplace_back(std::move(res));
33 for (
auto& index : subIndex) {
34 proxyIndex->addIndex(index.get());
39 template <
typename GpuIndex>
41 IndexWrapper<GpuIndex>::getIndex() {
42 if ((
bool) proxyIndex) {
43 return proxyIndex.get();
45 FAISS_ASSERT(!subIndex.empty());
46 return subIndex.front().get();
50 template <
typename GpuIndex>
52 IndexWrapper<GpuIndex>::runOnIndices(std::function<
void(GpuIndex*)> f) {
54 if ((
bool) proxyIndex) {
55 proxyIndex->runOnIndex(
57 f(dynamic_cast<GpuIndex*>(index));
60 FAISS_ASSERT(!subIndex.empty());
61 f(subIndex.front().get());
65 template <
typename GpuIndex>
67 IndexWrapper<GpuIndex>::setNumProbes(
int nprobe) {
68 runOnIndices([nprobe](GpuIndex* index) {
69 index->setNumProbes(nprobe);