Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect1.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "IVFUtils.cuh"
10 #include "../utils/DeviceDefs.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Limits.cuh"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
16 
17 //
18 // This kernel is split into a separate compilation unit to cut down
19 // on compile time
20 //
21 
22 namespace faiss { namespace gpu {
23 
24 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
25 __global__ void
26 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
27  Tensor<float, 1, true> distance,
28  int nprobe,
29  int k,
30  Tensor<float, 3, true> heapDistances,
31  Tensor<int, 3, true> heapIndices) {
32  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
33 
34  __shared__ float smemK[kNumWarps * NumWarpQ];
35  __shared__ int smemV[kNumWarps * NumWarpQ];
36 
37  constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
38  BlockSelect<float, int, Dir, Comparator<float>,
39  NumWarpQ, NumThreadQ, ThreadsPerBlock>
40  heap(kInit, -1, smemK, smemV, k);
41 
42  auto queryId = blockIdx.y;
43  auto sliceId = blockIdx.x;
44  auto numSlices = gridDim.x;
45 
46  int sliceSize = (nprobe / numSlices);
47  int sliceStart = sliceSize * sliceId;
48  int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
49  sliceStart + sliceSize;
50  auto offsets = prefixSumOffsets[queryId].data();
51 
52  // We ensure that before the array (at offset -1), there is a 0 value
53  int start = *(&offsets[sliceStart] - 1);
54  int end = offsets[sliceEnd - 1];
55 
56  int num = end - start;
57  int limit = utils::roundDown(num, kWarpSize);
58 
59  int i = threadIdx.x;
60  auto distanceStart = distance[start].data();
61 
62  // BlockSelect add cannot be used in a warp divergent circumstance; we
63  // handle the remainder warp below
64  for (; i < limit; i += blockDim.x) {
65  heap.add(distanceStart[i], start + i);
66  }
67 
68  // Handle warp divergence separately
69  if (i < num) {
70  heap.addThreadQ(distanceStart[i], start + i);
71  }
72 
73  // Merge all final results
74  heap.reduce();
75 
76  // Write out the final k-selected values; they should be all
77  // together
78  for (int i = threadIdx.x; i < k; i += blockDim.x) {
79  heapDistances[queryId][sliceId][i] = smemK[i];
80  heapIndices[queryId][sliceId][i] = smemV[i];
81  }
82 }
83 
84 void
85 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
86  Tensor<float, 1, true>& distance,
87  int nprobe,
88  int k,
89  bool chooseLargest,
90  Tensor<float, 3, true>& heapDistances,
91  Tensor<int, 3, true>& heapIndices,
92  cudaStream_t stream) {
93  // This is caught at a higher level
94  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
95 
96  auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
97 
98 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \
99  do { \
100  pass1SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
101  <<<grid, BLOCK, 0, stream>>>(prefixSumOffsets, \
102  distance, \
103  nprobe, \
104  k, \
105  heapDistances, \
106  heapIndices); \
107  CUDA_TEST_ERROR(); \
108  return; /* success */ \
109  } while (0)
110 
111 #if GPU_MAX_SELECTION_K >= 2048
112 
113  // block size 128 for k <= 1024, 64 for k = 2048
114 #define RUN_PASS_DIR(DIR) \
115  do { \
116  if (k == 1) { \
117  RUN_PASS(128, 1, 1, DIR); \
118  } else if (k <= 32) { \
119  RUN_PASS(128, 32, 2, DIR); \
120  } else if (k <= 64) { \
121  RUN_PASS(128, 64, 3, DIR); \
122  } else if (k <= 128) { \
123  RUN_PASS(128, 128, 3, DIR); \
124  } else if (k <= 256) { \
125  RUN_PASS(128, 256, 4, DIR); \
126  } else if (k <= 512) { \
127  RUN_PASS(128, 512, 8, DIR); \
128  } else if (k <= 1024) { \
129  RUN_PASS(128, 1024, 8, DIR); \
130  } else if (k <= 2048) { \
131  RUN_PASS(64, 2048, 8, DIR); \
132  } \
133  } while (0)
134 
135 #else
136 
137 #define RUN_PASS_DIR(DIR) \
138  do { \
139  if (k == 1) { \
140  RUN_PASS(128, 1, 1, DIR); \
141  } else if (k <= 32) { \
142  RUN_PASS(128, 32, 2, DIR); \
143  } else if (k <= 64) { \
144  RUN_PASS(128, 64, 3, DIR); \
145  } else if (k <= 128) { \
146  RUN_PASS(128, 128, 3, DIR); \
147  } else if (k <= 256) { \
148  RUN_PASS(128, 256, 4, DIR); \
149  } else if (k <= 512) { \
150  RUN_PASS(128, 512, 8, DIR); \
151  } else if (k <= 1024) { \
152  RUN_PASS(128, 1024, 8, DIR); \
153  } \
154  } while (0)
155 
156 #endif // GPU_MAX_SELECTION_K
157 
158  if (chooseLargest) {
159  RUN_PASS_DIR(true);
160  } else {
161  RUN_PASS_DIR(false);
162  }
163 
164 #undef RUN_PASS_DIR
165 #undef RUN_PASS
166 }
167 
168 } } // namespace