faiss/gpu/impl/PQScanMultiPassPrecomputed.cu

564 lines
19 KiB
Plaintext

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved.
#include "PQScanMultiPassPrecomputed.cuh"
#include "../GpuResources.h"
#include "PQCodeLoad.cuh"
#include "IVFUtils.cuh"
#include "../utils/ConversionOperators.cuh"
#include "../utils/DeviceTensor.cuh"
#include "../utils/DeviceUtils.h"
#include "../utils/Float16.cuh"
#include "../utils/LoadStoreOperators.cuh"
#include "../utils/MathOperators.cuh"
#include "../utils/StaticUtils.h"
#include <limits>
namespace faiss { namespace gpu {
// For precomputed codes, this calculates and loads code distances
// into smem
template <typename LookupT, typename LookupVecT>
inline __device__ void
loadPrecomputedTerm(LookupT* smem,
LookupT* term2Start,
LookupT* term3Start,
int numCodes) {
constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
// We can only use vector loads if the data is guaranteed to be
// aligned. The codes are innermost, so if it is evenly divisible,
// then any slice will be aligned.
if (numCodes % kWordSize == 0) {
constexpr int kUnroll = 2;
// Load the data by float4 for efficiency, and then handle any remainder
// limitVec is the number of whole vec words we can load, in terms
// of whole blocks performing the load
int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
limitVec *= kUnroll * blockDim.x;
LookupVecT* smemV = (LookupVecT*) smem;
LookupVecT* term2StartV = (LookupVecT*) term2Start;
LookupVecT* term3StartV = (LookupVecT*) term3Start;
for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
LookupVecT vals[kUnroll];
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
vals[j] =
LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
LookupVecT q =
LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
vals[j] = Math<LookupVecT>::add(vals[j], q);
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
}
}
// This is where we start loading the remainder that does not evenly
// fit into kUnroll x blockDim.x
int remainder = limitVec * kWordSize;
for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
}
} else {
// Potential unaligned load
constexpr int kUnroll = 4;
int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
int i = threadIdx.x;
for (; i < limit; i += kUnroll * blockDim.x) {
LookupT vals[kUnroll];
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
vals[j] = term2Start[i + j * blockDim.x];
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
smem[i + j * blockDim.x] = vals[j];
}
}
for (; i < numCodes; i += blockDim.x) {
smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
}
}
}
template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
__global__ void
pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
Tensor<float, 2, true> precompTerm1,
Tensor<LookupT, 3, true> precompTerm2,
Tensor<LookupT, 3, true> precompTerm3,
Tensor<int, 2, true> topQueryToCentroid,
void** listCodes,
int* listLengths,
Tensor<int, 2, true> prefixSumOffsets,
Tensor<float, 1, true> distance) {
// precomputed term 2 + 3 storage
// (sub q)(code id)
extern __shared__ char smemTerm23[];
LookupT* term23 = (LookupT*) smemTerm23;
// Each block handles a single query
auto queryId = blockIdx.y;
auto probeId = blockIdx.x;
auto codesPerSubQuantizer = precompTerm2.getSize(2);
auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
// This is where we start writing out data
// We ensure that before the array (at offset -1), there is a 0 value
int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
float* distanceOut = distance[outBase].data();
auto listId = topQueryToCentroid[queryId][probeId];
// Safety guard in case NaNs in input cause no list ID to be generated
if (listId == -1) {
return;
}
unsigned char* codeList = (unsigned char*) listCodes[listId];
int limit = listLengths[listId];
constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
(NumSubQuantizers / 4);
unsigned int code32[kNumCode32];
unsigned int nextCode32[kNumCode32];
// We double-buffer the code loading, which improves memory utilization
if (threadIdx.x < limit) {
LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
}
// Load precomputed terms 1, 2, 3
float term1 = precompTerm1[queryId][probeId];
loadPrecomputedTerm<LookupT, LookupVecT>(term23,
precompTerm2[listId].data(),
precompTerm3[queryId].data(),
precompTermSize);
// Prevent WAR dependencies
__syncthreads();
// Each thread handles one code element in the list, with a
// block-wide stride
for (int codeIndex = threadIdx.x;
codeIndex < limit;
codeIndex += blockDim.x) {
// Prefetch next codes
if (codeIndex + blockDim.x < limit) {
LoadCode32<NumSubQuantizers>::load(
nextCode32, codeList, codeIndex + blockDim.x);
}
float dist = term1;
#pragma unroll
for (int word = 0; word < kNumCode32; ++word) {
constexpr int kBytesPerCode32 =
NumSubQuantizers < 4 ? NumSubQuantizers : 4;
if (kBytesPerCode32 == 1) {
auto code = code32[0];
dist = ConvertTo<float>::to(term23[code]);
} else {
#pragma unroll
for (int byte = 0; byte < kBytesPerCode32; ++byte) {
auto code = getByte(code32[word], byte * 8, 8);
auto offset =
codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
dist += ConvertTo<float>::to(term23[offset + code]);
}
}
}
// Write out intermediate distance result
// We do not maintain indices here, in order to reduce global
// memory traffic. Those are recovered in the final selection step.
distanceOut[codeIndex] = dist;
// Rotate buffers
#pragma unroll
for (int word = 0; word < kNumCode32; ++word) {
code32[word] = nextCode32[word];
}
}
}
void
runMultiPassTile(Tensor<float, 2, true>& queries,
Tensor<float, 2, true>& precompTerm1,
NoTypeTensor<3, true>& precompTerm2,
NoTypeTensor<3, true>& precompTerm3,
Tensor<int, 2, true>& topQueryToCentroid,
bool useFloat16Lookup,
int bytesPerCode,
int numSubQuantizers,
int numSubQuantizerCodes,
thrust::device_vector<void*>& listCodes,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
thrust::device_vector<int>& listLengths,
Tensor<char, 1, true>& thrustMem,
Tensor<int, 2, true>& prefixSumOffsets,
Tensor<float, 1, true>& allDistances,
Tensor<float, 3, true>& heapDistances,
Tensor<int, 3, true>& heapIndices,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<long, 2, true>& outIndices,
cudaStream_t stream) {
// Calculate offset lengths, so we know where to write out
// intermediate results
runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
thrustMem, stream);
// Convert all codes to a distance, and write out (distance,
// index) values for all intermediate results
{
auto kThreadsPerBlock = 256;
auto grid = dim3(topQueryToCentroid.getSize(1),
topQueryToCentroid.getSize(0));
auto block = dim3(kThreadsPerBlock);
// pq precomputed terms (2 + 3)
auto smem = sizeof(float);
#ifdef FAISS_USE_FLOAT16
if (useFloat16Lookup) {
smem = sizeof(half);
}
#endif
smem *= numSubQuantizers * numSubQuantizerCodes;
FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
#define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
do { \
auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
\
pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
<<<grid, block, smem, stream>>>( \
queries, \
precompTerm1, \
precompTerm2T, \
precompTerm3T, \
topQueryToCentroid, \
listCodes.data().get(), \
listLengths.data().get(), \
prefixSumOffsets, \
allDistances); \
} while (0)
#ifdef FAISS_USE_FLOAT16
#define RUN_PQ(NUM_SUB_Q) \
do { \
if (useFloat16Lookup) { \
RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
} else { \
RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
} \
} while (0)
#else
#define RUN_PQ(NUM_SUB_Q) \
do { \
RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
} while (0)
#endif // FAISS_USE_FLOAT16
switch (bytesPerCode) {
case 1:
RUN_PQ(1);
break;
case 2:
RUN_PQ(2);
break;
case 3:
RUN_PQ(3);
break;
case 4:
RUN_PQ(4);
break;
case 8:
RUN_PQ(8);
break;
case 12:
RUN_PQ(12);
break;
case 16:
RUN_PQ(16);
break;
case 20:
RUN_PQ(20);
break;
case 24:
RUN_PQ(24);
break;
case 28:
RUN_PQ(28);
break;
case 32:
RUN_PQ(32);
break;
case 40:
RUN_PQ(40);
break;
case 48:
RUN_PQ(48);
break;
case 56:
RUN_PQ(56);
break;
case 64:
RUN_PQ(64);
break;
default:
FAISS_ASSERT(false);
break;
}
#undef RUN_PQ
#undef RUN_PQ_OPT
}
// k-select the output in chunks, to increase parallelism
runPass1SelectLists(prefixSumOffsets,
allDistances,
topQueryToCentroid.getSize(1),
k,
false, // L2 distance chooses smallest
heapDistances,
heapIndices,
stream);
// k-select final output
auto flatHeapDistances = heapDistances.downcastInner<2>();
auto flatHeapIndices = heapIndices.downcastInner<2>();
runPass2SelectLists(flatHeapDistances,
flatHeapIndices,
listIndices,
indicesOptions,
prefixSumOffsets,
topQueryToCentroid,
k,
false, // L2 distance chooses smallest
outDistances,
outIndices,
stream);
CUDA_VERIFY(cudaGetLastError());
}
void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
Tensor<float, 2, true>& precompTerm1,
NoTypeTensor<3, true>& precompTerm2,
NoTypeTensor<3, true>& precompTerm3,
Tensor<int, 2, true>& topQueryToCentroid,
bool useFloat16Lookup,
int bytesPerCode,
int numSubQuantizers,
int numSubQuantizerCodes,
thrust::device_vector<void*>& listCodes,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
thrust::device_vector<int>& listLengths,
int maxListLength,
int k,
// output
Tensor<float, 2, true>& outDistances,
// output
Tensor<long, 2, true>& outIndices,
GpuResources* res) {
constexpr int kMinQueryTileSize = 8;
constexpr int kMaxQueryTileSize = 128;
constexpr int kThrustMemSize = 16384;
int nprobe = topQueryToCentroid.getSize(1);
auto& mem = res->getMemoryManagerCurrentDevice();
auto stream = res->getDefaultStreamCurrentDevice();
// Make a reservation for Thrust to do its dirty work (global memory
// cross-block reduction space); hopefully this is large enough.
DeviceTensor<char, 1, true> thrustMem1(
mem, {kThrustMemSize}, stream);
DeviceTensor<char, 1, true> thrustMem2(
mem, {kThrustMemSize}, stream);
DeviceTensor<char, 1, true>* thrustMem[2] =
{&thrustMem1, &thrustMem2};
// How much temporary storage is available?
// If possible, we'd like to fit within the space available.
size_t sizeAvailable = mem.getSizeAvailable();
// We run two passes of heap selection
// This is the size of the first-level heap passes
constexpr int kNProbeSplit = 8;
int pass2Chunks = std::min(nprobe, kNProbeSplit);
size_t sizeForFirstSelectPass =
pass2Chunks * k * (sizeof(float) + sizeof(int));
// How much temporary storage we need per each query
size_t sizePerQuery =
2 * // # streams
((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
nprobe * maxListLength * sizeof(float) + // allDistances
sizeForFirstSelectPass);
int queryTileSize = (int) (sizeAvailable / sizePerQuery);
if (queryTileSize < kMinQueryTileSize) {
queryTileSize = kMinQueryTileSize;
} else if (queryTileSize > kMaxQueryTileSize) {
queryTileSize = kMaxQueryTileSize;
}
// FIXME: we should adjust queryTileSize to deal with this, since
// indexing is in int32
FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
std::numeric_limits<int>::max());
// Temporary memory buffers
// Make sure there is space prior to the start which will be 0, and
// will handle the boundary condition without branches
DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
mem, {queryTileSize * nprobe + 1}, stream);
DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
mem, {queryTileSize * nprobe + 1}, stream);
DeviceTensor<int, 2, true> prefixSumOffsets1(
prefixSumOffsetSpace1[1].data(),
{queryTileSize, nprobe});
DeviceTensor<int, 2, true> prefixSumOffsets2(
prefixSumOffsetSpace2[1].data(),
{queryTileSize, nprobe});
DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
{&prefixSumOffsets1, &prefixSumOffsets2};
// Make sure the element before prefixSumOffsets is 0, since we
// depend upon simple, boundary-less indexing to get proper results
CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
0,
sizeof(int),
stream));
CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
0,
sizeof(int),
stream));
DeviceTensor<float, 1, true> allDistances1(
mem, {queryTileSize * nprobe * maxListLength}, stream);
DeviceTensor<float, 1, true> allDistances2(
mem, {queryTileSize * nprobe * maxListLength}, stream);
DeviceTensor<float, 1, true>* allDistances[2] =
{&allDistances1, &allDistances2};
DeviceTensor<float, 3, true> heapDistances1(
mem, {queryTileSize, pass2Chunks, k}, stream);
DeviceTensor<float, 3, true> heapDistances2(
mem, {queryTileSize, pass2Chunks, k}, stream);
DeviceTensor<float, 3, true>* heapDistances[2] =
{&heapDistances1, &heapDistances2};
DeviceTensor<int, 3, true> heapIndices1(
mem, {queryTileSize, pass2Chunks, k}, stream);
DeviceTensor<int, 3, true> heapIndices2(
mem, {queryTileSize, pass2Chunks, k}, stream);
DeviceTensor<int, 3, true>* heapIndices[2] =
{&heapIndices1, &heapIndices2};
auto streams = res->getAlternateStreamsCurrentDevice();
streamWait(streams, {stream});
int curStream = 0;
for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
int numQueriesInTile =
std::min(queryTileSize, queries.getSize(0) - query);
auto prefixSumOffsetsView =
prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
auto coarseIndicesView =
topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
auto queryView =
queries.narrowOutermost(query, numQueriesInTile);
auto term1View =
precompTerm1.narrowOutermost(query, numQueriesInTile);
auto term3View =
precompTerm3.narrowOutermost(query, numQueriesInTile);
auto heapDistancesView =
heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
auto heapIndicesView =
heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
auto outDistanceView =
outDistances.narrowOutermost(query, numQueriesInTile);
auto outIndicesView =
outIndices.narrowOutermost(query, numQueriesInTile);
runMultiPassTile(queryView,
term1View,
precompTerm2,
term3View,
coarseIndicesView,
useFloat16Lookup,
bytesPerCode,
numSubQuantizers,
numSubQuantizerCodes,
listCodes,
listIndices,
indicesOptions,
listLengths,
*thrustMem[curStream],
prefixSumOffsetsView,
*allDistances[curStream],
heapDistancesView,
heapIndicesView,
k,
outDistanceView,
outIndicesView,
streams[curStream]);
curStream = (curStream + 1) % 2;
}
streamWait({stream}, streams);
}
} } // namespace