Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/IndexIVFSpectralHash.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 
11 #include "IndexIVFSpectralHash.h"
12 
13 #include <memory>
14 #include <algorithm>
15 
16 #include "hamming.h"
17 #include "utils.h"
18 #include "FaissAssert.h"
19 #include "AuxIndexStructures.h"
20 #include "VectorTransform.h"
21 
22 namespace faiss {
23 
24 
25 IndexIVFSpectralHash::IndexIVFSpectralHash (
26  Index * quantizer, size_t d, size_t nlist,
27  int nbit, float period):
28  IndexIVF (quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
29  nbit (nbit), period (period), threshold_type (Thresh_global)
30 {
31  FAISS_THROW_IF_NOT (code_size % 4 == 0);
32  RandomRotationMatrix *rr = new RandomRotationMatrix (d, nbit);
33  rr->init (1234);
34  vt = rr;
35  own_fields = true;
36  is_trained = false;
37 }
38 
39 IndexIVFSpectralHash::IndexIVFSpectralHash():
40  IndexIVF(), vt(nullptr), own_fields(false),
41  nbit(0), period(0), threshold_type(Thresh_global)
42 {}
43 
44 IndexIVFSpectralHash::~IndexIVFSpectralHash ()
45 {
46  if (own_fields) {
47  delete vt;
48  }
49 }
50 
51 namespace {
52 
53 
54 float median (size_t n, float *x) {
55  std::sort(x, x + n);
56  if (n % 2 == 1) {
57  return x [n / 2];
58  } else {
59  return (x [n / 2 - 1] + x [n / 2]) / 2;
60  }
61 }
62 
63 }
64 
65 
67 {
68  if (!vt->is_trained) {
69  vt->train (n, x);
70  }
71 
72  if (threshold_type == Thresh_global) {
73  // nothing to do
74  return;
75  } else if (threshold_type == Thresh_centroid ||
76  threshold_type == Thresh_centroid_half) {
77  // convert all centroids with vt
78  std::vector<float> centroids (nlist * d);
79  quantizer->reconstruct_n (0, nlist, centroids.data());
80  trained.resize(nlist * nbit);
81  vt->apply_noalloc (nlist, centroids.data(), trained.data());
82  if (threshold_type == Thresh_centroid_half) {
83  for (size_t i = 0; i < nlist * nbit; i++) {
84  trained[i] -= 0.25 * period;
85  }
86  }
87  return;
88  }
89  // otherwise train medians
90 
91  // assign
92  std::unique_ptr<idx_t []> idx (new idx_t [n]);
93  quantizer->assign (n, x, idx.get());
94 
95  std::vector<size_t> sizes(nlist + 1);
96  for (size_t i = 0; i < n; i++) {
97  FAISS_THROW_IF_NOT (idx[i] >= 0);
98  sizes[idx[i]]++;
99  }
100 
101  size_t ofs = 0;
102  for (int j = 0; j < nlist; j++) {
103  size_t o0 = ofs;
104  ofs += sizes[j];
105  sizes[j] = o0;
106  }
107 
108  // transform
109  std::unique_ptr<float []> xt (vt->apply (n, x));
110 
111  // transpose + reorder
112  std::unique_ptr<float []> xo (new float[n * nbit]);
113 
114  for (size_t i = 0; i < n; i++) {
115  size_t idest = sizes[idx[i]]++;
116  for (size_t j = 0; j < nbit; j++) {
117  xo[idest + n * j] = xt[i * nbit + j];
118  }
119  }
120 
121  trained.resize (n * nbit);
122  // compute medians
123 #pragma omp for
124  for (int i = 0; i < nlist; i++) {
125  size_t i0 = i == 0 ? 0 : sizes[i - 1];
126  size_t i1 = sizes[i];
127  for (int j = 0; j < nbit; j++) {
128  float *xoi = xo.get() + i0 + n * j;
129  if (i0 == i1) { // nothing to train
130  trained[i * nbit + j] = 0.0;
131  } else if (i1 == i0 + 1) {
132  trained[i * nbit + j] = xoi[0];
133  } else {
134  trained[i * nbit + j] = median(i1 - i0, xoi);
135  }
136  }
137  }
138 }
139 
140 
141 namespace {
142 
143 void binarize_with_freq(size_t nbit, float freq,
144  const float *x, const float *c,
145  uint8_t *codes)
146 {
147  memset (codes, 0, (nbit + 7) / 8);
148  for (size_t i = 0; i < nbit; i++) {
149  float xf = (x[i] - c[i]);
150  int xi = int(floor(xf * freq));
151  int bit = xi & 1;
152  codes[i >> 3] |= bit << (i & 7);
153  }
154 }
155 
156 
157 };
158 
159 
160 
161 void IndexIVFSpectralHash::encode_vectors(idx_t n, const float* x_in,
162  const idx_t *list_nos,
163  uint8_t * codes) const
164 {
165  FAISS_THROW_IF_NOT (is_trained);
166  float freq = 2.0 / period;
167 
168  // transform with vt
169  std::unique_ptr<float []> x (vt->apply (n, x_in));
170 
171 #pragma omp parallel
172  {
173  std::vector<float> zero (nbit);
174 
175  // each thread takes care of a subset of lists
176 #pragma omp for
177  for (size_t i = 0; i < n; i++) {
178  long list_no = list_nos [i];
179 
180  if (list_no >= 0) {
181  const float *c;
182  if (threshold_type == Thresh_global) {
183  c = zero.data();
184  } else {
185  c = trained.data() + list_no * nbit;
186  }
187  binarize_with_freq (nbit, freq,
188  x.get() + i * nbit, c,
189  codes + i * code_size) ;
190  }
191  }
192  }
193 }
194 
195 namespace {
196 
197 
198 template<class HammingComputer>
199 struct IVFScanner: InvertedListScanner {
200 
201  // copied from index structure
202  const IndexIVFSpectralHash *index;
203  size_t code_size;
204  size_t nbit;
205  bool store_pairs;
206 
207  float period, freq;
208  std::vector<float> q;
209  std::vector<float> zero;
210  std::vector<uint8_t> qcode;
211  HammingComputer hc;
212 
213  using idx_t = Index::idx_t;
214 
215  IVFScanner (const IndexIVFSpectralHash * index,
216  bool store_pairs):
217  index (index),
218  code_size(index->code_size),
219  nbit(index->nbit),
220  store_pairs(store_pairs),
221  period(index->period), freq(2.0 / index->period),
222  q(nbit), zero(nbit), qcode(code_size),
223  hc(qcode.data(), code_size)
224  {
225  }
226 
227 
228  void set_query (const float *query) override {
229  FAISS_THROW_IF_NOT(query);
230  FAISS_THROW_IF_NOT(q.size() == nbit);
231  index->vt->apply_noalloc (1, query, q.data());
232 
233  if (index->threshold_type ==
234  IndexIVFSpectralHash::Thresh_global) {
235  binarize_with_freq
236  (nbit, freq, q.data(), zero.data(), qcode.data());
237  hc.set (qcode.data(), code_size);
238  }
239  }
240 
241  idx_t list_no;
242 
243  void set_list (idx_t list_no, float /*coarse_dis*/) override {
244  this->list_no = list_no;
245  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
246  const float *c = index->trained.data() + list_no * nbit;
247  binarize_with_freq (nbit, freq, q.data(), c, qcode.data());
248  hc.set (qcode.data(), code_size);
249  }
250  }
251 
252  float distance_to_code (const uint8_t *code) const final {
253  return hc.hamming (code);
254  }
255 
256  size_t scan_codes (size_t list_size,
257  const uint8_t *codes,
258  const idx_t *ids,
259  float *simi, idx_t *idxi,
260  size_t k) const override
261  {
262  size_t nup = 0;
263  for (size_t j = 0; j < list_size; j++) {
264 
265  float dis = hc.hamming (codes);
266 
267  if (dis < simi [0]) {
268  maxheap_pop (k, simi, idxi);
269  long id = store_pairs ? (list_no << 32 | j) : ids[j];
270  maxheap_push (k, simi, idxi, dis, id);
271  nup++;
272  }
273  codes += code_size;
274  }
275  return nup;
276  }
277 
278  void scan_codes_range (size_t list_size,
279  const uint8_t *codes,
280  const idx_t *ids,
281  float radius,
282  RangeQueryResult & res) const override
283  {
284  for (size_t j = 0; j < list_size; j++) {
285  float dis = hc.hamming (codes);
286  if (dis < radius) {
287  long id = store_pairs ? (list_no << 32 | j) : ids[j];
288  res.add (dis, id);
289  }
290  codes += code_size;
291  }
292  }
293 
294 
295 };
296 
297 } // anonymous namespace
298 
300  (bool store_pairs) const
301 {
302  switch (code_size) {
303 #define HANDLE_CODE_SIZE(cs) \
304  case cs: \
305  return new IVFScanner<HammingComputer ## cs> (this, store_pairs)
306  HANDLE_CODE_SIZE(4);
307  HANDLE_CODE_SIZE(8);
308  HANDLE_CODE_SIZE(16);
309  HANDLE_CODE_SIZE(20);
310  HANDLE_CODE_SIZE(32);
311  HANDLE_CODE_SIZE(64);
312 #undef HANDLE_CODE_SIZE
313  default:
314  if (code_size % 8 == 0) {
315  return new IVFScanner<HammingComputerM8>(this, store_pairs);
316  } else if (code_size % 4 == 0) {
317  return new IVFScanner<HammingComputerM4>(this, store_pairs);
318  } else {
319  FAISS_THROW_MSG("not supported");
320  }
321  }
322 
323 }
324 
325 
326 
327 } // namespace faiss
size_t code_size
code_size_1 + code_size_2
Definition: IndexIVFPQ.h:221
InvertedListScanner * get_InvertedListScanner(bool store_pairs) const override
get a scanner for this index (store_pairs means ignore labels)
long idx_t
all indices are this type
Definition: Index.h:62
void train_residual(idx_t n, const float *x) override
virtual void train(idx_t n, const float *x)
void encode_vectors(idx_t n, const float *x, const idx_t *list_nos, uint8_t *codes) const override
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
float * apply(idx_t n, const float *x) const
size_t nlist
number of possible key values
Definition: IndexIVF.h:33
virtual void apply_noalloc(idx_t n, const float *x, float *xt) const =0
same as apply, but result is pre-allocated
size_t code_size
code size per vector in bytes
Definition: IndexIVF.h:95