21 #include "FaissAssert.h"
30 IndexLSH::IndexLSH (idx_t d,
int nbits,
bool rotate_data,
bool train_thresholds):
31 Index(d), nbits(nbits), rotate_data(rotate_data),
32 train_thresholds (train_thresholds), rrot(d, nbits)
36 bytes_per_vec = (nbits + 7) / 8;
41 FAISS_ASSERT(d >= nbits);
46 IndexLSH::IndexLSH ():
47 nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
52 void IndexLSH::set_typename()
55 s <<
"LSH_" << nbits << (rotate_data ?
"r" :
"");
56 index_typename = s.str();
66 }
else if (d != nbits) {
67 xt =
new float [nbits * n];
69 for (
idx_t i = 0; i < n; i++) {
70 const float *xl = x + i *
d;
71 for (
int j = 0; j <
nbits; j++)
76 if (train_thresholds) {
79 xt =
new float [nbits * n];
80 memcpy (xt, x,
sizeof(*x) * n * nbits);
84 for (
idx_t i = 0; i < n; i++)
85 for (
int j = 0; j <
nbits; j++)
96 if (train_thresholds) {
98 train_thresholds =
false;
100 train_thresholds =
true;
102 float * transposed_x =
new float [n *
nbits];
104 for (
idx_t i = 0; i < n; i++)
106 transposed_x [j * n + i] = xt [i * nbits + j];
107 if (xt != x)
delete [] xt;
110 float *xi = transposed_x + i * n;
112 std::sort (xi, xi + n);
116 thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
130 fvecs2bitvecs (xt, &
codes[
ntotal * bytes_per_vec], nbits, n);
148 fvecs2bitvecs (xt, qcodes, nbits, n);
153 int * idistances =
new int [n * k];
162 for (
int i = 0; i < k * n; i++)
163 distances[i] = idistances[i];
164 delete [] idistances;
170 if (!train_thresholds)
return;
171 FAISS_ASSERT (nbits == vt->
d_out);
172 if (!vt->have_bias) {
173 vt->
b.resize (nbits, 0);
174 vt->have_bias =
true;
176 for (
int i = 0; i <
nbits; i++)
178 train_thresholds =
false;
int bytes_per_vec
nb of 8-bits per encoded vector
std::vector< float > thresholds
thresholds to compare with
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
RandomRotationMatrix rrot
optional random rotation
void transfer_thresholds(LinearTransform *vt)
long idx_t
all indices are this type
void hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
idx_t ntotal
total nb of indexed vectors
virtual void reset() override
removes all elements from the database.
virtual void add(idx_t n, const float *x) override
virtual void train(idx_t n, const float *x) override
int nbits
nb of bits per vector
const float * apply_preprocess(idx_t n, const float *x) const
bool is_trained
set if the Index does not require training, or if training is done already
std::vector< uint8_t > codes
encoded dataset