10 #include "PQScanMultiPassPrecomputed.cuh"
11 #include "../GpuResources.h"
12 #include "PQCodeLoad.cuh"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/MathOperators.cuh"
20 #include "../utils/StaticUtils.h"
23 namespace faiss {
namespace gpu {
27 template <
typename LookupT,
typename LookupVecT>
28 inline __device__
void
29 loadPrecomputedTerm(LookupT* smem,
33 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
38 if (numCodes % kWordSize == 0) {
39 constexpr
int kUnroll = 2;
44 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
45 limitVec *= kUnroll * blockDim.x;
47 LookupVecT* smemV = (LookupVecT*) smem;
48 LookupVecT* term2StartV = (LookupVecT*) term2Start;
49 LookupVecT* term3StartV = (LookupVecT*) term3Start;
51 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
52 LookupVecT vals[kUnroll];
55 for (
int j = 0; j < kUnroll; ++j) {
57 LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
61 for (
int j = 0; j < kUnroll; ++j) {
63 LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
65 vals[j] = Math<LookupVecT>::add(vals[j], q);
69 for (
int j = 0; j < kUnroll; ++j) {
70 LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
76 int remainder = limitVec * kWordSize;
78 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
79 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
83 constexpr
int kUnroll = 4;
85 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
88 for (; i < limit; i += kUnroll * blockDim.x) {
89 LookupT vals[kUnroll];
92 for (
int j = 0; j < kUnroll; ++j) {
93 vals[j] = term2Start[i + j * blockDim.x];
97 for (
int j = 0; j < kUnroll; ++j) {
98 vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
102 for (
int j = 0; j < kUnroll; ++j) {
103 smem[i + j * blockDim.x] = vals[j];
107 for (; i < numCodes; i += blockDim.x) {
108 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
113 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
115 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
116 Tensor<float, 2, true> precompTerm1,
117 Tensor<LookupT, 3, true> precompTerm2,
118 Tensor<LookupT, 3, true> precompTerm3,
119 Tensor<int, 2, true> topQueryToCentroid,
122 Tensor<int, 2, true> prefixSumOffsets,
123 Tensor<float, 1, true> distance) {
126 extern __shared__
char smemTerm23[];
127 LookupT* term23 = (LookupT*) smemTerm23;
130 auto queryId = blockIdx.y;
131 auto probeId = blockIdx.x;
132 auto codesPerSubQuantizer = precompTerm2.getSize(2);
133 auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
137 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
138 float* distanceOut = distance[outBase].data();
140 auto listId = topQueryToCentroid[queryId][probeId];
146 unsigned char* codeList = (
unsigned char*) listCodes[listId];
147 int limit = listLengths[listId];
149 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
150 (NumSubQuantizers / 4);
151 unsigned int code32[kNumCode32];
152 unsigned int nextCode32[kNumCode32];
155 if (threadIdx.x < limit) {
156 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
160 float term1 = precompTerm1[queryId][probeId];
161 loadPrecomputedTerm<LookupT, LookupVecT>(term23,
162 precompTerm2[listId].data(),
163 precompTerm3[queryId].data(),
171 for (
int codeIndex = threadIdx.x;
173 codeIndex += blockDim.x) {
175 if (codeIndex + blockDim.x < limit) {
176 LoadCode32<NumSubQuantizers>::load(
177 nextCode32, codeList, codeIndex + blockDim.x);
183 for (
int word = 0; word < kNumCode32; ++word) {
184 constexpr
int kBytesPerCode32 =
185 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
187 if (kBytesPerCode32 == 1) {
188 auto code = code32[0];
189 dist = ConvertTo<float>::to(term23[code]);
193 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
194 auto code = getByte(code32[word], byte * 8, 8);
197 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
199 dist += ConvertTo<float>::to(term23[offset + code]);
207 distanceOut[codeIndex] = dist;
211 for (
int word = 0; word < kNumCode32; ++word) {
212 code32[word] = nextCode32[word];
218 runMultiPassTile(Tensor<float, 2, true>& queries,
219 Tensor<float, 2, true>& precompTerm1,
220 NoTypeTensor<3, true>& precompTerm2,
221 NoTypeTensor<3, true>& precompTerm3,
222 Tensor<int, 2, true>& topQueryToCentroid,
223 bool useFloat16Lookup,
225 int numSubQuantizers,
226 int numSubQuantizerCodes,
227 thrust::device_vector<void*>& listCodes,
228 thrust::device_vector<void*>& listIndices,
229 IndicesOptions indicesOptions,
230 thrust::device_vector<int>& listLengths,
231 Tensor<char, 1, true>& thrustMem,
232 Tensor<int, 2, true>& prefixSumOffsets,
233 Tensor<float, 1, true>& allDistances,
234 Tensor<float, 3, true>& heapDistances,
235 Tensor<int, 3, true>& heapIndices,
237 Tensor<float, 2, true>& outDistances,
238 Tensor<long, 2, true>& outIndices,
239 cudaStream_t stream) {
242 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
248 auto kThreadsPerBlock = 256;
250 auto grid = dim3(topQueryToCentroid.getSize(1),
251 topQueryToCentroid.getSize(0));
252 auto block = dim3(kThreadsPerBlock);
255 auto smem =
sizeof(float);
256 #ifdef FAISS_USE_FLOAT16
257 if (useFloat16Lookup) {
261 smem *= numSubQuantizers * numSubQuantizerCodes;
262 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
264 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
266 auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
267 auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
269 pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
270 <<<grid, block, smem, stream>>>( \
275 topQueryToCentroid, \
276 listCodes.data().get(), \
277 listLengths.data().get(), \
282 #ifdef FAISS_USE_FLOAT16
283 #define RUN_PQ(NUM_SUB_Q) \
285 if (useFloat16Lookup) { \
286 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
288 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
292 #define RUN_PQ(NUM_SUB_Q) \
294 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
296 #endif // FAISS_USE_FLOAT16
298 switch (bytesPerCode) {
359 runPass1SelectLists(prefixSumOffsets,
361 topQueryToCentroid.getSize(1),
369 auto flatHeapDistances = heapDistances.downcastInner<2>();
370 auto flatHeapIndices = heapIndices.downcastInner<2>();
372 runPass2SelectLists(flatHeapDistances,
387 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
388 Tensor<float, 2, true>& precompTerm1,
389 NoTypeTensor<3, true>& precompTerm2,
390 NoTypeTensor<3, true>& precompTerm3,
391 Tensor<int, 2, true>& topQueryToCentroid,
392 bool useFloat16Lookup,
394 int numSubQuantizers,
395 int numSubQuantizerCodes,
396 thrust::device_vector<void*>& listCodes,
397 thrust::device_vector<void*>& listIndices,
398 IndicesOptions indicesOptions,
399 thrust::device_vector<int>& listLengths,
403 Tensor<float, 2, true>& outDistances,
405 Tensor<long, 2, true>& outIndices,
407 constexpr
int kMinQueryTileSize = 8;
408 constexpr
int kMaxQueryTileSize = 128;
409 constexpr
int kThrustMemSize = 16384;
411 int nprobe = topQueryToCentroid.getSize(1);
413 auto& mem = res->getMemoryManagerCurrentDevice();
414 auto stream = res->getDefaultStreamCurrentDevice();
418 DeviceTensor<char, 1, true> thrustMem1(
419 mem, {kThrustMemSize}, stream);
420 DeviceTensor<char, 1, true> thrustMem2(
421 mem, {kThrustMemSize}, stream);
422 DeviceTensor<char, 1, true>* thrustMem[2] =
423 {&thrustMem1, &thrustMem2};
427 size_t sizeAvailable = mem.getSizeAvailable();
431 constexpr
int kNProbeSplit = 8;
432 int pass2Chunks = std::min(nprobe, kNProbeSplit);
434 size_t sizeForFirstSelectPass =
435 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
438 size_t sizePerQuery =
440 ((nprobe *
sizeof(int) +
sizeof(
int)) +
441 nprobe * maxListLength *
sizeof(
float) +
442 sizeForFirstSelectPass);
444 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
446 if (queryTileSize < kMinQueryTileSize) {
447 queryTileSize = kMinQueryTileSize;
448 }
else if (queryTileSize > kMaxQueryTileSize) {
449 queryTileSize = kMaxQueryTileSize;
454 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
455 std::numeric_limits<int>::max());
460 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
461 mem, {queryTileSize * nprobe + 1}, stream);
462 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
463 mem, {queryTileSize * nprobe + 1}, stream);
465 DeviceTensor<int, 2, true> prefixSumOffsets1(
466 prefixSumOffsetSpace1[1].data(),
467 {queryTileSize, nprobe});
468 DeviceTensor<int, 2, true> prefixSumOffsets2(
469 prefixSumOffsetSpace2[1].data(),
470 {queryTileSize, nprobe});
471 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
472 {&prefixSumOffsets1, &prefixSumOffsets2};
476 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
480 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
485 DeviceTensor<float, 1, true> allDistances1(
486 mem, {queryTileSize * nprobe * maxListLength}, stream);
487 DeviceTensor<float, 1, true> allDistances2(
488 mem, {queryTileSize * nprobe * maxListLength}, stream);
489 DeviceTensor<float, 1, true>* allDistances[2] =
490 {&allDistances1, &allDistances2};
492 DeviceTensor<float, 3, true> heapDistances1(
493 mem, {queryTileSize, pass2Chunks, k}, stream);
494 DeviceTensor<float, 3, true> heapDistances2(
495 mem, {queryTileSize, pass2Chunks, k}, stream);
496 DeviceTensor<float, 3, true>* heapDistances[2] =
497 {&heapDistances1, &heapDistances2};
499 DeviceTensor<int, 3, true> heapIndices1(
500 mem, {queryTileSize, pass2Chunks, k}, stream);
501 DeviceTensor<int, 3, true> heapIndices2(
502 mem, {queryTileSize, pass2Chunks, k}, stream);
503 DeviceTensor<int, 3, true>* heapIndices[2] =
504 {&heapIndices1, &heapIndices2};
506 auto streams = res->getAlternateStreamsCurrentDevice();
507 streamWait(streams, {stream});
511 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
512 int numQueriesInTile =
513 std::min(queryTileSize, queries.getSize(0) - query);
515 auto prefixSumOffsetsView =
516 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
518 auto coarseIndicesView =
519 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
521 queries.narrowOutermost(query, numQueriesInTile);
523 precompTerm1.narrowOutermost(query, numQueriesInTile);
525 precompTerm3.narrowOutermost(query, numQueriesInTile);
527 auto heapDistancesView =
528 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
529 auto heapIndicesView =
530 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
532 auto outDistanceView =
533 outDistances.narrowOutermost(query, numQueriesInTile);
534 auto outIndicesView =
535 outIndices.narrowOutermost(query, numQueriesInTile);
537 runMultiPassTile(queryView,
545 numSubQuantizerCodes,
550 *thrustMem[curStream],
551 prefixSumOffsetsView,
552 *allDistances[curStream],
560 curStream = (curStream + 1) % 2;
563 streamWait({stream}, streams);