Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFFlatScan.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "IVFFlatScan.cuh"
13 #include "../GpuResources.h"
14 #include "IVFUtils.cuh"
15 #include "../utils/ConversionOperators.cuh"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/DeviceTensor.cuh"
19 #include "../utils/Float16.cuh"
20 #include "../utils/MathOperators.cuh"
21 #include "../utils/LoadStoreOperators.cuh"
22 #include "../utils/PtxUtils.cuh"
23 #include "../utils/Reductions.cuh"
24 #include "../utils/StaticUtils.h"
25 #include <thrust/host_vector.h>
26 
27 namespace faiss { namespace gpu {
28 
29 template <typename T>
30 inline __device__ typename Math<T>::ScalarType l2Distance(T a, T b) {
31  a = Math<T>::sub(a, b);
32  a = Math<T>::mul(a, a);
33  return Math<T>::reduceAdd(a);
34 }
35 
36 template <typename T>
37 inline __device__ typename Math<T>::ScalarType ipDistance(T a, T b) {
38  return Math<T>::reduceAdd(Math<T>::mul(a, b));
39 }
40 
41 // For list scanning, even if the input data is `half`, we perform all
42 // math in float32, because the code is memory b/w bound, and the
43 // added precision for accumulation is useful
44 
45 /// The class that we use to provide scan specializations
46 template <int Dims, bool L2, typename T>
47 struct IVFFlatScan {
48 };
49 
50 // Fallback implementation: works for any dimension size
51 template <bool L2, typename T>
52 struct IVFFlatScan<-1, L2, T> {
53  static __device__ void scan(float* query,
54  void* vecData,
55  int numVecs,
56  int dim,
57  float* distanceOut) {
58  extern __shared__ float smem[];
59  T* vecs = (T*) vecData;
60 
61  for (int vec = 0; vec < numVecs; ++vec) {
62  // Reduce in dist
63  float dist = 0.0f;
64 
65  for (int d = threadIdx.x; d < dim; d += blockDim.x) {
66  float vecVal = ConvertTo<float>::to(vecs[vec * dim + d]);
67  float queryVal = query[d];
68  float curDist;
69 
70  if (L2) {
71  curDist = l2Distance(queryVal, vecVal);
72  } else {
73  curDist = ipDistance(queryVal, vecVal);
74  }
75 
76  dist += curDist;
77  }
78 
79  // Reduce distance within block
80  dist = blockReduceAllSum<float, false, true>(dist, smem);
81 
82  if (threadIdx.x == 0) {
83  distanceOut[vec] = dist;
84  }
85  }
86  }
87 };
88 
89 // implementation: works for # dims == blockDim.x
90 template <bool L2, typename T>
91 struct IVFFlatScan<0, L2, T> {
92  static __device__ void scan(float* query,
93  void* vecData,
94  int numVecs,
95  int dim,
96  float* distanceOut) {
97  extern __shared__ float smem[];
98  T* vecs = (T*) vecData;
99 
100  float queryVal = query[threadIdx.x];
101 
102  constexpr int kUnroll = 4;
103  int limit = utils::roundDown(numVecs, kUnroll);
104 
105  for (int i = 0; i < limit; i += kUnroll) {
106  float vecVal[kUnroll];
107 
108 #pragma unroll
109  for (int j = 0; j < kUnroll; ++j) {
110  vecVal[j] = ConvertTo<float>::to(vecs[(i + j) * dim + threadIdx.x]);
111  }
112 
113 #pragma unroll
114  for (int j = 0; j < kUnroll; ++j) {
115  if (L2) {
116  vecVal[j] = l2Distance(queryVal, vecVal[j]);
117  } else {
118  vecVal[j] = ipDistance(queryVal, vecVal[j]);
119  }
120  }
121 
122  blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
123 
124  if (threadIdx.x == 0) {
125 #pragma unroll
126  for (int j = 0; j < kUnroll; ++j) {
127  distanceOut[i + j] = vecVal[j];
128  }
129  }
130  }
131 
132  // Handle remainder
133  for (int i = limit; i < numVecs; ++i) {
134  float vecVal = ConvertTo<float>::to(vecs[i * dim + threadIdx.x]);
135 
136  if (L2) {
137  vecVal = l2Distance(queryVal, vecVal);
138  } else {
139  vecVal = ipDistance(queryVal, vecVal);
140  }
141 
142  vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
143 
144  if (threadIdx.x == 0) {
145  distanceOut[i] = vecVal;
146  }
147  }
148  }
149 };
150 
151 // 64-d float32 implementation
152 template <bool L2>
153 struct IVFFlatScan<64, L2, float> {
154  static constexpr int kDims = 64;
155 
156  static __device__ void scan(float* query,
157  void* vecData,
158  int numVecs,
159  int dim,
160  float* distanceOut) {
161  // Each warp reduces a single 64-d vector; each lane loads a float2
162  float* vecs = (float*) vecData;
163 
164  int laneId = getLaneId();
165  int warpId = threadIdx.x / kWarpSize;
166  int numWarps = blockDim.x / kWarpSize;
167 
168  float2 queryVal = *(float2*) &query[laneId * 2];
169 
170  constexpr int kUnroll = 4;
171  float2 vecVal[kUnroll];
172 
173  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
174 
175  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
176 #pragma unroll
177  for (int j = 0; j < kUnroll; ++j) {
178  // Vector we are loading from is i
179  // Dim we are loading from is laneId * 2
180  vecVal[j] = *(float2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
181  }
182 
183  float dist[kUnroll];
184 
185 #pragma unroll
186  for (int j = 0; j < kUnroll; ++j) {
187  if (L2) {
188  dist[j] = l2Distance(queryVal, vecVal[j]);
189  } else {
190  dist[j] = ipDistance(queryVal, vecVal[j]);
191  }
192  }
193 
194  // Reduce within the warp
195 #pragma unroll
196  for (int j = 0; j < kUnroll; ++j) {
197  dist[j] = warpReduceAllSum(dist[j]);
198  }
199 
200  if (laneId == 0) {
201 #pragma unroll
202  for (int j = 0; j < kUnroll; ++j) {
203  distanceOut[i + j * numWarps] = dist[j];
204  }
205  }
206  }
207 
208  // Handle remainder
209  for (int i = limit + warpId; i < numVecs; i += numWarps) {
210  vecVal[0] = *(float2*) &vecs[i * kDims + laneId * 2];
211  float dist;
212  if (L2) {
213  dist = l2Distance(queryVal, vecVal[0]);
214  } else {
215  dist = ipDistance(queryVal, vecVal[0]);
216  }
217 
218  dist = warpReduceAllSum(dist);
219 
220  if (laneId == 0) {
221  distanceOut[i] = dist;
222  }
223  }
224  }
225 };
226 
227 #ifdef FAISS_USE_FLOAT16
228 
229 // float16 implementation
230 template <bool L2>
231 struct IVFFlatScan<64, L2, half> {
232  static constexpr int kDims = 64;
233 
234  static __device__ void scan(float* query,
235  void* vecData,
236  int numVecs,
237  int dim,
238  float* distanceOut) {
239  // Each warp reduces a single 64-d vector; each lane loads a half2
240  half* vecs = (half*) vecData;
241 
242  int laneId = getLaneId();
243  int warpId = threadIdx.x / kWarpSize;
244  int numWarps = blockDim.x / kWarpSize;
245 
246  float2 queryVal = *(float2*) &query[laneId * 2];
247 
248  constexpr int kUnroll = 4;
249 
250  half2 vecVal[kUnroll];
251 
252  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
253 
254  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
255 #pragma unroll
256  for (int j = 0; j < kUnroll; ++j) {
257  // Vector we are loading from is i
258  // Dim we are loading from is laneId * 2
259  vecVal[j] = *(half2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
260  }
261 
262  float dist[kUnroll];
263 
264 #pragma unroll
265  for (int j = 0; j < kUnroll; ++j) {
266  if (L2) {
267  dist[j] = l2Distance(queryVal, __half22float2(vecVal[j]));
268  } else {
269  dist[j] = ipDistance(queryVal, __half22float2(vecVal[j]));
270  }
271  }
272 
273  // Reduce within the warp
274 #pragma unroll
275  for (int j = 0; j < kUnroll; ++j) {
276  dist[j] = warpReduceAllSum(dist[j]);
277  }
278 
279  if (laneId == 0) {
280 #pragma unroll
281  for (int j = 0; j < kUnroll; ++j) {
282  distanceOut[i + j * numWarps] = dist[j];
283  }
284  }
285  }
286 
287  // Handle remainder
288  for (int i = limit + warpId; i < numVecs; i += numWarps) {
289  vecVal[0] = *(half2*) &vecs[i * kDims + laneId * 2];
290 
291  float dist;
292  if (L2) {
293  dist = l2Distance(queryVal, __half22float2(vecVal[0]));
294  } else {
295  dist = ipDistance(queryVal, __half22float2(vecVal[0]));
296  }
297 
298  dist = warpReduceAllSum(dist);
299 
300  if (laneId == 0) {
301  distanceOut[i] = dist;
302  }
303  }
304  }
305 };
306 
307 #endif
308 
309 // 128-d float32 implementation
310 template <bool L2>
311 struct IVFFlatScan<128, L2, float> {
312  static constexpr int kDims = 128;
313 
314  static __device__ void scan(float* query,
315  void* vecData,
316  int numVecs,
317  int dim,
318  float* distanceOut) {
319  // Each warp reduces a single 128-d vector; each lane loads a float4
320  float* vecs = (float*) vecData;
321 
322  int laneId = getLaneId();
323  int warpId = threadIdx.x / kWarpSize;
324  int numWarps = blockDim.x / kWarpSize;
325 
326  float4 queryVal = *(float4*) &query[laneId * 4];
327 
328  constexpr int kUnroll = 4;
329  float4 vecVal[kUnroll];
330 
331  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
332 
333  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
334 #pragma unroll
335  for (int j = 0; j < kUnroll; ++j) {
336  // Vector we are loading from is i
337  // Dim we are loading from is laneId * 4
338  vecVal[j] = *(float4*) &vecs[(i + j * numWarps) * kDims + laneId * 4];
339  }
340 
341  float dist[kUnroll];
342 
343 #pragma unroll
344  for (int j = 0; j < kUnroll; ++j) {
345  if (L2) {
346  dist[j] = l2Distance(queryVal, vecVal[j]);
347  } else {
348  dist[j] = ipDistance(queryVal, vecVal[j]);
349  }
350  }
351 
352  // Reduce within the warp
353 #pragma unroll
354  for (int j = 0; j < kUnroll; ++j) {
355  dist[j] = warpReduceAllSum(dist[j]);
356  }
357 
358  if (laneId == 0) {
359 #pragma unroll
360  for (int j = 0; j < kUnroll; ++j) {
361  distanceOut[i + j * numWarps] = dist[j];
362  }
363  }
364  }
365 
366  // Handle remainder
367  for (int i = limit + warpId; i < numVecs; i += numWarps) {
368  vecVal[0] = *(float4*) &vecs[i * kDims + laneId * 4];
369  float dist;
370  if (L2) {
371  dist = l2Distance(queryVal, vecVal[0]);
372  } else {
373  dist = ipDistance(queryVal, vecVal[0]);
374  }
375 
376  dist = warpReduceAllSum(dist);
377 
378  if (laneId == 0) {
379  distanceOut[i] = dist;
380  }
381  }
382  }
383 };
384 
385 #ifdef FAISS_USE_FLOAT16
386 
387 // float16 implementation
388 template <bool L2>
389 struct IVFFlatScan<128, L2, half> {
390  static constexpr int kDims = 128;
391 
392  static __device__ void scan(float* query,
393  void* vecData,
394  int numVecs,
395  int dim,
396  float* distanceOut) {
397  // Each warp reduces a single 128-d vector; each lane loads a Half4
398  half* vecs = (half*) vecData;
399 
400  int laneId = getLaneId();
401  int warpId = threadIdx.x / kWarpSize;
402  int numWarps = blockDim.x / kWarpSize;
403 
404  float4 queryVal = *(float4*) &query[laneId * 4];
405 
406  constexpr int kUnroll = 4;
407 
408  Half4 vecVal[kUnroll];
409 
410  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
411 
412  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
413 #pragma unroll
414  for (int j = 0; j < kUnroll; ++j) {
415  // Vector we are loading from is i
416  // Dim we are loading from is laneId * 4
417  vecVal[j] =
419  &vecs[(i + j * numWarps) * kDims + laneId * 4]);
420  }
421 
422  float dist[kUnroll];
423 
424 #pragma unroll
425  for (int j = 0; j < kUnroll; ++j) {
426  if (L2) {
427  dist[j] = l2Distance(queryVal, half4ToFloat4(vecVal[j]));
428  } else {
429  dist[j] = ipDistance(queryVal, half4ToFloat4(vecVal[j]));
430  }
431  }
432 
433  // Reduce within the warp
434 #pragma unroll
435  for (int j = 0; j < kUnroll; ++j) {
436  dist[j] = warpReduceAllSum(dist[j]);
437  }
438 
439  if (laneId == 0) {
440 #pragma unroll
441  for (int j = 0; j < kUnroll; ++j) {
442  distanceOut[i + j * numWarps] = dist[j];
443  }
444  }
445  }
446 
447  // Handle remainder
448  for (int i = limit + warpId; i < numVecs; i += numWarps) {
449  vecVal[0] = LoadStore<Half4>::load(&vecs[i * kDims + laneId * 4]);
450 
451  float dist;
452  if (L2) {
453  dist = l2Distance(queryVal, half4ToFloat4(vecVal[0]));
454  } else {
455  dist = ipDistance(queryVal, half4ToFloat4(vecVal[0]));
456  }
457 
458  dist = warpReduceAllSum(dist);
459 
460  if (laneId == 0) {
461  distanceOut[i] = dist;
462  }
463  }
464  }
465 };
466 
467 #endif
468 
469 // 256-d float32 implementation
470 template <bool L2>
471 struct IVFFlatScan<256, L2, float> {
472  static constexpr int kDims = 256;
473 
474  static __device__ void scan(float* query,
475  void* vecData,
476  int numVecs,
477  int dim,
478  float* distanceOut) {
479  // A specialization here to load per-warp seems to be worse, since
480  // we're already running at near memory b/w peak
482  vecData,
483  numVecs,
484  dim,
485  distanceOut);
486  }
487 };
488 
489 #ifdef FAISS_USE_FLOAT16
490 
491 // float16 implementation
492 template <bool L2>
493 struct IVFFlatScan<256, L2, half> {
494  static constexpr int kDims = 256;
495 
496  static __device__ void scan(float* query,
497  void* vecData,
498  int numVecs,
499  int dim,
500  float* distanceOut) {
501  // Each warp reduces a single 256-d vector; each lane loads a Half8
502  half* vecs = (half*) vecData;
503 
504  int laneId = getLaneId();
505  int warpId = threadIdx.x / kWarpSize;
506  int numWarps = blockDim.x / kWarpSize;
507 
508  // This is not a contiguous load, but we only have to load these two
509  // values, so that we can load by Half8 below
510  float4 queryValA = *(float4*) &query[laneId * 8];
511  float4 queryValB = *(float4*) &query[laneId * 8 + 4];
512 
513  constexpr int kUnroll = 4;
514 
515  Half8 vecVal[kUnroll];
516 
517  int limit = utils::roundDown(numVecs, kUnroll * numWarps);
518 
519  for (int i = warpId; i < limit; i += kUnroll * numWarps) {
520 #pragma unroll
521  for (int j = 0; j < kUnroll; ++j) {
522  // Vector we are loading from is i
523  // Dim we are loading from is laneId * 8
524  vecVal[j] =
526  &vecs[(i + j * numWarps) * kDims + laneId * 8]);
527  }
528 
529  float dist[kUnroll];
530 
531 #pragma unroll
532  for (int j = 0; j < kUnroll; ++j) {
533  if (L2) {
534  dist[j] = l2Distance(queryValA, half4ToFloat4(vecVal[j].a));
535  dist[j] += l2Distance(queryValB, half4ToFloat4(vecVal[j].b));
536  } else {
537  dist[j] = ipDistance(queryValA, half4ToFloat4(vecVal[j].a));
538  dist[j] += ipDistance(queryValB, half4ToFloat4(vecVal[j].b));
539  }
540  }
541 
542  // Reduce within the warp
543 #pragma unroll
544  for (int j = 0; j < kUnroll; ++j) {
545  dist[j] = warpReduceAllSum(dist[j]);
546  }
547 
548  if (laneId == 0) {
549 #pragma unroll
550  for (int j = 0; j < kUnroll; ++j) {
551  distanceOut[i + j * numWarps] = dist[j];
552  }
553  }
554  }
555 
556  // Handle remainder
557  for (int i = limit + warpId; i < numVecs; i += numWarps) {
558  vecVal[0] = LoadStore<Half8>::load(&vecs[i * kDims + laneId * 8]);
559 
560  float dist;
561  if (L2) {
562  dist = l2Distance(queryValA, half4ToFloat4(vecVal[0].a));
563  dist += l2Distance(queryValB, half4ToFloat4(vecVal[0].b));
564  } else {
565  dist = ipDistance(queryValA, half4ToFloat4(vecVal[0].a));
566  dist += ipDistance(queryValB, half4ToFloat4(vecVal[0].b));
567  }
568 
569  dist = warpReduceAllSum(dist);
570 
571  if (laneId == 0) {
572  distanceOut[i] = dist;
573  }
574  }
575  }
576 };
577 
578 #endif
579 
580 template <int Dims, bool L2, typename T>
581 __global__ void
582 ivfFlatScan(Tensor<float, 2, true> queries,
583  Tensor<int, 2, true> listIds,
584  void** allListData,
585  int* listLengths,
586  Tensor<int, 2, true> prefixSumOffsets,
587  Tensor<float, 1, true> distance) {
588  auto queryId = blockIdx.y;
589  auto probeId = blockIdx.x;
590 
591  // This is where we start writing out data
592  // We ensure that before the array (at offset -1), there is a 0 value
593  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
594 
595  auto listId = listIds[queryId][probeId];
596  // Safety guard in case NaNs in input cause no list ID to be generated
597  if (listId == -1) {
598  return;
599  }
600 
601  auto query = queries[queryId].data();
602  auto vecs = allListData[listId];
603  auto numVecs = listLengths[listId];
604  auto dim = queries.getSize(1);
605  auto distanceOut = distance[outBase].data();
606 
607  IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
608 }
609 
610 void
611 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
612  Tensor<int, 2, true>& listIds,
613  thrust::device_vector<void*>& listData,
614  thrust::device_vector<void*>& listIndices,
615  IndicesOptions indicesOptions,
616  thrust::device_vector<int>& listLengths,
617  Tensor<char, 1, true>& thrustMem,
618  Tensor<int, 2, true>& prefixSumOffsets,
619  Tensor<float, 1, true>& allDistances,
620  Tensor<float, 3, true>& heapDistances,
621  Tensor<int, 3, true>& heapIndices,
622  int k,
623  bool l2Distance,
624  bool useFloat16,
625  Tensor<float, 2, true>& outDistances,
626  Tensor<long, 2, true>& outIndices,
627  cudaStream_t stream) {
628  // Calculate offset lengths, so we know where to write out
629  // intermediate results
630  runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
631 
632  // Calculate distances for vectors within our chunk of lists
633  constexpr int kMaxThreadsIVF = 512;
634 
635  // FIXME: if `half` and # dims is multiple of 2, halve the
636  // threadblock size
637 
638  int dim = queries.getSize(1);
639  int numThreads = std::min(dim, kMaxThreadsIVF);
640 
641  auto grid = dim3(listIds.getSize(1),
642  listIds.getSize(0));
643  auto block = dim3(numThreads);
644  // All exact dim kernels are unrolled by 4, hence the `4`
645  auto smem = sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
646 
647 #define RUN_IVF_FLAT(DIMS, L2, T) \
648  do { \
649  ivfFlatScan<DIMS, L2, T> \
650  <<<grid, block, smem, stream>>>( \
651  queries, \
652  listIds, \
653  listData.data().get(), \
654  listLengths.data().get(), \
655  prefixSumOffsets, \
656  allDistances); \
657  } while (0)
658 
659 #ifdef FAISS_USE_FLOAT16
660 
661 #define HANDLE_DIM_CASE(DIMS) \
662  do { \
663  if (l2Distance) { \
664  if (useFloat16) { \
665  RUN_IVF_FLAT(DIMS, true, half); \
666  } else { \
667  RUN_IVF_FLAT(DIMS, true, float); \
668  } \
669  } else { \
670  if (useFloat16) { \
671  RUN_IVF_FLAT(DIMS, false, half); \
672  } else { \
673  RUN_IVF_FLAT(DIMS, false, float); \
674  } \
675  } \
676  } while (0)
677 #else
678 
679 #define HANDLE_DIM_CASE(DIMS) \
680  do { \
681  if (l2Distance) { \
682  if (useFloat16) { \
683  FAISS_ASSERT(false); \
684  } else { \
685  RUN_IVF_FLAT(DIMS, true, float); \
686  } \
687  } else { \
688  if (useFloat16) { \
689  FAISS_ASSERT(false); \
690  } else { \
691  RUN_IVF_FLAT(DIMS, false, float); \
692  } \
693  } \
694  } while (0)
695 
696 #endif // FAISS_USE_FLOAT16
697 
698  if (dim == 64) {
699  HANDLE_DIM_CASE(64);
700  } else if (dim == 128) {
701  HANDLE_DIM_CASE(128);
702  } else if (dim == 256) {
703  HANDLE_DIM_CASE(256);
704  } else if (dim <= kMaxThreadsIVF) {
705  HANDLE_DIM_CASE(0);
706  } else {
707  HANDLE_DIM_CASE(-1);
708  }
709 
710 #undef HANDLE_DIM_CASE
711 #undef RUN_IVF_FLAT
712 
713  // k-select the output in chunks, to increase parallelism
714  runPass1SelectLists(prefixSumOffsets,
715  allDistances,
716  listIds.getSize(1),
717  k,
718  !l2Distance, // L2 distance chooses smallest
719  heapDistances,
720  heapIndices,
721  stream);
722 
723  // k-select final output
724  auto flatHeapDistances = heapDistances.downcastInner<2>();
725  auto flatHeapIndices = heapIndices.downcastInner<2>();
726 
727  runPass2SelectLists(flatHeapDistances,
728  flatHeapIndices,
729  listIndices,
730  indicesOptions,
731  prefixSumOffsets,
732  listIds,
733  k,
734  !l2Distance, // L2 distance chooses smallest
735  outDistances,
736  outIndices,
737  stream);
738 
739  CUDA_VERIFY(cudaGetLastError());
740 }
741 
742 void
743 runIVFFlatScan(Tensor<float, 2, true>& queries,
744  Tensor<int, 2, true>& listIds,
745  thrust::device_vector<void*>& listData,
746  thrust::device_vector<void*>& listIndices,
747  IndicesOptions indicesOptions,
748  thrust::device_vector<int>& listLengths,
749  int maxListLength,
750  int k,
751  bool l2Distance,
752  bool useFloat16,
753  // output
754  Tensor<float, 2, true>& outDistances,
755  // output
756  Tensor<long, 2, true>& outIndices,
757  GpuResources* res) {
758  constexpr int kMinQueryTileSize = 8;
759  constexpr int kMaxQueryTileSize = 128;
760  constexpr int kThrustMemSize = 16384;
761 
762  int nprobe = listIds.getSize(1);
763 
764  auto& mem = res->getMemoryManagerCurrentDevice();
765  auto stream = res->getDefaultStreamCurrentDevice();
766 
767  // Make a reservation for Thrust to do its dirty work (global memory
768  // cross-block reduction space); hopefully this is large enough.
769  DeviceTensor<char, 1, true> thrustMem1(
770  mem, {kThrustMemSize}, stream);
771  DeviceTensor<char, 1, true> thrustMem2(
772  mem, {kThrustMemSize}, stream);
773  DeviceTensor<char, 1, true>* thrustMem[2] =
774  {&thrustMem1, &thrustMem2};
775 
776  // How much temporary storage is available?
777  // If possible, we'd like to fit within the space available.
778  size_t sizeAvailable = mem.getSizeAvailable();
779 
780  // We run two passes of heap selection
781  // This is the size of the first-level heap passes
782  constexpr int kNProbeSplit = 8;
783  int pass2Chunks = std::min(nprobe, kNProbeSplit);
784 
785  size_t sizeForFirstSelectPass =
786  pass2Chunks * k * (sizeof(float) + sizeof(int));
787 
788  // How much temporary storage we need per each query
789  size_t sizePerQuery =
790  2 * // # streams
791  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
792  nprobe * maxListLength * sizeof(float) + // allDistances
793  sizeForFirstSelectPass);
794 
795  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
796 
797  if (queryTileSize < kMinQueryTileSize) {
798  queryTileSize = kMinQueryTileSize;
799  } else if (queryTileSize > kMaxQueryTileSize) {
800  queryTileSize = kMaxQueryTileSize;
801  }
802 
803  // FIXME: we should adjust queryTileSize to deal with this, since
804  // indexing is in int32
805  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
806  std::numeric_limits<int>::max());
807 
808  // Temporary memory buffers
809  // Make sure there is space prior to the start which will be 0, and
810  // will handle the boundary condition without branches
811  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
812  mem, {queryTileSize * nprobe + 1}, stream);
813  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
814  mem, {queryTileSize * nprobe + 1}, stream);
815 
816  DeviceTensor<int, 2, true> prefixSumOffsets1(
817  prefixSumOffsetSpace1[1].data(),
818  {queryTileSize, nprobe});
819  DeviceTensor<int, 2, true> prefixSumOffsets2(
820  prefixSumOffsetSpace2[1].data(),
821  {queryTileSize, nprobe});
822  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
823  {&prefixSumOffsets1, &prefixSumOffsets2};
824 
825  // Make sure the element before prefixSumOffsets is 0, since we
826  // depend upon simple, boundary-less indexing to get proper results
827  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
828  0,
829  sizeof(int),
830  stream));
831  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
832  0,
833  sizeof(int),
834  stream));
835 
836  DeviceTensor<float, 1, true> allDistances1(
837  mem, {queryTileSize * nprobe * maxListLength}, stream);
838  DeviceTensor<float, 1, true> allDistances2(
839  mem, {queryTileSize * nprobe * maxListLength}, stream);
840  DeviceTensor<float, 1, true>* allDistances[2] =
841  {&allDistances1, &allDistances2};
842 
843  DeviceTensor<float, 3, true> heapDistances1(
844  mem, {queryTileSize, pass2Chunks, k}, stream);
845  DeviceTensor<float, 3, true> heapDistances2(
846  mem, {queryTileSize, pass2Chunks, k}, stream);
847  DeviceTensor<float, 3, true>* heapDistances[2] =
848  {&heapDistances1, &heapDistances2};
849 
850  DeviceTensor<int, 3, true> heapIndices1(
851  mem, {queryTileSize, pass2Chunks, k}, stream);
852  DeviceTensor<int, 3, true> heapIndices2(
853  mem, {queryTileSize, pass2Chunks, k}, stream);
854  DeviceTensor<int, 3, true>* heapIndices[2] =
855  {&heapIndices1, &heapIndices2};
856 
857  auto streams = res->getAlternateStreamsCurrentDevice();
858  streamWait(streams, {stream});
859 
860  int curStream = 0;
861 
862  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
863  int numQueriesInTile =
864  std::min(queryTileSize, queries.getSize(0) - query);
865 
866  auto prefixSumOffsetsView =
867  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
868 
869  auto listIdsView =
870  listIds.narrowOutermost(query, numQueriesInTile);
871  auto queryView =
872  queries.narrowOutermost(query, numQueriesInTile);
873 
874  auto heapDistancesView =
875  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
876  auto heapIndicesView =
877  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
878 
879  auto outDistanceView =
880  outDistances.narrowOutermost(query, numQueriesInTile);
881  auto outIndicesView =
882  outIndices.narrowOutermost(query, numQueriesInTile);
883 
884  runIVFFlatScanTile(queryView,
885  listIdsView,
886  listData,
887  listIndices,
888  indicesOptions,
889  listLengths,
890  *thrustMem[curStream],
891  prefixSumOffsetsView,
892  *allDistances[curStream],
893  heapDistancesView,
894  heapIndicesView,
895  k,
896  l2Distance,
897  useFloat16,
898  outDistanceView,
899  outIndicesView,
900  streams[curStream]);
901 
902  curStream = (curStream + 1) % 2;
903  }
904 
905  streamWait({stream}, streams);
906 }
907 
908 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
The class that we use to provide scan specializations.
Definition: IVFFlatScan.cu:47