9 #include "IVFUtils.cuh"
10 #include "../utils/DeviceDefs.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Limits.cuh"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
22 namespace faiss {
namespace gpu {
24 template <
int ThreadsPerBlock,
int NumWarpQ,
int NumThreadQ,
bool Dir>
26 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
27 Tensor<float, 1, true> distance,
30 Tensor<float, 3, true> heapDistances,
31 Tensor<int, 3, true> heapIndices) {
32 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
34 __shared__
float smemK[kNumWarps * NumWarpQ];
35 __shared__
int smemV[kNumWarps * NumWarpQ];
37 constexpr
auto kInit = Dir ? kFloatMin : kFloatMax;
38 BlockSelect<float, int, Dir, Comparator<float>,
39 NumWarpQ, NumThreadQ, ThreadsPerBlock>
40 heap(kInit, -1, smemK, smemV, k);
42 auto queryId = blockIdx.y;
43 auto sliceId = blockIdx.x;
44 auto numSlices = gridDim.x;
46 int sliceSize = (nprobe / numSlices);
47 int sliceStart = sliceSize * sliceId;
48 int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
49 sliceStart + sliceSize;
50 auto offsets = prefixSumOffsets[queryId].data();
53 int start = *(&offsets[sliceStart] - 1);
54 int end = offsets[sliceEnd - 1];
56 int num = end - start;
57 int limit = utils::roundDown(num, kWarpSize);
60 auto distanceStart = distance[start].data();
64 for (; i < limit; i += blockDim.x) {
65 heap.add(distanceStart[i], start + i);
70 heap.addThreadQ(distanceStart[i], start + i);
78 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
79 heapDistances[queryId][sliceId][i] = smemK[i];
80 heapIndices[queryId][sliceId][i] = smemV[i];
85 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
86 Tensor<float, 1, true>& distance,
90 Tensor<float, 3, true>& heapDistances,
91 Tensor<int, 3, true>& heapIndices,
92 cudaStream_t stream) {
94 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
96 auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
98 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \
100 pass1SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
101 <<<grid, BLOCK, 0, stream>>>(prefixSumOffsets, \
111 #if GPU_MAX_SELECTION_K >= 2048
114 #define RUN_PASS_DIR(DIR) \
117 RUN_PASS(128, 1, 1, DIR); \
118 } else if (k <= 32) { \
119 RUN_PASS(128, 32, 2, DIR); \
120 } else if (k <= 64) { \
121 RUN_PASS(128, 64, 3, DIR); \
122 } else if (k <= 128) { \
123 RUN_PASS(128, 128, 3, DIR); \
124 } else if (k <= 256) { \
125 RUN_PASS(128, 256, 4, DIR); \
126 } else if (k <= 512) { \
127 RUN_PASS(128, 512, 8, DIR); \
128 } else if (k <= 1024) { \
129 RUN_PASS(128, 1024, 8, DIR); \
130 } else if (k <= 2048) { \
131 RUN_PASS(64, 2048, 8, DIR); \
137 #define RUN_PASS_DIR(DIR) \
140 RUN_PASS(128, 1, 1, DIR); \
141 } else if (k <= 32) { \
142 RUN_PASS(128, 32, 2, DIR); \
143 } else if (k <= 64) { \
144 RUN_PASS(128, 64, 3, DIR); \
145 } else if (k <= 128) { \
146 RUN_PASS(128, 128, 3, DIR); \
147 } else if (k <= 256) { \
148 RUN_PASS(128, 256, 4, DIR); \
149 } else if (k <= 512) { \
150 RUN_PASS(128, 512, 8, DIR); \
151 } else if (k <= 1024) { \
152 RUN_PASS(128, 1024, 8, DIR); \
156 #endif // GPU_MAX_SELECTION_K