Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "Distance.cuh"
12 #include "BroadcastSum.cuh"
13 #include "L2Norm.cuh"
14 #include "L2Select.cuh"
15 #include "../../FaissAssert.h"
16 #include "../GpuResources.h"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
21 
22 #include <memory>
23 #include <thrust/fill.h>
24 #include <thrust/for_each.h>
25 #include <thrust/device_ptr.h>
26 #include <thrust/execution_policy.h>
27 
28 namespace faiss { namespace gpu {
29 
30 namespace {
31 
32 template <typename T>
33 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
34  Tensor<T, 2, true>* centroidsTransposed,
35  int startCentroid,
36  int num) {
37  if (startCentroid == 0 && num == centroids.getSize(0)) {
38  if (centroidsTransposed) {
39  return *centroidsTransposed;
40  } else {
41  return centroids;
42  }
43  }
44 
45  if (centroidsTransposed) {
46  // (dim, num)
47  return centroidsTransposed->narrow(1, startCentroid, num);
48  } else {
49  return centroids.narrow(0, startCentroid, num);
50  }
51 }
52 
53 // For each chunk of k indices, increment the index by chunk * increment
54 template <typename T>
55 __global__ void incrementIndex(Tensor<T, 2, true> indices,
56  int k,
57  int increment) {
58  for (int i = threadIdx.x; i < k; i += blockDim.x) {
59  indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
60  }
61 }
62 
63 // Used to update result indices in distance computation where the number of
64 // centroids is high, and is tiled
65 template <typename T>
66 void runIncrementIndex(Tensor<T, 2, true>& indices,
67  int k,
68  int increment,
69  cudaStream_t stream) {
70  dim3 grid(indices.getSize(1) / k, indices.getSize(0));
71  int block = std::min(k, 512);
72 
73  // should be exact
74  FAISS_ASSERT(grid.x * k == indices.getSize(1));
75 
76  incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
77 
78  cudaDeviceSynchronize();
79 }
80 
81 // If the inner size (dim) of the vectors is small, we want a larger query tile
82 // size, like 1024
83 
84 void chooseTileSize(int numQueries,
85  int numCentroids,
86  int dim,
87  int elementSize,
88  size_t tempMemAvailable,
89  int& tileRows,
90  int& tileCols) {
91  // The matrix multiplication should be large enough to be efficient, but if it
92  // is too large, we seem to lose efficiency as opposed to double-streaming.
93  // Each tile size here defines 1/2 of the memory use due to double streaming.
94  // We ignore available temporary memory, as that is adjusted independently by
95  // the user and can thus meet these requirements (or not).
96  // For <= 4 GB GPUs, prefer 512 MB of usage.
97  // For <= 8 GB GPUs, prefer 768 MB of usage.
98  // Otherwise, prefer 1 GB of usage.
99  auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
100 
101  int targetUsage = 0;
102 
103  if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) {
104  targetUsage = 512 * 1024 * 1024;
105  } else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) {
106  targetUsage = 768 * 1024 * 1024;
107  } else {
108  targetUsage = 1024 * 1024 * 1024;
109  }
110 
111  targetUsage /= 2 * elementSize;
112 
113  // 512 seems to be a batch size sweetspot for float32.
114  // If we are on float16, increase to 512.
115  // If the k size (vec dim) of the matrix multiplication is small (<= 32),
116  // increase to 1024.
117  int preferredTileRows = 512;
118  if (dim <= 32) {
119  preferredTileRows = 1024;
120  }
121 
122  tileRows = std::min(preferredTileRows, numQueries);
123 
124  // tileCols is the remainder size
125  tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
126 }
127 
128 }
129 
130 template <typename T>
131 void runDistance(bool computeL2,
132  GpuResources* resources,
133  Tensor<T, 2, true>& centroids,
134  Tensor<T, 2, true>* centroidsTransposed,
135  Tensor<T, 1, true>* centroidNorms,
136  Tensor<T, 2, true>& queries,
137  int k,
138  Tensor<T, 2, true>& outDistances,
139  Tensor<int, 2, true>& outIndices,
140  bool useHgemm,
141  bool ignoreOutDistances) {
142  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
143  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
144  FAISS_ASSERT(outDistances.getSize(1) == k);
145  FAISS_ASSERT(outIndices.getSize(1) == k);
146 
147  auto& mem = resources->getMemoryManagerCurrentDevice();
148  auto defaultStream = resources->getDefaultStreamCurrentDevice();
149 
150  // If we're quering against a 0 sized set, just return empty results
151  if (centroids.numElements() == 0) {
152  thrust::fill(thrust::cuda::par.on(defaultStream),
153  outDistances.data(), outDistances.end(),
154  Limits<T>::getMax());
155 
156  thrust::fill(thrust::cuda::par.on(defaultStream),
157  outIndices.data(), outIndices.end(),
158  -1);
159 
160  return;
161  }
162 
163  // L2: If ||c||^2 is not pre-computed, calculate it
164  DeviceTensor<T, 1, true> cNorms;
165  if (computeL2 && !centroidNorms) {
166  cNorms = std::move(DeviceTensor<T, 1, true>(
167  mem,
168  {centroids.getSize(0)}, defaultStream));
169  runL2Norm(centroids, cNorms, true, defaultStream);
170  centroidNorms = &cNorms;
171  }
172 
173  //
174  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
175  //
176  int qNormSize[1] = {queries.getSize(0)};
177  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
178 
179  // ||q||^2
180  if (computeL2) {
181  runL2Norm(queries, queryNorms, true, defaultStream);
182  }
183 
184  // By default, aim to use up to 512 MB of memory for the processing, with both
185  // number of queries and number of centroids being at least 512.
186  int tileRows = 0;
187  int tileCols = 0;
188  chooseTileSize(queries.getSize(0),
189  centroids.getSize(0),
190  queries.getSize(1),
191  sizeof(T),
192  mem.getSizeAvailable(),
193  tileRows,
194  tileCols);
195 
196  int numColTiles = utils::divUp(centroids.getSize(0), tileCols);
197 
198  FAISS_ASSERT(k <= centroids.getSize(0));
199  FAISS_ASSERT(k <= 1024); // select limitation
200 
201  // Temporary output memory space we'll use
202  DeviceTensor<T, 2, true> distanceBuf1(
203  mem, {tileRows, tileCols}, defaultStream);
204  DeviceTensor<T, 2, true> distanceBuf2(
205  mem, {tileRows, tileCols}, defaultStream);
206  DeviceTensor<T, 2, true>* distanceBufs[2] =
207  {&distanceBuf1, &distanceBuf2};
208 
209  DeviceTensor<T, 2, true> outDistanceBuf1(
210  mem, {tileRows, numColTiles * k}, defaultStream);
211  DeviceTensor<T, 2, true> outDistanceBuf2(
212  mem, {tileRows, numColTiles * k}, defaultStream);
213  DeviceTensor<T, 2, true>* outDistanceBufs[2] =
214  {&outDistanceBuf1, &outDistanceBuf2};
215 
216  DeviceTensor<int, 2, true> outIndexBuf1(
217  mem, {tileRows, numColTiles * k}, defaultStream);
218  DeviceTensor<int, 2, true> outIndexBuf2(
219  mem, {tileRows, numColTiles * k}, defaultStream);
220  DeviceTensor<int, 2, true>* outIndexBufs[2] =
221  {&outIndexBuf1, &outIndexBuf2};
222 
223  auto streams = resources->getAlternateStreamsCurrentDevice();
224  streamWait(streams, {defaultStream});
225 
226  int curStream = 0;
227 
228  // Tile over the input queries
229  for (int i = 0; i < queries.getSize(0); i += tileRows) {
230  int curQuerySize = std::min(tileRows, queries.getSize(0) - i);
231 
232  auto outDistanceView =
233  outDistances.narrow(0, i, curQuerySize);
234  auto outIndexView =
235  outIndices.narrow(0, i, curQuerySize);
236 
237  auto queryView =
238  queries.narrow(0, i, curQuerySize);
239  auto queryNormNiew =
240  queryNorms.narrow(0, i, curQuerySize);
241 
242  auto outDistanceBufRowView =
243  outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
244  auto outIndexBufRowView =
245  outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
246 
247  // Tile over the centroids
248  for (int j = 0; j < centroids.getSize(0); j += tileCols) {
249  int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);
250 
251  int curColTile = j / tileCols;
252 
253  auto centroidsView =
254  sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);
255 
256  auto distanceBufView = distanceBufs[curStream]->
257  narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
258 
259  auto outDistanceBufColView =
260  outDistanceBufRowView.narrow(1, k * curColTile, k);
261  auto outIndexBufColView =
262  outIndexBufRowView.narrow(1, k * curColTile, k);
263 
264  // L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc
265  // IP: just compute qc
266  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
267  runMatrixMult(distanceBufView, false,
268  queryView, false,
269  centroidsView,
270  centroidsTransposed ? false : true,
271  computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,
272  resources->getBlasHandleCurrentDevice(),
273  streams[curStream]);
274 
275  if (computeL2) {
276  // For L2 distance, we use this fused kernel that performs both
277  // adding ||c||^2 to -2qc and k-selection, so we only need two
278  // passes (one write by the gemm, one read here) over the huge
279  // region of output memory
280  //
281  // If we aren't tiling along the number of centroids, we can perform the
282  // output work directly
283  if (tileCols == centroids.getSize(0)) {
284  // Write into the final output
285  runL2SelectMin(distanceBufView,
286  *centroidNorms,
287  outDistanceView,
288  outIndexView,
289  k,
290  streams[curStream]);
291 
292  if (!ignoreOutDistances) {
293  // expand (query id) to (query id, k) by duplicating along rows
294  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
295  runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
296  }
297  } else {
298  auto centroidNormsView =
299  centroidNorms->narrow(0, j, curCentroidSize);
300 
301  // Write into our intermediate output
302  runL2SelectMin(distanceBufView,
303  centroidNormsView,
304  outDistanceBufColView,
305  outIndexBufColView,
306  k,
307  streams[curStream]);
308 
309  if (!ignoreOutDistances) {
310  // expand (query id) to (query id, k) by duplicating along rows
311  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
312  runSumAlongRows(queryNormNiew,
313  outDistanceBufColView,
314  streams[curStream]);
315  }
316  }
317  } else {
318  // For IP, just k-select the output for this tile
319  if (tileCols == centroids.getSize(0)) {
320  // Write into the final output
321  runBlockSelect(distanceBufView,
322  outDistanceView,
323  outIndexView,
324  true, k, streams[curStream]);
325  } else {
326  // Write into the intermediate output
327  runBlockSelect(distanceBufView,
328  outDistanceBufColView,
329  outIndexBufColView,
330  true, k, streams[curStream]);
331  }
332  }
333  }
334 
335  // As we're finished with processing a full set of centroids, perform the
336  // final k-selection
337  if (tileCols != centroids.getSize(0)) {
338  // The indices are tile-relative; for each tile of k, we need to add
339  // tileCols to the index
340  runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
341 
342  runBlockSelectPair(outDistanceBufRowView,
343  outIndexBufRowView,
344  outDistanceView,
345  outIndexView,
346  computeL2 ? false : true, k, streams[curStream]);
347  }
348 
349  curStream = (curStream + 1) % 2;
350  }
351 
352  // Have the desired ordering stream wait on the multi-stream
353  streamWait({defaultStream}, streams);
354 }
355 
356 template <typename T>
357 void runL2Distance(GpuResources* resources,
358  Tensor<T, 2, true>& centroids,
359  Tensor<T, 2, true>* centroidsTransposed,
360  Tensor<T, 1, true>* centroidNorms,
361  Tensor<T, 2, true>& queries,
362  int k,
363  Tensor<T, 2, true>& outDistances,
364  Tensor<int, 2, true>& outIndices,
365  bool useHgemm,
366  bool ignoreOutDistances = false) {
367  runDistance<T>(true, // L2
368  resources,
369  centroids,
370  centroidsTransposed,
371  centroidNorms,
372  queries,
373  k,
374  outDistances,
375  outIndices,
376  useHgemm,
377  ignoreOutDistances);
378 }
379 
380 template <typename T>
381 void runIPDistance(GpuResources* resources,
382  Tensor<T, 2, true>& centroids,
383  Tensor<T, 2, true>* centroidsTransposed,
384  Tensor<T, 2, true>& queries,
385  int k,
386  Tensor<T, 2, true>& outDistances,
387  Tensor<int, 2, true>& outIndices,
388  bool useHgemm) {
389  runDistance<T>(false, // IP
390  resources,
391  centroids,
392  centroidsTransposed,
393  nullptr,
394  queries,
395  k,
396  outDistances,
397  outIndices,
398  useHgemm,
399  false);
400 }
401 
402 //
403 // Instantiations of the distance templates
404 //
405 
406 void
407 runIPDistance(GpuResources* resources,
408  Tensor<float, 2, true>& vectors,
409  Tensor<float, 2, true>* vectorsTransposed,
410  Tensor<float, 2, true>& queries,
411  int k,
412  Tensor<float, 2, true>& outDistances,
413  Tensor<int, 2, true>& outIndices) {
414  runIPDistance<float>(resources,
415  vectors,
416  vectorsTransposed,
417  queries,
418  k,
419  outDistances,
420  outIndices,
421  false);
422 }
423 
424 #ifdef FAISS_USE_FLOAT16
425 void
426 runIPDistance(GpuResources* resources,
427  Tensor<half, 2, true>& vectors,
428  Tensor<half, 2, true>* vectorsTransposed,
429  Tensor<half, 2, true>& queries,
430  int k,
431  Tensor<half, 2, true>& outDistances,
432  Tensor<int, 2, true>& outIndices,
433  bool useHgemm) {
434  runIPDistance<half>(resources,
435  vectors,
436  vectorsTransposed,
437  queries,
438  k,
439  outDistances,
440  outIndices,
441  useHgemm);
442 }
443 #endif
444 
445 void
446 runL2Distance(GpuResources* resources,
447  Tensor<float, 2, true>& vectors,
448  Tensor<float, 2, true>* vectorsTransposed,
449  Tensor<float, 1, true>* vectorNorms,
450  Tensor<float, 2, true>& queries,
451  int k,
452  Tensor<float, 2, true>& outDistances,
453  Tensor<int, 2, true>& outIndices,
454  bool ignoreOutDistances) {
455  runL2Distance<float>(resources,
456  vectors,
457  vectorsTransposed,
458  vectorNorms,
459  queries,
460  k,
461  outDistances,
462  outIndices,
463  false,
464  ignoreOutDistances);
465 }
466 
467 #ifdef FAISS_USE_FLOAT16
468 void
469 runL2Distance(GpuResources* resources,
470  Tensor<half, 2, true>& vectors,
471  Tensor<half, 2, true>* vectorsTransposed,
472  Tensor<half, 1, true>* vectorNorms,
473  Tensor<half, 2, true>& queries,
474  int k,
475  Tensor<half, 2, true>& outDistances,
476  Tensor<int, 2, true>& outIndices,
477  bool useHgemm,
478  bool ignoreOutDistances) {
479  runL2Distance<half>(resources,
480  vectors,
481  vectorsTransposed,
482  vectorNorms,
483  queries,
484  k,
485  outDistances,
486  outIndices,
487  useHgemm,
488  ignoreOutDistances);
489 }
490 #endif
491 
492 } } // namespace