Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect1.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Select.cuh"
14 #include "../utils/StaticUtils.h"
15 #include "../utils/Tensor.cuh"
16 #include <limits>
17 
18 //
19 // This kernel is split into a separate compilation unit to cut down
20 // on compile time
21 //
22 
23 namespace faiss { namespace gpu {
24 
25 constexpr auto kMax = std::numeric_limits<float>::max();
26 constexpr auto kMin = std::numeric_limits<float>::min();
27 
28 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
29 __global__ void
30 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
31  Tensor<float, 1, true> distance,
32  int nprobe,
33  int k,
34  Tensor<float, 3, true> heapDistances,
35  Tensor<int, 3, true> heapIndices) {
36  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
37 
38  __shared__ float smemK[kNumWarps * NumWarpQ];
39  __shared__ int smemV[kNumWarps * NumWarpQ];
40 
41  constexpr auto kInit = Dir ? kMin : kMax;
42  BlockSelect<float, int, Dir, Comparator<float>,
43  NumWarpQ, NumThreadQ, ThreadsPerBlock>
44  heap(kInit, -1, smemK, smemV, k);
45 
46  auto queryId = blockIdx.y;
47  auto sliceId = blockIdx.x;
48  auto numSlices = gridDim.x;
49 
50  int sliceSize = (nprobe / numSlices);
51  int sliceStart = sliceSize * sliceId;
52  int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
53  sliceStart + sliceSize;
54  auto offsets = prefixSumOffsets[queryId].data();
55 
56  // We ensure that before the array (at offset -1), there is a 0 value
57  int start = *(&offsets[sliceStart] - 1);
58  int end = offsets[sliceEnd - 1];
59 
60  int num = end - start;
61  int limit = utils::roundDown(num, kWarpSize);
62 
63  int i = threadIdx.x;
64  auto distanceStart = distance[start].data();
65 
66  // BlockSelect add cannot be used in a warp divergent circumstance; we
67  // handle the remainder warp below
68  for (; i < limit; i += blockDim.x) {
69  heap.add(distanceStart[i], start + i);
70  }
71 
72  // Handle warp divergence separately
73  if (i < num) {
74  heap.addThreadQ(distanceStart[i], start + i);
75  }
76 
77  // Merge all final results
78  heap.reduce();
79 
80  // Write out the final k-selected values; they should be all
81  // together
82  for (int i = threadIdx.x; i < k; i += blockDim.x) {
83  heapDistances[queryId][sliceId][i] = smemK[i];
84  heapIndices[queryId][sliceId][i] = smemV[i];
85  }
86 }
87 
88 void
89 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
90  Tensor<float, 1, true>& distance,
91  int nprobe,
92  int k,
93  bool chooseLargest,
94  Tensor<float, 3, true>& heapDistances,
95  Tensor<int, 3, true>& heapIndices,
96  cudaStream_t stream) {
97  constexpr auto kThreadsPerBlock = 128;
98 
99  auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
100  auto block = dim3(kThreadsPerBlock);
101 
102 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
103  do { \
104  pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
105  <<<grid, block, 0, stream>>>(prefixSumOffsets, \
106  distance, \
107  nprobe, \
108  k, \
109  heapDistances, \
110  heapIndices); \
111  CUDA_TEST_ERROR(); \
112  return; /* success */ \
113  } while (0)
114 
115 #define RUN_PASS_DIR(DIR) \
116  do { \
117  if (k == 1) { \
118  RUN_PASS(1, 1, DIR); \
119  } else if (k <= 32) { \
120  RUN_PASS(32, 2, DIR); \
121  } else if (k <= 64) { \
122  RUN_PASS(64, 3, DIR); \
123  } else if (k <= 128) { \
124  RUN_PASS(128, 3, DIR); \
125  } else if (k <= 256) { \
126  RUN_PASS(256, 4, DIR); \
127  } else if (k <= 512) { \
128  RUN_PASS(512, 8, DIR); \
129  } else if (k <= 1024) { \
130  RUN_PASS(1024, 8, DIR); \
131  } \
132  } while (0)
133 
134  if (chooseLargest) {
135  RUN_PASS_DIR(true);
136  } else {
137  RUN_PASS_DIR(false);
138  }
139 
140  // unimplemented / too many resources
141  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
142 
143 #undef RUN_PASS_DIR
144 #undef RUN_PASS
145 }
146 
147 } } // namespace