Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectImpl.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #include "../BlockSelectKernel.cuh"
11 #include "../Limits.cuh"
12 
13 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
14  extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
15  Tensor<TYPE, 2, true>& in, \
16  Tensor<TYPE, 2, true>& outK, \
17  Tensor<int, 2, true>& outV, \
18  bool dir, \
19  int k, \
20  cudaStream_t stream)
21 
22 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
23  void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
24  Tensor<TYPE, 2, true>& in, \
25  Tensor<TYPE, 2, true>& outK, \
26  Tensor<int, 2, true>& outV, \
27  bool dir, \
28  int k, \
29  cudaStream_t stream) { \
30  auto grid = dim3(in.getSize(0)); \
31  \
32  constexpr int kBlockSelectNumThreads = 128; \
33  auto block = dim3(kBlockSelectNumThreads); \
34  \
35  FAISS_ASSERT(k <= WARP_Q); \
36  FAISS_ASSERT(dir == DIR); \
37  \
38  auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
39  auto vInit = -1; \
40  \
41  blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
42  <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
43  CUDA_TEST_ERROR(); \
44  }
45 
46 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
47  runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
48  in, outK, outV, dir, k, stream)