Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectHalf.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 #include "warpselect/WarpSelectImpl.cuh"
12 
13 namespace faiss { namespace gpu {
14 
15 #ifdef FAISS_USE_FLOAT16
16 
17 // warp Q to thread Q:
18 // 1, 1
19 // 32, 2
20 // 64, 3
21 // 128, 3
22 // 256, 4
23 // 512, 8
24 // 1024, 8
25 
26 WARP_SELECT_DECL(half, true, 1);
27 WARP_SELECT_DECL(half, true, 32);
28 WARP_SELECT_DECL(half, true, 64);
29 WARP_SELECT_DECL(half, true, 128);
30 WARP_SELECT_DECL(half, true, 256);
31 WARP_SELECT_DECL(half, true, 512);
32 WARP_SELECT_DECL(half, true, 1024);
33 
34 WARP_SELECT_DECL(half, false, 1);
35 WARP_SELECT_DECL(half, false, 32);
36 WARP_SELECT_DECL(half, false, 64);
37 WARP_SELECT_DECL(half, false, 128);
38 WARP_SELECT_DECL(half, false, 256);
39 WARP_SELECT_DECL(half, false, 512);
40 WARP_SELECT_DECL(half, false, 1024);
41 
42 void runWarpSelect(Tensor<half, 2, true>& in,
43  Tensor<half, 2, true>& outK,
44  Tensor<int, 2, true>& outV,
45  bool dir, int k, cudaStream_t stream) {
46  FAISS_ASSERT(k <= 1024);
47 
48  if (dir) {
49  if (k == 1) {
50  WARP_SELECT_CALL(half, true, 1);
51  } else if (k <= 32) {
52  WARP_SELECT_CALL(half, true, 32);
53  } else if (k <= 64) {
54  WARP_SELECT_CALL(half, true, 64);
55  } else if (k <= 128) {
56  WARP_SELECT_CALL(half, true, 128);
57  } else if (k <= 256) {
58  WARP_SELECT_CALL(half, true, 256);
59  } else if (k <= 512) {
60  WARP_SELECT_CALL(half, true, 512);
61  } else if (k <= 1024) {
62  WARP_SELECT_CALL(half, true, 1024);
63  }
64  } else {
65  if (k == 1) {
66  WARP_SELECT_CALL(half, false, 1);
67  } else if (k <= 32) {
68  WARP_SELECT_CALL(half, false, 32);
69  } else if (k <= 64) {
70  WARP_SELECT_CALL(half, false, 64);
71  } else if (k <= 128) {
72  WARP_SELECT_CALL(half, false, 128);
73  } else if (k <= 256) {
74  WARP_SELECT_CALL(half, false, 256);
75  } else if (k <= 512) {
76  WARP_SELECT_CALL(half, false, 512);
77  } else if (k <= 1024) {
78  WARP_SELECT_CALL(half, false, 1024);
79  }
80  }
81 }
82 
83 #endif
84 
85 } } // namespace