9 #include "IVFFlatScan.cuh"
10 #include "../GpuResources.h"
11 #include "IVFUtils.cuh"
12 #include "../utils/ConversionOperators.cuh"
13 #include "../utils/DeviceDefs.cuh"
14 #include "../utils/DeviceUtils.h"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/Reductions.cuh"
21 #include "../utils/StaticUtils.h"
22 #include <thrust/host_vector.h>
24 namespace faiss {
namespace gpu {
27 inline __device__
typename Math<T>::ScalarType l2Distance(T a, T b) {
28 a = Math<T>::sub(a, b);
29 a = Math<T>::mul(a, a);
34 inline __device__
typename Math<T>::ScalarType ipDistance(T a, T b) {
43 template <
int Dims,
bool L2,
typename T>
48 template <
bool L2,
typename T>
50 static __device__
void scan(
float* query,
55 extern __shared__
float smem[];
56 T* vecs = (T*) vecData;
58 for (
int vec = 0; vec < numVecs; ++vec) {
62 for (
int d = threadIdx.x; d < dim; d += blockDim.x) {
64 float queryVal = query[d];
68 curDist = l2Distance(queryVal, vecVal);
70 curDist = ipDistance(queryVal, vecVal);
77 dist = blockReduceAllSum<float, false, true>(dist, smem);
79 if (threadIdx.x == 0) {
80 distanceOut[vec] = dist;
87 template <
bool L2,
typename T>
89 static __device__
void scan(
float* query,
94 extern __shared__
float smem[];
95 T* vecs = (T*) vecData;
97 float queryVal = query[threadIdx.x];
99 constexpr
int kUnroll = 4;
100 int limit = utils::roundDown(numVecs, kUnroll);
102 for (
int i = 0; i < limit; i += kUnroll) {
103 float vecVal[kUnroll];
106 for (
int j = 0; j < kUnroll; ++j) {
111 for (
int j = 0; j < kUnroll; ++j) {
113 vecVal[j] = l2Distance(queryVal, vecVal[j]);
115 vecVal[j] = ipDistance(queryVal, vecVal[j]);
119 blockReduceAllSum<kUnroll, float, false, true>(vecVal, smem);
121 if (threadIdx.x == 0) {
123 for (
int j = 0; j < kUnroll; ++j) {
124 distanceOut[i + j] = vecVal[j];
130 for (
int i = limit; i < numVecs; ++i) {
134 vecVal = l2Distance(queryVal, vecVal);
136 vecVal = ipDistance(queryVal, vecVal);
139 vecVal = blockReduceAllSum<float, false, true>(vecVal, smem);
141 if (threadIdx.x == 0) {
142 distanceOut[i] = vecVal;
148 template <
int Dims,
bool L2,
typename T>
156 auto queryId = blockIdx.y;
157 auto probeId = blockIdx.x;
161 int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
163 auto listId = listIds[queryId][probeId];
169 auto query = queries[queryId].
data();
170 auto vecs = allListData[listId];
171 auto numVecs = listLengths[listId];
173 auto distanceOut = distance[outBase].
data();
175 IVFFlatScan<Dims, L2, T>::scan(query, vecs, numVecs, dim, distanceOut);
179 runIVFFlatScanTile(Tensor<float, 2, true>& queries,
180 Tensor<int, 2, true>& listIds,
181 thrust::device_vector<void*>& listData,
182 thrust::device_vector<void*>& listIndices,
183 IndicesOptions indicesOptions,
184 thrust::device_vector<int>& listLengths,
185 Tensor<char, 1, true>& thrustMem,
186 Tensor<int, 2, true>& prefixSumOffsets,
187 Tensor<float, 1, true>& allDistances,
188 Tensor<float, 3, true>& heapDistances,
189 Tensor<int, 3, true>& heapIndices,
193 Tensor<float, 2, true>& outDistances,
194 Tensor<long, 2, true>& outIndices,
195 cudaStream_t stream) {
198 runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream);
201 constexpr
int kMaxThreadsIVF = 512;
206 int dim = queries.getSize(1);
207 int numThreads = std::min(dim, kMaxThreadsIVF);
209 auto grid = dim3(listIds.getSize(1),
211 auto block = dim3(numThreads);
213 auto smem =
sizeof(float) * utils::divUp(numThreads, kWarpSize) * 4;
215 #define RUN_IVF_FLAT(DIMS, L2, T) \
217 ivfFlatScan<DIMS, L2, T> \
218 <<<grid, block, smem, stream>>>( \
221 listData.data().get(), \
222 listLengths.data().get(), \
227 #ifdef FAISS_USE_FLOAT16
229 #define HANDLE_DIM_CASE(DIMS) \
233 RUN_IVF_FLAT(DIMS, true, half); \
235 RUN_IVF_FLAT(DIMS, true, float); \
239 RUN_IVF_FLAT(DIMS, false, half); \
241 RUN_IVF_FLAT(DIMS, false, float); \
247 #define HANDLE_DIM_CASE(DIMS) \
251 FAISS_ASSERT(false); \
253 RUN_IVF_FLAT(DIMS, true, float); \
257 FAISS_ASSERT(false); \
259 RUN_IVF_FLAT(DIMS, false, float); \
264 #endif // FAISS_USE_FLOAT16
266 if (dim <= kMaxThreadsIVF) {
274 #undef HANDLE_DIM_CASE
278 runPass1SelectLists(prefixSumOffsets,
288 auto flatHeapDistances = heapDistances.downcastInner<2>();
289 auto flatHeapIndices = heapIndices.downcastInner<2>();
291 runPass2SelectLists(flatHeapDistances,
305 runIVFFlatScan(Tensor<float, 2, true>& queries,
306 Tensor<int, 2, true>& listIds,
307 thrust::device_vector<void*>& listData,
308 thrust::device_vector<void*>& listIndices,
309 IndicesOptions indicesOptions,
310 thrust::device_vector<int>& listLengths,
316 Tensor<float, 2, true>& outDistances,
318 Tensor<long, 2, true>& outIndices,
320 constexpr
int kMinQueryTileSize = 8;
321 constexpr
int kMaxQueryTileSize = 128;
322 constexpr
int kThrustMemSize = 16384;
324 int nprobe = listIds.getSize(1);
326 auto& mem = res->getMemoryManagerCurrentDevice();
327 auto stream = res->getDefaultStreamCurrentDevice();
331 DeviceTensor<char, 1, true> thrustMem1(
332 mem, {kThrustMemSize}, stream);
333 DeviceTensor<char, 1, true> thrustMem2(
334 mem, {kThrustMemSize}, stream);
335 DeviceTensor<char, 1, true>* thrustMem[2] =
336 {&thrustMem1, &thrustMem2};
340 size_t sizeAvailable = mem.getSizeAvailable();
344 constexpr
int kNProbeSplit = 8;
345 int pass2Chunks = std::min(nprobe, kNProbeSplit);
347 size_t sizeForFirstSelectPass =
348 pass2Chunks * k * (
sizeof(float) +
sizeof(
int));
351 size_t sizePerQuery =
353 ((nprobe *
sizeof(int) +
sizeof(
int)) +
354 nprobe * maxListLength *
sizeof(
float) +
355 sizeForFirstSelectPass);
357 int queryTileSize = (int) (sizeAvailable / sizePerQuery);
359 if (queryTileSize < kMinQueryTileSize) {
360 queryTileSize = kMinQueryTileSize;
361 }
else if (queryTileSize > kMaxQueryTileSize) {
362 queryTileSize = kMaxQueryTileSize;
367 FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
368 std::numeric_limits<int>::max());
373 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
374 mem, {queryTileSize * nprobe + 1}, stream);
375 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
376 mem, {queryTileSize * nprobe + 1}, stream);
378 DeviceTensor<int, 2, true> prefixSumOffsets1(
379 prefixSumOffsetSpace1[1].data(),
380 {queryTileSize, nprobe});
381 DeviceTensor<int, 2, true> prefixSumOffsets2(
382 prefixSumOffsetSpace2[1].data(),
383 {queryTileSize, nprobe});
384 DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
385 {&prefixSumOffsets1, &prefixSumOffsets2};
389 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
393 CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
398 DeviceTensor<float, 1, true> allDistances1(
399 mem, {queryTileSize * nprobe * maxListLength}, stream);
400 DeviceTensor<float, 1, true> allDistances2(
401 mem, {queryTileSize * nprobe * maxListLength}, stream);
402 DeviceTensor<float, 1, true>* allDistances[2] =
403 {&allDistances1, &allDistances2};
405 DeviceTensor<float, 3, true> heapDistances1(
406 mem, {queryTileSize, pass2Chunks, k}, stream);
407 DeviceTensor<float, 3, true> heapDistances2(
408 mem, {queryTileSize, pass2Chunks, k}, stream);
409 DeviceTensor<float, 3, true>* heapDistances[2] =
410 {&heapDistances1, &heapDistances2};
412 DeviceTensor<int, 3, true> heapIndices1(
413 mem, {queryTileSize, pass2Chunks, k}, stream);
414 DeviceTensor<int, 3, true> heapIndices2(
415 mem, {queryTileSize, pass2Chunks, k}, stream);
416 DeviceTensor<int, 3, true>* heapIndices[2] =
417 {&heapIndices1, &heapIndices2};
419 auto streams = res->getAlternateStreamsCurrentDevice();
420 streamWait(streams, {stream});
424 for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
425 int numQueriesInTile =
426 std::min(queryTileSize, queries.getSize(0) - query);
428 auto prefixSumOffsetsView =
429 prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
432 listIds.narrowOutermost(query, numQueriesInTile);
434 queries.narrowOutermost(query, numQueriesInTile);
436 auto heapDistancesView =
437 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
438 auto heapIndicesView =
439 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
441 auto outDistanceView =
442 outDistances.narrowOutermost(query, numQueriesInTile);
443 auto outIndicesView =
444 outIndices.narrowOutermost(query, numQueriesInTile);
446 runIVFFlatScanTile(queryView,
452 *thrustMem[curStream],
453 prefixSumOffsetsView,
454 *allDistances[curStream],
464 curStream = (curStream + 1) % 2;
467 streamWait({stream}, streams);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
__host__ __device__ IndexT getSize(int i) const
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
The class that we use to provide scan specializations.