12 #include "../BlockSelectKernel.cuh"
13 #include "../Limits.cuh"
15 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
16 extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
17 Tensor<TYPE, 2, true>& in, \
18 Tensor<TYPE, 2, true>& outK, \
19 Tensor<int, 2, true>& outV, \
22 cudaStream_t stream); \
24 extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
25 Tensor<TYPE, 2, true>& inK, \
26 Tensor<int, 2, true>& inV, \
27 Tensor<TYPE, 2, true>& outK, \
28 Tensor<int, 2, true>& outV, \
33 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
34 void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
35 Tensor<TYPE, 2, true>& in, \
36 Tensor<TYPE, 2, true>& outK, \
37 Tensor<int, 2, true>& outV, \
40 cudaStream_t stream) { \
41 FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
42 FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
43 FAISS_ASSERT(outK.getSize(1) == k); \
44 FAISS_ASSERT(outV.getSize(1) == k); \
46 auto grid = dim3(in.getSize(0)); \
48 constexpr int kBlockSelectNumThreads = 128; \
49 auto block = dim3(kBlockSelectNumThreads); \
51 FAISS_ASSERT(k <= WARP_Q); \
52 FAISS_ASSERT(dir == DIR); \
54 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
57 blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
58 <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
62 void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
63 Tensor<TYPE, 2, true>& inK, \
64 Tensor<int, 2, true>& inV, \
65 Tensor<TYPE, 2, true>& outK, \
66 Tensor<int, 2, true>& outV, \
69 cudaStream_t stream) { \
70 FAISS_ASSERT(inK.isSameSize(inV)); \
71 FAISS_ASSERT(outK.isSameSize(outV)); \
73 auto grid = dim3(inK.getSize(0)); \
75 constexpr int kBlockSelectNumThreads = 128; \
76 auto block = dim3(kBlockSelectNumThreads); \
78 FAISS_ASSERT(k <= WARP_Q); \
79 FAISS_ASSERT(dir == DIR); \
81 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
84 blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
85 <<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
90 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
91 runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
92 in, outK, outV, dir, k, stream)
94 #define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
95 runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
96 inK, inV, outK, outV, dir, k, stream)