Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
StaticUtils.h
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include <cuda.h>
12 
13 namespace faiss { namespace gpu { namespace utils {
14 
15 template <typename U, typename V>
16 constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
17  return (a + b - 1) / b;
18 }
19 
20 template <typename U, typename V>
21 constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
22  return (a / b) * b;
23 }
24 
25 template <typename U, typename V>
26 constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
27  return divUp(a, b) * b;
28 }
29 
30 template <class T>
31 constexpr __host__ __device__ T pow(T n, T power) {
32  return (power > 0 ? n * pow(n, power - 1) : 1);
33 }
34 
35 template <class T>
36 constexpr __host__ __device__ T pow2(T n) {
37  return pow(2, (T) n);
38 }
39 
40 static_assert(pow2(8) == 256, "pow2");
41 
42 template <typename T>
43 constexpr __host__ __device__ int log2(T n, int p = 0) {
44  return (n <= 1) ? p : log2(n / 2, p + 1);
45 }
46 
47 static_assert(log2(2) == 1, "log2");
48 static_assert(log2(3) == 1, "log2");
49 static_assert(log2(4) == 2, "log2");
50 
51 template <typename T>
52 constexpr __host__ __device__ bool isPowerOf2(T v) {
53  return (v && !(v & (v - 1)));
54 }
55 
56 static_assert(isPowerOf2(2048), "isPowerOf2");
57 static_assert(!isPowerOf2(3333), "isPowerOf2");
58 
59 template <typename T>
60 constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
61  return (isPowerOf2(v) ? (T) 2 * v : ((T) 1 << (log2(v) + 1)));
62 }
63 
64 static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
65 static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
66 static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
67 static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
68 
69 static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
70 static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
71 static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
72 
73 static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u,
74  "nextHighestPowerOf2");
75 static_assert(nextHighestPowerOf2((size_t) 2147483648ULL) ==
76  (size_t) 4294967296ULL, "nextHighestPowerOf2");
77 
78 } } } // namespace