12 #include "PQScanMultiPassPrecomputed.cuh"
13 #include "../GpuResources.h"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/MathOperators.cuh"
22 #include "../utils/StaticUtils.h"
25 namespace faiss {
namespace gpu {
29 template <
typename LookupT,
typename LookupVecT>
30 inline __device__
void
31 loadPrecomputedTerm(LookupT* smem,
35 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
40 if (numCodes % kWordSize == 0) {
41 constexpr
int kUnroll = 2;
46 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
47 limitVec *= kUnroll * blockDim.x;
49 LookupVecT* smemV = (LookupVecT*) smem;
50 LookupVecT* term2StartV = (LookupVecT*) term2Start;
51 LookupVecT* term3StartV = (LookupVecT*) term3Start;
53 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
54 LookupVecT vals[kUnroll];
57 for (
int j = 0; j < kUnroll; ++j) {
59 LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
63 for (
int j = 0; j < kUnroll; ++j) {
65 LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
67 vals[j] = Math<LookupVecT>::add(vals[j], q);
71 for (
int j = 0; j < kUnroll; ++j) {
72 LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
78 int remainder = limitVec * kWordSize;
80 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
81 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
85 constexpr
int kUnroll = 4;
87 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
90 for (; i < limit; i += kUnroll * blockDim.x) {
91 LookupT vals[kUnroll];
94 for (
int j = 0; j < kUnroll; ++j) {
95 vals[j] = term2Start[i + j * blockDim.x];
99 for (
int j = 0; j < kUnroll; ++j) {
100 vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
104 for (
int j = 0; j < kUnroll; ++j) {
105 smem[i + j * blockDim.x] = vals[j];
109 for (; i < numCodes; i += blockDim.x) {
110 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
115 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
117 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
118 Tensor<float, 2, true> precompTerm1,
119 Tensor<LookupT, 3, true> precompTerm2,
120 Tensor<LookupT, 3, true> precompTerm3,
121 Tensor<int, 2, true> topQueryToCentroid,
124 Tensor<int, 2, true> prefixSumOffsets,
125 Tensor<float, 1, true> distance) {
128 extern __shared__
char smemTerm23[];
129 LookupT* term23 = (LookupT*) smemTerm23;
132 auto queryId = blockIdx.y;
133 auto probeId = blockIdx.x;
134 auto codesPerSubQuantizer = precompTerm2.getSize(2);
135 auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
139 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
140 float* distanceOut = distance[outBase].data();
142 auto listId = topQueryToCentroid[queryId][probeId];
148 unsigned char* codeList = (
unsigned char*) listCodes[listId];
149 int limit = listLengths[listId];
151 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
152 (NumSubQuantizers / 4);
153 unsigned int code32[kNumCode32];
154 unsigned int nextCode32[kNumCode32];
157 if (threadIdx.x < limit) {
158 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
162 float term1 = precompTerm1[queryId][probeId];
163 loadPrecomputedTerm<LookupT, LookupVecT>(term23,
164 precompTerm2[listId].data(),
165 precompTerm3[queryId].data(),
173 for (
int codeIndex = threadIdx.x;
175 codeIndex += blockDim.x) {
177 if (codeIndex + blockDim.x < limit) {
178 LoadCode32<NumSubQuantizers>::load(
179 nextCode32, codeList, codeIndex + blockDim.x);
185 for (
int word = 0; word < kNumCode32; ++word) {
186 constexpr
int kBytesPerCode32 =
187 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
189 if (kBytesPerCode32 == 1) {
190 auto code = code32[0];
191 dist = ConvertTo<float>::to(term23[code]);
195 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
196 auto code = getByte(code32[word], byte * 8, 8);
199 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
201 dist += ConvertTo<float>::to(term23[offset + code]);
209 distanceOut[codeIndex] = dist;
213 for (
int word = 0; word < kNumCode32; ++word) {
214 code32[word] = nextCode32[word];
220 runMultiPassTile(Tensor<float, 2, true>& queries,
221 Tensor<float, 2, true>& precompTerm1,
222 NoTypeTensor<3, true>& precompTerm2,
223 NoTypeTensor<3, true>& precompTerm3,
224 Tensor<int, 2, true>& topQueryToCentroid,
225 bool useFloat16Lookup,
227 int numSubQuantizers,
228 int numSubQuantizerCodes,
229 thrust::device_vector<void*>& listCodes,
230 thrust::device_vector<void*>& listIndices,
231 IndicesOptions indicesOptions,
232 thrust::device_vector<int>& listLengths,
233 Tensor<char, 1, true>& thrustMem,
234 Tensor<int, 2, true>& prefixSumOffsets,
235 Tensor<float, 1, true>& allDistances,
236 Tensor<float, 3, true>& heapDistances,
237 Tensor<int, 3, true>& heapIndices,
239 Tensor<float, 2, true>& outDistances,
240 Tensor<long, 2, true>& outIndices,
241 cudaStream_t stream) {
244 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
250 auto kThreadsPerBlock = 256;
252 auto grid = dim3(topQueryToCentroid.getSize(1),
253 topQueryToCentroid.getSize(0));
254 auto block = dim3(kThreadsPerBlock);
257 auto smem =
sizeof(float);
258 #ifdef FAISS_USE_FLOAT16
259 if (useFloat16Lookup) {
263 smem *= numSubQuantizers * numSubQuantizerCodes;
264 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
266 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
268 auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
269 auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
271 pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
272 <<<grid, block, smem, stream>>>( \
277 topQueryToCentroid, \
278 listCodes.data().get(), \
279 listLengths.data().get(), \
284 #ifdef FAISS_USE_FLOAT16
285 #define RUN_PQ(NUM_SUB_Q) \
287 if (useFloat16Lookup) { \
288 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
290 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
294 #define RUN_PQ(NUM_SUB_Q) \
296 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
298 #endif // FAISS_USE_FLOAT16
300 switch (bytesPerCode) {
356 runPass1SelectLists(prefixSumOffsets,
358 topQueryToCentroid.getSize(1),
366 auto flatHeapDistances = heapDistances.downcastInner<2>();
367 auto flatHeapIndices = heapIndices.downcastInner<2>();
369 runPass2SelectLists(flatHeapDistances,
381 CUDA_VERIFY(cudaGetLastError());
384 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
385 Tensor<float, 2, true>& precompTerm1,
386 NoTypeTensor<3, true>& precompTerm2,
387 NoTypeTensor<3, true>& precompTerm3,
388 Tensor<int, 2, true>& topQueryToCentroid,
389 bool useFloat16Lookup,
391 int numSubQuantizers,
392 int numSubQuantizerCodes,
393 thrust::device_vector<void*>& listCodes,
394 thrust::device_vector<void*>& listIndices,
395 IndicesOptions indicesOptions,
396 thrust::device_vector<int>& listLengths,
400 Tensor<float, 2, true>& outDistances,
402 Tensor<long, 2, true>& outIndices,
404 constexpr
int kMinQueryTileSize = 8;
405 constexpr
int kMaxQueryTileSize = 128;
406 constexpr
int kThrustMemSize = 16384;
408 int nprobe = topQueryToCentroid.getSize(1);
410 auto& mem = res->getMemoryManagerCurrentDevice();
411 auto stream = res->getDefaultStreamCurrentDevice();
415 DeviceTensor<char, 1, true> thrustMem1(
416 mem, {kThrustMemSize}, stream);
417 DeviceTensor<char, 1, true> thrustMem2(
418 mem, {kThrustMemSize}, stream);
419 DeviceTensor<char, 1, true>* thrustMem[2] =
420 {&thrustMem1, &thrustMem2};
424 size_t sizeAvailable = mem.getSizeAvailable();
428 constexpr
int kNProbeSplit = 8;
429 int pass2Chunks = std::min(nprobe, kNProbeSplit);
431 size_t sizeForFirstSelectPass =
432 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
435 size_t sizePerQuery =
437 ((nprobe *
sizeof(int) +
sizeof(
int)) +
438 nprobe * maxListLength *
sizeof(
float) +
439 sizeForFirstSelectPass);
441 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
443 if (queryTileSize < kMinQueryTileSize) {
444 queryTileSize = kMinQueryTileSize;
445 }
else if (queryTileSize > kMaxQueryTileSize) {
446 queryTileSize = kMaxQueryTileSize;
451 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
452 std::numeric_limits<int>::max());
457 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
458 mem, {queryTileSize * nprobe + 1}, stream);
459 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
460 mem, {queryTileSize * nprobe + 1}, stream);
462 DeviceTensor<int, 2, true> prefixSumOffsets1(
463 prefixSumOffsetSpace1[1].data(),
464 {queryTileSize, nprobe});
465 DeviceTensor<int, 2, true> prefixSumOffsets2(
466 prefixSumOffsetSpace2[1].data(),
467 {queryTileSize, nprobe});
468 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
469 {&prefixSumOffsets1, &prefixSumOffsets2};
473 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
477 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
482 DeviceTensor<float, 1, true> allDistances1(
483 mem, {queryTileSize * nprobe * maxListLength}, stream);
484 DeviceTensor<float, 1, true> allDistances2(
485 mem, {queryTileSize * nprobe * maxListLength}, stream);
486 DeviceTensor<float, 1, true>* allDistances[2] =
487 {&allDistances1, &allDistances2};
489 DeviceTensor<float, 3, true> heapDistances1(
490 mem, {queryTileSize, pass2Chunks, k}, stream);
491 DeviceTensor<float, 3, true> heapDistances2(
492 mem, {queryTileSize, pass2Chunks, k}, stream);
493 DeviceTensor<float, 3, true>* heapDistances[2] =
494 {&heapDistances1, &heapDistances2};
496 DeviceTensor<int, 3, true> heapIndices1(
497 mem, {queryTileSize, pass2Chunks, k}, stream);
498 DeviceTensor<int, 3, true> heapIndices2(
499 mem, {queryTileSize, pass2Chunks, k}, stream);
500 DeviceTensor<int, 3, true>* heapIndices[2] =
501 {&heapIndices1, &heapIndices2};
503 auto streams = res->getAlternateStreamsCurrentDevice();
504 streamWait(streams, {stream});
508 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
509 int numQueriesInTile =
510 std::min(queryTileSize, queries.getSize(0) - query);
512 auto prefixSumOffsetsView =
513 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
515 auto coarseIndicesView =
516 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
518 queries.narrowOutermost(query, numQueriesInTile);
520 precompTerm1.narrowOutermost(query, numQueriesInTile);
522 precompTerm3.narrowOutermost(query, numQueriesInTile);
524 auto heapDistancesView =
525 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
526 auto heapIndicesView =
527 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
529 auto outDistanceView =
530 outDistances.narrowOutermost(query, numQueriesInTile);
531 auto outIndicesView =
532 outIndices.narrowOutermost(query, numQueriesInTile);
534 runMultiPassTile(queryView,
542 numSubQuantizerCodes,
547 *thrustMem[curStream],
548 prefixSumOffsetsView,
549 *allDistances[curStream],
557 curStream = (curStream + 1) % 2;
560 streamWait({stream}, streams);