Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
ReductionOperators.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include <cuda.h>
15 #include "Limits.cuh"
16 #include "MathOperators.cuh"
17 #include "Pair.cuh"
18 
19 namespace faiss { namespace gpu {
20 
21 template <typename T>
22 struct Sum {
23  __device__ inline T operator()(T a, T b) const {
24  return Math<T>::add(a, b);
25  }
26 
27  inline __device__ T identity() const {
28  return Math<T>::zero();
29  }
30 };
31 
32 template <typename T>
33 struct Min {
34  __device__ inline T operator()(T a, T b) const {
35  return Math<T>::lt(a, b) ? a : b;
36  }
37 
38  inline __device__ T identity() const {
39  return Limits<T>::getMax();
40  }
41 };
42 
43 template <typename T>
44 struct Max {
45  __device__ inline T operator()(T a, T b) const {
46  return Math<T>::gt(a, b) ? a : b;
47  }
48 
49  inline __device__ T identity() const {
50  return Limits<T>::getMin();
51  }
52 };
53 
54 /// Used for producing segmented prefix scans; the value of the Pair
55 /// denotes the start of a new segment for the scan
56 template <typename T, typename ReduceOp>
58  inline __device__ SegmentedReduce(const ReduceOp& o)
59  : op(o) {
60  }
61 
62  __device__
63  inline Pair<T, bool>
64  operator()(const Pair<T, bool>& a, const Pair<T, bool>& b) const {
65  return Pair<T, bool>(b.v ? b.k : op(a.k, b.k),
66  a.v || b.v);
67  }
68 
69  inline __device__ Pair<T, bool> identity() const {
70  return Pair<T, bool>(op.identity(), false);
71  }
72 
73  ReduceOp op;
74 };
75 
76 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:22