562 lines
18 KiB
Plaintext
562 lines
18 KiB
Plaintext
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
|
|
#include <faiss/gpu/impl/BroadcastSum.cuh>
|
|
#include <faiss/gpu/impl/Distance.cuh>
|
|
#include <faiss/gpu/impl/L2Norm.cuh>
|
|
#include <faiss/gpu/utils/ConversionOperators.cuh>
|
|
#include <faiss/gpu/utils/DeviceDefs.cuh>
|
|
#include <faiss/gpu/utils/DeviceUtils.h>
|
|
#include <faiss/gpu/utils/Float16.cuh>
|
|
#include <faiss/gpu/utils/MatrixMult.cuh>
|
|
#include <faiss/gpu/utils/PtxUtils.cuh>
|
|
#include <faiss/gpu/utils/StaticUtils.h>
|
|
#include <faiss/gpu/utils/Transpose.cuh>
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
// Kernel responsible for calculating distance from residual vector to
|
|
// each product quantizer code centroid
|
|
template <typename OutCodeT,
|
|
typename CentroidT,
|
|
int DimsPerSubQuantizer,
|
|
bool L2Distance>
|
|
__global__ void
|
|
__launch_bounds__(288, 4)
|
|
pqCodeDistances(Tensor<float, 2, true> queries,
|
|
int queriesPerBlock,
|
|
Tensor<CentroidT, 2, true> coarseCentroids,
|
|
Tensor<float, 3, true> pqCentroids,
|
|
Tensor<int, 2, true> topQueryToCentroid,
|
|
// (query id)(coarse)(subquantizer)(code) -> dist
|
|
Tensor<OutCodeT, 4, true> outCodeDistances) {
|
|
const auto numSubQuantizers = pqCentroids.getSize(0);
|
|
const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
|
|
assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
|
|
const auto codesPerSubQuantizer = pqCentroids.getSize(2);
|
|
|
|
bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
|
|
int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
|
|
|
|
extern __shared__ float smem[];
|
|
|
|
// Each thread calculates a single code
|
|
float subQuantizerData[DimsPerSubQuantizer];
|
|
|
|
auto code = threadIdx.x;
|
|
auto subQuantizer = blockIdx.y;
|
|
|
|
// Each thread will load the pq centroid data for the code that it
|
|
// is processing
|
|
#pragma unroll
|
|
for (int i = 0; i < DimsPerSubQuantizer; ++i) {
|
|
subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
|
|
}
|
|
|
|
// Where we store our query vector
|
|
float* smemQuery = smem;
|
|
|
|
// Where we store our residual vector; this is double buffered so we
|
|
// can be loading the next one while processing the current one
|
|
float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
|
|
float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
|
|
|
|
// Where we pre-load the coarse centroid IDs
|
|
int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer];
|
|
|
|
// Each thread is calculating the distance for a single code,
|
|
// performing the reductions locally
|
|
|
|
// Handle multiple queries per block
|
|
auto startQueryId = blockIdx.x * queriesPerBlock;
|
|
auto numQueries = queries.getSize(0) - startQueryId;
|
|
if (numQueries > queriesPerBlock) {
|
|
numQueries = queriesPerBlock;
|
|
}
|
|
|
|
for (int query = 0; query < numQueries; ++query) {
|
|
auto queryId = startQueryId + query;
|
|
|
|
auto querySubQuantizer =
|
|
queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
|
|
|
|
// Load current query vector
|
|
for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
|
|
smemQuery[i] = querySubQuantizer[i];
|
|
}
|
|
|
|
// Load list of coarse centroids found
|
|
for (int i = threadIdx.x;
|
|
i < topQueryToCentroid.getSize(1); i += blockDim.x) {
|
|
coarseIds[i] = topQueryToCentroid[queryId][i];
|
|
}
|
|
|
|
// We need coarseIds below
|
|
// FIXME: investigate loading separately, so we don't need this
|
|
__syncthreads();
|
|
|
|
// Preload first buffer of residual data
|
|
if (isLoadingThread) {
|
|
for (int i = loadingThreadId;
|
|
i < DimsPerSubQuantizer;
|
|
i += blockDim.x - codesPerSubQuantizer) {
|
|
auto coarseId = coarseIds[0];
|
|
// In case NaNs were in the original query data
|
|
coarseId = coarseId == -1 ? 0 : coarseId;
|
|
auto coarseCentroidSubQuantizer =
|
|
coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
|
|
|
|
if (L2Distance) {
|
|
smemResidual1[i] = smemQuery[i] -
|
|
ConvertTo<float>::to(coarseCentroidSubQuantizer[i]);
|
|
} else {
|
|
smemResidual1[i] =
|
|
ConvertTo<float>::to(coarseCentroidSubQuantizer[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// The block walks the list for a single query
|
|
for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
|
|
// Wait for smemResidual1 to be loaded
|
|
__syncthreads();
|
|
|
|
if (isLoadingThread) {
|
|
// Preload second buffer of residual data
|
|
for (int i = loadingThreadId;
|
|
i < DimsPerSubQuantizer;
|
|
i += blockDim.x - codesPerSubQuantizer) {
|
|
// FIXME: try always making this centroid id 0 so we can
|
|
// terminate
|
|
if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
|
|
auto coarseId = coarseIds[coarse + 1];
|
|
// In case NaNs were in the original query data
|
|
coarseId = coarseId == -1 ? 0 : coarseId;
|
|
|
|
auto coarseCentroidSubQuantizer =
|
|
coarseCentroids[coarseId]
|
|
[subQuantizer * dimsPerSubQuantizer].data();
|
|
|
|
if (L2Distance) {
|
|
smemResidual2[i] = smemQuery[i] -
|
|
ConvertTo<float>::to(coarseCentroidSubQuantizer[i]);
|
|
} else {
|
|
smemResidual2[i] =
|
|
ConvertTo<float>::to(coarseCentroidSubQuantizer[i]);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// These are the processing threads
|
|
float dist = 0.0f;
|
|
|
|
constexpr int kUnroll = 4;
|
|
constexpr int kRemainder = DimsPerSubQuantizer % kUnroll;
|
|
constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder;
|
|
float vals[kUnroll];
|
|
|
|
// Calculate residual - pqCentroid for each dim that we're
|
|
// processing
|
|
|
|
// Unrolled loop
|
|
if (L2Distance) {
|
|
#pragma unroll
|
|
for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] = smemResidual1[i * kUnroll + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] -= subQuantizerData[i * kUnroll + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] *= vals[j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist += vals[j];
|
|
}
|
|
}
|
|
} else {
|
|
// Inner product: query slice against the reconstructed sub-quantizer
|
|
// for this coarse cell (query o (centroid + subQCentroid))
|
|
#pragma unroll
|
|
for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] = smemResidual1[i * kUnroll + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] += subQuantizerData[i * kUnroll + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
vals[j] *= smemQuery[i * kUnroll + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kUnroll; ++j) {
|
|
dist += vals[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remainder loop
|
|
if (L2Distance) {
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] = smemResidual1[kRemainderBase + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] -= subQuantizerData[kRemainderBase + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] *= vals[j];
|
|
}
|
|
} else {
|
|
// Inner product
|
|
// Inner product: query slice against the reconstructed sub-quantizer
|
|
// for this coarse cell (query o (centroid + subQCentroid))
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] = smemResidual1[kRemainderBase + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] += subQuantizerData[kRemainderBase + j];
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
vals[j] *= smemQuery[kRemainderBase + j];
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < kRemainder; ++j) {
|
|
dist += vals[j];
|
|
}
|
|
|
|
// We have the distance for our code; write it out
|
|
outCodeDistances[queryId][coarse][subQuantizer][code] =
|
|
ConvertTo<OutCodeT>::to(dist);
|
|
} // !isLoadingThread
|
|
|
|
// Swap residual buffers
|
|
float* tmp = smemResidual1;
|
|
smemResidual1 = smemResidual2;
|
|
smemResidual2 = tmp;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename CentroidT>
|
|
__global__ void
|
|
residualVector(Tensor<float, 2, true> queries,
|
|
Tensor<CentroidT, 2, true> coarseCentroids,
|
|
Tensor<int, 2, true> topQueryToCentroid,
|
|
int numSubDim,
|
|
// output is transposed:
|
|
// (sub q)(query id)(centroid id)(sub dim)
|
|
Tensor<float, 4, true> residual) {
|
|
// block x is query id
|
|
// block y is centroid id
|
|
// thread x is dim
|
|
auto queryId = blockIdx.x;
|
|
auto centroidId = blockIdx.y;
|
|
|
|
int realCentroidId = topQueryToCentroid[queryId][centroidId];
|
|
|
|
for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
|
|
float q = queries[queryId][dim];
|
|
float c = ConvertTo<float>::to(coarseCentroids[realCentroidId][dim]);
|
|
|
|
residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = q - c;
|
|
}
|
|
}
|
|
|
|
template <typename CentroidT>
|
|
void
|
|
runResidualVector(Tensor<float, 3, true>& pqCentroids,
|
|
Tensor<float, 2, true>& queries,
|
|
Tensor<CentroidT, 2, true>& coarseCentroids,
|
|
Tensor<int, 2, true>& topQueryToCentroid,
|
|
Tensor<float, 4, true>& residual,
|
|
cudaStream_t stream) {
|
|
auto grid =
|
|
dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
|
|
auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
|
|
|
|
residualVector<<<grid, block, 0, stream>>>(
|
|
queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
|
|
residual);
|
|
|
|
CUDA_TEST_ERROR();
|
|
}
|
|
|
|
template <typename CentroidT>
|
|
void
|
|
runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
|
|
Tensor<float, 2, true>& queries,
|
|
Tensor<CentroidT, 2, true>& coarseCentroids,
|
|
Tensor<int, 2, true>& topQueryToCentroid,
|
|
NoTypeTensor<4, true>& outCodeDistances,
|
|
bool useFloat16Lookup,
|
|
DeviceMemory& mem,
|
|
cublasHandle_t handle,
|
|
cudaStream_t stream) {
|
|
// Calculate (q - c) residual vector
|
|
// (sub q)(query id)(centroid id)(sub dim)
|
|
DeviceTensor<float, 4, true> residual(
|
|
mem,
|
|
{pqCentroids.getSize(0),
|
|
topQueryToCentroid.getSize(0),
|
|
topQueryToCentroid.getSize(1),
|
|
pqCentroids.getSize(1)},
|
|
stream);
|
|
|
|
runResidualVector(pqCentroids, queries,
|
|
coarseCentroids, topQueryToCentroid,
|
|
residual, stream);
|
|
|
|
// Calculate ||q - c||^2
|
|
DeviceTensor<float, 1, true> residualNorms(
|
|
mem,
|
|
{pqCentroids.getSize(0) *
|
|
topQueryToCentroid.getSize(0) *
|
|
topQueryToCentroid.getSize(1)},
|
|
stream);
|
|
|
|
auto residualView2 = residual.view<2>(
|
|
{pqCentroids.getSize(0) *
|
|
topQueryToCentroid.getSize(0) *
|
|
topQueryToCentroid.getSize(1),
|
|
pqCentroids.getSize(1)});
|
|
|
|
runL2Norm(residualView2, true, residualNorms, true, stream);
|
|
|
|
// Perform a batch MM:
|
|
// (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
|
|
// (sub q) x {(q * c)(code)}
|
|
auto residualView3 = residual.view<3>(
|
|
{pqCentroids.getSize(0),
|
|
topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
|
pqCentroids.getSize(1)});
|
|
|
|
DeviceTensor<float, 3, true> residualDistance(
|
|
mem,
|
|
{pqCentroids.getSize(0),
|
|
topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
|
pqCentroids.getSize(2)},
|
|
stream);
|
|
|
|
runIteratedMatrixMult(residualDistance, false,
|
|
residualView3, false,
|
|
pqCentroids, false,
|
|
-2.0f, 0.0f,
|
|
handle,
|
|
stream);
|
|
|
|
// Sum ||q - c||^2 along rows
|
|
auto residualDistanceView2 = residualDistance.view<2>(
|
|
{pqCentroids.getSize(0) *
|
|
topQueryToCentroid.getSize(0) *
|
|
topQueryToCentroid.getSize(1),
|
|
pqCentroids.getSize(2)});
|
|
|
|
runSumAlongRows(residualNorms, residualDistanceView2, false, stream);
|
|
|
|
Tensor<float, 4, true> outCodeDistancesF;
|
|
DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
|
|
|
|
if (useFloat16Lookup) {
|
|
outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
|
|
mem, {outCodeDistances.getSize(0),
|
|
outCodeDistances.getSize(1),
|
|
outCodeDistances.getSize(2),
|
|
outCodeDistances.getSize(3)},
|
|
stream);
|
|
|
|
outCodeDistancesF = outCodeDistancesFloatMem;
|
|
} else {
|
|
outCodeDistancesF = outCodeDistances.toTensor<float>();
|
|
}
|
|
|
|
// Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
|
|
// is where we build our output distances)
|
|
auto outCodeDistancesView = outCodeDistancesF.view<3>(
|
|
{topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
|
outCodeDistances.getSize(2),
|
|
outCodeDistances.getSize(3)});
|
|
|
|
runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
|
|
|
|
// Calculate code norms per each sub-dim
|
|
// (sub q)(sub dim)(code) is pqCentroids
|
|
// transpose to (sub q)(code)(sub dim)
|
|
DeviceTensor<float, 3, true> pqCentroidsTranspose(
|
|
mem,
|
|
{pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
|
|
stream);
|
|
|
|
runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
|
|
|
|
auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
|
|
{pqCentroids.getSize(0) * pqCentroids.getSize(2),
|
|
pqCentroids.getSize(1)});
|
|
|
|
DeviceTensor<float, 1, true> pqCentroidsNorm(
|
|
mem,
|
|
{pqCentroids.getSize(0) * pqCentroids.getSize(2)},
|
|
stream);
|
|
|
|
runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream);
|
|
|
|
// View output as (q * c)(sub q * code), and add centroid norm to
|
|
// each row
|
|
auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
|
|
{topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
|
outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
|
|
|
|
runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
|
|
|
|
if (useFloat16Lookup) {
|
|
// Need to convert back
|
|
auto outCodeDistancesH = outCodeDistances.toTensor<half>();
|
|
convertTensor<float, half, 4>(stream,
|
|
outCodeDistancesF,
|
|
outCodeDistancesH);
|
|
}
|
|
}
|
|
|
|
template <typename CentroidT>
|
|
void
|
|
runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|
Tensor<float, 2, true>& queries,
|
|
Tensor<CentroidT, 2, true>& coarseCentroids,
|
|
Tensor<int, 2, true>& topQueryToCentroid,
|
|
NoTypeTensor<4, true>& outCodeDistances,
|
|
bool l2Distance,
|
|
bool useFloat16Lookup,
|
|
cudaStream_t stream) {
|
|
const auto numSubQuantizers = pqCentroids.getSize(0);
|
|
const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
|
|
const auto codesPerSubQuantizer = pqCentroids.getSize(2);
|
|
|
|
// FIXME: tune
|
|
// Reuse of pq centroid data is based on both # of queries * nprobe,
|
|
// and we should really be tiling in both dimensions
|
|
constexpr int kQueriesPerBlock = 8;
|
|
|
|
auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
|
|
numSubQuantizers);
|
|
|
|
// Reserve one block of threads for double buffering
|
|
// FIXME: probably impractical for large # of dims?
|
|
auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
|
|
auto block = dim3(codesPerSubQuantizer + loadingThreads);
|
|
|
|
auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
|
|
+ topQueryToCentroid.getSize(1) * sizeof(int);
|
|
|
|
#define RUN_CODE(DIMS, L2) \
|
|
do { \
|
|
if (useFloat16Lookup) { \
|
|
auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
|
|
\
|
|
pqCodeDistances<half, CentroidT, DIMS, L2><<<grid, block, smem, stream>>>( \
|
|
queries, kQueriesPerBlock, \
|
|
coarseCentroids, pqCentroids, \
|
|
topQueryToCentroid, outCodeDistancesT); \
|
|
} else { \
|
|
auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
|
|
\
|
|
pqCodeDistances<float, CentroidT, DIMS, L2><<<grid, block, smem, stream>>>( \
|
|
queries, kQueriesPerBlock, \
|
|
coarseCentroids, pqCentroids, \
|
|
topQueryToCentroid, outCodeDistancesT); \
|
|
} \
|
|
} while (0)
|
|
|
|
#define CODE_L2(DIMS) \
|
|
do { \
|
|
if (l2Distance) { \
|
|
RUN_CODE(DIMS, true); \
|
|
} else { \
|
|
RUN_CODE(DIMS, false); \
|
|
} \
|
|
} while (0)
|
|
|
|
switch (dimsPerSubQuantizer) {
|
|
case 1:
|
|
CODE_L2(1);
|
|
break;
|
|
case 2:
|
|
CODE_L2(2);
|
|
break;
|
|
case 3:
|
|
CODE_L2(3);
|
|
break;
|
|
case 4:
|
|
CODE_L2(4);
|
|
break;
|
|
case 6:
|
|
CODE_L2(6);
|
|
break;
|
|
case 8:
|
|
CODE_L2(8);
|
|
break;
|
|
case 10:
|
|
CODE_L2(10);
|
|
break;
|
|
case 12:
|
|
CODE_L2(12);
|
|
break;
|
|
case 16:
|
|
CODE_L2(16);
|
|
break;
|
|
case 20:
|
|
CODE_L2(20);
|
|
break;
|
|
case 24:
|
|
CODE_L2(24);
|
|
break;
|
|
case 28:
|
|
CODE_L2(28);
|
|
break;
|
|
case 32:
|
|
CODE_L2(32);
|
|
break;
|
|
// FIXME: larger sizes require too many registers - we need the
|
|
// MM implementation working
|
|
default:
|
|
FAISS_THROW_MSG("Too many dimensions (>32) per subquantizer "
|
|
"not currently supported");
|
|
}
|
|
|
|
#undef RUN_CODE
|
|
#undef CODE_L2
|
|
|
|
CUDA_TEST_ERROR();
|
|
}
|
|
|
|
} } // namespace
|