12 #include "PQScanMultiPassNoPrecomputed.cuh"
13 #include "../GpuResources.h"
14 #include "PQCodeDistances.cuh"
15 #include "PQCodeLoad.cuh"
16 #include "IVFUtils.cuh"
17 #include "../utils/ConversionOperators.cuh"
18 #include "../utils/DeviceTensor.cuh"
19 #include "../utils/DeviceUtils.h"
20 #include "../utils/Float16.cuh"
21 #include "../utils/LoadStoreOperators.cuh"
22 #include "../utils/NoTypeTensor.cuh"
23 #include "../utils/StaticUtils.h"
25 #include "../utils/HostTensor.cuh"
27 namespace faiss {
namespace gpu {
29 bool isSupportedNoPrecomputedSubDimSize(
int dims) {
51 template <
typename LookupT,
typename LookupVecT>
53 static inline __device__
void load(LookupT* smem,
56 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
61 if (numCodes % kWordSize == 0) {
65 constexpr
int kUnroll = 2;
66 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
67 limitVec *= kUnroll * blockDim.x;
69 LookupVecT* smemV = (LookupVecT*) smem;
70 LookupVecT* codesV = (LookupVecT*) codes;
72 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
73 LookupVecT vals[kUnroll];
76 for (
int j = 0; j < kUnroll; ++j) {
82 for (
int j = 0; j < kUnroll; ++j) {
89 int remainder = limitVec * kWordSize;
91 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
96 constexpr
int kUnroll = 4;
98 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
101 for (; i < limit; i += kUnroll * blockDim.x) {
102 LookupT vals[kUnroll];
105 for (
int j = 0; j < kUnroll; ++j) {
106 vals[j] = codes[i + j * blockDim.x];
110 for (
int j = 0; j < kUnroll; ++j) {
111 smem[i + j * blockDim.x] = vals[j];
115 for (; i < numCodes; i += blockDim.x) {
122 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
132 const auto codesPerSubQuantizer = pqCentroids.
getSize(2);
135 extern __shared__
char smemCodeDistances[];
136 LookupT* codeDist = (LookupT*) smemCodeDistances;
139 auto queryId = blockIdx.y;
140 auto probeId = blockIdx.x;
144 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
145 float* distanceOut = distance[outBase].
data();
147 auto listId = topQueryToCentroid[queryId][probeId];
153 unsigned char* codeList = (
unsigned char*) listCodes[listId];
154 int limit = listLengths[listId];
156 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
157 (NumSubQuantizers / 4);
158 unsigned int code32[kNumCode32];
159 unsigned int nextCode32[kNumCode32];
162 if (threadIdx.x < limit) {
163 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
166 LoadCodeDistances<LookupT, LookupVecT>::load(
168 codeDistances[queryId][probeId].data(),
176 for (
int codeIndex = threadIdx.x;
178 codeIndex += blockDim.x) {
180 if (codeIndex + blockDim.x < limit) {
181 LoadCode32<NumSubQuantizers>::load(
182 nextCode32, codeList, codeIndex + blockDim.x);
188 for (
int word = 0; word < kNumCode32; ++word) {
189 constexpr
int kBytesPerCode32 =
190 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
192 if (kBytesPerCode32 == 1) {
193 auto code = code32[0];
194 dist = ConvertTo<float>::to(codeDist[code]);
198 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
199 auto code = getByte(code32[word], byte * 8, 8);
202 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
204 dist += ConvertTo<float>::to(codeDist[offset + code]);
212 distanceOut[codeIndex] = dist;
216 for (
int word = 0; word < kNumCode32; ++word) {
217 code32[word] = nextCode32[word];
223 runMultiPassTile(Tensor<float, 2, true>& queries,
224 Tensor<float, 2, true>& centroids,
225 Tensor<float, 3, true>& pqCentroidsInnermostCode,
226 NoTypeTensor<4, true>& codeDistances,
227 Tensor<int, 2, true>& topQueryToCentroid,
228 bool useFloat16Lookup,
230 int numSubQuantizers,
231 int numSubQuantizerCodes,
232 thrust::device_vector<void*>& listCodes,
233 thrust::device_vector<void*>& listIndices,
234 IndicesOptions indicesOptions,
235 thrust::device_vector<int>& listLengths,
236 Tensor<char, 1, true>& thrustMem,
237 Tensor<int, 2, true>& prefixSumOffsets,
238 Tensor<float, 1, true>& allDistances,
239 Tensor<float, 3, true>& heapDistances,
240 Tensor<int, 3, true>& heapIndices,
242 Tensor<float, 2, true>& outDistances,
243 Tensor<long, 2, true>& outIndices,
244 cudaStream_t stream) {
245 #ifndef FAISS_USE_FLOAT16
246 FAISS_ASSERT(!useFloat16Lookup);
251 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
256 runPQCodeDistances(pqCentroidsInnermostCode,
267 auto kThreadsPerBlock = 256;
269 auto grid = dim3(topQueryToCentroid.getSize(1),
270 topQueryToCentroid.getSize(0));
271 auto block = dim3(kThreadsPerBlock);
274 auto smem =
sizeof(float);
275 #ifdef FAISS_USE_FLOAT16
276 if (useFloat16Lookup) {
280 smem *= numSubQuantizers * numSubQuantizerCodes;
281 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
283 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
285 auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
287 pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
288 <<<grid, block, smem, stream>>>( \
290 pqCentroidsInnermostCode, \
291 topQueryToCentroid, \
293 listCodes.data().get(), \
294 listLengths.data().get(), \
299 #ifdef FAISS_USE_FLOAT16
300 #define RUN_PQ(NUM_SUB_Q) \
302 if (useFloat16Lookup) { \
303 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
305 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
309 #define RUN_PQ(NUM_SUB_Q) \
311 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
313 #endif // FAISS_USE_FLOAT16
315 switch (bytesPerCode) {
374 runPass1SelectLists(prefixSumOffsets,
376 topQueryToCentroid.getSize(1),
384 auto flatHeapDistances = heapDistances.downcastInner<2>();
385 auto flatHeapIndices = heapIndices.downcastInner<2>();
387 runPass2SelectLists(flatHeapDistances,
399 CUDA_VERIFY(cudaGetLastError());
402 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
403 Tensor<float, 2, true>& centroids,
404 Tensor<float, 3, true>& pqCentroidsInnermostCode,
405 Tensor<int, 2, true>& topQueryToCentroid,
406 bool useFloat16Lookup,
408 int numSubQuantizers,
409 int numSubQuantizerCodes,
410 thrust::device_vector<void*>& listCodes,
411 thrust::device_vector<void*>& listIndices,
412 IndicesOptions indicesOptions,
413 thrust::device_vector<int>& listLengths,
417 Tensor<float, 2, true>& outDistances,
419 Tensor<long, 2, true>& outIndices,
421 constexpr
int kMinQueryTileSize = 8;
422 constexpr
int kMaxQueryTileSize = 128;
423 constexpr
int kThrustMemSize = 16384;
425 int nprobe = topQueryToCentroid.getSize(1);
427 auto& mem = res->getMemoryManagerCurrentDevice();
428 auto stream = res->getDefaultStreamCurrentDevice();
432 DeviceTensor<char, 1, true> thrustMem1(
433 mem, {kThrustMemSize}, stream);
434 DeviceTensor<char, 1, true> thrustMem2(
435 mem, {kThrustMemSize}, stream);
436 DeviceTensor<char, 1, true>* thrustMem[2] =
437 {&thrustMem1, &thrustMem2};
441 size_t sizeAvailable = mem.getSizeAvailable();
445 constexpr
int kNProbeSplit = 8;
446 int pass2Chunks = std::min(nprobe, kNProbeSplit);
448 size_t sizeForFirstSelectPass =
449 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
452 size_t sizePerQuery =
454 ((nprobe *
sizeof(int) +
sizeof(
int)) +
455 nprobe * maxListLength *
sizeof(
float) +
457 nprobe * numSubQuantizers * numSubQuantizerCodes *
sizeof(float) +
458 sizeForFirstSelectPass);
460 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
462 if (queryTileSize < kMinQueryTileSize) {
463 queryTileSize = kMinQueryTileSize;
464 }
else if (queryTileSize > kMaxQueryTileSize) {
465 queryTileSize = kMaxQueryTileSize;
470 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
471 std::numeric_limits<int>::max());
476 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
477 mem, {queryTileSize * nprobe + 1}, stream);
478 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
479 mem, {queryTileSize * nprobe + 1}, stream);
481 DeviceTensor<int, 2, true> prefixSumOffsets1(
482 prefixSumOffsetSpace1[1].data(),
483 {queryTileSize, nprobe});
484 DeviceTensor<int, 2, true> prefixSumOffsets2(
485 prefixSumOffsetSpace2[1].data(),
486 {queryTileSize, nprobe});
487 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
488 {&prefixSumOffsets1, &prefixSumOffsets2};
492 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
496 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
501 int codeDistanceTypeSize =
sizeof(float);
502 #ifdef FAISS_USE_FLOAT16
503 if (useFloat16Lookup) {
504 codeDistanceTypeSize =
sizeof(half);
507 FAISS_ASSERT(!useFloat16Lookup);
508 int codeSize =
sizeof(float);
511 int totalCodeDistancesSize =
512 queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
513 codeDistanceTypeSize;
515 DeviceTensor<char, 1, true> codeDistances1Mem(
516 mem, {totalCodeDistancesSize}, stream);
517 NoTypeTensor<4, true> codeDistances1(
518 codeDistances1Mem.data(),
519 codeDistanceTypeSize,
520 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
522 DeviceTensor<char, 1, true> codeDistances2Mem(
523 mem, {totalCodeDistancesSize}, stream);
524 NoTypeTensor<4, true> codeDistances2(
525 codeDistances2Mem.data(),
526 codeDistanceTypeSize,
527 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
529 NoTypeTensor<4, true>* codeDistances[2] =
530 {&codeDistances1, &codeDistances2};
532 DeviceTensor<float, 1, true> allDistances1(
533 mem, {queryTileSize * nprobe * maxListLength}, stream);
534 DeviceTensor<float, 1, true> allDistances2(
535 mem, {queryTileSize * nprobe * maxListLength}, stream);
536 DeviceTensor<float, 1, true>* allDistances[2] =
537 {&allDistances1, &allDistances2};
539 DeviceTensor<float, 3, true> heapDistances1(
540 mem, {queryTileSize, pass2Chunks, k}, stream);
541 DeviceTensor<float, 3, true> heapDistances2(
542 mem, {queryTileSize, pass2Chunks, k}, stream);
543 DeviceTensor<float, 3, true>* heapDistances[2] =
544 {&heapDistances1, &heapDistances2};
546 DeviceTensor<int, 3, true> heapIndices1(
547 mem, {queryTileSize, pass2Chunks, k}, stream);
548 DeviceTensor<int, 3, true> heapIndices2(
549 mem, {queryTileSize, pass2Chunks, k}, stream);
550 DeviceTensor<int, 3, true>* heapIndices[2] =
551 {&heapIndices1, &heapIndices2};
553 auto streams = res->getAlternateStreamsCurrentDevice();
554 streamWait(streams, {stream});
558 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
559 int numQueriesInTile =
560 std::min(queryTileSize, queries.getSize(0) - query);
562 auto prefixSumOffsetsView =
563 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
565 auto codeDistancesView =
566 codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
567 auto coarseIndicesView =
568 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
570 queries.narrowOutermost(query, numQueriesInTile);
572 auto heapDistancesView =
573 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
574 auto heapIndicesView =
575 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
577 auto outDistanceView =
578 outDistances.narrowOutermost(query, numQueriesInTile);
579 auto outIndicesView =
580 outIndices.narrowOutermost(query, numQueriesInTile);
582 runMultiPassTile(queryView,
584 pqCentroidsInnermostCode,
590 numSubQuantizerCodes,
595 *thrustMem[curStream],
596 prefixSumOffsetsView,
597 *allDistances[curStream],
605 curStream = (curStream + 1) % 2;
608 streamWait({stream}, streams);
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ IndexT getSize(int i) const