faiss/gpu/impl/Distance.cu

499 lines
17 KiB
Plaintext

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD+Patents license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "Distance.cuh"
#include "BroadcastSum.cuh"
#include "L2Norm.cuh"
#include "L2Select.cuh"
#include "../../FaissAssert.h"
#include "../GpuResources.h"
#include "../utils/DeviceUtils.h"
#include "../utils/Limits.cuh"
#include "../utils/MatrixMult.cuh"
#include "../utils/BlockSelectKernel.cuh"
#include <memory>
#include <algorithm>
#include <thrust/fill.h>
#include <thrust/for_each.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
namespace faiss { namespace gpu {
namespace {
template <typename T>
Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
Tensor<T, 2, true>* centroidsTransposed,
int startCentroid,
int num) {
if (startCentroid == 0 && num == centroids.getSize(0)) {
if (centroidsTransposed) {
return *centroidsTransposed;
} else {
return centroids;
}
}
if (centroidsTransposed) {
// (dim, num)
return centroidsTransposed->narrow(1, startCentroid, num);
} else {
return centroids.narrow(0, startCentroid, num);
}
}
// For each chunk of k indices, increment the index by chunk * increment
template <typename T>
__global__ void incrementIndex(Tensor<T, 2, true> indices,
int k,
int increment) {
for (int i = threadIdx.x; i < k; i += blockDim.x) {
indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
}
}
// Used to update result indices in distance computation where the number of
// centroids is high, and is tiled
template <typename T>
void runIncrementIndex(Tensor<T, 2, true>& indices,
int k,
int increment,
cudaStream_t stream) {
dim3 grid(indices.getSize(1) / k, indices.getSize(0));
int block = std::min(k, 512);
// should be exact
FAISS_ASSERT(grid.x * k == indices.getSize(1));
incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
cudaDeviceSynchronize();
}
// If the inner size (dim) of the vectors is small, we want a larger query tile
// size, like 1024
void chooseTileSize(int numQueries,
int numCentroids,
int dim,
int elementSize,
size_t tempMemAvailable,
int& tileRows,
int& tileCols) {
// The matrix multiplication should be large enough to be efficient, but if it
// is too large, we seem to lose efficiency as opposed to double-streaming.
// Each tile size here defines 1/2 of the memory use due to double streaming.
// We ignore available temporary memory, as that is adjusted independently by
// the user and can thus meet these requirements (or not).
// For <= 4 GB GPUs, prefer 512 MB of usage.
// For <= 8 GB GPUs, prefer 768 MB of usage.
// Otherwise, prefer 1 GB of usage.
auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
int targetUsage = 0;
if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) {
targetUsage = 512 * 1024 * 1024;
} else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) {
targetUsage = 768 * 1024 * 1024;
} else {
targetUsage = 1024 * 1024 * 1024;
}
targetUsage /= 2 * elementSize;
// 512 seems to be a batch size sweetspot for float32.
// If we are on float16, increase to 512.
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
// increase to 1024.
int preferredTileRows = 512;
if (dim <= 32) {
preferredTileRows = 1024;
}
tileRows = std::min(preferredTileRows, numQueries);
// tileCols is the remainder size
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
}
}
template <typename T>
void runDistance(bool computeL2,
GpuResources* resources,
Tensor<T, 2, true>& centroids,
Tensor<T, 2, true>* centroidsTransposed,
Tensor<T, 1, true>* centroidNorms,
Tensor<T, 2, true>& queries,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm,
bool ignoreOutDistances) {
FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
FAISS_ASSERT(outDistances.getSize(1) == k);
FAISS_ASSERT(outIndices.getSize(1) == k);
auto& mem = resources->getMemoryManagerCurrentDevice();
auto defaultStream = resources->getDefaultStreamCurrentDevice();
// If we're quering against a 0 sized set, just return empty results
if (centroids.numElements() == 0) {
thrust::fill(thrust::cuda::par.on(defaultStream),
outDistances.data(), outDistances.end(),
Limits<T>::getMax());
thrust::fill(thrust::cuda::par.on(defaultStream),
outIndices.data(), outIndices.end(),
-1);
return;
}
// L2: If ||c||^2 is not pre-computed, calculate it
DeviceTensor<T, 1, true> cNorms;
if (computeL2 && !centroidNorms) {
cNorms = std::move(DeviceTensor<T, 1, true>(
mem,
{centroids.getSize(0)}, defaultStream));
runL2Norm(centroids, cNorms, true, defaultStream);
centroidNorms = &cNorms;
}
//
// Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
//
int qNormSize[1] = {queries.getSize(0)};
DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
// ||q||^2
if (computeL2) {
runL2Norm(queries, queryNorms, true, defaultStream);
}
// By default, aim to use up to 512 MB of memory for the processing, with both
// number of queries and number of centroids being at least 512.
int tileRows = 0;
int tileCols = 0;
chooseTileSize(queries.getSize(0),
centroids.getSize(0),
queries.getSize(1),
sizeof(T),
mem.getSizeAvailable(),
tileRows,
tileCols);
int numColTiles = utils::divUp(centroids.getSize(0), tileCols);
FAISS_ASSERT(k <= centroids.getSize(0));
FAISS_ASSERT(k <= 1024); // select limitation
// Temporary output memory space we'll use
DeviceTensor<T, 2, true> distanceBuf1(
mem, {tileRows, tileCols}, defaultStream);
DeviceTensor<T, 2, true> distanceBuf2(
mem, {tileRows, tileCols}, defaultStream);
DeviceTensor<T, 2, true>* distanceBufs[2] =
{&distanceBuf1, &distanceBuf2};
DeviceTensor<T, 2, true> outDistanceBuf1(
mem, {tileRows, numColTiles * k}, defaultStream);
DeviceTensor<T, 2, true> outDistanceBuf2(
mem, {tileRows, numColTiles * k}, defaultStream);
DeviceTensor<T, 2, true>* outDistanceBufs[2] =
{&outDistanceBuf1, &outDistanceBuf2};
DeviceTensor<int, 2, true> outIndexBuf1(
mem, {tileRows, numColTiles * k}, defaultStream);
DeviceTensor<int, 2, true> outIndexBuf2(
mem, {tileRows, numColTiles * k}, defaultStream);
DeviceTensor<int, 2, true>* outIndexBufs[2] =
{&outIndexBuf1, &outIndexBuf2};
auto streams = resources->getAlternateStreamsCurrentDevice();
streamWait(streams, {defaultStream});
int curStream = 0;
// Tile over the input queries
for (int i = 0; i < queries.getSize(0); i += tileRows) {
int curQuerySize = std::min(tileRows, queries.getSize(0) - i);
auto outDistanceView =
outDistances.narrow(0, i, curQuerySize);
auto outIndexView =
outIndices.narrow(0, i, curQuerySize);
auto queryView =
queries.narrow(0, i, curQuerySize);
auto queryNormNiew =
queryNorms.narrow(0, i, curQuerySize);
auto outDistanceBufRowView =
outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
auto outIndexBufRowView =
outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
// Tile over the centroids
for (int j = 0; j < centroids.getSize(0); j += tileCols) {
int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);
int curColTile = j / tileCols;
auto centroidsView =
sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);
auto distanceBufView = distanceBufs[curStream]->
narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
auto outDistanceBufColView =
outDistanceBufRowView.narrow(1, k * curColTile, k);
auto outIndexBufColView =
outIndexBufRowView.narrow(1, k * curColTile, k);
// L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc
// IP: just compute qc
// (query id x dim) x (centroid id, dim)' = (query id, centroid id)
runMatrixMult(distanceBufView, false,
queryView, false,
centroidsView,
centroidsTransposed ? false : true,
computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,
resources->getBlasHandleCurrentDevice(),
streams[curStream]);
if (computeL2) {
// For L2 distance, we use this fused kernel that performs both
// adding ||c||^2 to -2qc and k-selection, so we only need two
// passes (one write by the gemm, one read here) over the huge
// region of output memory
//
// If we aren't tiling along the number of centroids, we can perform the
// output work directly
if (tileCols == centroids.getSize(0)) {
// Write into the final output
runL2SelectMin(distanceBufView,
*centroidNorms,
outDistanceView,
outIndexView,
k,
streams[curStream]);
if (!ignoreOutDistances) {
// expand (query id) to (query id, k) by duplicating along rows
// top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
runSumAlongRows(queryNormNiew,
outDistanceView,
true, // L2 distances should not go below zero due
// to roundoff error
streams[curStream]);
}
} else {
auto centroidNormsView =
centroidNorms->narrow(0, j, curCentroidSize);
// Write into our intermediate output
runL2SelectMin(distanceBufView,
centroidNormsView,
outDistanceBufColView,
outIndexBufColView,
k,
streams[curStream]);
if (!ignoreOutDistances) {
// expand (query id) to (query id, k) by duplicating along rows
// top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
runSumAlongRows(queryNormNiew,
outDistanceBufColView,
true, // L2 distances should not go below zero due
// to roundoff error
streams[curStream]);
}
}
} else {
// For IP, just k-select the output for this tile
if (tileCols == centroids.getSize(0)) {
// Write into the final output
runBlockSelect(distanceBufView,
outDistanceView,
outIndexView,
true, k, streams[curStream]);
} else {
// Write into the intermediate output
runBlockSelect(distanceBufView,
outDistanceBufColView,
outIndexBufColView,
true, k, streams[curStream]);
}
}
}
// As we're finished with processing a full set of centroids, perform the
// final k-selection
if (tileCols != centroids.getSize(0)) {
// The indices are tile-relative; for each tile of k, we need to add
// tileCols to the index
runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
runBlockSelectPair(outDistanceBufRowView,
outIndexBufRowView,
outDistanceView,
outIndexView,
computeL2 ? false : true, k, streams[curStream]);
}
curStream = (curStream + 1) % 2;
}
// Have the desired ordering stream wait on the multi-stream
streamWait({defaultStream}, streams);
}
template <typename T>
void runL2Distance(GpuResources* resources,
Tensor<T, 2, true>& centroids,
Tensor<T, 2, true>* centroidsTransposed,
Tensor<T, 1, true>* centroidNorms,
Tensor<T, 2, true>& queries,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm,
bool ignoreOutDistances = false) {
runDistance<T>(true, // L2
resources,
centroids,
centroidsTransposed,
centroidNorms,
queries,
k,
outDistances,
outIndices,
useHgemm,
ignoreOutDistances);
}
template <typename T>
void runIPDistance(GpuResources* resources,
Tensor<T, 2, true>& centroids,
Tensor<T, 2, true>* centroidsTransposed,
Tensor<T, 2, true>& queries,
int k,
Tensor<T, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm) {
runDistance<T>(false, // IP
resources,
centroids,
centroidsTransposed,
nullptr,
queries,
k,
outDistances,
outIndices,
useHgemm,
false);
}
//
// Instantiations of the distance templates
//
void
runIPDistance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
Tensor<float, 2, true>* vectorsTransposed,
Tensor<float, 2, true>& queries,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices) {
runIPDistance<float>(resources,
vectors,
vectorsTransposed,
queries,
k,
outDistances,
outIndices,
false);
}
#ifdef FAISS_USE_FLOAT16
void
runIPDistance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
Tensor<half, 2, true>* vectorsTransposed,
Tensor<half, 2, true>& queries,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm) {
runIPDistance<half>(resources,
vectors,
vectorsTransposed,
queries,
k,
outDistances,
outIndices,
useHgemm);
}
#endif
void
runL2Distance(GpuResources* resources,
Tensor<float, 2, true>& vectors,
Tensor<float, 2, true>* vectorsTransposed,
Tensor<float, 1, true>* vectorNorms,
Tensor<float, 2, true>& queries,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool ignoreOutDistances) {
runL2Distance<float>(resources,
vectors,
vectorsTransposed,
vectorNorms,
queries,
k,
outDistances,
outIndices,
false,
ignoreOutDistances);
}
#ifdef FAISS_USE_FLOAT16
void
runL2Distance(GpuResources* resources,
Tensor<half, 2, true>& vectors,
Tensor<half, 2, true>* vectorsTransposed,
Tensor<half, 1, true>* vectorNorms,
Tensor<half, 2, true>& queries,
int k,
Tensor<half, 2, true>& outDistances,
Tensor<int, 2, true>& outIndices,
bool useHgemm,
bool ignoreOutDistances) {
runL2Distance<half>(resources,
vectors,
vectorsTransposed,
vectorNorms,
queries,
k,
outDistances,
outIndices,
useHgemm,
ignoreOutDistances);
}
#endif
} } // namespace