Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect1.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "IVFUtils.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Limits.cuh"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
16 
17 //
18 // This kernel is split into a separate compilation unit to cut down
19 // on compile time
20 //
21 
22 namespace faiss { namespace gpu {
23 
24 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
25 __global__ void
26 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
27  Tensor<float, 1, true> distance,
28  int nprobe,
29  int k,
30  Tensor<float, 3, true> heapDistances,
31  Tensor<int, 3, true> heapIndices) {
32  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
33 
34  __shared__ float smemK[kNumWarps * NumWarpQ];
35  __shared__ int smemV[kNumWarps * NumWarpQ];
36 
37  constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
38  BlockSelect<float, int, Dir, Comparator<float>,
39  NumWarpQ, NumThreadQ, ThreadsPerBlock>
40  heap(kInit, -1, smemK, smemV, k);
41 
42  auto queryId = blockIdx.y;
43  auto sliceId = blockIdx.x;
44  auto numSlices = gridDim.x;
45 
46  int sliceSize = (nprobe / numSlices);
47  int sliceStart = sliceSize * sliceId;
48  int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
49  sliceStart + sliceSize;
50  auto offsets = prefixSumOffsets[queryId].data();
51 
52  // We ensure that before the array (at offset -1), there is a 0 value
53  int start = *(&offsets[sliceStart] - 1);
54  int end = offsets[sliceEnd - 1];
55 
56  int num = end - start;
57  int limit = utils::roundDown(num, kWarpSize);
58 
59  int i = threadIdx.x;
60  auto distanceStart = distance[start].data();
61 
62  // BlockSelect add cannot be used in a warp divergent circumstance; we
63  // handle the remainder warp below
64  for (; i < limit; i += blockDim.x) {
65  heap.add(distanceStart[i], start + i);
66  }
67 
68  // Handle warp divergence separately
69  if (i < num) {
70  heap.addThreadQ(distanceStart[i], start + i);
71  }
72 
73  // Merge all final results
74  heap.reduce();
75 
76  // Write out the final k-selected values; they should be all
77  // together
78  for (int i = threadIdx.x; i < k; i += blockDim.x) {
79  heapDistances[queryId][sliceId][i] = smemK[i];
80  heapIndices[queryId][sliceId][i] = smemV[i];
81  }
82 }
83 
84 void
85 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
86  Tensor<float, 1, true>& distance,
87  int nprobe,
88  int k,
89  bool chooseLargest,
90  Tensor<float, 3, true>& heapDistances,
91  Tensor<int, 3, true>& heapIndices,
92  cudaStream_t stream) {
93  constexpr auto kThreadsPerBlock = 128;
94 
95  auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
96  auto block = dim3(kThreadsPerBlock);
97 
98 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
99  do { \
100  pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
101  <<<grid, block, 0, stream>>>(prefixSumOffsets, \
102  distance, \
103  nprobe, \
104  k, \
105  heapDistances, \
106  heapIndices); \
107  CUDA_TEST_ERROR(); \
108  return; /* success */ \
109  } while (0)
110 
111 #define RUN_PASS_DIR(DIR) \
112  do { \
113  if (k == 1) { \
114  RUN_PASS(1, 1, DIR); \
115  } else if (k <= 32) { \
116  RUN_PASS(32, 2, DIR); \
117  } else if (k <= 64) { \
118  RUN_PASS(64, 3, DIR); \
119  } else if (k <= 128) { \
120  RUN_PASS(128, 3, DIR); \
121  } else if (k <= 256) { \
122  RUN_PASS(256, 4, DIR); \
123  } else if (k <= 512) { \
124  RUN_PASS(512, 8, DIR); \
125  } else if (k <= 1024) { \
126  RUN_PASS(1024, 8, DIR); \
127  } \
128  } while (0)
129 
130  if (chooseLargest) {
131  RUN_PASS_DIR(true);
132  } else {
133  RUN_PASS_DIR(false);
134  }
135 
136  // unimplemented / too many resources
137  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
138 
139 #undef RUN_PASS_DIR
140 #undef RUN_PASS
141 }
142 
143 } } // namespace