11 #include "PQScanMultiPassNoPrecomputed.cuh"
12 #include "../GpuResources.h"
13 #include "PQCodeDistances.cuh"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/NoTypeTensor.cuh"
22 #include "../utils/StaticUtils.h"
24 #include "../utils/HostTensor.cuh"
26 namespace faiss {
namespace gpu {
29 bool isSupportedNoPrecomputedSubDimSize(
int dims) {
52 template <
typename LookupT,
typename LookupVecT>
54 static inline __device__
void load(LookupT* smem,
57 constexpr
int kWordSize =
sizeof(LookupVecT) /
sizeof(LookupT);
62 if (numCodes % kWordSize == 0) {
66 constexpr
int kUnroll = 2;
67 int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
68 limitVec *= kUnroll * blockDim.x;
70 LookupVecT* smemV = (LookupVecT*) smem;
71 LookupVecT* codesV = (LookupVecT*) codes;
73 for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
74 LookupVecT vals[kUnroll];
77 for (
int j = 0; j < kUnroll; ++j) {
83 for (
int j = 0; j < kUnroll; ++j) {
90 int remainder = limitVec * kWordSize;
92 for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
97 constexpr
int kUnroll = 4;
99 int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
102 for (; i < limit; i += kUnroll * blockDim.x) {
103 LookupT vals[kUnroll];
106 for (
int j = 0; j < kUnroll; ++j) {
107 vals[j] = codes[i + j * blockDim.x];
111 for (
int j = 0; j < kUnroll; ++j) {
112 smem[i + j * blockDim.x] = vals[j];
116 for (; i < numCodes; i += blockDim.x) {
123 template <
int NumSubQuantizers,
typename LookupT,
typename LookupVecT>
133 const auto codesPerSubQuantizer = pqCentroids.
getSize(2);
136 extern __shared__
char smemCodeDistances[];
137 LookupT* codeDist = (LookupT*) smemCodeDistances;
140 auto queryId = blockIdx.y;
141 auto probeId = blockIdx.x;
145 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
146 float* distanceOut = distance[outBase].
data();
148 auto listId = topQueryToCentroid[queryId][probeId];
154 unsigned char* codeList = (
unsigned char*) listCodes[listId];
155 int limit = listLengths[listId];
157 constexpr
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
158 (NumSubQuantizers / 4);
159 unsigned int code32[kNumCode32];
160 unsigned int nextCode32[kNumCode32];
163 if (threadIdx.x < limit) {
164 LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
167 LoadCodeDistances<LookupT, LookupVecT>::load(
169 codeDistances[queryId][probeId].data(),
177 for (
int codeIndex = threadIdx.x;
179 codeIndex += blockDim.x) {
181 if (codeIndex + blockDim.x < limit) {
182 LoadCode32<NumSubQuantizers>::load(
183 nextCode32, codeList, codeIndex + blockDim.x);
189 for (
int word = 0; word < kNumCode32; ++word) {
190 constexpr
int kBytesPerCode32 =
191 NumSubQuantizers < 4 ? NumSubQuantizers : 4;
193 if (kBytesPerCode32 == 1) {
194 auto code = code32[0];
195 dist = ConvertTo<float>::to(codeDist[code]);
199 for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
200 auto code = getByte(code32[word], byte * 8, 8);
203 codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
205 dist += ConvertTo<float>::to(codeDist[offset + code]);
213 distanceOut[codeIndex] = dist;
217 for (
int word = 0; word < kNumCode32; ++word) {
218 code32[word] = nextCode32[word];
224 runMultiPassTile(Tensor<float, 2, true>& queries,
225 Tensor<float, 2, true>& centroids,
226 Tensor<float, 3, true>& pqCentroidsInnermostCode,
227 NoTypeTensor<4, true>& codeDistances,
228 Tensor<int, 2, true>& topQueryToCentroid,
229 bool useFloat16Lookup,
231 int numSubQuantizers,
232 int numSubQuantizerCodes,
233 thrust::device_vector<void*>& listCodes,
234 thrust::device_vector<void*>& listIndices,
235 IndicesOptions indicesOptions,
236 thrust::device_vector<int>& listLengths,
237 Tensor<char, 1, true>& thrustMem,
238 Tensor<int, 2, true>& prefixSumOffsets,
239 Tensor<float, 1, true>& allDistances,
240 Tensor<float, 3, true>& heapDistances,
241 Tensor<int, 3, true>& heapIndices,
243 Tensor<float, 2, true>& outDistances,
244 Tensor<long, 2, true>& outIndices,
245 cudaStream_t stream) {
246 #ifndef FAISS_USE_FLOAT16
247 FAISS_ASSERT(!useFloat16Lookup);
252 runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
257 runPQCodeDistances(pqCentroidsInnermostCode,
268 auto kThreadsPerBlock = 256;
270 auto grid = dim3(topQueryToCentroid.getSize(1),
271 topQueryToCentroid.getSize(0));
272 auto block = dim3(kThreadsPerBlock);
275 auto smem =
sizeof(float);
276 #ifdef FAISS_USE_FLOAT16
277 if (useFloat16Lookup) {
281 smem *= numSubQuantizers * numSubQuantizerCodes;
282 FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
284 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
286 auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
288 pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
289 <<<grid, block, smem, stream>>>( \
291 pqCentroidsInnermostCode, \
292 topQueryToCentroid, \
294 listCodes.data().get(), \
295 listLengths.data().get(), \
300 #ifdef FAISS_USE_FLOAT16
301 #define RUN_PQ(NUM_SUB_Q) \
303 if (useFloat16Lookup) { \
304 RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
306 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
310 #define RUN_PQ(NUM_SUB_Q) \
312 RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
314 #endif // FAISS_USE_FLOAT16
316 switch (bytesPerCode) {
377 runPass1SelectLists(prefixSumOffsets,
379 topQueryToCentroid.getSize(1),
387 auto flatHeapDistances = heapDistances.downcastInner<2>();
388 auto flatHeapIndices = heapIndices.downcastInner<2>();
390 runPass2SelectLists(flatHeapDistances,
403 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
404 Tensor<float, 2, true>& centroids,
405 Tensor<float, 3, true>& pqCentroidsInnermostCode,
406 Tensor<int, 2, true>& topQueryToCentroid,
407 bool useFloat16Lookup,
409 int numSubQuantizers,
410 int numSubQuantizerCodes,
411 thrust::device_vector<void*>& listCodes,
412 thrust::device_vector<void*>& listIndices,
413 IndicesOptions indicesOptions,
414 thrust::device_vector<int>& listLengths,
418 Tensor<float, 2, true>& outDistances,
420 Tensor<long, 2, true>& outIndices,
422 constexpr
int kMinQueryTileSize = 8;
423 constexpr
int kMaxQueryTileSize = 128;
424 constexpr
int kThrustMemSize = 16384;
426 int nprobe = topQueryToCentroid.getSize(1);
428 auto& mem = res->getMemoryManagerCurrentDevice();
429 auto stream = res->getDefaultStreamCurrentDevice();
433 DeviceTensor<char, 1, true> thrustMem1(
434 mem, {kThrustMemSize}, stream);
435 DeviceTensor<char, 1, true> thrustMem2(
436 mem, {kThrustMemSize}, stream);
437 DeviceTensor<char, 1, true>* thrustMem[2] =
438 {&thrustMem1, &thrustMem2};
442 size_t sizeAvailable = mem.getSizeAvailable();
446 constexpr
int kNProbeSplit = 8;
447 int pass2Chunks = std::min(nprobe, kNProbeSplit);
449 size_t sizeForFirstSelectPass =
450 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
453 size_t sizePerQuery =
455 ((nprobe *
sizeof(int) +
sizeof(
int)) +
456 nprobe * maxListLength *
sizeof(
float) +
458 nprobe * numSubQuantizers * numSubQuantizerCodes *
sizeof(float) +
459 sizeForFirstSelectPass);
461 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
463 if (queryTileSize < kMinQueryTileSize) {
464 queryTileSize = kMinQueryTileSize;
465 }
else if (queryTileSize > kMaxQueryTileSize) {
466 queryTileSize = kMaxQueryTileSize;
471 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
472 std::numeric_limits<int>::max());
477 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
478 mem, {queryTileSize * nprobe + 1}, stream);
479 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
480 mem, {queryTileSize * nprobe + 1}, stream);
482 DeviceTensor<int, 2, true> prefixSumOffsets1(
483 prefixSumOffsetSpace1[1].data(),
484 {queryTileSize, nprobe});
485 DeviceTensor<int, 2, true> prefixSumOffsets2(
486 prefixSumOffsetSpace2[1].data(),
487 {queryTileSize, nprobe});
488 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
489 {&prefixSumOffsets1, &prefixSumOffsets2};
493 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
497 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
502 int codeDistanceTypeSize =
sizeof(float);
503 #ifdef FAISS_USE_FLOAT16
504 if (useFloat16Lookup) {
505 codeDistanceTypeSize =
sizeof(half);
508 FAISS_ASSERT(!useFloat16Lookup);
509 int codeSize =
sizeof(float);
512 int totalCodeDistancesSize =
513 queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
514 codeDistanceTypeSize;
516 DeviceTensor<char, 1, true> codeDistances1Mem(
517 mem, {totalCodeDistancesSize}, stream);
518 NoTypeTensor<4, true> codeDistances1(
519 codeDistances1Mem.data(),
520 codeDistanceTypeSize,
521 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
523 DeviceTensor<char, 1, true> codeDistances2Mem(
524 mem, {totalCodeDistancesSize}, stream);
525 NoTypeTensor<4, true> codeDistances2(
526 codeDistances2Mem.data(),
527 codeDistanceTypeSize,
528 {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
530 NoTypeTensor<4, true>* codeDistances[2] =
531 {&codeDistances1, &codeDistances2};
533 DeviceTensor<float, 1, true> allDistances1(
534 mem, {queryTileSize * nprobe * maxListLength}, stream);
535 DeviceTensor<float, 1, true> allDistances2(
536 mem, {queryTileSize * nprobe * maxListLength}, stream);
537 DeviceTensor<float, 1, true>* allDistances[2] =
538 {&allDistances1, &allDistances2};
540 DeviceTensor<float, 3, true> heapDistances1(
541 mem, {queryTileSize, pass2Chunks, k}, stream);
542 DeviceTensor<float, 3, true> heapDistances2(
543 mem, {queryTileSize, pass2Chunks, k}, stream);
544 DeviceTensor<float, 3, true>* heapDistances[2] =
545 {&heapDistances1, &heapDistances2};
547 DeviceTensor<int, 3, true> heapIndices1(
548 mem, {queryTileSize, pass2Chunks, k}, stream);
549 DeviceTensor<int, 3, true> heapIndices2(
550 mem, {queryTileSize, pass2Chunks, k}, stream);
551 DeviceTensor<int, 3, true>* heapIndices[2] =
552 {&heapIndices1, &heapIndices2};
554 auto streams = res->getAlternateStreamsCurrentDevice();
555 streamWait(streams, {stream});
559 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
560 int numQueriesInTile =
561 std::min(queryTileSize, queries.getSize(0) - query);
563 auto prefixSumOffsetsView =
564 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
566 auto codeDistancesView =
567 codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
568 auto coarseIndicesView =
569 topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
571 queries.narrowOutermost(query, numQueriesInTile);
573 auto heapDistancesView =
574 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
575 auto heapIndicesView =
576 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
578 auto outDistanceView =
579 outDistances.narrowOutermost(query, numQueriesInTile);
580 auto outIndicesView =
581 outIndices.narrowOutermost(query, numQueriesInTile);
583 runMultiPassTile(queryView,
585 pqCentroidsInnermostCode,
591 numSubQuantizerCodes,
596 *thrustMem[curStream],
597 prefixSumOffsetsView,
598 *allDistances[curStream],
606 curStream = (curStream + 1) % 2;
609 streamWait({stream}, streams);
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ IndexT getSize(int i) const