Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectHalf.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "blockselect/BlockSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 #ifdef FAISS_USE_FLOAT16
14 
15 // warp Q to thread Q:
16 // 1, 1
17 // 32, 2
18 // 64, 3
19 // 128, 3
20 // 256, 4
21 // 512, 8
22 // 1024, 8
23 // 2048, 8
24 
25 BLOCK_SELECT_DECL(half, true, 1);
26 BLOCK_SELECT_DECL(half, true, 32);
27 BLOCK_SELECT_DECL(half, true, 64);
28 BLOCK_SELECT_DECL(half, true, 128);
29 BLOCK_SELECT_DECL(half, true, 256);
30 BLOCK_SELECT_DECL(half, true, 512);
31 BLOCK_SELECT_DECL(half, true, 1024);
32 #if GPU_MAX_SELECTION_K >= 2048
33 BLOCK_SELECT_DECL(half, true, 2048);
34 #endif
35 
36 BLOCK_SELECT_DECL(half, false, 1);
37 BLOCK_SELECT_DECL(half, false, 32);
38 BLOCK_SELECT_DECL(half, false, 64);
39 BLOCK_SELECT_DECL(half, false, 128);
40 BLOCK_SELECT_DECL(half, false, 256);
41 BLOCK_SELECT_DECL(half, false, 512);
42 BLOCK_SELECT_DECL(half, false, 1024);
43 #if GPU_MAX_SELECTION_K >= 2048
44 BLOCK_SELECT_DECL(half, false, 2048);
45 #endif
46 
47 void runBlockSelect(Tensor<half, 2, true>& in,
48  Tensor<half, 2, true>& outK,
49  Tensor<int, 2, true>& outV,
50  bool dir, int k, cudaStream_t stream) {
51  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
52 
53  if (dir) {
54  if (k == 1) {
55  BLOCK_SELECT_CALL(half, true, 1);
56  } else if (k <= 32) {
57  BLOCK_SELECT_CALL(half, true, 32);
58  } else if (k <= 64) {
59  BLOCK_SELECT_CALL(half, true, 64);
60  } else if (k <= 128) {
61  BLOCK_SELECT_CALL(half, true, 128);
62  } else if (k <= 256) {
63  BLOCK_SELECT_CALL(half, true, 256);
64  } else if (k <= 512) {
65  BLOCK_SELECT_CALL(half, true, 512);
66  } else if (k <= 1024) {
67  BLOCK_SELECT_CALL(half, true, 1024);
68 #if GPU_MAX_SELECTION_K >= 2048
69  } else if (k <= 2048) {
70  BLOCK_SELECT_CALL(half, true, 2048);
71 #endif
72  }
73  } else {
74  if (k == 1) {
75  BLOCK_SELECT_CALL(half, false, 1);
76  } else if (k <= 32) {
77  BLOCK_SELECT_CALL(half, false, 32);
78  } else if (k <= 64) {
79  BLOCK_SELECT_CALL(half, false, 64);
80  } else if (k <= 128) {
81  BLOCK_SELECT_CALL(half, false, 128);
82  } else if (k <= 256) {
83  BLOCK_SELECT_CALL(half, false, 256);
84  } else if (k <= 512) {
85  BLOCK_SELECT_CALL(half, false, 512);
86  } else if (k <= 1024) {
87  BLOCK_SELECT_CALL(half, false, 1024);
88 #if GPU_MAX_SELECTION_K >= 2048
89  } else if (k <= 2048) {
90  BLOCK_SELECT_CALL(half, false, 2048);
91 #endif
92  }
93  }
94 }
95 
96 void runBlockSelectPair(Tensor<half, 2, true>& inK,
97  Tensor<int, 2, true>& inV,
98  Tensor<half, 2, true>& outK,
99  Tensor<int, 2, true>& outV,
100  bool dir, int k, cudaStream_t stream) {
101  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
102 
103  if (dir) {
104  if (k == 1) {
105  BLOCK_SELECT_PAIR_CALL(half, true, 1);
106  } else if (k <= 32) {
107  BLOCK_SELECT_PAIR_CALL(half, true, 32);
108  } else if (k <= 64) {
109  BLOCK_SELECT_PAIR_CALL(half, true, 64);
110  } else if (k <= 128) {
111  BLOCK_SELECT_PAIR_CALL(half, true, 128);
112  } else if (k <= 256) {
113  BLOCK_SELECT_PAIR_CALL(half, true, 256);
114  } else if (k <= 512) {
115  BLOCK_SELECT_PAIR_CALL(half, true, 512);
116  } else if (k <= 1024) {
117  BLOCK_SELECT_PAIR_CALL(half, true, 1024);
118 #if GPU_MAX_SELECTION_K >= 2048
119  } else if (k <= 2048) {
120  BLOCK_SELECT_PAIR_CALL(half, true, 2048);
121 #endif
122  }
123  } else {
124  if (k == 1) {
125  BLOCK_SELECT_PAIR_CALL(half, false, 1);
126  } else if (k <= 32) {
127  BLOCK_SELECT_PAIR_CALL(half, false, 32);
128  } else if (k <= 64) {
129  BLOCK_SELECT_PAIR_CALL(half, false, 64);
130  } else if (k <= 128) {
131  BLOCK_SELECT_PAIR_CALL(half, false, 128);
132  } else if (k <= 256) {
133  BLOCK_SELECT_PAIR_CALL(half, false, 256);
134  } else if (k <= 512) {
135  BLOCK_SELECT_PAIR_CALL(half, false, 512);
136  } else if (k <= 1024) {
137  BLOCK_SELECT_PAIR_CALL(half, false, 1024);
138 #if GPU_MAX_SELECTION_K >= 2048
139  } else if (k <= 2048) {
140  BLOCK_SELECT_PAIR_CALL(half, false, 2048);
141 #endif
142  }
143  }
144 }
145 
146 #endif
147 
148 } } // namespace