10 #include "warpselect/WarpSelectImpl.cuh"
12 namespace faiss {
namespace gpu {
23 WARP_SELECT_DECL(
float,
true, 1);
24 WARP_SELECT_DECL(
float,
true, 32);
25 WARP_SELECT_DECL(
float,
true, 64);
26 WARP_SELECT_DECL(
float,
true, 128);
27 WARP_SELECT_DECL(
float,
true, 256);
28 WARP_SELECT_DECL(
float,
true, 512);
29 WARP_SELECT_DECL(
float,
true, 1024);
31 WARP_SELECT_DECL(
float,
false, 1);
32 WARP_SELECT_DECL(
float,
false, 32);
33 WARP_SELECT_DECL(
float,
false, 64);
34 WARP_SELECT_DECL(
float,
false, 128);
35 WARP_SELECT_DECL(
float,
false, 256);
36 WARP_SELECT_DECL(
float,
false, 512);
37 WARP_SELECT_DECL(
float,
false, 1024);
39 void runWarpSelect(Tensor<float, 2, true>& in,
40 Tensor<float, 2, true>& outK,
41 Tensor<int, 2, true>& outV,
42 bool dir,
int k, cudaStream_t stream) {
43 FAISS_ASSERT(k <= 1024);
47 WARP_SELECT_CALL(
float,
true, 1);
49 WARP_SELECT_CALL(
float,
true, 32);
51 WARP_SELECT_CALL(
float,
true, 64);
52 }
else if (k <= 128) {
53 WARP_SELECT_CALL(
float,
true, 128);
54 }
else if (k <= 256) {
55 WARP_SELECT_CALL(
float,
true, 256);
56 }
else if (k <= 512) {
57 WARP_SELECT_CALL(
float,
true, 512);
58 }
else if (k <= 1024) {
59 WARP_SELECT_CALL(
float,
true, 1024);
63 WARP_SELECT_CALL(
float,
false, 1);
65 WARP_SELECT_CALL(
float,
false, 32);
67 WARP_SELECT_CALL(
float,
false, 64);
68 }
else if (k <= 128) {
69 WARP_SELECT_CALL(
float,
false, 128);
70 }
else if (k <= 256) {
71 WARP_SELECT_CALL(
float,
false, 256);
72 }
else if (k <= 512) {
73 WARP_SELECT_CALL(
float,
false, 512);
74 }
else if (k <= 1024) {
75 WARP_SELECT_CALL(
float,
false, 1024);