Generalize DistanceComputer for flat indexes (#2255)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2255

The `DistanceComputer` object is derived from an Index (obtained with `get_distance_computer()`). It maintains a current query and quickly computes distances from that query to any item in the database. This is useful, eg. for the IndexHNSW and IndexNSG that rely on query-to-point comparisons in the datasets.

This diff introduces the `FlatCodesDistanceComputer`, that inherits from `DistanceComputer` for Flat indexes. In addition to the distance-to-item function, it adds a `distance_to_code` that computes the distance from any code to the current query, even if it is not stored in the index.

This is implemented for all FlatCode indexes (IndexFlat, IndexPQ, IndexScalarQuantizer and IndexAdditiveQuantizer).

In the process, the two classes were extracted to their own header file `impl/DistanceComputer.h`

Reviewed By: beauby

Differential Revision: D34863609

fbshipit-source-id: 39d8c66475e55c3223c4a6a210827aa48bca292d
pull/2276/head^2
Matthijs Douze 2022-03-20 23:43:33 -07:00 committed by Facebook GitHub Bot
parent add3705c11
commit 291353c5a9
40 changed files with 589 additions and 178 deletions

View File

@ -11,6 +11,7 @@ the Facebook Faiss team. Feel free to add entries here if you submit a PR.
## [Unreleased]
- Added sparse k-means routines and moved the generic kmeans to contrib
- Added FlatDistanceComputer for all FlatCodes indexes
## [1.7.2] - 2021-12-15
### Added

View File

@ -10,6 +10,7 @@
#include "AuxIndexStructures_c.h"
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <iostream>
#include "../macros_impl.h"

View File

@ -131,6 +131,7 @@ set(FAISS_HEADERS
index_io.h
impl/AdditiveQuantizer.h
impl/AuxIndexStructures.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
impl/HNSW.h

View File

@ -10,6 +10,7 @@
#include <faiss/Index.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>

View File

@ -38,7 +38,8 @@
namespace faiss {
/// Forward declarations see AuxIndexStructures.h
/// Forward declarations see impl/AuxIndexStructures.h and
/// impl/DistanceComputer.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;

View File

@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
namespace {
/************************************************************
* DistanceComputer implementation
************************************************************/
template <class VectorDistance>
struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
std::vector<float> tmp;
const AdditiveQuantizer & aq;
VectorDistance vd;
size_t d;
AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
tmp(iaq.d * 2),
aq(*iaq.aq),
vd(vd),
d(iaq.d)
{}
const float *q;
void set_query(const float* x) final {
q = x;
}
float symmetric_dis(idx_t i, idx_t j) final {
aq.decode(codes + i * d, tmp.data(), 1);
aq.decode(codes + j * d, tmp.data() + d, 1);
return vd(tmp.data(), tmp.data() + d);
}
float distance_to_code(const uint8_t *code) final {
aq.decode(code, tmp.data(), 1);
return vd(q, tmp.data());
}
virtual ~AQDistanceComputerDecompress() {}
};
template<bool is_IP, AdditiveQuantizer::Search_type_t st>
struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
std::vector<float> LUT;
const AdditiveQuantizer & aq;
size_t d;
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
LUT(iaq.aq->total_codebook_size + iaq.d * 2),
aq(*iaq.aq),
d(iaq.d)
{}
float bias;
void set_query(const float* x) final {
// this is quite sub-optimal for multiple queries
aq.compute_LUT(1, x, LUT.data());
if (is_IP) {
bias = 0;
} else {
bias = fvec_norm_L2sqr(x, d);
}
}
float symmetric_dis(idx_t i, idx_t j) final {
float *tmp = LUT.data();
aq.decode(codes + i * d, tmp, 1);
aq.decode(codes + j * d, tmp + d, 1);
return fvec_L2sqr(tmp, tmp + d, d);
}
float distance_to_code(const uint8_t *code) final {
return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
}
virtual ~AQDistanceComputerLUT() {}
};
/************************************************************
* scanning implementation for search
************************************************************/
template <class VectorDistance, class ResultHandler>
void search_with_decompress(
const IndexAdditiveQuantizer& ir,
@ -111,12 +195,58 @@ void search_with_LUT(
} // anonymous namespace
FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
if (metric_type == METRIC_L2) {
using VD = VectorDistance<METRIC_L2>;
VD vd = {size_t(d), metric_arg};
return new AQDistanceComputerDecompress<VD>(*this, vd);
} else if (metric_type == METRIC_INNER_PRODUCT) {
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
VD vd = {size_t(d), metric_arg};
return new AQDistanceComputerDecompress<VD>(*this, vd);
} else {
FAISS_THROW_MSG("unsupported metric");
}
} else {
if (metric_type == METRIC_INNER_PRODUCT) {
return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
} else {
switch(aq->search_type) {
#define DISPATCH(st) \
case AdditiveQuantizer::st: \
return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
break;
DISPATCH(ST_norm_float)
DISPATCH(ST_LUT_nonorm)
DISPATCH(ST_norm_qint8)
DISPATCH(ST_norm_qint4)
DISPATCH(ST_norm_cqint4)
case AdditiveQuantizer::ST_norm_cqint8:
case AdditiveQuantizer::ST_norm_lsq2x4:
case AdditiveQuantizer::ST_norm_rq2x4:
return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
break;
#undef DISPATCH
default:
FAISS_THROW_FMT("search type %d not supported", aq->search_type);
}
}
}
}
void IndexAdditiveQuantizer::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
if (metric_type == METRIC_L2) {
using VD = VectorDistance<METRIC_L2>;
@ -135,22 +265,23 @@ void IndexAdditiveQuantizer::search(
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
} else {
HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
} else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
} else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
} else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
} else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
aq->search_type == AdditiveQuantizer::ST_norm_rq2x4) {
switch(aq->search_type) {
#define DISPATCH(st) \
case AdditiveQuantizer::st: \
search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
break;
DISPATCH(ST_norm_float)
DISPATCH(ST_LUT_nonorm)
DISPATCH(ST_norm_qint8)
DISPATCH(ST_norm_qint4)
DISPATCH(ST_norm_cqint4)
case AdditiveQuantizer::ST_norm_cqint8:
case AdditiveQuantizer::ST_norm_lsq2x4:
case AdditiveQuantizer::ST_norm_rq2x4:
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
} else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
} else {
break;
#undef DISPATCH
default:
FAISS_THROW_FMT("search type %d not supported", aq->search_type);
}
}

