9 #include "PQCodeDistances.cuh"
11 #include "BroadcastSum.cuh"
12 #include "Distance.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MatrixMult.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Transpose.cuh"
22 namespace faiss {
namespace gpu {
28 #ifdef FAISS_USE_FLOAT16
31 inline static __device__ half to(
float v) {
return __float2half(v); }
37 inline static __device__
float to(
float v) {
return v; }
42 template <
typename OutCodeT,
int DimsPerSubQuantizer>
44 __launch_bounds__(288, 4)
45 pqCodeDistances(
Tensor<
float, 2, true> queries,
47 Tensor<
float, 2, true> coarseCentroids,
48 Tensor<
float, 3, true> pqCentroids,
49 Tensor<
int, 2, true> topQueryToCentroid,
51 Tensor<OutCodeT, 4, true> outCodeDistances) {
52 const auto numSubQuantizers = pqCentroids.getSize(0);
53 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
54 assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
55 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
57 bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
58 int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
60 extern __shared__
float smem[];
63 float subQuantizerData[DimsPerSubQuantizer];
65 auto code = threadIdx.x;
66 auto subQuantizer = blockIdx.y;
71 for (
int i = 0; i < DimsPerSubQuantizer; ++i) {
72 subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
76 float* smemQuery = smem;
80 float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
81 float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
84 int* coarseIds = (
int*) &smemResidual2[DimsPerSubQuantizer];
90 auto startQueryId = blockIdx.x * queriesPerBlock;
91 auto numQueries = queries.getSize(0) - startQueryId;
92 if (numQueries > queriesPerBlock) {
93 numQueries = queriesPerBlock;
96 for (
int query = 0; query < numQueries; ++query) {
97 auto queryId = startQueryId + query;
99 auto querySubQuantizer =
100 queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
103 for (
int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
104 smemQuery[i] = querySubQuantizer[i];
108 for (
int i = threadIdx.x;
109 i < topQueryToCentroid.getSize(1); i += blockDim.x) {
110 coarseIds[i] = topQueryToCentroid[queryId][i];
118 if (isLoadingThread) {
119 for (
int i = loadingThreadId;
120 i < DimsPerSubQuantizer;
121 i += blockDim.x - codesPerSubQuantizer) {
122 auto coarseId = coarseIds[0];
124 coarseId = coarseId == -1 ? 0 : coarseId;
125 auto coarseCentroidSubQuantizer =
126 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
128 smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
133 for (
int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
137 if (isLoadingThread) {
139 for (
int i = loadingThreadId;
140 i < DimsPerSubQuantizer;
141 i += blockDim.x - codesPerSubQuantizer) {
144 if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
145 auto coarseId = coarseIds[coarse + 1];
147 coarseId = coarseId == -1 ? 0 : coarseId;
149 auto coarseCentroidSubQuantizer =
150 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
152 smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
159 constexpr
int kUnroll = 4;
160 constexpr
int kRemainder = DimsPerSubQuantizer % kUnroll;
161 constexpr
int kRemainderBase = DimsPerSubQuantizer - kRemainder;
169 for (
int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
172 for (
int j = 0; j < kUnroll; ++j) {
173 vals[j] = smemResidual1[i * kUnroll + j];
177 for (
int j = 0; j < kUnroll; ++j) {
178 vals[j] -= subQuantizerData[i * kUnroll + j];
182 for (
int j = 0; j < kUnroll; ++j) {
187 for (
int j = 0; j < kUnroll; ++j) {
194 for (
int j = 0; j < kRemainder; ++j) {
195 vals[j] = smemResidual1[kRemainderBase + j];
199 for (
int j = 0; j < kRemainder; ++j) {
200 vals[j] -= subQuantizerData[kRemainderBase + j];
204 for (
int j = 0; j < kRemainder; ++j) {
209 for (
int j = 0; j < kRemainder; ++j) {
214 outCodeDistances[queryId][coarse][subQuantizer][code] =
215 Converter<OutCodeT>::to(dist);
219 float* tmp = smemResidual1;
220 smemResidual1 = smemResidual2;
227 residualVector(Tensor<float, 2, true> queries,
228 Tensor<float, 2, true> coarseCentroids,
229 Tensor<int, 2, true> topQueryToCentroid,
233 Tensor<float, 4, true> residual) {
237 auto queryId = blockIdx.x;
238 auto centroidId = blockIdx.y;
240 int realCentroidId = topQueryToCentroid[queryId][centroidId];
242 for (
int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
243 float q = queries[queryId][dim];
244 float c = coarseCentroids[realCentroidId][dim];
246 residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
252 runResidualVector(Tensor<float, 3, true>& pqCentroids,
253 Tensor<float, 2, true>& queries,
254 Tensor<float, 2, true>& coarseCentroids,
255 Tensor<int, 2, true>& topQueryToCentroid,
256 Tensor<float, 4, true>& residual,
257 cudaStream_t stream) {
259 dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
260 auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
262 residualVector<<<grid, block, 0, stream>>>(
263 queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
270 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
271 Tensor<float, 2, true>& queries,
272 Tensor<float, 2, true>& coarseCentroids,
273 Tensor<int, 2, true>& topQueryToCentroid,
274 NoTypeTensor<4, true>& outCodeDistances,
275 bool useFloat16Lookup,
277 cublasHandle_t handle,
278 cudaStream_t stream) {
281 DeviceTensor<float, 4, true> residual(
283 {pqCentroids.getSize(0),
284 topQueryToCentroid.getSize(0),
285 topQueryToCentroid.getSize(1),
286 pqCentroids.getSize(1)},
289 runResidualVector(pqCentroids, queries,
290 coarseCentroids, topQueryToCentroid,
294 DeviceTensor<float, 1, true> residualNorms(
296 {pqCentroids.getSize(0) *
297 topQueryToCentroid.getSize(0) *
298 topQueryToCentroid.getSize(1)},
301 auto residualView2 = residual.view<2>(
302 {pqCentroids.getSize(0) *
303 topQueryToCentroid.getSize(0) *
304 topQueryToCentroid.getSize(1),
305 pqCentroids.getSize(1)});
307 runL2Norm(residualView2,
true, residualNorms,
true, stream);
312 auto residualView3 = residual.view<3>(
313 {pqCentroids.getSize(0),
314 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
315 pqCentroids.getSize(1)});
317 DeviceTensor<float, 3, true> residualDistance(
319 {pqCentroids.getSize(0),
320 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
321 pqCentroids.getSize(2)},
324 runIteratedMatrixMult(residualDistance,
false,
325 residualView3,
false,
332 auto residualDistanceView2 = residualDistance.view<2>(
333 {pqCentroids.getSize(0) *
334 topQueryToCentroid.getSize(0) *
335 topQueryToCentroid.getSize(1),
336 pqCentroids.getSize(2)});
338 runSumAlongRows(residualNorms, residualDistanceView2,
false, stream);
340 Tensor<float, 4, true> outCodeDistancesF;
341 DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
343 #ifdef FAISS_USE_FLOAT16
344 if (useFloat16Lookup) {
345 outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
346 mem, {outCodeDistances.getSize(0),
347 outCodeDistances.getSize(1),
348 outCodeDistances.getSize(2),
349 outCodeDistances.getSize(3)},
352 outCodeDistancesF = outCodeDistancesFloatMem;
356 if (!useFloat16Lookup) {
357 outCodeDistancesF = outCodeDistances.toTensor<
float>();
362 auto outCodeDistancesView = outCodeDistancesF.view<3>(
363 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
364 outCodeDistances.getSize(2),
365 outCodeDistances.getSize(3)});
367 runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
372 DeviceTensor<float, 3, true> pqCentroidsTranspose(
374 {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
377 runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
379 auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
380 {pqCentroids.getSize(0) * pqCentroids.getSize(2),
381 pqCentroids.getSize(1)});
383 DeviceTensor<float, 1, true> pqCentroidsNorm(
385 {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
388 runL2Norm(pqCentroidsTransposeView,
true, pqCentroidsNorm,
true, stream);
392 auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
393 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
394 outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
396 runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
398 #ifdef FAISS_USE_FLOAT16
399 if (useFloat16Lookup) {
401 auto outCodeDistancesH = outCodeDistances.toTensor<half>();
402 toHalf(stream, outCodeDistancesF, outCodeDistancesH);
408 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
409 Tensor<float, 2, true>& queries,
410 Tensor<float, 2, true>& coarseCentroids,
411 Tensor<int, 2, true>& topQueryToCentroid,
412 NoTypeTensor<4, true>& outCodeDistances,
413 bool useFloat16Lookup,
414 cudaStream_t stream) {
415 const auto numSubQuantizers = pqCentroids.getSize(0);
416 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
417 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
422 constexpr
int kQueriesPerBlock = 8;
424 auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
429 auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
430 auto block = dim3(codesPerSubQuantizer + loadingThreads);
432 auto smem = (3 * dimsPerSubQuantizer) *
sizeof(
float)
433 + topQueryToCentroid.getSize(1) *
sizeof(int);
435 #ifdef FAISS_USE_FLOAT16
436 #define CODE_DISTANCE(DIMS) \
438 if (useFloat16Lookup) { \
439 auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
441 pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
442 queries, kQueriesPerBlock, \
443 coarseCentroids, pqCentroids, \
444 topQueryToCentroid, outCodeDistancesT); \
446 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
448 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
449 queries, kQueriesPerBlock, \
450 coarseCentroids, pqCentroids, \
451 topQueryToCentroid, outCodeDistancesT); \
455 #define CODE_DISTANCE(DIMS) \
457 if (!useFloat16Lookup) { \
458 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
460 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
461 queries, kQueriesPerBlock, \
462 coarseCentroids, pqCentroids, \
463 topQueryToCentroid, outCodeDistancesT); \
468 switch (dimsPerSubQuantizer) {