11 #include "../BlockSelectKernel.cuh"
12 #include "../Limits.cuh"
14 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
15 extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
16 Tensor<TYPE, 2, true>& in, \
17 Tensor<TYPE, 2, true>& outK, \
18 Tensor<int, 2, true>& outV, \
21 cudaStream_t stream); \
23 extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
24 Tensor<TYPE, 2, true>& inK, \
25 Tensor<int, 2, true>& inV, \
26 Tensor<TYPE, 2, true>& outK, \
27 Tensor<int, 2, true>& outV, \
32 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
33 void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
34 Tensor<TYPE, 2, true>& in, \
35 Tensor<TYPE, 2, true>& outK, \
36 Tensor<int, 2, true>& outV, \
39 cudaStream_t stream) { \
40 FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
41 FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
42 FAISS_ASSERT(outK.getSize(1) == k); \
43 FAISS_ASSERT(outV.getSize(1) == k); \
45 auto grid = dim3(in.getSize(0)); \
47 constexpr int kBlockSelectNumThreads = 128; \
48 auto block = dim3(kBlockSelectNumThreads); \
50 FAISS_ASSERT(k <= WARP_Q); \
51 FAISS_ASSERT(dir == DIR); \
53 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
56 blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
57 <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
61 void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
62 Tensor<TYPE, 2, true>& inK, \
63 Tensor<int, 2, true>& inV, \
64 Tensor<TYPE, 2, true>& outK, \
65 Tensor<int, 2, true>& outV, \
68 cudaStream_t stream) { \
69 FAISS_ASSERT(inK.isSameSize(inV)); \
70 FAISS_ASSERT(outK.isSameSize(outV)); \
72 auto grid = dim3(inK.getSize(0)); \
74 constexpr int kBlockSelectNumThreads = 128; \
75 auto block = dim3(kBlockSelectNumThreads); \
77 FAISS_ASSERT(k <= WARP_Q); \
78 FAISS_ASSERT(dir == DIR); \
80 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
83 blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
84 <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
89 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
90 runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
91 in, outK, outV, dir, k, stream)
93 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
94 runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
95 inK, inV, outK, outV, dir, k, stream)