View File

@ -43,6 +43,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
};
/** Index based on a residual quantizer. Stored vectors are

View File

@ -26,6 +26,7 @@
#include <faiss/IndexBinaryFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/hamming.h>

View File

@ -83,16 +83,16 @@ void IndexFlat::compute_distance_subset(
namespace {
struct FlatL2Dis : DistanceComputer {
struct FlatL2Dis : FlatCodesDistanceComputer {
size_t d;
Index::idx_t nb;
const float* q;
const float* b;
size_t ndis;
float operator()(idx_t i) override {
float distance_to_code(const uint8_t* code) final {
ndis++;
return fvec_L2sqr(q, b + i * d, d);
return fvec_L2sqr(q, (float*)code, d);
}
float symmetric_dis(idx_t i, idx_t j) override {
@ -100,7 +100,10 @@ struct FlatL2Dis : DistanceComputer {
}
explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
: d(storage.d),
: FlatCodesDistanceComputer(
storage.codes.data(),
storage.code_size),
d(storage.d),
nb(storage.ntotal),
q(q),
b(storage.get_xb()),
@ -111,24 +114,27 @@ struct FlatL2Dis : DistanceComputer {
}
};
struct FlatIPDis : DistanceComputer {
struct FlatIPDis : FlatCodesDistanceComputer {
size_t d;
Index::idx_t nb;
const float* q;
const float* b;
size_t ndis;
float operator()(idx_t i) override {
ndis++;
return fvec_inner_product(q, b + i * d, d);
}
float symmetric_dis(idx_t i, idx_t j) override {
return fvec_inner_product(b + j * d, b + i * d, d);
}
float distance_to_code(const uint8_t* code) final {
ndis++;
return fvec_inner_product(q, (float*)code, d);
}
explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
: d(storage.d),
: FlatCodesDistanceComputer(
storage.codes.data(),
storage.code_size),
d(storage.d),
nb(storage.ntotal),
q(q),
b(storage.get_xb()),
@ -141,7 +147,7 @@ struct FlatIPDis : DistanceComputer {
} // namespace
DistanceComputer* IndexFlat::get_distance_computer() const {
FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
if (metric_type == METRIC_L2) {
return new FlatL2Dis(*this);
} else if (metric_type == METRIC_INNER_PRODUCT) {

View File

@ -60,7 +60,7 @@ struct IndexFlat : IndexFlatCodes {
IndexFlat() {}
DistanceComputer* get_distance_computer() const override;
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
/* The stanadlone codec interface (just memcopies in this case) */
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;

