Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtils.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "IVFUtils.cuh"
10 #include "../utils/DeviceUtils.h"
11 #include "../utils/StaticUtils.h"
12 #include "../utils/Tensor.cuh"
13 #include "../utils/ThrustAllocator.cuh"
14 #include <thrust/scan.h>
15 #include <thrust/execution_policy.h>
16 
17 namespace faiss { namespace gpu {
18 
19 // Calculates the total number of intermediate distances to consider
20 // for all queries
21 __global__ void
22 getResultLengths(Tensor<int, 2, true> topQueryToCentroid,
23  int* listLengths,
24  int totalSize,
25  Tensor<int, 2, true> length) {
26  int linearThreadId = blockIdx.x * blockDim.x + threadIdx.x;
27  if (linearThreadId >= totalSize) {
28  return;
29  }
30 
31  int nprobe = topQueryToCentroid.getSize(1);
32  int queryId = linearThreadId / nprobe;
33  int listId = linearThreadId % nprobe;
34 
35  int centroidId = topQueryToCentroid[queryId][listId];
36 
37  // Safety guard in case NaNs in input cause no list ID to be generated
38  length[queryId][listId] = (centroidId != -1) ? listLengths[centroidId] : 0;
39 }
40 
41 void runCalcListOffsets(Tensor<int, 2, true>& topQueryToCentroid,
42  thrust::device_vector<int>& listLengths,
43  Tensor<int, 2, true>& prefixSumOffsets,
44  Tensor<char, 1, true>& thrustMem,
45  cudaStream_t stream) {
46  FAISS_ASSERT(topQueryToCentroid.getSize(0) == prefixSumOffsets.getSize(0));
47  FAISS_ASSERT(topQueryToCentroid.getSize(1) == prefixSumOffsets.getSize(1));
48 
49  int totalSize = topQueryToCentroid.numElements();
50 
51  int numThreads = std::min(totalSize, getMaxThreadsCurrentDevice());
52  int numBlocks = utils::divUp(totalSize, numThreads);
53 
54  auto grid = dim3(numBlocks);
55  auto block = dim3(numThreads);
56 
57  getResultLengths<<<grid, block, 0, stream>>>(
58  topQueryToCentroid,
59  listLengths.data().get(),
60  totalSize,
61  prefixSumOffsets);
62  CUDA_TEST_ERROR();
63 
64  // Prefix sum of the indices, so we know where the intermediate
65  // results should be maintained
66  // Thrust wants a place for its temporary allocations, so provide
67  // one, so it won't call cudaMalloc/Free
68  GpuResourcesThrustAllocator alloc(thrustMem.data(),
69  thrustMem.getSizeInBytes());
70 
71  thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream),
72  prefixSumOffsets.data(),
73  prefixSumOffsets.data() + totalSize,
74  prefixSumOffsets.data());
75  CUDA_TEST_ERROR();
76 }
77 
78 } } // namespace