2017-02-23 06:26:44 +08:00
|
|
|
/**
|
2019-05-28 22:17:22 +08:00
|
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
2017-02-23 06:26:44 +08:00
|
|
|
*
|
2019-05-28 22:17:22 +08:00
|
|
|
* This source code is licensed under the MIT license found in the
|
2017-02-23 06:26:44 +08:00
|
|
|
* LICENSE file in the root directory of this source tree.
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <cuda.h>
|
2019-09-21 00:59:10 +08:00
|
|
|
#include <faiss/gpu/utils/Float16.cuh>
|
2017-02-23 06:26:44 +08:00
|
|
|
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
struct Comparator {
|
|
|
|
__device__ static inline bool lt(T a, T b) {
|
|
|
|
return a < b;
|
|
|
|
}
|
|
|
|
|
|
|
|
__device__ static inline bool gt(T a, T b) {
|
|
|
|
return a > b;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct Comparator<half> {
|
|
|
|
__device__ static inline bool lt(half a, half b) {
|
|
|
|
#if FAISS_USE_FULL_FLOAT16
|
|
|
|
return __hlt(a, b);
|
|
|
|
#else
|
|
|
|
return __half2float(a) < __half2float(b);
|
|
|
|
#endif // FAISS_USE_FULL_FLOAT16
|
|
|
|
}
|
|
|
|
|
|
|
|
__device__ static inline bool gt(half a, half b) {
|
|
|
|
#if FAISS_USE_FULL_FLOAT16
|
|
|
|
return __hgt(a, b);
|
|
|
|
#else
|
|
|
|
return __half2float(a) > __half2float(b);
|
|
|
|
#endif // FAISS_USE_FULL_FLOAT16
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} } // namespace
|