View File

@ -8,6 +8,7 @@
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
namespace faiss {
@ -64,4 +65,9 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
reconstruct_n(key, 1, recons);
}
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
FAISS_THROW_MSG("not implemented");
}
} // namespace faiss

View File

@ -10,6 +10,7 @@
#pragma once
#include <faiss/Index.h>
#include <faiss/impl/DistanceComputer.h>
#include <vector>
namespace faiss {
@ -42,6 +43,13 @@ struct IndexFlatCodes : Index {
* indexing structure, the semantics of this operation are
* different from the usual ones: the new ids are shifted */
size_t remove_ids(const IDSelector& sel) override;
/** a FlatCodesDistanceComputer offers a distance_to_code method */
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
DistanceComputer* get_distance_computer() const override {
return get_FlatCodesDistanceComputer();
}
};
} // namespace faiss

View File

@ -5,9 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
// quiet the noise
// XXclang-format off
#include <faiss/IndexIVFAdditiveQuantizer.h>
#include <algorithm>

View File

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexPQ.h>
#include <cinttypes>
@ -17,7 +15,7 @@
#include <algorithm>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>
@ -73,19 +71,16 @@ void IndexPQ::train(idx_t n, const float* x) {
namespace {
template <class PQDecoder>
struct PQDistanceComputer : DistanceComputer {
struct PQDistanceComputer : FlatCodesDistanceComputer {
size_t d;
MetricType metric;
Index::idx_t nb;
const uint8_t* codes;
size_t code_size;
const ProductQuantizer& pq;
const float* sdc;
std::vector<float> precomputed_table;
size_t ndis;
float operator()(idx_t i) override {
const uint8_t* code = codes + i * code_size;
float distance_to_code(const uint8_t* code) final {
const float* dt = precomputed_table.data();
PQDecoder decoder(code, pq.nbits);
float accu = 0;
@ -112,13 +107,15 @@ struct PQDistanceComputer : DistanceComputer {
return accu;
}
explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
explicit PQDistanceComputer(const IndexPQ& storage)
: FlatCodesDistanceComputer(
storage.codes.data(),
storage.code_size),
pq(storage.pq) {
precomputed_table.resize(pq.M * pq.ksub);
nb = storage.ntotal;
d = storage.d;
metric = storage.metric_type;
codes = storage.codes.data();
code_size = pq.code_size;
if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
sdc = pq.sdc_table.data();
} else {
@ -138,7 +135,7 @@ struct PQDistanceComputer : DistanceComputer {
} // namespace
DistanceComputer* IndexPQ::get_distance_computer() const {
FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
if (pq.nbits == 8) {
return new PQDistanceComputer<PQDecoder8>(*this);
} else if (pq.nbits == 16) {

View File

@ -52,7 +52,7 @@ struct IndexPQ : IndexFlatCodes {
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
DistanceComputer* get_distance_computer() const override;
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
/******************************************************
* Polysemous codes implementation

View File

@ -15,6 +15,7 @@
#include <memory>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
namespace faiss {

View File

@ -85,7 +85,8 @@ void IndexScalarQuantizer::search(
}
}
DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
FlatCodesDistanceComputer* IndexScalarQuantizer::get_FlatCodesDistanceComputer()
const {
ScalarQuantizer::SQDistanceComputer* dc =
sq.get_distance_computer(metric_type);
dc->code_size = sq.code_size;

View File

@ -20,11 +20,8 @@
namespace faiss {
/**
* The uniform quantizer has a range [vmin, vmax]. The range can be
* the same for all dimensions (uniform) or specific per dimension
* (default).
* Flat index built on a scalar quantizer.
*/
struct IndexScalarQuantizer : IndexFlatCodes {
/// Used to encode the vectors
ScalarQuantizer sq;
@ -51,7 +48,7 @@ struct IndexScalarQuantizer : IndexFlatCodes {
float* distances,
idx_t* labels) const override;
DistanceComputer* get_distance_computer() const override;
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
/* standalone codec interface */
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;

View File

@ -83,24 +83,27 @@ void AdditiveQuantizer::set_derived_values() {
}
total_codebook_size = codebook_offsets[M];
switch (search_type) {
case ST_decompress:
case ST_LUT_nonorm:
case ST_norm_from_LUT:
break; // nothing to add
case ST_norm_float:
tot_bits += 32;
norm_bits = 32;
break;
case ST_norm_qint8:
case ST_norm_cqint8:
case ST_norm_lsq2x4:
case ST_norm_rq2x4:
tot_bits += 8;
norm_bits = 8;
break;
case ST_norm_qint4:
case ST_norm_cqint4:
tot_bits += 4;
norm_bits = 4;
break;
case ST_decompress:
case ST_LUT_nonorm:
case ST_norm_from_LUT:
default:
norm_bits = 0;
break;
}
tot_bits += norm_bits;
// convert bits to bytes
code_size = (tot_bits + 7) / 8;
@ -195,6 +198,28 @@ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
return qnorm.get_xb()[c];
}
uint64_t AdditiveQuantizer::encode_norm(float norm) const {
switch (search_type) {
case ST_norm_float:
return *(uint32_t*)&norm;
case ST_norm_qint8:
return encode_qint8(norm, norm_min, norm_max);
case ST_norm_qint4:
return encode_qint4(norm, norm_min, norm_max);
case ST_norm_lsq2x4:
case ST_norm_rq2x4:
case ST_norm_cqint8:
return encode_qcint(norm);
case ST_norm_cqint4:
return encode_qcint(norm);
case ST_decompress:
case ST_LUT_nonorm:
case ST_norm_from_LUT:
default:
return 0;
}
}
void AdditiveQuantizer::pack_codes(
size_t n,
const int32_t* codes,
@ -230,36 +255,8 @@ void AdditiveQuantizer::pack_codes(
for (int m = 0; m < M; m++) {
bsw.write(codes1[m], nbits[m]);
}
switch (search_type) {
case ST_decompress:
case ST_LUT_nonorm:
case ST_norm_from_LUT:
break;
case ST_norm_float:
bsw.write(*(uint32_t*)&norms[i], 32);
break;
case ST_norm_qint8: {
uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
bsw.write(b, 8);
break;
}
case ST_norm_qint4: {
uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
bsw.write(b, 4);
break;
}
case ST_norm_lsq2x4:
case ST_norm_rq2x4:
case ST_norm_cqint8: {
uint32_t b = encode_qcint(norms[i]);
bsw.write(b, 8);
break;
}
case ST_norm_cqint4: {
uint32_t b = encode_qcint(norms[i]);
bsw.write(b, 4);
break;
}
if (norm_bits != 0) {
bsw.write(encode_norm(norms[i]), norm_bits);
}
}
}

View File

@ -30,7 +30,8 @@ struct AdditiveQuantizer {
// derived values
std::vector<uint64_t> codebook_offsets;
size_t code_size; ///< code size in bytes
size_t tot_bits; ///< total number of bits
size_t tot_bits; ///< total number of bits (indexes + norms)
size_t norm_bits; ///< bits allocated for the norms
size_t total_codebook_size; ///< size of the codebook in vectors
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
@ -41,6 +42,9 @@ struct AdditiveQuantizer {
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
///< fastscan search
/// encode a norm into norm_bits bits
uint64_t encode_norm(float norm) const;
uint32_t encode_qcint(
float x) const; ///< encode norm by non-uniform scalar quantization

View File

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
// Auxiliary index structures, that are used in indexes but that can
// be forward-declared
@ -186,30 +184,6 @@ struct RangeSearchPartialResult : BufferList {
bool do_delete = true);
};
/***********************************************************
* The distance computer maintains a current query and computes
* distances to elements in an index that supports random access.
*
* The DistanceComputer is not intended to be thread-safe (eg. because
* it maintains counters) so the distance functions are not const,
* instantiate one from each thread if needed.
***********************************************************/
struct DistanceComputer {
using idx_t = Index::idx_t;
/// called before computing distances. Pointer x should remain valid
/// while operator () is called
virtual void set_query(const float* x) = 0;
/// compute distance of vector i to current query
virtual float operator()(idx_t i) = 0;
/// compute distance between two stored vectors
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
virtual ~DistanceComputer() {}
};
/***********************************************************
* Interrupt callback
***********************************************************/

View File

@ -0,0 +1,64 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/Index.h>
namespace faiss {
/***********************************************************
* The distance computer maintains a current query and computes
* distances to elements in an index that supports random access.
*
* The DistanceComputer is not intended to be thread-safe (eg. because
* it maintains counters) so the distance functions are not const,
* instantiate one from each thread if needed.
*
* Note that the equivalent for IVF indexes is the InvertedListScanner,
* that has additional methods to handle the inverted list context.
***********************************************************/
struct DistanceComputer {
using idx_t = Index::idx_t;
/// called before computing distances. Pointer x should remain valid
/// while operator () is called
virtual void set_query(const float* x) = 0;
/// compute distance of vector i to current query
virtual float operator()(idx_t i) = 0;
/// compute distance between two stored vectors
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
virtual ~DistanceComputer() {}
};
/*************************************************************
* Specialized version of the DistanceComputer when we know that codes are
* laid out in a flat index.
*/
struct FlatCodesDistanceComputer : DistanceComputer {
const uint8_t* codes;
size_t code_size;
FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
: codes(codes), code_size(code_size) {}
FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
float operator()(idx_t i) final {
return distance_to_code(codes + i * code_size);
}
/// compute distance of current query to an encoded vector
virtual float distance_to_code(const uint8_t* code) = 0;
virtual ~FlatCodesDistanceComputer() {}
};
} // namespace faiss

View File

@ -12,6 +12,7 @@
#include <string>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
namespace faiss {

View File

@ -13,6 +13,7 @@
#include <string>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
namespace faiss {

View File

@ -14,7 +14,7 @@
#include <mutex>
#include <stack>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
namespace faiss {

View File

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/impl/ResidualQuantizer.h>
#include <algorithm>
@ -97,6 +95,39 @@ ResidualQuantizer::ResidualQuantizer(
Search_type_t search_type)
: ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
void ResidualQuantizer::initialize_from(
const ResidualQuantizer& other,
int skip_M) {
FAISS_THROW_IF_NOT(M + skip_M <= other.M);
FAISS_THROW_IF_NOT(skip_M >= 0);
Search_type_t this_search_type = search_type;
int this_M = M;
// a first good approximation: override everything
*this = other;
// adjust derived values
M = this_M;
search_type = this_search_type;
nbits.resize(M);
memcpy(nbits.data(),
other.nbits.data() + skip_M,
nbits.size() * sizeof(nbits[0]));
set_derived_values();
// resize codebooks if trained
if (codebooks.size() > 0) {
FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
codebooks.resize(total_codebook_size * d);
memcpy(codebooks.data(),
other.codebooks.data() + other.codebook_offsets[skip_M] * d,
codebooks.size() * sizeof(codebooks[0]));
// TODO: norm_tabs?
}
}
void beam_search_encode_step(
size_t d,
size_t K,

View File

@ -28,7 +28,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
// Was enum but that does not work so well with bitmasks
using train_type_t = int;
/// Or of the Train_* flags below
/// Binary or of the Train_* flags below
train_type_t train_type;
/// regular k-means (minimal amount of computation)
@ -47,7 +47,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
* first element of the beam (faster but less accurate) */
static const int Train_top_beam = 1024;
/** set this bit to not autmatically compute the codebook tables
/** set this bit to *not* autmatically compute the codebook tables
* after training */
static const int Skip_codebook_tables = 2048;
@ -83,6 +83,9 @@ struct ResidualQuantizer : AdditiveQuantizer {
/// Train the residual quantizer
void train(size_t n, const float* x) override;
/// Copy the M codebook levels from other, starting from skip_M
void initialize_from(const ResidualQuantizer& other, int skip_M = 0);
/** Encode the vectors and compute codebook that minimizes the quantization
* error on these codes
*

View File

@ -911,11 +911,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
q = x;
}
/// compute distance of vector i to current query
float operator()(idx_t i) final {
return query_to_code(codes + i * code_size);
}
float symmetric_dis(idx_t i, idx_t j) override {
return compute_code_distance(
codes + i * code_size, codes + j * code_size);
@ -963,11 +958,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
q = x;
}
/// compute distance of vector i to current query
float operator()(idx_t i) final {
return query_to_code(codes + i * code_size);
}
float symmetric_dis(idx_t i, idx_t j) override {
return compute_code_distance(
codes + i * code_size, codes + j * code_size);
@ -1021,11 +1011,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
return compute_code_distance(tmp.data(), code);
}
/// compute distance of vector i to current query
float operator()(idx_t i) final {
return query_to_code(codes + i * code_size);
}
float symmetric_dis(idx_t i, idx_t j) override {
return compute_code_distance(
codes + i * code_size, codes + j * code_size);
@ -1089,11 +1074,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
return compute_code_distance(tmp.data(), code);
}
/// compute distance of vector i to current query
float operator()(idx_t i) final {
return query_to_code(codes + i * code_size);
}
float symmetric_dis(idx_t i, idx_t j) override {
return compute_code_distance(
codes + i * code_size, codes + j * code_size);

View File

@ -11,6 +11,7 @@
#include <faiss/IndexIVF.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
namespace faiss {
@ -105,14 +106,16 @@ struct ScalarQuantizer {
Quantizer* select_quantizer() const;
struct SQDistanceComputer : DistanceComputer {
struct SQDistanceComputer : FlatCodesDistanceComputer {
const float* q;
const uint8_t* codes;
size_t code_size;
SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
SQDistanceComputer() : q(nullptr) {}
virtual float query_to_code(const uint8_t* code) const = 0;
float distance_to_code(const uint8_t* code) final {
return query_to_code(code);
}
};
SQDistanceComputer* get_distance_computer(

View File

@ -1287,6 +1287,12 @@ def randn(n, seed=12345):
float_randn(swig_ptr(res), res.size, seed)
return res
rand_smooth_vectors_c = rand_smooth_vectors
def rand_smooth_vectors(n, d, seed=1234):
res = np.empty((n, d), dtype='float32')
rand_smooth_vectors_c(n, d, swig_ptr(res), seed)
return res
def eval_intersection(I1, I2):
""" size of intersection between each line of two result tables"""
@ -1429,24 +1435,14 @@ def knn(xq, xb, k, metric=METRIC_L2):
D = np.empty((nq, k), dtype='float32')
if metric == METRIC_L2:
heaps = float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = swig_ptr(D)
heaps.ids = swig_ptr(I)
knn_L2sqr(
swig_ptr(xq), swig_ptr(xb),
d, nq, nb, heaps
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
)
elif metric == METRIC_INNER_PRODUCT:
heaps = float_minheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = swig_ptr(D)
heaps.ids = swig_ptr(I)
knn_inner_product(
swig_ptr(xq), swig_ptr(xb),
d, nq, nb, heaps
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
)
else:
raise NotImplementedError("only L2 and INNER_PRODUCT are supported")

View File

@ -126,6 +126,7 @@ typedef uint64_t size_t;
#include <faiss/utils/AlignedTable.h>
#include <faiss/utils/partitioning.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/AdditiveQuantizer.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/LocalSearchQuantizer.h>
@ -378,6 +379,10 @@ void gpu_sync_all_devices()
%newobject *::get_distance_computer() const;
%include <faiss/Index.h>
%include <faiss/impl/DistanceComputer.h>
%newobject *::get_FlatCodesDistanceComputer() const;
%include <faiss/IndexFlatCodes.h>
%include <faiss/IndexFlat.h>
%include <faiss/Clustering.h>
@ -484,7 +489,8 @@ void gpu_sync_all_devices()
%ignore faiss::InterruptCallback::instance;
%ignore faiss::InterruptCallback::lock;
%include <faiss/impl/AuxIndexStructures.h>
%include <faiss/impl/AuxIndexStructures.h>
#ifdef GPU_WRAPPER

View File

@ -333,6 +333,19 @@ void knn_inner_product(
}
}
void knn_inner_product(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* indexes) {
float_minheap_array_t heaps = {nx, k, indexes, distances};
knn_inner_product(x, y, d, nx, ny, &heaps);
}
void knn_L2sqr(
const float* x,
const float* y,
@ -361,6 +374,20 @@ void knn_L2sqr(
}
}
void knn_L2sqr(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* indexes,
const float* y_norm2) {
float_maxheap_array_t heaps = {nx, k, indexes, distances};
knn_L2sqr(x, y, d, nx, ny, &heaps, y_norm2);
}
/***************************************************************************
* Range search
***************************************************************************/

View File

@ -198,11 +198,11 @@ FAISS_API extern int distance_compute_blas_database_bs;
FAISS_API extern int distance_compute_min_k_reservoir;
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, w.r.t to max inner product
* vector y, w.r.t to max inner product.
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param res result array, which also provides k. Sorted on output
* @param res result heap structure, which also provides k. Sorted on output
*/
void knn_inner_product(
const float* x,
@ -212,8 +212,30 @@ void knn_inner_product(
size_t ny,
float_minheap_array_t* res);
/** Same as knn_inner_product, for the L2 distance
* @param y_norm2 norms for the y vectors (nullptr or size ny)
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, for the inner product metric.
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param distances output distances, size nq * k
* @param indexes output vector ids, size nq * k
*/
void knn_inner_product(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* indexes);
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, for the L2 distance
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param res result heap strcture, which also provides k. Sorted on output
* @param y_norm2 (optional) norms for the y vectors (nullptr or size ny)
*/
void knn_L2sqr(
const float* x,
@ -224,6 +246,26 @@ void knn_L2sqr(
float_maxheap_array_t* res,
const float* y_norm2 = nullptr);
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, for the L2 distance
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param distances output distances, size nq * k
* @param indexes output vector ids, size nq * k
* @param y_norm2 (optional) norms for the y vectors (nullptr or size ny)
*/
void knn_L2sqr(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
size_t k,
float* distances,
int64_t* indexes,
const float* y_norm2 = nullptr);
/* Find the nearest neighbors for nx queries in a set of ny vectors
* indexed by ids. May be useful for re-ranking a pre-selected vector list
*/

View File

@ -14,6 +14,7 @@
#include <cmath>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
@ -89,18 +90,18 @@ void knn_extra_metrics_template(
}
template <class VD>
struct ExtraDistanceComputer : DistanceComputer {
struct ExtraDistanceComputer : FlatCodesDistanceComputer {
VD vd;
Index::idx_t nb;
const float* q;
const float* b;
float operator()(idx_t i) override {
return vd(q, b + i * vd.d);
float symmetric_dis(idx_t i, idx_t j) final {
return vd(b + j * vd.d, b + i * vd.d);
}
float symmetric_dis(idx_t i, idx_t j) override {
return vd(b + j * vd.d, b + i * vd.d);
float distance_to_code(const uint8_t* code) final {
return vd(q, (float*)code);
}
ExtraDistanceComputer(
@ -108,7 +109,11 @@ struct ExtraDistanceComputer : DistanceComputer {
const float* xb,
size_t nb,
const float* q = nullptr)
: vd(vd), nb(nb), q(q), b(xb) {}
: FlatCodesDistanceComputer((uint8_t*)xb, vd.d * sizeof(float)),
vd(vd),
nb(nb),
q(q),
b(xb) {}
void set_query(const float* x) override {
q = x;
@ -188,7 +193,7 @@ void knn_extra_metrics(
}
}
DistanceComputer* get_extra_distance_computer(
FlatCodesDistanceComputer* get_extra_distance_computer(
size_t d,
MetricType mt,
float metric_arg,

View File

@ -18,6 +18,8 @@
namespace faiss {
struct FlatCodesDistanceComputer;
void pairwise_extra_distances(
int64_t d,
int64_t nq,
@ -43,7 +45,7 @@ void knn_extra_metrics(
/** get a DistanceComputer that refers to this type of distance and
* indexes a flat array of size nb */
DistanceComputer* get_extra_distance_computer(
FlatCodesDistanceComputer* get_extra_distance_computer(
size_t d,
MetricType mt,
float metric_arg,

View File

@ -9,6 +9,23 @@
#include <faiss/utils/random.h>
extern "C" {
int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);
}
namespace faiss {
/**************************************************
@ -165,4 +182,40 @@ void byte_rand(uint8_t* x, size_t n, int64_t seed) {
}
}
void rand_smooth_vectors(size_t n, size_t d, float* x, int64_t seed) {
size_t d1 = 10;
std::vector<float> x1(n * d1);
float_randn(x1.data(), x1.size(), seed);
std::vector<float> rot(d1 * d);
float_rand(rot.data(), rot.size(), seed + 1);
{ //
FINTEGER di = d, d1i = d1, ni = n;
float one = 1.0, zero = 0.0;
sgemm_("Not transposed",
"Not transposed", // natural order
&di,
&ni,
&d1i,
&one,
rot.data(),
&di, // rotation matrix
x1.data(),
&d1i, // second term
&zero,
x,
&di);
}
std::vector<float> scales(d);
float_rand(scales.data(), d, seed + 2);
#pragma omp parallel for if (n * d > 10000)
for (int64_t i = 0; i < n; i++) {
for (size_t j = 0; j < d; j++) {
x[i * d + j] = sinf(x[i * d + j] * (scales[j] * 4 + 0.1));
}
}
}
} // namespace faiss

View File

@ -54,4 +54,9 @@ void int64_rand_max(int64_t* x, size_t n, uint64_t max, int64_t seed);
/* random permutation */
void rand_perm(int* perm, size_t n, int64_t seed);
/* Random set of vectors with intrinsic dimensionality 10 that is harder to
* index than a subspace of dim 10 but easier than uniform data in dimension d
* */
void rand_smooth_vectors(size_t n, size_t d, float* x, int64_t seed);
} // namespace faiss

View File

@ -325,6 +325,7 @@ class TestScalarQuantizer(unittest.TestCase):
# print(dis, D[i, j])
assert abs(D[i, j] - dis) / dis < 1e-5
class TestRandom(unittest.TestCase):
def test_rand(self):
@ -340,6 +341,24 @@ class TestRandom(unittest.TestCase):
print(c)
assert c.max() - c.min() < 50 * 2
def test_rand_vector(self):
""" test if the smooth_vectors function is reasonably compressible with
a small PQ """
x = faiss.rand_smooth_vectors(1300, 32)
xt = x[:1000]
xb = x[1000:1200]
xq = x[1200:]
_, gt = faiss.knn(xq, xb, 10)
index = faiss.IndexPQ(32, 4, 4)
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)
ninter = faiss.eval_intersection(I, gt)
# 445 for SyntheticDataset
self.assertGreater(ninter, 420)
self.assertLess(ninter, 460)
class TestPairwiseDis(unittest.TestCase):

View File

@ -21,15 +21,22 @@ class TestDistanceComputer(unittest.TestCase):
index.add(ds.get_database())
xq = ds.get_queries()
Dref, Iref = index.search(xq, 10)
dc = index.get_distance_computer()
self.assertTrue(dc.this.own())
for q in range(ds.nq):
dc.set_query(faiss.swig_ptr(xq[q]))
for j in range(10):
ref_dis = Dref[q, j]
new_dis = dc(int(Iref[q, j]))
np.testing.assert_almost_equal(
new_dis, ref_dis, decimal=5)
for is_FlatCodesDistanceComputer in False, True:
if not is_FlatCodesDistanceComputer:
dc = index.get_distance_computer()
else:
if not isinstance(index, faiss.IndexFlatCodes):
continue
dc = index.get_FlatCodesDistanceComputer()
self.assertTrue(dc.this.own())
for q in range(ds.nq):
dc.set_query(faiss.swig_ptr(xq[q]))
for j in range(10):
ref_dis = Dref[q, j]
new_dis = dc(int(Iref[q, j]))
np.testing.assert_almost_equal(
new_dis, ref_dis, decimal=5)
def test_distance_computer_PQ(self):
self.do_test("PQ8np")
@ -49,5 +56,11 @@ class TestDistanceComputer(unittest.TestCase):
def test_distance_computer_VT(self):
self.do_test("PCA20,SQ8")
def test_distance_computer_AQ_decompress(self):
self.do_test("RQ3x4") # test decompress path
def test_distance_computer_AQ_LUT(self):
self.do_test("RQ3x4_Nqint8") # test LUT path
def test_distance_computer_AQ_LUT_IP(self):
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)

View File

@ -258,6 +258,38 @@ class TestResidualQuantizer(unittest.TestCase):
for c0, c1 in zip(cb0, cb1):
self.assertTrue(np.all(c0 == c1))
def test_clipping(self):
""" verify that a clipped residual quantizer gives the same
code prefix + suffix as the full RQ """
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
rq = faiss.ResidualQuantizer(ds.d, 5, 4)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.max_beam_size = 5
rq.train(ds.get_train())
rq.max_beam_size = 1 # is not he same for a large beam size
codes = rq.compute_codes(ds.get_database())
rq2 = faiss.ResidualQuantizer(ds.d, 2, 4)
rq2.initialize_from(rq)
self.assertEqual(rq2.M, 2)
# verify that the beginning of the codes are the same
codes2 = rq2.compute_codes(ds.get_database())
rq3 = faiss.ResidualQuantizer(ds.d, 3, 4)
rq3.initialize_from(rq, 2)
self.assertEqual(rq3.M, 3)
codes3 = rq3.compute_codes(ds.get_database() - rq2.decode(codes2))
# verify that prefixes are the same
for i in range(ds.nb):
print(i, ds.nb)
br = faiss.BitstringReader(faiss.swig_ptr(codes[i]), rq.code_size)
br2 = faiss.BitstringReader(faiss.swig_ptr(codes2[i]), rq2.code_size)
self.assertEqual(br.read(rq2.tot_bits), br2.read(rq2.tot_bits))
br3 = faiss.BitstringReader(faiss.swig_ptr(codes3[i]), rq3.code_size)
self.assertEqual(br.read(rq3.tot_bits), br3.read(rq3.tot_bits))
###########################################################
# Test index, index factory sa_encode / sa_decode
@ -318,6 +350,7 @@ def retrain_AQ_codebook(index, xt):
return C, B
class TestIndexResidualQuantizer(unittest.TestCase):
def test_io(self):