Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IVFUtilsSelect1.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IVFUtils.cuh"
12 #include "../utils/DeviceUtils.h"
13 #include "../utils/Limits.cuh"
14 #include "../utils/Select.cuh"
15 #include "../utils/StaticUtils.h"
16 #include "../utils/Tensor.cuh"
17 
18 //
19 // This kernel is split into a separate compilation unit to cut down
20 // on compile time
21 //
22 
23 namespace faiss { namespace gpu {
24 
25 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
26 __global__ void
27 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
28  Tensor<float, 1, true> distance,
29  int nprobe,
30  int k,
31  Tensor<float, 3, true> heapDistances,
32  Tensor<int, 3, true> heapIndices) {
33  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
34 
35  __shared__ float smemK[kNumWarps * NumWarpQ];
36  __shared__ int smemV[kNumWarps * NumWarpQ];
37 
38  constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
39  BlockSelect<float, int, Dir, Comparator<float>,
40  NumWarpQ, NumThreadQ, ThreadsPerBlock>
41  heap(kInit, -1, smemK, smemV, k);
42 
43  auto queryId = blockIdx.y;
44  auto sliceId = blockIdx.x;
45  auto numSlices = gridDim.x;
46 
47  int sliceSize = (nprobe / numSlices);
48  int sliceStart = sliceSize * sliceId;
49  int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
50  sliceStart + sliceSize;
51  auto offsets = prefixSumOffsets[queryId].data();
52 
53  // We ensure that before the array (at offset -1), there is a 0 value
54  int start = *(&offsets[sliceStart] - 1);
55  int end = offsets[sliceEnd - 1];
56 
57  int num = end - start;
58  int limit = utils::roundDown(num, kWarpSize);
59 
60  int i = threadIdx.x;
61  auto distanceStart = distance[start].data();
62 
63  // BlockSelect add cannot be used in a warp divergent circumstance; we
64  // handle the remainder warp below
65  for (; i < limit; i += blockDim.x) {
66  heap.add(distanceStart[i], start + i);
67  }
68 
69  // Handle warp divergence separately
70  if (i < num) {
71  heap.addThreadQ(distanceStart[i], start + i);
72  }
73 
74  // Merge all final results
75  heap.reduce();
76 
77  // Write out the final k-selected values; they should be all
78  // together
79  for (int i = threadIdx.x; i < k; i += blockDim.x) {
80  heapDistances[queryId][sliceId][i] = smemK[i];
81  heapIndices[queryId][sliceId][i] = smemV[i];
82  }
83 }
84 
85 void
86 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
87  Tensor<float, 1, true>& distance,
88  int nprobe,
89  int k,
90  bool chooseLargest,
91  Tensor<float, 3, true>& heapDistances,
92  Tensor<int, 3, true>& heapIndices,
93  cudaStream_t stream) {
94  constexpr auto kThreadsPerBlock = 128;
95 
96  auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
97  auto block = dim3(kThreadsPerBlock);
98 
99 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
100  do { \
101  pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
102  <<<grid, block, 0, stream>>>(prefixSumOffsets, \
103  distance, \
104  nprobe, \
105  k, \
106  heapDistances, \
107  heapIndices); \
108  CUDA_TEST_ERROR(); \
109  return; /* success */ \
110  } while (0)
111 
112 #define RUN_PASS_DIR(DIR) \
113  do { \
114  if (k == 1) { \
115  RUN_PASS(1, 1, DIR); \
116  } else if (k <= 32) { \
117  RUN_PASS(32, 2, DIR); \
118  } else if (k <= 64) { \
119  RUN_PASS(64, 3, DIR); \
120  } else if (k <= 128) { \
121  RUN_PASS(128, 3, DIR); \
122  } else if (k <= 256) { \
123  RUN_PASS(256, 4, DIR); \
124  } else if (k <= 512) { \
125  RUN_PASS(512, 8, DIR); \
126  } else if (k <= 1024) { \
127  RUN_PASS(1024, 8, DIR); \
128  } \
129  } while (0)
130 
131  if (chooseLargest) {
132  RUN_PASS_DIR(true);
133  } else {
134  RUN_PASS_DIR(false);
135  }
136 
137  // unimplemented / too many resources
138  FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
139 
140 #undef RUN_PASS_DIR
141 #undef RUN_PASS
142 }
143 
144 } } // namespace