10 #include "../../FaissAssert.h"
12 namespace faiss {
namespace gpu {
14 template <
typename GpuIndex>
15 IndexWrapper<GpuIndex>::IndexWrapper(
17 std::function<std::unique_ptr<GpuIndex>(GpuResources*,
int)> init) {
18 FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
19 for (
int i = 0; i < numGpus; ++i) {
20 auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
21 new StandardGpuResources);
23 subIndex.emplace_back(init(res.get(), i));
24 resources.emplace_back(std::move(res));
32 for (
auto& index : subIndex) {
33 proxyIndex->addIndex(index.get());
38 template <
typename GpuIndex>
40 IndexWrapper<GpuIndex>::getIndex() {
41 if ((
bool) proxyIndex) {
42 return proxyIndex.get();
44 FAISS_ASSERT(!subIndex.empty());
45 return subIndex.front().get();
49 template <
typename GpuIndex>
51 IndexWrapper<GpuIndex>::runOnIndices(std::function<
void(GpuIndex*)> f) {
53 if ((
bool) proxyIndex) {
54 proxyIndex->runOnIndex(
56 f(dynamic_cast<GpuIndex*>(index));
59 FAISS_ASSERT(!subIndex.empty());
60 f(subIndex.front().get());
64 template <
typename GpuIndex>
66 IndexWrapper<GpuIndex>::setNumProbes(
int nprobe) {
67 runOnIndices([nprobe](GpuIndex* index) {
68 index->setNumProbes(nprobe);