Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectKernel.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "Float16.cuh"
11 #include "Select.cuh"
12 
13 namespace faiss { namespace gpu {
14 
15 template <typename K,
16  typename IndexType,
17  bool Dir,
18  int NumWarpQ,
19  int NumThreadQ,
20  int ThreadsPerBlock>
21 __global__ void warpSelect(Tensor<K, 2, true> in,
22  Tensor<K, 2, true> outK,
23  Tensor<IndexType, 2, true> outV,
24  K initK,
25  IndexType initV,
26  int k) {
27  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
28 
29  WarpSelect<K, IndexType, Dir, Comparator<K>,
30  NumWarpQ, NumThreadQ, ThreadsPerBlock>
31  heap(initK, initV, k);
32 
33  int warpId = threadIdx.x / kWarpSize;
34  int row = blockIdx.x * kNumWarps + warpId;
35 
36  if (row >= in.getSize(0)) {
37  return;
38  }
39 
40  int i = getLaneId();
41  K* inStart = in[row][i].data();
42 
43  // Whole warps must participate in the selection
44  int limit = utils::roundDown(in.getSize(1), kWarpSize);
45 
46  for (; i < limit; i += kWarpSize) {
47  heap.add(*inStart, (IndexType) i);
48  inStart += kWarpSize;
49  }
50 
51  // Handle non-warp multiple remainder
52  if (i < in.getSize(1)) {
53  heap.addThreadQ(*inStart, (IndexType) i);
54  }
55 
56  heap.reduce();
57  heap.writeOut(outK[row].data(),
58  outV[row].data(), k);
59 }
60 
61 void runWarpSelect(Tensor<float, 2, true>& in,
62  Tensor<float, 2, true>& outKeys,
63  Tensor<int, 2, true>& outIndices,
64  bool dir, int k, cudaStream_t stream);
65 
66 #ifdef FAISS_USE_FLOAT16
67 void runWarpSelect(Tensor<half, 2, true>& in,
68  Tensor<half, 2, true>& outKeys,
69  Tensor<int, 2, true>& outIndices,
70  bool dir, int k, cudaStream_t stream);
71 #endif
72 
73 } } // namespace