72 lines
1.8 KiB
Plaintext
72 lines
1.8 KiB
Plaintext
/**
|
|
* Copyright (c) 2015-present, Facebook, Inc.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the CC-by-NC license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <cuda.h>
|
|
#include "MathOperators.cuh"
|
|
#include "WarpShuffles.cuh"
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
/// A simple pair type for CUDA device usage
|
|
template <typename K, typename V>
|
|
struct Pair {
|
|
constexpr __device__ inline Pair() {
|
|
}
|
|
|
|
constexpr __device__ inline Pair(K key, V value)
|
|
: k(key), v(value) {
|
|
}
|
|
|
|
__device__ inline bool
|
|
operator==(const Pair<K, V>& rhs) const {
|
|
return Math<K>::eq(k, rhs.k) && Math<V>::eq(v, rhs.v);
|
|
}
|
|
|
|
__device__ inline bool
|
|
operator!=(const Pair<K, V>& rhs) const {
|
|
return !operator==(rhs);
|
|
}
|
|
|
|
__device__ inline bool
|
|
operator<(const Pair<K, V>& rhs) const {
|
|
return Math<K>::lt(k, rhs.k) ||
|
|
(Math<K>::eq(k, rhs.k) && Math<V>::lt(v, rhs.v));
|
|
}
|
|
|
|
__device__ inline bool
|
|
operator>(const Pair<K, V>& rhs) const {
|
|
return Math<K>::gt(k, rhs.k) ||
|
|
(Math<K>::eq(k, rhs.k) && Math<V>::gt(v, rhs.v));
|
|
}
|
|
|
|
K k;
|
|
V v;
|
|
};
|
|
|
|
template <typename T, typename U>
|
|
inline __device__ Pair<T, U> shfl_up(const Pair<T, U>& pair,
|
|
unsigned int delta,
|
|
int width = kWarpSize) {
|
|
return Pair<T, U>(shfl_up(pair.k, delta, width),
|
|
shfl_up(pair.v, delta, width));
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
inline __device__ Pair<T, U> shfl_xor(const Pair<T, U>& pair,
|
|
int laneMask,
|
|
int width = kWarpSize) {
|
|
return Pair<T, U>(shfl_xor(pair.k, laneMask, width),
|
|
shfl_xor(pair.v, laneMask, width));
|
|
}
|
|
|
|
} } // namespace
|