9 #include "IVFUtils.cuh"
10 #include "../utils/DeviceDefs.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Limits.cuh"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
22 namespace faiss {
namespace gpu {
26 inline __device__
int binarySearchForBucket(
int* prefixSumOffsets,
32 while (end - start > 0) {
33 int mid = start + (end - start) / 2;
35 int midVal = prefixSumOffsets[mid];
46 assert(start != size);
51 template <
int ThreadsPerBlock,
56 pass2SelectLists(Tensor<float, 2, true> heapDistances,
57 Tensor<int, 2, true> heapIndices,
59 Tensor<int, 2, true> prefixSumOffsets,
60 Tensor<int, 2, true> topQueryToCentroid,
63 Tensor<float, 2, true> outDistances,
64 Tensor<long, 2, true> outIndices) {
65 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
67 __shared__
float smemK[kNumWarps * NumWarpQ];
68 __shared__
int smemV[kNumWarps * NumWarpQ];
70 constexpr
auto kInit = Dir ? kFloatMin : kFloatMax;
71 BlockSelect<float, int, Dir, Comparator<float>,
72 NumWarpQ, NumThreadQ, ThreadsPerBlock>
73 heap(kInit, -1, smemK, smemV, k);
75 auto queryId = blockIdx.x;
76 int num = heapDistances.getSize(1);
77 int limit = utils::roundDown(num, kWarpSize);
80 auto heapDistanceStart = heapDistances[queryId];
84 for (; i < limit; i += blockDim.x) {
85 heap.add(heapDistanceStart[i], i);
90 heap.addThreadQ(heapDistanceStart[i], i);
96 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
97 outDistances[queryId][i] = smemK[i];
114 int offset = heapIndices[queryId][v];
119 int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
120 prefixSumOffsets.getSize(1),
125 int listId = topQueryToCentroid[queryId][probe];
129 int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
130 int listOffset = offset - listStart;
133 if (opt == INDICES_32_BIT) {
134 index = (long) ((
int*) listIndices[listId])[listOffset];
135 }
else if (opt == INDICES_64_BIT) {
136 index = ((
long*) listIndices[listId])[listOffset];
138 index = ((long) listId << 32 | (
long) listOffset);
142 outIndices[queryId][i] = index;
147 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
148 Tensor<int, 2, true>& heapIndices,
149 thrust::device_vector<void*>& listIndices,
150 IndicesOptions indicesOptions,
151 Tensor<int, 2, true>& prefixSumOffsets,
152 Tensor<int, 2, true>& topQueryToCentroid,
155 Tensor<float, 2, true>& outDistances,
156 Tensor<long, 2, true>& outIndices,
157 cudaStream_t stream) {
158 auto grid = dim3(topQueryToCentroid.getSize(0));
160 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \
162 pass2SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
163 <<<grid, BLOCK, 0, stream>>>(heapDistances, \
165 listIndices.data().get(), \
167 topQueryToCentroid, \
176 #if GPU_MAX_SELECTION_K >= 2048
179 #define RUN_PASS_DIR(DIR) \
182 RUN_PASS(128, 1, 1, DIR); \
183 } else if (k <= 32) { \
184 RUN_PASS(128, 32, 2, DIR); \
185 } else if (k <= 64) { \
186 RUN_PASS(128, 64, 3, DIR); \
187 } else if (k <= 128) { \
188 RUN_PASS(128, 128, 3, DIR); \
189 } else if (k <= 256) { \
190 RUN_PASS(128, 256, 4, DIR); \
191 } else if (k <= 512) { \
192 RUN_PASS(128, 512, 8, DIR); \
193 } else if (k <= 1024) { \
194 RUN_PASS(128, 1024, 8, DIR); \
195 } else if (k <= 2048) { \
196 RUN_PASS(64, 2048, 8, DIR); \
202 #define RUN_PASS_DIR(DIR) \
205 RUN_PASS(128, 1, 1, DIR); \
206 } else if (k <= 32) { \
207 RUN_PASS(128, 32, 2, DIR); \
208 } else if (k <= 64) { \
209 RUN_PASS(128, 64, 3, DIR); \
210 } else if (k <= 128) { \
211 RUN_PASS(128, 128, 3, DIR); \
212 } else if (k <= 256) { \
213 RUN_PASS(128, 256, 4, DIR); \
214 } else if (k <= 512) { \
215 RUN_PASS(128, 512, 8, DIR); \
216 } else if (k <= 1024) { \
217 RUN_PASS(128, 1024, 8, DIR); \
221 #endif // GPU_MAX_SELECTION_K
230 FAISS_ASSERT_FMT(
false,
"unimplemented k value (%d)", k);