10 #include "../BlockSelectKernel.cuh"
11 #include "../Limits.cuh"
13 #define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
14 extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
15 Tensor<TYPE, 2, true>& in, \
16 Tensor<TYPE, 2, true>& outK, \
17 Tensor<int, 2, true>& outV, \
22 #define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
23 void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
24 Tensor<TYPE, 2, true>& in, \
25 Tensor<TYPE, 2, true>& outK, \
26 Tensor<int, 2, true>& outV, \
29 cudaStream_t stream) { \
30 auto grid = dim3(in.getSize(0)); \
32 constexpr int kBlockSelectNumThreads = 128; \
33 auto block = dim3(kBlockSelectNumThreads); \
35 FAISS_ASSERT(k <= WARP_Q); \
36 FAISS_ASSERT(dir == DIR); \
38 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
41 blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
42 <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
46 #define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
47 runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
48 in, outK, outV, dir, k, stream)