11 #include "IVFFlatScan.cuh"
12 #include "../GpuResources.h"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/Float16.cuh"
19 #include "../utils/MathOperators.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/PtxUtils.cuh"
22 #include "../utils/Reductions.cuh"
23 #include "../utils/StaticUtils.h"
24 #include <thrust/host_vector.h>
26 namespace faiss {
namespace gpu {
29 inline __device__
typename Math<T>::ScalarType l2Distance(T a, T b) {
30 a = Math<T>::sub(a, b);
31 a = Math<T>::mul(a, a);
36 inline __device__
typename Math<T>::ScalarType ipDistance(T a, T b) {
45 template <
int Dims,
bool L2,
typename T>
50 template <
bool L2,
typename T>
52 static __device__
void scan(
float* query,
57 extern __shared__
float smem[];
58 T* vecs = (T*) vecData;
60 for (
int vec = 0; vec < numVecs; ++vec) {
64 for (
int d = threadIdx.x; d < dim; d += blockDim.x) {
66 float queryVal = query[d];
70 curDist = l2Distance(queryVal, vecVal);
72 curDist = ipDistance(queryVal, vecVal);
79 dist = blockReduceAllSum<float, false, true>(dist, smem);
81 if (threadIdx.x == 0) {
82 distanceOut[vec] = dist;
89 template <
bool L2,
typename T>
91 static __device__
void scan(
float* query,
96 extern __shared__
float smem[];
97 T* vecs = (T*) vecData;
99 float queryVal = query[threadIdx.x];
101 constexpr
int kUnroll = 4;
102 int limit = utils::roundDown(numVecs, kUnroll);
104 for (
int i = 0; i < limit; i += kUnroll) {
105 float vecVal[kUnroll];
108 for (
int j = 0; j < kUnroll; ++j) {
113 for (
int j = 0; j < kUnroll; ++j) {
115 vecVal[j] = l2Distance(queryVal, vecVal[j]);
117 vecVal[j] = ipDistance(queryVal, vecVal[j]);
121 blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
123 if (threadIdx.x == 0) {
125 for (
int j = 0; j < kUnroll; ++j) {
126 distanceOut[i + j] = vecVal[j];
132 for (
int i = limit; i < numVecs; ++i) {
136 vecVal = l2Distance(queryVal, vecVal);
138 vecVal = ipDistance(queryVal, vecVal);
141 vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
143 if (threadIdx.x == 0) {
144 distanceOut[i] = vecVal;
153 static constexpr
int kDims = 64;
155 static __device__
void scan(
float* query,
159 float* distanceOut) {
161 float* vecs = (
float*) vecData;
163 int laneId = getLaneId();
164 int warpId = threadIdx.x / kWarpSize;
165 int numWarps = blockDim.x / kWarpSize;
167 float2 queryVal = *(float2*) &query[laneId * 2];
169 constexpr
int kUnroll = 4;
170 float2 vecVal[kUnroll];
172 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
174 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
176 for (
int j = 0; j < kUnroll; ++j) {
179 vecVal[j] = *(float2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
185 for (
int j = 0; j < kUnroll; ++j) {
187 dist[j] = l2Distance(queryVal, vecVal[j]);
189 dist[j] = ipDistance(queryVal, vecVal[j]);
195 for (
int j = 0; j < kUnroll; ++j) {
196 dist[j] = warpReduceAllSum(dist[j]);
201 for (
int j = 0; j < kUnroll; ++j) {
202 distanceOut[i + j * numWarps] = dist[j];
208 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
209 vecVal[0] = *(float2*) &vecs[i * kDims + laneId * 2];
212 dist = l2Distance(queryVal, vecVal[0]);
214 dist = ipDistance(queryVal, vecVal[0]);
217 dist = warpReduceAllSum(dist);
220 distanceOut[i] = dist;
226 #ifdef FAISS_USE_FLOAT16
231 static constexpr
int kDims = 64;
233 static __device__
void scan(
float* query,
237 float* distanceOut) {
239 half* vecs = (half*) vecData;
241 int laneId = getLaneId();
242 int warpId = threadIdx.x / kWarpSize;
243 int numWarps = blockDim.x / kWarpSize;
245 float2 queryVal = *(float2*) &query[laneId * 2];
247 constexpr
int kUnroll = 4;
249 half2 vecVal[kUnroll];
251 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
253 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
255 for (
int j = 0; j < kUnroll; ++j) {
258 vecVal[j] = *(half2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
264 for (
int j = 0; j < kUnroll; ++j) {
266 dist[j] = l2Distance(queryVal, __half22float2(vecVal[j]));
268 dist[j] = ipDistance(queryVal, __half22float2(vecVal[j]));
274 for (
int j = 0; j < kUnroll; ++j) {
275 dist[j] = warpReduceAllSum(dist[j]);
280 for (
int j = 0; j < kUnroll; ++j) {
281 distanceOut[i + j * numWarps] = dist[j];
287 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
288 vecVal[0] = *(half2*) &vecs[i * kDims + laneId * 2];
292 dist = l2Distance(queryVal, __half22float2(vecVal[0]));
294 dist = ipDistance(queryVal, __half22float2(vecVal[0]));
297 dist = warpReduceAllSum(dist);
300 distanceOut[i] = dist;
311 static constexpr
int kDims = 128;
313 static __device__
void scan(
float* query,
317 float* distanceOut) {
319 float* vecs = (
float*) vecData;
321 int laneId = getLaneId();
322 int warpId = threadIdx.x / kWarpSize;
323 int numWarps = blockDim.x / kWarpSize;
325 float4 queryVal = *(float4*) &query[laneId * 4];
327 constexpr
int kUnroll = 4;
328 float4 vecVal[kUnroll];
330 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
332 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
334 for (
int j = 0; j < kUnroll; ++j) {
337 vecVal[j] = *(float4*) &vecs[(i + j * numWarps) * kDims + laneId * 4];
343 for (
int j = 0; j < kUnroll; ++j) {
345 dist[j] = l2Distance(queryVal, vecVal[j]);
347 dist[j] = ipDistance(queryVal, vecVal[j]);
353 for (
int j = 0; j < kUnroll; ++j) {
354 dist[j] = warpReduceAllSum(dist[j]);
359 for (
int j = 0; j < kUnroll; ++j) {
360 distanceOut[i + j * numWarps] = dist[j];
366 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
367 vecVal[0] = *(float4*) &vecs[i * kDims + laneId * 4];
370 dist = l2Distance(queryVal, vecVal[0]);
372 dist = ipDistance(queryVal, vecVal[0]);
375 dist = warpReduceAllSum(dist);
378 distanceOut[i] = dist;
384 #ifdef FAISS_USE_FLOAT16
389 static constexpr
int kDims = 128;
391 static __device__
void scan(
float* query,
395 float* distanceOut) {
397 half* vecs = (half*) vecData;
399 int laneId = getLaneId();
400 int warpId = threadIdx.x / kWarpSize;
401 int numWarps = blockDim.x / kWarpSize;
403 float4 queryVal = *(float4*) &query[laneId * 4];
405 constexpr
int kUnroll = 4;
407 Half4 vecVal[kUnroll];
409 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
411 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
413 for (
int j = 0; j < kUnroll; ++j) {
418 &vecs[(i + j * numWarps) * kDims + laneId * 4]);
424 for (
int j = 0; j < kUnroll; ++j) {
426 dist[j] = l2Distance(queryVal, half4ToFloat4(vecVal[j]));
428 dist[j] = ipDistance(queryVal, half4ToFloat4(vecVal[j]));
434 for (
int j = 0; j < kUnroll; ++j) {
435 dist[j] = warpReduceAllSum(dist[j]);
440 for (
int j = 0; j < kUnroll; ++j) {
441 distanceOut[i + j * numWarps] = dist[j];
447 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
448 vecVal[0] = LoadStore<Half4>::load(&vecs[i * kDims + laneId * 4]);
452 dist = l2Distance(queryVal, half4ToFloat4(vecVal[0]));
454 dist = ipDistance(queryVal, half4ToFloat4(vecVal[0]));
457 dist = warpReduceAllSum(dist);
460 distanceOut[i] = dist;
471 static constexpr
int kDims = 256;
473 static __device__
void scan(
float* query,
477 float* distanceOut) {
488 #ifdef FAISS_USE_FLOAT16
493 static constexpr
int kDims = 256;
495 static __device__
void scan(
float* query,
499 float* distanceOut) {
501 half* vecs = (half*) vecData;
503 int laneId = getLaneId();
504 int warpId = threadIdx.x / kWarpSize;
505 int numWarps = blockDim.x / kWarpSize;
509 float4 queryValA = *(float4*) &query[laneId * 8];
510 float4 queryValB = *(float4*) &query[laneId * 8 + 4];
512 constexpr
int kUnroll = 4;
514 Half8 vecVal[kUnroll];
516 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
518 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
520 for (
int j = 0; j < kUnroll; ++j) {
525 &vecs[(i + j * numWarps) * kDims + laneId * 8]);
531 for (
int j = 0; j < kUnroll; ++j) {
533 dist[j] = l2Distance(queryValA, half4ToFloat4(vecVal[j].a));
534 dist[j] += l2Distance(queryValB, half4ToFloat4(vecVal[j].b));
536 dist[j] = ipDistance(queryValA, half4ToFloat4(vecVal[j].a));
537 dist[j] += ipDistance(queryValB, half4ToFloat4(vecVal[j].b));
543 for (
int j = 0; j < kUnroll; ++j) {
544 dist[j] = warpReduceAllSum(dist[j]);
549 for (
int j = 0; j < kUnroll; ++j) {
550 distanceOut[i + j * numWarps] = dist[j];
556 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
557 vecVal[0] = LoadStore<Half8>::load(&vecs[i * kDims + laneId * 8]);
561 dist = l2Distance(queryValA, half4ToFloat4(vecVal[0].a));
562 dist += l2Distance(queryValB, half4ToFloat4(vecVal[0].b));
564 dist = ipDistance(queryValA, half4ToFloat4(vecVal[0].a));
565 dist += ipDistance(queryValB, half4ToFloat4(vecVal[0].b));
568 dist = warpReduceAllSum(dist);
571 distanceOut[i] = dist;
579 template <
int Dims,
bool L2,
typename T>
581 ivfFlatScan(Tensor<float, 2, true> queries,
582 Tensor<int, 2, true> listIds,
585 Tensor<int, 2, true> prefixSumOffsets,
586 Tensor<float, 1, true> distance) {
587 auto queryId = blockIdx.y;
588 auto probeId = blockIdx.x;
592 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
594 auto listId = listIds[queryId][probeId];
600 auto query = queries[queryId].data();
601 auto vecs = allListData[listId];
602 auto numVecs = listLengths[listId];
603 auto dim = queries.getSize(1);
604 auto distanceOut = distance[outBase].data();
606 IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
610 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
611 Tensor<int, 2, true>& listIds,
612 thrust::device_vector<void*>& listData,
613 thrust::device_vector<void*>& listIndices,
614 IndicesOptions indicesOptions,
615 thrust::device_vector<int>& listLengths,
616 Tensor<char, 1, true>& thrustMem,
617 Tensor<int, 2, true>& prefixSumOffsets,
618 Tensor<float, 1, true>& allDistances,
619 Tensor<float, 3, true>& heapDistances,
620 Tensor<int, 3, true>& heapIndices,
624 Tensor<float, 2, true>& outDistances,
625 Tensor<long, 2, true>& outIndices,
626 cudaStream_t stream) {
629 runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
632 constexpr
int kMaxThreadsIVF = 512;
637 int dim = queries.getSize(1);
638 int numThreads = std::min(dim, kMaxThreadsIVF);
640 auto grid = dim3(listIds.getSize(1),
642 auto block = dim3(numThreads);
644 auto smem =
sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
646 #define RUN_IVF_FLAT(DIMS, L2, T) \
648 ivfFlatScan<DIMS, L2, T> \
649 <<<grid, block, smem, stream>>>( \
652 listData.data().get(), \
653 listLengths.data().get(), \
658 #ifdef FAISS_USE_FLOAT16
660 #define HANDLE_DIM_CASE(DIMS) \
664 RUN_IVF_FLAT(DIMS, true, half); \
666 RUN_IVF_FLAT(DIMS, true, float); \
670 RUN_IVF_FLAT(DIMS, false, half); \
672 RUN_IVF_FLAT(DIMS, false, float); \
678 #define HANDLE_DIM_CASE(DIMS) \
682 FAISS_ASSERT(false); \
684 RUN_IVF_FLAT(DIMS, true, float); \
688 FAISS_ASSERT(false); \
690 RUN_IVF_FLAT(DIMS, false, float); \
695 #endif // FAISS_USE_FLOAT16
699 }
else if (dim == 128) {
700 HANDLE_DIM_CASE(128);
701 }
else if (dim == 256) {
702 HANDLE_DIM_CASE(256);
703 }
else if (dim <= kMaxThreadsIVF) {
711 #undef HANDLE_DIM_CASE
715 runPass1SelectLists(prefixSumOffsets,
725 auto flatHeapDistances = heapDistances.downcastInner<2>();
726 auto flatHeapIndices = heapIndices.downcastInner<2>();
728 runPass2SelectLists(flatHeapDistances,
742 runIVFFlatScan(Tensor<float, 2, true>& queries,
743 Tensor<int, 2, true>& listIds,
744 thrust::device_vector<void*>& listData,
745 thrust::device_vector<void*>& listIndices,
746 IndicesOptions indicesOptions,
747 thrust::device_vector<int>& listLengths,
753 Tensor<float, 2, true>& outDistances,
755 Tensor<long, 2, true>& outIndices,
757 constexpr
int kMinQueryTileSize = 8;
758 constexpr
int kMaxQueryTileSize = 128;
759 constexpr
int kThrustMemSize = 16384;
761 int nprobe = listIds.getSize(1);
763 auto& mem = res->getMemoryManagerCurrentDevice();
764 auto stream = res->getDefaultStreamCurrentDevice();
768 DeviceTensor<char, 1, true> thrustMem1(
769 mem, {kThrustMemSize}, stream);
770 DeviceTensor<char, 1, true> thrustMem2(
771 mem, {kThrustMemSize}, stream);
772 DeviceTensor<char, 1, true>* thrustMem[2] =
773 {&thrustMem1, &thrustMem2};
777 size_t sizeAvailable = mem.getSizeAvailable();
781 constexpr
int kNProbeSplit = 8;
782 int pass2Chunks = std::min(nprobe, kNProbeSplit);
784 size_t sizeForFirstSelectPass =
785 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
788 size_t sizePerQuery =
790 ((nprobe *
sizeof(int) +
sizeof(
int)) +
791 nprobe * maxListLength *
sizeof(
float) +
792 sizeForFirstSelectPass);
794 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
796 if (queryTileSize < kMinQueryTileSize) {
797 queryTileSize = kMinQueryTileSize;
798 }
else if (queryTileSize > kMaxQueryTileSize) {
799 queryTileSize = kMaxQueryTileSize;
804 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
805 std::numeric_limits<int>::max());
810 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
811 mem, {queryTileSize * nprobe + 1}, stream);
812 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
813 mem, {queryTileSize * nprobe + 1}, stream);
815 DeviceTensor<int, 2, true> prefixSumOffsets1(
816 prefixSumOffsetSpace1[1].data(),
817 {queryTileSize, nprobe});
818 DeviceTensor<int, 2, true> prefixSumOffsets2(
819 prefixSumOffsetSpace2[1].data(),
820 {queryTileSize, nprobe});
821 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
822 {&prefixSumOffsets1, &prefixSumOffsets2};
826 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
830 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
835 DeviceTensor<float, 1, true> allDistances1(
836 mem, {queryTileSize * nprobe * maxListLength}, stream);
837 DeviceTensor<float, 1, true> allDistances2(
838 mem, {queryTileSize * nprobe * maxListLength}, stream);
839 DeviceTensor<float, 1, true>* allDistances[2] =
840 {&allDistances1, &allDistances2};
842 DeviceTensor<float, 3, true> heapDistances1(
843 mem, {queryTileSize, pass2Chunks, k}, stream);
844 DeviceTensor<float, 3, true> heapDistances2(
845 mem, {queryTileSize, pass2Chunks, k}, stream);
846 DeviceTensor<float, 3, true>* heapDistances[2] =
847 {&heapDistances1, &heapDistances2};
849 DeviceTensor<int, 3, true> heapIndices1(
850 mem, {queryTileSize, pass2Chunks, k}, stream);
851 DeviceTensor<int, 3, true> heapIndices2(
852 mem, {queryTileSize, pass2Chunks, k}, stream);
853 DeviceTensor<int, 3, true>* heapIndices[2] =
854 {&heapIndices1, &heapIndices2};
856 auto streams = res->getAlternateStreamsCurrentDevice();
857 streamWait(streams, {stream});
861 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
862 int numQueriesInTile =
863 std::min(queryTileSize, queries.getSize(0) - query);
865 auto prefixSumOffsetsView =
866 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
869 listIds.narrowOutermost(query, numQueriesInTile);
871 queries.narrowOutermost(query, numQueriesInTile);
873 auto heapDistancesView =
874 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
875 auto heapIndicesView =
876 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
878 auto outDistanceView =
879 outDistances.narrowOutermost(query, numQueriesInTile);
880 auto outIndicesView =
881 outIndices.narrowOutermost(query, numQueriesInTile);
883 runIVFFlatScanTile(queryView,
889 *thrustMem[curStream],
890 prefixSumOffsetsView,
891 *allDistances[curStream],
901 curStream = (curStream + 1) % 2;
904 streamWait({stream}, streams);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
The class that we use to provide scan specializations.