Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectImpl.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "../BlockSelectKernel.cuh"
11 #include "../Limits.cuh"
12 
13 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
14  extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
15  Tensor<TYPE, 2, true>& in, \
16  Tensor<TYPE, 2, true>& outK, \
17  Tensor<int, 2, true>& outV, \
18  bool dir, \
19  int k, \
20  cudaStream_t stream); \
21  \
22  extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
23  Tensor<TYPE, 2, true>& inK, \
24  Tensor<int, 2, true>& inV, \
25  Tensor<TYPE, 2, true>& outK, \
26  Tensor<int, 2, true>& outV, \
27  bool dir, \
28  int k, \
29  cudaStream_t stream)
30 
31 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
32  void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
33  Tensor<TYPE, 2, true>& in, \
34  Tensor<TYPE, 2, true>& outK, \
35  Tensor<int, 2, true>& outV, \
36  bool dir, \
37  int k, \
38  cudaStream_t stream) { \
39  FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
40  FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
41  FAISS_ASSERT(outK.getSize(1) == k); \
42  FAISS_ASSERT(outV.getSize(1) == k); \
43  \
44  auto grid = dim3(in.getSize(0)); \
45  \
46  constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \
47  auto block = dim3(kBlockSelectNumThreads); \
48  \
49  FAISS_ASSERT(k <= WARP_Q); \
50  FAISS_ASSERT(dir == DIR); \
51  \
52  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
53  auto vInit = -1; \
54  \
55  blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
56  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
57  CUDA_TEST_ERROR(); \
58  } \
59  \
60  void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
61  Tensor<TYPE, 2, true>& inK, \
62  Tensor<int, 2, true>& inV, \
63  Tensor<TYPE, 2, true>& outK, \
64  Tensor<int, 2, true>& outV, \
65  bool dir, \
66  int k, \
67  cudaStream_t stream) { \
68  FAISS_ASSERT(inK.isSameSize(inV)); \
69  FAISS_ASSERT(outK.isSameSize(outV)); \
70  \
71  auto grid = dim3(inK.getSize(0)); \
72  \
73  constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \
74  auto block = dim3(kBlockSelectNumThreads); \
75  \
76  FAISS_ASSERT(k <= WARP_Q); \
77  FAISS_ASSERT(dir == DIR); \
78  \
79  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
80  auto vInit = -1; \
81  \
82  blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
83  <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
84  CUDA_TEST_ERROR(); \
85  }
86 
87 
88 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
89  runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
90  in, outK, outV, dir, k, stream)
91 
92 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
93  runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
94  inK, inV, outK, outV, dir, k, stream)