Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/IndexBinaryFromFloat.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "IndexBinaryFromFloat.h"
11 
12 #include <memory>
13 #include "utils.h"
14 
15 namespace faiss {
16 
17 
18 IndexBinaryFromFloat::IndexBinaryFromFloat() {}
19 
20 IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index)
21  : IndexBinary(index->d),
22  index(index),
23  own_fields(false) {
24  is_trained = index->is_trained;
25  ntotal = index->ntotal;
26 }
27 
28 IndexBinaryFromFloat::~IndexBinaryFromFloat() {
29  if (own_fields) {
30  delete index;
31  }
32 }
33 
34 void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) {
35  constexpr idx_t bs = 32768;
36  std::unique_ptr<float[]> xf(new float[bs * d]);
37 
38  for (idx_t b = 0; b < n; b += bs) {
39  idx_t bn = std::min(bs, n - b);
40  binary_to_real(bn * d, x + b * code_size, xf.get());
41 
42  index->add(bn, xf.get());
43  }
44  ntotal = index->ntotal;
45 }
46 
48  index->reset();
49  ntotal = index->ntotal;
50 }
51 
52 void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
53  int32_t *distances, idx_t *labels) const {
54  constexpr idx_t bs = 32768;
55  std::unique_ptr<float[]> xf(new float[bs * d]);
56  std::unique_ptr<float[]> df(new float[bs * k]);
57 
58  for (idx_t b = 0; b < n; b += bs) {
59  idx_t bn = std::min(bs, n - b);
60  binary_to_real(bn * d, x + b * code_size, xf.get());
61 
62  index->search(bn, xf.get(), k, df.get(), labels + b * k);
63  for (int i = 0; i < bn * k; ++i) {
64  distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
65  }
66  }
67 }
68 
69 void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) {
70  std::unique_ptr<float[]> xf(new float[n * d]);
71  binary_to_real(n * d, x, xf.get());
72 
73  index->train(n, xf.get());
74  is_trained = true;
75  ntotal = index->ntotal;
76 }
77 
78 } // namespace faiss
void add(idx_t n, const uint8_t *x) override
virtual void reset()=0
removes all elements from the database.
bool own_fields
Whether object owns the index pointer.
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:23
bool is_trained
set if the Index does not require training, or if training is done already
Definition: IndexBinary.h:47
int code_size
number of bytes per vector ( = d / 8 )
Definition: IndexBinary.h:42
Index::idx_t idx_t
all indices are this type
Definition: IndexBinary.h:37
int d
vector dimension
Definition: IndexBinary.h:41
void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
virtual void add(idx_t n, const float *x)=0
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void reset() override
Removes all elements from the database.
idx_t ntotal
total nb of indexed vectors
Definition: IndexBinary.h:43
void binary_to_real(size_t d, const uint8_t *x_in, float *x_out)
Definition: utils.cpp:1564
void train(idx_t n, const uint8_t *x) override