Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/IndexLSH.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "IndexLSH.h"
12 
13 #include <cstdio>
14 #include <cstring>
15 
16 #include <algorithm>
17 
18 #include "utils.h"
19 #include "hamming.h"
20 #include "FaissAssert.h"
21 
22 namespace faiss {
23 
24 /***************************************************************
25  * IndexLSH
26  ***************************************************************/
27 
28 
29 IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
30  Index(d), nbits(nbits), rotate_data(rotate_data),
31  train_thresholds (train_thresholds), rrot(d, nbits)
32 {
33  is_trained = !train_thresholds;
34 
35  bytes_per_vec = (nbits + 7) / 8;
36 
37  if (rotate_data) {
38  rrot.init(5);
39  } else {
40  FAISS_THROW_IF_NOT (d >= nbits);
41  }
42 }
43 
44 IndexLSH::IndexLSH ():
45  nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
46 {
47 }
48 
49 
50 const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
51 {
52 
53  float *xt = nullptr;
54  if (rotate_data) {
55  // also applies bias if exists
56  xt = rrot.apply (n, x);
57  } else if (d != nbits) {
58  xt = new float [nbits * n];
59  float *xp = xt;
60  for (idx_t i = 0; i < n; i++) {
61  const float *xl = x + i * d;
62  for (int j = 0; j < nbits; j++)
63  *xp++ = xl [j];
64  }
65  }
66 
67  if (train_thresholds) {
68 
69  if (xt == NULL) {
70  xt = new float [nbits * n];
71  memcpy (xt, x, sizeof(*x) * n * nbits);
72  }
73 
74  float *xp = xt;
75  for (idx_t i = 0; i < n; i++)
76  for (int j = 0; j < nbits; j++)
77  *xp++ -= thresholds [j];
78  }
79 
80  return xt ? xt : x;
81 }
82 
83 
84 
85 void IndexLSH::train (idx_t n, const float *x)
86 {
87  if (train_thresholds) {
88  thresholds.resize (nbits);
89  train_thresholds = false;
90  const float *xt = apply_preprocess (n, x);
91  ScopeDeleter<float> del (xt == x ? nullptr : xt);
92  train_thresholds = true;
93 
94  float * transposed_x = new float [n * nbits];
95  ScopeDeleter<float> del2 (transposed_x);
96 
97  for (idx_t i = 0; i < n; i++)
98  for (idx_t j = 0; j < nbits; j++)
99  transposed_x [j * n + i] = xt [i * nbits + j];
100 
101  for (idx_t i = 0; i < nbits; i++) {
102  float *xi = transposed_x + i * n;
103  // std::nth_element
104  std::sort (xi, xi + n);
105  if (n % 2 == 1)
106  thresholds [i] = xi [n / 2];
107  else
108  thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
109 
110  }
111  }
112  is_trained = true;
113 }
114 
115 
116 void IndexLSH::add (idx_t n, const float *x)
117 {
118  FAISS_THROW_IF_NOT (is_trained);
119  const float *xt = apply_preprocess (n, x);
120  ScopeDeleter<float> del (xt == x ? nullptr : xt);
121 
122  codes.resize ((ntotal + n) * bytes_per_vec);
123  fvecs2bitvecs (xt, &codes[ntotal * bytes_per_vec], nbits, n);
124  ntotal += n;
125 }
126 
127 
129  idx_t n,
130  const float *x,
131  idx_t k,
132  float *distances,
133  idx_t *labels) const
134 {
135  FAISS_THROW_IF_NOT (is_trained);
136  const float *xt = apply_preprocess (n, x);
137  ScopeDeleter<float> del (xt == x ? nullptr : xt);
138 
139  uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
140  ScopeDeleter<uint8_t> del2 (qcodes);
141 
142  fvecs2bitvecs (xt, qcodes, nbits, n);
143 
144  int * idistances = new int [n * k];
145  ScopeDeleter<int> del3 (idistances);
146 
147  int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
148 
149  hammings_knn (&res, qcodes, codes.data(),
150  ntotal, bytes_per_vec, true);
151 
152 
153  // convert distances to floats
154  for (int i = 0; i < k * n; i++)
155  distances[i] = idistances[i];
156 
157 }
158 
159 
161  if (!train_thresholds) return;
162  FAISS_THROW_IF_NOT (nbits == vt->d_out);
163  if (!vt->have_bias) {
164  vt->b.resize (nbits, 0);
165  vt->have_bias = true;
166  }
167  for (int i = 0; i < nbits; i++)
168  vt->b[i] -= thresholds[i];
169  train_thresholds = false;
170  thresholds.clear();
171 }
172 
174  codes.clear();
175  ntotal = 0;
176 }
177 
178 
179 
180 } // namespace faiss
int bytes_per_vec
nb of 8-bits per encoded vector
Definition: IndexLSH.h:28
std::vector< float > thresholds
thresholds to compare with
Definition: IndexLSH.h:34
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexLSH.cpp:128
int d
vector dimension
Definition: Index.h:64
std::vector< float > b
bias vector, size d_out
RandomRotationMatrix rrot
optional random rotation
Definition: IndexLSH.h:32
void transfer_thresholds(LinearTransform *vt)
Definition: IndexLSH.cpp:160
long idx_t
all indices are this type
Definition: Index.h:62
void hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:474
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
void reset() override
removes all elements from the database.
Definition: IndexLSH.cpp:173
void add(idx_t n, const float *x) override
Definition: IndexLSH.cpp:116
void train(idx_t n, const float *x) override
Definition: IndexLSH.cpp:85
int d_out
! input dimension
int nbits
nb of bits per vector
Definition: IndexLSH.h:27
const float * apply_preprocess(idx_t n, const float *x) const
Definition: IndexLSH.cpp:50
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
float * apply(idx_t n, const float *x) const
std::vector< uint8_t > codes
encoded dataset
Definition: IndexLSH.h:37