Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectKernel.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "Float16.cuh"
13 #include "Select.cuh"
14 
15 namespace faiss { namespace gpu {
16 
17 template <typename K,
18  typename IndexType,
19  bool Dir,
20  int NumWarpQ,
21  int NumThreadQ,
22  int ThreadsPerBlock>
23 __global__ void blockSelect(Tensor<K, 2, true> in,
24  Tensor<K, 2, true> outK,
25  Tensor<IndexType, 2, true> outV,
26  K initK,
27  IndexType initV,
28  int k) {
29  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
30 
31  __shared__ K smemK[kNumWarps * NumWarpQ];
32  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
33 
34  BlockSelect<K, IndexType, Dir, Comparator<K>,
35  NumWarpQ, NumThreadQ, ThreadsPerBlock>
36  heap(initK, initV, smemK, smemV, k);
37 
38  // Grid is exactly sized to rows available
39  int row = blockIdx.x;
40 
41  int i = threadIdx.x;
42  K* inStart = in[row][i].data();
43 
44  // Whole warps must participate in the selection
45  int limit = utils::roundDown(in.getSize(1), kWarpSize);
46 
47  for (; i < limit; i += ThreadsPerBlock) {
48  heap.add(*inStart, (IndexType) i);
49  inStart += ThreadsPerBlock;
50  }
51 
52  // Handle last remainder fraction of a warp of elements
53  if (i < in.getSize(1)) {
54  heap.addThreadQ(*inStart, (IndexType) i);
55  }
56 
57  heap.reduce();
58 
59  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
60  outK[row][i] = smemK[i];
61  outV[row][i] = smemV[i];
62  }
63 }
64 
65 template <typename K,
66  typename IndexType,
67  bool Dir,
68  int NumWarpQ,
69  int NumThreadQ,
70  int ThreadsPerBlock>
71 __global__ void blockSelectPair(Tensor<K, 2, true> inK,
72  Tensor<IndexType, 2, true> inV,
73  Tensor<K, 2, true> outK,
74  Tensor<IndexType, 2, true> outV,
75  K initK,
76  IndexType initV,
77  int k) {
78  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
79 
80  __shared__ K smemK[kNumWarps * NumWarpQ];
81  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
82 
83  BlockSelect<K, IndexType, Dir, Comparator<K>,
84  NumWarpQ, NumThreadQ, ThreadsPerBlock>
85  heap(initK, initV, smemK, smemV, k);
86 
87  // Grid is exactly sized to rows available
88  int row = blockIdx.x;
89 
90  int i = threadIdx.x;
91  K* inKStart = inK[row][i].data();
92  IndexType* inVStart = inV[row][i].data();
93 
94  // Whole warps must participate in the selection
95  int limit = utils::roundDown(inK.getSize(1), kWarpSize);
96 
97  for (; i < limit; i += ThreadsPerBlock) {
98  heap.add(*inKStart, *inVStart);
99  inKStart += ThreadsPerBlock;
100  inVStart += ThreadsPerBlock;
101  }
102 
103  // Handle last remainder fraction of a warp of elements
104  if (i < inK.getSize(1)) {
105  heap.addThreadQ(*inKStart, *inVStart);
106  }
107 
108  heap.reduce();
109 
110  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
111  outK[row][i] = smemK[i];
112  outV[row][i] = smemV[i];
113  }
114 }
115 
116 void runBlockSelect(Tensor<float, 2, true>& in,
117  Tensor<float, 2, true>& outKeys,
118  Tensor<int, 2, true>& outIndices,
119  bool dir, int k, cudaStream_t stream);
120 
121 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
122  Tensor<int, 2, true>& inIndices,
123  Tensor<float, 2, true>& outKeys,
124  Tensor<int, 2, true>& outIndices,
125  bool dir, int k, cudaStream_t stream);
126 
127 #ifdef FAISS_USE_FLOAT16
128 void runBlockSelect(Tensor<half, 2, true>& in,
129  Tensor<half, 2, true>& outKeys,
130  Tensor<int, 2, true>& outIndices,
131  bool dir, int k, cudaStream_t stream);
132 
133 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
134  Tensor<int, 2, true>& inIndices,
135  Tensor<half, 2, true>& outKeys,
136  Tensor<int, 2, true>& outIndices,
137  bool dir, int k, cudaStream_t stream);
138 #endif
139 
140 } } // namespace