10 #include "../BlockSelectKernel.cuh"
11 #include "../Limits.cuh"
13 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
14 extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
15 Tensor<TYPE, 2, true>& in, \
16 Tensor<TYPE, 2, true>& outK, \
17 Tensor<int, 2, true>& outV, \
20 cudaStream_t stream); \
22 extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
23 Tensor<TYPE, 2, true>& inK, \
24 Tensor<int, 2, true>& inV, \
25 Tensor<TYPE, 2, true>& outK, \
26 Tensor<int, 2, true>& outV, \
31 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
32 void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
33 Tensor<TYPE, 2, true>& in, \
34 Tensor<TYPE, 2, true>& outK, \
35 Tensor<int, 2, true>& outV, \
38 cudaStream_t stream) { \
39 FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
40 FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
41 FAISS_ASSERT(outK.getSize(1) == k); \
42 FAISS_ASSERT(outV.getSize(1) == k); \
44 auto grid = dim3(in.getSize(0)); \
46 constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \
47 auto block = dim3(kBlockSelectNumThreads); \
49 FAISS_ASSERT(k <= WARP_Q); \
50 FAISS_ASSERT(dir == DIR); \
52 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
55 blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
56 <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
60 void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
61 Tensor<TYPE, 2, true>& inK, \
62 Tensor<int, 2, true>& inV, \
63 Tensor<TYPE, 2, true>& outK, \
64 Tensor<int, 2, true>& outV, \
67 cudaStream_t stream) { \
68 FAISS_ASSERT(inK.isSameSize(inV)); \
69 FAISS_ASSERT(outK.isSameSize(outV)); \
71 auto grid = dim3(inK.getSize(0)); \
73 constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \
74 auto block = dim3(kBlockSelectNumThreads); \
76 FAISS_ASSERT(k <= WARP_Q); \
77 FAISS_ASSERT(dir == DIR); \
79 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
82 blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
83 <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
88 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
89 runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
90 in, outK, outV, dir, k, stream)
92 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
93 runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
94 inK, inV, outK, outV, dir, k, stream)