9 #include "warpselect/WarpSelectImpl.cuh"
11 namespace faiss {
namespace gpu {
22 WARP_SELECT_DECL(
float,
true, 1);
23 WARP_SELECT_DECL(
float,
true, 32);
24 WARP_SELECT_DECL(
float,
true, 64);
25 WARP_SELECT_DECL(
float,
true, 128);
26 WARP_SELECT_DECL(
float,
true, 256);
27 WARP_SELECT_DECL(
float,
true, 512);
28 WARP_SELECT_DECL(
float,
true, 1024);
30 WARP_SELECT_DECL(
float,
false, 1);
31 WARP_SELECT_DECL(
float,
false, 32);
32 WARP_SELECT_DECL(
float,
false, 64);
33 WARP_SELECT_DECL(
float,
false, 128);
34 WARP_SELECT_DECL(
float,
false, 256);
35 WARP_SELECT_DECL(
float,
false, 512);
36 WARP_SELECT_DECL(
float,
false, 1024);
38 void runWarpSelect(Tensor<float, 2, true>& in,
39 Tensor<float, 2, true>& outK,
40 Tensor<int, 2, true>& outV,
41 bool dir,
int k, cudaStream_t stream) {
42 FAISS_ASSERT(k <= 1024);
46 WARP_SELECT_CALL(
float,
true, 1);
48 WARP_SELECT_CALL(
float,
true, 32);
50 WARP_SELECT_CALL(
float,
true, 64);
51 }
else if (k <= 128) {
52 WARP_SELECT_CALL(
float,
true, 128);
53 }
else if (k <= 256) {
54 WARP_SELECT_CALL(
float,
true, 256);
55 }
else if (k <= 512) {
56 WARP_SELECT_CALL(
float,
true, 512);
57 }
else if (k <= 1024) {
58 WARP_SELECT_CALL(
float,
true, 1024);
62 WARP_SELECT_CALL(
float,
false, 1);
64 WARP_SELECT_CALL(
float,
false, 32);
66 WARP_SELECT_CALL(
float,
false, 64);
67 }
else if (k <= 128) {
68 WARP_SELECT_CALL(
float,
false, 128);
69 }
else if (k <= 256) {
70 WARP_SELECT_CALL(
float,
false, 256);
71 }
else if (k <= 512) {
72 WARP_SELECT_CALL(
float,
false, 512);
73 }
else if (k <= 1024) {
74 WARP_SELECT_CALL(
float,
false, 1024);