11 #include "PQCodeDistances.cuh"
13 #include "BroadcastSum.cuh"
14 #include "Distance.cuh"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Float16.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/PtxUtils.cuh"
21 #include "../utils/StaticUtils.h"
22 #include "../utils/Transpose.cuh"
24 namespace faiss {
namespace gpu {
30 #ifdef FAISS_USE_FLOAT16
33 inline static __device__ half to(
float v) {
return __float2half(v); }
39 inline static __device__
float to(
float v) {
return v; }
44 template <
typename OutCodeT,
int DimsPerSubQuantizer>
46 __launch_bounds__(288, 4)
47 pqCodeDistances(
Tensor<
float, 2, true> queries,
49 Tensor<
float, 2, true> coarseCentroids,
50 Tensor<
float, 3, true> pqCentroids,
51 Tensor<
int, 2, true> topQueryToCentroid,
53 Tensor<OutCodeT, 4, true> outCodeDistances) {
54 const auto numSubQuantizers = pqCentroids.getSize(0);
55 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
56 assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
57 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
59 bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
60 int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
62 extern __shared__
float smem[];
65 float subQuantizerData[DimsPerSubQuantizer];
67 auto code = threadIdx.x;
68 auto subQuantizer = blockIdx.y;
73 for (
int i = 0; i < DimsPerSubQuantizer; ++i) {
74 subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
78 float* smemQuery = smem;
82 float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
83 float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
86 int* coarseIds = (
int*) &smemResidual2[DimsPerSubQuantizer];
92 auto startQueryId = blockIdx.x * queriesPerBlock;
93 auto numQueries = queries.getSize(0) - startQueryId;
94 if (numQueries > queriesPerBlock) {
95 numQueries = queriesPerBlock;
98 for (
int query = 0; query < numQueries; ++query) {
99 auto queryId = startQueryId + query;
101 auto querySubQuantizer =
102 queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
105 for (
int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
106 smemQuery[i] = querySubQuantizer[i];
110 for (
int i = threadIdx.x;
111 i < topQueryToCentroid.getSize(1); i += blockDim.x) {
112 coarseIds[i] = topQueryToCentroid[queryId][i];
120 if (isLoadingThread) {
121 for (
int i = loadingThreadId;
122 i < DimsPerSubQuantizer;
123 i += blockDim.x - codesPerSubQuantizer) {
124 auto coarseId = coarseIds[0];
126 coarseId = coarseId == -1 ? 0 : coarseId;
127 auto coarseCentroidSubQuantizer =
128 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
130 smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
135 for (
int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
139 if (isLoadingThread) {
141 for (
int i = loadingThreadId;
142 i < DimsPerSubQuantizer;
143 i += blockDim.x - codesPerSubQuantizer) {
146 if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
147 auto coarseId = coarseIds[coarse + 1];
149 coarseId = coarseId == -1 ? 0 : coarseId;
151 auto coarseCentroidSubQuantizer =
152 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
154 smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
161 constexpr
int kUnroll = 4;
162 constexpr
int kRemainder = DimsPerSubQuantizer % kUnroll;
163 constexpr
int kRemainderBase = DimsPerSubQuantizer - kRemainder;
171 for (
int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
174 for (
int j = 0; j < kUnroll; ++j) {
175 vals[j] = smemResidual1[i * kUnroll + j];
179 for (
int j = 0; j < kUnroll; ++j) {
180 vals[j] -= subQuantizerData[i * kUnroll + j];
184 for (
int j = 0; j < kUnroll; ++j) {
189 for (
int j = 0; j < kUnroll; ++j) {
196 for (
int j = 0; j < kRemainder; ++j) {
197 vals[j] = smemResidual1[kRemainderBase + j];
201 for (
int j = 0; j < kRemainder; ++j) {
202 vals[j] -= subQuantizerData[kRemainderBase + j];
206 for (
int j = 0; j < kRemainder; ++j) {
211 for (
int j = 0; j < kRemainder; ++j) {
216 outCodeDistances[queryId][coarse][subQuantizer][code] =
217 Converter<OutCodeT>::to(dist);
221 float* tmp = smemResidual1;
222 smemResidual1 = smemResidual2;
229 residualVector(Tensor<float, 2, true> queries,
230 Tensor<float, 2, true> coarseCentroids,
231 Tensor<int, 2, true> topQueryToCentroid,
235 Tensor<float, 4, true> residual) {
239 auto queryId = blockIdx.x;
240 auto centroidId = blockIdx.y;
242 int realCentroidId = topQueryToCentroid[queryId][centroidId];
244 for (
int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
245 float q = queries[queryId][dim];
246 float c = coarseCentroids[realCentroidId][dim];
248 residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
254 runResidualVector(Tensor<float, 3, true>& pqCentroids,
255 Tensor<float, 2, true>& queries,
256 Tensor<float, 2, true>& coarseCentroids,
257 Tensor<int, 2, true>& topQueryToCentroid,
258 Tensor<float, 4, true>& residual,
259 cudaStream_t stream) {
261 dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
262 auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
264 residualVector<<<grid, block, 0, stream>>>(
265 queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
272 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
273 Tensor<float, 2, true>& queries,
274 Tensor<float, 2, true>& coarseCentroids,
275 Tensor<int, 2, true>& topQueryToCentroid,
276 NoTypeTensor<4, true>& outCodeDistances,
277 bool useFloat16Lookup,
279 cublasHandle_t handle,
280 cudaStream_t stream) {
283 DeviceTensor<float, 4, true> residual(
285 {pqCentroids.getSize(0),
286 topQueryToCentroid.getSize(0),
287 topQueryToCentroid.getSize(1),
288 pqCentroids.getSize(1)},
291 runResidualVector(pqCentroids, queries,
292 coarseCentroids, topQueryToCentroid,
296 DeviceTensor<float, 1, true> residualNorms(
298 {pqCentroids.getSize(0) *
299 topQueryToCentroid.getSize(0) *
300 topQueryToCentroid.getSize(1)},
303 auto residualView2 = residual.view<2>(
304 {pqCentroids.getSize(0) *
305 topQueryToCentroid.getSize(0) *
306 topQueryToCentroid.getSize(1),
307 pqCentroids.getSize(1)});
309 runL2Norm(residualView2, residualNorms,
true, stream);
314 auto residualView3 = residual.view<3>(
315 {pqCentroids.getSize(0),
316 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
317 pqCentroids.getSize(1)});
319 DeviceTensor<float, 3, true> residualDistance(
321 {pqCentroids.getSize(0),
322 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
323 pqCentroids.getSize(2)},
326 runIteratedMatrixMult(residualDistance,
false,
327 residualView3,
false,
334 auto residualDistanceView2 = residualDistance.view<2>(
335 {pqCentroids.getSize(0) *
336 topQueryToCentroid.getSize(0) *
337 topQueryToCentroid.getSize(1),
338 pqCentroids.getSize(2)});
340 runSumAlongRows(residualNorms, residualDistanceView2, stream);
342 Tensor<float, 4, true> outCodeDistancesF;
343 DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
345 #ifdef FAISS_USE_FLOAT16
346 if (useFloat16Lookup) {
347 outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
348 mem, {outCodeDistances.getSize(0),
349 outCodeDistances.getSize(1),
350 outCodeDistances.getSize(2),
351 outCodeDistances.getSize(3)},
354 outCodeDistancesF = outCodeDistancesFloatMem;
358 if (!useFloat16Lookup) {
359 outCodeDistancesF = outCodeDistances.toTensor<
float>();
364 auto outCodeDistancesView = outCodeDistancesF.view<3>(
365 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
366 outCodeDistances.getSize(2),
367 outCodeDistances.getSize(3)});
369 runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
374 DeviceTensor<float, 3, true> pqCentroidsTranspose(
376 {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
379 runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
381 auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
382 {pqCentroids.getSize(0) * pqCentroids.getSize(2),
383 pqCentroids.getSize(1)});
385 DeviceTensor<float, 1, true> pqCentroidsNorm(
387 {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
390 runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm,
true, stream);
394 auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
395 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
396 outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
398 runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
400 #ifdef FAISS_USE_FLOAT16
401 if (useFloat16Lookup) {
403 auto outCodeDistancesH = outCodeDistances.toTensor<half>();
404 toHalf(stream, outCodeDistancesF, outCodeDistancesH);
410 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
411 Tensor<float, 2, true>& queries,
412 Tensor<float, 2, true>& coarseCentroids,
413 Tensor<int, 2, true>& topQueryToCentroid,
414 NoTypeTensor<4, true>& outCodeDistances,
415 bool useFloat16Lookup,
416 cudaStream_t stream) {
417 const auto numSubQuantizers = pqCentroids.getSize(0);
418 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
419 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
424 constexpr
int kQueriesPerBlock = 8;
426 auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
431 auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
432 auto block = dim3(codesPerSubQuantizer + loadingThreads);
434 auto smem = (3 * dimsPerSubQuantizer) *
sizeof(
float)
435 + topQueryToCentroid.getSize(1) *
sizeof(int);
437 #ifdef FAISS_USE_FLOAT16
438 #define CODE_DISTANCE(DIMS) \
440 if (useFloat16Lookup) { \
441 auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
443 pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
444 queries, kQueriesPerBlock, \
445 coarseCentroids, pqCentroids, \
446 topQueryToCentroid, outCodeDistancesT); \
448 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
450 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
451 queries, kQueriesPerBlock, \
452 coarseCentroids, pqCentroids, \
453 topQueryToCentroid, outCodeDistancesT); \
457 #define CODE_DISTANCE(DIMS) \
459 if (!useFloat16Lookup) { \
460 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
462 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
463 queries, kQueriesPerBlock, \
464 coarseCentroids, pqCentroids, \
465 topQueryToCentroid, outCodeDistancesT); \
470 switch (dimsPerSubQuantizer) {