mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
various bugfixes from github issues kmean with some frozen centroids GPU better tiling for large flat datasets default AVX for vector ops
214 lines
7.5 KiB
Plaintext
214 lines
7.5 KiB
Plaintext
/**
|
|
* Copyright (c) 2015-present, Facebook, Inc.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD+Patents license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
#include "IVFUtils.cuh"
|
|
#include "../utils/DeviceUtils.h"
|
|
#include "../utils/Limits.cuh"
|
|
#include "../utils/Select.cuh"
|
|
#include "../utils/StaticUtils.h"
|
|
#include "../utils/Tensor.cuh"
|
|
|
|
//
|
|
// This kernel is split into a separate compilation unit to cut down
|
|
// on compile time
|
|
//
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
// This is warp divergence central, but this is really a final step
|
|
// and happening a small number of times
|
|
inline __device__ int binarySearchForBucket(int* prefixSumOffsets,
|
|
int size,
|
|
int val) {
|
|
int start = 0;
|
|
int end = size;
|
|
|
|
while (end - start > 0) {
|
|
int mid = start + (end - start) / 2;
|
|
|
|
int midVal = prefixSumOffsets[mid];
|
|
|
|
// Find the first bucket that we are <=
|
|
if (midVal <= val) {
|
|
start = mid + 1;
|
|
} else {
|
|
end = mid;
|
|
}
|
|
}
|
|
|
|
// We must find the bucket that it is in
|
|
assert(start != size);
|
|
|
|
return start;
|
|
}
|
|
|
|
template <int ThreadsPerBlock,
|
|
int NumWarpQ,
|
|
int NumThreadQ,
|
|
bool Dir>
|
|
__global__ void
|
|
pass2SelectLists(Tensor<float, 2, true> heapDistances,
|
|
Tensor<int, 2, true> heapIndices,
|
|
void** listIndices,
|
|
Tensor<int, 2, true> prefixSumOffsets,
|
|
Tensor<int, 2, true> topQueryToCentroid,
|
|
int k,
|
|
IndicesOptions opt,
|
|
Tensor<float, 2, true> outDistances,
|
|
Tensor<long, 2, true> outIndices) {
|
|
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
|
|
|
|
__shared__ float smemK[kNumWarps * NumWarpQ];
|
|
__shared__ int smemV[kNumWarps * NumWarpQ];
|
|
|
|
constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
|
|
BlockSelect<float, int, Dir, Comparator<float>,
|
|
NumWarpQ, NumThreadQ, ThreadsPerBlock>
|
|
heap(kInit, -1, smemK, smemV, k);
|
|
|
|
auto queryId = blockIdx.x;
|
|
int num = heapDistances.getSize(1);
|
|
int limit = utils::roundDown(num, kWarpSize);
|
|
|
|
int i = threadIdx.x;
|
|
auto heapDistanceStart = heapDistances[queryId];
|
|
|
|
// BlockSelect add cannot be used in a warp divergent circumstance; we
|
|
// handle the remainder warp below
|
|
for (; i < limit; i += blockDim.x) {
|
|
heap.add(heapDistanceStart[i], i);
|
|
}
|
|
|
|
// Handle warp divergence separately
|
|
if (i < num) {
|
|
heap.addThreadQ(heapDistanceStart[i], i);
|
|
}
|
|
|
|
// Merge all final results
|
|
heap.reduce();
|
|
|
|
for (int i = threadIdx.x; i < k; i += blockDim.x) {
|
|
outDistances[queryId][i] = smemK[i];
|
|
|
|
// `v` is the index in `heapIndices`
|
|
// We need to translate this into an original user index. The
|
|
// reason why we don't maintain intermediate results in terms of
|
|
// user indices is to substantially reduce temporary memory
|
|
// requirements and global memory write traffic for the list
|
|
// scanning.
|
|
// This code is highly divergent, but it's probably ok, since this
|
|
// is the very last step and it is happening a small number of
|
|
// times (#queries x k).
|
|
int v = smemV[i];
|
|
long index = -1;
|
|
|
|
if (v != -1) {
|
|
// `offset` is the offset of the intermediate result, as
|
|
// calculated by the original scan.
|
|
int offset = heapIndices[queryId][v];
|
|
|
|
// In order to determine the actual user index, we need to first
|
|
// determine what list it was in.
|
|
// We do this by binary search in the prefix sum list.
|
|
int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(),
|
|
prefixSumOffsets.getSize(1),
|
|
offset);
|
|
|
|
// This is then the probe for the query; we can find the actual
|
|
// list ID from this
|
|
int listId = topQueryToCentroid[queryId][probe];
|
|
|
|
// Now, we need to know the offset within the list
|
|
// We ensure that before the array (at offset -1), there is a 0 value
|
|
int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
|
|
int listOffset = offset - listStart;
|
|
|
|
// This gives us our final index
|
|
if (opt == INDICES_32_BIT) {
|
|
index = (long) ((int*) listIndices[listId])[listOffset];
|
|
} else if (opt == INDICES_64_BIT) {
|
|
index = ((long*) listIndices[listId])[listOffset];
|
|
} else {
|
|
index = ((long) listId << 32 | (long) listOffset);
|
|
}
|
|
}
|
|
|
|
outIndices[queryId][i] = index;
|
|
}
|
|
}
|
|
|
|
void
|
|
runPass2SelectLists(Tensor<float, 2, true>& heapDistances,
|
|
Tensor<int, 2, true>& heapIndices,
|
|
thrust::device_vector<void*>& listIndices,
|
|
IndicesOptions indicesOptions,
|
|
Tensor<int, 2, true>& prefixSumOffsets,
|
|
Tensor<int, 2, true>& topQueryToCentroid,
|
|
int k,
|
|
bool chooseLargest,
|
|
Tensor<float, 2, true>& outDistances,
|
|
Tensor<long, 2, true>& outIndices,
|
|
cudaStream_t stream) {
|
|
constexpr auto kThreadsPerBlock = 128;
|
|
|
|
auto grid = dim3(topQueryToCentroid.getSize(0));
|
|
auto block = dim3(kThreadsPerBlock);
|
|
|
|
#define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR) \
|
|
do { \
|
|
pass2SelectLists<kThreadsPerBlock, \
|
|
NUM_WARP_Q, NUM_THREAD_Q, DIR> \
|
|
<<<grid, block, 0, stream>>>(heapDistances, \
|
|
heapIndices, \
|
|
listIndices.data().get(), \
|
|
prefixSumOffsets, \
|
|
topQueryToCentroid, \
|
|
k, \
|
|
indicesOptions, \
|
|
outDistances, \
|
|
outIndices); \
|
|
CUDA_TEST_ERROR(); \
|
|
return; /* success */ \
|
|
} while (0)
|
|
|
|
#define RUN_PASS_DIR(DIR) \
|
|
do { \
|
|
if (k == 1) { \
|
|
RUN_PASS(1, 1, DIR); \
|
|
} else if (k <= 32) { \
|
|
RUN_PASS(32, 2, DIR); \
|
|
} else if (k <= 64) { \
|
|
RUN_PASS(64, 3, DIR); \
|
|
} else if (k <= 128) { \
|
|
RUN_PASS(128, 3, DIR); \
|
|
} else if (k <= 256) { \
|
|
RUN_PASS(256, 4, DIR); \
|
|
} else if (k <= 512) { \
|
|
RUN_PASS(512, 8, DIR); \
|
|
} else if (k <= 1024) { \
|
|
RUN_PASS(1024, 8, DIR); \
|
|
} \
|
|
} while (0)
|
|
|
|
if (chooseLargest) {
|
|
RUN_PASS_DIR(true);
|
|
} else {
|
|
RUN_PASS_DIR(false);
|
|
}
|
|
|
|
// unimplemented / too many resources
|
|
FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
|
|
|
|
#undef RUN_PASS_DIR
|
|
#undef RUN_PASS
|
|
}
|
|
|
|
} } // namespace
|