8 #include "../WarpSelectKernel.cuh"
9 #include "../Limits.cuh"
11 #define WARP_SELECT_DECL(TYPE, DIR, WARP_Q) \
12 extern void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
13 Tensor<TYPE, 2, true>& in, \
14 Tensor<TYPE, 2, true>& outK, \
15 Tensor<int, 2, true>& outV, \
20 #define WARP_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
21 void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
22 Tensor<TYPE, 2, true>& in, \
23 Tensor<TYPE, 2, true>& outK, \
24 Tensor<int, 2, true>& outV, \
27 cudaStream_t stream) { \
29 constexpr int kWarpSelectNumThreads = 128; \
30 auto grid = dim3(utils::divUp(in.getSize(0), \
31 (kWarpSelectNumThreads / kWarpSize))); \
32 auto block = dim3(kWarpSelectNumThreads); \
34 FAISS_ASSERT(k <= WARP_Q); \
35 FAISS_ASSERT(dir == DIR); \
37 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
40 warpSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kWarpSelectNumThreads> \
41 <<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
45 #define WARP_SELECT_CALL(TYPE, DIR, WARP_Q) \
46 runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
47 in, outK, outV, dir, k, stream)