908 lines
27 KiB
Plaintext
908 lines
27 KiB
Plaintext
/**
|
|
* Copyright (c) 2015-present, Facebook, Inc.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD+Patents license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
#include "IVFFlatScan.cuh"
|
|
#include "../GpuResources.h"
|
|
#include "IVFUtils.cuh"
|
|
#include "../utils/ConversionOperators.cuh"
|
|
#include "../utils/DeviceDefs.cuh"
|
|
#include "../utils/DeviceUtils.h"
|
|
#include "../utils/DeviceTensor.cuh"
|
|
#include "../utils/Float16.cuh"
|
|
#include "../utils/MathOperators.cuh"
|
|
#include "../utils/LoadStoreOperators.cuh"
|
|
#include "../utils/PtxUtils.cuh"
|
|
#include "../utils/Reductions.cuh"
|
|
#include "../utils/StaticUtils.h"
|
|
#include <thrust/host_vector.h>
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
template <typename T>
|
|
inline __device__ typename Math<T>::ScalarType l2Distance(T a, T b) {
|
|
a = Math<T>::sub(a, b);
|
|
a = Math<T>::mul(a, a);
|
|
return Math<T>::reduceAdd(a);
|
|
}
|
|
|
|
template <typename T>
|
|
inline __device__ typename Math<T>::ScalarType ipDistance(T a, T b) {
|
|
return Math<T>::reduceAdd(Math<T>::mul(a, b));
|
|
}
|
|
|
|
// For list scanning, even if the input data is `half`, we perform all
|
|
// math in float32, because the code is memory b/w bound, and the
|
|
// added precision for accumulation is useful
|
|
|
|
/// The class that we use to provide scan specializations
|
|
template <int Dims, bool L2, typename T>
|
|
struct IVFFlatScan {
|
|
};
|
|
|
|
// Fallback implementation: works for any dimension size
|
|
template <bool L2, typename T>
|
|
struct IVFFlatScan<-1, L2, T> {
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
extern __shared__ float smem[];
|
|
T* vecs = (T*) vecData;
|
|
|
|
for (int vec = 0; vec < numVecs; ++vec) {
|
|
// Reduce in dist
|
|
float dist = 0.0f;
|
|
|
|
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
|
|
float vecVal = ConvertTo<float>::to(vecs[vec * dim + d]);
|
|
float queryVal = query[d];
|
|
float curDist;
|
|
|
|
if (L2) {
|
|
curDist = l2Distance(queryVal, vecVal);
|
|
} else {
|
|
curDist = ipDistance(queryVal, vecVal);
|
|
}
|
|
|
|
dist += curDist;
|
|
}
|
|
|
|
// Reduce distance within block
|
|
dist = blockReduceAllSum<float, false, true>(dist, smem);
|
|
|
|
if (threadIdx.x == 0) {
|
|
distanceOut[vec] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// implementation: works for # dims == blockDim.x
|
|
template <bool L2, typename T>
|
|
struct IVFFlatScan<0, L2, T> {
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
extern __shared__ float smem[];
|
|
T* vecs = (T*) vecData;
|
|
|
|
float queryVal = query[threadIdx.x];
|
|
|
|
constexpr int kUnroll = 4;
|
|
int limit = utils::roundDown(numVecs, kUnroll);
|
|
|
|
for (int i = 0; i < limit; i += kUnroll) {
|
|
float vecVal[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vecVal[j] = ConvertTo<float>::to(vecs[(i + j) * dim + threadIdx.x]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
vecVal[j] = l2Distance(queryVal, vecVal[j]);
|
|
} else {
|
|
vecVal[j] = ipDistance(queryVal, vecVal[j]);
|
|
}
|
|
}
|
|
|
|
blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
|
|
|
|
if (threadIdx.x == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j] = vecVal[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit; i < numVecs; ++i) {
|
|
float vecVal = ConvertTo<float>::to(vecs[i * dim + threadIdx.x]);
|
|
|
|
if (L2) {
|
|
vecVal = l2Distance(queryVal, vecVal);
|
|
} else {
|
|
vecVal = ipDistance(queryVal, vecVal);
|
|
}
|
|
|
|
vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
|
|
|
|
if (threadIdx.x == 0) {
|
|
distanceOut[i] = vecVal;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// 64-d float32 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<64, L2, float> {
|
|
static constexpr int kDims = 64;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// Each warp reduces a single 64-d vector; each lane loads a float2
|
|
float* vecs = (float*) vecData;
|
|
|
|
int laneId = getLaneId();
|
|
int warpId = threadIdx.x / kWarpSize;
|
|
int numWarps = blockDim.x / kWarpSize;
|
|
|
|
float2 queryVal = *(float2*) &query[laneId * 2];
|
|
|
|
constexpr int kUnroll = 4;
|
|
float2 vecVal[kUnroll];
|
|
|
|
int limit = utils::roundDown(numVecs, kUnroll * numWarps);
|
|
|
|
for (int i = warpId; i < limit; i += kUnroll * numWarps) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
// Vector we are loading from is i
|
|
// Dim we are loading from is laneId * 2
|
|
vecVal[j] = *(float2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
|
|
}
|
|
|
|
float dist[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
dist[j] = l2Distance(queryVal, vecVal[j]);
|
|
} else {
|
|
dist[j] = ipDistance(queryVal, vecVal[j]);
|
|
}
|
|
}
|
|
|
|
// Reduce within the warp
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist[j] = warpReduceAllSum(dist[j]);
|
|
}
|
|
|
|
if (laneId == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j * numWarps] = dist[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit + warpId; i < numVecs; i += numWarps) {
|
|
vecVal[0] = *(float2*) &vecs[i * kDims + laneId * 2];
|
|
float dist;
|
|
if (L2) {
|
|
dist = l2Distance(queryVal, vecVal[0]);
|
|
} else {
|
|
dist = ipDistance(queryVal, vecVal[0]);
|
|
}
|
|
|
|
dist = warpReduceAllSum(dist);
|
|
|
|
if (laneId == 0) {
|
|
distanceOut[i] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#ifdef FAISS_USE_FLOAT16
|
|
|
|
// float16 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<64, L2, half> {
|
|
static constexpr int kDims = 64;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// Each warp reduces a single 64-d vector; each lane loads a half2
|
|
half* vecs = (half*) vecData;
|
|
|
|
int laneId = getLaneId();
|
|
int warpId = threadIdx.x / kWarpSize;
|
|
int numWarps = blockDim.x / kWarpSize;
|
|
|
|
float2 queryVal = *(float2*) &query[laneId * 2];
|
|
|
|
constexpr int kUnroll = 4;
|
|
|
|
half2 vecVal[kUnroll];
|
|
|
|
int limit = utils::roundDown(numVecs, kUnroll * numWarps);
|
|
|
|
for (int i = warpId; i < limit; i += kUnroll * numWarps) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
// Vector we are loading from is i
|
|
// Dim we are loading from is laneId * 2
|
|
vecVal[j] = *(half2*) &vecs[(i + j * numWarps) * kDims + laneId * 2];
|
|
}
|
|
|
|
float dist[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
dist[j] = l2Distance(queryVal, __half22float2(vecVal[j]));
|
|
} else {
|
|
dist[j] = ipDistance(queryVal, __half22float2(vecVal[j]));
|
|
}
|
|
}
|
|
|
|
// Reduce within the warp
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist[j] = warpReduceAllSum(dist[j]);
|
|
}
|
|
|
|
if (laneId == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j * numWarps] = dist[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit + warpId; i < numVecs; i += numWarps) {
|
|
vecVal[0] = *(half2*) &vecs[i * kDims + laneId * 2];
|
|
|
|
float dist;
|
|
if (L2) {
|
|
dist = l2Distance(queryVal, __half22float2(vecVal[0]));
|
|
} else {
|
|
dist = ipDistance(queryVal, __half22float2(vecVal[0]));
|
|
}
|
|
|
|
dist = warpReduceAllSum(dist);
|
|
|
|
if (laneId == 0) {
|
|
distanceOut[i] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
// 128-d float32 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<128, L2, float> {
|
|
static constexpr int kDims = 128;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// Each warp reduces a single 128-d vector; each lane loads a float4
|
|
float* vecs = (float*) vecData;
|
|
|
|
int laneId = getLaneId();
|
|
int warpId = threadIdx.x / kWarpSize;
|
|
int numWarps = blockDim.x / kWarpSize;
|
|
|
|
float4 queryVal = *(float4*) &query[laneId * 4];
|
|
|
|
constexpr int kUnroll = 4;
|
|
float4 vecVal[kUnroll];
|
|
|
|
int limit = utils::roundDown(numVecs, kUnroll * numWarps);
|
|
|
|
for (int i = warpId; i < limit; i += kUnroll * numWarps) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
// Vector we are loading from is i
|
|
// Dim we are loading from is laneId * 4
|
|
vecVal[j] = *(float4*) &vecs[(i + j * numWarps) * kDims + laneId * 4];
|
|
}
|
|
|
|
float dist[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
dist[j] = l2Distance(queryVal, vecVal[j]);
|
|
} else {
|
|
dist[j] = ipDistance(queryVal, vecVal[j]);
|
|
}
|
|
}
|
|
|
|
// Reduce within the warp
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist[j] = warpReduceAllSum(dist[j]);
|
|
}
|
|
|
|
if (laneId == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j * numWarps] = dist[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit + warpId; i < numVecs; i += numWarps) {
|
|
vecVal[0] = *(float4*) &vecs[i * kDims + laneId * 4];
|
|
float dist;
|
|
if (L2) {
|
|
dist = l2Distance(queryVal, vecVal[0]);
|
|
} else {
|
|
dist = ipDistance(queryVal, vecVal[0]);
|
|
}
|
|
|
|
dist = warpReduceAllSum(dist);
|
|
|
|
if (laneId == 0) {
|
|
distanceOut[i] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#ifdef FAISS_USE_FLOAT16
|
|
|
|
// float16 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<128, L2, half> {
|
|
static constexpr int kDims = 128;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// Each warp reduces a single 128-d vector; each lane loads a Half4
|
|
half* vecs = (half*) vecData;
|
|
|
|
int laneId = getLaneId();
|
|
int warpId = threadIdx.x / kWarpSize;
|
|
int numWarps = blockDim.x / kWarpSize;
|
|
|
|
float4 queryVal = *(float4*) &query[laneId * 4];
|
|
|
|
constexpr int kUnroll = 4;
|
|
|
|
Half4 vecVal[kUnroll];
|
|
|
|
int limit = utils::roundDown(numVecs, kUnroll * numWarps);
|
|
|
|
for (int i = warpId; i < limit; i += kUnroll * numWarps) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
// Vector we are loading from is i
|
|
// Dim we are loading from is laneId * 4
|
|
vecVal[j] =
|
|
LoadStore<Half4>::load(
|
|
&vecs[(i + j * numWarps) * kDims + laneId * 4]);
|
|
}
|
|
|
|
float dist[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
dist[j] = l2Distance(queryVal, half4ToFloat4(vecVal[j]));
|
|
} else {
|
|
dist[j] = ipDistance(queryVal, half4ToFloat4(vecVal[j]));
|
|
}
|
|
}
|
|
|
|
// Reduce within the warp
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist[j] = warpReduceAllSum(dist[j]);
|
|
}
|
|
|
|
if (laneId == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j * numWarps] = dist[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit + warpId; i < numVecs; i += numWarps) {
|
|
vecVal[0] = LoadStore<Half4>::load(&vecs[i * kDims + laneId * 4]);
|
|
|
|
float dist;
|
|
if (L2) {
|
|
dist = l2Distance(queryVal, half4ToFloat4(vecVal[0]));
|
|
} else {
|
|
dist = ipDistance(queryVal, half4ToFloat4(vecVal[0]));
|
|
}
|
|
|
|
dist = warpReduceAllSum(dist);
|
|
|
|
if (laneId == 0) {
|
|
distanceOut[i] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
// 256-d float32 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<256, L2, float> {
|
|
static constexpr int kDims = 256;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// A specialization here to load per-warp seems to be worse, since
|
|
// we're already running at near memory b/w peak
|
|
IVFFlatScan<0, L2, float>::scan(query,
|
|
vecData,
|
|
numVecs,
|
|
dim,
|
|
distanceOut);
|
|
}
|
|
};
|
|
|
|
#ifdef FAISS_USE_FLOAT16
|
|
|
|
// float16 implementation
|
|
template <bool L2>
|
|
struct IVFFlatScan<256, L2, half> {
|
|
static constexpr int kDims = 256;
|
|
|
|
static __device__ void scan(float* query,
|
|
void* vecData,
|
|
int numVecs,
|
|
int dim,
|
|
float* distanceOut) {
|
|
// Each warp reduces a single 256-d vector; each lane loads a Half8
|
|
half* vecs = (half*) vecData;
|
|
|
|
int laneId = getLaneId();
|
|
int warpId = threadIdx.x / kWarpSize;
|
|
int numWarps = blockDim.x / kWarpSize;
|
|
|
|
// This is not a contiguous load, but we only have to load these two
|
|
// values, so that we can load by Half8 below
|
|
float4 queryValA = *(float4*) &query[laneId * 8];
|
|
float4 queryValB = *(float4*) &query[laneId * 8 + 4];
|
|
|
|
constexpr int kUnroll = 4;
|
|
|
|
Half8 vecVal[kUnroll];
|
|
|
|
int limit = utils::roundDown(numVecs, kUnroll * numWarps);
|
|
|
|
for (int i = warpId; i < limit; i += kUnroll * numWarps) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
// Vector we are loading from is i
|
|
// Dim we are loading from is laneId * 8
|
|
vecVal[j] =
|
|
LoadStore<Half8>::load(
|
|
&vecs[(i + j * numWarps) * kDims + laneId * 8]);
|
|
}
|
|
|
|
float dist[kUnroll];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
if (L2) {
|
|
dist[j] = l2Distance(queryValA, half4ToFloat4(vecVal[j].a));
|
|
dist[j] += l2Distance(queryValB, half4ToFloat4(vecVal[j].b));
|
|
} else {
|
|
dist[j] = ipDistance(queryValA, half4ToFloat4(vecVal[j].a));
|
|
dist[j] += ipDistance(queryValB, half4ToFloat4(vecVal[j].b));
|
|
}
|
|
}
|
|
|
|
// Reduce within the warp
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist[j] = warpReduceAllSum(dist[j]);
|
|
}
|
|
|
|
if (laneId == 0) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
distanceOut[i + j * numWarps] = dist[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle remainder
|
|
for (int i = limit + warpId; i < numVecs; i += numWarps) {
|
|
vecVal[0] = LoadStore<Half8>::load(&vecs[i * kDims + laneId * 8]);
|
|
|
|
float dist;
|
|
if (L2) {
|
|
dist = l2Distance(queryValA, half4ToFloat4(vecVal[0].a));
|
|
dist += l2Distance(queryValB, half4ToFloat4(vecVal[0].b));
|
|
} else {
|
|
dist = ipDistance(queryValA, half4ToFloat4(vecVal[0].a));
|
|
dist += ipDistance(queryValB, half4ToFloat4(vecVal[0].b));
|
|
}
|
|
|
|
dist = warpReduceAllSum(dist);
|
|
|
|
if (laneId == 0) {
|
|
distanceOut[i] = dist;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
template <int Dims, bool L2, typename T>
|
|
__global__ void
|
|
ivfFlatScan(Tensor<float, 2, true> queries,
|
|
Tensor<int, 2, true> listIds,
|
|
void** allListData,
|
|
int* listLengths,
|
|
Tensor<int, 2, true> prefixSumOffsets,
|
|
Tensor<float, 1, true> distance) {
|
|
auto queryId = blockIdx.y;
|
|
auto probeId = blockIdx.x;
|
|
|
|
// This is where we start writing out data
|
|
// We ensure that before the array (at offset -1), there is a 0 value
|
|
int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
|
|
|
|
auto listId = listIds[queryId][probeId];
|
|
// Safety guard in case NaNs in input cause no list ID to be generated
|
|
if (listId == -1) {
|
|
return;
|
|
}
|
|
|
|
auto query = queries[queryId].data();
|
|
auto vecs = allListData[listId];
|
|
auto numVecs = listLengths[listId];
|
|
auto dim = queries.getSize(1);
|
|
auto distanceOut = distance[outBase].data();
|
|
|
|
IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
|
|
}
|
|
|
|
void
|
|
runIVFFlatScanTile(Tensor<float, 2, true>& queries,
|
|
Tensor<int, 2, true>& listIds,
|
|
thrust::device_vector<void*>& listData,
|
|
thrust::device_vector<void*>& listIndices,
|
|
IndicesOptions indicesOptions,
|
|
thrust::device_vector<int>& listLengths,
|
|
Tensor<char, 1, true>& thrustMem,
|
|
Tensor<int, 2, true>& prefixSumOffsets,
|
|
Tensor<float, 1, true>& allDistances,
|
|
Tensor<float, 3, true>& heapDistances,
|
|
Tensor<int, 3, true>& heapIndices,
|
|
int k,
|
|
bool l2Distance,
|
|
bool useFloat16,
|
|
Tensor<float, 2, true>& outDistances,
|
|
Tensor<long, 2, true>& outIndices,
|
|
cudaStream_t stream) {
|
|
// Calculate offset lengths, so we know where to write out
|
|
// intermediate results
|
|
runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
|
|
|
|
// Calculate distances for vectors within our chunk of lists
|
|
constexpr int kMaxThreadsIVF = 512;
|
|
|
|
// FIXME: if `half` and # dims is multiple of 2, halve the
|
|
// threadblock size
|
|
|
|
int dim = queries.getSize(1);
|
|
int numThreads = std::min(dim, kMaxThreadsIVF);
|
|
|
|
auto grid = dim3(listIds.getSize(1),
|
|
listIds.getSize(0));
|
|
auto block = dim3(numThreads);
|
|
// All exact dim kernels are unrolled by 4, hence the `4`
|
|
auto smem = sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
|
|
|
|
#define RUN_IVF_FLAT(DIMS, L2, T) \
|
|
do { \
|
|
ivfFlatScan<DIMS, L2, T> \
|
|
<<<grid, block, smem, stream>>>( \
|
|
queries, \
|
|
listIds, \
|
|
listData.data().get(), \
|
|
listLengths.data().get(), \
|
|
prefixSumOffsets, \
|
|
allDistances); \
|
|
} while (0)
|
|
|
|
#ifdef FAISS_USE_FLOAT16
|
|
|
|
#define HANDLE_DIM_CASE(DIMS) \
|
|
do { \
|
|
if (l2Distance) { \
|
|
if (useFloat16) { \
|
|
RUN_IVF_FLAT(DIMS, true, half); \
|
|
} else { \
|
|
RUN_IVF_FLAT(DIMS, true, float); \
|
|
} \
|
|
} else { \
|
|
if (useFloat16) { \
|
|
RUN_IVF_FLAT(DIMS, false, half); \
|
|
} else { \
|
|
RUN_IVF_FLAT(DIMS, false, float); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
#else
|
|
|
|
#define HANDLE_DIM_CASE(DIMS) \
|
|
do { \
|
|
if (l2Distance) { \
|
|
if (useFloat16) { \
|
|
FAISS_ASSERT(false); \
|
|
} else { \
|
|
RUN_IVF_FLAT(DIMS, true, float); \
|
|
} \
|
|
} else { \
|
|
if (useFloat16) { \
|
|
FAISS_ASSERT(false); \
|
|
} else { \
|
|
RUN_IVF_FLAT(DIMS, false, float); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
#endif // FAISS_USE_FLOAT16
|
|
|
|
if (dim == 64) {
|
|
HANDLE_DIM_CASE(64);
|
|
} else if (dim == 128) {
|
|
HANDLE_DIM_CASE(128);
|
|
} else if (dim == 256) {
|
|
HANDLE_DIM_CASE(256);
|
|
} else if (dim <= kMaxThreadsIVF) {
|
|
HANDLE_DIM_CASE(0);
|
|
} else {
|
|
HANDLE_DIM_CASE(-1);
|
|
}
|
|
|
|
CUDA_TEST_ERROR();
|
|
|
|
#undef HANDLE_DIM_CASE
|
|
#undef RUN_IVF_FLAT
|
|
|
|
// k-select the output in chunks, to increase parallelism
|
|
runPass1SelectLists(prefixSumOffsets,
|
|
allDistances,
|
|
listIds.getSize(1),
|
|
k,
|
|
!l2Distance, // L2 distance chooses smallest
|
|
heapDistances,
|
|
heapIndices,
|
|
stream);
|
|
|
|
// k-select final output
|
|
auto flatHeapDistances = heapDistances.downcastInner<2>();
|
|
auto flatHeapIndices = heapIndices.downcastInner<2>();
|
|
|
|
runPass2SelectLists(flatHeapDistances,
|
|
flatHeapIndices,
|
|
listIndices,
|
|
indicesOptions,
|
|
prefixSumOffsets,
|
|
listIds,
|
|
k,
|
|
!l2Distance, // L2 distance chooses smallest
|
|
outDistances,
|
|
outIndices,
|
|
stream);
|
|
}
|
|
|
|
void
|
|
runIVFFlatScan(Tensor<float, 2, true>& queries,
|
|
Tensor<int, 2, true>& listIds,
|
|
thrust::device_vector<void*>& listData,
|
|
thrust::device_vector<void*>& listIndices,
|
|
IndicesOptions indicesOptions,
|
|
thrust::device_vector<int>& listLengths,
|
|
int maxListLength,
|
|
int k,
|
|
bool l2Distance,
|
|
bool useFloat16,
|
|
// output
|
|
Tensor<float, 2, true>& outDistances,
|
|
// output
|
|
Tensor<long, 2, true>& outIndices,
|
|
GpuResources* res) {
|
|
constexpr int kMinQueryTileSize = 8;
|
|
constexpr int kMaxQueryTileSize = 128;
|
|
constexpr int kThrustMemSize = 16384;
|
|
|
|
int nprobe = listIds.getSize(1);
|
|
|
|
auto& mem = res->getMemoryManagerCurrentDevice();
|
|
auto stream = res->getDefaultStreamCurrentDevice();
|
|
|
|
// Make a reservation for Thrust to do its dirty work (global memory
|
|
// cross-block reduction space); hopefully this is large enough.
|
|
DeviceTensor<char, 1, true> thrustMem1(
|
|
mem, {kThrustMemSize}, stream);
|
|
DeviceTensor<char, 1, true> thrustMem2(
|
|
mem, {kThrustMemSize}, stream);
|
|
DeviceTensor<char, 1, true>* thrustMem[2] =
|
|
{&thrustMem1, &thrustMem2};
|
|
|
|
// How much temporary storage is available?
|
|
// If possible, we'd like to fit within the space available.
|
|
size_t sizeAvailable = mem.getSizeAvailable();
|
|
|
|
// We run two passes of heap selection
|
|
// This is the size of the first-level heap passes
|
|
constexpr int kNProbeSplit = 8;
|
|
int pass2Chunks = std::min(nprobe, kNProbeSplit);
|
|
|
|
size_t sizeForFirstSelectPass =
|
|
pass2Chunks * k * (sizeof(float) + sizeof(int));
|
|
|
|
// How much temporary storage we need per each query
|
|
size_t sizePerQuery =
|
|
2 * // # streams
|
|
((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
|
|
nprobe * maxListLength * sizeof(float) + // allDistances
|
|
sizeForFirstSelectPass);
|
|
|
|
int queryTileSize = (int) (sizeAvailable / sizePerQuery);
|
|
|
|
if (queryTileSize < kMinQueryTileSize) {
|
|
queryTileSize = kMinQueryTileSize;
|
|
} else if (queryTileSize > kMaxQueryTileSize) {
|
|
queryTileSize = kMaxQueryTileSize;
|
|
}
|
|
|
|
// FIXME: we should adjust queryTileSize to deal with this, since
|
|
// indexing is in int32
|
|
FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
|
|
std::numeric_limits<int>::max());
|
|
|
|
// Temporary memory buffers
|
|
// Make sure there is space prior to the start which will be 0, and
|
|
// will handle the boundary condition without branches
|
|
DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
|
|
mem, {queryTileSize * nprobe + 1}, stream);
|
|
DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
|
|
mem, {queryTileSize * nprobe + 1}, stream);
|
|
|
|
DeviceTensor<int, 2, true> prefixSumOffsets1(
|
|
prefixSumOffsetSpace1[1].data(),
|
|
{queryTileSize, nprobe});
|
|
DeviceTensor<int, 2, true> prefixSumOffsets2(
|
|
prefixSumOffsetSpace2[1].data(),
|
|
{queryTileSize, nprobe});
|
|
DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
|
|
{&prefixSumOffsets1, &prefixSumOffsets2};
|
|
|
|
// Make sure the element before prefixSumOffsets is 0, since we
|
|
// depend upon simple, boundary-less indexing to get proper results
|
|
CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
|
|
0,
|
|
sizeof(int),
|
|
stream));
|
|
CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
|
|
0,
|
|
sizeof(int),
|
|
stream));
|
|
|
|
DeviceTensor<float, 1, true> allDistances1(
|
|
mem, {queryTileSize * nprobe * maxListLength}, stream);
|
|
DeviceTensor<float, 1, true> allDistances2(
|
|
mem, {queryTileSize * nprobe * maxListLength}, stream);
|
|
DeviceTensor<float, 1, true>* allDistances[2] =
|
|
{&allDistances1, &allDistances2};
|
|
|
|
DeviceTensor<float, 3, true> heapDistances1(
|
|
mem, {queryTileSize, pass2Chunks, k}, stream);
|
|
DeviceTensor<float, 3, true> heapDistances2(
|
|
mem, {queryTileSize, pass2Chunks, k}, stream);
|
|
DeviceTensor<float, 3, true>* heapDistances[2] =
|
|
{&heapDistances1, &heapDistances2};
|
|
|
|
DeviceTensor<int, 3, true> heapIndices1(
|
|
mem, {queryTileSize, pass2Chunks, k}, stream);
|
|
DeviceTensor<int, 3, true> heapIndices2(
|
|
mem, {queryTileSize, pass2Chunks, k}, stream);
|
|
DeviceTensor<int, 3, true>* heapIndices[2] =
|
|
{&heapIndices1, &heapIndices2};
|
|
|
|
auto streams = res->getAlternateStreamsCurrentDevice();
|
|
streamWait(streams, {stream});
|
|
|
|
int curStream = 0;
|
|
|
|
for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
|
|
int numQueriesInTile =
|
|
std::min(queryTileSize, queries.getSize(0) - query);
|
|
|
|
auto prefixSumOffsetsView =
|
|
prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
|
|
|
|
auto listIdsView =
|
|
listIds.narrowOutermost(query, numQueriesInTile);
|
|
auto queryView =
|
|
queries.narrowOutermost(query, numQueriesInTile);
|
|
|
|
auto heapDistancesView =
|
|
heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
|
|
auto heapIndicesView =
|
|
heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
|
|
|
|
auto outDistanceView =
|
|
outDistances.narrowOutermost(query, numQueriesInTile);
|
|
auto outIndicesView =
|
|
outIndices.narrowOutermost(query, numQueriesInTile);
|
|
|
|
runIVFFlatScanTile(queryView,
|
|
listIdsView,
|
|
listData,
|
|
listIndices,
|
|
indicesOptions,
|
|
listLengths,
|
|
*thrustMem[curStream],
|
|
prefixSumOffsetsView,
|
|
*allDistances[curStream],
|
|
heapDistancesView,
|
|
heapIndicesView,
|
|
k,
|
|
l2Distance,
|
|
useFloat16,
|
|
outDistanceView,
|
|
outIndicesView,
|
|
streams[curStream]);
|
|
|
|
curStream = (curStream + 1) % 2;
|
|
}
|
|
|
|
streamWait({stream}, streams);
|
|
}
|
|
|
|
} } // namespace
|