10 #include "PQCodeDistances.cuh"
12 #include "BroadcastSum.cuh"
13 #include "Distance.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MatrixMult.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/StaticUtils.h"
21 #include "../utils/Transpose.cuh"
23 namespace faiss {
namespace gpu {
29 #ifdef FAISS_USE_FLOAT16
32 inline static __device__ half to(
float v) {
return __float2half(v); }
38 inline static __device__
float to(
float v) {
return v; }
43 template <
typename OutCodeT,
int DimsPerSubQuantizer>
45 __launch_bounds__(288, 4)
46 pqCodeDistances(
Tensor<
float, 2, true> queries,
48 Tensor<
float, 2, true> coarseCentroids,
49 Tensor<
float, 3, true> pqCentroids,
50 Tensor<
int, 2, true> topQueryToCentroid,
52 Tensor<OutCodeT, 4, true> outCodeDistances) {
53 const auto numSubQuantizers = pqCentroids.getSize(0);
54 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
55 assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
56 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
58 bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
59 int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
61 extern __shared__
float smem[];
64 float subQuantizerData[DimsPerSubQuantizer];
66 auto code = threadIdx.x;
67 auto subQuantizer = blockIdx.y;
72 for (
int i = 0; i < DimsPerSubQuantizer; ++i) {
73 subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
77 float* smemQuery = smem;
81 float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
82 float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
85 int* coarseIds = (
int*) &smemResidual2[DimsPerSubQuantizer];
91 auto startQueryId = blockIdx.x * queriesPerBlock;
92 auto numQueries = queries.getSize(0) - startQueryId;
93 if (numQueries > queriesPerBlock) {
94 numQueries = queriesPerBlock;
97 for (
int query = 0; query < numQueries; ++query) {
98 auto queryId = startQueryId + query;
100 auto querySubQuantizer =
101 queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
104 for (
int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
105 smemQuery[i] = querySubQuantizer[i];
109 for (
int i = threadIdx.x;
110 i < topQueryToCentroid.getSize(1); i += blockDim.x) {
111 coarseIds[i] = topQueryToCentroid[queryId][i];
119 if (isLoadingThread) {
120 for (
int i = loadingThreadId;
121 i < DimsPerSubQuantizer;
122 i += blockDim.x - codesPerSubQuantizer) {
123 auto coarseId = coarseIds[0];
125 coarseId = coarseId == -1 ? 0 : coarseId;
126 auto coarseCentroidSubQuantizer =
127 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
129 smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
134 for (
int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
138 if (isLoadingThread) {
140 for (
int i = loadingThreadId;
141 i < DimsPerSubQuantizer;
142 i += blockDim.x - codesPerSubQuantizer) {
145 if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
146 auto coarseId = coarseIds[coarse + 1];
148 coarseId = coarseId == -1 ? 0 : coarseId;
150 auto coarseCentroidSubQuantizer =
151 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
153 smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
160 constexpr
int kUnroll = 4;
161 constexpr
int kRemainder = DimsPerSubQuantizer % kUnroll;
162 constexpr
int kRemainderBase = DimsPerSubQuantizer - kRemainder;
170 for (
int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
173 for (
int j = 0; j < kUnroll; ++j) {
174 vals[j] = smemResidual1[i * kUnroll + j];
178 for (
int j = 0; j < kUnroll; ++j) {
179 vals[j] -= subQuantizerData[i * kUnroll + j];
183 for (
int j = 0; j < kUnroll; ++j) {
188 for (
int j = 0; j < kUnroll; ++j) {
195 for (
int j = 0; j < kRemainder; ++j) {
196 vals[j] = smemResidual1[kRemainderBase + j];
200 for (
int j = 0; j < kRemainder; ++j) {
201 vals[j] -= subQuantizerData[kRemainderBase + j];
205 for (
int j = 0; j < kRemainder; ++j) {
210 for (
int j = 0; j < kRemainder; ++j) {
215 outCodeDistances[queryId][coarse][subQuantizer][code] =
216 Converter<OutCodeT>::to(dist);
220 float* tmp = smemResidual1;
221 smemResidual1 = smemResidual2;
228 residualVector(Tensor<float, 2, true> queries,
229 Tensor<float, 2, true> coarseCentroids,
230 Tensor<int, 2, true> topQueryToCentroid,
234 Tensor<float, 4, true> residual) {
238 auto queryId = blockIdx.x;
239 auto centroidId = blockIdx.y;
241 int realCentroidId = topQueryToCentroid[queryId][centroidId];
243 for (
int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
244 float q = queries[queryId][dim];
245 float c = coarseCentroids[realCentroidId][dim];
247 residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
253 runResidualVector(Tensor<float, 3, true>& pqCentroids,
254 Tensor<float, 2, true>& queries,
255 Tensor<float, 2, true>& coarseCentroids,
256 Tensor<int, 2, true>& topQueryToCentroid,
257 Tensor<float, 4, true>& residual,
258 cudaStream_t stream) {
260 dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
261 auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
263 residualVector<<<grid, block, 0, stream>>>(
264 queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
271 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
272 Tensor<float, 2, true>& queries,
273 Tensor<float, 2, true>& coarseCentroids,
274 Tensor<int, 2, true>& topQueryToCentroid,
275 NoTypeTensor<4, true>& outCodeDistances,
276 bool useFloat16Lookup,
278 cublasHandle_t handle,
279 cudaStream_t stream) {
282 DeviceTensor<float, 4, true> residual(
284 {pqCentroids.getSize(0),
285 topQueryToCentroid.getSize(0),
286 topQueryToCentroid.getSize(1),
287 pqCentroids.getSize(1)},
290 runResidualVector(pqCentroids, queries,
291 coarseCentroids, topQueryToCentroid,
295 DeviceTensor<float, 1, true> residualNorms(
297 {pqCentroids.getSize(0) *
298 topQueryToCentroid.getSize(0) *
299 topQueryToCentroid.getSize(1)},
302 auto residualView2 = residual.view<2>(
303 {pqCentroids.getSize(0) *
304 topQueryToCentroid.getSize(0) *
305 topQueryToCentroid.getSize(1),
306 pqCentroids.getSize(1)});
308 runL2Norm(residualView2, residualNorms,
true, stream);
313 auto residualView3 = residual.view<3>(
314 {pqCentroids.getSize(0),
315 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
316 pqCentroids.getSize(1)});
318 DeviceTensor<float, 3, true> residualDistance(
320 {pqCentroids.getSize(0),
321 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
322 pqCentroids.getSize(2)},
325 runIteratedMatrixMult(residualDistance,
false,
326 residualView3,
false,
333 auto residualDistanceView2 = residualDistance.view<2>(
334 {pqCentroids.getSize(0) *
335 topQueryToCentroid.getSize(0) *
336 topQueryToCentroid.getSize(1),
337 pqCentroids.getSize(2)});
339 runSumAlongRows(residualNorms, residualDistanceView2,
false, stream);
341 Tensor<float, 4, true> outCodeDistancesF;
342 DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
344 #ifdef FAISS_USE_FLOAT16
345 if (useFloat16Lookup) {
346 outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
347 mem, {outCodeDistances.getSize(0),
348 outCodeDistances.getSize(1),
349 outCodeDistances.getSize(2),
350 outCodeDistances.getSize(3)},
353 outCodeDistancesF = outCodeDistancesFloatMem;
357 if (!useFloat16Lookup) {
358 outCodeDistancesF = outCodeDistances.toTensor<
float>();
363 auto outCodeDistancesView = outCodeDistancesF.view<3>(
364 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
365 outCodeDistances.getSize(2),
366 outCodeDistances.getSize(3)});
368 runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
373 DeviceTensor<float, 3, true> pqCentroidsTranspose(
375 {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
378 runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
380 auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
381 {pqCentroids.getSize(0) * pqCentroids.getSize(2),
382 pqCentroids.getSize(1)});
384 DeviceTensor<float, 1, true> pqCentroidsNorm(
386 {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
389 runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm,
true, stream);
393 auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
394 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
395 outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
397 runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
399 #ifdef FAISS_USE_FLOAT16
400 if (useFloat16Lookup) {
402 auto outCodeDistancesH = outCodeDistances.toTensor<half>();
403 toHalf(stream, outCodeDistancesF, outCodeDistancesH);
409 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
410 Tensor<float, 2, true>& queries,
411 Tensor<float, 2, true>& coarseCentroids,
412 Tensor<int, 2, true>& topQueryToCentroid,
413 NoTypeTensor<4, true>& outCodeDistances,
414 bool useFloat16Lookup,
415 cudaStream_t stream) {
416 const auto numSubQuantizers = pqCentroids.getSize(0);
417 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
418 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
423 constexpr
int kQueriesPerBlock = 8;
425 auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
430 auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
431 auto block = dim3(codesPerSubQuantizer + loadingThreads);
433 auto smem = (3 * dimsPerSubQuantizer) *
sizeof(
float)
434 + topQueryToCentroid.getSize(1) *
sizeof(int);
436 #ifdef FAISS_USE_FLOAT16
437 #define CODE_DISTANCE(DIMS) \
439 if (useFloat16Lookup) { \
440 auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
442 pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
443 queries, kQueriesPerBlock, \
444 coarseCentroids, pqCentroids, \
445 topQueryToCentroid, outCodeDistancesT); \
447 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
449 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
450 queries, kQueriesPerBlock, \
451 coarseCentroids, pqCentroids, \
452 topQueryToCentroid, outCodeDistancesT); \
456 #define CODE_DISTANCE(DIMS) \
458 if (!useFloat16Lookup) { \
459 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
461 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
462 queries, kQueriesPerBlock, \
463 coarseCentroids, pqCentroids, \
464 topQueryToCentroid, outCodeDistancesT); \
469 switch (dimsPerSubQuantizer) {