Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect2.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
16 #include <limits>
17 
18 //
19 // This kernel is split into a separate compilation unit to cut down
20 // on compile time
21 //
22 
23 namespace faiss { namespace gpu {
24 
25 constexpr auto kMax = std::numeric_limits<float>::max();
26 constexpr auto kMin = std::numeric_limits<float>::min();
27 
28 // This is warp divergence central, but this is really a final step
29 // and happening a small number of times
30 inline __device__ int binarySearchForBucket(int* prefixSumOffsets,
31  int size,
32  int val) {
33  int start = 0;
34  int end = size;
35 
36  while (end - start > 0) {
37  int mid = start + (end - start) / 2;
38 
39  int midVal = prefixSumOffsets[mid];
40 
41  // Find the first bucket that we are <=
42  if (midVal <= val) {
43  start = mid + 1;
44  } else {
45  end = mid;
46  }
47  }
48 
49  // We must find the bucket that it is in
50  assert(start != size);
51 
52  return start;
53 }
54 
55 template <int ThreadsPerBlock,
56  int NumWarpQ,
57  int NumThreadQ,
58  bool Dir>
59 __global__ void
60 pass2SelectLists(Tensor<float, 2, true> heapDistances,
61  Tensor<int, 2, true> heapIndices,
62  void** listIndices,
63  Tensor<int, 2, true> prefixSumOffsets,
64  Tensor<int, 2, true> topQueryToCentroid,
65  int k,
66  IndicesOptions opt,
67  Tensor<float, 2, true> outDistances,
68  Tensor<long, 2, true> outIndices) {
69  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
70 
71  __shared__ float smemK[kNumWarps * NumWarpQ];
72  __shared__ int smemV[kNumWarps * NumWarpQ];
73 
74  constexpr auto kInit = Dir ? kMin : kMax;
75  BlockSelect<float, int, Dir, Comparator<float>,
76  NumWarpQ, NumThreadQ, ThreadsPerBlock>
77  heap(kInit, -1, smemK, smemV, k);
78 
79  auto queryId = blockIdx.x;
80  int num = heapDistances.getSize(1);
81  int limit = utils::roundDown(num, kWarpSize);
82 
83  int i = threadIdx.x;
84  auto heapDistanceStart = heapDistances[queryId];
85 
86  // BlockSelect add cannot be used in a warp divergent circumstance; we
87  // handle the remainder warp below
88  for (; i < limit; i += blockDim.x) {
89  heap.add(heapDistanceStart[i], i);
90  }
91 
92  // Handle warp divergence separately
93  if (i < num) {
94  heap.addThreadQ(heapDistanceStart[i], i);
95  }
96 
97  // Merge all final results
98  heap.reduce();
99 
100  for (int i = threadIdx.x; i < k; i += blockDim.x) {
101  outDistances[queryId][i] = smemK[i];
102 
103  // `v` is the index in `heapIndices`
104  // We need to translate this into an original user index. The
105  // reason why we don't maintain intermediate results in terms of
106  // user indices is to substantially reduce temporary memory
107  // requirements and global memory write traffic for the list
108  // scanning.
109  // This code is highly divergent, but it's probably ok, since this
110  // is the very last step and it is happening a small number of
111  // times (#queries x k).
112  int v = smemV[i];
113  long index = -1;
114 
115  if (v != -1) {
116  // `offset` is the offset of the intermediate result, as
117  // calculated by the original scan.
118  int offset = heapIndices[queryId][v];
119 
120  // In order to determine the actual user index, we need to first
121  // determine what list it was in.
122  // We do this by binary search in the prefix sum list.
123  int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
124  prefixSumOffsets.getSize(1),
125  offset);
126 
127  // This is then the probe for the query; we can find the actual
128  // list ID from this
129  int listId = topQueryToCentroid[queryId][probe];
130 
131  // Now, we need to know the offset within the list
132  // We ensure that before the array (at offset -1), there is a 0 value
133  int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
134  int listOffset = offset - listStart;
135 
136  // This gives us our final index
137  if (opt == INDICES_32_BIT) {
138  index = (long) ((int*) listIndices[listId])[listOffset];
139  } else if (opt == INDICES_64_BIT) {
140  index = ((long*) listIndices[listId])[listOffset];
141  } else {
142  index = ((long) listId << 32 | (long) listOffset);
143  }
144  }
145 
146  outIndices[queryId][i] = index;
147  }
148 }
149 
150 void
151 runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
152  Tensor<int, 2, true>& heapIndices,
153  thrust::device_vector<void*>& listIndices,
154  IndicesOptions indicesOptions,
155  Tensor<int, 2, true>& prefixSumOffsets,
156  Tensor<int, 2, true>& topQueryToCentroid,
157  int k,
158  bool chooseLargest,
159  Tensor<float, 2, true>& outDistances,
160  Tensor<long, 2, true>& outIndices,
161  cudaStream_t stream) {
162  constexpr auto kThreadsPerBlock = 128;
163 
164  auto grid = dim3(topQueryToCentroid.getSize(0));
165  auto block = dim3(kThreadsPerBlock);
166 
167 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
168  do { \
169  pass2SelectLists<kThreadsPerBlock, \
170  NUM_WARP_Q, NUM_THREAD_Q, DIR> \
171  <<<grid, block, 0, stream>>>(heapDistances, \
172  heapIndices, \
173  listIndices.data().get(), \
174  prefixSumOffsets, \
175  topQueryToCentroid, \
176  k, \
177  indicesOptions, \
178  outDistances, \
179  outIndices); \
180  CUDA_TEST_ERROR(); \
181  return; /* success */ \
182  } while (0)
183 
184 #define RUN_PASS_DIR(DIR) \
185  do { \
186  if (k == 1) { \
187  RUN_PASS(1, 1, DIR); \
188  } else if (k <= 32) { \
189  RUN_PASS(32, 2, DIR); \
190  } else if (k <= 64) { \
191  RUN_PASS(64, 3, DIR); \
192  } else if (k <= 128) { \
193  RUN_PASS(128, 3, DIR); \
194  } else if (k <= 256) { \
195  RUN_PASS(256, 4, DIR); \
196  } else if (k <= 512) { \
197  RUN_PASS(512, 8, DIR); \
198  } else if (k <= 1024) { \
199  RUN_PASS(1024, 8, DIR); \
200  } \
201  } while (0)
202 
203  if (chooseLargest) {
204  RUN_PASS_DIR(true);
205  } else {
206  RUN_PASS_DIR(false);
207  }
208 
209  // unimplemented / too many resources
210  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
211 
212 #undef RUN_PASS_DIR
213 #undef RUN_PASS
214 }
215 
216 } } // namespace