Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
GpuDistance.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "GpuDistance.h"
11 #include "../FaissAssert.h"
12 #include "GpuResources.h"
13 #include "impl/Distance.cuh"
14 #include "utils/ConversionOperators.cuh"
15 #include "utils/CopyUtils.cuh"
16 #include "utils/DeviceUtils.h"
17 #include "utils/DeviceTensor.cuh"
18 
19 #include <thrust/execution_policy.h>
20 #include <thrust/transform.h>
21 
22 namespace faiss { namespace gpu {
23 
24 void bruteForceKnn(GpuResources* resources,
25  faiss::MetricType metric,
26  // A region of memory size numVectors x dims, with dims
27  // innermost
28  const float* vectors,
29  int numVectors,
30  // A region of memory size numQueries x dims, with dims
31  // innermost
32  const float* queries,
33  int numQueries,
34  int dims,
35  int k,
36  // A region of memory size numQueries x k, with k
37  // innermost
38  float* outDistances,
39  // A region of memory size numQueries x k, with k
40  // innermost
41  faiss::Index::idx_t* outIndices) {
42  auto device = getCurrentDevice();
43  auto stream = resources->getDefaultStreamCurrentDevice();
44  auto& mem = resources->getMemoryManagerCurrentDevice();
45 
46  auto tVectors = toDevice<float, 2>(resources,
47  device,
48  const_cast<float*>(vectors),
49  stream,
50  {numVectors, dims});
51  auto tQueries = toDevice<float, 2>(resources,
52  device,
53  const_cast<float*>(queries),
54  stream,
55  {numQueries, dims});
56 
57  auto tOutDistances = toDevice<float, 2>(resources,
58  device,
59  outDistances,
60  stream,
61  {numQueries, k});
62 
63  // FlatIndex only supports an interface returning int indices, allocate
64  // temporary memory for it
65  DeviceTensor<int, 2, true> tOutIntIndices(mem, {numQueries, k}, stream);
66 
67  // Do the work
68  if (metric == faiss::MetricType::METRIC_L2) {
69  runL2Distance(resources,
70  tVectors,
71  nullptr,
72  nullptr, // compute norms in temp memory
73  tQueries,
74  k,
75  tOutDistances,
76  tOutIntIndices);
77  } else if (metric == faiss::MetricType::METRIC_INNER_PRODUCT) {
78  runIPDistance(resources,
79  tVectors,
80  nullptr,
81  tQueries,
82  k,
83  tOutDistances,
84  tOutIntIndices);
85  } else {
86  FAISS_THROW_MSG("metric should be METRIC_L2 or METRIC_INNER_PRODUCT");
87  }
88 
89  // Convert and copy int indices out
90  auto tOutIndices = toDevice<faiss::Index::idx_t, 2>(resources,
91  device,
92  outIndices,
93  stream,
94  {numQueries, k});
95 
96  // Convert int to idx_t
97  thrust::transform(thrust::cuda::par.on(stream),
98  tOutIntIndices.data(),
99  tOutIntIndices.end(),
100  tOutIndices.data(),
101  IntToIdxType());
102 
103  // Copy back if necessary
104  fromDevice<float, 2>(tOutDistances, outDistances, stream);
105  fromDevice<faiss::Index::idx_t, 2>(tOutIndices, outIndices, stream);
106 }
107 
108 } } // namespace
long idx_t
all indices are this type
Definition: Index.h:64
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:45