Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectKernel.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "Float16.cuh"
13 #include "Select.cuh"
14 
15 namespace faiss { namespace gpu {
16 
17 template <typename K,
18  typename IndexType,
19  bool Dir,
20  int NumWarpQ,
21  int NumThreadQ,
22  int ThreadsPerBlock>
23 __global__ void blockSelect(Tensor<K, 2, true> in,
24  Tensor<K, 2, true> outK,
25  Tensor<IndexType, 2, true> outV,
26  K initK,
27  IndexType initV,
28  int k) {
29  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
30 
31  __shared__ K smemK[kNumWarps * NumWarpQ];
32  __shared__ IndexType smemV[kNumWarps * NumWarpQ];
33 
34  BlockSelect<K, IndexType, Dir, Comparator<K>,
35  NumWarpQ, NumThreadQ, ThreadsPerBlock>
36  heap(initK, initV, smemK, smemV, k);
37 
38  // Grid is exactly sized to rows available
39  int row = blockIdx.x;
40 
41  int i = threadIdx.x;
42  K* inStart = in[row][i].data();
43 
44  // Whole warps must participate in the selection
45  int limit = utils::roundDown(in.getSize(1), kWarpSize);
46 
47  for (; i < limit; i += ThreadsPerBlock) {
48  heap.add(*inStart, (IndexType) i);
49  inStart += ThreadsPerBlock;
50  }
51 
52  // Handle last remainder fraction of a warp of elements
53  if (i < in.getSize(1)) {
54  heap.addThreadQ(*inStart, (IndexType) i);
55  }
56 
57  heap.reduce();
58 
59  for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
60  outK[row][i] = smemK[i];
61  outV[row][i] = smemV[i];
62  }
63 }
64 
65 void runBlockSelect(Tensor<float, 2, true>& in,
66  Tensor<float, 2, true>& outKeys,
67  Tensor<int, 2, true>& outIndices,
68  bool dir, int k, cudaStream_t stream);
69 
70 #ifdef FAISS_USE_FLOAT16
71 void runBlockSelect(Tensor<half, 2, true>& in,
72  Tensor<half, 2, true>& outKeys,
73  Tensor<int, 2, true>& outIndices,
74  bool dir, int k, cudaStream_t stream);
75 #endif
76 
77 } } // namespace