8 #include "blockselect/BlockSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
11 namespace faiss {
namespace gpu {
13 #ifdef FAISS_USE_FLOAT16
25 BLOCK_SELECT_DECL(half,
true, 1);
26 BLOCK_SELECT_DECL(half,
true, 32);
27 BLOCK_SELECT_DECL(half,
true, 64);
28 BLOCK_SELECT_DECL(half,
true, 128);
29 BLOCK_SELECT_DECL(half,
true, 256);
30 BLOCK_SELECT_DECL(half,
true, 512);
31 BLOCK_SELECT_DECL(half,
true, 1024);
32 #if GPU_MAX_SELECTION_K >= 2048
33 BLOCK_SELECT_DECL(half,
true, 2048);
36 BLOCK_SELECT_DECL(half,
false, 1);
37 BLOCK_SELECT_DECL(half,
false, 32);
38 BLOCK_SELECT_DECL(half,
false, 64);
39 BLOCK_SELECT_DECL(half,
false, 128);
40 BLOCK_SELECT_DECL(half,
false, 256);
41 BLOCK_SELECT_DECL(half,
false, 512);
42 BLOCK_SELECT_DECL(half,
false, 1024);
43 #if GPU_MAX_SELECTION_K >= 2048
44 BLOCK_SELECT_DECL(half,
false, 2048);
47 void runBlockSelect(Tensor<half, 2, true>& in,
48 Tensor<half, 2, true>& outK,
49 Tensor<int, 2, true>& outV,
50 bool dir,
int k, cudaStream_t stream) {
51 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
55 BLOCK_SELECT_CALL(half,
true, 1);
57 BLOCK_SELECT_CALL(half,
true, 32);
59 BLOCK_SELECT_CALL(half,
true, 64);
60 }
else if (k <= 128) {
61 BLOCK_SELECT_CALL(half,
true, 128);
62 }
else if (k <= 256) {
63 BLOCK_SELECT_CALL(half,
true, 256);
64 }
else if (k <= 512) {
65 BLOCK_SELECT_CALL(half,
true, 512);
66 }
else if (k <= 1024) {
67 BLOCK_SELECT_CALL(half,
true, 1024);
68 #if GPU_MAX_SELECTION_K >= 2048
69 }
else if (k <= 2048) {
70 BLOCK_SELECT_CALL(half,
true, 2048);
75 BLOCK_SELECT_CALL(half,
false, 1);
77 BLOCK_SELECT_CALL(half,
false, 32);
79 BLOCK_SELECT_CALL(half,
false, 64);
80 }
else if (k <= 128) {
81 BLOCK_SELECT_CALL(half,
false, 128);
82 }
else if (k <= 256) {
83 BLOCK_SELECT_CALL(half,
false, 256);
84 }
else if (k <= 512) {
85 BLOCK_SELECT_CALL(half,
false, 512);
86 }
else if (k <= 1024) {
87 BLOCK_SELECT_CALL(half,
false, 1024);
88 #if GPU_MAX_SELECTION_K >= 2048
89 }
else if (k <= 2048) {
90 BLOCK_SELECT_CALL(half,
false, 2048);
96 void runBlockSelectPair(Tensor<half, 2, true>& inK,
97 Tensor<int, 2, true>& inV,
98 Tensor<half, 2, true>& outK,
99 Tensor<int, 2, true>& outV,
100 bool dir,
int k, cudaStream_t stream) {
101 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
105 BLOCK_SELECT_PAIR_CALL(half,
true, 1);
106 }
else if (k <= 32) {
107 BLOCK_SELECT_PAIR_CALL(half,
true, 32);
108 }
else if (k <= 64) {
109 BLOCK_SELECT_PAIR_CALL(half,
true, 64);
110 }
else if (k <= 128) {
111 BLOCK_SELECT_PAIR_CALL(half,
true, 128);
112 }
else if (k <= 256) {
113 BLOCK_SELECT_PAIR_CALL(half,
true, 256);
114 }
else if (k <= 512) {
115 BLOCK_SELECT_PAIR_CALL(half,
true, 512);
116 }
else if (k <= 1024) {
117 BLOCK_SELECT_PAIR_CALL(half,
true, 1024);
118 #if GPU_MAX_SELECTION_K >= 2048
119 }
else if (k <= 2048) {
120 BLOCK_SELECT_PAIR_CALL(half,
true, 2048);
125 BLOCK_SELECT_PAIR_CALL(half,
false, 1);
126 }
else if (k <= 32) {
127 BLOCK_SELECT_PAIR_CALL(half,
false, 32);
128 }
else if (k <= 64) {
129 BLOCK_SELECT_PAIR_CALL(half,
false, 64);
130 }
else if (k <= 128) {
131 BLOCK_SELECT_PAIR_CALL(half,
false, 128);
132 }
else if (k <= 256) {
133 BLOCK_SELECT_PAIR_CALL(half,
false, 256);
134 }
else if (k <= 512) {
135 BLOCK_SELECT_PAIR_CALL(half,
false, 512);
136 }
else if (k <= 1024) {
137 BLOCK_SELECT_PAIR_CALL(half,
false, 1024);
138 #if GPU_MAX_SELECTION_K >= 2048
139 }
else if (k <= 2048) {
140 BLOCK_SELECT_PAIR_CALL(half,
false, 2048);