Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
ReductionOperators.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include <cuda.h>
12 #include "Limits.cuh"
13 #include "MathOperators.cuh"
14 #include "Pair.cuh"
15 
16 namespace faiss { namespace gpu {
17 
18 template <typename T>
19 struct Sum {
20  __device__ inline T operator()(T a, T b) const {
21  return Math<T>::add(a, b);
22  }
23 
24  inline __device__ T identity() const {
25  return Math<T>::zero();
26  }
27 };
28 
29 template <typename T>
30 struct Min {
31  __device__ inline T operator()(T a, T b) const {
32  return Math<T>::lt(a, b) ? a : b;
33  }
34 
35  inline __device__ T identity() const {
36  return Limits<T>::getMax();
37  }
38 };
39 
40 template <typename T>
41 struct Max {
42  __device__ inline T operator()(T a, T b) const {
43  return Math<T>::gt(a, b) ? a : b;
44  }
45 
46  inline __device__ T identity() const {
47  return Limits<T>::getMin();
48  }
49 };
50 
51 /// Used for producing segmented prefix scans; the value of the Pair
52 /// denotes the start of a new segment for the scan
53 template <typename T, typename ReduceOp>
55  inline __device__ SegmentedReduce(const ReduceOp& o)
56  : op(o) {
57  }
58 
59  __device__
60  inline Pair<T, bool>
61  operator()(const Pair<T, bool>& a, const Pair<T, bool>& b) const {
62  return Pair<T, bool>(b.v ? b.k : op(a.k, b.k),
63  a.v || b.v);
64  }
65 
66  inline __device__ Pair<T, bool> identity() const {
67  return Pair<T, bool>(op.identity(), false);
68  }
69 
70  ReduceOp op;
71 };
72 
73 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:19