Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
LoadStoreOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "Float16.cuh"
14 
15 //
16 // Templated wrappers to express load/store for different scalar and vector
17 // types, so kernels can have the same written form but can operate
18 // over half and float, and on vector types transparently
19 //
20 
21 namespace faiss { namespace gpu {
22 
23 template <typename T>
24 struct LoadStore {
25  static inline __device__ T load(void* p) {
26  return *((T*) p);
27  }
28 
29  static inline __device__ void store(void* p, const T& v) {
30  *((T*) p) = v;
31  }
32 };
33 
34 #ifdef FAISS_USE_FLOAT16
35 
36 template <>
37 struct LoadStore<Half4> {
38  static inline __device__ Half4 load(void* p) {
39  Half4 out;
40  asm("ld.global.v2.u32 {%0, %1}, [%2];" :
41  "=r"(out.a.x), "=r"(out.b.x) : "l"(p));
42  return out;
43  }
44 
45  static inline __device__ void store(void* p, const Half4& v) {
46  asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), "r"(v.a.x), "r"(v.b.x));
47  }
48 };
49 
50 template <>
51 struct LoadStore<Half8> {
52  static inline __device__ Half8 load(void* p) {
53  Half8 out;
54  asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" :
55  "=r"(out.a.a.x), "=r"(out.a.b.x),
56  "=r"(out.b.a.x), "=r"(out.b.b.x) : "l"(p));
57  return out;
58  }
59 
60  static inline __device__ void store(void* p, const Half8& v) {
61  asm("st.v4.u32 [%0], {%1, %2, %3, %4};"
62  : : "l"(p), "r"(v.a.a.x), "r"(v.a.b.x), "r"(v.b.a.x), "r"(v.b.b.x));
63  }
64 };
65 
66 #endif // FAISS_USE_FLOAT16
67 
68 } } // namespace