Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IndexBinaryFromFloat.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IndexBinaryFromFloat.h"
12 
13 #include <memory>
14 #include "utils.h"
15 
16 namespace faiss {
17 
18 
19 IndexBinaryFromFloat::IndexBinaryFromFloat() {}
20 
21 IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index)
22  : IndexBinary(index->d),
23  index(index),
24  own_fields(false) {
25  is_trained = index->is_trained;
26  ntotal = index->ntotal;
27 }
28 
29 IndexBinaryFromFloat::~IndexBinaryFromFloat() {
30  if (own_fields) {
31  delete index;
32  }
33 }
34 
35 void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) {
36  constexpr idx_t bs = 32768;
37  std::unique_ptr<float[]> xf(new float[bs * d]);
38 
39  for (idx_t b = 0; b < n; b += bs) {
40  idx_t bn = std::min(bs, n - b);
41  binary_to_real(bn * d, x + b * code_size, xf.get());
42 
43  index->add(bn, xf.get());
44  }
45  ntotal = index->ntotal;
46 }
47 
49  index->reset();
50  ntotal = index->ntotal;
51 }
52 
53 void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
54  int32_t *distances, idx_t *labels) const {
55  constexpr idx_t bs = 32768;
56  std::unique_ptr<float[]> xf(new float[bs * d]);
57  std::unique_ptr<float[]> df(new float[bs * k]);
58 
59  for (idx_t b = 0; b < n; b += bs) {
60  idx_t bn = std::min(bs, n - b);
61  binary_to_real(bn * d, x + b * code_size, xf.get());
62 
63  index->search(bn, xf.get(), k, df.get(), labels + b * k);
64  for (int i = 0; i < bn * k; ++i) {
65  distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
66  }
67  }
68 }
69 
70 void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) {
71  std::unique_ptr<float[]> xf(new float[n * d]);
72  binary_to_real(n * d, x, xf.get());
73 
74  index->train(n, xf.get());
75  is_trained = true;
76  ntotal = index->ntotal;
77 }
78 
79 } // namespace faiss
void add(idx_t n, const uint8_t *x) override
virtual void reset()=0
removes all elements from the database.
bool own_fields
Whether object owns the index pointer.
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:24
bool is_trained
set if the Index does not require training, or if training is done already
Definition: IndexBinary.h:46
int code_size
number of bytes per vector ( = d / 8 )
Definition: IndexBinary.h:41
int d
vector dimension
Definition: IndexBinary.h:40
void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
virtual void add(idx_t n, const float *x)=0
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void reset() override
Removes all elements from the database.
idx_t ntotal
total nb of indexed vectors
Definition: IndexBinary.h:42
long idx_t
all indices are this type
Definition: IndexBinary.h:38
void binary_to_real(size_t d, const uint8_t *x_in, float *x_out)
Definition: utils.cpp:1552
void train(idx_t n, const uint8_t *x) override