Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect1.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "IVFUtils.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/Select.cuh"
15 #include "../utils/StaticUtils.h"
16 #include "../utils/Tensor.cuh"
17 #include <limits>
18 
19 //
20 // This kernel is split into a separate compilation unit to cut down
21 // on compile time
22 //
23 
24 namespace faiss { namespace gpu {
25 
26 constexpr auto kMax = std::numeric_limits<float>::max();
27 constexpr auto kMin = std::numeric_limits<float>::min();
28 
29 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
30 __global__ void
31 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
32  Tensor<float, 1, true> distance,
33  int nprobe,
34  int k,
35  Tensor<float, 3, true> heapDistances,
36  Tensor<int, 3, true> heapIndices) {
37  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
38 
39  __shared__ float smemK[kNumWarps * NumWarpQ];
40  __shared__ int smemV[kNumWarps * NumWarpQ];
41 
42  constexpr auto kInit = Dir ? kMin : kMax;
43  BlockSelect<float, int, Dir, Comparator<float>,
44  NumWarpQ, NumThreadQ, ThreadsPerBlock>
45  heap(kInit, -1, smemK, smemV, k);
46 
47  auto queryId = blockIdx.y;
48  auto sliceId = blockIdx.x;
49  auto numSlices = gridDim.x;
50 
51  int sliceSize = (nprobe / numSlices);
52  int sliceStart = sliceSize * sliceId;
53  int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
54  sliceStart + sliceSize;
55  auto offsets = prefixSumOffsets[queryId].data();
56 
57  // We ensure that before the array (at offset -1), there is a 0 value
58  int start = *(&offsets[sliceStart] - 1);
59  int end = offsets[sliceEnd - 1];
60 
61  int num = end - start;
62  int limit = utils::roundDown(num, kWarpSize);
63 
64  int i = threadIdx.x;
65  auto distanceStart = distance[start].data();
66 
67  // BlockSelect add cannot be used in a warp divergent circumstance; we
68  // handle the remainder warp below
69  for (; i < limit; i += blockDim.x) {
70  heap.add(distanceStart[i], start + i);
71  }
72 
73  // Handle warp divergence separately
74  if (i < num) {
75  heap.addThreadQ(distanceStart[i], start + i);
76  }
77 
78  // Merge all final results
79  heap.reduce();
80 
81  // Write out the final k-selected values; they should be all
82  // together
83  for (int i = threadIdx.x; i < k; i += blockDim.x) {
84  heapDistances[queryId][sliceId][i] = smemK[i];
85  heapIndices[queryId][sliceId][i] = smemV[i];
86  }
87 }
88 
89 void
90 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
91  Tensor<float, 1, true>& distance,
92  int nprobe,
93  int k,
94  bool chooseLargest,
95  Tensor<float, 3, true>& heapDistances,
96  Tensor<int, 3, true>& heapIndices,
97  cudaStream_t stream) {
98  constexpr auto kThreadsPerBlock = 128;
99 
100  auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
101  auto block = dim3(kThreadsPerBlock);
102 
103 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
104  do { \
105  pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
106  <<<grid, block, 0, stream>>>(prefixSumOffsets, \
107  distance, \
108  nprobe, \
109  k, \
110  heapDistances, \
111  heapIndices); \
112  return; /* success */ \
113  } while (0)
114 
115 #define RUN_PASS_DIR(DIR) \
116  do { \
117  if (k == 1) { \
118  RUN_PASS(1, 1, DIR); \
119  } else if (k <= 32) { \
120  RUN_PASS(32, 2, DIR); \
121  } else if (k <= 64) { \
122  RUN_PASS(64, 3, DIR); \
123  } else if (k <= 128) { \
124  RUN_PASS(128, 3, DIR); \
125  } else if (k <= 256) { \
126  RUN_PASS(256, 4, DIR); \
127  } else if (k <= 512) { \
128  RUN_PASS(512, 8, DIR); \
129  } else if (k <= 1024) { \
130  RUN_PASS(1024, 8, DIR); \
131  } \
132  } while (0)
133 
134  if (chooseLargest) {
135  RUN_PASS_DIR(true);
136  } else {
137  RUN_PASS_DIR(false);
138  }
139 
140  // unimplemented / too many resources
141  FAISS_ASSERT(false);
142 
143 #undef RUN_PASS_DIR
144 #undef RUN_PASS
145 }
146 
147 } } // namespace