Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
FlatIndex.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "../utils/DeviceTensor.cuh"
14 #include "../utils/DeviceVector.cuh"
15 #include "../utils/Float16.cuh"
16 #include "../utils/MemorySpace.h"
17 
18 namespace faiss { namespace gpu {
19 
20 class GpuResources;
21 
22 /// Holder of GPU resources for a particular flat index
23 class FlatIndex {
24  public:
26  int dim,
27  bool l2Distance,
28  bool useFloat16,
29  bool useFloat16Accumulator,
30  bool storeTransposed,
31  MemorySpace space);
32 
33  bool getUseFloat16() const;
34 
35  /// Returns the number of vectors we contain
36  int getSize() const;
37 
38  int getDim() const;
39 
40  /// Reserve storage that can contain at least this many vectors
41  void reserve(size_t numVecs, cudaStream_t stream);
42 
43  /// Returns a reference to our vectors currently in use
45 
46 #ifdef FAISS_USE_FLOAT16
47  /// Returns a reference to our vectors currently in use (useFloat16 mode)
48  Tensor<half, 2, true>& getVectorsFloat16Ref();
49 #endif
50 
51  /// Performs a copy of the vectors on the given device, converting
52  /// as needed from float16
54 
55  /// Returns only a subset of the vectors
57  int num,
58  cudaStream_t stream);
59 
60  void query(Tensor<float, 2, true>& vecs,
61  int k,
62  Tensor<float, 2, true>& outDistances,
63  Tensor<int, 2, true>& outIndices,
64  bool exactDistance);
65 
66 #ifdef FAISS_USE_FLOAT16
67  void query(Tensor<half, 2, true>& vecs,
68  int k,
69  Tensor<half, 2, true>& outDistances,
70  Tensor<int, 2, true>& outIndices,
71  bool exactDistance);
72 #endif
73 
74  /// Add vectors to ourselves; the pointer passed can be on the host
75  /// or the device
76  void add(const float* data, int numVecs, cudaStream_t stream);
77 
78  /// Free all storage
79  void reset();
80 
81  private:
82  /// Collection of GPU resources that we use
83  GpuResources* resources_;
84 
85  /// Dimensionality of our vectors
86  const int dim_;
87 
88  /// Float16 data format
89  const bool useFloat16_;
90 
91  /// For supporting hardware, whether or not we use Hgemm
92  const bool useFloat16Accumulator_;
93 
94  /// Store vectors in transposed layout for speed; makes addition to
95  /// the index slower
96  const bool storeTransposed_;
97 
98  /// L2 or inner product distance?
99  bool l2Distance_;
100 
101  /// Memory space for our allocations
102  MemorySpace space_;
103 
104  /// How many vectors we have
105  int num_;
106 
107  /// The underlying expandable storage
108  DeviceVector<char> rawData_;
109 
110  /// Vectors currently in rawData_
112  DeviceTensor<float, 2, true> vectorsTransposed_;
113 
114 #ifdef FAISS_USE_FLOAT16
115  /// Vectors currently in rawData_, float16 form
116  DeviceTensor<half, 2, true> vectorsHalf_;
117  DeviceTensor<half, 2, true> vectorsHalfTransposed_;
118 #endif
119 
120  /// Precomputed L2 norms
122 
123 #ifdef FAISS_USE_FLOAT16
124  /// Precomputed L2 norms, float16 form
125  DeviceTensor<half, 1, true> normsHalf_;
126 #endif
127 };
128 
129 } } // namespace
DeviceTensor< float, 2, true > getVectorsFloat32Copy(cudaStream_t stream)
Definition: FlatIndex.cu:91
int getSize() const
Returns the number of vectors we contain.
Definition: FlatIndex.cu:47
Holder of GPU resources for a particular flat index.
Definition: FlatIndex.cuh:23
void reserve(size_t numVecs, cudaStream_t stream)
Reserve storage that can contain at least this many vectors.
Definition: FlatIndex.cu:68
void add(const float *data, int numVecs, cudaStream_t stream)
Definition: FlatIndex.cu:195
Our tensor type.
Definition: Tensor.cuh:30
Tensor< float, 2, true > & getVectorsFloat32Ref()
Returns a reference to our vectors currently in use.
Definition: FlatIndex.cu:79
void reset()
Free all storage.
Definition: FlatIndex.cu:269