Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectImpl.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "../WarpSelectKernel.cuh"
9 #include "../Limits.cuh"
10 
11 #define WARP_SELECT_DECL(TYPE, DIR, WARP_Q) \
12  extern void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
13  Tensor<TYPE, 2, true>& in, \
14  Tensor<TYPE, 2, true>& outK, \
15  Tensor<int, 2, true>& outV, \
16  bool dir, \
17  int k, \
18  cudaStream_t stream)
19 
20 #define WARP_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
21  void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
22  Tensor<TYPE, 2, true>& in, \
23  Tensor<TYPE, 2, true>& outK, \
24  Tensor<int, 2, true>& outV, \
25  bool dir, \
26  int k, \
27  cudaStream_t stream) { \
28  \
29  constexpr int kWarpSelectNumThreads = 128; \
30  auto grid = dim3(utils::divUp(in.getSize(0), \
31  (kWarpSelectNumThreads / kWarpSize))); \
32  auto block = dim3(kWarpSelectNumThreads); \
33  \
34  FAISS_ASSERT(k <= WARP_Q); \
35  FAISS_ASSERT(dir == DIR); \
36  \
37  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
38  auto vInit = -1; \
39  \
40  warpSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kWarpSelectNumThreads> \
41  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
42  CUDA_TEST_ERROR(); \
43  }
44 
45 #define WARP_SELECT_CALL(TYPE, DIR, WARP_Q) \
46  runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
47  in, outK, outV, dir, k, stream)