12 #include "IVFFlatScan.cuh"
13 #include "../GpuResources.h"
14 #include "IVFUtils.cuh"
15 #include "../utils/ConversionOperators.cuh"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/DeviceTensor.cuh"
19 #include "../utils/Float16.cuh"
20 #include "../utils/MathOperators.cuh"
21 #include "../utils/LoadStoreOperators.cuh"
22 #include "../utils/PtxUtils.cuh"
23 #include "../utils/Reductions.cuh"
24 #include "../utils/StaticUtils.h"
25 #include <thrust/host_vector.h>
27 namespace faiss {
namespace gpu {
30 inline __device__
typename Math<T>::ScalarType l2Distance(T a, T b) {
31 a = Math<T>::sub(a, b);
32 a = Math<T>::mul(a, a);
37 inline __device__
typename Math<T>::ScalarType ipDistance(T a, T b) {
46 template <
int Dims,
bool L2,
typename T>
51 template <
bool L2,
typename T>
53 static __device__
void scan(
float* query,
58 extern __shared__
float smem[];
59 T* vecs = (T*) vecData;
61 for (
int vec = 0; vec < numVecs; ++vec) {
65 for (
int d = threadIdx.x; d < dim; d += blockDim.x) {
67 float queryVal = query[d];
71 curDist = l2Distance(queryVal, vecVal);
73 curDist = ipDistance(queryVal, vecVal);
80 dist = blockReduceAllSum<float, false, true>(dist, smem);
82 if (threadIdx.x == 0) {
83 distanceOut[vec] = dist;
90 template <
bool L2,
typename T>
92 static __device__
void scan(
float* query,
97 extern __shared__
float smem[];
98 T* vecs = (T*) vecData;
100 float queryVal = query[threadIdx.x];
102 constexpr
int kUnroll = 4;
103 int limit = utils::roundDown(numVecs, kUnroll);
105 for (
int i = 0; i < limit; i += kUnroll) {
106 float vecVal[kUnroll];
109 for (
int j = 0; j < kUnroll; ++j) {
114 for (
int j = 0; j < kUnroll; ++j) {
116 vecVal[j] = l2Distance(queryVal, vecVal[j]);
118 vecVal[j] = ipDistance(queryVal, vecVal[j]);
122 blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
124 if (threadIdx.x == 0) {
126 for (
int j = 0; j < kUnroll; ++j) {
127 distanceOut[i + j] = vecVal[j];
133 for (
int i = limit; i < numVecs; ++i) {
137 vecVal = l2Distance(queryVal, vecVal);
139 vecVal = ipDistance(queryVal, vecVal);
142 vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
144 if (threadIdx.x == 0) {
145 distanceOut[i] = vecVal;
154 static constexpr
int kDims = 64;
156 static __device__
void scan(
float* query,
160 float* distanceOut) {
162 float* vecs = (
float*) vecData;
164 int laneId = getLaneId();
165 int warpId = threadIdx.x / kWarpSize;
166 int numWarps = blockDim.x / kWarpSize;
168 float2 queryVal = *(float2*) &query[laneId * 2];
170 constexpr
int kUnroll = 4;
171 float2 vecVal[kUnroll];
173 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
175 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
177 for (
int j = 0; j < kUnroll; ++j) {
180 vecVal[j] = *(float2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
186 for (
int j = 0; j < kUnroll; ++j) {
188 dist[j] = l2Distance(queryVal, vecVal[j]);
190 dist[j] = ipDistance(queryVal, vecVal[j]);
196 for (
int j = 0; j < kUnroll; ++j) {
197 dist[j] = warpReduceAllSum(dist[j]);
202 for (
int j = 0; j < kUnroll; ++j) {
203 distanceOut[i + j * numWarps] = dist[j];
209 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
210 vecVal[0] = *(float2*) &vecs[i * kDims + laneId * 2];
213 dist = l2Distance(queryVal, vecVal[0]);
215 dist = ipDistance(queryVal, vecVal[0]);
218 dist = warpReduceAllSum(dist);
221 distanceOut[i] = dist;
227 #ifdef FAISS_USE_FLOAT16
232 static constexpr
int kDims = 64;
234 static __device__
void scan(
float* query,
238 float* distanceOut) {
240 half* vecs = (half*) vecData;
242 int laneId = getLaneId();
243 int warpId = threadIdx.x / kWarpSize;
244 int numWarps = blockDim.x / kWarpSize;
246 float2 queryVal = *(float2*) &query[laneId * 2];
248 constexpr
int kUnroll = 4;
250 half2 vecVal[kUnroll];
252 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
254 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
256 for (
int j = 0; j < kUnroll; ++j) {
259 vecVal[j] = *(half2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
265 for (
int j = 0; j < kUnroll; ++j) {
267 dist[j] = l2Distance(queryVal, __half22float2(vecVal[j]));
269 dist[j] = ipDistance(queryVal, __half22float2(vecVal[j]));
275 for (
int j = 0; j < kUnroll; ++j) {
276 dist[j] = warpReduceAllSum(dist[j]);
281 for (
int j = 0; j < kUnroll; ++j) {
282 distanceOut[i + j * numWarps] = dist[j];
288 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
289 vecVal[0] = *(half2*) &vecs[i * kDims + laneId * 2];
293 dist = l2Distance(queryVal, __half22float2(vecVal[0]));
295 dist = ipDistance(queryVal, __half22float2(vecVal[0]));
298 dist = warpReduceAllSum(dist);
301 distanceOut[i] = dist;
312 static constexpr
int kDims = 128;
314 static __device__
void scan(
float* query,
318 float* distanceOut) {
320 float* vecs = (
float*) vecData;
322 int laneId = getLaneId();
323 int warpId = threadIdx.x / kWarpSize;
324 int numWarps = blockDim.x / kWarpSize;
326 float4 queryVal = *(float4*) &query[laneId * 4];
328 constexpr
int kUnroll = 4;
329 float4 vecVal[kUnroll];
331 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
333 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
335 for (
int j = 0; j < kUnroll; ++j) {
338 vecVal[j] = *(float4*) &vecs[(i + j * numWarps) * kDims + laneId * 4];
344 for (
int j = 0; j < kUnroll; ++j) {
346 dist[j] = l2Distance(queryVal, vecVal[j]);
348 dist[j] = ipDistance(queryVal, vecVal[j]);
354 for (
int j = 0; j < kUnroll; ++j) {
355 dist[j] = warpReduceAllSum(dist[j]);
360 for (
int j = 0; j < kUnroll; ++j) {
361 distanceOut[i + j * numWarps] = dist[j];
367 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
368 vecVal[0] = *(float4*) &vecs[i * kDims + laneId * 4];
371 dist = l2Distance(queryVal, vecVal[0]);
373 dist = ipDistance(queryVal, vecVal[0]);
376 dist = warpReduceAllSum(dist);
379 distanceOut[i] = dist;
385 #ifdef FAISS_USE_FLOAT16
390 static constexpr
int kDims = 128;
392 static __device__
void scan(
float* query,
396 float* distanceOut) {
398 half* vecs = (half*) vecData;
400 int laneId = getLaneId();
401 int warpId = threadIdx.x / kWarpSize;
402 int numWarps = blockDim.x / kWarpSize;
404 float4 queryVal = *(float4*) &query[laneId * 4];
406 constexpr
int kUnroll = 4;
408 Half4 vecVal[kUnroll];
410 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
412 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
414 for (
int j = 0; j < kUnroll; ++j) {
419 &vecs[(i + j * numWarps) * kDims + laneId * 4]);
425 for (
int j = 0; j < kUnroll; ++j) {
427 dist[j] = l2Distance(queryVal, half4ToFloat4(vecVal[j]));
429 dist[j] = ipDistance(queryVal, half4ToFloat4(vecVal[j]));
435 for (
int j = 0; j < kUnroll; ++j) {
436 dist[j] = warpReduceAllSum(dist[j]);
441 for (
int j = 0; j < kUnroll; ++j) {
442 distanceOut[i + j * numWarps] = dist[j];
448 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
449 vecVal[0] = LoadStore<Half4>::load(&vecs[i * kDims + laneId * 4]);
453 dist = l2Distance(queryVal, half4ToFloat4(vecVal[0]));
455 dist = ipDistance(queryVal, half4ToFloat4(vecVal[0]));
458 dist = warpReduceAllSum(dist);
461 distanceOut[i] = dist;
472 static constexpr
int kDims = 256;
474 static __device__
void scan(
float* query,
478 float* distanceOut) {
489 #ifdef FAISS_USE_FLOAT16
494 static constexpr
int kDims = 256;
496 static __device__
void scan(
float* query,
500 float* distanceOut) {
502 half* vecs = (half*) vecData;
504 int laneId = getLaneId();
505 int warpId = threadIdx.x / kWarpSize;
506 int numWarps = blockDim.x / kWarpSize;
510 float4 queryValA = *(float4*) &query[laneId * 8];
511 float4 queryValB = *(float4*) &query[laneId * 8 + 4];
513 constexpr
int kUnroll = 4;
515 Half8 vecVal[kUnroll];
517 int limit = utils::roundDown(numVecs, kUnroll * numWarps);
519 for (
int i = warpId; i < limit; i += kUnroll * numWarps) {
521 for (
int j = 0; j < kUnroll; ++j) {
526 &vecs[(i + j * numWarps) * kDims + laneId * 8]);
532 for (
int j = 0; j < kUnroll; ++j) {
534 dist[j] = l2Distance(queryValA, half4ToFloat4(vecVal[j].a));
535 dist[j] += l2Distance(queryValB, half4ToFloat4(vecVal[j].b));
537 dist[j] = ipDistance(queryValA, half4ToFloat4(vecVal[j].a));
538 dist[j] += ipDistance(queryValB, half4ToFloat4(vecVal[j].b));
544 for (
int j = 0; j < kUnroll; ++j) {
545 dist[j] = warpReduceAllSum(dist[j]);
550 for (
int j = 0; j < kUnroll; ++j) {
551 distanceOut[i + j * numWarps] = dist[j];
557 for (
int i = limit + warpId; i < numVecs; i += numWarps) {
558 vecVal[0] = LoadStore<Half8>::load(&vecs[i * kDims + laneId * 8]);
562 dist = l2Distance(queryValA, half4ToFloat4(vecVal[0].a));
563 dist += l2Distance(queryValB, half4ToFloat4(vecVal[0].b));
565 dist = ipDistance(queryValA, half4ToFloat4(vecVal[0].a));
566 dist += ipDistance(queryValB, half4ToFloat4(vecVal[0].b));
569 dist = warpReduceAllSum(dist);
572 distanceOut[i] = dist;
580 template <
int Dims,
bool L2,
typename T>
582 ivfFlatScan(Tensor<float, 2, true> queries,
583 Tensor<int, 2, true> listIds,
586 Tensor<int, 2, true> prefixSumOffsets,
587 Tensor<float, 1, true> distance) {
588 auto queryId = blockIdx.y;
589 auto probeId = blockIdx.x;
593 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
595 auto listId = listIds[queryId][probeId];
601 auto query = queries[queryId].data();
602 auto vecs = allListData[listId];
603 auto numVecs = listLengths[listId];
604 auto dim = queries.getSize(1);
605 auto distanceOut = distance[outBase].data();
607 IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
611 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
612 Tensor<int, 2, true>& listIds,
613 thrust::device_vector<void*>& listData,
614 thrust::device_vector<void*>& listIndices,
615 IndicesOptions indicesOptions,
616 thrust::device_vector<int>& listLengths,
617 Tensor<char, 1, true>& thrustMem,
618 Tensor<int, 2, true>& prefixSumOffsets,
619 Tensor<float, 1, true>& allDistances,
620 Tensor<float, 3, true>& heapDistances,
621 Tensor<int, 3, true>& heapIndices,
625 Tensor<float, 2, true>& outDistances,
626 Tensor<long, 2, true>& outIndices,
627 cudaStream_t stream) {
630 runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
633 constexpr
int kMaxThreadsIVF = 512;
638 int dim = queries.getSize(1);
639 int numThreads = std::min(dim, kMaxThreadsIVF);
641 auto grid = dim3(listIds.getSize(1),
643 auto block = dim3(numThreads);
645 auto smem =
sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
647 #define RUN_IVF_FLAT(DIMS, L2, T) \
649 ivfFlatScan<DIMS, L2, T> \
650 <<<grid, block, smem, stream>>>( \
653 listData.data().get(), \
654 listLengths.data().get(), \
659 #ifdef FAISS_USE_FLOAT16
661 #define HANDLE_DIM_CASE(DIMS) \
665 RUN_IVF_FLAT(DIMS, true, half); \
667 RUN_IVF_FLAT(DIMS, true, float); \
671 RUN_IVF_FLAT(DIMS, false, half); \
673 RUN_IVF_FLAT(DIMS, false, float); \
679 #define HANDLE_DIM_CASE(DIMS) \
683 FAISS_ASSERT(false); \
685 RUN_IVF_FLAT(DIMS, true, float); \
689 FAISS_ASSERT(false); \
691 RUN_IVF_FLAT(DIMS, false, float); \
696 #endif // FAISS_USE_FLOAT16
700 }
else if (dim == 128) {
701 HANDLE_DIM_CASE(128);
702 }
else if (dim == 256) {
703 HANDLE_DIM_CASE(256);
704 }
else if (dim <= kMaxThreadsIVF) {
710 #undef HANDLE_DIM_CASE
714 runPass1SelectLists(prefixSumOffsets,
724 auto flatHeapDistances = heapDistances.downcastInner<2>();
725 auto flatHeapIndices = heapIndices.downcastInner<2>();
727 runPass2SelectLists(flatHeapDistances,
739 CUDA_VERIFY(cudaGetLastError());
743 runIVFFlatScan(Tensor<float, 2, true>& queries,
744 Tensor<int, 2, true>& listIds,
745 thrust::device_vector<void*>& listData,
746 thrust::device_vector<void*>& listIndices,
747 IndicesOptions indicesOptions,
748 thrust::device_vector<int>& listLengths,
754 Tensor<float, 2, true>& outDistances,
756 Tensor<long, 2, true>& outIndices,
758 constexpr
int kMinQueryTileSize = 8;
759 constexpr
int kMaxQueryTileSize = 128;
760 constexpr
int kThrustMemSize = 16384;
762 int nprobe = listIds.getSize(1);
764 auto& mem = res->getMemoryManagerCurrentDevice();
765 auto stream = res->getDefaultStreamCurrentDevice();
769 DeviceTensor<char, 1, true> thrustMem1(
770 mem, {kThrustMemSize}, stream);
771 DeviceTensor<char, 1, true> thrustMem2(
772 mem, {kThrustMemSize}, stream);
773 DeviceTensor<char, 1, true>* thrustMem[2] =
774 {&thrustMem1, &thrustMem2};
778 size_t sizeAvailable = mem.getSizeAvailable();
782 constexpr
int kNProbeSplit = 8;
783 int pass2Chunks = std::min(nprobe, kNProbeSplit);
785 size_t sizeForFirstSelectPass =
786 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
789 size_t sizePerQuery =
791 ((nprobe *
sizeof(int) +
sizeof(
int)) +
792 nprobe * maxListLength *
sizeof(
float) +
793 sizeForFirstSelectPass);
795 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
797 if (queryTileSize < kMinQueryTileSize) {
798 queryTileSize = kMinQueryTileSize;
799 }
else if (queryTileSize > kMaxQueryTileSize) {
800 queryTileSize = kMaxQueryTileSize;
805 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
806 std::numeric_limits<int>::max());
811 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
812 mem, {queryTileSize * nprobe + 1}, stream);
813 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
814 mem, {queryTileSize * nprobe + 1}, stream);
816 DeviceTensor<int, 2, true> prefixSumOffsets1(
817 prefixSumOffsetSpace1[1].data(),
818 {queryTileSize, nprobe});
819 DeviceTensor<int, 2, true> prefixSumOffsets2(
820 prefixSumOffsetSpace2[1].data(),
821 {queryTileSize, nprobe});
822 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
823 {&prefixSumOffsets1, &prefixSumOffsets2};
827 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
831 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
836 DeviceTensor<float, 1, true> allDistances1(
837 mem, {queryTileSize * nprobe * maxListLength}, stream);
838 DeviceTensor<float, 1, true> allDistances2(
839 mem, {queryTileSize * nprobe * maxListLength}, stream);
840 DeviceTensor<float, 1, true>* allDistances[2] =
841 {&allDistances1, &allDistances2};
843 DeviceTensor<float, 3, true> heapDistances1(
844 mem, {queryTileSize, pass2Chunks, k}, stream);
845 DeviceTensor<float, 3, true> heapDistances2(
846 mem, {queryTileSize, pass2Chunks, k}, stream);
847 DeviceTensor<float, 3, true>* heapDistances[2] =
848 {&heapDistances1, &heapDistances2};
850 DeviceTensor<int, 3, true> heapIndices1(
851 mem, {queryTileSize, pass2Chunks, k}, stream);
852 DeviceTensor<int, 3, true> heapIndices2(
853 mem, {queryTileSize, pass2Chunks, k}, stream);
854 DeviceTensor<int, 3, true>* heapIndices[2] =
855 {&heapIndices1, &heapIndices2};
857 auto streams = res->getAlternateStreamsCurrentDevice();
858 streamWait(streams, {stream});
862 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
863 int numQueriesInTile =
864 std::min(queryTileSize, queries.getSize(0) - query);
866 auto prefixSumOffsetsView =
867 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
870 listIds.narrowOutermost(query, numQueriesInTile);
872 queries.narrowOutermost(query, numQueriesInTile);
874 auto heapDistancesView =
875 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
876 auto heapIndicesView =
877 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
879 auto outDistanceView =
880 outDistances.narrowOutermost(query, numQueriesInTile);
881 auto outIndicesView =
882 outIndices.narrowOutermost(query, numQueriesInTile);
884 runIVFFlatScanTile(queryView,
890 *thrustMem[curStream],
891 prefixSumOffsetsView,
892 *allDistances[curStream],
902 curStream = (curStream + 1) % 2;
905 streamWait({stream}, streams);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
The class that we use to provide scan specializations.