Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectImpl.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "../BlockSelectKernel.cuh"
13 #include "../Limits.cuh"
14 
15 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
16  extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
17  Tensor<TYPE, 2, true>& in, \
18  Tensor<TYPE, 2, true>& outK, \
19  Tensor<int, 2, true>& outV, \
20  bool dir, \
21  int k, \
22  cudaStream_t stream); \
23  \
24  extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
25  Tensor<TYPE, 2, true>& inK, \
26  Tensor<int, 2, true>& inV, \
27  Tensor<TYPE, 2, true>& outK, \
28  Tensor<int, 2, true>& outV, \
29  bool dir, \
30  int k, \
31  cudaStream_t stream)
32 
33 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
34  void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
35  Tensor<TYPE, 2, true>& in, \
36  Tensor<TYPE, 2, true>& outK, \
37  Tensor<int, 2, true>& outV, \
38  bool dir, \
39  int k, \
40  cudaStream_t stream) { \
41  FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
42  FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
43  FAISS_ASSERT(outK.getSize(1) == k); \
44  FAISS_ASSERT(outV.getSize(1) == k); \
45  \
46  auto grid = dim3(in.getSize(0)); \
47  \
48  constexpr int kBlockSelectNumThreads = 128; \
49  auto block = dim3(kBlockSelectNumThreads); \
50  \
51  FAISS_ASSERT(k <= WARP_Q); \
52  FAISS_ASSERT(dir == DIR); \
53  \
54  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
55  auto vInit = -1; \
56  \
57  blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
58  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
59  CUDA_TEST_ERROR(); \
60  } \
61  \
62  void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
63  Tensor<TYPE, 2, true>& inK, \
64  Tensor<int, 2, true>& inV, \
65  Tensor<TYPE, 2, true>& outK, \
66  Tensor<int, 2, true>& outV, \
67  bool dir, \
68  int k, \
69  cudaStream_t stream) { \
70  FAISS_ASSERT(inK.isSameSize(inV)); \
71  FAISS_ASSERT(outK.isSameSize(outV)); \
72  \
73  auto grid = dim3(inK.getSize(0)); \
74  \
75  constexpr int kBlockSelectNumThreads = 128; \
76  auto block = dim3(kBlockSelectNumThreads); \
77  \
78  FAISS_ASSERT(k <= WARP_Q); \
79  FAISS_ASSERT(dir == DIR); \
80  \
81  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
82  auto vInit = -1; \
83  \
84  blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
85  <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
86  CUDA_TEST_ERROR(); \
87  }
88 
89 
90 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
91  runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
92  in, outK, outV, dir, k, stream)
93 
94 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
95  runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
96  inK, inV, outK, outV, dir, k, stream)