9 #include "PQScanMultiPassNoPrecomputed.cuh"
10 #include "../GpuResources.h"
11 #include "PQCodeDistances.cuh"
12 #include "PQCodeLoad.cuh"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/NoTypeTensor.cuh"
20 #include "../utils/StaticUtils.h"
22 #include "../utils/HostTensor.cuh"
24 namespace faiss {
namespace gpu {
27 bool isSupportedNoPrecomputedSubDimSize(
int dims) {
50 template <
typename LookupT,
typename LookupVecT>
52 static inline __device__
void load(LookupT* smem,
55 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
60 if (numCodes % kWordSize == 0) {
64 constexpr
int kUnroll = 2;
65 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
66 limitVec *= kUnroll * blockDim.x;
68 LookupVecT* smemV = (LookupVecT*) smem;
69 LookupVecT* codesV = (LookupVecT*) codes;
71 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
72 LookupVecT vals[kUnroll];
75 for (
int j = 0; j < kUnroll; ++j) {
81 for (
int j = 0; j < kUnroll; ++j) {
88 int remainder = limitVec * kWordSize;
90 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
95 constexpr
int kUnroll = 4;
97 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
100 for (; i < limit; i += kUnroll * blockDim.x) {
101 LookupT vals[kUnroll];
104 for (
int j = 0; j < kUnroll; ++j) {
105 vals[j] = codes[i + j * blockDim.x];
109 for (
int j = 0; j < kUnroll; ++j) {
110 smem[i + j * blockDim.x] = vals[j];
114 for (; i < numCodes; i += blockDim.x) {
121 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
131 const auto codesPerSubQuantizer = pqCentroids.
getSize(2);
134 extern __shared__
char smemCodeDistances[];
135 LookupT* codeDist = (LookupT*) smemCodeDistances;
138 auto queryId = blockIdx.y;
139 auto probeId = blockIdx.x;
143 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
144 float* distanceOut = distance[outBase].
data();
146 auto listId = topQueryToCentroid[queryId][probeId];
152 unsigned char* codeList = (
unsigned char*) listCodes[listId];
153 int limit = listLengths[listId];
155 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
156 (NumSubQuantizers / 4);
157 unsigned int code32[kNumCode32];
158 unsigned int nextCode32[kNumCode32];
161 if (threadIdx.x < limit) {
162 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
165 LoadCodeDistances<LookupT, LookupVecT>::load(
167 codeDistances[queryId][probeId].data(),
175 for (
int codeIndex = threadIdx.x;
177 codeIndex += blockDim.x) {
179 if (codeIndex + blockDim.x < limit) {
180 LoadCode32<NumSubQuantizers>::load(
181 nextCode32, codeList, codeIndex + blockDim.x);
187 for (
int word = 0; word < kNumCode32; ++word) {
188 constexpr
int kBytesPerCode32 =
189 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
191 if (kBytesPerCode32 == 1) {
192 auto code = code32[0];
193 dist = ConvertTo<float>::to(codeDist[code]);
197 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
198 auto code = getByte(code32[word], byte * 8, 8);
201 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
203 dist += ConvertTo<float>::to(codeDist[offset + code]);
211 distanceOut[codeIndex] = dist;
215 for (
int word = 0; word < kNumCode32; ++word) {
216 code32[word] = nextCode32[word];
222 runMultiPassTile(Tensor<float, 2, true>& queries,
223 Tensor<float, 2, true>& centroids,
224 Tensor<float, 3, true>& pqCentroidsInnermostCode,
225 NoTypeTensor<4, true>& codeDistances,
226 Tensor<int, 2, true>& topQueryToCentroid,
227 bool useFloat16Lookup,
229 int numSubQuantizers,
230 int numSubQuantizerCodes,
231 thrust::device_vector<void*>& listCodes,
232 thrust::device_vector<void*>& listIndices,
233 IndicesOptions indicesOptions,
234 thrust::device_vector<int>& listLengths,
235 Tensor<char, 1, true>& thrustMem,
236 Tensor<int, 2, true>& prefixSumOffsets,
237 Tensor<float, 1, true>& allDistances,
238 Tensor<float, 3, true>& heapDistances,
239 Tensor<int, 3, true>& heapIndices,
241 Tensor<float, 2, true>& outDistances,
242 Tensor<long, 2, true>& outIndices,
243 cudaStream_t stream) {
244 #ifndef FAISS_USE_FLOAT16
245 FAISS_ASSERT(!useFloat16Lookup);
250 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
255 runPQCodeDistances(pqCentroidsInnermostCode,
266 auto kThreadsPerBlock = 256;
268 auto grid = dim3(topQueryToCentroid.getSize(1),
269 topQueryToCentroid.getSize(0));
270 auto block = dim3(kThreadsPerBlock);
273 auto smem =
sizeof(float);
274 #ifdef FAISS_USE_FLOAT16
275 if (useFloat16Lookup) {
279 smem *= numSubQuantizers * numSubQuantizerCodes;
280 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
282 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
284 auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
286 pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
287 <<<grid, block, smem, stream>>>( \
289 pqCentroidsInnermostCode, \
290 topQueryToCentroid, \
292 listCodes.data().get(), \
293 listLengths.data().get(), \
298 #ifdef FAISS_USE_FLOAT16
299 #define RUN_PQ(NUM_SUB_Q) \
301 if (useFloat16Lookup) { \
302 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
304 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
308 #define RUN_PQ(NUM_SUB_Q) \
310 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
312 #endif // FAISS_USE_FLOAT16
314 switch (bytesPerCode) {
375 runPass1SelectLists(prefixSumOffsets,
377 topQueryToCentroid.getSize(1),
385 auto flatHeapDistances = heapDistances.downcastInner<2>();
386 auto flatHeapIndices = heapIndices.downcastInner<2>();
388 runPass2SelectLists(flatHeapDistances,
401 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
402 Tensor<float, 2, true>& centroids,
403 Tensor<float, 3, true>& pqCentroidsInnermostCode,
404 Tensor<int, 2, true>& topQueryToCentroid,
405 bool useFloat16Lookup,
407 int numSubQuantizers,
408 int numSubQuantizerCodes,
409 thrust::device_vector<void*>& listCodes,
410 thrust::device_vector<void*>& listIndices,
411 IndicesOptions indicesOptions,
412 thrust::device_vector<int>& listLengths,
416 Tensor<float, 2, true>& outDistances,
418 Tensor<long, 2, true>& outIndices,
420 constexpr
int kMinQueryTileSize = 8;
421 constexpr
int kMaxQueryTileSize = 128;
422 constexpr
int kThrustMemSize = 16384;
424 int nprobe = topQueryToCentroid.getSize(1);
426 auto& mem = res->getMemoryManagerCurrentDevice();
427 auto stream = res->getDefaultStreamCurrentDevice();
431 DeviceTensor<char, 1, true> thrustMem1(
432 mem, {kThrustMemSize}, stream);
433 DeviceTensor<char, 1, true> thrustMem2(
434 mem, {kThrustMemSize}, stream);
435 DeviceTensor<char, 1, true>* thrustMem[2] =
436 {&thrustMem1, &thrustMem2};
440 size_t sizeAvailable = mem.getSizeAvailable();
444 constexpr
int kNProbeSplit = 8;
445 int pass2Chunks = std::min(nprobe, kNProbeSplit);
447 size_t sizeForFirstSelectPass =
448 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
451 size_t sizePerQuery =
453 ((nprobe *
sizeof(int) +
sizeof(
int)) +
454 nprobe * maxListLength *
sizeof(
float) +
456 nprobe * numSubQuantizers * numSubQuantizerCodes *
sizeof(float) +
457 sizeForFirstSelectPass);
459 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
461 if (queryTileSize < kMinQueryTileSize) {
462 queryTileSize = kMinQueryTileSize;
463 }
else if (queryTileSize > kMaxQueryTileSize) {
464 queryTileSize = kMaxQueryTileSize;
469 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
470 std::numeric_limits<int>::max());
475 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
476 mem, {queryTileSize * nprobe + 1}, stream);
477 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
478 mem, {queryTileSize * nprobe + 1}, stream);
480 DeviceTensor<int, 2, true> prefixSumOffsets1(
481 prefixSumOffsetSpace1[1].data(),
482 {queryTileSize, nprobe});
483 DeviceTensor<int, 2, true> prefixSumOffsets2(
484 prefixSumOffsetSpace2[1].data(),
485 {queryTileSize, nprobe});
486 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
487 {&prefixSumOffsets1, &prefixSumOffsets2};
491 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
495 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
500 int codeDistanceTypeSize =
sizeof(float);
501 #ifdef FAISS_USE_FLOAT16
502 if (useFloat16Lookup) {
503 codeDistanceTypeSize =
sizeof(half);
506 FAISS_ASSERT(!useFloat16Lookup);
509 int totalCodeDistancesSize =
510 queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
511 codeDistanceTypeSize;
513 DeviceTensor<char, 1, true> codeDistances1Mem(
514 mem, {totalCodeDistancesSize}, stream);
515 NoTypeTensor<4, true> codeDistances1(
516 codeDistances1Mem.data(),
517 codeDistanceTypeSize,
518 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
520 DeviceTensor<char, 1, true> codeDistances2Mem(
521 mem, {totalCodeDistancesSize}, stream);
522 NoTypeTensor<4, true> codeDistances2(
523 codeDistances2Mem.data(),
524 codeDistanceTypeSize,
525 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
527 NoTypeTensor<4, true>* codeDistances[2] =
528 {&codeDistances1, &codeDistances2};
530 DeviceTensor<float, 1, true> allDistances1(
531 mem, {queryTileSize * nprobe * maxListLength}, stream);
532 DeviceTensor<float, 1, true> allDistances2(
533 mem, {queryTileSize * nprobe * maxListLength}, stream);
534 DeviceTensor<float, 1, true>* allDistances[2] =
535 {&allDistances1, &allDistances2};
537 DeviceTensor<float, 3, true> heapDistances1(
538 mem, {queryTileSize, pass2Chunks, k}, stream);
539 DeviceTensor<float, 3, true> heapDistances2(
540 mem, {queryTileSize, pass2Chunks, k}, stream);
541 DeviceTensor<float, 3, true>* heapDistances[2] =
542 {&heapDistances1, &heapDistances2};
544 DeviceTensor<int, 3, true> heapIndices1(
545 mem, {queryTileSize, pass2Chunks, k}, stream);
546 DeviceTensor<int, 3, true> heapIndices2(
547 mem, {queryTileSize, pass2Chunks, k}, stream);
548 DeviceTensor<int, 3, true>* heapIndices[2] =
549 {&heapIndices1, &heapIndices2};
551 auto streams = res->getAlternateStreamsCurrentDevice();
552 streamWait(streams, {stream});
556 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
557 int numQueriesInTile =
558 std::min(queryTileSize, queries.getSize(0) - query);
560 auto prefixSumOffsetsView =
561 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
563 auto codeDistancesView =
564 codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
565 auto coarseIndicesView =
566 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
568 queries.narrowOutermost(query, numQueriesInTile);
570 auto heapDistancesView =
571 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
572 auto heapIndicesView =
573 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
575 auto outDistanceView =
576 outDistances.narrowOutermost(query, numQueriesInTile);
577 auto outIndicesView =
578 outIndices.narrowOutermost(query, numQueriesInTile);
580 runMultiPassTile(queryView,
582 pqCentroidsInnermostCode,
588 numSubQuantizerCodes,
593 *thrustMem[curStream],
594 prefixSumOffsetsView,
595 *allDistances[curStream],
603 curStream = (curStream + 1) % 2;
606 streamWait({stream}, streams);
__host__ __device__ IndexT getSize(int i) const
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.