Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectImpl.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "../BlockSelectKernel.cuh"
12 #include "../Limits.cuh"
13 
14 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
15  extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
16  Tensor<TYPE, 2, true>& in, \
17  Tensor<TYPE, 2, true>& outK, \
18  Tensor<int, 2, true>& outV, \
19  bool dir, \
20  int k, \
21  cudaStream_t stream); \
22  \
23  extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
24  Tensor<TYPE, 2, true>& inK, \
25  Tensor<int, 2, true>& inV, \
26  Tensor<TYPE, 2, true>& outK, \
27  Tensor<int, 2, true>& outV, \
28  bool dir, \
29  int k, \
30  cudaStream_t stream)
31 
32 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
33  void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
34  Tensor<TYPE, 2, true>& in, \
35  Tensor<TYPE, 2, true>& outK, \
36  Tensor<int, 2, true>& outV, \
37  bool dir, \
38  int k, \
39  cudaStream_t stream) { \
40  FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
41  FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
42  FAISS_ASSERT(outK.getSize(1) == k); \
43  FAISS_ASSERT(outV.getSize(1) == k); \
44  \
45  auto grid = dim3(in.getSize(0)); \
46  \
47  constexpr int kBlockSelectNumThreads = 128; \
48  auto block = dim3(kBlockSelectNumThreads); \
49  \
50  FAISS_ASSERT(k <= WARP_Q); \
51  FAISS_ASSERT(dir == DIR); \
52  \
53  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
54  auto vInit = -1; \
55  \
56  blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
57  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
58  CUDA_TEST_ERROR(); \
59  } \
60  \
61  void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
62  Tensor<TYPE, 2, true>& inK, \
63  Tensor<int, 2, true>& inV, \
64  Tensor<TYPE, 2, true>& outK, \
65  Tensor<int, 2, true>& outV, \
66  bool dir, \
67  int k, \
68  cudaStream_t stream) { \
69  FAISS_ASSERT(inK.isSameSize(inV)); \
70  FAISS_ASSERT(outK.isSameSize(outV)); \
71  \
72  auto grid = dim3(inK.getSize(0)); \
73  \
74  constexpr int kBlockSelectNumThreads = 128; \
75  auto block = dim3(kBlockSelectNumThreads); \
76  \
77  FAISS_ASSERT(k <= WARP_Q); \
78  FAISS_ASSERT(dir == DIR); \
79  \
80  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
81  auto vInit = -1; \
82  \
83  blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
84  <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
85  CUDA_TEST_ERROR(); \
86  }
87 
88 
89 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
90  runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
91  in, outK, outV, dir, k, stream)
92 
93 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
94  runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
95  inK, inV, outK, outV, dir, k, stream)