10 #include "Distance.cuh"
11 #include "BroadcastSum.cuh"
13 #include "L2Select.cuh"
14 #include "../../FaissAssert.h"
15 #include "../GpuResources.h"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Limits.cuh"
18 #include "../utils/MatrixMult.cuh"
19 #include "../utils/BlockSelectKernel.cuh"
22 #include <thrust/fill.h>
23 #include <thrust/for_each.h>
24 #include <thrust/device_ptr.h>
25 #include <thrust/execution_policy.h>
27 namespace faiss {
namespace gpu {
32 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
33 Tensor<T, 2, true>* centroidsTransposed,
36 if (startCentroid == 0 && num == centroids.getSize(0)) {
37 if (centroidsTransposed) {
38 return *centroidsTransposed;
44 if (centroidsTransposed) {
46 return centroidsTransposed->narrow(1, startCentroid, num);
48 return centroids.narrow(0, startCentroid, num);
54 __global__
void incrementIndex(Tensor<T, 2, true> indices,
57 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
58 indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
65 void runIncrementIndex(Tensor<T, 2, true>& indices,
68 cudaStream_t stream) {
69 dim3 grid(indices.getSize(1) / k, indices.getSize(0));
70 int block = std::min(k, 512);
73 FAISS_ASSERT(grid.x * k == indices.getSize(1));
75 incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
77 cudaDeviceSynchronize();
83 void chooseTileSize(
int numQueries,
87 size_t tempMemAvailable,
98 auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
102 if (totalMem <= ((
size_t) 4) * 1024 * 1024 * 1024) {
103 targetUsage = 512 * 1024 * 1024;
104 }
else if (totalMem <= ((
size_t) 8) * 1024 * 1024 * 1024) {
105 targetUsage = 768 * 1024 * 1024;
107 targetUsage = 1024 * 1024 * 1024;
110 targetUsage /= 2 * elementSize;
116 int preferredTileRows = 512;
118 preferredTileRows = 1024;
121 tileRows = std::min(preferredTileRows, numQueries);
124 tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
129 template <
typename T>
130 void runDistance(
bool computeL2,
131 GpuResources* resources,
132 Tensor<T, 2, true>& centroids,
133 Tensor<T, 2, true>* centroidsTransposed,
134 Tensor<T, 1, true>* centroidNorms,
135 Tensor<T, 2, true>& queries,
137 Tensor<T, 2, true>& outDistances,
138 Tensor<int, 2, true>& outIndices,
140 bool ignoreOutDistances) {
141 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
142 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
143 FAISS_ASSERT(outDistances.getSize(1) == k);
144 FAISS_ASSERT(outIndices.getSize(1) == k);
146 auto& mem = resources->getMemoryManagerCurrentDevice();
147 auto defaultStream = resources->getDefaultStreamCurrentDevice();
150 if (centroids.numElements() == 0) {
151 thrust::fill(thrust::cuda::par.on(defaultStream),
152 outDistances.data(), outDistances.end(),
153 Limits<T>::getMax());
155 thrust::fill(thrust::cuda::par.on(defaultStream),
156 outIndices.data(), outIndices.end(),
163 DeviceTensor<T, 1, true> cNorms;
164 if (computeL2 && !centroidNorms) {
165 cNorms = std::move(DeviceTensor<T, 1, true>(
167 {centroids.getSize(0)}, defaultStream));
168 runL2Norm(centroids, cNorms,
true, defaultStream);
169 centroidNorms = &cNorms;
175 int qNormSize[1] = {queries.getSize(0)};
176 DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
180 runL2Norm(queries, queryNorms,
true, defaultStream);
187 chooseTileSize(queries.getSize(0),
188 centroids.getSize(0),
191 mem.getSizeAvailable(),
195 int numColTiles = utils::divUp(centroids.getSize(0), tileCols);
197 FAISS_ASSERT(k <= centroids.getSize(0));
198 FAISS_ASSERT(k <= 1024);
201 DeviceTensor<T, 2, true> distanceBuf1(
202 mem, {tileRows, tileCols}, defaultStream);
203 DeviceTensor<T, 2, true> distanceBuf2(
204 mem, {tileRows, tileCols}, defaultStream);
205 DeviceTensor<T, 2, true>* distanceBufs[2] =
206 {&distanceBuf1, &distanceBuf2};
208 DeviceTensor<T, 2, true> outDistanceBuf1(
209 mem, {tileRows, numColTiles * k}, defaultStream);
210 DeviceTensor<T, 2, true> outDistanceBuf2(
211 mem, {tileRows, numColTiles * k}, defaultStream);
212 DeviceTensor<T, 2, true>* outDistanceBufs[2] =
213 {&outDistanceBuf1, &outDistanceBuf2};
215 DeviceTensor<int, 2, true> outIndexBuf1(
216 mem, {tileRows, numColTiles * k}, defaultStream);
217 DeviceTensor<int, 2, true> outIndexBuf2(
218 mem, {tileRows, numColTiles * k}, defaultStream);
219 DeviceTensor<int, 2, true>* outIndexBufs[2] =
220 {&outIndexBuf1, &outIndexBuf2};
222 auto streams = resources->getAlternateStreamsCurrentDevice();
223 streamWait(streams, {defaultStream});
228 for (
int i = 0; i < queries.getSize(0); i += tileRows) {
229 int curQuerySize = std::min(tileRows, queries.getSize(0) - i);
231 auto outDistanceView =
232 outDistances.narrow(0, i, curQuerySize);
234 outIndices.narrow(0, i, curQuerySize);
237 queries.narrow(0, i, curQuerySize);
239 queryNorms.narrow(0, i, curQuerySize);
241 auto outDistanceBufRowView =
242 outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
243 auto outIndexBufRowView =
244 outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
247 for (
int j = 0; j < centroids.getSize(0); j += tileCols) {
248 int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);
250 int curColTile = j / tileCols;
253 sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);
255 auto distanceBufView = distanceBufs[curStream]->
256 narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
258 auto outDistanceBufColView =
259 outDistanceBufRowView.narrow(1, k * curColTile, k);
260 auto outIndexBufColView =
261 outIndexBufRowView.narrow(1, k * curColTile, k);
266 runMatrixMult(distanceBufView,
false,
269 centroidsTransposed ?
false :
true,
270 computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,
271 resources->getBlasHandleCurrentDevice(),
282 if (tileCols == centroids.getSize(0)) {
284 runL2SelectMin(distanceBufView,
291 if (!ignoreOutDistances) {
294 runSumAlongRows(queryNormNiew,
301 auto centroidNormsView =
302 centroidNorms->narrow(0, j, curCentroidSize);
305 runL2SelectMin(distanceBufView,
307 outDistanceBufColView,
312 if (!ignoreOutDistances) {
315 runSumAlongRows(queryNormNiew,
316 outDistanceBufColView,
324 if (tileCols == centroids.getSize(0)) {
326 runBlockSelect(distanceBufView,
329 true, k, streams[curStream]);
332 runBlockSelect(distanceBufView,
333 outDistanceBufColView,
335 true, k, streams[curStream]);
342 if (tileCols != centroids.getSize(0)) {
345 runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
347 runBlockSelectPair(outDistanceBufRowView,
351 computeL2 ?
false :
true, k, streams[curStream]);
354 curStream = (curStream + 1) % 2;
358 streamWait({defaultStream}, streams);
361 template <
typename T>
362 void runL2Distance(GpuResources* resources,
363 Tensor<T, 2, true>& centroids,
364 Tensor<T, 2, true>* centroidsTransposed,
365 Tensor<T, 1, true>* centroidNorms,
366 Tensor<T, 2, true>& queries,
368 Tensor<T, 2, true>& outDistances,
369 Tensor<int, 2, true>& outIndices,
371 bool ignoreOutDistances =
false) {
385 template <
typename T>
386 void runIPDistance(GpuResources* resources,
387 Tensor<T, 2, true>& centroids,
388 Tensor<T, 2, true>* centroidsTransposed,
389 Tensor<T, 2, true>& queries,
391 Tensor<T, 2, true>& outDistances,
392 Tensor<int, 2, true>& outIndices,
394 runDistance<T>(
false,
412 runIPDistance(GpuResources* resources,
413 Tensor<float, 2, true>& vectors,
414 Tensor<float, 2, true>* vectorsTransposed,
415 Tensor<float, 2, true>& queries,
417 Tensor<float, 2, true>& outDistances,
418 Tensor<int, 2, true>& outIndices) {
419 runIPDistance<float>(resources,
429 #ifdef FAISS_USE_FLOAT16
431 runIPDistance(GpuResources* resources,
432 Tensor<half, 2, true>& vectors,
433 Tensor<half, 2, true>* vectorsTransposed,
434 Tensor<half, 2, true>& queries,
436 Tensor<half, 2, true>& outDistances,
437 Tensor<int, 2, true>& outIndices,
439 runIPDistance<half>(resources,
451 runL2Distance(GpuResources* resources,
452 Tensor<float, 2, true>& vectors,
453 Tensor<float, 2, true>* vectorsTransposed,
454 Tensor<float, 1, true>* vectorNorms,
455 Tensor<float, 2, true>& queries,
457 Tensor<float, 2, true>& outDistances,
458 Tensor<int, 2, true>& outIndices,
459 bool ignoreOutDistances) {
460 runL2Distance<float>(resources,
472 #ifdef FAISS_USE_FLOAT16
474 runL2Distance(GpuResources* resources,
475 Tensor<half, 2, true>& vectors,
476 Tensor<half, 2, true>* vectorsTransposed,
477 Tensor<half, 1, true>* vectorNorms,
478 Tensor<half, 2, true>& queries,
480 Tensor<half, 2, true>& outDistances,
481 Tensor<int, 2, true>& outIndices,
483 bool ignoreOutDistances) {
484 runL2Distance<half>(resources,