Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
LoadStoreOperators.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "Float16.cuh"
12 
13 #ifndef __HALF2_TO_UI
14 // cuda_fp16.hpp doesn't export this
15 #define __HALF2_TO_UI(var) *(reinterpret_cast<unsigned int *>(&(var)))
16 #endif
17 
18 
19 //
20 // Templated wrappers to express load/store for different scalar and vector
21 // types, so kernels can have the same written form but can operate
22 // over half and float, and on vector types transparently
23 //
24 
25 namespace faiss { namespace gpu {
26 
27 template <typename T>
28 struct LoadStore {
29  static inline __device__ T load(void* p) {
30  return *((T*) p);
31  }
32 
33  static inline __device__ void store(void* p, const T& v) {
34  *((T*) p) = v;
35  }
36 };
37 
38 #ifdef FAISS_USE_FLOAT16
39 
40 template <>
41 struct LoadStore<Half4> {
42  static inline __device__ Half4 load(void* p) {
43  Half4 out;
44 #if CUDA_VERSION >= 9000
45  asm("ld.global.v2.u32 {%0, %1}, [%2];" :
46  "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b)) : "l"(p));
47 #else
48  asm("ld.global.v2.u32 {%0, %1}, [%2];" :
49  "=r"(out.a.x), "=r"(out.b.x) : "l"(p));
50 #endif
51  return out;
52  }
53 
54  static inline __device__ void store(void* p, Half4& v) {
55 #if CUDA_VERSION >= 9000
56  asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p),
57  "r"(__HALF2_TO_UI(v.a)), "r"(__HALF2_TO_UI(v.b)));
58 #else
59  asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), "r"(v.a.x), "r"(v.b.x));
60 #endif
61  }
62 };
63 
64 template <>
65 struct LoadStore<Half8> {
66  static inline __device__ Half8 load(void* p) {
67  Half8 out;
68 #if CUDA_VERSION >= 9000
69  asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" :
70  "=r"(__HALF2_TO_UI(out.a.a)), "=r"(__HALF2_TO_UI(out.a.b)),
71  "=r"(__HALF2_TO_UI(out.b.a)), "=r"(__HALF2_TO_UI(out.b.b)) : "l"(p));
72 #else
73  asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" :
74  "=r"(out.a.a.x), "=r"(out.a.b.x),
75  "=r"(out.b.a.x), "=r"(out.b.b.x) : "l"(p));
76 #endif
77  return out;
78  }
79 
80  static inline __device__ void store(void* p, Half8& v) {
81 #if CUDA_VERSION >= 9000
82  asm("st.v4.u32 [%0], {%1, %2, %3, %4};"
83  : : "l"(p), "r"(__HALF2_TO_UI(v.a.a)), "r"(__HALF2_TO_UI(v.a.b)),
84  "r"(__HALF2_TO_UI(v.b.a)), "r"(__HALF2_TO_UI(v.b.b)));
85 #else
86  asm("st.v4.u32 [%0], {%1, %2, %3, %4};"
87  : : "l"(p), "r"(v.a.a.x), "r"(v.a.b.x), "r"(v.b.a.x), "r"(v.b.b.x));
88 #endif
89  }
90 };
91 
92 #endif // FAISS_USE_FLOAT16
93 
94 } } // namespace