Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectFloat.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "warpselect/WarpSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 // warp Q to thread Q:
14 // 1, 1
15 // 32, 2
16 // 64, 3
17 // 128, 3
18 // 256, 4
19 // 512, 8
20 // 1024, 8
21 // 2048, 8
22 
23 WARP_SELECT_DECL(float, true, 1);
24 WARP_SELECT_DECL(float, true, 32);
25 WARP_SELECT_DECL(float, true, 64);
26 WARP_SELECT_DECL(float, true, 128);
27 WARP_SELECT_DECL(float, true, 256);
28 WARP_SELECT_DECL(float, true, 512);
29 WARP_SELECT_DECL(float, true, 1024);
30 #if GPU_MAX_SELECTION_K >= 2048
31 WARP_SELECT_DECL(float, true, 2048);
32 #endif
33 
34 WARP_SELECT_DECL(float, false, 1);
35 WARP_SELECT_DECL(float, false, 32);
36 WARP_SELECT_DECL(float, false, 64);
37 WARP_SELECT_DECL(float, false, 128);
38 WARP_SELECT_DECL(float, false, 256);
39 WARP_SELECT_DECL(float, false, 512);
40 WARP_SELECT_DECL(float, false, 1024);
41 #if GPU_MAX_SELECTION_K >= 2048
42 WARP_SELECT_DECL(float, false, 2048);
43 #endif
44 
45 void runWarpSelect(Tensor<float, 2, true>& in,
46  Tensor<float, 2, true>& outK,
47  Tensor<int, 2, true>& outV,
48  bool dir, int k, cudaStream_t stream) {
49  FAISS_ASSERT(k <= 2048);
50 
51  if (dir) {
52  if (k == 1) {
53  WARP_SELECT_CALL(float, true, 1);
54  } else if (k <= 32) {
55  WARP_SELECT_CALL(float, true, 32);
56  } else if (k <= 64) {
57  WARP_SELECT_CALL(float, true, 64);
58  } else if (k <= 128) {
59  WARP_SELECT_CALL(float, true, 128);
60  } else if (k <= 256) {
61  WARP_SELECT_CALL(float, true, 256);
62  } else if (k <= 512) {
63  WARP_SELECT_CALL(float, true, 512);
64  } else if (k <= 1024) {
65  WARP_SELECT_CALL(float, true, 1024);
66 #if GPU_MAX_SELECTION_K >= 2048
67  } else if (k <= 2048) {
68  WARP_SELECT_CALL(float, true, 2048);
69 #endif
70  }
71  } else {
72  if (k == 1) {
73  WARP_SELECT_CALL(float, false, 1);
74  } else if (k <= 32) {
75  WARP_SELECT_CALL(float, false, 32);
76  } else if (k <= 64) {
77  WARP_SELECT_CALL(float, false, 64);
78  } else if (k <= 128) {
79  WARP_SELECT_CALL(float, false, 128);
80  } else if (k <= 256) {
81  WARP_SELECT_CALL(float, false, 256);
82  } else if (k <= 512) {
83  WARP_SELECT_CALL(float, false, 512);
84  } else if (k <= 1024) {
85  WARP_SELECT_CALL(float, false, 1024);
86 #if GPU_MAX_SELECTION_K >= 2048
87  } else if (k <= 2048) {
88  WARP_SELECT_CALL(float, false, 2048);
89 #endif
90  }
91  }
92 }
93 
94 } } // namespace