Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectHalf.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include "warpselect/WarpSelectImpl.cuh"
9 #include "DeviceDefs.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 #ifdef FAISS_USE_FLOAT16
14 
15 // warp Q to thread Q:
16 // 1, 1
17 // 32, 2
18 // 64, 3
19 // 128, 3
20 // 256, 4
21 // 512, 8
22 // 1024, 8
23 // 2048, 8
24 
25 WARP_SELECT_DECL(half, true, 1);
26 WARP_SELECT_DECL(half, true, 32);
27 WARP_SELECT_DECL(half, true, 64);
28 WARP_SELECT_DECL(half, true, 128);
29 WARP_SELECT_DECL(half, true, 256);
30 WARP_SELECT_DECL(half, true, 512);
31 WARP_SELECT_DECL(half, true, 1024);
32 #if GPU_MAX_SELECTION_K >= 2048
33 WARP_SELECT_DECL(half, true, 2048);
34 #endif
35 
36 WARP_SELECT_DECL(half, false, 1);
37 WARP_SELECT_DECL(half, false, 32);
38 WARP_SELECT_DECL(half, false, 64);
39 WARP_SELECT_DECL(half, false, 128);
40 WARP_SELECT_DECL(half, false, 256);
41 WARP_SELECT_DECL(half, false, 512);
42 WARP_SELECT_DECL(half, false, 1024);
43 #if GPU_MAX_SELECTION_K >= 2048
44 WARP_SELECT_DECL(half, false, 2048);
45 #endif
46 
47 void runWarpSelect(Tensor<half, 2, true>& in,
48  Tensor<half, 2, true>& outK,
49  Tensor<int, 2, true>& outV,
50  bool dir, int k, cudaStream_t stream) {
51  FAISS_ASSERT(k <= 1024);
52 
53  if (dir) {
54  if (k == 1) {
55  WARP_SELECT_CALL(half, true, 1);
56  } else if (k <= 32) {
57  WARP_SELECT_CALL(half, true, 32);
58  } else if (k <= 64) {
59  WARP_SELECT_CALL(half, true, 64);
60  } else if (k <= 128) {
61  WARP_SELECT_CALL(half, true, 128);
62  } else if (k <= 256) {
63  WARP_SELECT_CALL(half, true, 256);
64  } else if (k <= 512) {
65  WARP_SELECT_CALL(half, true, 512);
66  } else if (k <= 1024) {
67  WARP_SELECT_CALL(half, true, 1024);
68 #if GPU_MAX_SELECTION_K >= 2048
69  } else if (k <= 2048) {
70  WARP_SELECT_CALL(half, true, 2048);
71 #endif
72  }
73  } else {
74  if (k == 1) {
75  WARP_SELECT_CALL(half, false, 1);
76  } else if (k <= 32) {
77  WARP_SELECT_CALL(half, false, 32);
78  } else if (k <= 64) {
79  WARP_SELECT_CALL(half, false, 64);
80  } else if (k <= 128) {
81  WARP_SELECT_CALL(half, false, 128);
82  } else if (k <= 256) {
83  WARP_SELECT_CALL(half, false, 256);
84  } else if (k <= 512) {
85  WARP_SELECT_CALL(half, false, 512);
86  } else if (k <= 1024) {
87  WARP_SELECT_CALL(half, false, 1024);
88 #if GPU_MAX_SELECTION_K >= 2048
89  } else if (k <= 2048) {
90  WARP_SELECT_CALL(half, false, 2048);
91 #endif
92  }
93  }
94 }
95 
96 #endif
97 
98 } } // namespace