11 #include "IndexIVFSpectralHash.h"
18 #include "FaissAssert.h"
19 #include "AuxIndexStructures.h"
20 #include "VectorTransform.h"
25 IndexIVFSpectralHash::IndexIVFSpectralHash (
26 Index * quantizer,
size_t d,
size_t nlist,
27 int nbit,
float period):
28 IndexIVF (quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
29 nbit (nbit), period (period), threshold_type (Thresh_global)
32 RandomRotationMatrix *rr =
new RandomRotationMatrix (d, nbit);
39 IndexIVFSpectralHash::IndexIVFSpectralHash():
40 IndexIVF(), vt(nullptr), own_fields(false),
41 nbit(0), period(0), threshold_type(Thresh_global)
44 IndexIVFSpectralHash::~IndexIVFSpectralHash ()
54 float median (
size_t n,
float *x) {
59 return (x [n / 2 - 1] + x [n / 2]) / 2;
72 if (threshold_type == Thresh_global) {
75 }
else if (threshold_type == Thresh_centroid ||
76 threshold_type == Thresh_centroid_half) {
78 std::vector<float> centroids (nlist * d);
79 quantizer->reconstruct_n (0, nlist, centroids.data());
80 trained.resize(nlist * nbit);
82 if (threshold_type == Thresh_centroid_half) {
83 for (
size_t i = 0; i < nlist * nbit; i++) {
84 trained[i] -= 0.25 * period;
92 std::unique_ptr<idx_t []> idx (
new idx_t [n]);
93 quantizer->assign (n, x, idx.get());
95 std::vector<size_t> sizes(nlist + 1);
96 for (
size_t i = 0; i < n; i++) {
97 FAISS_THROW_IF_NOT (idx[i] >= 0);
102 for (
int j = 0; j <
nlist; j++) {
109 std::unique_ptr<float []> xt (vt->
apply (n, x));
112 std::unique_ptr<float []> xo (
new float[n * nbit]);
114 for (
size_t i = 0; i < n; i++) {
115 size_t idest = sizes[idx[i]]++;
116 for (
size_t j = 0; j < nbit; j++) {
117 xo[idest + n * j] = xt[i * nbit + j];
121 trained.resize (n * nbit);
124 for (
int i = 0; i <
nlist; i++) {
125 size_t i0 = i == 0 ? 0 : sizes[i - 1];
126 size_t i1 = sizes[i];
127 for (
int j = 0; j < nbit; j++) {
128 float *xoi = xo.get() + i0 + n * j;
130 trained[i * nbit + j] = 0.0;
131 }
else if (i1 == i0 + 1) {
132 trained[i * nbit + j] = xoi[0];
134 trained[i * nbit + j] = median(i1 - i0, xoi);
143 void binarize_with_freq(
size_t nbit,
float freq,
144 const float *x,
const float *c,
147 memset (codes, 0, (nbit + 7) / 8);
148 for (
size_t i = 0; i < nbit; i++) {
149 float xf = (x[i] - c[i]);
150 int xi = int(floor(xf * freq));
152 codes[i >> 3] |= bit << (i & 7);
162 const idx_t *list_nos,
163 uint8_t * codes)
const
166 float freq = 2.0 / period;
169 std::unique_ptr<float []> x (vt->
apply (n, x_in));
173 std::vector<float> zero (nbit);
177 for (
size_t i = 0; i < n; i++) {
178 long list_no = list_nos [i];
182 if (threshold_type == Thresh_global) {
185 c = trained.data() + list_no * nbit;
187 binarize_with_freq (nbit, freq,
188 x.get() + i * nbit, c,
198 template<
class HammingComputer>
208 std::vector<float> q;
209 std::vector<float> zero;
210 std::vector<uint8_t> qcode;
218 code_size(index->code_size),
220 store_pairs(store_pairs),
221 period(index->period), freq(2.0 / index->period),
222 q(nbit), zero(nbit), qcode(code_size),
223 hc(qcode.data(), code_size)
228 void set_query (
const float *query)
override {
229 FAISS_THROW_IF_NOT(query);
230 FAISS_THROW_IF_NOT(q.size() == nbit);
231 index->vt->apply_noalloc (1, query, q.data());
233 if (index->threshold_type ==
234 IndexIVFSpectralHash::Thresh_global) {
236 (nbit, freq, q.data(), zero.data(), qcode.data());
237 hc.set (qcode.data(), code_size);
243 void set_list (idx_t list_no,
float )
override {
244 this->list_no = list_no;
245 if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
246 const float *c = index->trained.data() + list_no * nbit;
247 binarize_with_freq (nbit, freq, q.data(), c, qcode.data());
248 hc.set (qcode.data(), code_size);
252 float distance_to_code (
const uint8_t *code)
const final {
253 return hc.hamming (code);
256 size_t scan_codes (
size_t list_size,
257 const uint8_t *codes,
259 float *simi, idx_t *idxi,
260 size_t k)
const override
263 for (
size_t j = 0; j < list_size; j++) {
265 float dis = hc.hamming (codes);
267 if (dis < simi [0]) {
268 maxheap_pop (k, simi, idxi);
269 long id = store_pairs ? (list_no << 32 | j) : ids[j];
270 maxheap_push (k, simi, idxi, dis,
id);
278 void scan_codes_range (
size_t list_size,
279 const uint8_t *codes,
282 RangeQueryResult & res)
const override
284 for (
size_t j = 0; j < list_size; j++) {
285 float dis = hc.hamming (codes);
287 long id = store_pairs ? (list_no << 32 | j) : ids[j];
300 (
bool store_pairs)
const
303 #define HANDLE_CODE_SIZE(cs) \
305 return new IVFScanner<HammingComputer ## cs> (this, store_pairs)
308 HANDLE_CODE_SIZE(16);
309 HANDLE_CODE_SIZE(20);
310 HANDLE_CODE_SIZE(32);
311 HANDLE_CODE_SIZE(64);
312 #undef HANDLE_CODE_SIZE
314 if (code_size % 8 == 0) {
315 return new IVFScanner<HammingComputerM8>(
this, store_pairs);
316 }
else if (code_size % 4 == 0) {
317 return new IVFScanner<HammingComputerM4>(
this, store_pairs);
319 FAISS_THROW_MSG(
"not supported");
size_t code_size
code_size_1 + code_size_2
InvertedListScanner * get_InvertedListScanner(bool store_pairs) const override
get a scanner for this index (store_pairs means ignore labels)
long idx_t
all indices are this type
void train_residual(idx_t n, const float *x) override
void encode_vectors(idx_t n, const float *x, const idx_t *list_nos, uint8_t *codes) const override
bool is_trained
set if the Index does not require training, or if training is done already
size_t nlist
number of possible key values
size_t code_size
code size per vector in bytes