12 #include "../../FaissAssert.h"
14 namespace faiss {
namespace gpu {
16 template <
typename GpuIndex>
17 IndexWrapper<GpuIndex>::IndexWrapper(
19 std::function<std::unique_ptr<GpuIndex>(GpuResources*,
int)> init) {
20 FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices());
21 for (
int i = 0; i < numGpus; ++i) {
22 auto res = std::unique_ptr<faiss::gpu::StandardGpuResources>(
23 new StandardGpuResources);
25 subIndex.emplace_back(init(res.get(), i));
26 resources.emplace_back(std::move(res));
34 for (
auto& index : subIndex) {
35 proxyIndex->addIndex(index.get());
40 template <
typename GpuIndex>
42 IndexWrapper<GpuIndex>::getIndex() {
43 if ((
bool) proxyIndex) {
44 return proxyIndex.get();
46 FAISS_ASSERT(!subIndex.empty());
47 return subIndex.front().get();
51 template <
typename GpuIndex>
53 IndexWrapper<GpuIndex>::runOnIndices(std::function<
void(GpuIndex*)> f) {
55 if ((
bool) proxyIndex) {
56 proxyIndex->runOnIndex(
58 f(dynamic_cast<GpuIndex*>(index));
61 FAISS_ASSERT(!subIndex.empty());
62 f(subIndex.front().get());
66 template <
typename GpuIndex>
68 IndexWrapper<GpuIndex>::setNumProbes(
int nprobe) {
69 runOnIndices([nprobe](GpuIndex* index) {
70 index->setNumProbes(nprobe);