/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD+Patents license found in the * LICENSE file in the root directory of this source tree. */ #include "Distance.cuh" #include "BroadcastSum.cuh" #include "L2Norm.cuh" #include "L2Select.cuh" #include "../../FaissAssert.h" #include "../GpuResources.h" #include "../utils/DeviceUtils.h" #include "../utils/Limits.cuh" #include "../utils/MatrixMult.cuh" #include "../utils/BlockSelectKernel.cuh" #include #include #include #include #include #include namespace faiss { namespace gpu { namespace { template Tensor sliceCentroids(Tensor& centroids, Tensor* centroidsTransposed, int startCentroid, int num) { if (startCentroid == 0 && num == centroids.getSize(0)) { if (centroidsTransposed) { return *centroidsTransposed; } else { return centroids; } } if (centroidsTransposed) { // (dim, num) return centroidsTransposed->narrow(1, startCentroid, num); } else { return centroids.narrow(0, startCentroid, num); } } // For each chunk of k indices, increment the index by chunk * increment template __global__ void incrementIndex(Tensor indices, int k, int increment) { for (int i = threadIdx.x; i < k; i += blockDim.x) { indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment; } } // Used to update result indices in distance computation where the number of // centroids is high, and is tiled template void runIncrementIndex(Tensor& indices, int k, int increment, cudaStream_t stream) { dim3 grid(indices.getSize(1) / k, indices.getSize(0)); int block = std::min(k, 512); // should be exact FAISS_ASSERT(grid.x * k == indices.getSize(1)); incrementIndex<<>>(indices, k, increment); cudaDeviceSynchronize(); } // If the inner size (dim) of the vectors is small, we want a larger query tile // size, like 1024 void chooseTileSize(int numQueries, int numCentroids, int dim, int elementSize, size_t tempMemAvailable, int& tileRows, int& tileCols) { // The matrix multiplication should be large enough to be efficient, but if it // is too large, we seem to lose efficiency as opposed to double-streaming. // Each tile size here defines 1/2 of the memory use due to double streaming. // We ignore available temporary memory, as that is adjusted independently by // the user and can thus meet these requirements (or not). // For <= 4 GB GPUs, prefer 512 MB of usage. // For <= 8 GB GPUs, prefer 768 MB of usage. // Otherwise, prefer 1 GB of usage. auto totalMem = getCurrentDeviceProperties().totalGlobalMem; int targetUsage = 0; if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) { targetUsage = 512 * 1024 * 1024; } else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) { targetUsage = 768 * 1024 * 1024; } else { targetUsage = 1024 * 1024 * 1024; } targetUsage /= 2 * elementSize; // 512 seems to be a batch size sweetspot for float32. // If we are on float16, increase to 512. // If the k size (vec dim) of the matrix multiplication is small (<= 32), // increase to 1024. int preferredTileRows = 512; if (dim <= 32) { preferredTileRows = 1024; } tileRows = std::min(preferredTileRows, numQueries); // tileCols is the remainder size tileCols = std::min(targetUsage / preferredTileRows, numCentroids); } } template void runDistance(bool computeL2, GpuResources* resources, Tensor& centroids, Tensor* centroidsTransposed, Tensor* centroidNorms, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool useHgemm, bool ignoreOutDistances) { FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0)); FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0)); FAISS_ASSERT(outDistances.getSize(1) == k); FAISS_ASSERT(outIndices.getSize(1) == k); auto& mem = resources->getMemoryManagerCurrentDevice(); auto defaultStream = resources->getDefaultStreamCurrentDevice(); // If we're quering against a 0 sized set, just return empty results if (centroids.numElements() == 0) { thrust::fill(thrust::cuda::par.on(defaultStream), outDistances.data(), outDistances.end(), Limits::getMax()); thrust::fill(thrust::cuda::par.on(defaultStream), outIndices.data(), outIndices.end(), -1); return; } // L2: If ||c||^2 is not pre-computed, calculate it DeviceTensor cNorms; if (computeL2 && !centroidNorms) { cNorms = std::move(DeviceTensor( mem, {centroids.getSize(0)}, defaultStream)); runL2Norm(centroids, cNorms, true, defaultStream); centroidNorms = &cNorms; } // // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed // int qNormSize[1] = {queries.getSize(0)}; DeviceTensor queryNorms(mem, qNormSize, defaultStream); // ||q||^2 if (computeL2) { runL2Norm(queries, queryNorms, true, defaultStream); } // By default, aim to use up to 512 MB of memory for the processing, with both // number of queries and number of centroids being at least 512. int tileRows = 0; int tileCols = 0; chooseTileSize(queries.getSize(0), centroids.getSize(0), queries.getSize(1), sizeof(T), mem.getSizeAvailable(), tileRows, tileCols); int numColTiles = utils::divUp(centroids.getSize(0), tileCols); FAISS_ASSERT(k <= centroids.getSize(0)); FAISS_ASSERT(k <= 1024); // select limitation // Temporary output memory space we'll use DeviceTensor distanceBuf1( mem, {tileRows, tileCols}, defaultStream); DeviceTensor distanceBuf2( mem, {tileRows, tileCols}, defaultStream); DeviceTensor* distanceBufs[2] = {&distanceBuf1, &distanceBuf2}; DeviceTensor outDistanceBuf1( mem, {tileRows, numColTiles * k}, defaultStream); DeviceTensor outDistanceBuf2( mem, {tileRows, numColTiles * k}, defaultStream); DeviceTensor* outDistanceBufs[2] = {&outDistanceBuf1, &outDistanceBuf2}; DeviceTensor outIndexBuf1( mem, {tileRows, numColTiles * k}, defaultStream); DeviceTensor outIndexBuf2( mem, {tileRows, numColTiles * k}, defaultStream); DeviceTensor* outIndexBufs[2] = {&outIndexBuf1, &outIndexBuf2}; auto streams = resources->getAlternateStreamsCurrentDevice(); streamWait(streams, {defaultStream}); int curStream = 0; // Tile over the input queries for (int i = 0; i < queries.getSize(0); i += tileRows) { int curQuerySize = std::min(tileRows, queries.getSize(0) - i); auto outDistanceView = outDistances.narrow(0, i, curQuerySize); auto outIndexView = outIndices.narrow(0, i, curQuerySize); auto queryView = queries.narrow(0, i, curQuerySize); auto queryNormNiew = queryNorms.narrow(0, i, curQuerySize); auto outDistanceBufRowView = outDistanceBufs[curStream]->narrow(0, 0, curQuerySize); auto outIndexBufRowView = outIndexBufs[curStream]->narrow(0, 0, curQuerySize); // Tile over the centroids for (int j = 0; j < centroids.getSize(0); j += tileCols) { int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j); int curColTile = j / tileCols; auto centroidsView = sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize); auto distanceBufView = distanceBufs[curStream]-> narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize); auto outDistanceBufColView = outDistanceBufRowView.narrow(1, k * curColTile, k); auto outIndexBufColView = outIndexBufRowView.narrow(1, k * curColTile, k); // L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc // IP: just compute qc // (query id x dim) x (centroid id, dim)' = (query id, centroid id) runMatrixMult(distanceBufView, false, queryView, false, centroidsView, centroidsTransposed ? false : true, computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm, resources->getBlasHandleCurrentDevice(), streams[curStream]); if (computeL2) { // For L2 distance, we use this fused kernel that performs both // adding ||c||^2 to -2qc and k-selection, so we only need two // passes (one write by the gemm, one read here) over the huge // region of output memory // // If we aren't tiling along the number of centroids, we can perform the // output work directly if (tileCols == centroids.getSize(0)) { // Write into the final output runL2SelectMin(distanceBufView, *centroidNorms, outDistanceView, outIndexView, k, streams[curStream]); if (!ignoreOutDistances) { // expand (query id) to (query id, k) by duplicating along rows // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k) runSumAlongRows(queryNormNiew, outDistanceView, true, // L2 distances should not go below zero due // to roundoff error streams[curStream]); } } else { auto centroidNormsView = centroidNorms->narrow(0, j, curCentroidSize); // Write into our intermediate output runL2SelectMin(distanceBufView, centroidNormsView, outDistanceBufColView, outIndexBufColView, k, streams[curStream]); if (!ignoreOutDistances) { // expand (query id) to (query id, k) by duplicating along rows // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k) runSumAlongRows(queryNormNiew, outDistanceBufColView, true, // L2 distances should not go below zero due // to roundoff error streams[curStream]); } } } else { // For IP, just k-select the output for this tile if (tileCols == centroids.getSize(0)) { // Write into the final output runBlockSelect(distanceBufView, outDistanceView, outIndexView, true, k, streams[curStream]); } else { // Write into the intermediate output runBlockSelect(distanceBufView, outDistanceBufColView, outIndexBufColView, true, k, streams[curStream]); } } } // As we're finished with processing a full set of centroids, perform the // final k-selection if (tileCols != centroids.getSize(0)) { // The indices are tile-relative; for each tile of k, we need to add // tileCols to the index runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]); runBlockSelectPair(outDistanceBufRowView, outIndexBufRowView, outDistanceView, outIndexView, computeL2 ? false : true, k, streams[curStream]); } curStream = (curStream + 1) % 2; } // Have the desired ordering stream wait on the multi-stream streamWait({defaultStream}, streams); } template void runL2Distance(GpuResources* resources, Tensor& centroids, Tensor* centroidsTransposed, Tensor* centroidNorms, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool useHgemm, bool ignoreOutDistances = false) { runDistance(true, // L2 resources, centroids, centroidsTransposed, centroidNorms, queries, k, outDistances, outIndices, useHgemm, ignoreOutDistances); } template void runIPDistance(GpuResources* resources, Tensor& centroids, Tensor* centroidsTransposed, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool useHgemm) { runDistance(false, // IP resources, centroids, centroidsTransposed, nullptr, queries, k, outDistances, outIndices, useHgemm, false); } // // Instantiations of the distance templates // void runIPDistance(GpuResources* resources, Tensor& vectors, Tensor* vectorsTransposed, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices) { runIPDistance(resources, vectors, vectorsTransposed, queries, k, outDistances, outIndices, false); } #ifdef FAISS_USE_FLOAT16 void runIPDistance(GpuResources* resources, Tensor& vectors, Tensor* vectorsTransposed, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool useHgemm) { runIPDistance(resources, vectors, vectorsTransposed, queries, k, outDistances, outIndices, useHgemm); } #endif void runL2Distance(GpuResources* resources, Tensor& vectors, Tensor* vectorsTransposed, Tensor* vectorNorms, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool ignoreOutDistances) { runL2Distance(resources, vectors, vectorsTransposed, vectorNorms, queries, k, outDistances, outIndices, false, ignoreOutDistances); } #ifdef FAISS_USE_FLOAT16 void runL2Distance(GpuResources* resources, Tensor& vectors, Tensor* vectorsTransposed, Tensor* vectorNorms, Tensor& queries, int k, Tensor& outDistances, Tensor& outIndices, bool useHgemm, bool ignoreOutDistances) { runL2Distance(resources, vectors, vectorsTransposed, vectorNorms, queries, k, outDistances, outIndices, useHgemm, ignoreOutDistances); } #endif } } // namespace