Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectHalf.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #include "blockselect/BlockSelectImpl.cuh"
11 
12 namespace faiss { namespace gpu {
13 
14 #ifdef FAISS_USE_FLOAT16
15 
16 // warp Q to thread Q:
17 // 1, 1
18 // 32, 2
19 // 64, 3
20 // 128, 3
21 // 256, 4
22 // 512, 8
23 // 1024, 8
24 
25 BLOCK_SELECT_DECL(half, true, 1);
26 BLOCK_SELECT_DECL(half, true, 32);
27 BLOCK_SELECT_DECL(half, true, 64);
28 BLOCK_SELECT_DECL(half, true, 128);
29 BLOCK_SELECT_DECL(half, true, 256);
30 BLOCK_SELECT_DECL(half, true, 512);
31 BLOCK_SELECT_DECL(half, true, 1024);
32 
33 BLOCK_SELECT_DECL(half, false, 1);
34 BLOCK_SELECT_DECL(half, false, 32);
35 BLOCK_SELECT_DECL(half, false, 64);
36 BLOCK_SELECT_DECL(half, false, 128);
37 BLOCK_SELECT_DECL(half, false, 256);
38 BLOCK_SELECT_DECL(half, false, 512);
39 BLOCK_SELECT_DECL(half, false, 1024);
40 
41 void runBlockSelect(Tensor<half, 2, true>& in,
42  Tensor<half, 2, true>& outK,
43  Tensor<int, 2, true>& outV,
44  bool dir, int k, cudaStream_t stream) {
45  FAISS_ASSERT(k <= 1024);
46 
47  if (dir) {
48  if (k == 1) {
49  BLOCK_SELECT_CALL(half, true, 1);
50  } else if (k <= 32) {
51  BLOCK_SELECT_CALL(half, true, 32);
52  } else if (k <= 64) {
53  BLOCK_SELECT_CALL(half, true, 64);
54  } else if (k <= 128) {
55  BLOCK_SELECT_CALL(half, true, 128);
56  } else if (k <= 256) {
57  BLOCK_SELECT_CALL(half, true, 256);
58  } else if (k <= 512) {
59  BLOCK_SELECT_CALL(half, true, 512);
60  } else if (k <= 1024) {
61  BLOCK_SELECT_CALL(half, true, 1024);
62  }
63  } else {
64  if (k == 1) {
65  BLOCK_SELECT_CALL(half, false, 1);
66  } else if (k <= 32) {
67  BLOCK_SELECT_CALL(half, false, 32);
68  } else if (k <= 64) {
69  BLOCK_SELECT_CALL(half, false, 64);
70  } else if (k <= 128) {
71  BLOCK_SELECT_CALL(half, false, 128);
72  } else if (k <= 256) {
73  BLOCK_SELECT_CALL(half, false, 256);
74  } else if (k <= 512) {
75  BLOCK_SELECT_CALL(half, false, 512);
76  } else if (k <= 1024) {
77  BLOCK_SELECT_CALL(half, false, 1024);
78  }
79  }
80 }
81 
82 #endif
83 
84 } } // namespace