Generalize DistanceComputer for flat indexes (#2255)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2255 The `DistanceComputer` object is derived from an Index (obtained with `get_distance_computer()`). It maintains a current query and quickly computes distances from that query to any item in the database. This is useful, eg. for the IndexHNSW and IndexNSG that rely on query-to-point comparisons in the datasets. This diff introduces the `FlatCodesDistanceComputer`, that inherits from `DistanceComputer` for Flat indexes. In addition to the distance-to-item function, it adds a `distance_to_code` that computes the distance from any code to the current query, even if it is not stored in the index. This is implemented for all FlatCode indexes (IndexFlat, IndexPQ, IndexScalarQuantizer and IndexAdditiveQuantizer). In the process, the two classes were extracted to their own header file `impl/DistanceComputer.h` Reviewed By: beauby Differential Revision: D34863609 fbshipit-source-id: 39d8c66475e55c3223c4a6a210827aa48bca292dpull/2276/head^2
parent
add3705c11
commit
291353c5a9
|
@ -11,6 +11,7 @@ the Facebook Faiss team. Feel free to add entries here if you submit a PR.
|
|||
## [Unreleased]
|
||||
|
||||
- Added sparse k-means routines and moved the generic kmeans to contrib
|
||||
- Added FlatDistanceComputer for all FlatCodes indexes
|
||||
|
||||
## [1.7.2] - 2021-12-15
|
||||
### Added
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "AuxIndexStructures_c.h"
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <iostream>
|
||||
#include "../macros_impl.h"
|
||||
|
||||
|
|
|
@ -131,6 +131,7 @@ set(FAISS_HEADERS
|
|||
index_io.h
|
||||
impl/AdditiveQuantizer.h
|
||||
impl/AuxIndexStructures.h
|
||||
impl/DistanceComputer.h
|
||||
impl/FaissAssert.h
|
||||
impl/FaissException.h
|
||||
impl/HNSW.h
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <faiss/Index.h>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/distances.h>
|
||||
|
||||
|
|
|
@ -38,7 +38,8 @@
|
|||
|
||||
namespace faiss {
|
||||
|
||||
/// Forward declarations see AuxIndexStructures.h
|
||||
/// Forward declarations see impl/AuxIndexStructures.h and
|
||||
/// impl/DistanceComputer.h
|
||||
struct IDSelector;
|
||||
struct RangeSearchResult;
|
||||
struct DistanceComputer;
|
||||
|
|
|
@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
|
|||
|
||||
namespace {
|
||||
|
||||
/************************************************************
|
||||
* DistanceComputer implementation
|
||||
************************************************************/
|
||||
|
||||
template <class VectorDistance>
|
||||
struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
|
||||
std::vector<float> tmp;
|
||||
const AdditiveQuantizer & aq;
|
||||
VectorDistance vd;
|
||||
size_t d;
|
||||
|
||||
AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
|
||||
FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
||||
tmp(iaq.d * 2),
|
||||
aq(*iaq.aq),
|
||||
vd(vd),
|
||||
d(iaq.d)
|
||||
{}
|
||||
|
||||
const float *q;
|
||||
void set_query(const float* x) final {
|
||||
q = x;
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) final {
|
||||
aq.decode(codes + i * d, tmp.data(), 1);
|
||||
aq.decode(codes + j * d, tmp.data() + d, 1);
|
||||
return vd(tmp.data(), tmp.data() + d);
|
||||
}
|
||||
|
||||
float distance_to_code(const uint8_t *code) final {
|
||||
aq.decode(code, tmp.data(), 1);
|
||||
return vd(q, tmp.data());
|
||||
}
|
||||
|
||||
virtual ~AQDistanceComputerDecompress() {}
|
||||
};
|
||||
|
||||
|
||||
template<bool is_IP, AdditiveQuantizer::Search_type_t st>
|
||||
struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
|
||||
std::vector<float> LUT;
|
||||
const AdditiveQuantizer & aq;
|
||||
size_t d;
|
||||
|
||||
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
|
||||
FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
||||
LUT(iaq.aq->total_codebook_size + iaq.d * 2),
|
||||
aq(*iaq.aq),
|
||||
d(iaq.d)
|
||||
{}
|
||||
|
||||
float bias;
|
||||
void set_query(const float* x) final {
|
||||
// this is quite sub-optimal for multiple queries
|
||||
aq.compute_LUT(1, x, LUT.data());
|
||||
if (is_IP) {
|
||||
bias = 0;
|
||||
} else {
|
||||
bias = fvec_norm_L2sqr(x, d);
|
||||
}
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) final {
|
||||
float *tmp = LUT.data();
|
||||
aq.decode(codes + i * d, tmp, 1);
|
||||
aq.decode(codes + j * d, tmp + d, 1);
|
||||
return fvec_L2sqr(tmp, tmp + d, d);
|
||||
}
|
||||
|
||||
float distance_to_code(const uint8_t *code) final {
|
||||
return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
|
||||
}
|
||||
|
||||
virtual ~AQDistanceComputerLUT() {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
/************************************************************
|
||||
* scanning implementation for search
|
||||
************************************************************/
|
||||
|
||||
|
||||
template <class VectorDistance, class ResultHandler>
|
||||
void search_with_decompress(
|
||||
const IndexAdditiveQuantizer& ir,
|
||||
|
@ -111,12 +195,58 @@ void search_with_LUT(
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
|
||||
|
||||
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
||||
if (metric_type == METRIC_L2) {
|
||||
using VD = VectorDistance<METRIC_L2>;
|
||||
VD vd = {size_t(d), metric_arg};
|
||||
return new AQDistanceComputerDecompress<VD>(*this, vd);
|
||||
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
|
||||
VD vd = {size_t(d), metric_arg};
|
||||
return new AQDistanceComputerDecompress<VD>(*this, vd);
|
||||
} else {
|
||||
FAISS_THROW_MSG("unsupported metric");
|
||||
}
|
||||
} else {
|
||||
if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
|
||||
} else {
|
||||
switch(aq->search_type) {
|
||||
#define DISPATCH(st) \
|
||||
case AdditiveQuantizer::st: \
|
||||
return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
|
||||
break;
|
||||
DISPATCH(ST_norm_float)
|
||||
DISPATCH(ST_LUT_nonorm)
|
||||
DISPATCH(ST_norm_qint8)
|
||||
DISPATCH(ST_norm_qint4)
|
||||
DISPATCH(ST_norm_cqint4)
|
||||
case AdditiveQuantizer::ST_norm_cqint8:
|
||||
case AdditiveQuantizer::ST_norm_lsq2x4:
|
||||
case AdditiveQuantizer::ST_norm_rq2x4:
|
||||
return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
|
||||
break;
|
||||
#undef DISPATCH
|
||||
default:
|
||||
FAISS_THROW_FMT("search type %d not supported", aq->search_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void IndexAdditiveQuantizer::search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels) const {
|
||||
|
||||
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
||||
if (metric_type == METRIC_L2) {
|
||||
using VD = VectorDistance<METRIC_L2>;
|
||||
|
@ -135,22 +265,23 @@ void IndexAdditiveQuantizer::search(
|
|||
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
|
||||
} else {
|
||||
HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
|
||||
|
||||
if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
|
||||
} else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
|
||||
} else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
|
||||
} else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
|
||||
} else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
|
||||
aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
|
||||
aq->search_type == AdditiveQuantizer::ST_norm_rq2x4) {
|
||||
switch(aq->search_type) {
|
||||
#define DISPATCH(st) \
|
||||
case AdditiveQuantizer::st: \
|
||||
search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
|
||||
break;
|
||||
DISPATCH(ST_norm_float)
|
||||
DISPATCH(ST_LUT_nonorm)
|
||||
DISPATCH(ST_norm_qint8)
|
||||
DISPATCH(ST_norm_qint4)
|
||||
DISPATCH(ST_norm_cqint4)
|
||||
case AdditiveQuantizer::ST_norm_cqint8:
|
||||
case AdditiveQuantizer::ST_norm_lsq2x4:
|
||||
case AdditiveQuantizer::ST_norm_rq2x4:
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
|
||||
} else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
|
||||
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
|
||||
} else {
|
||||
break;
|
||||
#undef DISPATCH
|
||||
default:
|
||||
FAISS_THROW_FMT("search type %d not supported", aq->search_type);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,6 +43,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
|
|||
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
||||
|
||||
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
||||
|
||||
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
||||
};
|
||||
|
||||
/** Index based on a residual quantizer. Stored vectors are
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
#include <faiss/IndexBinaryFlat.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
|
|
|
@ -83,16 +83,16 @@ void IndexFlat::compute_distance_subset(
|
|||
|
||||
namespace {
|
||||
|
||||
struct FlatL2Dis : DistanceComputer {
|
||||
struct FlatL2Dis : FlatCodesDistanceComputer {
|
||||
size_t d;
|
||||
Index::idx_t nb;
|
||||
const float* q;
|
||||
const float* b;
|
||||
size_t ndis;
|
||||
|
||||
float operator()(idx_t i) override {
|
||||
float distance_to_code(const uint8_t* code) final {
|
||||
ndis++;
|
||||
return fvec_L2sqr(q, b + i * d, d);
|
||||
return fvec_L2sqr(q, (float*)code, d);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
|
@ -100,7 +100,10 @@ struct FlatL2Dis : DistanceComputer {
|
|||
}
|
||||
|
||||
explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
|
||||
: d(storage.d),
|
||||
: FlatCodesDistanceComputer(
|
||||
storage.codes.data(),
|
||||
storage.code_size),
|
||||
d(storage.d),
|
||||
nb(storage.ntotal),
|
||||
q(q),
|
||||
b(storage.get_xb()),
|
||||
|
@ -111,24 +114,27 @@ struct FlatL2Dis : DistanceComputer {
|
|||
}
|
||||
};
|
||||
|
||||
struct FlatIPDis : DistanceComputer {
|
||||
struct FlatIPDis : FlatCodesDistanceComputer {
|
||||
size_t d;
|
||||
Index::idx_t nb;
|
||||
const float* q;
|
||||
const float* b;
|
||||
size_t ndis;
|
||||
|
||||
float operator()(idx_t i) override {
|
||||
ndis++;
|
||||
return fvec_inner_product(q, b + i * d, d);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return fvec_inner_product(b + j * d, b + i * d, d);
|
||||
}
|
||||
|
||||
float distance_to_code(const uint8_t* code) final {
|
||||
ndis++;
|
||||
return fvec_inner_product(q, (float*)code, d);
|
||||
}
|
||||
|
||||
explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
|
||||
: d(storage.d),
|
||||
: FlatCodesDistanceComputer(
|
||||
storage.codes.data(),
|
||||
storage.code_size),
|
||||
d(storage.d),
|
||||
nb(storage.ntotal),
|
||||
q(q),
|
||||
b(storage.get_xb()),
|
||||
|
@ -141,7 +147,7 @@ struct FlatIPDis : DistanceComputer {
|
|||
|
||||
} // namespace
|
||||
|
||||
DistanceComputer* IndexFlat::get_distance_computer() const {
|
||||
FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
|
||||
if (metric_type == METRIC_L2) {
|
||||
return new FlatL2Dis(*this);
|
||||
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
||||
|
|
|
@ -60,7 +60,7 @@ struct IndexFlat : IndexFlatCodes {
|
|||
|
||||
IndexFlat() {}
|
||||
|
||||
DistanceComputer* get_distance_computer() const override;
|
||||
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
||||
|
||||
/* The stanadlone codec interface (just memcopies in this case) */
|
||||
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <faiss/IndexFlatCodes.h>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
||||
namespace faiss {
|
||||
|
@ -64,4 +65,9 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
|
|||
reconstruct_n(key, 1, recons);
|
||||
}
|
||||
|
||||
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
||||
const {
|
||||
FAISS_THROW_MSG("not implemented");
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <vector>
|
||||
|
||||
namespace faiss {
|
||||
|
@ -42,6 +43,13 @@ struct IndexFlatCodes : Index {
|
|||
* indexing structure, the semantics of this operation are
|
||||
* different from the usual ones: the new ids are shifted */
|
||||
size_t remove_ids(const IDSelector& sel) override;
|
||||
|
||||
/** a FlatCodesDistanceComputer offers a distance_to_code method */
|
||||
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
|
||||
|
||||
DistanceComputer* get_distance_computer() const override {
|
||||
return get_FlatCodesDistanceComputer();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -5,9 +5,6 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// quiet the noise
|
||||
// XXclang-format off
|
||||
|
||||
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/IndexPQ.h>
|
||||
|
||||
#include <cinttypes>
|
||||
|
@ -17,7 +15,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
|
||||
|
@ -73,19 +71,16 @@ void IndexPQ::train(idx_t n, const float* x) {
|
|||
namespace {
|
||||
|
||||
template <class PQDecoder>
|
||||
struct PQDistanceComputer : DistanceComputer {
|
||||
struct PQDistanceComputer : FlatCodesDistanceComputer {
|
||||
size_t d;
|
||||
MetricType metric;
|
||||
Index::idx_t nb;
|
||||
const uint8_t* codes;
|
||||
size_t code_size;
|
||||
const ProductQuantizer& pq;
|
||||
const float* sdc;
|
||||
std::vector<float> precomputed_table;
|
||||
size_t ndis;
|
||||
|
||||
float operator()(idx_t i) override {
|
||||
const uint8_t* code = codes + i * code_size;
|
||||
float distance_to_code(const uint8_t* code) final {
|
||||
const float* dt = precomputed_table.data();
|
||||
PQDecoder decoder(code, pq.nbits);
|
||||
float accu = 0;
|
||||
|
@ -112,13 +107,15 @@ struct PQDistanceComputer : DistanceComputer {
|
|||
return accu;
|
||||
}
|
||||
|
||||
explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
|
||||
explicit PQDistanceComputer(const IndexPQ& storage)
|
||||
: FlatCodesDistanceComputer(
|
||||
storage.codes.data(),
|
||||
storage.code_size),
|
||||
pq(storage.pq) {
|
||||
precomputed_table.resize(pq.M * pq.ksub);
|
||||
nb = storage.ntotal;
|
||||
d = storage.d;
|
||||
metric = storage.metric_type;
|
||||
codes = storage.codes.data();
|
||||
code_size = pq.code_size;
|
||||
if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
|
||||
sdc = pq.sdc_table.data();
|
||||
} else {
|
||||
|
@ -138,7 +135,7 @@ struct PQDistanceComputer : DistanceComputer {
|
|||
|
||||
} // namespace
|
||||
|
||||
DistanceComputer* IndexPQ::get_distance_computer() const {
|
||||
FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
|
||||
if (pq.nbits == 8) {
|
||||
return new PQDistanceComputer<PQDecoder8>(*this);
|
||||
} else if (pq.nbits == 16) {
|
||||
|
|
|
@ -52,7 +52,7 @@ struct IndexPQ : IndexFlatCodes {
|
|||
|
||||
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
||||
|
||||
DistanceComputer* get_distance_computer() const override;
|
||||
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
||||
|
||||
/******************************************************
|
||||
* Polysemous codes implementation
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
||||
namespace faiss {
|
||||
|
|
|
@ -85,7 +85,8 @@ void IndexScalarQuantizer::search(
|
|||
}
|
||||
}
|
||||
|
||||
DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
|
||||
FlatCodesDistanceComputer* IndexScalarQuantizer::get_FlatCodesDistanceComputer()
|
||||
const {
|
||||
ScalarQuantizer::SQDistanceComputer* dc =
|
||||
sq.get_distance_computer(metric_type);
|
||||
dc->code_size = sq.code_size;
|
||||
|
|
|
@ -20,11 +20,8 @@
|
|||
namespace faiss {
|
||||
|
||||
/**
|
||||
* The uniform quantizer has a range [vmin, vmax]. The range can be
|
||||
* the same for all dimensions (uniform) or specific per dimension
|
||||
* (default).
|
||||
* Flat index built on a scalar quantizer.
|
||||
*/
|
||||
|
||||
struct IndexScalarQuantizer : IndexFlatCodes {
|
||||
/// Used to encode the vectors
|
||||
ScalarQuantizer sq;
|
||||
|
@ -51,7 +48,7 @@ struct IndexScalarQuantizer : IndexFlatCodes {
|
|||
float* distances,
|
||||
idx_t* labels) const override;
|
||||
|
||||
DistanceComputer* get_distance_computer() const override;
|
||||
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
||||
|
||||
/* standalone codec interface */
|
||||
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
||||
|
|
|
@ -83,24 +83,27 @@ void AdditiveQuantizer::set_derived_values() {
|
|||
}
|
||||
total_codebook_size = codebook_offsets[M];
|
||||
switch (search_type) {
|
||||
case ST_decompress:
|
||||
case ST_LUT_nonorm:
|
||||
case ST_norm_from_LUT:
|
||||
break; // nothing to add
|
||||
case ST_norm_float:
|
||||
tot_bits += 32;
|
||||
norm_bits = 32;
|
||||
break;
|
||||
case ST_norm_qint8:
|
||||
case ST_norm_cqint8:
|
||||
case ST_norm_lsq2x4:
|
||||
case ST_norm_rq2x4:
|
||||
tot_bits += 8;
|
||||
norm_bits = 8;
|
||||
break;
|
||||
case ST_norm_qint4:
|
||||
case ST_norm_cqint4:
|
||||
tot_bits += 4;
|
||||
norm_bits = 4;
|
||||
break;
|
||||
case ST_decompress:
|
||||
case ST_LUT_nonorm:
|
||||
case ST_norm_from_LUT:
|
||||
default:
|
||||
norm_bits = 0;
|
||||
break;
|
||||
}
|
||||
tot_bits += norm_bits;
|
||||
|
||||
// convert bits to bytes
|
||||
code_size = (tot_bits + 7) / 8;
|
||||
|
@ -195,6 +198,28 @@ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
|
|||
return qnorm.get_xb()[c];
|
||||
}
|
||||
|
||||
uint64_t AdditiveQuantizer::encode_norm(float norm) const {
|
||||
switch (search_type) {
|
||||
case ST_norm_float:
|
||||
return *(uint32_t*)&norm;
|
||||
case ST_norm_qint8:
|
||||
return encode_qint8(norm, norm_min, norm_max);
|
||||
case ST_norm_qint4:
|
||||
return encode_qint4(norm, norm_min, norm_max);
|
||||
case ST_norm_lsq2x4:
|
||||
case ST_norm_rq2x4:
|
||||
case ST_norm_cqint8:
|
||||
return encode_qcint(norm);
|
||||
case ST_norm_cqint4:
|
||||
return encode_qcint(norm);
|
||||
case ST_decompress:
|
||||
case ST_LUT_nonorm:
|
||||
case ST_norm_from_LUT:
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void AdditiveQuantizer::pack_codes(
|
||||
size_t n,
|
||||
const int32_t* codes,
|
||||
|
@ -230,36 +255,8 @@ void AdditiveQuantizer::pack_codes(
|
|||
for (int m = 0; m < M; m++) {
|
||||
bsw.write(codes1[m], nbits[m]);
|
||||
}
|
||||
switch (search_type) {
|
||||
case ST_decompress:
|
||||
case ST_LUT_nonorm:
|
||||
case ST_norm_from_LUT:
|
||||
break;
|
||||
case ST_norm_float:
|
||||
bsw.write(*(uint32_t*)&norms[i], 32);
|
||||
break;
|
||||
case ST_norm_qint8: {
|
||||
uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
|
||||
bsw.write(b, 8);
|
||||
break;
|
||||
}
|
||||
case ST_norm_qint4: {
|
||||
uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
|
||||
bsw.write(b, 4);
|
||||
break;
|
||||
}
|
||||
case ST_norm_lsq2x4:
|
||||
case ST_norm_rq2x4:
|
||||
case ST_norm_cqint8: {
|
||||
uint32_t b = encode_qcint(norms[i]);
|
||||
bsw.write(b, 8);
|
||||
break;
|
||||
}
|
||||
case ST_norm_cqint4: {
|
||||
uint32_t b = encode_qcint(norms[i]);
|
||||
bsw.write(b, 4);
|
||||
break;
|
||||
}
|
||||
if (norm_bits != 0) {
|
||||
bsw.write(encode_norm(norms[i]), norm_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,8 @@ struct AdditiveQuantizer {
|
|||
// derived values
|
||||
std::vector<uint64_t> codebook_offsets;
|
||||
size_t code_size; ///< code size in bytes
|
||||
size_t tot_bits; ///< total number of bits
|
||||
size_t tot_bits; ///< total number of bits (indexes + norms)
|
||||
size_t norm_bits; ///< bits allocated for the norms
|
||||
size_t total_codebook_size; ///< size of the codebook in vectors
|
||||
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
|
||||
|
||||
|
@ -41,6 +42,9 @@ struct AdditiveQuantizer {
|
|||
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
|
||||
///< fastscan search
|
||||
|
||||
/// encode a norm into norm_bits bits
|
||||
uint64_t encode_norm(float norm) const;
|
||||
|
||||
uint32_t encode_qcint(
|
||||
float x) const; ///< encode norm by non-uniform scalar quantization
|
||||
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
// Auxiliary index structures, that are used in indexes but that can
|
||||
// be forward-declared
|
||||
|
||||
|
@ -186,30 +184,6 @@ struct RangeSearchPartialResult : BufferList {
|
|||
bool do_delete = true);
|
||||
};
|
||||
|
||||
/***********************************************************
|
||||
* The distance computer maintains a current query and computes
|
||||
* distances to elements in an index that supports random access.
|
||||
*
|
||||
* The DistanceComputer is not intended to be thread-safe (eg. because
|
||||
* it maintains counters) so the distance functions are not const,
|
||||
* instantiate one from each thread if needed.
|
||||
***********************************************************/
|
||||
struct DistanceComputer {
|
||||
using idx_t = Index::idx_t;
|
||||
|
||||
/// called before computing distances. Pointer x should remain valid
|
||||
/// while operator () is called
|
||||
virtual void set_query(const float* x) = 0;
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
virtual float operator()(idx_t i) = 0;
|
||||
|
||||
/// compute distance between two stored vectors
|
||||
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
|
||||
|
||||
virtual ~DistanceComputer() {}
|
||||
};
|
||||
|
||||
/***********************************************************
|
||||
* Interrupt callback
|
||||
***********************************************************/
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/Index.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/***********************************************************
|
||||
* The distance computer maintains a current query and computes
|
||||
* distances to elements in an index that supports random access.
|
||||
*
|
||||
* The DistanceComputer is not intended to be thread-safe (eg. because
|
||||
* it maintains counters) so the distance functions are not const,
|
||||
* instantiate one from each thread if needed.
|
||||
*
|
||||
* Note that the equivalent for IVF indexes is the InvertedListScanner,
|
||||
* that has additional methods to handle the inverted list context.
|
||||
***********************************************************/
|
||||
struct DistanceComputer {
|
||||
using idx_t = Index::idx_t;
|
||||
|
||||
/// called before computing distances. Pointer x should remain valid
|
||||
/// while operator () is called
|
||||
virtual void set_query(const float* x) = 0;
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
virtual float operator()(idx_t i) = 0;
|
||||
|
||||
/// compute distance between two stored vectors
|
||||
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
|
||||
|
||||
virtual ~DistanceComputer() {}
|
||||
};
|
||||
|
||||
/*************************************************************
|
||||
* Specialized version of the DistanceComputer when we know that codes are
|
||||
* laid out in a flat index.
|
||||
*/
|
||||
struct FlatCodesDistanceComputer : DistanceComputer {
|
||||
const uint8_t* codes;
|
||||
size_t code_size;
|
||||
|
||||
FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
|
||||
: codes(codes), code_size(code_size) {}
|
||||
|
||||
FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
|
||||
|
||||
float operator()(idx_t i) final {
|
||||
return distance_to_code(codes + i * code_size);
|
||||
}
|
||||
|
||||
/// compute distance of current query to an encoded vector
|
||||
virtual float distance_to_code(const uint8_t* code) = 0;
|
||||
|
||||
virtual ~FlatCodesDistanceComputer() {}
|
||||
};
|
||||
|
||||
} // namespace faiss
|
|
@ -12,6 +12,7 @@
|
|||
#include <string>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <string>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#include <mutex>
|
||||
#include <stack>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/impl/ResidualQuantizer.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -97,6 +95,39 @@ ResidualQuantizer::ResidualQuantizer(
|
|||
Search_type_t search_type)
|
||||
: ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
|
||||
|
||||
void ResidualQuantizer::initialize_from(
|
||||
const ResidualQuantizer& other,
|
||||
int skip_M) {
|
||||
FAISS_THROW_IF_NOT(M + skip_M <= other.M);
|
||||
FAISS_THROW_IF_NOT(skip_M >= 0);
|
||||
|
||||
Search_type_t this_search_type = search_type;
|
||||
int this_M = M;
|
||||
|
||||
// a first good approximation: override everything
|
||||
*this = other;
|
||||
|
||||
// adjust derived values
|
||||
M = this_M;
|
||||
search_type = this_search_type;
|
||||
nbits.resize(M);
|
||||
memcpy(nbits.data(),
|
||||
other.nbits.data() + skip_M,
|
||||
nbits.size() * sizeof(nbits[0]));
|
||||
|
||||
set_derived_values();
|
||||
|
||||
// resize codebooks if trained
|
||||
if (codebooks.size() > 0) {
|
||||
FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
|
||||
codebooks.resize(total_codebook_size * d);
|
||||
memcpy(codebooks.data(),
|
||||
other.codebooks.data() + other.codebook_offsets[skip_M] * d,
|
||||
codebooks.size() * sizeof(codebooks[0]));
|
||||
// TODO: norm_tabs?
|
||||
}
|
||||
}
|
||||
|
||||
void beam_search_encode_step(
|
||||
size_t d,
|
||||
size_t K,
|
||||
|
|
|
@ -28,7 +28,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|||
// Was enum but that does not work so well with bitmasks
|
||||
using train_type_t = int;
|
||||
|
||||
/// Or of the Train_* flags below
|
||||
/// Binary or of the Train_* flags below
|
||||
train_type_t train_type;
|
||||
|
||||
/// regular k-means (minimal amount of computation)
|
||||
|
@ -47,7 +47,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|||
* first element of the beam (faster but less accurate) */
|
||||
static const int Train_top_beam = 1024;
|
||||
|
||||
/** set this bit to not autmatically compute the codebook tables
|
||||
/** set this bit to *not* autmatically compute the codebook tables
|
||||
* after training */
|
||||
static const int Skip_codebook_tables = 2048;
|
||||
|
||||
|
@ -83,6 +83,9 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|||
/// Train the residual quantizer
|
||||
void train(size_t n, const float* x) override;
|
||||
|
||||
/// Copy the M codebook levels from other, starting from skip_M
|
||||
void initialize_from(const ResidualQuantizer& other, int skip_M = 0);
|
||||
|
||||
/** Encode the vectors and compute codebook that minimizes the quantization
|
||||
* error on these codes
|
||||
*
|
||||
|
|
|
@ -911,11 +911,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
|
|||
q = x;
|
||||
}
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
float operator()(idx_t i) final {
|
||||
return query_to_code(codes + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return compute_code_distance(
|
||||
codes + i * code_size, codes + j * code_size);
|
||||
|
@ -963,11 +958,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
|||
q = x;
|
||||
}
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
float operator()(idx_t i) final {
|
||||
return query_to_code(codes + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return compute_code_distance(
|
||||
codes + i * code_size, codes + j * code_size);
|
||||
|
@ -1021,11 +1011,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|||
return compute_code_distance(tmp.data(), code);
|
||||
}
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
float operator()(idx_t i) final {
|
||||
return query_to_code(codes + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return compute_code_distance(
|
||||
codes + i * code_size, codes + j * code_size);
|
||||
|
@ -1089,11 +1074,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|||
return compute_code_distance(tmp.data(), code);
|
||||
}
|
||||
|
||||
/// compute distance of vector i to current query
|
||||
float operator()(idx_t i) final {
|
||||
return query_to_code(codes + i * code_size);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return compute_code_distance(
|
||||
codes + i * code_size, codes + j * code_size);
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <faiss/IndexIVF.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
@ -105,14 +106,16 @@ struct ScalarQuantizer {
|
|||
|
||||
Quantizer* select_quantizer() const;
|
||||
|
||||
struct SQDistanceComputer : DistanceComputer {
|
||||
struct SQDistanceComputer : FlatCodesDistanceComputer {
|
||||
const float* q;
|
||||
const uint8_t* codes;
|
||||
size_t code_size;
|
||||
|
||||
SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
|
||||
SQDistanceComputer() : q(nullptr) {}
|
||||
|
||||
virtual float query_to_code(const uint8_t* code) const = 0;
|
||||
|
||||
float distance_to_code(const uint8_t* code) final {
|
||||
return query_to_code(code);
|
||||
}
|
||||
};
|
||||
|
||||
SQDistanceComputer* get_distance_computer(
|
||||
|
|
|
@ -1287,6 +1287,12 @@ def randn(n, seed=12345):
|
|||
float_randn(swig_ptr(res), res.size, seed)
|
||||
return res
|
||||
|
||||
rand_smooth_vectors_c = rand_smooth_vectors
|
||||
|
||||
def rand_smooth_vectors(n, d, seed=1234):
|
||||
res = np.empty((n, d), dtype='float32')
|
||||
rand_smooth_vectors_c(n, d, swig_ptr(res), seed)
|
||||
return res
|
||||
|
||||
def eval_intersection(I1, I2):
|
||||
""" size of intersection between each line of two result tables"""
|
||||
|
@ -1429,24 +1435,14 @@ def knn(xq, xb, k, metric=METRIC_L2):
|
|||
D = np.empty((nq, k), dtype='float32')
|
||||
|
||||
if metric == METRIC_L2:
|
||||
heaps = float_maxheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = swig_ptr(D)
|
||||
heaps.ids = swig_ptr(I)
|
||||
knn_L2sqr(
|
||||
swig_ptr(xq), swig_ptr(xb),
|
||||
d, nq, nb, heaps
|
||||
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
|
||||
)
|
||||
elif metric == METRIC_INNER_PRODUCT:
|
||||
heaps = float_minheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = swig_ptr(D)
|
||||
heaps.ids = swig_ptr(I)
|
||||
knn_inner_product(
|
||||
swig_ptr(xq), swig_ptr(xb),
|
||||
d, nq, nb, heaps
|
||||
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("only L2 and INNER_PRODUCT are supported")
|
||||
|
|
|
@ -126,6 +126,7 @@ typedef uint64_t size_t;
|
|||
#include <faiss/utils/AlignedTable.h>
|
||||
#include <faiss/utils/partitioning.h>
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/AdditiveQuantizer.h>
|
||||
#include <faiss/impl/ResidualQuantizer.h>
|
||||
#include <faiss/impl/LocalSearchQuantizer.h>
|
||||
|
@ -378,6 +379,10 @@ void gpu_sync_all_devices()
|
|||
|
||||
%newobject *::get_distance_computer() const;
|
||||
%include <faiss/Index.h>
|
||||
|
||||
%include <faiss/impl/DistanceComputer.h>
|
||||
|
||||
%newobject *::get_FlatCodesDistanceComputer() const;
|
||||
%include <faiss/IndexFlatCodes.h>
|
||||
%include <faiss/IndexFlat.h>
|
||||
%include <faiss/Clustering.h>
|
||||
|
@ -484,7 +489,8 @@ void gpu_sync_all_devices()
|
|||
%ignore faiss::InterruptCallback::instance;
|
||||
%ignore faiss::InterruptCallback::lock;
|
||||
|
||||
%include <faiss/impl/AuxIndexStructures.h>
|
||||
%include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
|
||||
|
||||
#ifdef GPU_WRAPPER
|
||||
|
|
|
@ -333,6 +333,19 @@ void knn_inner_product(
|
|||
}
|
||||
}
|
||||
|
||||
void knn_inner_product(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* indexes) {
|
||||
float_minheap_array_t heaps = {nx, k, indexes, distances};
|
||||
knn_inner_product(x, y, d, nx, ny, &heaps);
|
||||
}
|
||||
|
||||
void knn_L2sqr(
|
||||
const float* x,
|
||||
const float* y,
|
||||
|
@ -361,6 +374,20 @@ void knn_L2sqr(
|
|||
}
|
||||
}
|
||||
|
||||
void knn_L2sqr(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* indexes,
|
||||
const float* y_norm2) {
|
||||
float_maxheap_array_t heaps = {nx, k, indexes, distances};
|
||||
knn_L2sqr(x, y, d, nx, ny, &heaps, y_norm2);
|
||||
}
|
||||
|
||||
/***************************************************************************
|
||||
* Range search
|
||||
***************************************************************************/
|
||||
|
|
|
@ -198,11 +198,11 @@ FAISS_API extern int distance_compute_blas_database_bs;
|
|||
FAISS_API extern int distance_compute_min_k_reservoir;
|
||||
|
||||
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
||||
* vector y, w.r.t to max inner product
|
||||
* vector y, w.r.t to max inner product.
|
||||
*
|
||||
* @param x query vectors, size nx * d
|
||||
* @param y database vectors, size ny * d
|
||||
* @param res result array, which also provides k. Sorted on output
|
||||
* @param res result heap structure, which also provides k. Sorted on output
|
||||
*/
|
||||
void knn_inner_product(
|
||||
const float* x,
|
||||
|
@ -212,8 +212,30 @@ void knn_inner_product(
|
|||
size_t ny,
|
||||
float_minheap_array_t* res);
|
||||
|
||||
/** Same as knn_inner_product, for the L2 distance
|
||||
* @param y_norm2 norms for the y vectors (nullptr or size ny)
|
||||
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
||||
* vector y, for the inner product metric.
|
||||
*
|
||||
* @param x query vectors, size nx * d
|
||||
* @param y database vectors, size ny * d
|
||||
* @param distances output distances, size nq * k
|
||||
* @param indexes output vector ids, size nq * k
|
||||
*/
|
||||
void knn_inner_product(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* indexes);
|
||||
|
||||
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
||||
* vector y, for the L2 distance
|
||||
* @param x query vectors, size nx * d
|
||||
* @param y database vectors, size ny * d
|
||||
* @param res result heap strcture, which also provides k. Sorted on output
|
||||
* @param y_norm2 (optional) norms for the y vectors (nullptr or size ny)
|
||||
*/
|
||||
void knn_L2sqr(
|
||||
const float* x,
|
||||
|
@ -224,6 +246,26 @@ void knn_L2sqr(
|
|||
float_maxheap_array_t* res,
|
||||
const float* y_norm2 = nullptr);
|
||||
|
||||
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
||||
* vector y, for the L2 distance
|
||||
*
|
||||
* @param x query vectors, size nx * d
|
||||
* @param y database vectors, size ny * d
|
||||
* @param distances output distances, size nq * k
|
||||
* @param indexes output vector ids, size nq * k
|
||||
* @param y_norm2 (optional) norms for the y vectors (nullptr or size ny)
|
||||
*/
|
||||
void knn_L2sqr(
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t nx,
|
||||
size_t ny,
|
||||
size_t k,
|
||||
float* distances,
|
||||
int64_t* indexes,
|
||||
const float* y_norm2 = nullptr);
|
||||
|
||||
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
||||
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
||||
*/
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <cmath>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
#include <faiss/impl/DistanceComputer.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
|
||||
|
@ -89,18 +90,18 @@ void knn_extra_metrics_template(
|
|||
}
|
||||
|
||||
template <class VD>
|
||||
struct ExtraDistanceComputer : DistanceComputer {
|
||||
struct ExtraDistanceComputer : FlatCodesDistanceComputer {
|
||||
VD vd;
|
||||
Index::idx_t nb;
|
||||
const float* q;
|
||||
const float* b;
|
||||
|
||||
float operator()(idx_t i) override {
|
||||
return vd(q, b + i * vd.d);
|
||||
float symmetric_dis(idx_t i, idx_t j) final {
|
||||
return vd(b + j * vd.d, b + i * vd.d);
|
||||
}
|
||||
|
||||
float symmetric_dis(idx_t i, idx_t j) override {
|
||||
return vd(b + j * vd.d, b + i * vd.d);
|
||||
float distance_to_code(const uint8_t* code) final {
|
||||
return vd(q, (float*)code);
|
||||
}
|
||||
|
||||
ExtraDistanceComputer(
|
||||
|
@ -108,7 +109,11 @@ struct ExtraDistanceComputer : DistanceComputer {
|
|||
const float* xb,
|
||||
size_t nb,
|
||||
const float* q = nullptr)
|
||||
: vd(vd), nb(nb), q(q), b(xb) {}
|
||||
: FlatCodesDistanceComputer((uint8_t*)xb, vd.d * sizeof(float)),
|
||||
vd(vd),
|
||||
nb(nb),
|
||||
q(q),
|
||||
b(xb) {}
|
||||
|
||||
void set_query(const float* x) override {
|
||||
q = x;
|
||||
|
@ -188,7 +193,7 @@ void knn_extra_metrics(
|
|||
}
|
||||
}
|
||||
|
||||
DistanceComputer* get_extra_distance_computer(
|
||||
FlatCodesDistanceComputer* get_extra_distance_computer(
|
||||
size_t d,
|
||||
MetricType mt,
|
||||
float metric_arg,
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
namespace faiss {
|
||||
|
||||
struct FlatCodesDistanceComputer;
|
||||
|
||||
void pairwise_extra_distances(
|
||||
int64_t d,
|
||||
int64_t nq,
|
||||
|
@ -43,7 +45,7 @@ void knn_extra_metrics(
|
|||
|
||||
/** get a DistanceComputer that refers to this type of distance and
|
||||
* indexes a flat array of size nb */
|
||||
DistanceComputer* get_extra_distance_computer(
|
||||
FlatCodesDistanceComputer* get_extra_distance_computer(
|
||||
size_t d,
|
||||
MetricType mt,
|
||||
float metric_arg,
|
||||
|
|
|
@ -9,6 +9,23 @@
|
|||
|
||||
#include <faiss/utils/random.h>
|
||||
|
||||
extern "C" {
|
||||
int sgemm_(
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
FINTEGER* m,
|
||||
FINTEGER* n,
|
||||
FINTEGER* k,
|
||||
const float* alpha,
|
||||
const float* a,
|
||||
FINTEGER* lda,
|
||||
const float* b,
|
||||
FINTEGER* ldb,
|
||||
float* beta,
|
||||
float* c,
|
||||
FINTEGER* ldc);
|
||||
}
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/**************************************************
|
||||
|
@ -165,4 +182,40 @@ void byte_rand(uint8_t* x, size_t n, int64_t seed) {
|
|||
}
|
||||
}
|
||||
|
||||
void rand_smooth_vectors(size_t n, size_t d, float* x, int64_t seed) {
|
||||
size_t d1 = 10;
|
||||
std::vector<float> x1(n * d1);
|
||||
float_randn(x1.data(), x1.size(), seed);
|
||||
std::vector<float> rot(d1 * d);
|
||||
float_rand(rot.data(), rot.size(), seed + 1);
|
||||
|
||||
{ //
|
||||
FINTEGER di = d, d1i = d1, ni = n;
|
||||
float one = 1.0, zero = 0.0;
|
||||
sgemm_("Not transposed",
|
||||
"Not transposed", // natural order
|
||||
&di,
|
||||
&ni,
|
||||
&d1i,
|
||||
&one,
|
||||
rot.data(),
|
||||
&di, // rotation matrix
|
||||
x1.data(),
|
||||
&d1i, // second term
|
||||
&zero,
|
||||
x,
|
||||
&di);
|
||||
}
|
||||
|
||||
std::vector<float> scales(d);
|
||||
float_rand(scales.data(), d, seed + 2);
|
||||
|
||||
#pragma omp parallel for if (n * d > 10000)
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (size_t j = 0; j < d; j++) {
|
||||
x[i * d + j] = sinf(x[i * d + j] * (scales[j] * 4 + 0.1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -54,4 +54,9 @@ void int64_rand_max(int64_t* x, size_t n, uint64_t max, int64_t seed);
|
|||
/* random permutation */
|
||||
void rand_perm(int* perm, size_t n, int64_t seed);
|
||||
|
||||
/* Random set of vectors with intrinsic dimensionality 10 that is harder to
|
||||
* index than a subspace of dim 10 but easier than uniform data in dimension d
|
||||
* */
|
||||
void rand_smooth_vectors(size_t n, size_t d, float* x, int64_t seed);
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -325,6 +325,7 @@ class TestScalarQuantizer(unittest.TestCase):
|
|||
# print(dis, D[i, j])
|
||||
assert abs(D[i, j] - dis) / dis < 1e-5
|
||||
|
||||
|
||||
class TestRandom(unittest.TestCase):
|
||||
|
||||
def test_rand(self):
|
||||
|
@ -340,6 +341,24 @@ class TestRandom(unittest.TestCase):
|
|||
print(c)
|
||||
assert c.max() - c.min() < 50 * 2
|
||||
|
||||
def test_rand_vector(self):
|
||||
""" test if the smooth_vectors function is reasonably compressible with
|
||||
a small PQ """
|
||||
x = faiss.rand_smooth_vectors(1300, 32)
|
||||
xt = x[:1000]
|
||||
xb = x[1000:1200]
|
||||
xq = x[1200:]
|
||||
_, gt = faiss.knn(xq, xb, 10)
|
||||
index = faiss.IndexPQ(32, 4, 4)
|
||||
index.train(xt)
|
||||
index.add(xb)
|
||||
D, I = index.search(xq, 10)
|
||||
ninter = faiss.eval_intersection(I, gt)
|
||||
# 445 for SyntheticDataset
|
||||
self.assertGreater(ninter, 420)
|
||||
self.assertLess(ninter, 460)
|
||||
|
||||
|
||||
|
||||
class TestPairwiseDis(unittest.TestCase):
|
||||
|
||||
|
|
|
@ -21,15 +21,22 @@ class TestDistanceComputer(unittest.TestCase):
|
|||
index.add(ds.get_database())
|
||||
xq = ds.get_queries()
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
dc = index.get_distance_computer()
|
||||
self.assertTrue(dc.this.own())
|
||||
for q in range(ds.nq):
|
||||
dc.set_query(faiss.swig_ptr(xq[q]))
|
||||
for j in range(10):
|
||||
ref_dis = Dref[q, j]
|
||||
new_dis = dc(int(Iref[q, j]))
|
||||
np.testing.assert_almost_equal(
|
||||
new_dis, ref_dis, decimal=5)
|
||||
|
||||
for is_FlatCodesDistanceComputer in False, True:
|
||||
if not is_FlatCodesDistanceComputer:
|
||||
dc = index.get_distance_computer()
|
||||
else:
|
||||
if not isinstance(index, faiss.IndexFlatCodes):
|
||||
continue
|
||||
dc = index.get_FlatCodesDistanceComputer()
|
||||
self.assertTrue(dc.this.own())
|
||||
for q in range(ds.nq):
|
||||
dc.set_query(faiss.swig_ptr(xq[q]))
|
||||
for j in range(10):
|
||||
ref_dis = Dref[q, j]
|
||||
new_dis = dc(int(Iref[q, j]))
|
||||
np.testing.assert_almost_equal(
|
||||
new_dis, ref_dis, decimal=5)
|
||||
|
||||
def test_distance_computer_PQ(self):
|
||||
self.do_test("PQ8np")
|
||||
|
@ -49,5 +56,11 @@ class TestDistanceComputer(unittest.TestCase):
|
|||
def test_distance_computer_VT(self):
|
||||
self.do_test("PCA20,SQ8")
|
||||
|
||||
def test_distance_computer_AQ_decompress(self):
|
||||
self.do_test("RQ3x4") # test decompress path
|
||||
|
||||
def test_distance_computer_AQ_LUT(self):
|
||||
self.do_test("RQ3x4_Nqint8") # test LUT path
|
||||
|
||||
def test_distance_computer_AQ_LUT_IP(self):
|
||||
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)
|
||||
|
|
|
@ -258,6 +258,38 @@ class TestResidualQuantizer(unittest.TestCase):
|
|||
for c0, c1 in zip(cb0, cb1):
|
||||
self.assertTrue(np.all(c0 == c1))
|
||||
|
||||
def test_clipping(self):
|
||||
""" verify that a clipped residual quantizer gives the same
|
||||
code prefix + suffix as the full RQ """
|
||||
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
|
||||
|
||||
rq = faiss.ResidualQuantizer(ds.d, 5, 4)
|
||||
rq.train_type = faiss.ResidualQuantizer.Train_default
|
||||
rq.max_beam_size = 5
|
||||
rq.train(ds.get_train())
|
||||
|
||||
rq.max_beam_size = 1 # is not he same for a large beam size
|
||||
codes = rq.compute_codes(ds.get_database())
|
||||
|
||||
rq2 = faiss.ResidualQuantizer(ds.d, 2, 4)
|
||||
rq2.initialize_from(rq)
|
||||
self.assertEqual(rq2.M, 2)
|
||||
# verify that the beginning of the codes are the same
|
||||
codes2 = rq2.compute_codes(ds.get_database())
|
||||
|
||||
rq3 = faiss.ResidualQuantizer(ds.d, 3, 4)
|
||||
rq3.initialize_from(rq, 2)
|
||||
self.assertEqual(rq3.M, 3)
|
||||
codes3 = rq3.compute_codes(ds.get_database() - rq2.decode(codes2))
|
||||
|
||||
# verify that prefixes are the same
|
||||
for i in range(ds.nb):
|
||||
print(i, ds.nb)
|
||||
br = faiss.BitstringReader(faiss.swig_ptr(codes[i]), rq.code_size)
|
||||
br2 = faiss.BitstringReader(faiss.swig_ptr(codes2[i]), rq2.code_size)
|
||||
self.assertEqual(br.read(rq2.tot_bits), br2.read(rq2.tot_bits))
|
||||
br3 = faiss.BitstringReader(faiss.swig_ptr(codes3[i]), rq3.code_size)
|
||||
self.assertEqual(br.read(rq3.tot_bits), br3.read(rq3.tot_bits))
|
||||
|
||||
###########################################################
|
||||
# Test index, index factory sa_encode / sa_decode
|
||||
|
@ -318,6 +350,7 @@ def retrain_AQ_codebook(index, xt):
|
|||
|
||||
return C, B
|
||||
|
||||
|
||||
class TestIndexResidualQuantizer(unittest.TestCase):
|
||||
|
||||
def test_io(self):
|
||||
|
|
Loading…
Reference in New Issue