8 #include "blockselect/BlockSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
11 namespace faiss {
namespace gpu {
23 BLOCK_SELECT_DECL(
float,
true, 1);
24 BLOCK_SELECT_DECL(
float,
true, 32);
25 BLOCK_SELECT_DECL(
float,
true, 64);
26 BLOCK_SELECT_DECL(
float,
true, 128);
27 BLOCK_SELECT_DECL(
float,
true, 256);
28 BLOCK_SELECT_DECL(
float,
true, 512);
29 BLOCK_SELECT_DECL(
float,
true, 1024);
30 #if GPU_MAX_SELECTION_K >= 2048
31 BLOCK_SELECT_DECL(
float,
true, 2048);
34 BLOCK_SELECT_DECL(
float,
false, 1);
35 BLOCK_SELECT_DECL(
float,
false, 32);
36 BLOCK_SELECT_DECL(
float,
false, 64);
37 BLOCK_SELECT_DECL(
float,
false, 128);
38 BLOCK_SELECT_DECL(
float,
false, 256);
39 BLOCK_SELECT_DECL(
float,
false, 512);
40 BLOCK_SELECT_DECL(
float,
false, 1024);
41 #if GPU_MAX_SELECTION_K >= 2048
42 BLOCK_SELECT_DECL(
float,
false, 2048);
45 void runBlockSelect(Tensor<float, 2, true>& in,
46 Tensor<float, 2, true>& outK,
47 Tensor<int, 2, true>& outV,
48 bool dir,
int k, cudaStream_t stream) {
49 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
53 BLOCK_SELECT_CALL(
float,
true, 1);
55 BLOCK_SELECT_CALL(
float,
true, 32);
57 BLOCK_SELECT_CALL(
float,
true, 64);
58 }
else if (k <= 128) {
59 BLOCK_SELECT_CALL(
float,
true, 128);
60 }
else if (k <= 256) {
61 BLOCK_SELECT_CALL(
float,
true, 256);
62 }
else if (k <= 512) {
63 BLOCK_SELECT_CALL(
float,
true, 512);
64 }
else if (k <= 1024) {
65 BLOCK_SELECT_CALL(
float,
true, 1024);
66 #if GPU_MAX_SELECTION_K >= 2048
67 }
else if (k <= 2048) {
68 BLOCK_SELECT_CALL(
float,
true, 2048);
73 BLOCK_SELECT_CALL(
float,
false, 1);
75 BLOCK_SELECT_CALL(
float,
false, 32);
77 BLOCK_SELECT_CALL(
float,
false, 64);
78 }
else if (k <= 128) {
79 BLOCK_SELECT_CALL(
float,
false, 128);
80 }
else if (k <= 256) {
81 BLOCK_SELECT_CALL(
float,
false, 256);
82 }
else if (k <= 512) {
83 BLOCK_SELECT_CALL(
float,
false, 512);
84 }
else if (k <= 1024) {
85 BLOCK_SELECT_CALL(
float,
false, 1024);
86 #if GPU_MAX_SELECTION_K >= 2048
87 }
else if (k <= 2048) {
88 BLOCK_SELECT_CALL(
float,
false, 2048);
94 void runBlockSelectPair(Tensor<float, 2, true>& inK,
95 Tensor<int, 2, true>& inV,
96 Tensor<float, 2, true>& outK,
97 Tensor<int, 2, true>& outV,
98 bool dir,
int k, cudaStream_t stream) {
99 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
103 BLOCK_SELECT_PAIR_CALL(
float,
true, 1);
104 }
else if (k <= 32) {
105 BLOCK_SELECT_PAIR_CALL(
float,
true, 32);
106 }
else if (k <= 64) {
107 BLOCK_SELECT_PAIR_CALL(
float,
true, 64);
108 }
else if (k <= 128) {
109 BLOCK_SELECT_PAIR_CALL(
float,
true, 128);
110 }
else if (k <= 256) {
111 BLOCK_SELECT_PAIR_CALL(
float,
true, 256);
112 }
else if (k <= 512) {
113 BLOCK_SELECT_PAIR_CALL(
float,
true, 512);
114 }
else if (k <= 1024) {
115 BLOCK_SELECT_PAIR_CALL(
float,
true, 1024);
116 #if GPU_MAX_SELECTION_K >= 2048
117 }
else if (k <= 2048) {
118 BLOCK_SELECT_PAIR_CALL(
float,
true, 2048);
123 BLOCK_SELECT_PAIR_CALL(
float,
false, 1);
124 }
else if (k <= 32) {
125 BLOCK_SELECT_PAIR_CALL(
float,
false, 32);
126 }
else if (k <= 64) {
127 BLOCK_SELECT_PAIR_CALL(
float,
false, 64);
128 }
else if (k <= 128) {
129 BLOCK_SELECT_PAIR_CALL(
float,
false, 128);
130 }
else if (k <= 256) {
131 BLOCK_SELECT_PAIR_CALL(
float,
false, 256);
132 }
else if (k <= 512) {
133 BLOCK_SELECT_PAIR_CALL(
float,
false, 512);
134 }
else if (k <= 1024) {
135 BLOCK_SELECT_PAIR_CALL(
float,
false, 1024);
136 #if GPU_MAX_SELECTION_K >= 2048
137 }
else if (k <= 2048) {
138 BLOCK_SELECT_PAIR_CALL(
float,
false, 2048);