Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "Distance.cuh"
10 #include "BroadcastSum.cuh"
11 #include "L2Norm.cuh"
12 #include "L2Select.cuh"
13 #include "../../FaissAssert.h"
14 #include "../../AuxIndexStructures.h"
15 #include "../GpuResources.h"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
21 
22 #include <memory>
23 #include <algorithm>
24 #include <thrust/fill.h>
25 #include <thrust/for_each.h>
26 #include <thrust/device_ptr.h>
27 #include <thrust/execution_policy.h>
28 
29 namespace faiss { namespace gpu {
30 
31 namespace {
32 
33 template <typename T>
34 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
35  bool centroidsRowMajor,
36  int startCentroid,
37  int num) {
38  // Row major is (num, dim)
39  // Col major is (dim, num)
40  if (startCentroid == 0 &&
41  num == centroids.getSize(centroidsRowMajor ? 0 : 1)) {
42  return centroids;
43  }
44 
45  return centroids.narrow(centroidsRowMajor ? 0 : 1, startCentroid, num);
46 }
47 
48 // For each chunk of k indices, increment the index by chunk * increment
49 template <typename T>
50 __global__ void incrementIndex(Tensor<T, 2, true> indices,
51  int k,
52  int increment) {
53  for (int i = threadIdx.x; i < k; i += blockDim.x) {
54  indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
55  }
56 }
57 
58 // Used to update result indices in distance computation where the number of
59 // centroids is high, and is tiled
60 template <typename T>
61 void runIncrementIndex(Tensor<T, 2, true>& indices,
62  int k,
63  int increment,
64  cudaStream_t stream) {
65  dim3 grid(indices.getSize(1) / k, indices.getSize(0));
66  int block = std::min(k, 512);
67 
68  // should be exact
69  FAISS_ASSERT(grid.x * k == indices.getSize(1));
70 
71  incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
72 
73  cudaDeviceSynchronize();
74 }
75 
76 // If the inner size (dim) of the vectors is small, we want a larger query tile
77 // size, like 1024
78 
79 void chooseTileSize(int numQueries,
80  int numCentroids,
81  int dim,
82  int elementSize,
83  size_t tempMemAvailable,
84  int& tileRows,
85  int& tileCols) {
86  // The matrix multiplication should be large enough to be efficient, but if it
87  // is too large, we seem to lose efficiency as opposed to double-streaming.
88  // Each tile size here defines 1/2 of the memory use due to double streaming.
89  // We ignore available temporary memory, as that is adjusted independently by
90  // the user and can thus meet these requirements (or not).
91  // For <= 4 GB GPUs, prefer 512 MB of usage.
92  // For <= 8 GB GPUs, prefer 768 MB of usage.
93  // Otherwise, prefer 1 GB of usage.
94  auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
95 
96  int targetUsage = 0;
97 
98  if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) {
99  targetUsage = 512 * 1024 * 1024;
100  } else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) {
101  targetUsage = 768 * 1024 * 1024;
102  } else {
103  targetUsage = 1024 * 1024 * 1024;
104  }
105 
106  targetUsage /= 2 * elementSize;
107 
108  // 512 seems to be a batch size sweetspot for float32.
109  // If we are on float16, increase to 512.
110  // If the k size (vec dim) of the matrix multiplication is small (<= 32),
111  // increase to 1024.
112  int preferredTileRows = 512;
113  if (dim <= 32) {
114  preferredTileRows = 1024;
115  }
116 
117  tileRows = std::min(preferredTileRows, numQueries);
118 
119  // tileCols is the remainder size
120  tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
121 }
122 
123 }
124 
125 template <typename T>
126 void runDistance(bool computeL2,
127  GpuResources* resources,
128  Tensor<T, 2, true>& centroids,
129  bool centroidsRowMajor,
130  Tensor<T, 1, true>* centroidNorms,
131  Tensor<T, 2, true>& queries,
132  bool queriesRowMajor,
133  int k,
134  Tensor<T, 2, true>& outDistances,
135  Tensor<int, 2, true>& outIndices,
136  bool useHgemm,
137  bool ignoreOutDistances) {
138  // The # of centroids in `centroids` based on memory layout
139  auto numCentroids = centroids.getSize(centroidsRowMajor ? 0 : 1);
140 
141  // The # of queries in `queries` based on memory layout
142  auto numQueries = queries.getSize(queriesRowMajor ? 0 : 1);
143 
144  // The dimensions of the vectors to consider
145  auto dim = queries.getSize(queriesRowMajor ? 1 : 0);
146  FAISS_ASSERT((numQueries == 0 || numCentroids == 0) ||
147  dim == centroids.getSize(centroidsRowMajor ? 1 : 0));
148 
149  FAISS_ASSERT(outDistances.getSize(0) == numQueries);
150  FAISS_ASSERT(outIndices.getSize(0) == numQueries);
151  FAISS_ASSERT(outDistances.getSize(1) == k);
152  FAISS_ASSERT(outIndices.getSize(1) == k);
153 
154  auto& mem = resources->getMemoryManagerCurrentDevice();
155  auto defaultStream = resources->getDefaultStreamCurrentDevice();
156 
157  // If we're quering against a 0 sized set, just return empty results
158  if (centroids.numElements() == 0) {
159  thrust::fill(thrust::cuda::par.on(defaultStream),
160  outDistances.data(), outDistances.end(),
161  Limits<T>::getMax());
162 
163  thrust::fill(thrust::cuda::par.on(defaultStream),
164  outIndices.data(), outIndices.end(),
165  -1);
166 
167  return;
168  }
169 
170  // L2: If ||c||^2 is not pre-computed, calculate it
171  DeviceTensor<T, 1, true> cNorms;
172  if (computeL2 && !centroidNorms) {
173  cNorms =
174  std::move(DeviceTensor<T, 1, true>(mem,
175  {numCentroids}, defaultStream));
176  runL2Norm(centroids, centroidsRowMajor, cNorms, true, defaultStream);
177  centroidNorms = &cNorms;
178  }
179 
180  //
181  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
182  //
183  int qNormSize[1] = {numQueries};
184  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
185 
186  // ||q||^2
187  if (computeL2) {
188  runL2Norm(queries, queriesRowMajor, queryNorms, true, defaultStream);
189  }
190 
191  // By default, aim to use up to 512 MB of memory for the processing, with both
192  // number of queries and number of centroids being at least 512.
193  int tileRows = 0;
194  int tileCols = 0;
195  chooseTileSize(numQueries,
196  numCentroids,
197  dim,
198  sizeof(T),
199  mem.getSizeAvailable(),
200  tileRows,
201  tileCols);
202 
203  int numColTiles = utils::divUp(numCentroids, tileCols);
204 
205  // We can have any number of vectors to query against, even less than k, in
206  // which case we'll return -1 for the index
207  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); // select limitation
208 
209  // Temporary output memory space we'll use
210  DeviceTensor<T, 2, true> distanceBuf1(
211  mem, {tileRows, tileCols}, defaultStream);
212  DeviceTensor<T, 2, true> distanceBuf2(
213  mem, {tileRows, tileCols}, defaultStream);
214  DeviceTensor<T, 2, true>* distanceBufs[2] =
215  {&distanceBuf1, &distanceBuf2};
216 
217  DeviceTensor<T, 2, true> outDistanceBuf1(
218  mem, {tileRows, numColTiles * k}, defaultStream);
219  DeviceTensor<T, 2, true> outDistanceBuf2(
220  mem, {tileRows, numColTiles * k}, defaultStream);
221  DeviceTensor<T, 2, true>* outDistanceBufs[2] =
222  {&outDistanceBuf1, &outDistanceBuf2};
223 
224  DeviceTensor<int, 2, true> outIndexBuf1(
225  mem, {tileRows, numColTiles * k}, defaultStream);
226  DeviceTensor<int, 2, true> outIndexBuf2(
227  mem, {tileRows, numColTiles * k}, defaultStream);
228  DeviceTensor<int, 2, true>* outIndexBufs[2] =
229  {&outIndexBuf1, &outIndexBuf2};
230 
231  auto streams = resources->getAlternateStreamsCurrentDevice();
232  streamWait(streams, {defaultStream});
233 
234  int curStream = 0;
235  bool interrupt = false;
236 
237  // Tile over the input queries
238  for (int i = 0; i < numQueries; i += tileRows) {
239  if (interrupt || InterruptCallback::is_interrupted()) {
240  interrupt = true;
241  break;
242  }
243 
244  int curQuerySize = std::min(tileRows, numQueries - i);
245 
246  auto outDistanceView =
247  outDistances.narrow(0, i, curQuerySize);
248  auto outIndexView =
249  outIndices.narrow(0, i, curQuerySize);
250 
251  auto queryView =
252  queries.narrow(queriesRowMajor ? 0 : 1, i, curQuerySize);
253  auto queryNormNiew =
254  queryNorms.narrow(0, i, curQuerySize);
255 
256  auto outDistanceBufRowView =
257  outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
258  auto outIndexBufRowView =
259  outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
260 
261  // Tile over the centroids
262  for (int j = 0; j < numCentroids; j += tileCols) {
264  interrupt = true;
265  break;
266  }
267 
268  int curCentroidSize = std::min(tileCols, numCentroids - j);
269  int curColTile = j / tileCols;
270 
271  auto centroidsView =
272  sliceCentroids(centroids, centroidsRowMajor, j, curCentroidSize);
273 
274  auto distanceBufView = distanceBufs[curStream]->
275  narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
276 
277  auto outDistanceBufColView =
278  outDistanceBufRowView.narrow(1, k * curColTile, k);
279  auto outIndexBufColView =
280  outIndexBufRowView.narrow(1, k * curColTile, k);
281 
282  // L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc
283  // IP: just compute qc
284  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
285  runMatrixMult(distanceBufView,
286  false, // not transposed
287  queryView,
288  !queriesRowMajor, // transposed MM if col major
289  centroidsView,
290  centroidsRowMajor, // transposed MM if row major
291  computeL2 ? -2.0f : 1.0f,
292  0.0f,
293  useHgemm,
294  resources->getBlasHandleCurrentDevice(),
295  streams[curStream]);
296 
297  if (computeL2) {
298  // For L2 distance, we use this fused kernel that performs both
299  // adding ||c||^2 to -2qc and k-selection, so we only need two
300  // passes (one write by the gemm, one read here) over the huge
301  // region of output memory
302  //
303  // If we aren't tiling along the number of centroids, we can perform the
304  // output work directly
305  if (tileCols == numCentroids) {
306  // Write into the final output
307  runL2SelectMin(distanceBufView,
308  *centroidNorms,
309  outDistanceView,
310  outIndexView,
311  k,
312  streams[curStream]);
313 
314  if (!ignoreOutDistances) {
315  // expand (query id) to (query id, k) by duplicating along rows
316  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
317  runSumAlongRows(queryNormNiew,
318  outDistanceView,
319  true, // L2 distances should not go below zero due
320  // to roundoff error
321  streams[curStream]);
322  }
323  } else {
324  auto centroidNormsView = centroidNorms->narrow(0, j, curCentroidSize);
325 
326  // Write into our intermediate output
327  runL2SelectMin(distanceBufView,
328  centroidNormsView,
329  outDistanceBufColView,
330  outIndexBufColView,
331  k,
332  streams[curStream]);
333 
334  if (!ignoreOutDistances) {
335  // expand (query id) to (query id, k) by duplicating along rows
336  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
337  runSumAlongRows(queryNormNiew,
338  outDistanceBufColView,
339  true, // L2 distances should not go below zero due
340  // to roundoff error
341  streams[curStream]);
342  }
343  }
344  } else {
345  // For IP, just k-select the output for this tile
346  if (tileCols == numCentroids) {
347  // Write into the final output
348  runBlockSelect(distanceBufView,
349  outDistanceView,
350  outIndexView,
351  true, k, streams[curStream]);
352  } else {
353  // Write into the intermediate output
354  runBlockSelect(distanceBufView,
355  outDistanceBufColView,
356  outIndexBufColView,
357  true, k, streams[curStream]);
358  }
359  }
360  }
361 
362  // As we're finished with processing a full set of centroids, perform the
363  // final k-selection
364  if (tileCols != numCentroids) {
365  // The indices are tile-relative; for each tile of k, we need to add
366  // tileCols to the index
367  runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
368 
369  runBlockSelectPair(outDistanceBufRowView,
370  outIndexBufRowView,
371  outDistanceView,
372  outIndexView,
373  computeL2 ? false : true, k, streams[curStream]);
374  }
375 
376  curStream = (curStream + 1) % 2;
377  }
378 
379  // Have the desired ordering stream wait on the multi-stream
380  streamWait({defaultStream}, streams);
381 
382  if (interrupt) {
383  FAISS_THROW_MSG("interrupted");
384  }
385 }
386 
387 template <typename T>
388 void runL2Distance(GpuResources* resources,
389  Tensor<T, 2, true>& centroids,
390  bool centroidsRowMajor,
391  Tensor<T, 1, true>* centroidNorms,
392  Tensor<T, 2, true>& queries,
393  bool queriesRowMajor,
394  int k,
395  Tensor<T, 2, true>& outDistances,
396  Tensor<int, 2, true>& outIndices,
397  bool useHgemm,
398  bool ignoreOutDistances = false) {
399  runDistance<T>(true, // L2
400  resources,
401  centroids,
402  centroidsRowMajor,
403  centroidNorms,
404  queries,
405  queriesRowMajor,
406  k,
407  outDistances,
408  outIndices,
409  useHgemm,
410  ignoreOutDistances);
411 }
412 
413 template <typename T>
414 void runIPDistance(GpuResources* resources,
415  Tensor<T, 2, true>& centroids,
416  bool centroidsRowMajor,
417  Tensor<T, 2, true>& queries,
418  bool queriesRowMajor,
419  int k,
420  Tensor<T, 2, true>& outDistances,
421  Tensor<int, 2, true>& outIndices,
422  bool useHgemm) {
423  runDistance<T>(false, // IP
424  resources,
425  centroids,
426  centroidsRowMajor,
427  nullptr, // no centroid norms provided
428  queries,
429  queriesRowMajor,
430  k,
431  outDistances,
432  outIndices,
433  useHgemm,
434  false);
435 }
436 
437 //
438 // Instantiations of the distance templates
439 //
440 
441 void
442 runIPDistance(GpuResources* resources,
443  Tensor<float, 2, true>& vectors,
444  bool vectorsRowMajor,
445  Tensor<float, 2, true>& queries,
446  bool queriesRowMajor,
447  int k,
448  Tensor<float, 2, true>& outDistances,
449  Tensor<int, 2, true>& outIndices) {
450  runIPDistance<float>(resources,
451  vectors,
452  vectorsRowMajor,
453  queries,
454  queriesRowMajor,
455  k,
456  outDistances,
457  outIndices,
458  false);
459 }
460 
461 #ifdef FAISS_USE_FLOAT16
462 void
463 runIPDistance(GpuResources* resources,
464  Tensor<half, 2, true>& vectors,
465  bool vectorsRowMajor,
466  Tensor<half, 2, true>& queries,
467  bool queriesRowMajor,
468  int k,
469  Tensor<half, 2, true>& outDistances,
470  Tensor<int, 2, true>& outIndices,
471  bool useHgemm) {
472  runIPDistance<half>(resources,
473  vectors,
474  vectorsRowMajor,
475  queries,
476  queriesRowMajor,
477  k,
478  outDistances,
479  outIndices,
480  useHgemm);
481 }
482 #endif
483 
484 void
485 runL2Distance(GpuResources* resources,
486  Tensor<float, 2, true>& vectors,
487  bool vectorsRowMajor,
488  Tensor<float, 1, true>* vectorNorms,
489  Tensor<float, 2, true>& queries,
490  bool queriesRowMajor,
491  int k,
492  Tensor<float, 2, true>& outDistances,
493  Tensor<int, 2, true>& outIndices,
494  bool ignoreOutDistances) {
495  runL2Distance<float>(resources,
496  vectors,
497  vectorsRowMajor,
498  vectorNorms,
499  queries,
500  queriesRowMajor,
501  k,
502  outDistances,
503  outIndices,
504  false,
505  ignoreOutDistances);
506 }
507 
508 #ifdef FAISS_USE_FLOAT16
509 void
510 runL2Distance(GpuResources* resources,
511  Tensor<half, 2, true>& vectors,
512  bool vectorsRowMajor,
513  Tensor<half, 1, true>* vectorNorms,
514  Tensor<half, 2, true>& queries,
515  bool queriesRowMajor,
516  int k,
517  Tensor<half, 2, true>& outDistances,
518  Tensor<int, 2, true>& outIndices,
519  bool useHgemm,
520  bool ignoreOutDistances) {
521  runL2Distance<half>(resources,
522  vectors,
523  vectorsRowMajor,
524  vectorNorms,
525  queries,
526  queriesRowMajor,
527  k,
528  outDistances,
529  outIndices,
530  useHgemm,
531  ignoreOutDistances);
532 }
533 #endif
534 
535 } } // namespace