11 #include "PQScanMultiPassNoPrecomputed.cuh"
12 #include "../GpuResources.h"
13 #include "PQCodeDistances.cuh"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/NoTypeTensor.cuh"
22 #include "../utils/StaticUtils.h"
24 #include "../utils/HostTensor.cuh"
26 namespace faiss {
namespace gpu {
28 bool isSupportedNoPrecomputedSubDimSize(
int dims) {
48 template <
typename LookupT,
typename LookupVecT>
50 static inline __device__
void load(LookupT* smem,
53 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
58 if (numCodes % kWordSize == 0) {
62 constexpr
int kUnroll = 2;
63 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
64 limitVec *= kUnroll * blockDim.x;
66 LookupVecT* smemV = (LookupVecT*) smem;
67 LookupVecT* codesV = (LookupVecT*) codes;
69 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
70 LookupVecT vals[kUnroll];
73 for (
int j = 0; j < kUnroll; ++j) {
79 for (
int j = 0; j < kUnroll; ++j) {
86 int remainder = limitVec * kWordSize;
88 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
93 constexpr
int kUnroll = 4;
95 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
98 for (; i < limit; i += kUnroll * blockDim.x) {
99 LookupT vals[kUnroll];
102 for (
int j = 0; j < kUnroll; ++j) {
103 vals[j] = codes[i + j * blockDim.x];
107 for (
int j = 0; j < kUnroll; ++j) {
108 smem[i + j * blockDim.x] = vals[j];
112 for (; i < numCodes; i += blockDim.x) {
119 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
129 const auto codesPerSubQuantizer = pqCentroids.
getSize(2);
132 extern __shared__
char smemCodeDistances[];
133 LookupT* codeDist = (LookupT*) smemCodeDistances;
136 auto queryId = blockIdx.y;
137 auto probeId = blockIdx.x;
141 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
142 float* distanceOut = distance[outBase].
data();
144 auto listId = topQueryToCentroid[queryId][probeId];
150 unsigned char* codeList = (
unsigned char*) listCodes[listId];
151 int limit = listLengths[listId];
153 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
154 (NumSubQuantizers / 4);
155 unsigned int code32[kNumCode32];
156 unsigned int nextCode32[kNumCode32];
159 if (threadIdx.x < limit) {
160 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
163 LoadCodeDistances<LookupT, LookupVecT>::load(
165 codeDistances[queryId][probeId].data(),
173 for (
int codeIndex = threadIdx.x;
175 codeIndex += blockDim.x) {
177 if (codeIndex + blockDim.x < limit) {
178 LoadCode32<NumSubQuantizers>::load(
179 nextCode32, codeList, codeIndex + blockDim.x);
185 for (
int word = 0; word < kNumCode32; ++word) {
186 constexpr
int kBytesPerCode32 =
187 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
189 if (kBytesPerCode32 == 1) {
190 auto code = code32[0];
191 dist = ConvertTo<float>::to(codeDist[code]);
195 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
196 auto code = getByte(code32[word], byte * 8, 8);
199 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
201 dist += ConvertTo<float>::to(codeDist[offset + code]);
209 distanceOut[codeIndex] = dist;
213 for (
int word = 0; word < kNumCode32; ++word) {
214 code32[word] = nextCode32[word];
220 runMultiPassTile(Tensor<float, 2, true>& queries,
221 Tensor<float, 2, true>& centroids,
222 Tensor<float, 3, true>& pqCentroidsInnermostCode,
223 NoTypeTensor<4, true>& codeDistances,
224 Tensor<int, 2, true>& topQueryToCentroid,
225 bool useFloat16Lookup,
227 int numSubQuantizers,
228 int numSubQuantizerCodes,
229 thrust::device_vector<void*>& listCodes,
230 thrust::device_vector<void*>& listIndices,
231 IndicesOptions indicesOptions,
232 thrust::device_vector<int>& listLengths,
233 Tensor<char, 1, true>& thrustMem,
234 Tensor<int, 2, true>& prefixSumOffsets,
235 Tensor<float, 1, true>& allDistances,
236 Tensor<float, 3, true>& heapDistances,
237 Tensor<int, 3, true>& heapIndices,
239 Tensor<float, 2, true>& outDistances,
240 Tensor<long, 2, true>& outIndices,
241 cudaStream_t stream) {
242 #ifndef FAISS_USE_FLOAT16
243 FAISS_ASSERT(!useFloat16Lookup);
248 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
253 runPQCodeDistances(pqCentroidsInnermostCode,
264 auto kThreadsPerBlock = 256;
266 auto grid = dim3(topQueryToCentroid.getSize(1),
267 topQueryToCentroid.getSize(0));
268 auto block = dim3(kThreadsPerBlock);
271 auto smem =
sizeof(float);
272 #ifdef FAISS_USE_FLOAT16
273 if (useFloat16Lookup) {
277 smem *= numSubQuantizers * numSubQuantizerCodes;
278 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
280 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
282 auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
284 pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
285 <<<grid, block, smem, stream>>>( \
287 pqCentroidsInnermostCode, \
288 topQueryToCentroid, \
290 listCodes.data().get(), \
291 listLengths.data().get(), \
296 #ifdef FAISS_USE_FLOAT16
297 #define RUN_PQ(NUM_SUB_Q) \
299 if (useFloat16Lookup) { \
300 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
302 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
306 #define RUN_PQ(NUM_SUB_Q) \
308 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
310 #endif // FAISS_USE_FLOAT16
312 switch (bytesPerCode) {
373 runPass1SelectLists(prefixSumOffsets,
375 topQueryToCentroid.getSize(1),
383 auto flatHeapDistances = heapDistances.downcastInner<2>();
384 auto flatHeapIndices = heapIndices.downcastInner<2>();
386 runPass2SelectLists(flatHeapDistances,
399 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
400 Tensor<float, 2, true>& centroids,
401 Tensor<float, 3, true>& pqCentroidsInnermostCode,
402 Tensor<int, 2, true>& topQueryToCentroid,
403 bool useFloat16Lookup,
405 int numSubQuantizers,
406 int numSubQuantizerCodes,
407 thrust::device_vector<void*>& listCodes,
408 thrust::device_vector<void*>& listIndices,
409 IndicesOptions indicesOptions,
410 thrust::device_vector<int>& listLengths,
414 Tensor<float, 2, true>& outDistances,
416 Tensor<long, 2, true>& outIndices,
418 constexpr
int kMinQueryTileSize = 8;
419 constexpr
int kMaxQueryTileSize = 128;
420 constexpr
int kThrustMemSize = 16384;
422 int nprobe = topQueryToCentroid.getSize(1);
424 auto& mem = res->getMemoryManagerCurrentDevice();
425 auto stream = res->getDefaultStreamCurrentDevice();
429 DeviceTensor<char, 1, true> thrustMem1(
430 mem, {kThrustMemSize}, stream);
431 DeviceTensor<char, 1, true> thrustMem2(
432 mem, {kThrustMemSize}, stream);
433 DeviceTensor<char, 1, true>* thrustMem[2] =
434 {&thrustMem1, &thrustMem2};
438 size_t sizeAvailable = mem.getSizeAvailable();
442 constexpr
int kNProbeSplit = 8;
443 int pass2Chunks = std::min(nprobe, kNProbeSplit);
445 size_t sizeForFirstSelectPass =
446 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
449 size_t sizePerQuery =
451 ((nprobe *
sizeof(int) +
sizeof(
int)) +
452 nprobe * maxListLength *
sizeof(
float) +
454 nprobe * numSubQuantizers * numSubQuantizerCodes *
sizeof(float) +
455 sizeForFirstSelectPass);
457 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
459 if (queryTileSize < kMinQueryTileSize) {
460 queryTileSize = kMinQueryTileSize;
461 }
else if (queryTileSize > kMaxQueryTileSize) {
462 queryTileSize = kMaxQueryTileSize;
467 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
468 std::numeric_limits<int>::max());
473 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
474 mem, {queryTileSize * nprobe + 1}, stream);
475 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
476 mem, {queryTileSize * nprobe + 1}, stream);
478 DeviceTensor<int, 2, true> prefixSumOffsets1(
479 prefixSumOffsetSpace1[1].data(),
480 {queryTileSize, nprobe});
481 DeviceTensor<int, 2, true> prefixSumOffsets2(
482 prefixSumOffsetSpace2[1].data(),
483 {queryTileSize, nprobe});
484 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
485 {&prefixSumOffsets1, &prefixSumOffsets2};
489 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
493 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
498 int codeDistanceTypeSize =
sizeof(float);
499 #ifdef FAISS_USE_FLOAT16
500 if (useFloat16Lookup) {
501 codeDistanceTypeSize =
sizeof(half);
504 FAISS_ASSERT(!useFloat16Lookup);
505 int codeSize =
sizeof(float);
508 int totalCodeDistancesSize =
509 queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
510 codeDistanceTypeSize;
512 DeviceTensor<char, 1, true> codeDistances1Mem(
513 mem, {totalCodeDistancesSize}, stream);
514 NoTypeTensor<4, true> codeDistances1(
515 codeDistances1Mem.data(),
516 codeDistanceTypeSize,
517 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
519 DeviceTensor<char, 1, true> codeDistances2Mem(
520 mem, {totalCodeDistancesSize}, stream);
521 NoTypeTensor<4, true> codeDistances2(
522 codeDistances2Mem.data(),
523 codeDistanceTypeSize,
524 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
526 NoTypeTensor<4, true>* codeDistances[2] =
527 {&codeDistances1, &codeDistances2};
529 DeviceTensor<float, 1, true> allDistances1(
530 mem, {queryTileSize * nprobe * maxListLength}, stream);
531 DeviceTensor<float, 1, true> allDistances2(
532 mem, {queryTileSize * nprobe * maxListLength}, stream);
533 DeviceTensor<float, 1, true>* allDistances[2] =
534 {&allDistances1, &allDistances2};
536 DeviceTensor<float, 3, true> heapDistances1(
537 mem, {queryTileSize, pass2Chunks, k}, stream);
538 DeviceTensor<float, 3, true> heapDistances2(
539 mem, {queryTileSize, pass2Chunks, k}, stream);
540 DeviceTensor<float, 3, true>* heapDistances[2] =
541 {&heapDistances1, &heapDistances2};
543 DeviceTensor<int, 3, true> heapIndices1(
544 mem, {queryTileSize, pass2Chunks, k}, stream);
545 DeviceTensor<int, 3, true> heapIndices2(
546 mem, {queryTileSize, pass2Chunks, k}, stream);
547 DeviceTensor<int, 3, true>* heapIndices[2] =
548 {&heapIndices1, &heapIndices2};
550 auto streams = res->getAlternateStreamsCurrentDevice();
551 streamWait(streams, {stream});
555 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
556 int numQueriesInTile =
557 std::min(queryTileSize, queries.getSize(0) - query);
559 auto prefixSumOffsetsView =
560 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
562 auto codeDistancesView =
563 codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
564 auto coarseIndicesView =
565 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
567 queries.narrowOutermost(query, numQueriesInTile);
569 auto heapDistancesView =
570 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
571 auto heapIndicesView =
572 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
574 auto outDistanceView =
575 outDistances.narrowOutermost(query, numQueriesInTile);
576 auto outIndicesView =
577 outIndices.narrowOutermost(query, numQueriesInTile);
579 runMultiPassTile(queryView,
581 pqCentroidsInnermostCode,
587 numSubQuantizerCodes,
592 *thrustMem[curStream],
593 prefixSumOffsetsView,
594 *allDistances[curStream],
602 curStream = (curStream + 1) % 2;
605 streamWait({stream}, streams);
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ IndexT getSize(int i) const