11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
23 namespace faiss {
namespace gpu {
25 constexpr
auto kMax = std::numeric_limits<float>::max();
26 constexpr
auto kMin = std::numeric_limits<float>::min();
28 template <
int ThreadsPerBlock,
int NumWarpQ,
int NumThreadQ,
bool Dir>
30 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
31 Tensor<float, 1, true> distance,
34 Tensor<float, 3, true> heapDistances,
35 Tensor<int, 3, true> heapIndices) {
36 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
38 __shared__
float smemK[kNumWarps * NumWarpQ];
39 __shared__
int smemV[kNumWarps * NumWarpQ];
41 constexpr
auto kInit = Dir ? kMin : kMax;
42 BlockSelect<float, int, Dir, Comparator<float>,
43 NumWarpQ, NumThreadQ, ThreadsPerBlock>
44 heap(kInit, -1, smemK, smemV, k);
46 auto queryId = blockIdx.y;
47 auto sliceId = blockIdx.x;
48 auto numSlices = gridDim.x;
50 int sliceSize = (nprobe / numSlices);
51 int sliceStart = sliceSize * sliceId;
52 int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
53 sliceStart + sliceSize;
54 auto offsets = prefixSumOffsets[queryId].data();
57 int start = *(&offsets[sliceStart] - 1);
58 int end = offsets[sliceEnd - 1];
60 int num = end - start;
61 int limit = utils::roundDown(num, kWarpSize);
64 auto distanceStart = distance[start].data();
68 for (; i < limit; i += blockDim.x) {
69 heap.add(distanceStart[i], start + i);
74 heap.addThreadQ(distanceStart[i], start + i);
82 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
83 heapDistances[queryId][sliceId][i] = smemK[i];
84 heapIndices[queryId][sliceId][i] = smemV[i];
89 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
90 Tensor<float, 1, true>& distance,
94 Tensor<float, 3, true>& heapDistances,
95 Tensor<int, 3, true>& heapIndices,
96 cudaStream_t stream) {
97 constexpr
auto kThreadsPerBlock = 128;
99 auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
100 auto block = dim3(kThreadsPerBlock);
102 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
104 pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
105 <<<grid, block, 0, stream>>>(prefixSumOffsets, \
115 #define RUN_PASS_DIR(DIR) \
118 RUN_PASS(1, 1, DIR); \
119 } else if (k <= 32) { \
120 RUN_PASS(32, 2, DIR); \
121 } else if (k <= 64) { \
122 RUN_PASS(64, 3, DIR); \
123 } else if (k <= 128) { \
124 RUN_PASS(128, 3, DIR); \
125 } else if (k <= 256) { \
126 RUN_PASS(256, 4, DIR); \
127 } else if (k <= 512) { \
128 RUN_PASS(512, 8, DIR); \
129 } else if (k <= 1024) { \
130 RUN_PASS(1024, 8, DIR); \
141 FAISS_ASSERT_FMT(
false,
"unimplemented k value (%d)", k);