11 #include "warpselect/WarpSelectImpl.cuh"
13 namespace faiss {
namespace gpu {
24 WARP_SELECT_DECL(
float,
true, 1);
25 WARP_SELECT_DECL(
float,
true, 32);
26 WARP_SELECT_DECL(
float,
true, 64);
27 WARP_SELECT_DECL(
float,
true, 128);
28 WARP_SELECT_DECL(
float,
true, 256);
29 WARP_SELECT_DECL(
float,
true, 512);
30 WARP_SELECT_DECL(
float,
true, 1024);
32 WARP_SELECT_DECL(
float,
false, 1);
33 WARP_SELECT_DECL(
float,
false, 32);
34 WARP_SELECT_DECL(
float,
false, 64);
35 WARP_SELECT_DECL(
float,
false, 128);
36 WARP_SELECT_DECL(
float,
false, 256);
37 WARP_SELECT_DECL(
float,
false, 512);
38 WARP_SELECT_DECL(
float,
false, 1024);
40 void runWarpSelect(Tensor<float, 2, true>& in,
41 Tensor<float, 2, true>& outK,
42 Tensor<int, 2, true>& outV,
43 bool dir,
int k, cudaStream_t stream) {
44 FAISS_ASSERT(k <= 1024);
48 WARP_SELECT_CALL(
float,
true, 1);
50 WARP_SELECT_CALL(
float,
true, 32);
52 WARP_SELECT_CALL(
float,
true, 64);
53 }
else if (k <= 128) {
54 WARP_SELECT_CALL(
float,
true, 128);
55 }
else if (k <= 256) {
56 WARP_SELECT_CALL(
float,
true, 256);
57 }
else if (k <= 512) {
58 WARP_SELECT_CALL(
float,
true, 512);
59 }
else if (k <= 1024) {
60 WARP_SELECT_CALL(
float,
true, 1024);
64 WARP_SELECT_CALL(
float,
false, 1);
66 WARP_SELECT_CALL(
float,
false, 32);
68 WARP_SELECT_CALL(
float,
false, 64);
69 }
else if (k <= 128) {
70 WARP_SELECT_CALL(
float,
false, 128);
71 }
else if (k <= 256) {
72 WARP_SELECT_CALL(
float,
false, 256);
73 }
else if (k <= 512) {
74 WARP_SELECT_CALL(
float,
false, 512);
75 }
else if (k <= 1024) {
76 WARP_SELECT_CALL(
float,
false, 1024);