Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WarpSelectFloat.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #include "warpselect/WarpSelectImpl.cuh"
11 
12 namespace faiss { namespace gpu {
13 
14 // warp Q to thread Q:
15 // 1, 1
16 // 32, 2
17 // 64, 3
18 // 128, 3
19 // 256, 4
20 // 512, 8
21 // 1024, 8
22 
23 WARP_SELECT_DECL(float, true, 1);
24 WARP_SELECT_DECL(float, true, 32);
25 WARP_SELECT_DECL(float, true, 64);
26 WARP_SELECT_DECL(float, true, 128);
27 WARP_SELECT_DECL(float, true, 256);
28 WARP_SELECT_DECL(float, true, 512);
29 WARP_SELECT_DECL(float, true, 1024);
30 
31 WARP_SELECT_DECL(float, false, 1);
32 WARP_SELECT_DECL(float, false, 32);
33 WARP_SELECT_DECL(float, false, 64);
34 WARP_SELECT_DECL(float, false, 128);
35 WARP_SELECT_DECL(float, false, 256);
36 WARP_SELECT_DECL(float, false, 512);
37 WARP_SELECT_DECL(float, false, 1024);
38 
39 void runWarpSelect(Tensor<float, 2, true>& in,
40  Tensor<float, 2, true>& outK,
41  Tensor<int, 2, true>& outV,
42  bool dir, int k, cudaStream_t stream) {
43  FAISS_ASSERT(k <= 1024);
44 
45  if (dir) {
46  if (k == 1) {
47  WARP_SELECT_CALL(float, true, 1);
48  } else if (k <= 32) {
49  WARP_SELECT_CALL(float, true, 32);
50  } else if (k <= 64) {
51  WARP_SELECT_CALL(float, true, 64);
52  } else if (k <= 128) {
53  WARP_SELECT_CALL(float, true, 128);
54  } else if (k <= 256) {
55  WARP_SELECT_CALL(float, true, 256);
56  } else if (k <= 512) {
57  WARP_SELECT_CALL(float, true, 512);
58  } else if (k <= 1024) {
59  WARP_SELECT_CALL(float, true, 1024);
60  }
61  } else {
62  if (k == 1) {
63  WARP_SELECT_CALL(float, false, 1);
64  } else if (k <= 32) {
65  WARP_SELECT_CALL(float, false, 32);
66  } else if (k <= 64) {
67  WARP_SELECT_CALL(float, false, 64);
68  } else if (k <= 128) {
69  WARP_SELECT_CALL(float, false, 128);
70  } else if (k <= 256) {
71  WARP_SELECT_CALL(float, false, 256);
72  } else if (k <= 512) {
73  WARP_SELECT_CALL(float, false, 512);
74  } else if (k <= 1024) {
75  WARP_SELECT_CALL(float, false, 1024);
76  }
77  }
78 }
79 
80 } } // namespace