11 #include "warpselect/WarpSelectImpl.cuh"
13 namespace faiss {
namespace gpu {
15 #ifdef FAISS_USE_FLOAT16
26 WARP_SELECT_DECL(half,
true, 1);
27 WARP_SELECT_DECL(half,
true, 32);
28 WARP_SELECT_DECL(half,
true, 64);
29 WARP_SELECT_DECL(half,
true, 128);
30 WARP_SELECT_DECL(half,
true, 256);
31 WARP_SELECT_DECL(half,
true, 512);
32 WARP_SELECT_DECL(half,
true, 1024);
34 WARP_SELECT_DECL(half,
false, 1);
35 WARP_SELECT_DECL(half,
false, 32);
36 WARP_SELECT_DECL(half,
false, 64);
37 WARP_SELECT_DECL(half,
false, 128);
38 WARP_SELECT_DECL(half,
false, 256);
39 WARP_SELECT_DECL(half,
false, 512);
40 WARP_SELECT_DECL(half,
false, 1024);
42 void runWarpSelect(Tensor<half, 2, true>& in,
43 Tensor<half, 2, true>& outK,
44 Tensor<int, 2, true>& outV,
45 bool dir,
int k, cudaStream_t stream) {
46 FAISS_ASSERT(k <= 1024);
50 WARP_SELECT_CALL(half,
true, 1);
52 WARP_SELECT_CALL(half,
true, 32);
54 WARP_SELECT_CALL(half,
true, 64);
55 }
else if (k <= 128) {
56 WARP_SELECT_CALL(half,
true, 128);
57 }
else if (k <= 256) {
58 WARP_SELECT_CALL(half,
true, 256);
59 }
else if (k <= 512) {
60 WARP_SELECT_CALL(half,
true, 512);
61 }
else if (k <= 1024) {
62 WARP_SELECT_CALL(half,
true, 1024);
66 WARP_SELECT_CALL(half,
false, 1);
68 WARP_SELECT_CALL(half,
false, 32);
70 WARP_SELECT_CALL(half,
false, 64);
71 }
else if (k <= 128) {
72 WARP_SELECT_CALL(half,
false, 128);
73 }
else if (k <= 256) {
74 WARP_SELECT_CALL(half,
false, 256);
75 }
else if (k <= 512) {
76 WARP_SELECT_CALL(half,
false, 512);
77 }
else if (k <= 1024) {
78 WARP_SELECT_CALL(half,
false, 1024);