Faiss GPU: bfloat16 brute-force kNN support (#4018)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4018

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4014

This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`).

The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16).

Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should use them.

As bfloat16 support is quite lacking on AMD/ROCm (see [here](https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html), very few bf16 functions implemented), bf16 functionality is completely disabled / not compiled for AMD ROCm.

Reviewed By: mdouze

Differential Revision: D65459723

fbshipit-source-id: 8a6aec843f7e37c205d95f2485442a26c402a3b0
pull/4040/head
Jeff Johnson 2024-11-19 19:37:03 -08:00 committed by Facebook GitHub Bot
parent 844e3ce769
commit eaab46c870
20 changed files with 879 additions and 59 deletions

View File

@ -56,6 +56,13 @@ def swig_ptr_from_FloatTensor(x):
return faiss.cast_integer_to_float_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_BFloat16Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.bfloat16
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
@ -606,8 +613,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
elif xb.dtype == torch.float16:
xb_type = faiss.DistanceDataType_F16
xb_ptr = swig_ptr_from_HalfTensor(xb)
elif xb.dtype == torch.bfloat16:
xb_type = faiss.DistanceDataType_BF16
xb_ptr = swig_ptr_from_BFloat16Tensor(xb)
else:
raise TypeError('xb must be f32 or f16')
raise TypeError('xq must be float32, float16 or bfloat16')
nq, d2 = xq.size()
assert d2 == d
@ -625,8 +635,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
elif xq.dtype == torch.float16:
xq_type = faiss.DistanceDataType_F16
xq_ptr = swig_ptr_from_HalfTensor(xq)
elif xq.dtype == torch.bfloat16:
xq_type = faiss.DistanceDataType_BF16
xq_ptr = swig_ptr_from_BFloat16Tensor(xq)
else:
raise TypeError('xq must be f32 or f16')
raise TypeError('xq must be float32, float16 or bfloat16')
if D is None:
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)

View File

@ -30,6 +30,7 @@
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <optional>
#if defined USE_NVIDIA_CUVS
@ -231,7 +232,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
FAISS_THROW_IF_NOT_MSG(
args.vectorType == args.queryType,
"limitation: both vectorType and queryType must currently "
"be the same (F32 or F16");
"be the same (F32 / F16 / BF16");
#if defined USE_NVIDIA_CUVS
// Note: For now, cuVS bfknn requires queries and vectors to be same layout
@ -400,6 +401,17 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
bfKnnConvert<float>(prov, args);
} else if (args.vectorType == DistanceDataType::F16) {
bfKnnConvert<half>(prov, args);
} else if (args.vectorType == DistanceDataType::BF16) {
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
bfKnnConvert<__nv_bfloat16>(prov, args);
} else {
FAISS_THROW_MSG("not compiled with bfloat16 support");
}
#else
FAISS_THROW_MSG("no AMD bfloat16 support");
#endif
} else {
FAISS_THROW_MSG("unknown vectorType");
}
@ -466,8 +478,10 @@ void bfKnn_single_query_shard(
args.k > 0,
"bfKnn_tiling: tiling vectors is only supported for k > 0");
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
: args.vectorType == DistanceDataType::F16 ? 2
: 0;
: (args.vectorType == DistanceDataType::F16 ||
args.vectorType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown vectorType");
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
@ -524,8 +538,10 @@ void bfKnn_tiling(
args.k > 0,
"bfKnn_tiling: tiling queries is only supported for k > 0");
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
: args.queryType == DistanceDataType::F16 ? 2
: 0;
: (args.queryType == DistanceDataType::F16 ||
args.queryType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown queryType");
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8

View File

@ -19,6 +19,7 @@ class GpuResourcesProvider;
enum class DistanceDataType {
F32 = 1,
F16,
BF16,
};
// Scalar type of the indices data

View File

@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
GpuResources::~GpuResources() = default;
bool GpuResources::supportsBFloat16CurrentDevice() {
return supportsBFloat16(getCurrentDevice());
}
cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
return getBlasHandle(getCurrentDevice());
}

View File

@ -205,6 +205,9 @@ class GpuResources {
/// of demand
virtual void initializeForDevice(int device) = 0;
/// Does the given GPU support bfloat16?
virtual bool supportsBFloat16(int device) = 0;
/// Returns the cuBLAS handle that we use for the given device
virtual cublasHandle_t getBlasHandle(int device) = 0;
@ -252,6 +255,9 @@ class GpuResources {
/// Functions provided by default
///
/// Does the current GPU support bfloat16?
bool supportsBFloat16CurrentDevice();
/// Calls getBlasHandle with the current device
cublasHandle_t getBlasHandleCurrentDevice();

View File

@ -206,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
return requested;
}
/// Does the given GPU support bfloat16?
bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
initializeForDevice(device);
auto& prop = getDeviceProperties(device);
return prop.major >= 8;
}
void StandardGpuResourcesImpl::noTempMemory() {
setTempMemory(0);
}
@ -701,6 +708,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
return res_;
}
bool StandardGpuResources::supportsBFloat16(int device) {
return res_->supportsBFloat16(device);
}
bool StandardGpuResources::supportsBFloat16CurrentDevice() {
return res_->supportsBFloat16CurrentDevice();
}
void StandardGpuResources::noTempMemory() {
res_->noTempMemory();
}

View File

@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
~StandardGpuResourcesImpl() override;
/// Does the given GPU support bfloat16?
bool supportsBFloat16(int device) override;
/// Disable allocation of temporary memory; all temporary memory
/// requests will call cudaMalloc / cudaFree at the point of use
void noTempMemory();
@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
std::shared_ptr<GpuResources> getResources() override;
/// Whether or not the given device supports native bfloat16 arithmetic
bool supportsBFloat16(int device);
/// Whether or not the current device supports native bfloat16 arithmetic
bool supportsBFloat16CurrentDevice();
/// Disable allocation of temporary memory; all temporary memory
/// requests will call cudaMalloc / cudaFree at the point of use
void noTempMemory();

View File

@ -504,6 +504,30 @@ void runAllPairwiseL2Distance(
outDistances);
}
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
true,
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
outDistances);
}
#endif // USE_AMD_ROCM
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
@ -544,6 +568,29 @@ void runAllPairwiseIPDistance(
outDistances);
}
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
false,
res,
stream,
vectors,
vectorsRowMajor,
nullptr,
queries,
queriesRowMajor,
outDistances);
}
#endif // USE_AMD_ROCM
void runL2Distance(
GpuResources* res,
cudaStream_t stream,
@ -596,6 +643,35 @@ void runL2Distance(
ignoreOutDistances);
}
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances) {
runL2Distance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
k,
outDistances,
outIndices,
ignoreOutDistances);
}
#endif // USE_AMD_ROCM
void runIPDistance(
GpuResources* res,
cudaStream_t stream,
@ -640,5 +716,30 @@ void runIPDistance(
outIndices);
}
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices) {
runIPDistance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
queries,
queriesRowMajor,
k,
outDistances,
outIndices);
}
#endif // USE_AMD_ROCM
} // namespace gpu
} // namespace faiss

View File

@ -41,6 +41,19 @@ void runAllPairwiseL2Distance(
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
#endif // USE_AMD_ROCM
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
@ -59,6 +72,18 @@ void runAllPairwiseIPDistance(
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
#endif // USE_AMD_ROCM
/// Calculates brute-force L2 distance between `vectors` and
/// `queries`, returning the k closest results seen
void runL2Distance(
@ -91,6 +116,22 @@ void runL2Distance(
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances = false);
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runL2Distance(
GpuResources* resources,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances = false);
#endif // USE_AMD_ROCM
/// Calculates brute-force inner product distance between `vectors`
/// and `queries`, returning the k closest results seen
void runIPDistance(
@ -115,6 +156,20 @@ void runIPDistance(
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices);
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runIPDistance(
GpuResources* resources,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices);
#endif // USE_AMD_ROCM
//
// General distance implementation, assumes that all arguments are on the
// device. This is the top-level internal distance function to call to dispatch

View File

@ -151,10 +151,10 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance(
bool kInBounds = k < query.getSize(1);
queryTileBase[threadIdx.x + i * TILE_SIZE] =
kInBounds ? queryBase[k] : ConvertTo<T>::to(0);
kInBounds ? queryBase[k] : ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x + i * TILE_SIZE] =
kInBounds ? vecBase[k] : ConvertTo<T>::to(0);
kInBounds ? vecBase[k] : ConvertTo<T>::to(0.0f);
}
__syncthreads();
@ -185,10 +185,10 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance(
for (idx_t k = threadIdx.x; k < limit; k += TILE_SIZE) {
// Load query tile
queryTileBase[threadIdx.x] =
queryThreadInBounds ? queryBase[k] : ConvertTo<T>::to(0);
queryThreadInBounds ? queryBase[k] : ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x] =
vecThreadInBoundsLoad ? vecBase[k] : ConvertTo<T>::to(0);
vecThreadInBoundsLoad ? vecBase[k] : ConvertTo<T>::to(0.0f);
__syncthreads();
@ -211,11 +211,11 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance(
// Load query tile
queryTileBase[threadIdx.x] = queryThreadInBounds && kInBounds
? queryBase[k]
: ConvertTo<T>::to(0);
: ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x] = vecThreadInBoundsLoad && kInBounds
? vecBase[k]
: ConvertTo<T>::to(0);
: ConvertTo<T>::to(0.0f);
__syncthreads();

View File

@ -154,7 +154,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
inline __device__ void decode(void* data, idx_t vec, int d, float* out)
const {
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
out[0] = Convert<half, float>()(p[d]);
out[0] = ConvertTo<float>::to(p[d]);
}
inline __device__ float decodePartial(
@ -172,7 +172,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
int d,
float v[kDimPerIter]) const {
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
p[d] = Convert<float, half>()(v[0]);
p[d] = ConvertTo<half>::to(v[0]);
}
inline __device__ void encodePartial(
@ -191,11 +191,11 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
static constexpr int kEncodeBits = 16;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return Convert<float, half>()(v);
return ConvertTo<half>::to(v);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return Convert<half, float>()(v);
return ConvertTo<float>::to(v);
}
int bytesPerVec;

View File

@ -11,7 +11,6 @@
#include <faiss/gpu/impl/L2Norm.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/MathOperators.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/Reductions.cuh>
@ -276,5 +275,18 @@ void runL2Norm(
runL2Norm<half, half2>(input, inputRowMajor, output, normSquared, stream);
}
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runL2Norm(
Tensor<__nv_bfloat16, 2, true>& input,
bool inputRowMajor,
Tensor<float, 1, true>& output,
bool normSquared,
cudaStream_t stream) {
runL2Norm<__nv_bfloat16, __nv_bfloat162>(
input, inputRowMajor, output, normSquared, stream);
}
#endif
} // namespace gpu
} // namespace faiss

View File

@ -7,7 +7,7 @@
#pragma once
#include <cuda_fp16.h>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/Tensor.cuh>
namespace faiss {
@ -27,5 +27,15 @@ void runL2Norm(
bool normSquared,
cudaStream_t stream);
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runL2Norm(
Tensor<__nv_bfloat16, 2, true>& input,
bool inputRowMajor,
Tensor<float, 1, true>& output,
bool normSquared,
cudaStream_t stream);
#endif
} // namespace gpu
} // namespace faiss

View File

@ -114,10 +114,8 @@ __global__ void gatherReconstructByIds(
auto vec = vecs[id];
auto outVec = out[blockIdx.x];
Convert<T, float> conv;
for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
}
}
@ -131,10 +129,8 @@ __global__ void gatherReconstructByRange(
auto vec = vecs[id];
auto outVec = out[blockIdx.x];
Convert<T, float> conv;
for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
}
}

View File

@ -32,6 +32,13 @@
#include <sstream>
#include <vector>
enum class TestThresholds {
Normal,
BF16,
// Linf has worse error than the other metrics for bf16
BF16_Linf,
};
void evaluate_bfknn(
faiss::gpu::GpuDistanceParams& args,
faiss::gpu::GpuResourcesProvider* res,
@ -43,16 +50,39 @@ void evaluate_bfknn(
int k,
bool colMajorVecs,
bool colMajorQueries,
faiss::MetricType metric) {
faiss::MetricType metric,
TestThresholds thresh = TestThresholds::Normal) {
using namespace faiss::gpu;
bfKnn(res, args);
std::stringstream str;
str << "using cuVS " << args.use_cuvs << "metric " << metric
str << "using cuVS " << args.use_cuvs << " metric " << metric
<< " colMajorVecs " << colMajorVecs << " colMajorQueries "
<< colMajorQueries;
float maxRelativeError;
float pctMaxDiff1;
float pctMaxDiffN;
switch (thresh) {
case TestThresholds::Normal:
maxRelativeError = 6e-3f;
pctMaxDiff1 = 0.1f;
pctMaxDiffN = 0.015f;
break;
case TestThresholds::BF16:
maxRelativeError = 1.5e-2f;
pctMaxDiff1 = 0.3f;
pctMaxDiffN = 0.1f;
break;
case TestThresholds::BF16_Linf:
maxRelativeError = 1.5e-2f;
pctMaxDiff1 = 0.53f;
pctMaxDiffN = 0.2f;
break;
}
compareLists(
cpuDistance.data(),
cpuIndices.data(),
@ -64,9 +94,9 @@ void evaluate_bfknn(
false,
false,
true,
6e-3f,
0.1f,
0.015f);
maxRelativeError,
pctMaxDiff1,
pctMaxDiffN);
}
void testTransposition(
@ -82,6 +112,10 @@ void testTransposition(
StandardGpuResources res;
res.noTempMemory();
// The transpose and distance code assumes the desired device is already set
DeviceScope scope(device);
auto stream = res.getDefaultStream(device);
int dim = randVal(20, 150);
int numVecs = randVal(10, 30000);
int numQuery = randVal(1, 1024);
@ -120,10 +154,6 @@ void testTransposition(
cpuIndex.search(
numQuery, queries.data(), k, cpuDistance.data(), cpuIndices.data());
// The transpose and distance code assumes the desired device is already set
DeviceScope scope(device);
auto stream = res.getDefaultStream(device);
// Copy input data to GPU, and pre-transpose both vectors and queries for
// passing
auto gpuVecs = toDeviceNonTemporary<float, 2>(
@ -191,12 +221,161 @@ void testTransposition(
metric);
}
void testTransposition_bf16(
bool colMajorVecs,
bool colMajorQueries,
faiss::MetricType metric,
bool use_raft = false,
float metricArg = 0) {
using namespace faiss::gpu;
#ifdef USE_AMD_ROCM
std::cout << "skipping bfloat16 test (no bfloat16 support on AMD)\n";
EXPECT_TRUE(true);
return;
#else
int device = randVal(0, getNumDevices() - 1);
StandardGpuResources res;
if (!res.supportsBFloat16(device)) {
std::cout << "skipping bfloat16 test (no bfloat16 support on device)\n";
return;
}
res.noTempMemory();
// The transpose and distance code assumes the desired device is already set
DeviceScope scope(device);
auto stream = res.getDefaultStream(device);
int dim = randVal(20, 150);
int numVecs = randVal(10, 30000);
int numQuery = randVal(1, 1024);
int k = std::min(numVecs, randVal(20, 70));
// Input data for CPU
std::vector<float> vecs = randVecs(numVecs, dim);
std::vector<float> queries = randVecs(numQuery, dim);
if ((metric == faiss::MetricType::METRIC_JensenShannon) ||
(metric == faiss::MetricType::METRIC_Jaccard)) {
// make values positive
for (auto& v : vecs) {
v = std::abs(v);
if (v == 0) {
v = 1e-6;
}
}
for (auto& q : queries) {
q = std::abs(q);
if (q == 0) {
q = 1e-6;
}
}
}
// The CPU index is our reference for the results
faiss::IndexFlat cpuIndex(dim, metric);
cpuIndex.metric_arg = metricArg;
cpuIndex.add(numVecs, vecs.data());
std::vector<float> cpuDistance(numQuery * k, 0);
std::vector<faiss::idx_t> cpuIndices(numQuery * k, -1);
cpuIndex.search(
numQuery, queries.data(), k, cpuDistance.data(), cpuIndices.data());
// Convert float32 data to bfloat16 via truncation not rounding
// (just copy high 2 bytes)
std::vector<uint16_t> bf16_vecs(vecs.size());
std::vector<uint16_t> bf16_queries(queries.size());
auto fn_f32_bf16 = [](float v) {
uint32_t vi;
std::memcpy(&vi, &v, sizeof(uint32_t));
return uint16_t(vi >> 16);
};
std::transform(vecs.begin(), vecs.end(), bf16_vecs.begin(), fn_f32_bf16);
std::transform(
queries.begin(), queries.end(), bf16_queries.begin(), fn_f32_bf16);
// Copy input data to GPU, and pre-transpose both vectors and queries for
// passing. Just use uint16_t in lieu of __nv_bfloat16
auto gpuVecs = toDeviceNonTemporary<uint16_t, 2>(
res.getResources().get(),
device,
bf16_vecs.data(),
stream,
{numVecs, dim});
auto gpuQueries = toDeviceNonTemporary<uint16_t, 2>(
res.getResources().get(),
device,
bf16_queries.data(),
stream,
{numQuery, dim});
DeviceTensor<uint16_t, 2, true> vecsT(
res.getResources().get(),
makeDevAlloc(AllocType::Other, stream),
{dim, numVecs});
runTransposeAny(gpuVecs, 0, 1, vecsT, stream);
DeviceTensor<uint16_t, 2, true> queriesT(
res.getResources().get(),
makeDevAlloc(AllocType::Other, stream),
{dim, numQuery});
runTransposeAny(gpuQueries, 0, 1, queriesT, stream);
std::vector<float> gpuDistance(numQuery * k, 0);
std::vector<faiss::idx_t> gpuIndices(numQuery * k, -1);
GpuDistanceParams args;
args.metric = metric;
args.metricArg = metricArg;
args.k = k;
args.dims = dim;
args.vectors = colMajorVecs ? vecsT.data() : gpuVecs.data();
args.vectorType = DistanceDataType::BF16;
args.vectorsRowMajor = !colMajorVecs;
args.numVectors = numVecs;
args.queries = colMajorQueries ? queriesT.data() : gpuQueries.data();
args.queryType = DistanceDataType::BF16;
args.queriesRowMajor = !colMajorQueries;
args.numQueries = numQuery;
args.outDistances = gpuDistance.data();
args.outIndices = gpuIndices.data();
args.device = device;
evaluate_bfknn(
args,
&res,
cpuDistance,
cpuIndices,
gpuDistance,
gpuIndices,
numQuery,
k,
colMajorVecs,
colMajorQueries,
metric,
metric == faiss::MetricType::METRIC_Linf ? TestThresholds::BF16_Linf
: TestThresholds::BF16);
#endif
}
// Test different memory layouts for brute-force k-NN
TEST(TestGpuDistance, Transposition_RR) {
testTransposition(false, false, faiss::MetricType::METRIC_L2);
testTransposition(false, false, faiss::MetricType::METRIC_INNER_PRODUCT);
}
TEST(TestGpuDistance, Transposition_RR_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_L2);
testTransposition_bf16(
false, false, faiss::MetricType::METRIC_INNER_PRODUCT);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Transposition_RR) {
testTransposition(false, false, faiss::MetricType::METRIC_L2, true);
@ -209,6 +388,10 @@ TEST(TestGpuDistance, Transposition_RC) {
testTransposition(false, true, faiss::MetricType::METRIC_L2);
}
TEST(TestGpuDistance, Transposition_RC_BF16) {
testTransposition_bf16(false, true, faiss::MetricType::METRIC_L2);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Transposition_RC) {
testTransposition(false, true, faiss::MetricType::METRIC_L2, true);
@ -219,6 +402,10 @@ TEST(TestGpuDistance, Transposition_CR) {
testTransposition(true, false, faiss::MetricType::METRIC_L2);
}
TEST(TestGpuDistance, Transposition_CR_BF16) {
testTransposition_bf16(true, false, faiss::MetricType::METRIC_L2);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Transposition_CR) {
testTransposition(true, false, faiss::MetricType::METRIC_L2, true);
@ -229,6 +416,10 @@ TEST(TestGpuDistance, Transposition_CC) {
testTransposition(true, true, faiss::MetricType::METRIC_L2);
}
TEST(TestGpuDistance, Transposition_CC_BF16) {
testTransposition_bf16(true, true, faiss::MetricType::METRIC_L2);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Transposition_CC) {
testTransposition(true, true, faiss::MetricType::METRIC_L2, true);
@ -239,6 +430,10 @@ TEST(TestGpuDistance, L1) {
testTransposition(false, false, faiss::MetricType::METRIC_L1);
}
TEST(TestGpuDistance, L1_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_L1);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, L1) {
testTransposition(false, false, faiss::MetricType::METRIC_L1, true);
@ -257,10 +452,18 @@ TEST(TestCuvsGpuDistance, L1_RC) {
}
#endif
TEST(TestGpuDistance, L1_RC_BF16) {
testTransposition_bf16(false, true, faiss::MetricType::METRIC_L1);
}
TEST(TestGpuDistance, L1_CR) {
testTransposition(true, false, faiss::MetricType::METRIC_L1);
}
TEST(TestGpuDistance, L1_CR_BF16) {
testTransposition_bf16(true, false, faiss::MetricType::METRIC_L1);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, L1_CR) {
testTransposition(true, false, faiss::MetricType::METRIC_L1, true);
@ -271,6 +474,10 @@ TEST(TestGpuDistance, L1_CC) {
testTransposition(true, true, faiss::MetricType::METRIC_L1);
}
TEST(TestGpuDistance, L1_CC_BF16) {
testTransposition_bf16(true, true, faiss::MetricType::METRIC_L1);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, L1_CC) {
testTransposition(true, true, faiss::MetricType::METRIC_L1, true);
@ -289,10 +496,19 @@ TEST(TestCuvsGpuDistance, Linf) {
}
#endif
TEST(TestGpuDistance, Linf_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_Linf);
}
TEST(TestGpuDistance, Lp) {
testTransposition(false, false, faiss::MetricType::METRIC_Lp, false, 3);
}
TEST(TestGpuDistance, Lp_BF16) {
testTransposition_bf16(
false, false, faiss::MetricType::METRIC_Lp, false, 3);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Lp) {
testTransposition(false, false, faiss::MetricType::METRIC_Lp, true, 3);
@ -303,6 +519,10 @@ TEST(TestGpuDistance, Canberra) {
testTransposition(false, false, faiss::MetricType::METRIC_Canberra);
}
TEST(TestGpuDistance, Canberra_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_Canberra);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, Canberra) {
testTransposition(false, false, faiss::MetricType::METRIC_Canberra, true);
@ -313,10 +533,19 @@ TEST(TestGpuDistance, BrayCurtis) {
testTransposition(false, false, faiss::MetricType::METRIC_BrayCurtis);
}
TEST(TestGpuDistance, BrayCurtis_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_BrayCurtis);
}
TEST(TestGpuDistance, JensenShannon) {
testTransposition(false, false, faiss::MetricType::METRIC_JensenShannon);
}
TEST(TestGpuDistance, JensenShannon_BF16) {
testTransposition_bf16(
false, false, faiss::MetricType::METRIC_JensenShannon);
}
#if defined USE_NVIDIA_CUVS
TEST(TestCuvsGpuDistance, JensenShannon) {
testTransposition(
@ -328,6 +557,10 @@ TEST(TestGpuDistance, Jaccard) {
testTransposition(false, false, faiss::MetricType::METRIC_Jaccard);
}
TEST(TestGpuDistance, Jaccard_BF16) {
testTransposition_bf16(false, false, faiss::MetricType::METRIC_Jaccard);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);

View File

@ -5,7 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
#include <cuda_fp16.h>
#include <faiss/gpu/test/TestUtils.h>
#include <faiss/utils/random.h>
#include <gtest/gtest.h>
@ -18,6 +17,77 @@
namespace faiss {
namespace gpu {
inline float half2float(const unsigned short h) {
unsigned int sign = ((static_cast<unsigned int>(h) >> 15U) & 1U);
unsigned int exponent = ((static_cast<unsigned int>(h) >> 10U) & 0x1fU);
unsigned int mantissa = ((static_cast<unsigned int>(h) & 0x3ffU) << 13U);
float f;
if (exponent == 0x1fU) { /* NaN or Inf */
/* discard sign of a NaN */
sign = ((mantissa != 0U) ? (sign >> 1U) : sign);
mantissa = ((mantissa != 0U) ? 0x7fffffU : 0U);
exponent = 0xffU;
} else if (exponent == 0U) { /* Denorm or Zero */
if (mantissa != 0U) {
unsigned int msb;
exponent = 0x71U;
do {
msb = (mantissa & 0x400000U);
mantissa <<= 1U; /* normalize */
--exponent;
} while (msb == 0U);
mantissa &= 0x7fffffU; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70U;
}
const unsigned int u = ((sign << 31U) | (exponent << 23U) | mantissa);
std::memcpy(&f, &u, sizeof(u));
return f;
}
unsigned short float2half(const float f) {
unsigned int sign;
unsigned int remainder;
unsigned int x;
unsigned int u;
unsigned int result;
(void)std::memcpy(&x, &f, sizeof(f));
u = (x & 0x7fffffffU);
sign = ((x >> 16U) & 0x8000U);
// NaN/+Inf/-Inf
if (u >= 0x7f800000U) {
remainder = 0U;
result = ((u == 0x7f800000U) ? (sign | 0x7c00U) : 0x7fffU);
} else if (u > 0x477fefffU) { // Overflows
remainder = 0x80000000U;
result = (sign | 0x7bffU);
} else if (u >= 0x38800000U) { // Normal numbers
remainder = u << 19U;
u -= 0x38000000U;
result = (sign | (u >> 13U));
} else if (u < 0x33000001U) { // +0/-0
remainder = u;
result = sign;
} else { // Denormal numbers
const unsigned int exponent = u >> 23U;
const unsigned int shift = 0x7eU - exponent;
unsigned int mantissa = (u & 0x7fffffU);
mantissa |= 0x800000U;
remainder = mantissa << (32U - shift);
result = (sign | (mantissa >> shift));
result &= 0x0000FFFFU;
}
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((result & 0x1U) != 0U))) {
return static_cast<unsigned short>(result) + 1;
} else {
return static_cast<unsigned short>(result);
}
}
inline float relativeError(float a, float b) {
return std::abs(a - b) / (0.5f * (std::abs(a) + std::abs(b)));
}
@ -78,7 +148,7 @@ std::vector<unsigned char> randBinaryVecs(size_t num, size_t dim) {
std::vector<float> roundToHalf(const std::vector<float>& v) {
auto out = std::vector<float>(v.size());
for (int i = 0; i < v.size(); ++i) {
out[i] = __half2float(__float2half(v[i]));
out[i] = half2float(float2half(v[i]));
}
return out;

View File

@ -22,29 +22,13 @@ namespace gpu {
// Conversion utilities
//
template <typename From, typename To>
struct Convert {
inline __device__ To operator()(From v) const {
return (To)v;
}
};
template <>
struct Convert<float, half> {
inline __device__ half operator()(float v) const {
return __float2half(v);
}
};
template <>
struct Convert<half, float> {
inline __device__ float operator()(half v) const {
return __half2float(v);
}
};
template <typename T>
struct ConvertTo {};
struct ConvertTo {
template <typename U>
static inline __device__ T to(U v) {
return T(v);
}
};
template <>
struct ConvertTo<float> {
@ -54,6 +38,12 @@ struct ConvertTo<float> {
static inline __device__ float to(half v) {
return __half2float(v);
}
#ifndef USE_AMD_ROCM
static inline __device__ float to(__nv_bfloat16 v) {
return __bfloat162float(v);
}
#endif // !USE_AMD_ROCM
};
template <>
@ -106,6 +96,31 @@ struct ConvertTo<Half4> {
}
};
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
template <>
struct ConvertTo<__nv_bfloat16> {
static inline __device__ __nv_bfloat16 to(float v) {
return __float2bfloat16(v);
}
static inline __device__ __nv_bfloat16 to(half v) {
return __float2bfloat16(__half2float(v));
}
static inline __device__ __nv_bfloat16 to(__nv_bfloat16 v) {
return v;
}
};
#endif // USE_AMD_ROCM
template <typename From, typename To>
struct Convert {
inline __device__ To operator()(From v) const {
return ConvertTo<To>::to(v);
}
};
// Tensor conversion
template <typename From, typename To>
void runConvert(const From* in, To* out, size_t num, cudaStream_t stream) {

View File

@ -16,7 +16,21 @@
#define FAISS_USE_FULL_FLOAT16 1
#endif // __CUDA_ARCH__ types
// Some compute capabilities have full bfloat16 ALUs.
// FIXME: no support in ROCm yet
#if __CUDA_ARCH__ >= 800 // || defined(USE_AMD_ROCM)
#define FAISS_USE_FULL_BFLOAT16 1
#endif // __CUDA_ARCH__ types
#include <cuda_fp16.h>
#if !defined(USE_AMD_ROCM)
#include <cuda_bf16.h>
#endif
// #else
// FIXME: no support in ROCm yet
// #include <amd_hip_bf16.h>
// #include <amd_hip_fp16.h>
// #endif // !defined(USE_AMD_ROCM)
namespace faiss {
namespace gpu {

View File

@ -13,7 +13,7 @@
//
// Templated wrappers to express math for different scalar and vector
// types, so kernels can have the same written form but can operate
// over half and float, and on vector types transparently
// over half, bfloat16 and float, and on vector types transparently
//
namespace faiss {
@ -556,5 +556,240 @@ struct Math<Half8> {
}
};
#ifndef USE_AMD_ROCM
template <>
struct Math<__nv_bfloat16> {
typedef __nv_bfloat16 ScalarType;
static inline __device__ __nv_bfloat16
add(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hadd(a, b);
#else
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
#endif
}
static inline __device__ __nv_bfloat16
sub(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hsub(a, b);
#else
return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
#endif
}
static inline __device__ __nv_bfloat16
mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hmul(a, b);
#else
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
#endif
}
static inline __device__ __nv_bfloat16 neg(__nv_bfloat16 v) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hneg(v);
#else
return __float2bfloat16(-__bfloat162float(v));
#endif
}
static inline __device__ float reduceAdd(__nv_bfloat16 v) {
return ConvertTo<float>::to(v);
}
static inline __device__ bool lt(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hlt(a, b);
#else
return __bfloat162float(a) < __bfloat162float(b);
#endif
}
static inline __device__ bool gt(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hgt(a, b);
#else
return __bfloat162float(a) > __bfloat162float(b);
#endif
}
static inline __device__ bool eq(__nv_bfloat16 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __heq(a, b);
#else
return __bfloat162float(a) == __bfloat162float(b);
#endif
}
static inline __device__ __nv_bfloat16 zero() {
#if CUDA_VERSION >= 9000
return 0.0f;
#else
__nv_bfloat16 h;
h.x = 0;
return h;
#endif
}
};
template <>
struct Math<__nv_bfloat162> {
typedef __nv_bfloat16 ScalarType;
#ifndef FAISS_USE_FULL_BFLOAT16
// define a few conversion functions that don't exist on cuda 11
// this overrides their definition in cuda 12 but we use native bf16 on this
// platform anyways.
static inline __device__ float2 __bfloat1622float2(__nv_bfloat162 a) {
float2 af;
af.x = __bfloat162float(a.x);
af.y = __bfloat162float(a.y);
return af;
}
static inline __device__ __nv_bfloat162 __float22bfloat162_rn(float2 af) {
__nv_bfloat162 a;
a.x = __float2bfloat16_rn(af.x);
a.y = __float2bfloat16_rn(af.y);
return a;
}
static inline __device__ __nv_bfloat162
__bfloat162bfloat162(__nv_bfloat16 v) {
__nv_bfloat162 a;
a.x = v;
a.y = v;
return a;
}
#endif
static inline __device__ __nv_bfloat162
add(__nv_bfloat162 a, __nv_bfloat162 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hadd2(a, b);
#else
float2 af = __bfloat1622float2(a);
float2 bf = __bfloat1622float2(b);
af.x += bf.x;
af.y += bf.y;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162
sub(__nv_bfloat162 a, __nv_bfloat162 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hsub2(a, b);
#else
float2 af = __bfloat1622float2(a);
float2 bf = __bfloat1622float2(b);
af.x -= bf.x;
af.y -= bf.y;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162
add(__nv_bfloat162 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
__nv_bfloat162 b2 = __bfloat162bfloat162(b);
return __hadd2(a, b2);
#else
float2 af = __bfloat1622float2(a);
float bf = __bfloat162float(b);
af.x += bf;
af.y += bf;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162
sub(__nv_bfloat162 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
__nv_bfloat162 b2 = __bfloat162bfloat162(b);
return __hsub2(a, b2);
#else
float2 af = __bfloat1622float2(a);
float bf = __bfloat162float(b);
af.x -= bf;
af.y -= bf;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162
mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hmul2(a, b);
#else
float2 af = __bfloat1622float2(a);
float2 bf = __bfloat1622float2(b);
af.x *= bf.x;
af.y *= bf.y;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162
mul(__nv_bfloat162 a, __nv_bfloat16 b) {
#ifdef FAISS_USE_FULL_BFLOAT16
__nv_bfloat162 b2 = __bfloat162bfloat162(b);
return __hmul2(a, b2);
#else
float2 af = __bfloat1622float2(a);
float bf = __bfloat162float(b);
af.x *= bf;
af.y *= bf;
return __float22bfloat162_rn(af);
#endif
}
static inline __device__ __nv_bfloat162 neg(__nv_bfloat162 v) {
#ifdef FAISS_USE_FULL_BFLOAT16
return __hneg2(v);
#else
float2 vf = __bfloat1622float2(v);
vf.x = -vf.x;
vf.y = -vf.y;
return __float22bfloat162_rn(vf);
#endif
}
static inline __device__ float reduceAdd(__nv_bfloat162 v) {
float2 vf = __bfloat1622float2(v);
vf.x += vf.y;
return vf.x;
}
// not implemented for vector types
// static inline __device__ bool lt(__nv_bfloat162 a, __nv_bfloat162 b);
// static inline __device__ bool gt(__nv_bfloat162 a, __nv_bfloat162 b);
// static inline __device__ bool eq(__nv_bfloat162 a, __nv_bfloat162 b);
static inline __device__ __nv_bfloat162 zero() {
return __bfloat162bfloat162(Math<__nv_bfloat16>::zero());
}
};
#endif // !USE_AMD_ROCM
} // namespace gpu
} // namespace faiss

View File

@ -21,6 +21,7 @@ template <typename T>
struct GetCudaType;
#ifdef USE_AMD_ROCM
template <>
struct GetCudaType<float> {
static constexpr hipblasDatatype_t Type = HIPBLAS_R_32F;
@ -30,7 +31,15 @@ template <>
struct GetCudaType<half> {
static constexpr hipblasDatatype_t Type = HIPBLAS_R_16F;
};
// FIXME: no AMD support for bf16
// template <>
// struct GetCudaType<__nv_bfloat16> {
// static constexpr hipblasDatatype_t Type = HIPBLAS_R_16B;
// };
#else
template <>
struct GetCudaType<float> {
static constexpr cudaDataType_t Type = CUDA_R_32F;
@ -40,6 +49,12 @@ template <>
struct GetCudaType<half> {
static constexpr cudaDataType_t Type = CUDA_R_16F;
};
template <>
struct GetCudaType<__nv_bfloat16> {
static constexpr cudaDataType_t Type = CUDA_R_16BF;
};
#endif
template <typename AT, typename BT>