Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
ConversionOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include <cuda.h>
13 #include "../../Index.h"
14 #include "Float16.cuh"
15 
16 namespace faiss { namespace gpu {
17 
18 //
19 // Conversion utilities
20 //
21 
22 struct IntToIdxType {
23  inline __device__ faiss::Index::idx_t operator()(int v) const {
24  return (faiss::Index::idx_t) v;
25  }
26 };
27 
28 template <typename T>
29 struct ConvertTo {
30 };
31 
32 template <>
33 struct ConvertTo<float> {
34  static inline __device__ float to(float v) { return v; }
35 #ifdef FAISS_USE_FLOAT16
36  static inline __device__ float to(half v) { return __half2float(v); }
37 #endif
38 };
39 
40 template <>
41 struct ConvertTo<float2> {
42  static inline __device__ float2 to(float2 v) { return v; }
43 #ifdef FAISS_USE_FLOAT16
44  static inline __device__ float2 to(half2 v) { return __half22float2(v); }
45 #endif
46 };
47 
48 template <>
49 struct ConvertTo<float4> {
50  static inline __device__ float4 to(float4 v) { return v; }
51 #ifdef FAISS_USE_FLOAT16
52  static inline __device__ float4 to(Half4 v) { return half4ToFloat4(v); }
53 #endif
54 };
55 
56 #ifdef FAISS_USE_FLOAT16
57 template <>
58 struct ConvertTo<half> {
59  static inline __device__ half to(float v) { return __float2half(v); }
60  static inline __device__ half to(half v) { return v; }
61 };
62 
63 template <>
64 struct ConvertTo<half2> {
65  static inline __device__ half2 to(float2 v) { return __float22half2_rn(v); }
66  static inline __device__ half2 to(half2 v) { return v; }
67 };
68 
69 template <>
70 struct ConvertTo<Half4> {
71  static inline __device__ Half4 to(float4 v) { return float4ToHalf4(v); }
72  static inline __device__ Half4 to(Half4 v) { return v; }
73 };
74 #endif
75 
76 
77 } } // namespace
long idx_t
all indices are this type
Definition: Index.h:64