9 #include "PQScanMultiPassPrecomputed.cuh"
10 #include "../GpuResources.h"
11 #include "PQCodeLoad.cuh"
12 #include "IVFUtils.cuh"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceTensor.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/LoadStoreOperators.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/StaticUtils.h"
22 namespace faiss {
namespace gpu {
26 template <
typename LookupT,
typename LookupVecT>
27 inline __device__
void
28 loadPrecomputedTerm(LookupT* smem,
32 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
37 if (numCodes % kWordSize == 0) {
38 constexpr
int kUnroll = 2;
43 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
44 limitVec *= kUnroll * blockDim.x;
46 LookupVecT* smemV = (LookupVecT*) smem;
47 LookupVecT* term2StartV = (LookupVecT*) term2Start;
48 LookupVecT* term3StartV = (LookupVecT*) term3Start;
50 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
51 LookupVecT vals[kUnroll];
54 for (
int j = 0; j < kUnroll; ++j) {
56 LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
60 for (
int j = 0; j < kUnroll; ++j) {
62 LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
64 vals[j] = Math<LookupVecT>::add(vals[j], q);
68 for (
int j = 0; j < kUnroll; ++j) {
69 LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
75 int remainder = limitVec * kWordSize;
77 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
78 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
82 constexpr
int kUnroll = 4;
84 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
87 for (; i < limit; i += kUnroll * blockDim.x) {
88 LookupT vals[kUnroll];
91 for (
int j = 0; j < kUnroll; ++j) {
92 vals[j] = term2Start[i + j * blockDim.x];
96 for (
int j = 0; j < kUnroll; ++j) {
97 vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
101 for (
int j = 0; j < kUnroll; ++j) {
102 smem[i + j * blockDim.x] = vals[j];
106 for (; i < numCodes; i += blockDim.x) {
107 smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
112 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
114 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
115 Tensor<float, 2, true> precompTerm1,
116 Tensor<LookupT, 3, true> precompTerm2,
117 Tensor<LookupT, 3, true> precompTerm3,
118 Tensor<int, 2, true> topQueryToCentroid,
121 Tensor<int, 2, true> prefixSumOffsets,
122 Tensor<float, 1, true> distance) {
125 extern __shared__
char smemTerm23[];
126 LookupT* term23 = (LookupT*) smemTerm23;
129 auto queryId = blockIdx.y;
130 auto probeId = blockIdx.x;
131 auto codesPerSubQuantizer = precompTerm2.getSize(2);
132 auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
136 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
137 float* distanceOut = distance[outBase].data();
139 auto listId = topQueryToCentroid[queryId][probeId];
145 unsigned char* codeList = (
unsigned char*) listCodes[listId];
146 int limit = listLengths[listId];
148 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
149 (NumSubQuantizers / 4);
150 unsigned int code32[kNumCode32];
151 unsigned int nextCode32[kNumCode32];
154 if (threadIdx.x < limit) {
155 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
159 float term1 = precompTerm1[queryId][probeId];
160 loadPrecomputedTerm<LookupT, LookupVecT>(term23,
161 precompTerm2[listId].data(),
162 precompTerm3[queryId].data(),
170 for (
int codeIndex = threadIdx.x;
172 codeIndex += blockDim.x) {
174 if (codeIndex + blockDim.x < limit) {
175 LoadCode32<NumSubQuantizers>::load(
176 nextCode32, codeList, codeIndex + blockDim.x);
182 for (
int word = 0; word < kNumCode32; ++word) {
183 constexpr
int kBytesPerCode32 =
184 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
186 if (kBytesPerCode32 == 1) {
187 auto code = code32[0];
188 dist = ConvertTo<float>::to(term23[code]);
192 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
193 auto code = getByte(code32[word], byte * 8, 8);
196 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
198 dist += ConvertTo<float>::to(term23[offset + code]);
206 distanceOut[codeIndex] = dist;
210 for (
int word = 0; word < kNumCode32; ++word) {
211 code32[word] = nextCode32[word];
217 runMultiPassTile(Tensor<float, 2, true>& queries,
218 Tensor<float, 2, true>& precompTerm1,
219 NoTypeTensor<3, true>& precompTerm2,
220 NoTypeTensor<3, true>& precompTerm3,
221 Tensor<int, 2, true>& topQueryToCentroid,
222 bool useFloat16Lookup,
224 int numSubQuantizers,
225 int numSubQuantizerCodes,
226 thrust::device_vector<void*>& listCodes,
227 thrust::device_vector<void*>& listIndices,
228 IndicesOptions indicesOptions,
229 thrust::device_vector<int>& listLengths,
230 Tensor<char, 1, true>& thrustMem,
231 Tensor<int, 2, true>& prefixSumOffsets,
232 Tensor<float, 1, true>& allDistances,
233 Tensor<float, 3, true>& heapDistances,
234 Tensor<int, 3, true>& heapIndices,
236 Tensor<float, 2, true>& outDistances,
237 Tensor<long, 2, true>& outIndices,
238 cudaStream_t stream) {
241 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
247 auto kThreadsPerBlock = 256;
249 auto grid = dim3(topQueryToCentroid.getSize(1),
250 topQueryToCentroid.getSize(0));
251 auto block = dim3(kThreadsPerBlock);
254 auto smem =
sizeof(float);
255 #ifdef FAISS_USE_FLOAT16
256 if (useFloat16Lookup) {
260 smem *= numSubQuantizers * numSubQuantizerCodes;
261 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
263 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
265 auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
266 auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
268 pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
269 <<<grid, block, smem, stream>>>( \
274 topQueryToCentroid, \
275 listCodes.data().get(), \
276 listLengths.data().get(), \
281 #ifdef FAISS_USE_FLOAT16
282 #define RUN_PQ(NUM_SUB_Q) \
284 if (useFloat16Lookup) { \
285 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
287 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
291 #define RUN_PQ(NUM_SUB_Q) \
293 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
295 #endif // FAISS_USE_FLOAT16
297 switch (bytesPerCode) {
358 runPass1SelectLists(prefixSumOffsets,
360 topQueryToCentroid.getSize(1),
368 auto flatHeapDistances = heapDistances.downcastInner<2>();
369 auto flatHeapIndices = heapIndices.downcastInner<2>();
371 runPass2SelectLists(flatHeapDistances,
386 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
387 Tensor<float, 2, true>& precompTerm1,
388 NoTypeTensor<3, true>& precompTerm2,
389 NoTypeTensor<3, true>& precompTerm3,
390 Tensor<int, 2, true>& topQueryToCentroid,
391 bool useFloat16Lookup,
393 int numSubQuantizers,
394 int numSubQuantizerCodes,
395 thrust::device_vector<void*>& listCodes,
396 thrust::device_vector<void*>& listIndices,
397 IndicesOptions indicesOptions,
398 thrust::device_vector<int>& listLengths,
402 Tensor<float, 2, true>& outDistances,
404 Tensor<long, 2, true>& outIndices,
406 constexpr
int kMinQueryTileSize = 8;
407 constexpr
int kMaxQueryTileSize = 128;
408 constexpr
int kThrustMemSize = 16384;
410 int nprobe = topQueryToCentroid.getSize(1);
412 auto& mem = res->getMemoryManagerCurrentDevice();
413 auto stream = res->getDefaultStreamCurrentDevice();
417 DeviceTensor<char, 1, true> thrustMem1(
418 mem, {kThrustMemSize}, stream);
419 DeviceTensor<char, 1, true> thrustMem2(
420 mem, {kThrustMemSize}, stream);
421 DeviceTensor<char, 1, true>* thrustMem[2] =
422 {&thrustMem1, &thrustMem2};
426 size_t sizeAvailable = mem.getSizeAvailable();
430 constexpr
int kNProbeSplit = 8;
431 int pass2Chunks = std::min(nprobe, kNProbeSplit);
433 size_t sizeForFirstSelectPass =
434 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
437 size_t sizePerQuery =
439 ((nprobe *
sizeof(int) +
sizeof(
int)) +
440 nprobe * maxListLength *
sizeof(
float) +
441 sizeForFirstSelectPass);
443 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
445 if (queryTileSize < kMinQueryTileSize) {
446 queryTileSize = kMinQueryTileSize;
447 }
else if (queryTileSize > kMaxQueryTileSize) {
448 queryTileSize = kMaxQueryTileSize;
453 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
454 std::numeric_limits<int>::max());
459 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
460 mem, {queryTileSize * nprobe + 1}, stream);
461 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
462 mem, {queryTileSize * nprobe + 1}, stream);
464 DeviceTensor<int, 2, true> prefixSumOffsets1(
465 prefixSumOffsetSpace1[1].data(),
466 {queryTileSize, nprobe});
467 DeviceTensor<int, 2, true> prefixSumOffsets2(
468 prefixSumOffsetSpace2[1].data(),
469 {queryTileSize, nprobe});
470 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
471 {&prefixSumOffsets1, &prefixSumOffsets2};
475 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
479 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
484 DeviceTensor<float, 1, true> allDistances1(
485 mem, {queryTileSize * nprobe * maxListLength}, stream);
486 DeviceTensor<float, 1, true> allDistances2(
487 mem, {queryTileSize * nprobe * maxListLength}, stream);
488 DeviceTensor<float, 1, true>* allDistances[2] =
489 {&allDistances1, &allDistances2};
491 DeviceTensor<float, 3, true> heapDistances1(
492 mem, {queryTileSize, pass2Chunks, k}, stream);
493 DeviceTensor<float, 3, true> heapDistances2(
494 mem, {queryTileSize, pass2Chunks, k}, stream);
495 DeviceTensor<float, 3, true>* heapDistances[2] =
496 {&heapDistances1, &heapDistances2};
498 DeviceTensor<int, 3, true> heapIndices1(
499 mem, {queryTileSize, pass2Chunks, k}, stream);
500 DeviceTensor<int, 3, true> heapIndices2(
501 mem, {queryTileSize, pass2Chunks, k}, stream);
502 DeviceTensor<int, 3, true>* heapIndices[2] =
503 {&heapIndices1, &heapIndices2};
505 auto streams = res->getAlternateStreamsCurrentDevice();
506 streamWait(streams, {stream});
510 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
511 int numQueriesInTile =
512 std::min(queryTileSize, queries.getSize(0) - query);
514 auto prefixSumOffsetsView =
515 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
517 auto coarseIndicesView =
518 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
520 queries.narrowOutermost(query, numQueriesInTile);
522 precompTerm1.narrowOutermost(query, numQueriesInTile);
524 precompTerm3.narrowOutermost(query, numQueriesInTile);
526 auto heapDistancesView =
527 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
528 auto heapIndicesView =
529 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
531 auto outDistanceView =
532 outDistances.narrowOutermost(query, numQueriesInTile);
533 auto outIndicesView =
534 outIndices.narrowOutermost(query, numQueriesInTile);
536 runMultiPassTile(queryView,
544 numSubQuantizerCodes,
549 *thrustMem[curStream],
550 prefixSumOffsetsView,
551 *allDistances[curStream],
559 curStream = (curStream + 1) % 2;
562 streamWait({stream}, streams);