10 #include "IVFFlatScan.cuh"
11 #include "../GpuResources.h"
12 #include "IVFUtils.cuh"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/DeviceTensor.cuh"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/LoadStoreOperators.cuh"
20 #include "../utils/PtxUtils.cuh"
21 #include "../utils/Reductions.cuh"
22 #include "../utils/StaticUtils.h"
23 #include <thrust/host_vector.h>
25 namespace faiss {
namespace gpu {
28 inline __device__
typename Math<T>::ScalarType l2Distance(T a, T b) {
29 a = Math<T>::sub(a, b);
30 a = Math<T>::mul(a, a);
35 inline __device__
typename Math<T>::ScalarType ipDistance(T a, T b) {
44 template <
int Dims,
bool L2,
typename T>
49 template <
bool L2,
typename T>
51 static __device__
void scan(
float* query,
56 extern __shared__
float smem[];
57 T* vecs = (T*) vecData;
59 for (
int vec = 0; vec < numVecs; ++vec) {
63 for (
int d = threadIdx.x; d < dim; d += blockDim.x) {
65 float queryVal = query[d];
69 curDist = l2Distance(queryVal, vecVal);
71 curDist = ipDistance(queryVal, vecVal);
78 dist = blockReduceAllSum<float, false, true>(dist, smem);
80 if (threadIdx.x == 0) {
81 distanceOut[vec] = dist;
88 template <
bool L2,
typename T>
90 static __device__
void scan(
float* query,
95 extern __shared__
float smem[];
96 T* vecs = (T*) vecData;
98 float queryVal = query[threadIdx.x];
100 constexpr
int kUnroll = 4;
101 int limit = utils::roundDown(numVecs, kUnroll);
103 for (
int i = 0; i < limit; i += kUnroll) {
104 float vecVal[kUnroll];
107 for (
int j = 0; j < kUnroll; ++j) {
112 for (
int j = 0; j < kUnroll; ++j) {
114 vecVal[j] = l2Distance(queryVal, vecVal[j]);
116 vecVal[j] = ipDistance(queryVal, vecVal[j]);
120 blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
122 if (threadIdx.x == 0) {
124 for (
int j = 0; j < kUnroll; ++j) {
125 distanceOut[i + j] = vecVal[j];
131 for (
int i = limit; i < numVecs; ++i) {
135 vecVal = l2Distance(queryVal, vecVal);
137 vecVal = ipDistance(queryVal, vecVal);
140 vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
142 if (threadIdx.x == 0) {
143 distanceOut[i] = vecVal;
149 template <
int Dims,
bool L2,
typename T>
157 auto queryId = blockIdx.y;
158 auto probeId = blockIdx.x;
162 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
164 auto listId = listIds[queryId][probeId];
170 auto query = queries[queryId].
data();
171 auto vecs = allListData[listId];
172 auto numVecs = listLengths[listId];
174 auto distanceOut = distance[outBase].
data();
176 IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
180 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
181 Tensor<int, 2, true>& listIds,
182 thrust::device_vector<void*>& listData,
183 thrust::device_vector<void*>& listIndices,
184 IndicesOptions indicesOptions,
185 thrust::device_vector<int>& listLengths,
186 Tensor<char, 1, true>& thrustMem,
187 Tensor<int, 2, true>& prefixSumOffsets,
188 Tensor<float, 1, true>& allDistances,
189 Tensor<float, 3, true>& heapDistances,
190 Tensor<int, 3, true>& heapIndices,
194 Tensor<float, 2, true>& outDistances,
195 Tensor<long, 2, true>& outIndices,
196 cudaStream_t stream) {
199 runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
202 constexpr
int kMaxThreadsIVF = 512;
207 int dim = queries.getSize(1);
208 int numThreads = std::min(dim, kMaxThreadsIVF);
210 auto grid = dim3(listIds.getSize(1),
212 auto block = dim3(numThreads);
214 auto smem =
sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
216 #define RUN_IVF_FLAT(DIMS, L2, T) \
218 ivfFlatScan<DIMS, L2, T> \
219 <<<grid, block, smem, stream>>>( \
222 listData.data().get(), \
223 listLengths.data().get(), \
228 #ifdef FAISS_USE_FLOAT16
230 #define HANDLE_DIM_CASE(DIMS) \
234 RUN_IVF_FLAT(DIMS, true, half); \
236 RUN_IVF_FLAT(DIMS, true, float); \
240 RUN_IVF_FLAT(DIMS, false, half); \
242 RUN_IVF_FLAT(DIMS, false, float); \
248 #define HANDLE_DIM_CASE(DIMS) \
252 FAISS_ASSERT(false); \
254 RUN_IVF_FLAT(DIMS, true, float); \
258 FAISS_ASSERT(false); \
260 RUN_IVF_FLAT(DIMS, false, float); \
265 #endif // FAISS_USE_FLOAT16
267 if (dim <= kMaxThreadsIVF) {
275 #undef HANDLE_DIM_CASE
279 runPass1SelectLists(prefixSumOffsets,
289 auto flatHeapDistances = heapDistances.downcastInner<2>();
290 auto flatHeapIndices = heapIndices.downcastInner<2>();
292 runPass2SelectLists(flatHeapDistances,
306 runIVFFlatScan(Tensor<float, 2, true>& queries,
307 Tensor<int, 2, true>& listIds,
308 thrust::device_vector<void*>& listData,
309 thrust::device_vector<void*>& listIndices,
310 IndicesOptions indicesOptions,
311 thrust::device_vector<int>& listLengths,
317 Tensor<float, 2, true>& outDistances,
319 Tensor<long, 2, true>& outIndices,
321 constexpr
int kMinQueryTileSize = 8;
322 constexpr
int kMaxQueryTileSize = 128;
323 constexpr
int kThrustMemSize = 16384;
325 int nprobe = listIds.getSize(1);
327 auto& mem = res->getMemoryManagerCurrentDevice();
328 auto stream = res->getDefaultStreamCurrentDevice();
332 DeviceTensor<char, 1, true> thrustMem1(
333 mem, {kThrustMemSize}, stream);
334 DeviceTensor<char, 1, true> thrustMem2(
335 mem, {kThrustMemSize}, stream);
336 DeviceTensor<char, 1, true>* thrustMem[2] =
337 {&thrustMem1, &thrustMem2};
341 size_t sizeAvailable = mem.getSizeAvailable();
345 constexpr
int kNProbeSplit = 8;
346 int pass2Chunks = std::min(nprobe, kNProbeSplit);
348 size_t sizeForFirstSelectPass =
349 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
352 size_t sizePerQuery =
354 ((nprobe *
sizeof(int) +
sizeof(
int)) +
355 nprobe * maxListLength *
sizeof(
float) +
356 sizeForFirstSelectPass);
358 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
360 if (queryTileSize < kMinQueryTileSize) {
361 queryTileSize = kMinQueryTileSize;
362 }
else if (queryTileSize > kMaxQueryTileSize) {
363 queryTileSize = kMaxQueryTileSize;
368 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
369 std::numeric_limits<int>::max());
374 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
375 mem, {queryTileSize * nprobe + 1}, stream);
376 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
377 mem, {queryTileSize * nprobe + 1}, stream);
379 DeviceTensor<int, 2, true> prefixSumOffsets1(
380 prefixSumOffsetSpace1[1].data(),
381 {queryTileSize, nprobe});
382 DeviceTensor<int, 2, true> prefixSumOffsets2(
383 prefixSumOffsetSpace2[1].data(),
384 {queryTileSize, nprobe});
385 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
386 {&prefixSumOffsets1, &prefixSumOffsets2};
390 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
394 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
399 DeviceTensor<float, 1, true> allDistances1(
400 mem, {queryTileSize * nprobe * maxListLength}, stream);
401 DeviceTensor<float, 1, true> allDistances2(
402 mem, {queryTileSize * nprobe * maxListLength}, stream);
403 DeviceTensor<float, 1, true>* allDistances[2] =
404 {&allDistances1, &allDistances2};
406 DeviceTensor<float, 3, true> heapDistances1(
407 mem, {queryTileSize, pass2Chunks, k}, stream);
408 DeviceTensor<float, 3, true> heapDistances2(
409 mem, {queryTileSize, pass2Chunks, k}, stream);
410 DeviceTensor<float, 3, true>* heapDistances[2] =
411 {&heapDistances1, &heapDistances2};
413 DeviceTensor<int, 3, true> heapIndices1(
414 mem, {queryTileSize, pass2Chunks, k}, stream);
415 DeviceTensor<int, 3, true> heapIndices2(
416 mem, {queryTileSize, pass2Chunks, k}, stream);
417 DeviceTensor<int, 3, true>* heapIndices[2] =
418 {&heapIndices1, &heapIndices2};
420 auto streams = res->getAlternateStreamsCurrentDevice();
421 streamWait(streams, {stream});
425 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
426 int numQueriesInTile =
427 std::min(queryTileSize, queries.getSize(0) - query);
429 auto prefixSumOffsetsView =
430 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
433 listIds.narrowOutermost(query, numQueriesInTile);
435 queries.narrowOutermost(query, numQueriesInTile);
437 auto heapDistancesView =
438 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
439 auto heapIndicesView =
440 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
442 auto outDistanceView =
443 outDistances.narrowOutermost(query, numQueriesInTile);
444 auto outIndicesView =
445 outIndices.narrowOutermost(query, numQueriesInTile);
447 runIVFFlatScanTile(queryView,
453 *thrustMem[curStream],
454 prefixSumOffsetsView,
455 *allDistances[curStream],
465 curStream = (curStream + 1) % 2;
468 streamWait({stream}, streams);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
__host__ __device__ IndexT getSize(int i) const
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
The class that we use to provide scan specializations.