Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
FlatIndex.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "../utils/DeviceTensor.cuh"
14 #include "../utils/DeviceVector.cuh"
15 #include "../utils/Float16.cuh"
16 #include "../utils/MemorySpace.h"
17 
18 namespace faiss { namespace gpu {
19 
20 class GpuResources;
21 
22 /// Holder of GPU resources for a particular flat index
23 class FlatIndex {
24  public:
26  int dim,
27  bool l2Distance,
28  bool useFloat16,
29  bool useFloat16Accumulator,
30  bool storeTransposed,
31  MemorySpace space);
32 
33  bool getUseFloat16() const;
34 
35  /// Returns the number of vectors we contain
36  int getSize() const;
37 
38  int getDim() const;
39 
40  /// Reserve storage that can contain at least this many vectors
41  void reserve(size_t numVecs, cudaStream_t stream);
42 
43  /// Returns a reference to our vectors currently in use
45 
46 #ifdef FAISS_USE_FLOAT16
47  /// Returns a reference to our vectors currently in use (useFloat16 mode)
48  Tensor<half, 2, true>& getVectorsFloat16Ref();
49 #endif
50 
51  /// Performs a copy of the vectors on the given device, converting
52  /// as needed from float16
54 
55  /// Returns only a subset of the vectors
57  int num,
58  cudaStream_t stream);
59 
60  void query(Tensor<float, 2, true>& vecs,
61  int k,
62  Tensor<float, 2, true>& outDistances,
63  Tensor<int, 2, true>& outIndices,
64  bool exactDistance,
65  int tileSize = -1);
66 
67 #ifdef FAISS_USE_FLOAT16
68  void query(Tensor<half, 2, true>& vecs,
69  int k,
70  Tensor<half, 2, true>& outDistances,
71  Tensor<int, 2, true>& outIndices,
72  bool exactDistance,
73  int tileSize = -1);
74 #endif
75 
76  /// Add vectors to ourselves; the pointer passed can be on the host
77  /// or the device
78  void add(const float* data, int numVecs, cudaStream_t stream);
79 
80  /// Free all storage
81  void reset();
82 
83  private:
84  /// Collection of GPU resources that we use
85  GpuResources* resources_;
86 
87  /// Dimensionality of our vectors
88  const int dim_;
89 
90  /// Float16 data format
91  const bool useFloat16_;
92 
93  /// For supporting hardware, whether or not we use Hgemm
94  const bool useFloat16Accumulator_;
95 
96  /// Store vectors in transposed layout for speed; makes addition to
97  /// the index slower
98  const bool storeTransposed_;
99 
100  /// L2 or inner product distance?
101  bool l2Distance_;
102 
103  /// Memory space for our allocations
104  MemorySpace space_;
105 
106  /// How many vectors we have
107  int num_;
108 
109  /// The underlying expandable storage
110  DeviceVector<char> rawData_;
111 
112  /// Vectors currently in rawData_
114  DeviceTensor<float, 2, true> vectorsTransposed_;
115 
116 #ifdef FAISS_USE_FLOAT16
117  /// Vectors currently in rawData_, float16 form
118  DeviceTensor<half, 2, true> vectorsHalf_;
119  DeviceTensor<half, 2, true> vectorsHalfTransposed_;
120 #endif
121 
122  /// Precomputed L2 norms
124 
125 #ifdef FAISS_USE_FLOAT16
126  /// Precomputed L2 norms, float16 form
127  DeviceTensor<half, 1, true> normsHalf_;
128 #endif
129 };
130 
131 } } // namespace
DeviceTensor< float, 2, true > getVectorsFloat32Copy(cudaStream_t stream)
Definition: FlatIndex.cu:91
int getSize() const
Returns the number of vectors we contain.
Definition: FlatIndex.cu:47
Holder of GPU resources for a particular flat index.
Definition: FlatIndex.cuh:23
void reserve(size_t numVecs, cudaStream_t stream)
Reserve storage that can contain at least this many vectors.
Definition: FlatIndex.cu:68
void add(const float *data, int numVecs, cudaStream_t stream)
Definition: FlatIndex.cu:201
Our tensor type.
Definition: Tensor.cuh:30
Tensor< float, 2, true > & getVectorsFloat32Ref()
Returns a reference to our vectors currently in use.
Definition: FlatIndex.cu:79
void reset()
Free all storage.
Definition: FlatIndex.cu:273