Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect2.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Limits.cuh"
14 #include "../utils/Select.cuh"
15 #include "../utils/StaticUtils.h"
16 #include "../utils/Tensor.cuh"
17 
18 //
19 // This kernel is split into a separate compilation unit to cut down
20 // on compile time
21 //
22 
23 namespace faiss { namespace gpu {
24 
25 // This is warp divergence central, but this is really a final step
26 // and happening a small number of times
27 inline __device__ int binarySearchForBucket(int* prefixSumOffsets,
28  int size,
29  int val) {
30  int start = 0;
31  int end = size;
32 
33  while (end - start > 0) {
34  int mid = start + (end - start) / 2;
35 
36  int midVal = prefixSumOffsets[mid];
37 
38  // Find the first bucket that we are <=
39  if (midVal <= val) {
40  start = mid + 1;
41  } else {
42  end = mid;
43  }
44  }
45 
46  // We must find the bucket that it is in
47  assert(start != size);
48 
49  return start;
50 }
51 
52 template <int ThreadsPerBlock,
53  int NumWarpQ,
54  int NumThreadQ,
55  bool Dir>
56 __global__ void
57 pass2SelectLists(Tensor<float, 2, true> heapDistances,
58  Tensor<int, 2, true> heapIndices,
59  void** listIndices,
60  Tensor<int, 2, true> prefixSumOffsets,
61  Tensor<int, 2, true> topQueryToCentroid,
62  int k,
63  IndicesOptions opt,
64  Tensor<float, 2, true> outDistances,
65  Tensor<long, 2, true> outIndices) {
66  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
67 
68  __shared__ float smemK[kNumWarps * NumWarpQ];
69  __shared__ int smemV[kNumWarps * NumWarpQ];
70 
71  constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
72  BlockSelect<float, int, Dir, Comparator<float>,
73  NumWarpQ, NumThreadQ, ThreadsPerBlock>
74  heap(kInit, -1, smemK, smemV, k);
75 
76  auto queryId = blockIdx.x;
77  int num = heapDistances.getSize(1);
78  int limit = utils::roundDown(num, kWarpSize);
79 
80  int i = threadIdx.x;
81  auto heapDistanceStart = heapDistances[queryId];
82 
83  // BlockSelect add cannot be used in a warp divergent circumstance; we
84  // handle the remainder warp below
85  for (; i < limit; i += blockDim.x) {
86  heap.add(heapDistanceStart[i], i);
87  }
88 
89  // Handle warp divergence separately
90  if (i < num) {
91  heap.addThreadQ(heapDistanceStart[i], i);
92  }
93 
94  // Merge all final results
95  heap.reduce();
96 
97  for (int i = threadIdx.x; i < k; i += blockDim.x) {
98  outDistances[queryId][i] = smemK[i];
99 
100  // `v` is the index in `heapIndices`
101  // We need to translate this into an original user index. The
102  // reason why we don't maintain intermediate results in terms of
103  // user indices is to substantially reduce temporary memory
104  // requirements and global memory write traffic for the list
105  // scanning.
106  // This code is highly divergent, but it's probably ok, since this
107  // is the very last step and it is happening a small number of
108  // times (#queries x k).
109  int v = smemV[i];
110  long index = -1;
111 
112  if (v != -1) {
113  // `offset` is the offset of the intermediate result, as
114  // calculated by the original scan.
115  int offset = heapIndices[queryId][v];
116 
117  // In order to determine the actual user index, we need to first
118  // determine what list it was in.
119  // We do this by binary search in the prefix sum list.
120  int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
121  prefixSumOffsets.getSize(1),
122  offset);
123 
124  // This is then the probe for the query; we can find the actual
125  // list ID from this
126  int listId = topQueryToCentroid[queryId][probe];
127 
128  // Now, we need to know the offset within the list
129  // We ensure that before the array (at offset -1), there is a 0 value
130  int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
131  int listOffset = offset - listStart;
132 
133  // This gives us our final index
134  if (opt == INDICES_32_BIT) {
135  index = (long) ((int*) listIndices[listId])[listOffset];
136  } else if (opt == INDICES_64_BIT) {
137  index = ((long*) listIndices[listId])[listOffset];
138  } else {
139  index = ((long) listId << 32 | (long) listOffset);
140  }
141  }
142 
143  outIndices[queryId][i] = index;
144  }
145 }
146 
147 void
148 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
149  Tensor<int, 2, true>& heapIndices,
150  thrust::device_vector<void*>& listIndices,
151  IndicesOptions indicesOptions,
152  Tensor<int, 2, true>& prefixSumOffsets,
153  Tensor<int, 2, true>& topQueryToCentroid,
154  int k,
155  bool chooseLargest,
156  Tensor<float, 2, true>& outDistances,
157  Tensor<long, 2, true>& outIndices,
158  cudaStream_t stream) {
159  constexpr auto kThreadsPerBlock = 128;
160 
161  auto grid = dim3(topQueryToCentroid.getSize(0));
162  auto block = dim3(kThreadsPerBlock);
163 
164 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
165  do { \
166  pass2SelectLists<kThreadsPerBlock, \
167  NUM_WARP_Q, NUM_THREAD_Q, DIR> \
168  <<<grid, block, 0, stream>>>(heapDistances, \
169  heapIndices, \
170  listIndices.data().get(), \
171  prefixSumOffsets, \
172  topQueryToCentroid, \
173  k, \
174  indicesOptions, \
175  outDistances, \
176  outIndices); \
177  CUDA_TEST_ERROR(); \
178  return; /* success */ \
179  } while (0)
180 
181 #define RUN_PASS_DIR(DIR) \
182  do { \
183  if (k == 1) { \
184  RUN_PASS(1, 1, DIR); \
185  } else if (k <= 32) { \
186  RUN_PASS(32, 2, DIR); \
187  } else if (k <= 64) { \
188  RUN_PASS(64, 3, DIR); \
189  } else if (k <= 128) { \
190  RUN_PASS(128, 3, DIR); \
191  } else if (k <= 256) { \
192  RUN_PASS(256, 4, DIR); \
193  } else if (k <= 512) { \
194  RUN_PASS(512, 8, DIR); \
195  } else if (k <= 1024) { \
196  RUN_PASS(1024, 8, DIR); \
197  } \
198  } while (0)
199 
200  if (chooseLargest) {
201  RUN_PASS_DIR(true);
202  } else {
203  RUN_PASS_DIR(false);
204  }
205 
206  // unimplemented / too many resources
207  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
208 
209 #undef RUN_PASS_DIR
210 #undef RUN_PASS
211 }
212 
213 } } // namespace