8 #include "warpselect/WarpSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
11 namespace faiss {
namespace gpu {
13 #ifdef FAISS_USE_FLOAT16
25 WARP_SELECT_DECL(half,
true, 1);
26 WARP_SELECT_DECL(half,
true, 32);
27 WARP_SELECT_DECL(half,
true, 64);
28 WARP_SELECT_DECL(half,
true, 128);
29 WARP_SELECT_DECL(half,
true, 256);
30 WARP_SELECT_DECL(half,
true, 512);
31 WARP_SELECT_DECL(half,
true, 1024);
32 #if GPU_MAX_SELECTION_K >= 2048
33 WARP_SELECT_DECL(half,
true, 2048);
36 WARP_SELECT_DECL(half,
false, 1);
37 WARP_SELECT_DECL(half,
false, 32);
38 WARP_SELECT_DECL(half,
false, 64);
39 WARP_SELECT_DECL(half,
false, 128);
40 WARP_SELECT_DECL(half,
false, 256);
41 WARP_SELECT_DECL(half,
false, 512);
42 WARP_SELECT_DECL(half,
false, 1024);
43 #if GPU_MAX_SELECTION_K >= 2048
44 WARP_SELECT_DECL(half,
false, 2048);
47 void runWarpSelect(Tensor<half, 2, true>& in,
48 Tensor<half, 2, true>& outK,
49 Tensor<int, 2, true>& outV,
50 bool dir,
int k, cudaStream_t stream) {
51 FAISS_ASSERT(k <= 1024);
55 WARP_SELECT_CALL(half,
true, 1);
57 WARP_SELECT_CALL(half,
true, 32);
59 WARP_SELECT_CALL(half,
true, 64);
60 }
else if (k <= 128) {
61 WARP_SELECT_CALL(half,
true, 128);
62 }
else if (k <= 256) {
63 WARP_SELECT_CALL(half,
true, 256);
64 }
else if (k <= 512) {
65 WARP_SELECT_CALL(half,
true, 512);
66 }
else if (k <= 1024) {
67 WARP_SELECT_CALL(half,
true, 1024);
68 #if GPU_MAX_SELECTION_K >= 2048
69 }
else if (k <= 2048) {
70 WARP_SELECT_CALL(half,
true, 2048);
75 WARP_SELECT_CALL(half,
false, 1);
77 WARP_SELECT_CALL(half,
false, 32);
79 WARP_SELECT_CALL(half,
false, 64);
80 }
else if (k <= 128) {
81 WARP_SELECT_CALL(half,
false, 128);
82 }
else if (k <= 256) {
83 WARP_SELECT_CALL(half,
false, 256);
84 }
else if (k <= 512) {
85 WARP_SELECT_CALL(half,
false, 512);
86 }
else if (k <= 1024) {
87 WARP_SELECT_CALL(half,
false, 1024);
88 #if GPU_MAX_SELECTION_K >= 2048
89 }
else if (k <= 2048) {
90 WARP_SELECT_CALL(half,
false, 2048);