Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFFlatScan.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IVFFlatScan.cuh"
12 #include "../GpuResources.h"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/Float16.cuh"
19 #include "../utils/MathOperators.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/PtxUtils.cuh"
22 #include "../utils/Reductions.cuh"
23 #include "../utils/StaticUtils.h"
24 #include <thrust/host_vector.h>
25 
26 namespace faiss { namespace gpu {
27 
28 template <typename T>
29 inline __device__ typename Math<T>::ScalarType l2Distance(T a, T b) {
30  a = Math<T>::sub(a, b);
31  a = Math<T>::mul(a, a);
32  return Math<T>::reduceAdd(a);
33 }
34 
35 template <typename T>
36 inline __device__ typename Math<T>::ScalarType ipDistance(T a, T b) {
37  return Math<T>::reduceAdd(Math<T>::mul(a, b));
38 }
39 
40 // For list scanning, even if the input data is `half`, we perform all
41 // math in float32, because the code is memory b/w bound, and the
42 // added precision for accumulation is useful
43 
44 /// The class that we use to provide scan specializations
45 template <int Dims, bool L2, typename T>
46 struct IVFFlatScan {
47 };
48 
49 // Fallback implementation: works for any dimension size
50 template <bool L2, typename T>
51 struct IVFFlatScan<-1, L2, T> {
52  static __device__ void scan(float* query,
53  void* vecData,
54  int numVecs,
55  int dim,
56  float* distanceOut) {
57  extern __shared__ float smem[];
58  T* vecs = (T*) vecData;
59 
60  for (int vec = 0; vec < numVecs; ++vec) {
61  // Reduce in dist
62  float dist = 0.0f;
63 
64  for (int d = threadIdx.x; d < dim; d += blockDim.x) {
65  float vecVal = ConvertTo<float>::to(vecs[vec * dim + d]);
66  float queryVal = query[d];
67  float curDist;
68 
69  if (L2) {
70  curDist = l2Distance(queryVal, vecVal);
71  } else {
72  curDist = ipDistance(queryVal, vecVal);
73  }
74 
75  dist += curDist;
76  }
77 
78  // Reduce distance within block
79  dist = blockReduceAllSum<float, false, true>(dist, smem);
80 
81  if (threadIdx.x == 0) {
82  distanceOut[vec] = dist;
83  }
84  }
85  }
86 };
87 
88 // implementation: works for # dims == blockDim.x
89 template <bool L2, typename T>
90 struct IVFFlatScan<0, L2, T> {
91  static __device__ void scan(float* query,
92  void* vecData,
93  int numVecs,
94  int dim,
95  float* distanceOut) {
96  extern __shared__ float smem[];
97  T* vecs = (T*) vecData;
98 
99  float queryVal = query[threadIdx.x];
100 
101  constexpr int kUnroll = 4;
102  int limit = utils::roundDown(numVecs, kUnroll);
103 
104  for (int i = 0; i < limit; i += kUnroll) {
105  float vecVal[kUnroll];
106 
107 #pragma unroll
108  for (int j = 0; j < kUnroll; ++j) {
109  vecVal[j] = ConvertTo<float>::to(vecs[(i + j) * dim + threadIdx.x]);
110  }
111 
112 #pragma unroll
113  for (int j = 0; j < kUnroll; ++j) {
114  if (L2) {
115  vecVal[j] = l2Distance(queryVal, vecVal[j]);
116  } else {
117  vecVal[j] = ipDistance(queryVal, vecVal[j]);
118  }
119  }
120 
121  blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
122 
123  if (threadIdx.x == 0) {
124 #pragma unroll
125  for (int j = 0; j < kUnroll; ++j) {
126  distanceOut[i + j] = vecVal[j];
127  }
128  }
129  }
130 
131  // Handle remainder
132  for (int i = limit; i < numVecs; ++i) {
133  float vecVal = ConvertTo<float>::to(vecs[i * dim + threadIdx.x]);
134 
135  if (L2) {
136  vecVal = l2Distance(queryVal, vecVal);
137  } else {
138  vecVal = ipDistance(queryVal, vecVal);
139  }
140 
141  vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
142 
143  if (threadIdx.x == 0) {
144  distanceOut[i] = vecVal;
145  }
146  }
147  }
148 };
149 
150 // 64-d float32 implementation
151 template <bool L2>
152 struct IVFFlatScan<64, L2, float> {
153  static constexpr int kDims = 64;
154 
155  static __device__ void scan(float* query,
156  void* vecData,
157  int numVecs,
158  int dim,
159  float* distanceOut) {
160  // Each warp reduces a single 64-d vector; each lane loads a float2
161  float* vecs = (float*) vecData;
162 
163  int laneId = getLaneId();
164  int warpId = threadIdx.x / kWarpSize;
165  int numWarps = blockDim.x / kWarpSize;
166 
167  float2 queryVal = *(float2*) &query[laneId * 2];
168 
169  constexpr int kUnroll = 4;
170  float2 vecVal[kUnroll];
171 
172  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
173 
174  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
175 #pragma unroll
176  for (int j = 0; j < kUnroll; ++j) {
177  // Vector we are loading from is i
178  // Dim we are loading from is laneId * 2
179  vecVal[j] = *(float2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
180  }
181 
182  float dist[kUnroll];
183 
184 #pragma unroll
185  for (int j = 0; j < kUnroll; ++j) {
186  if (L2) {
187  dist[j] = l2Distance(queryVal, vecVal[j]);
188  } else {
189  dist[j] = ipDistance(queryVal, vecVal[j]);
190  }
191  }
192 
193  // Reduce within the warp
194 #pragma unroll
195  for (int j = 0; j < kUnroll; ++j) {
196  dist[j] = warpReduceAllSum(dist[j]);
197  }
198 
199  if (laneId == 0) {
200 #pragma unroll
201  for (int j = 0; j < kUnroll; ++j) {
202  distanceOut[i + j * numWarps] = dist[j];
203  }
204  }
205  }
206 
207  // Handle remainder
208  for (int i = limit + warpId; i < numVecs; i += numWarps) {
209  vecVal[0] = *(float2*) &vecs[i * kDims + laneId * 2];
210  float dist;
211  if (L2) {
212  dist = l2Distance(queryVal, vecVal[0]);
213  } else {
214  dist = ipDistance(queryVal, vecVal[0]);
215  }
216 
217  dist = warpReduceAllSum(dist);
218 
219  if (laneId == 0) {
220  distanceOut[i] = dist;
221  }
222  }
223  }
224 };
225 
226 #ifdef FAISS_USE_FLOAT16
227 
228 // float16 implementation
229 template <bool L2>
230 struct IVFFlatScan<64, L2, half> {
231  static constexpr int kDims = 64;
232 
233  static __device__ void scan(float* query,
234  void* vecData,
235  int numVecs,
236  int dim,
237  float* distanceOut) {
238  // Each warp reduces a single 64-d vector; each lane loads a half2
239  half* vecs = (half*) vecData;
240 
241  int laneId = getLaneId();
242  int warpId = threadIdx.x / kWarpSize;
243  int numWarps = blockDim.x / kWarpSize;
244 
245  float2 queryVal = *(float2*) &query[laneId * 2];
246 
247  constexpr int kUnroll = 4;
248 
249  half2 vecVal[kUnroll];
250 
251  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
252 
253  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
254 #pragma unroll
255  for (int j = 0; j < kUnroll; ++j) {
256  // Vector we are loading from is i
257  // Dim we are loading from is laneId * 2
258  vecVal[j] = *(half2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
259  }
260 
261  float dist[kUnroll];
262 
263 #pragma unroll
264  for (int j = 0; j < kUnroll; ++j) {
265  if (L2) {
266  dist[j] = l2Distance(queryVal, __half22float2(vecVal[j]));
267  } else {
268  dist[j] = ipDistance(queryVal, __half22float2(vecVal[j]));
269  }
270  }
271 
272  // Reduce within the warp
273 #pragma unroll
274  for (int j = 0; j < kUnroll; ++j) {
275  dist[j] = warpReduceAllSum(dist[j]);
276  }
277 
278  if (laneId == 0) {
279 #pragma unroll
280  for (int j = 0; j < kUnroll; ++j) {
281  distanceOut[i + j * numWarps] = dist[j];
282  }
283  }
284  }
285 
286  // Handle remainder
287  for (int i = limit + warpId; i < numVecs; i += numWarps) {
288  vecVal[0] = *(half2*) &vecs[i * kDims + laneId * 2];
289 
290  float dist;
291  if (L2) {
292  dist = l2Distance(queryVal, __half22float2(vecVal[0]));
293  } else {
294  dist = ipDistance(queryVal, __half22float2(vecVal[0]));
295  }
296 
297  dist = warpReduceAllSum(dist);
298 
299  if (laneId == 0) {
300  distanceOut[i] = dist;
301  }
302  }
303  }
304 };
305 
306 #endif
307 
308 // 128-d float32 implementation
309 template <bool L2>
310 struct IVFFlatScan<128, L2, float> {
311  static constexpr int kDims = 128;
312 
313  static __device__ void scan(float* query,
314  void* vecData,
315  int numVecs,
316  int dim,
317  float* distanceOut) {
318  // Each warp reduces a single 128-d vector; each lane loads a float4
319  float* vecs = (float*) vecData;
320 
321  int laneId = getLaneId();
322  int warpId = threadIdx.x / kWarpSize;
323  int numWarps = blockDim.x / kWarpSize;
324 
325  float4 queryVal = *(float4*) &query[laneId * 4];
326 
327  constexpr int kUnroll = 4;
328  float4 vecVal[kUnroll];
329 
330  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
331 
332  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
333 #pragma unroll
334  for (int j = 0; j < kUnroll; ++j) {
335  // Vector we are loading from is i
336  // Dim we are loading from is laneId * 4
337  vecVal[j] = *(float4*) &vecs[(i + j * numWarps) * kDims + laneId * 4];
338  }
339 
340  float dist[kUnroll];
341 
342 #pragma unroll
343  for (int j = 0; j < kUnroll; ++j) {
344  if (L2) {
345  dist[j] = l2Distance(queryVal, vecVal[j]);
346  } else {
347  dist[j] = ipDistance(queryVal, vecVal[j]);
348  }
349  }
350 
351  // Reduce within the warp
352 #pragma unroll
353  for (int j = 0; j < kUnroll; ++j) {
354  dist[j] = warpReduceAllSum(dist[j]);
355  }
356 
357  if (laneId == 0) {
358 #pragma unroll
359  for (int j = 0; j < kUnroll; ++j) {
360  distanceOut[i + j * numWarps] = dist[j];
361  }
362  }
363  }
364 
365  // Handle remainder
366  for (int i = limit + warpId; i < numVecs; i += numWarps) {
367  vecVal[0] = *(float4*) &vecs[i * kDims + laneId * 4];
368  float dist;
369  if (L2) {
370  dist = l2Distance(queryVal, vecVal[0]);
371  } else {
372  dist = ipDistance(queryVal, vecVal[0]);
373  }
374 
375  dist = warpReduceAllSum(dist);
376 
377  if (laneId == 0) {
378  distanceOut[i] = dist;
379  }
380  }
381  }
382 };
383 
384 #ifdef FAISS_USE_FLOAT16
385 
386 // float16 implementation
387 template <bool L2>
388 struct IVFFlatScan<128, L2, half> {
389  static constexpr int kDims = 128;
390 
391  static __device__ void scan(float* query,
392  void* vecData,
393  int numVecs,
394  int dim,
395  float* distanceOut) {
396  // Each warp reduces a single 128-d vector; each lane loads a Half4
397  half* vecs = (half*) vecData;
398 
399  int laneId = getLaneId();
400  int warpId = threadIdx.x / kWarpSize;
401  int numWarps = blockDim.x / kWarpSize;
402 
403  float4 queryVal = *(float4*) &query[laneId * 4];
404 
405  constexpr int kUnroll = 4;
406 
407  Half4 vecVal[kUnroll];
408 
409  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
410 
411  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
412 #pragma unroll
413  for (int j = 0; j < kUnroll; ++j) {
414  // Vector we are loading from is i
415  // Dim we are loading from is laneId * 4
416  vecVal[j] =
418  &vecs[(i + j * numWarps) * kDims + laneId * 4]);
419  }
420 
421  float dist[kUnroll];
422 
423 #pragma unroll
424  for (int j = 0; j < kUnroll; ++j) {
425  if (L2) {
426  dist[j] = l2Distance(queryVal, half4ToFloat4(vecVal[j]));
427  } else {
428  dist[j] = ipDistance(queryVal, half4ToFloat4(vecVal[j]));
429  }
430  }
431 
432  // Reduce within the warp
433 #pragma unroll
434  for (int j = 0; j < kUnroll; ++j) {
435  dist[j] = warpReduceAllSum(dist[j]);
436  }
437 
438  if (laneId == 0) {
439 #pragma unroll
440  for (int j = 0; j < kUnroll; ++j) {
441  distanceOut[i + j * numWarps] = dist[j];
442  }
443  }
444  }
445 
446  // Handle remainder
447  for (int i = limit + warpId; i < numVecs; i += numWarps) {
448  vecVal[0] = LoadStore<Half4>::load(&vecs[i * kDims + laneId * 4]);
449 
450  float dist;
451  if (L2) {
452  dist = l2Distance(queryVal, half4ToFloat4(vecVal[0]));
453  } else {
454  dist = ipDistance(queryVal, half4ToFloat4(vecVal[0]));
455  }
456 
457  dist = warpReduceAllSum(dist);
458 
459  if (laneId == 0) {
460  distanceOut[i] = dist;
461  }
462  }
463  }
464 };
465 
466 #endif
467 
468 // 256-d float32 implementation
469 template <bool L2>
470 struct IVFFlatScan<256, L2, float> {
471  static constexpr int kDims = 256;
472 
473  static __device__ void scan(float* query,
474  void* vecData,
475  int numVecs,
476  int dim,
477  float* distanceOut) {
478  // A specialization here to load per-warp seems to be worse, since
479  // we're already running at near memory b/w peak
481  vecData,
482  numVecs,
483  dim,
484  distanceOut);
485  }
486 };
487 
488 #ifdef FAISS_USE_FLOAT16
489 
490 // float16 implementation
491 template <bool L2>
492 struct IVFFlatScan<256, L2, half> {
493  static constexpr int kDims = 256;
494 
495  static __device__ void scan(float* query,
496  void* vecData,
497  int numVecs,
498  int dim,
499  float* distanceOut) {
500  // Each warp reduces a single 256-d vector; each lane loads a Half8
501  half* vecs = (half*) vecData;
502 
503  int laneId = getLaneId();
504  int warpId = threadIdx.x / kWarpSize;
505  int numWarps = blockDim.x / kWarpSize;
506 
507  // This is not a contiguous load, but we only have to load these two
508  // values, so that we can load by Half8 below
509  float4 queryValA = *(float4*) &query[laneId * 8];
510  float4 queryValB = *(float4*) &query[laneId * 8 + 4];
511 
512  constexpr int kUnroll = 4;
513 
514  Half8 vecVal[kUnroll];
515 
516  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
517 
518  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
519 #pragma unroll
520  for (int j = 0; j < kUnroll; ++j) {
521  // Vector we are loading from is i
522  // Dim we are loading from is laneId * 8
523  vecVal[j] =
525  &vecs[(i + j * numWarps) * kDims + laneId * 8]);
526  }
527 
528  float dist[kUnroll];
529 
530 #pragma unroll
531  for (int j = 0; j < kUnroll; ++j) {
532  if (L2) {
533  dist[j] = l2Distance(queryValA, half4ToFloat4(vecVal[j].a));
534  dist[j] += l2Distance(queryValB, half4ToFloat4(vecVal[j].b));
535  } else {
536  dist[j] = ipDistance(queryValA, half4ToFloat4(vecVal[j].a));
537  dist[j] += ipDistance(queryValB, half4ToFloat4(vecVal[j].b));
538  }
539  }
540 
541  // Reduce within the warp
542 #pragma unroll
543  for (int j = 0; j < kUnroll; ++j) {
544  dist[j] = warpReduceAllSum(dist[j]);
545  }
546 
547  if (laneId == 0) {
548 #pragma unroll
549  for (int j = 0; j < kUnroll; ++j) {
550  distanceOut[i + j * numWarps] = dist[j];
551  }
552  }
553  }
554 
555  // Handle remainder
556  for (int i = limit + warpId; i < numVecs; i += numWarps) {
557  vecVal[0] = LoadStore<Half8>::load(&vecs[i * kDims + laneId * 8]);
558 
559  float dist;
560  if (L2) {
561  dist = l2Distance(queryValA, half4ToFloat4(vecVal[0].a));
562  dist += l2Distance(queryValB, half4ToFloat4(vecVal[0].b));
563  } else {
564  dist = ipDistance(queryValA, half4ToFloat4(vecVal[0].a));
565  dist += ipDistance(queryValB, half4ToFloat4(vecVal[0].b));
566  }
567 
568  dist = warpReduceAllSum(dist);
569 
570  if (laneId == 0) {
571  distanceOut[i] = dist;
572  }
573  }
574  }
575 };
576 
577 #endif
578 
579 template <int Dims, bool L2, typename T>
580 __global__ void
581 ivfFlatScan(Tensor<float, 2, true> queries,
582  Tensor<int, 2, true> listIds,
583  void** allListData,
584  int* listLengths,
585  Tensor<int, 2, true> prefixSumOffsets,
586  Tensor<float, 1, true> distance) {
587  auto queryId = blockIdx.y;
588  auto probeId = blockIdx.x;
589 
590  // This is where we start writing out data
591  // We ensure that before the array (at offset -1), there is a 0 value
592  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
593 
594  auto listId = listIds[queryId][probeId];
595  // Safety guard in case NaNs in input cause no list ID to be generated
596  if (listId == -1) {
597  return;
598  }
599 
600  auto query = queries[queryId].data();
601  auto vecs = allListData[listId];
602  auto numVecs = listLengths[listId];
603  auto dim = queries.getSize(1);
604  auto distanceOut = distance[outBase].data();
605 
606  IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
607 }
608 
609 void
610 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
611  Tensor<int, 2, true>& listIds,
612  thrust::device_vector<void*>& listData,
613  thrust::device_vector<void*>& listIndices,
614  IndicesOptions indicesOptions,
615  thrust::device_vector<int>& listLengths,
616  Tensor<char, 1, true>& thrustMem,
617  Tensor<int, 2, true>& prefixSumOffsets,
618  Tensor<float, 1, true>& allDistances,
619  Tensor<float, 3, true>& heapDistances,
620  Tensor<int, 3, true>& heapIndices,
621  int k,
622  bool l2Distance,
623  bool useFloat16,
624  Tensor<float, 2, true>& outDistances,
625  Tensor<long, 2, true>& outIndices,
626  cudaStream_t stream) {
627  // Calculate offset lengths, so we know where to write out
628  // intermediate results
629  runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
630 
631  // Calculate distances for vectors within our chunk of lists
632  constexpr int kMaxThreadsIVF = 512;
633 
634  // FIXME: if `half` and # dims is multiple of 2, halve the
635  // threadblock size
636 
637  int dim = queries.getSize(1);
638  int numThreads = std::min(dim, kMaxThreadsIVF);
639 
640  auto grid = dim3(listIds.getSize(1),
641  listIds.getSize(0));
642  auto block = dim3(numThreads);
643  // All exact dim kernels are unrolled by 4, hence the `4`
644  auto smem = sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
645 
646 #define RUN_IVF_FLAT(DIMS, L2, T) \
647  do { \
648  ivfFlatScan<DIMS, L2, T> \
649  <<<grid, block, smem, stream>>>( \
650  queries, \
651  listIds, \
652  listData.data().get(), \
653  listLengths.data().get(), \
654  prefixSumOffsets, \
655  allDistances); \
656  } while (0)
657 
658 #ifdef FAISS_USE_FLOAT16
659 
660 #define HANDLE_DIM_CASE(DIMS) \
661  do { \
662  if (l2Distance) { \
663  if (useFloat16) { \
664  RUN_IVF_FLAT(DIMS, true, half); \
665  } else { \
666  RUN_IVF_FLAT(DIMS, true, float); \
667  } \
668  } else { \
669  if (useFloat16) { \
670  RUN_IVF_FLAT(DIMS, false, half); \
671  } else { \
672  RUN_IVF_FLAT(DIMS, false, float); \
673  } \
674  } \
675  } while (0)
676 #else
677 
678 #define HANDLE_DIM_CASE(DIMS) \
679  do { \
680  if (l2Distance) { \
681  if (useFloat16) { \
682  FAISS_ASSERT(false); \
683  } else { \
684  RUN_IVF_FLAT(DIMS, true, float); \
685  } \
686  } else { \
687  if (useFloat16) { \
688  FAISS_ASSERT(false); \
689  } else { \
690  RUN_IVF_FLAT(DIMS, false, float); \
691  } \
692  } \
693  } while (0)
694 
695 #endif // FAISS_USE_FLOAT16
696 
697  if (dim == 64) {
698  HANDLE_DIM_CASE(64);
699  } else if (dim == 128) {
700  HANDLE_DIM_CASE(128);
701  } else if (dim == 256) {
702  HANDLE_DIM_CASE(256);
703  } else if (dim <= kMaxThreadsIVF) {
704  HANDLE_DIM_CASE(0);
705  } else {
706  HANDLE_DIM_CASE(-1);
707  }
708 
709  CUDA_TEST_ERROR();
710 
711 #undef HANDLE_DIM_CASE
712 #undef RUN_IVF_FLAT
713 
714  // k-select the output in chunks, to increase parallelism
715  runPass1SelectLists(prefixSumOffsets,
716  allDistances,
717  listIds.getSize(1),
718  k,
719  !l2Distance, // L2 distance chooses smallest
720  heapDistances,
721  heapIndices,
722  stream);
723 
724  // k-select final output
725  auto flatHeapDistances = heapDistances.downcastInner<2>();
726  auto flatHeapIndices = heapIndices.downcastInner<2>();
727 
728  runPass2SelectLists(flatHeapDistances,
729  flatHeapIndices,
730  listIndices,
731  indicesOptions,
732  prefixSumOffsets,
733  listIds,
734  k,
735  !l2Distance, // L2 distance chooses smallest
736  outDistances,
737  outIndices,
738  stream);
739 }
740 
741 void
742 runIVFFlatScan(Tensor<float, 2, true>& queries,
743  Tensor<int, 2, true>& listIds,
744  thrust::device_vector<void*>& listData,
745  thrust::device_vector<void*>& listIndices,
746  IndicesOptions indicesOptions,
747  thrust::device_vector<int>& listLengths,
748  int maxListLength,
749  int k,
750  bool l2Distance,
751  bool useFloat16,
752  // output
753  Tensor<float, 2, true>& outDistances,
754  // output
755  Tensor<long, 2, true>& outIndices,
756  GpuResources* res) {
757  constexpr int kMinQueryTileSize = 8;
758  constexpr int kMaxQueryTileSize = 128;
759  constexpr int kThrustMemSize = 16384;
760 
761  int nprobe = listIds.getSize(1);
762 
763  auto& mem = res->getMemoryManagerCurrentDevice();
764  auto stream = res->getDefaultStreamCurrentDevice();
765 
766  // Make a reservation for Thrust to do its dirty work (global memory
767  // cross-block reduction space); hopefully this is large enough.
768  DeviceTensor<char, 1, true> thrustMem1(
769  mem, {kThrustMemSize}, stream);
770  DeviceTensor<char, 1, true> thrustMem2(
771  mem, {kThrustMemSize}, stream);
772  DeviceTensor<char, 1, true>* thrustMem[2] =
773  {&thrustMem1, &thrustMem2};
774 
775  // How much temporary storage is available?
776  // If possible, we'd like to fit within the space available.
777  size_t sizeAvailable = mem.getSizeAvailable();
778 
779  // We run two passes of heap selection
780  // This is the size of the first-level heap passes
781  constexpr int kNProbeSplit = 8;
782  int pass2Chunks = std::min(nprobe, kNProbeSplit);
783 
784  size_t sizeForFirstSelectPass =
785  pass2Chunks * k * (sizeof(float) + sizeof(int));
786 
787  // How much temporary storage we need per each query
788  size_t sizePerQuery =
789  2 * // # streams
790  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
791  nprobe * maxListLength * sizeof(float) + // allDistances
792  sizeForFirstSelectPass);
793 
794  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
795 
796  if (queryTileSize < kMinQueryTileSize) {
797  queryTileSize = kMinQueryTileSize;
798  } else if (queryTileSize > kMaxQueryTileSize) {
799  queryTileSize = kMaxQueryTileSize;
800  }
801 
802  // FIXME: we should adjust queryTileSize to deal with this, since
803  // indexing is in int32
804  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
805  std::numeric_limits<int>::max());
806 
807  // Temporary memory buffers
808  // Make sure there is space prior to the start which will be 0, and
809  // will handle the boundary condition without branches
810  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
811  mem, {queryTileSize * nprobe + 1}, stream);
812  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
813  mem, {queryTileSize * nprobe + 1}, stream);
814 
815  DeviceTensor<int, 2, true> prefixSumOffsets1(
816  prefixSumOffsetSpace1[1].data(),
817  {queryTileSize, nprobe});
818  DeviceTensor<int, 2, true> prefixSumOffsets2(
819  prefixSumOffsetSpace2[1].data(),
820  {queryTileSize, nprobe});
821  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
822  {&prefixSumOffsets1, &prefixSumOffsets2};
823 
824  // Make sure the element before prefixSumOffsets is 0, since we
825  // depend upon simple, boundary-less indexing to get proper results
826  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
827  0,
828  sizeof(int),
829  stream));
830  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
831  0,
832  sizeof(int),
833  stream));
834 
835  DeviceTensor<float, 1, true> allDistances1(
836  mem, {queryTileSize * nprobe * maxListLength}, stream);
837  DeviceTensor<float, 1, true> allDistances2(
838  mem, {queryTileSize * nprobe * maxListLength}, stream);
839  DeviceTensor<float, 1, true>* allDistances[2] =
840  {&allDistances1, &allDistances2};
841 
842  DeviceTensor<float, 3, true> heapDistances1(
843  mem, {queryTileSize, pass2Chunks, k}, stream);
844  DeviceTensor<float, 3, true> heapDistances2(
845  mem, {queryTileSize, pass2Chunks, k}, stream);
846  DeviceTensor<float, 3, true>* heapDistances[2] =
847  {&heapDistances1, &heapDistances2};
848 
849  DeviceTensor<int, 3, true> heapIndices1(
850  mem, {queryTileSize, pass2Chunks, k}, stream);
851  DeviceTensor<int, 3, true> heapIndices2(
852  mem, {queryTileSize, pass2Chunks, k}, stream);
853  DeviceTensor<int, 3, true>* heapIndices[2] =
854  {&heapIndices1, &heapIndices2};
855 
856  auto streams = res->getAlternateStreamsCurrentDevice();
857  streamWait(streams, {stream});
858 
859  int curStream = 0;
860 
861  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
862  int numQueriesInTile =
863  std::min(queryTileSize, queries.getSize(0) - query);
864 
865  auto prefixSumOffsetsView =
866  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
867 
868  auto listIdsView =
869  listIds.narrowOutermost(query, numQueriesInTile);
870  auto queryView =
871  queries.narrowOutermost(query, numQueriesInTile);
872 
873  auto heapDistancesView =
874  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
875  auto heapIndicesView =
876  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
877 
878  auto outDistanceView =
879  outDistances.narrowOutermost(query, numQueriesInTile);
880  auto outIndicesView =
881  outIndices.narrowOutermost(query, numQueriesInTile);
882 
883  runIVFFlatScanTile(queryView,
884  listIdsView,
885  listData,
886  listIndices,
887  indicesOptions,
888  listLengths,
889  *thrustMem[curStream],
890  prefixSumOffsetsView,
891  *allDistances[curStream],
892  heapDistancesView,
893  heapIndicesView,
894  k,
895  l2Distance,
896  useFloat16,
897  outDistanceView,
898  outIndicesView,
899  streams[curStream]);
900 
901  curStream = (curStream + 1) % 2;
902  }
903 
904  streamWait({stream}, streams);
905 }
906 
907 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
The class that we use to provide scan specializations.
Definition: IVFFlatScan.cu:46