12 #include "Distance.cuh"
13 #include "BroadcastSum.cuh"
15 #include "L2Select.cuh"
16 #include "../../FaissAssert.h"
17 #include "../GpuResources.h"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Limits.cuh"
20 #include "../utils/MatrixMult.cuh"
21 #include "../utils/BlockSelectKernel.cuh"
24 #include <thrust/fill.h>
25 #include <thrust/for_each.h>
26 #include <thrust/device_ptr.h>
27 #include <thrust/execution_policy.h>
29 namespace faiss {
namespace gpu {
33 constexpr
int kDefaultTileSize = 256;
36 int chooseTileSize(
int tileSizeOverride,
38 size_t tempMemAvailable) {
39 if (tileSizeOverride > 0) {
40 return tileSizeOverride;
44 sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
46 while (tileSize > 64) {
47 size_t memRequirement = 2 * tileSize * numCentroids *
sizeof(T);
49 if (memRequirement <= tempMemAvailable) {
59 FAISS_ASSERT(tileSize >= 64);
70 void runL2Distance(GpuResources* resources,
71 Tensor<T, 2, true>& centroids,
72 Tensor<T, 2, true>* centroidsTransposed,
73 Tensor<T, 1, true>* centroidNorms,
74 Tensor<T, 2, true>& queries,
76 Tensor<T, 2, true>& outDistances,
77 Tensor<int, 2, true>& outIndices,
78 bool ignoreOutDistances =
false,
79 int tileSizeOverride = -1) {
80 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
81 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
82 FAISS_ASSERT(outDistances.getSize(1) == k);
83 FAISS_ASSERT(outIndices.getSize(1) == k);
85 auto& mem = resources->getMemoryManagerCurrentDevice();
86 auto defaultStream = resources->getDefaultStreamCurrentDevice();
89 if (centroids.numElements() == 0) {
90 thrust::fill(thrust::cuda::par.on(defaultStream),
91 outDistances.data(), outDistances.end(),
94 thrust::fill(thrust::cuda::par.on(defaultStream),
95 outIndices.data(), outIndices.end(),
102 DeviceTensor<T, 1, true> cNorms;
103 if (!centroidNorms) {
104 cNorms = std::move(DeviceTensor<T, 1, true>(
106 {centroids.getSize(0)}, defaultStream));
107 runL2Norm(centroids, cNorms,
true, defaultStream);
108 centroidNorms = &cNorms;
114 int qNormSize[1] = {queries.getSize(0)};
115 DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
118 runL2Norm(queries, queryNorms,
true, defaultStream);
125 FAISS_ASSERT(k <= centroids.getSize(0));
126 FAISS_ASSERT(k <= 1024);
131 centroids.getSize(0),
132 resources->getMemoryManagerCurrentDevice().getSizeAvailable());
134 int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
137 DeviceTensor<T, 2, true> distanceBuf1(
138 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
139 DeviceTensor<T, 2, true> distanceBuf2(
140 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
141 DeviceTensor<T, 2, true>* distanceBufs[2] =
142 {&distanceBuf1, &distanceBuf2};
144 auto streams = resources->getAlternateStreamsCurrentDevice();
145 streamWait(streams, {defaultStream});
149 for (
int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
150 int numQueriesForIteration = std::min(maxQueriesPerIteration,
151 queries.getSize(0) - i);
153 auto distanceBufView =
154 distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
156 queries.narrowOutermost(i, numQueriesForIteration);
157 auto outDistanceView =
158 outDistances.narrowOutermost(i, numQueriesForIteration);
160 outIndices.narrowOutermost(i, numQueriesForIteration);
162 queryNorms.narrowOutermost(i, numQueriesForIteration);
168 runMatrixMult(distanceBufView,
false,
170 centroidsTransposed ? *centroidsTransposed : centroids,
171 centroidsTransposed ?
false :
true,
173 resources->getBlasHandleCurrentDevice(),
180 runL2SelectMin(distanceBufView,
187 if (!ignoreOutDistances) {
190 runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
193 curStream = (curStream + 1) % 2;
197 streamWait({defaultStream}, streams);
200 template <
typename T>
201 void runIPDistance(GpuResources* resources,
202 Tensor<T, 2, true>& centroids,
203 Tensor<T, 2, true>* centroidsTransposed,
204 Tensor<T, 2, true>& queries,
206 Tensor<T, 2, true>& outDistances,
207 Tensor<int, 2, true>& outIndices,
208 int tileSizeOverride = -1) {
209 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
210 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
211 FAISS_ASSERT(outDistances.getSize(1) == k);
212 FAISS_ASSERT(outIndices.getSize(1) == k);
214 auto& mem = resources->getMemoryManagerCurrentDevice();
215 auto defaultStream = resources->getDefaultStreamCurrentDevice();
218 if (centroids.numElements() == 0) {
219 thrust::fill(thrust::cuda::par.on(defaultStream),
220 outDistances.data(), outDistances.end(),
221 Limits<T>::getMax());
223 thrust::fill(thrust::cuda::par.on(defaultStream),
224 outIndices.data(), outIndices.end(),
235 FAISS_ASSERT(k <= centroids.getSize(0));
236 FAISS_ASSERT(k <= 1024);
241 centroids.getSize(0),
242 resources->getMemoryManagerCurrentDevice().getSizeAvailable());
244 int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
247 DeviceTensor<T, 2, true> distanceBuf1(
248 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
249 DeviceTensor<T, 2, true> distanceBuf2(
250 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
251 DeviceTensor<T, 2, true>* distanceBufs[2] =
252 {&distanceBuf1, &distanceBuf2};
254 auto streams = resources->getAlternateStreamsCurrentDevice();
255 streamWait(streams, {defaultStream});
259 for (
int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
260 int numQueriesForIteration = std::min(maxQueriesPerIteration,
261 queries.getSize(0) - i);
263 auto distanceBufView =
264 distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
266 queries.narrowOutermost(i, numQueriesForIteration);
267 auto outDistanceView =
268 outDistances.narrowOutermost(i, numQueriesForIteration);
270 outIndices.narrowOutermost(i, numQueriesForIteration);
273 runMatrixMult(distanceBufView,
false,
275 centroidsTransposed ? *centroidsTransposed : centroids,
276 centroidsTransposed ?
false :
true,
278 resources->getBlasHandleCurrentDevice(),
283 runBlockSelect(distanceBufView,
286 true, k, streams[curStream]);
288 curStream = (curStream + 1) % 2;
291 streamWait({defaultStream}, streams);
299 runIPDistance(GpuResources* resources,
300 Tensor<float, 2, true>& vectors,
301 Tensor<float, 2, true>* vectorsTransposed,
302 Tensor<float, 2, true>& queries,
304 Tensor<float, 2, true>& outDistances,
305 Tensor<int, 2, true>& outIndices,
306 int tileSizeOverride) {
307 runIPDistance<float>(resources,
317 #ifdef FAISS_USE_FLOAT16
319 runIPDistance(GpuResources* resources,
320 Tensor<half, 2, true>& vectors,
321 Tensor<half, 2, true>* vectorsTransposed,
322 Tensor<half, 2, true>& queries,
324 Tensor<half, 2, true>& outDistances,
325 Tensor<int, 2, true>& outIndices,
326 int tileSizeOverride) {
327 runIPDistance<half>(resources,
339 runL2Distance(GpuResources* resources,
340 Tensor<float, 2, true>& vectors,
341 Tensor<float, 2, true>* vectorsTransposed,
342 Tensor<float, 1, true>* vectorNorms,
343 Tensor<float, 2, true>& queries,
345 Tensor<float, 2, true>& outDistances,
346 Tensor<int, 2, true>& outIndices,
347 bool ignoreOutDistances,
348 int tileSizeOverride) {
349 runL2Distance<float>(resources,
361 #ifdef FAISS_USE_FLOAT16
363 runL2Distance(GpuResources* resources,
364 Tensor<half, 2, true>& vectors,
365 Tensor<half, 2, true>* vectorsTransposed,
366 Tensor<half, 1, true>* vectorNorms,
367 Tensor<half, 2, true>& queries,
369 Tensor<half, 2, true>& outDistances,
370 Tensor<int, 2, true>& outIndices,
371 bool ignoreOutDistances,
372 int tileSizeOverride) {
373 runL2Distance<half>(resources,