10 #include "blockselect/BlockSelectImpl.cuh"
12 namespace faiss {
namespace gpu {
14 #ifdef FAISS_USE_FLOAT16
25 BLOCK_SELECT_DECL(half,
true, 1);
26 BLOCK_SELECT_DECL(half,
true, 32);
27 BLOCK_SELECT_DECL(half,
true, 64);
28 BLOCK_SELECT_DECL(half,
true, 128);
29 BLOCK_SELECT_DECL(half,
true, 256);
30 BLOCK_SELECT_DECL(half,
true, 512);
31 BLOCK_SELECT_DECL(half,
true, 1024);
33 BLOCK_SELECT_DECL(half,
false, 1);
34 BLOCK_SELECT_DECL(half,
false, 32);
35 BLOCK_SELECT_DECL(half,
false, 64);
36 BLOCK_SELECT_DECL(half,
false, 128);
37 BLOCK_SELECT_DECL(half,
false, 256);
38 BLOCK_SELECT_DECL(half,
false, 512);
39 BLOCK_SELECT_DECL(half,
false, 1024);
41 void runBlockSelect(Tensor<half, 2, true>& in,
42 Tensor<half, 2, true>& outK,
43 Tensor<int, 2, true>& outV,
44 bool dir,
int k, cudaStream_t stream) {
45 FAISS_ASSERT(k <= 1024);
49 BLOCK_SELECT_CALL(half,
true, 1);
51 BLOCK_SELECT_CALL(half,
true, 32);
53 BLOCK_SELECT_CALL(half,
true, 64);
54 }
else if (k <= 128) {
55 BLOCK_SELECT_CALL(half,
true, 128);
56 }
else if (k <= 256) {
57 BLOCK_SELECT_CALL(half,
true, 256);
58 }
else if (k <= 512) {
59 BLOCK_SELECT_CALL(half,
true, 512);
60 }
else if (k <= 1024) {
61 BLOCK_SELECT_CALL(half,
true, 1024);
65 BLOCK_SELECT_CALL(half,
false, 1);
67 BLOCK_SELECT_CALL(half,
false, 32);
69 BLOCK_SELECT_CALL(half,
false, 64);
70 }
else if (k <= 128) {
71 BLOCK_SELECT_CALL(half,
false, 128);
72 }
else if (k <= 256) {
73 BLOCK_SELECT_CALL(half,
false, 256);
74 }
else if (k <= 512) {
75 BLOCK_SELECT_CALL(half,
false, 512);
76 }
else if (k <= 1024) {
77 BLOCK_SELECT_CALL(half,
false, 1024);
82 void runBlockSelectPair(Tensor<half, 2, true>& inK,
83 Tensor<int, 2, true>& inV,
84 Tensor<half, 2, true>& outK,
85 Tensor<int, 2, true>& outV,
86 bool dir,
int k, cudaStream_t stream) {
87 FAISS_ASSERT(k <= 1024);
91 BLOCK_SELECT_PAIR_CALL(half,
true, 1);
93 BLOCK_SELECT_PAIR_CALL(half,
true, 32);
95 BLOCK_SELECT_PAIR_CALL(half,
true, 64);
96 }
else if (k <= 128) {
97 BLOCK_SELECT_PAIR_CALL(half,
true, 128);
98 }
else if (k <= 256) {
99 BLOCK_SELECT_PAIR_CALL(half,
true, 256);
100 }
else if (k <= 512) {
101 BLOCK_SELECT_PAIR_CALL(half,
true, 512);
102 }
else if (k <= 1024) {
103 BLOCK_SELECT_PAIR_CALL(half,
true, 1024);
107 BLOCK_SELECT_PAIR_CALL(half,
false, 1);
108 }
else if (k <= 32) {
109 BLOCK_SELECT_PAIR_CALL(half,
false, 32);
110 }
else if (k <= 64) {
111 BLOCK_SELECT_PAIR_CALL(half,
false, 64);
112 }
else if (k <= 128) {
113 BLOCK_SELECT_PAIR_CALL(half,
false, 128);
114 }
else if (k <= 256) {
115 BLOCK_SELECT_PAIR_CALL(half,
false, 256);
116 }
else if (k <= 512) {
117 BLOCK_SELECT_PAIR_CALL(half,
false, 512);
118 }
else if (k <= 1024) {
119 BLOCK_SELECT_PAIR_CALL(half,
false, 1024);