11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
23 namespace faiss {
namespace gpu {
25 constexpr
auto kMax = std::numeric_limits<float>::max();
26 constexpr
auto kMin = std::numeric_limits<float>::min();
30 inline __device__
int binarySearchForBucket(
int* prefixSumOffsets,
36 while (end - start > 0) {
37 int mid = start + (end - start) / 2;
39 int midVal = prefixSumOffsets[mid];
50 assert(start != size);
55 template <
int ThreadsPerBlock,
60 pass2SelectLists(Tensor<float, 2, true> heapDistances,
61 Tensor<int, 2, true> heapIndices,
63 Tensor<int, 2, true> prefixSumOffsets,
64 Tensor<int, 2, true> topQueryToCentroid,
67 Tensor<float, 2, true> outDistances,
68 Tensor<long, 2, true> outIndices) {
69 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
71 __shared__
float smemK[kNumWarps * NumWarpQ];
72 __shared__
int smemV[kNumWarps * NumWarpQ];
74 constexpr
auto kInit = Dir ? kMin : kMax;
75 BlockSelect<float, int, Dir, Comparator<float>,
76 NumWarpQ, NumThreadQ, ThreadsPerBlock>
77 heap(kInit, -1, smemK, smemV, k);
79 auto queryId = blockIdx.x;
80 int num = heapDistances.getSize(1);
81 int limit = utils::roundDown(num, kWarpSize);
84 auto heapDistanceStart = heapDistances[queryId];
88 for (; i < limit; i += blockDim.x) {
89 heap.add(heapDistanceStart[i], i);
94 heap.addThreadQ(heapDistanceStart[i], i);
100 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
101 outDistances[queryId][i] = smemK[i];
118 int offset = heapIndices[queryId][v];
123 int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
124 prefixSumOffsets.getSize(1),
129 int listId = topQueryToCentroid[queryId][probe];
133 int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
134 int listOffset = offset - listStart;
137 if (opt == INDICES_32_BIT) {
138 index = (long) ((
int*) listIndices[listId])[listOffset];
139 }
else if (opt == INDICES_64_BIT) {
140 index = ((
long*) listIndices[listId])[listOffset];
142 index = ((long) listId << 32 | (
long) listOffset);
146 outIndices[queryId][i] = index;
151 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
152 Tensor<int, 2, true>& heapIndices,
153 thrust::device_vector<void*>& listIndices,
154 IndicesOptions indicesOptions,
155 Tensor<int, 2, true>& prefixSumOffsets,
156 Tensor<int, 2, true>& topQueryToCentroid,
159 Tensor<float, 2, true>& outDistances,
160 Tensor<long, 2, true>& outIndices,
161 cudaStream_t stream) {
162 constexpr
auto kThreadsPerBlock = 128;
164 auto grid = dim3(topQueryToCentroid.getSize(0));
165 auto block = dim3(kThreadsPerBlock);
167 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
169 pass2SelectLists<kThreadsPerBlock, \
170 NUM_WARP_Q, NUM_THREAD_Q, DIR> \
171 <<<grid, block, 0, stream>>>(heapDistances, \
173 listIndices.data().get(), \
175 topQueryToCentroid, \
184 #define RUN_PASS_DIR(DIR) \
187 RUN_PASS(1, 1, DIR); \
188 } else if (k <= 32) { \
189 RUN_PASS(32, 2, DIR); \
190 } else if (k <= 64) { \
191 RUN_PASS(64, 3, DIR); \
192 } else if (k <= 128) { \
193 RUN_PASS(128, 3, DIR); \
194 } else if (k <= 256) { \
195 RUN_PASS(256, 4, DIR); \
196 } else if (k <= 512) { \
197 RUN_PASS(512, 8, DIR); \
198 } else if (k <= 1024) { \
199 RUN_PASS(1024, 8, DIR); \
210 FAISS_ASSERT_FMT(
false,
"unimplemented k value (%d)", k);