Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect2.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "IVFUtils.cuh"
10 #include "../utils/DeviceDefs.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Limits.cuh"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
16 
17 //
18 // This kernel is split into a separate compilation unit to cut down
19 // on compile time
20 //
21 
22 namespace faiss { namespace gpu {
23 
24 // This is warp divergence central, but this is really a final step
25 // and happening a small number of times
26 inline __device__ int binarySearchForBucket(int* prefixSumOffsets,
27  int size,
28  int val) {
29  int start = 0;
30  int end = size;
31 
32  while (end - start > 0) {
33  int mid = start + (end - start) / 2;
34 
35  int midVal = prefixSumOffsets[mid];
36 
37  // Find the first bucket that we are <=
38  if (midVal <= val) {
39  start = mid + 1;
40  } else {
41  end = mid;
42  }
43  }
44 
45  // We must find the bucket that it is in
46  assert(start != size);
47 
48  return start;
49 }
50 
51 template <int ThreadsPerBlock,
52  int NumWarpQ,
53  int NumThreadQ,
54  bool Dir>
55 __global__ void
56 pass2SelectLists(Tensor<float, 2, true> heapDistances,
57  Tensor<int, 2, true> heapIndices,
58  void** listIndices,
59  Tensor<int, 2, true> prefixSumOffsets,
60  Tensor<int, 2, true> topQueryToCentroid,
61  int k,
62  IndicesOptions opt,
63  Tensor<float, 2, true> outDistances,
64  Tensor<long, 2, true> outIndices) {
65  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
66 
67  __shared__ float smemK[kNumWarps * NumWarpQ];
68  __shared__ int smemV[kNumWarps * NumWarpQ];
69 
70  constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
71  BlockSelect<float, int, Dir, Comparator<float>,
72  NumWarpQ, NumThreadQ, ThreadsPerBlock>
73  heap(kInit, -1, smemK, smemV, k);
74 
75  auto queryId = blockIdx.x;
76  int num = heapDistances.getSize(1);
77  int limit = utils::roundDown(num, kWarpSize);
78 
79  int i = threadIdx.x;
80  auto heapDistanceStart = heapDistances[queryId];
81 
82  // BlockSelect add cannot be used in a warp divergent circumstance; we
83  // handle the remainder warp below
84  for (; i < limit; i += blockDim.x) {
85  heap.add(heapDistanceStart[i], i);
86  }
87 
88  // Handle warp divergence separately
89  if (i < num) {
90  heap.addThreadQ(heapDistanceStart[i], i);
91  }
92 
93  // Merge all final results
94  heap.reduce();
95 
96  for (int i = threadIdx.x; i < k; i += blockDim.x) {
97  outDistances[queryId][i] = smemK[i];
98 
99  // `v` is the index in `heapIndices`
100  // We need to translate this into an original user index. The
101  // reason why we don't maintain intermediate results in terms of
102  // user indices is to substantially reduce temporary memory
103  // requirements and global memory write traffic for the list
104  // scanning.
105  // This code is highly divergent, but it's probably ok, since this
106  // is the very last step and it is happening a small number of
107  // times (#queries x k).
108  int v = smemV[i];
109  long index = -1;
110 
111  if (v != -1) {
112  // `offset` is the offset of the intermediate result, as
113  // calculated by the original scan.
114  int offset = heapIndices[queryId][v];
115 
116  // In order to determine the actual user index, we need to first
117  // determine what list it was in.
118  // We do this by binary search in the prefix sum list.
119  int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
120  prefixSumOffsets.getSize(1),
121  offset);
122 
123  // This is then the probe for the query; we can find the actual
124  // list ID from this
125  int listId = topQueryToCentroid[queryId][probe];
126 
127  // Now, we need to know the offset within the list
128  // We ensure that before the array (at offset -1), there is a 0 value
129  int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
130  int listOffset = offset - listStart;
131 
132  // This gives us our final index
133  if (opt == INDICES_32_BIT) {
134  index = (long) ((int*) listIndices[listId])[listOffset];
135  } else if (opt == INDICES_64_BIT) {
136  index = ((long*) listIndices[listId])[listOffset];
137  } else {
138  index = ((long) listId << 32 | (long) listOffset);
139  }
140  }
141 
142  outIndices[queryId][i] = index;
143  }
144 }
145 
146 void
147 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
148  Tensor<int, 2, true>& heapIndices,
149  thrust::device_vector<void*>& listIndices,
150  IndicesOptions indicesOptions,
151  Tensor<int, 2, true>& prefixSumOffsets,
152  Tensor<int, 2, true>& topQueryToCentroid,
153  int k,
154  bool chooseLargest,
155  Tensor<float, 2, true>& outDistances,
156  Tensor<long, 2, true>& outIndices,
157  cudaStream_t stream) {
158  auto grid = dim3(topQueryToCentroid.getSize(0));
159 
160 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \
161  do { \
162  pass2SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
163  <<<grid, BLOCK, 0, stream>>>(heapDistances, \
164  heapIndices, \
165  listIndices.data().get(), \
166  prefixSumOffsets, \
167  topQueryToCentroid, \
168  k, \
169  indicesOptions, \
170  outDistances, \
171  outIndices); \
172  CUDA_TEST_ERROR(); \
173  return; /* success */ \
174  } while (0)
175 
176 #if GPU_MAX_SELECTION_K >= 2048
177 
178  // block size 128 for k <= 1024, 64 for k = 2048
179 #define RUN_PASS_DIR(DIR) \
180  do { \
181  if (k == 1) { \
182  RUN_PASS(128, 1, 1, DIR); \
183  } else if (k <= 32) { \
184  RUN_PASS(128, 32, 2, DIR); \
185  } else if (k <= 64) { \
186  RUN_PASS(128, 64, 3, DIR); \
187  } else if (k <= 128) { \
188  RUN_PASS(128, 128, 3, DIR); \
189  } else if (k <= 256) { \
190  RUN_PASS(128, 256, 4, DIR); \
191  } else if (k <= 512) { \
192  RUN_PASS(128, 512, 8, DIR); \
193  } else if (k <= 1024) { \
194  RUN_PASS(128, 1024, 8, DIR); \
195  } else if (k <= 2048) { \
196  RUN_PASS(64, 2048, 8, DIR); \
197  } \
198  } while (0)
199 
200 #else
201 
202 #define RUN_PASS_DIR(DIR) \
203  do { \
204  if (k == 1) { \
205  RUN_PASS(128, 1, 1, DIR); \
206  } else if (k <= 32) { \
207  RUN_PASS(128, 32, 2, DIR); \
208  } else if (k <= 64) { \
209  RUN_PASS(128, 64, 3, DIR); \
210  } else if (k <= 128) { \
211  RUN_PASS(128, 128, 3, DIR); \
212  } else if (k <= 256) { \
213  RUN_PASS(128, 256, 4, DIR); \
214  } else if (k <= 512) { \
215  RUN_PASS(128, 512, 8, DIR); \
216  } else if (k <= 1024) { \
217  RUN_PASS(128, 1024, 8, DIR); \
218  } \
219  } while (0)
220 
221 #endif // GPU_MAX_SELECTION_K
222 
223  if (chooseLargest) {
224  RUN_PASS_DIR(true);
225  } else {
226  RUN_PASS_DIR(false);
227  }
228 
229  // unimplemented / too many resources
230  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
231 
232 #undef RUN_PASS_DIR
233 #undef RUN_PASS
234 }
235 
236 } } // namespace