Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectImpl.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 #include "../WarpSelectKernel.cuh"
12 #include "../Limits.cuh"
13 
14 #define WARP_SELECT_DECL(TYPE, DIR, WARP_Q) \
15  extern void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
16  Tensor<TYPE, 2, true>& in, \
17  Tensor<TYPE, 2, true>& outK, \
18  Tensor<int, 2, true>& outV, \
19  bool dir, \
20  int k, \
21  cudaStream_t stream)
22 
23 #define WARP_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
24  void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
25  Tensor<TYPE, 2, true>& in, \
26  Tensor<TYPE, 2, true>& outK, \
27  Tensor<int, 2, true>& outV, \
28  bool dir, \
29  int k, \
30  cudaStream_t stream) { \
31  \
32  constexpr int kWarpSelectNumThreads = 128; \
33  auto grid = dim3(utils::divUp(in.getSize(0), \
34  (kWarpSelectNumThreads / kWarpSize))); \
35  auto block = dim3(kWarpSelectNumThreads); \
36  \
37  FAISS_ASSERT(k <= WARP_Q); \
38  FAISS_ASSERT(dir == DIR); \
39  \
40  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
41  auto vInit = -1; \
42  \
43  warpSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kWarpSelectNumThreads> \
44  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
45  }
46 
47 #define WARP_SELECT_CALL(TYPE, DIR, WARP_Q) \
48  runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
49  in, outK, outV, dir, k, stream)