Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect2.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "IVFUtils.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/Select.cuh"
15 #include "../utils/StaticUtils.h"
16 #include "../utils/Tensor.cuh"
17 #include <limits>
18 
19 //
20 // This kernel is split into a separate compilation unit to cut down
21 // on compile time
22 //
23 
24 namespace faiss { namespace gpu {
25 
26 constexpr auto kMax = std::numeric_limits<float>::max();
27 constexpr auto kMin = std::numeric_limits<float>::min();
28 
29 // This is warp divergence central, but this is really a final step
30 // and happening a small number of times
31 inline __device__ int binarySearchForBucket(int* prefixSumOffsets,
32  int size,
33  int val) {
34  int start = 0;
35  int end = size;
36 
37  while (end - start > 0) {
38  int mid = start + (end - start) / 2;
39 
40  int midVal = prefixSumOffsets[mid];
41 
42  // Find the first bucket that we are <=
43  if (midVal <= val) {
44  start = mid + 1;
45  } else {
46  end = mid;
47  }
48  }
49 
50  // We must find the bucket that it is in
51  assert(start != size);
52 
53  return start;
54 }
55 
56 template <int ThreadsPerBlock,
57  int NumWarpQ,
58  int NumThreadQ,
59  bool Dir>
60 __global__ void
61 pass2SelectLists(Tensor<float, 2, true> heapDistances,
62  Tensor<int, 2, true> heapIndices,
63  void** listIndices,
64  Tensor<int, 2, true> prefixSumOffsets,
65  Tensor<int, 2, true> topQueryToCentroid,
66  int k,
67  IndicesOptions opt,
68  Tensor<float, 2, true> outDistances,
69  Tensor<long, 2, true> outIndices) {
70  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
71 
72  __shared__ float smemK[kNumWarps * NumWarpQ];
73  __shared__ int smemV[kNumWarps * NumWarpQ];
74 
75  constexpr auto kInit = Dir ? kMin : kMax;
76  BlockSelect<float, int, Dir, Comparator<float>,
77  NumWarpQ, NumThreadQ, ThreadsPerBlock>
78  heap(kInit, -1, smemK, smemV, k);
79 
80  auto queryId = blockIdx.x;
81  int num = heapDistances.getSize(1);
82  int limit = utils::roundDown(num, kWarpSize);
83 
84  int i = threadIdx.x;
85  auto heapDistanceStart = heapDistances[queryId];
86 
87  // BlockSelect add cannot be used in a warp divergent circumstance; we
88  // handle the remainder warp below
89  for (; i < limit; i += blockDim.x) {
90  heap.add(heapDistanceStart[i], i);
91  }
92 
93  // Handle warp divergence separately
94  if (i < num) {
95  heap.addThreadQ(heapDistanceStart[i], i);
96  }
97 
98  // Merge all final results
99  heap.reduce();
100 
101  for (int i = threadIdx.x; i < k; i += blockDim.x) {
102  outDistances[queryId][i] = smemK[i];
103 
104  // `v` is the index in `heapIndices`
105  // We need to translate this into an original user index. The
106  // reason why we don't maintain intermediate results in terms of
107  // user indices is to substantially reduce temporary memory
108  // requirements and global memory write traffic for the list
109  // scanning.
110  // This code is highly divergent, but it's probably ok, since this
111  // is the very last step and it is happening a small number of
112  // times (#queries x k).
113  int v = smemV[i];
114  long index = -1;
115 
116  if (v != -1) {
117  // `offset` is the offset of the intermediate result, as
118  // calculated by the original scan.
119  int offset = heapIndices[queryId][v];
120 
121  // In order to determine the actual user index, we need to first
122  // determine what list it was in.
123  // We do this by binary search in the prefix sum list.
124  int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
125  prefixSumOffsets.getSize(1),
126  offset);
127 
128  // This is then the probe for the query; we can find the actual
129  // list ID from this
130  int listId = topQueryToCentroid[queryId][probe];
131 
132  // Now, we need to know the offset within the list
133  // We ensure that before the array (at offset -1), there is a 0 value
134  int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
135  int listOffset = offset - listStart;
136 
137  // This gives us our final index
138  if (opt == INDICES_32_BIT) {
139  index = (long) ((int*) listIndices[listId])[listOffset];
140  } else if (opt == INDICES_64_BIT) {
141  index = ((long*) listIndices[listId])[listOffset];
142  } else {
143  index = ((long) listId << 32 | (long) listOffset);
144  }
145  }
146 
147  outIndices[queryId][i] = index;
148  }
149 }
150 
151 void
152 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
153  Tensor<int, 2, true>& heapIndices,
154  thrust::device_vector<void*>& listIndices,
155  IndicesOptions indicesOptions,
156  Tensor<int, 2, true>& prefixSumOffsets,
157  Tensor<int, 2, true>& topQueryToCentroid,
158  int k,
159  bool chooseLargest,
160  Tensor<float, 2, true>& outDistances,
161  Tensor<long, 2, true>& outIndices,
162  cudaStream_t stream) {
163  constexpr auto kThreadsPerBlock = 128;
164 
165  auto grid = dim3(topQueryToCentroid.getSize(0));
166  auto block = dim3(kThreadsPerBlock);
167 
168 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
169  do { \
170  pass2SelectLists<kThreadsPerBlock, \
171  NUM_WARP_Q, NUM_THREAD_Q, DIR> \
172  <<<grid, block, 0, stream>>>(heapDistances, \
173  heapIndices, \
174  listIndices.data().get(), \
175  prefixSumOffsets, \
176  topQueryToCentroid, \
177  k, \
178  indicesOptions, \
179  outDistances, \
180  outIndices); \
181  return; /* success */ \
182  } while (0)
183 
184 #define RUN_PASS_DIR(DIR) \
185  do { \
186  if (k == 1) { \
187  RUN_PASS(1, 1, DIR); \
188  } else if (k <= 32) { \
189  RUN_PASS(32, 2, DIR); \
190  } else if (k <= 64) { \
191  RUN_PASS(64, 3, DIR); \
192  } else if (k <= 128) { \
193  RUN_PASS(128, 3, DIR); \
194  } else if (k <= 256) { \
195  RUN_PASS(256, 4, DIR); \
196  } else if (k <= 512) { \
197  RUN_PASS(512, 8, DIR); \
198  } else if (k <= 1024) { \
199  RUN_PASS(1024, 8, DIR); \
200  } \
201  } while (0)
202 
203  if (chooseLargest) {
204  RUN_PASS_DIR(true);
205  } else {
206  RUN_PASS_DIR(false);
207  }
208 
209  // unimplemented / too many resources
210  FAISS_ASSERT(false);
211 
212 #undef RUN_PASS_DIR
213 #undef RUN_PASS
214 }
215 
216 } } // namespace