/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ // -*- c++ -*- #include #include #include namespace faiss { IndexBinaryFromFloat::IndexBinaryFromFloat() {} IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index) : IndexBinary(index->d), index(index), own_fields(false) { is_trained = index->is_trained; ntotal = index->ntotal; } IndexBinaryFromFloat::~IndexBinaryFromFloat() { if (own_fields) { delete index; } } void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) { constexpr idx_t bs = 32768; std::unique_ptr xf(new float[bs * d]); for (idx_t b = 0; b < n; b += bs) { idx_t bn = std::min(bs, n - b); binary_to_real(bn * d, x + b * code_size, xf.get()); index->add(bn, xf.get()); } ntotal = index->ntotal; } void IndexBinaryFromFloat::reset() { index->reset(); ntotal = index->ntotal; } void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const { constexpr idx_t bs = 32768; std::unique_ptr xf(new float[bs * d]); std::unique_ptr df(new float[bs * k]); for (idx_t b = 0; b < n; b += bs) { idx_t bn = std::min(bs, n - b); binary_to_real(bn * d, x + b * code_size, xf.get()); index->search(bn, xf.get(), k, df.get(), labels + b * k); for (int i = 0; i < bn * k; ++i) { distances[b * k + i] = int32_t(std::round(df[i] / 4.0)); } } } void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) { std::unique_ptr xf(new float[n * d]); binary_to_real(n * d, x, xf.get()); index->train(n, xf.get()); is_trained = true; ntotal = index->ntotal; } } // namespace faiss