Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
GpuDistance.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "GpuDistance.h"
10 #include "../FaissAssert.h"
11 #include "GpuResources.h"
12 #include "impl/Distance.cuh"
13 #include "utils/ConversionOperators.cuh"
14 #include "utils/CopyUtils.cuh"
15 #include "utils/DeviceUtils.h"
16 #include "utils/DeviceTensor.cuh"
17 
18 #include <thrust/execution_policy.h>
19 #include <thrust/transform.h>
20 
21 namespace faiss { namespace gpu {
22 
23 void bruteForceKnn(GpuResources* resources,
24  faiss::MetricType metric,
25  // A region of memory size numVectors x dims, with dims
26  // innermost
27  const float* vectors,
28  bool vectorsRowMajor,
29  int numVectors,
30  // A region of memory size numQueries x dims, with dims
31  // innermost
32  const float* queries,
33  bool queriesRowMajor,
34  int numQueries,
35  int dims,
36  int k,
37  // A region of memory size numQueries x k, with k
38  // innermost
39  float* outDistances,
40  // A region of memory size numQueries x k, with k
41  // innermost
42  faiss::Index::idx_t* outIndices) {
43  auto device = getCurrentDevice();
44  auto stream = resources->getDefaultStreamCurrentDevice();
45  auto& mem = resources->getMemoryManagerCurrentDevice();
46 
47  auto tVectors = toDevice<float, 2>(resources,
48  device,
49  const_cast<float*>(vectors),
50  stream,
51  {vectorsRowMajor ? numVectors : dims,
52  vectorsRowMajor ? dims : numVectors});
53  auto tQueries = toDevice<float, 2>(resources,
54  device,
55  const_cast<float*>(queries),
56  stream,
57  {queriesRowMajor ? numQueries : dims,
58  queriesRowMajor ? dims : numQueries});
59 
60  auto tOutDistances = toDevice<float, 2>(resources,
61  device,
62  outDistances,
63  stream,
64  {numQueries, k});
65 
66  // FlatIndex only supports an interface returning int indices, allocate
67  // temporary memory for it
68  DeviceTensor<int, 2, true> tOutIntIndices(mem, {numQueries, k}, stream);
69 
70  // Do the work
71  if (metric == faiss::MetricType::METRIC_L2) {
72  runL2Distance(resources,
73  tVectors,
74  vectorsRowMajor,
75  nullptr, // compute norms in temp memory
76  tQueries,
77  queriesRowMajor,
78  k,
79  tOutDistances,
80  tOutIntIndices);
81  } else if (metric == faiss::MetricType::METRIC_INNER_PRODUCT) {
82  runIPDistance(resources,
83  tVectors,
84  vectorsRowMajor,
85  tQueries,
86  queriesRowMajor,
87  k,
88  tOutDistances,
89  tOutIntIndices);
90  } else {
91  FAISS_THROW_MSG("metric should be METRIC_L2 or METRIC_INNER_PRODUCT");
92  }
93 
94  // Convert and copy int indices out
95  auto tOutIndices = toDevice<faiss::Index::idx_t, 2>(resources,
96  device,
97  outIndices,
98  stream,
99  {numQueries, k});
100 
101  // Convert int to idx_t
102  thrust::transform(thrust::cuda::par.on(stream),
103  tOutIntIndices.data(),
104  tOutIntIndices.end(),
105  tOutIndices.data(),
106  IntToIdxType());
107 
108  // Copy back if necessary
109  fromDevice<float, 2>(tOutDistances, outDistances, stream);
110  fromDevice<faiss::Index::idx_t, 2>(tOutIndices, outIndices, stream);
111 }
112 
113 } } // namespace
long idx_t
all indices are this type
Definition: Index.h:62
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:44