Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectFloat.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "blockselect/BlockSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 // warp Q to thread Q:
14 // 1, 1
15 // 32, 2
16 // 64, 3
17 // 128, 3
18 // 256, 4
19 // 512, 8
20 // 1024, 8
21 // 2048, 8
22 
23 BLOCK_SELECT_DECL(float, true, 1);
24 BLOCK_SELECT_DECL(float, true, 32);
25 BLOCK_SELECT_DECL(float, true, 64);
26 BLOCK_SELECT_DECL(float, true, 128);
27 BLOCK_SELECT_DECL(float, true, 256);
28 BLOCK_SELECT_DECL(float, true, 512);
29 BLOCK_SELECT_DECL(float, true, 1024);
30 #if GPU_MAX_SELECTION_K >= 2048
31 BLOCK_SELECT_DECL(float, true, 2048);
32 #endif
33 
34 BLOCK_SELECT_DECL(float, false, 1);
35 BLOCK_SELECT_DECL(float, false, 32);
36 BLOCK_SELECT_DECL(float, false, 64);
37 BLOCK_SELECT_DECL(float, false, 128);
38 BLOCK_SELECT_DECL(float, false, 256);
39 BLOCK_SELECT_DECL(float, false, 512);
40 BLOCK_SELECT_DECL(float, false, 1024);
41 #if GPU_MAX_SELECTION_K >= 2048
42 BLOCK_SELECT_DECL(float, false, 2048);
43 #endif
44 
45 void runBlockSelect(Tensor<float, 2, true>& in,
46  Tensor<float, 2, true>& outK,
47  Tensor<int, 2, true>& outV,
48  bool dir, int k, cudaStream_t stream) {
49  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
50 
51  if (dir) {
52  if (k == 1) {
53  BLOCK_SELECT_CALL(float, true, 1);
54  } else if (k <= 32) {
55  BLOCK_SELECT_CALL(float, true, 32);
56  } else if (k <= 64) {
57  BLOCK_SELECT_CALL(float, true, 64);
58  } else if (k <= 128) {
59  BLOCK_SELECT_CALL(float, true, 128);
60  } else if (k <= 256) {
61  BLOCK_SELECT_CALL(float, true, 256);
62  } else if (k <= 512) {
63  BLOCK_SELECT_CALL(float, true, 512);
64  } else if (k <= 1024) {
65  BLOCK_SELECT_CALL(float, true, 1024);
66 #if GPU_MAX_SELECTION_K >= 2048
67  } else if (k <= 2048) {
68  BLOCK_SELECT_CALL(float, true, 2048);
69 #endif
70  }
71  } else {
72  if (k == 1) {
73  BLOCK_SELECT_CALL(float, false, 1);
74  } else if (k <= 32) {
75  BLOCK_SELECT_CALL(float, false, 32);
76  } else if (k <= 64) {
77  BLOCK_SELECT_CALL(float, false, 64);
78  } else if (k <= 128) {
79  BLOCK_SELECT_CALL(float, false, 128);
80  } else if (k <= 256) {
81  BLOCK_SELECT_CALL(float, false, 256);
82  } else if (k <= 512) {
83  BLOCK_SELECT_CALL(float, false, 512);
84  } else if (k <= 1024) {
85  BLOCK_SELECT_CALL(float, false, 1024);
86 #if GPU_MAX_SELECTION_K >= 2048
87  } else if (k <= 2048) {
88  BLOCK_SELECT_CALL(float, false, 2048);
89 #endif
90  }
91  }
92 }
93 
94 void runBlockSelectPair(Tensor<float, 2, true>& inK,
95  Tensor<int, 2, true>& inV,
96  Tensor<float, 2, true>& outK,
97  Tensor<int, 2, true>& outV,
98  bool dir, int k, cudaStream_t stream) {
99  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
100 
101  if (dir) {
102  if (k == 1) {
103  BLOCK_SELECT_PAIR_CALL(float, true, 1);
104  } else if (k <= 32) {
105  BLOCK_SELECT_PAIR_CALL(float, true, 32);
106  } else if (k <= 64) {
107  BLOCK_SELECT_PAIR_CALL(float, true, 64);
108  } else if (k <= 128) {
109  BLOCK_SELECT_PAIR_CALL(float, true, 128);
110  } else if (k <= 256) {
111  BLOCK_SELECT_PAIR_CALL(float, true, 256);
112  } else if (k <= 512) {
113  BLOCK_SELECT_PAIR_CALL(float, true, 512);
114  } else if (k <= 1024) {
115  BLOCK_SELECT_PAIR_CALL(float, true, 1024);
116 #if GPU_MAX_SELECTION_K >= 2048
117  } else if (k <= 2048) {
118  BLOCK_SELECT_PAIR_CALL(float, true, 2048);
119 #endif
120  }
121  } else {
122  if (k == 1) {
123  BLOCK_SELECT_PAIR_CALL(float, false, 1);
124  } else if (k <= 32) {
125  BLOCK_SELECT_PAIR_CALL(float, false, 32);
126  } else if (k <= 64) {
127  BLOCK_SELECT_PAIR_CALL(float, false, 64);
128  } else if (k <= 128) {
129  BLOCK_SELECT_PAIR_CALL(float, false, 128);
130  } else if (k <= 256) {
131  BLOCK_SELECT_PAIR_CALL(float, false, 256);
132  } else if (k <= 512) {
133  BLOCK_SELECT_PAIR_CALL(float, false, 512);
134  } else if (k <= 1024) {
135  BLOCK_SELECT_PAIR_CALL(float, false, 1024);
136 #if GPU_MAX_SELECTION_K >= 2048
137  } else if (k <= 2048) {
138  BLOCK_SELECT_PAIR_CALL(float, false, 2048);
139 #endif
140  }
141  }
142 }
143 
144 } } // namespace