Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
ReductionOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include <cuda.h>
14 #include "Limits.cuh"
15 #include "MathOperators.cuh"
16 #include "Pair.cuh"
17 
18 namespace faiss { namespace gpu {
19 
20 template <typename T>
21 struct Sum {
22  __device__ inline T operator()(T a, T b) const {
23  return Math<T>::add(a, b);
24  }
25 
26  inline __device__ T identity() const {
27  return Math<T>::zero();
28  }
29 };
30 
31 template <typename T>
32 struct Min {
33  __device__ inline T operator()(T a, T b) const {
34  return Math<T>::lt(a, b) ? a : b;
35  }
36 
37  inline __device__ T identity() const {
38  return Limits<T>::getMax();
39  }
40 };
41 
42 template <typename T>
43 struct Max {
44  __device__ inline T operator()(T a, T b) const {
45  return Math<T>::gt(a, b) ? a : b;
46  }
47 
48  inline __device__ T identity() const {
49  return Limits<T>::getMin();
50  }
51 };
52 
53 /// Used for producing segmented prefix scans; the value of the Pair
54 /// denotes the start of a new segment for the scan
55 template <typename T, typename ReduceOp>
57  inline __device__ SegmentedReduce(const ReduceOp& o)
58  : op(o) {
59  }
60 
61  __device__
62  inline Pair<T, bool>
63  operator()(const Pair<T, bool>& a, const Pair<T, bool>& b) const {
64  return Pair<T, bool>(b.v ? b.k : op(a.k, b.k),
65  a.v || b.v);
66  }
67 
68  inline __device__ Pair<T, bool> identity() const {
69  return Pair<T, bool>(op.identity(), false);
70  }
71 
72  ReduceOp op;
73 };
74 
75 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:21