Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectKernel.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "Float16.cuh"
11 #include "Select.cuh"
12 
13 namespace faiss { namespace gpu {
14 
15 template <typename K,
16  typename IndexType,
17  bool Dir,
18  int NumWarpQ,
19  int NumThreadQ,
20  int ThreadsPerBlock>
21 __global__ void blockSelect(Tensor<K, 2, true> in,
22  Tensor<K, 2, true> outK,
23  Tensor<IndexType, 2, true> outV,
24  K initK,
25  IndexType initV,
26  int k) {
27  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
28 
29  __shared__ K smemK[kNumWarps * NumWarpQ];
30  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
31 
32  BlockSelect<K, IndexType, Dir, Comparator<K>,
33  NumWarpQ, NumThreadQ, ThreadsPerBlock>
34  heap(initK, initV, smemK, smemV, k);
35 
36  // Grid is exactly sized to rows available
37  int row = blockIdx.x;
38 
39  int i = threadIdx.x;
40  K* inStart = in[row][i].data();
41 
42  // Whole warps must participate in the selection
43  int limit = utils::roundDown(in.getSize(1), kWarpSize);
44 
45  for (; i < limit; i += ThreadsPerBlock) {
46  heap.add(*inStart, (IndexType) i);
47  inStart += ThreadsPerBlock;
48  }
49 
50  // Handle last remainder fraction of a warp of elements
51  if (i < in.getSize(1)) {
52  heap.addThreadQ(*inStart, (IndexType) i);
53  }
54 
55  heap.reduce();
56 
57  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
58  outK[row][i] = smemK[i];
59  outV[row][i] = smemV[i];
60  }
61 }
62 
63 template <typename K,
64  typename IndexType,
65  bool Dir,
66  int NumWarpQ,
67  int NumThreadQ,
68  int ThreadsPerBlock>
69 __global__ void blockSelectPair(Tensor<K, 2, true> inK,
70  Tensor<IndexType, 2, true> inV,
71  Tensor<K, 2, true> outK,
72  Tensor<IndexType, 2, true> outV,
73  K initK,
74  IndexType initV,
75  int k) {
76  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
77 
78  __shared__ K smemK[kNumWarps * NumWarpQ];
79  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
80 
81  BlockSelect<K, IndexType, Dir, Comparator<K>,
82  NumWarpQ, NumThreadQ, ThreadsPerBlock>
83  heap(initK, initV, smemK, smemV, k);
84 
85  // Grid is exactly sized to rows available
86  int row = blockIdx.x;
87 
88  int i = threadIdx.x;
89  K* inKStart = inK[row][i].data();
90  IndexType* inVStart = inV[row][i].data();
91 
92  // Whole warps must participate in the selection
93  int limit = utils::roundDown(inK.getSize(1), kWarpSize);
94 
95  for (; i < limit; i += ThreadsPerBlock) {
96  heap.add(*inKStart, *inVStart);
97  inKStart += ThreadsPerBlock;
98  inVStart += ThreadsPerBlock;
99  }
100 
101  // Handle last remainder fraction of a warp of elements
102  if (i < inK.getSize(1)) {
103  heap.addThreadQ(*inKStart, *inVStart);
104  }
105 
106  heap.reduce();
107 
108  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
109  outK[row][i] = smemK[i];
110  outV[row][i] = smemV[i];
111  }
112 }
113 
114 void runBlockSelect(Tensor<float, 2, true>& in,
115  Tensor<float, 2, true>& outKeys,
116  Tensor<int, 2, true>& outIndices,
117  bool dir, int k, cudaStream_t stream);
118 
119 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
120  Tensor<int, 2, true>& inIndices,
121  Tensor<float, 2, true>& outKeys,
122  Tensor<int, 2, true>& outIndices,
123  bool dir, int k, cudaStream_t stream);
124 
125 #ifdef FAISS_USE_FLOAT16
126 void runBlockSelect(Tensor<half, 2, true>& in,
127  Tensor<half, 2, true>& outKeys,
128  Tensor<int, 2, true>& outIndices,
129  bool dir, int k, cudaStream_t stream);
130 
131 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
132  Tensor<int, 2, true>& inIndices,
133  Tensor<half, 2, true>& outKeys,
134  Tensor<int, 2, true>& outIndices,
135  bool dir, int k, cudaStream_t stream);
136 #endif
137 
138 } } // namespace