Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
LoadStoreOperators.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include "Float16.cuh"
15 
16 //
17 // Templated wrappers to express load/store for different scalar and vector
18 // types, so kernels can have the same written form but can operate
19 // over half and float, and on vector types transparently
20 //
21 
22 namespace faiss { namespace gpu {
23 
24 template <typename T>
25 struct LoadStore {
26  static inline __device__ T load(void* p) {
27  return *((T*) p);
28  }
29 
30  static inline __device__ void store(void* p, const T& v) {
31  *((T*) p) = v;
32  }
33 };
34 
35 #ifdef FAISS_USE_FLOAT16
36 
37 template <>
38 struct LoadStore<Half4> {
39  static inline __device__ Half4 load(void* p) {
40  Half4 out;
41  asm("ld.global.v2.u32 {%0, %1}, [%2];" :
42  "=r"(out.a.x), "=r"(out.b.x) : "l"(p));
43  return out;
44  }
45 
46  static inline __device__ void store(void* p, const Half4& v) {
47  asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), "r"(v.a.x), "r"(v.b.x));
48  }
49 };
50 
51 template <>
52 struct LoadStore<Half8> {
53  static inline __device__ Half8 load(void* p) {
54  Half8 out;
55  asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" :
56  "=r"(out.a.a.x), "=r"(out.a.b.x),
57  "=r"(out.b.a.x), "=r"(out.b.b.x) : "l"(p));
58  return out;
59  }
60 
61  static inline __device__ void store(void* p, const Half8& v) {
62  asm("st.v4.u32 [%0], {%1, %2, %3, %4};"
63  : : "l"(p), "r"(v.a.a.x), "r"(v.a.b.x), "r"(v.b.a.x), "r"(v.b.b.x));
64  }
65 };
66 
67 #endif // FAISS_USE_FLOAT16
68 
69 } } // namespace