Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IndexLSH.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IndexLSH.h"
12 
13 #include <cstdio>
14 #include <cstring>
15 
16 #include <algorithm>
17 
18 #include "utils.h"
19 #include "hamming.h"
20 #include "FaissAssert.h"
21 
22 
23 namespace faiss {
24 
25 /***************************************************************
26  * IndexLSH
27  ***************************************************************/
28 
29 
30 IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
31  Index(d), nbits(nbits), rotate_data(rotate_data),
32  train_thresholds (train_thresholds), rrot(d, nbits)
33 {
34  is_trained = !train_thresholds;
35 
36  bytes_per_vec = (nbits + 7) / 8;
37 
38  if (rotate_data) {
39  rrot.init(5);
40  } else {
41  FAISS_THROW_IF_NOT (d >= nbits);
42  }
43 }
44 
45 IndexLSH::IndexLSH ():
46  nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
47 {
48 }
49 
50 
51 const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
52 {
53 
54  float *xt = nullptr;
55  if (rotate_data) {
56  // also applies bias if exists
57  xt = rrot.apply (n, x);
58  } else if (d != nbits) {
59  xt = new float [nbits * n];
60  float *xp = xt;
61  for (idx_t i = 0; i < n; i++) {
62  const float *xl = x + i * d;
63  for (int j = 0; j < nbits; j++)
64  *xp++ = xl [j];
65  }
66  }
67 
68  if (train_thresholds) {
69 
70  if (xt == NULL) {
71  xt = new float [nbits * n];
72  memcpy (xt, x, sizeof(*x) * n * nbits);
73  }
74 
75  float *xp = xt;
76  for (idx_t i = 0; i < n; i++)
77  for (int j = 0; j < nbits; j++)
78  *xp++ -= thresholds [j];
79  }
80 
81  return xt ? xt : x;
82 }
83 
84 
85 
86 void IndexLSH::train (idx_t n, const float *x)
87 {
88  if (train_thresholds) {
89  thresholds.resize (nbits);
90  train_thresholds = false;
91  const float *xt = apply_preprocess (n, x);
92  ScopeDeleter<float> del (xt == x ? nullptr : xt);
93  train_thresholds = true;
94 
95  float * transposed_x = new float [n * nbits];
96  ScopeDeleter<float> del2 (transposed_x);
97 
98  for (idx_t i = 0; i < n; i++)
99  for (idx_t j = 0; j < nbits; j++)
100  transposed_x [j * n + i] = xt [i * nbits + j];
101 
102  for (idx_t i = 0; i < nbits; i++) {
103  float *xi = transposed_x + i * n;
104  // std::nth_element
105  std::sort (xi, xi + n);
106  if (n % 2 == 1)
107  thresholds [i] = xi [n / 2];
108  else
109  thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
110 
111  }
112  }
113  is_trained = true;
114 }
115 
116 
117 void IndexLSH::add (idx_t n, const float *x)
118 {
119  FAISS_THROW_IF_NOT (is_trained);
120  const float *xt = apply_preprocess (n, x);
121  ScopeDeleter<float> del (xt == x ? nullptr : xt);
122 
123  codes.resize ((ntotal + n) * bytes_per_vec);
124  fvecs2bitvecs (xt, &codes[ntotal * bytes_per_vec], nbits, n);
125  ntotal += n;
126 }
127 
128 
130  idx_t n,
131  const float *x,
132  idx_t k,
133  float *distances,
134  idx_t *labels) const
135 {
136  FAISS_THROW_IF_NOT (is_trained);
137  const float *xt = apply_preprocess (n, x);
138  ScopeDeleter<float> del (xt == x ? nullptr : xt);
139 
140  uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
141  ScopeDeleter<uint8_t> del2 (qcodes);
142 
143  fvecs2bitvecs (xt, qcodes, nbits, n);
144 
145  int * idistances = new int [n * k];
146  ScopeDeleter<int> del3 (idistances);
147 
148  int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
149 
150  hammings_knn_hc (&res, qcodes, codes.data(),
151  ntotal, bytes_per_vec, true);
152 
153 
154  // convert distances to floats
155  for (int i = 0; i < k * n; i++)
156  distances[i] = idistances[i];
157 
158 }
159 
160 
162  if (!train_thresholds) return;
163  FAISS_THROW_IF_NOT (nbits == vt->d_out);
164  if (!vt->have_bias) {
165  vt->b.resize (nbits, 0);
166  vt->have_bias = true;
167  }
168  for (int i = 0; i < nbits; i++)
169  vt->b[i] -= thresholds[i];
170  train_thresholds = false;
171  thresholds.clear();
172 }
173 
175  codes.clear();
176  ntotal = 0;
177 }
178 
179 
180 } // namespace faiss
void hammings_knn_hc(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:518
int bytes_per_vec
nb of 8-bits per encoded vector
Definition: IndexLSH.h:27
std::vector< float > thresholds
thresholds to compare with
Definition: IndexLSH.h:33
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexLSH.cpp:129
int d
vector dimension
Definition: Index.h:66
std::vector< float > b
bias vector, size d_out
RandomRotationMatrix rrot
optional random rotation
Definition: IndexLSH.h:31
void transfer_thresholds(LinearTransform *vt)
Definition: IndexLSH.cpp:161
long idx_t
all indices are this type
Definition: Index.h:64
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
void reset() override
removes all elements from the database.
Definition: IndexLSH.cpp:174
void add(idx_t n, const float *x) override
Definition: IndexLSH.cpp:117
void train(idx_t n, const float *x) override
Definition: IndexLSH.cpp:86
int d_out
! input dimension
int nbits
nb of bits per vector
Definition: IndexLSH.h:26
const float * apply_preprocess(idx_t n, const float *x) const
Definition: IndexLSH.cpp:51
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
float * apply(idx_t n, const float *x) const
std::vector< uint8_t > codes
encoded dataset
Definition: IndexLSH.h:36