faiss/IndexLSH.cpp

190 lines
4.2 KiB
C++
Raw Normal View History

2017-02-23 06:26:44 +08:00
/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved.
#include "IndexLSH.h"
#include <cstdio>
#include <cstring>
#include <algorithm>
#include "utils.h"
#include "hamming.h"
#include "FaissAssert.h"
namespace faiss {
/***************************************************************
* IndexLSH
***************************************************************/
IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
Index(d), nbits(nbits), rotate_data(rotate_data),
train_thresholds (train_thresholds), rrot(d, nbits)
{
is_trained = !train_thresholds;
bytes_per_vec = (nbits + 7) / 8;
if (rotate_data) {
rrot.init(5);
} else {
FAISS_ASSERT(d >= nbits);
}
set_typename();
}
IndexLSH::IndexLSH ():
nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
{
}
void IndexLSH::set_typename()
{
std::stringstream s;
s << "LSH_" << nbits << (rotate_data ? "r" : "");
index_typename = s.str();
}
const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
{
float *xt = nullptr;
if (rotate_data) {
// also applies bias if exists
xt = rrot.apply (n, x);
} else if (d != nbits) {
xt = new float [nbits * n];
float *xp = xt;
for (idx_t i = 0; i < n; i++) {
const float *xl = x + i * d;
for (int j = 0; j < nbits; j++)
*xp++ = xl [j];
}
}
if (train_thresholds) {
if (xt == NULL) {
xt = new float [nbits * n];
memcpy (xt, x, sizeof(*x) * n * nbits);
}
float *xp = xt;
for (idx_t i = 0; i < n; i++)
for (int j = 0; j < nbits; j++)
*xp++ -= thresholds [j];
}
return xt ? xt : x;
}
void IndexLSH::train (idx_t n, const float *x)
{
if (train_thresholds) {
thresholds.resize (nbits);
train_thresholds = false;
const float *xt = apply_preprocess (n, x);
train_thresholds = true;
float * transposed_x = new float [n * nbits];
for (idx_t i = 0; i < n; i++)
for (idx_t j = 0; j < nbits; j++)
transposed_x [j * n + i] = xt [i * nbits + j];
if (xt != x) delete [] xt;
for (idx_t i = 0; i < nbits; i++) {
float *xi = transposed_x + i * n;
// std::nth_element
std::sort (xi, xi + n);
if (n % 2 == 1)
thresholds [i] = xi [n / 2];
else
thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
}
}
is_trained = true;
}
void IndexLSH::add (idx_t n, const float *x)
{
FAISS_ASSERT (is_trained);
const float *xt = apply_preprocess (n, x);
codes.resize ((ntotal + n) * bytes_per_vec);
fvecs2bitvecs (xt, &codes[ntotal * bytes_per_vec], nbits, n);
if (x != xt)
delete [] xt;
ntotal += n;
}
void IndexLSH::search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const
{
FAISS_ASSERT (is_trained);
const float *xt = apply_preprocess (n, x);
uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
fvecs2bitvecs (xt, qcodes, nbits, n);
if (x != xt)
delete [] xt;
int * idistances = new int [n * k];
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
hammings_knn (&res, qcodes, codes.data(),
ntotal, bytes_per_vec, true);
delete [] qcodes;
// convert distances to floats
for (int i = 0; i < k * n; i++)
distances[i] = idistances[i];
delete [] idistances;
}
void IndexLSH::transfer_thresholds (LinearTransform *vt) {
if (!train_thresholds) return;
FAISS_ASSERT (nbits == vt->d_out);
if (!vt->have_bias) {
vt->b.resize (nbits, 0);
vt->have_bias = true;
}
for (int i = 0; i < nbits; i++)
vt->b[i] -= thresholds[i];
train_thresholds = false;
thresholds.clear();
}
void IndexLSH::reset() {
codes.clear();
ntotal = 0;
}
} // namespace faiss