Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectKernel.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "Float16.cuh"
12 #include "Select.cuh"
13 
14 namespace faiss { namespace gpu {
15 
16 template <typename K,
17  typename IndexType,
18  bool Dir,
19  int NumWarpQ,
20  int NumThreadQ,
21  int ThreadsPerBlock>
22 __global__ void blockSelect(Tensor<K, 2, true> in,
23  Tensor<K, 2, true> outK,
24  Tensor<IndexType, 2, true> outV,
25  K initK,
26  IndexType initV,
27  int k) {
28  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
29 
30  __shared__ K smemK[kNumWarps * NumWarpQ];
31  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
32 
33  BlockSelect<K, IndexType, Dir, Comparator<K>,
34  NumWarpQ, NumThreadQ, ThreadsPerBlock>
35  heap(initK, initV, smemK, smemV, k);
36 
37  // Grid is exactly sized to rows available
38  int row = blockIdx.x;
39 
40  int i = threadIdx.x;
41  K* inStart = in[row][i].data();
42 
43  // Whole warps must participate in the selection
44  int limit = utils::roundDown(in.getSize(1), kWarpSize);
45 
46  for (; i < limit; i += ThreadsPerBlock) {
47  heap.add(*inStart, (IndexType) i);
48  inStart += ThreadsPerBlock;
49  }
50 
51  // Handle last remainder fraction of a warp of elements
52  if (i < in.getSize(1)) {
53  heap.addThreadQ(*inStart, (IndexType) i);
54  }
55 
56  heap.reduce();
57 
58  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
59  outK[row][i] = smemK[i];
60  outV[row][i] = smemV[i];
61  }
62 }
63 
64 template <typename K,
65  typename IndexType,
66  bool Dir,
67  int NumWarpQ,
68  int NumThreadQ,
69  int ThreadsPerBlock>
70 __global__ void blockSelectPair(Tensor<K, 2, true> inK,
71  Tensor<IndexType, 2, true> inV,
72  Tensor<K, 2, true> outK,
73  Tensor<IndexType, 2, true> outV,
74  K initK,
75  IndexType initV,
76  int k) {
77  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
78 
79  __shared__ K smemK[kNumWarps * NumWarpQ];
80  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
81 
82  BlockSelect<K, IndexType, Dir, Comparator<K>,
83  NumWarpQ, NumThreadQ, ThreadsPerBlock>
84  heap(initK, initV, smemK, smemV, k);
85 
86  // Grid is exactly sized to rows available
87  int row = blockIdx.x;
88 
89  int i = threadIdx.x;
90  K* inKStart = inK[row][i].data();
91  IndexType* inVStart = inV[row][i].data();
92 
93  // Whole warps must participate in the selection
94  int limit = utils::roundDown(inK.getSize(1), kWarpSize);
95 
96  for (; i < limit; i += ThreadsPerBlock) {
97  heap.add(*inKStart, *inVStart);
98  inKStart += ThreadsPerBlock;
99  inVStart += ThreadsPerBlock;
100  }
101 
102  // Handle last remainder fraction of a warp of elements
103  if (i < inK.getSize(1)) {
104  heap.addThreadQ(*inKStart, *inVStart);
105  }
106 
107  heap.reduce();
108 
109  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
110  outK[row][i] = smemK[i];
111  outV[row][i] = smemV[i];
112  }
113 }
114 
115 void runBlockSelect(Tensor<float, 2, true>& in,
116  Tensor<float, 2, true>& outKeys,
117  Tensor<int, 2, true>& outIndices,
118  bool dir, int k, cudaStream_t stream);
119 
120 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
121  Tensor<int, 2, true>& inIndices,
122  Tensor<float, 2, true>& outKeys,
123  Tensor<int, 2, true>& outIndices,
124  bool dir, int k, cudaStream_t stream);
125 
126 #ifdef FAISS_USE_FLOAT16
127 void runBlockSelect(Tensor<half, 2, true>& in,
128  Tensor<half, 2, true>& outKeys,
129  Tensor<int, 2, true>& outIndices,
130  bool dir, int k, cudaStream_t stream);
131 
132 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
133  Tensor<int, 2, true>& inIndices,
134  Tensor<half, 2, true>& outKeys,
135  Tensor<int, 2, true>& outIndices,
136  bool dir, int k, cudaStream_t stream);
137 #endif
138 
139 } } // namespace