Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IVFlib.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IVFlib.h"
12 
13 #include <memory>
14 
15 #include "VectorTransform.h"
16 #include "FaissAssert.h"
17 
18 
19 
20 namespace faiss { namespace ivflib {
21 
22 
23 void check_compatible_for_merge (const Index * index0,
24  const Index * index1)
25 {
26 
27  const faiss::IndexPreTransform *pt0 =
28  dynamic_cast<const faiss::IndexPreTransform *>(index0);
29 
30  if (pt0) {
31  const faiss::IndexPreTransform *pt1 =
32  dynamic_cast<const faiss::IndexPreTransform *>(index1);
33  FAISS_THROW_IF_NOT_MSG (pt1, "both indexes should be pretransforms");
34 
35  FAISS_THROW_IF_NOT (pt0->chain.size() == pt1->chain.size());
36  for (int i = 0; i < pt0->chain.size(); i++) {
37  FAISS_THROW_IF_NOT (typeid(pt0->chain[i]) == typeid(pt1->chain[i]));
38  }
39 
40  index0 = pt0->index;
41  index1 = pt1->index;
42  }
43  FAISS_THROW_IF_NOT (typeid(index0) == typeid(index1));
44  FAISS_THROW_IF_NOT (index0->d == index1->d &&
45  index0->metric_type == index1->metric_type);
46 
47  const faiss::IndexIVF *ivf0 = dynamic_cast<const faiss::IndexIVF *>(index0);
48  if (ivf0) {
49  const faiss::IndexIVF *ivf1 =
50  dynamic_cast<const faiss::IndexIVF *>(index1);
51  FAISS_THROW_IF_NOT (ivf1);
52 
53  ivf0->check_compatible_for_merge (*ivf1);
54  }
55 
56  // TODO: check as thoroughfully for other index types
57 
58 }
59 
60 const IndexIVF * extract_index_ivf (const Index * index)
61 {
62  if (auto *pt =
63  dynamic_cast<const IndexPreTransform *>(index)) {
64  index = pt->index;
65  }
66 
67  auto *ivf = dynamic_cast<const IndexIVF *>(index);
68 
69  FAISS_THROW_IF_NOT (ivf);
70 
71  return ivf;
72 }
73 
74 IndexIVF * extract_index_ivf (Index * index) {
75  return const_cast<IndexIVF*> (extract_index_ivf ((const Index*)(index)));
76 }
77 
78 void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) {
79 
80  check_compatible_for_merge (index0, index1);
81  IndexIVF * ivf0 = extract_index_ivf (index0);
82  IndexIVF * ivf1 = extract_index_ivf (index1);
83 
84  ivf0->merge_from (*ivf1, shift_ids ? ivf0->ntotal : 0);
85 
86  // useful for IndexPreTransform
87  index0->ntotal = ivf0->ntotal;
88  index1->ntotal = ivf1->ntotal;
89 }
90 
91 
92 
93 void search_centroid(faiss::Index *index,
94  const float* x, int n,
95  idx_t* centroid_ids)
96 {
97  std::unique_ptr<float[]> del;
98  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
99  x = index_pre->apply_chain(n, x);
100  del.reset((float*)x);
101  index = index_pre->index;
102  }
103  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
104  assert(index_ivf);
105  index_ivf->quantizer->assign(n, x, centroid_ids);
106 }
107 
108 
109 
110 void search_and_return_centroids(faiss::Index *index,
111  size_t n,
112  const float* xin,
113  long k,
114  float *distances,
115  idx_t* labels,
116  idx_t* query_centroid_ids,
117  idx_t* result_centroid_ids)
118 {
119  const float *x = xin;
120  std::unique_ptr<float []> del;
121  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
122  x = index_pre->apply_chain(n, x);
123  del.reset((float*)x);
124  index = index_pre->index;
125  }
126  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
127  assert(index_ivf);
128 
129  size_t nprobe = index_ivf->nprobe;
130  std::vector<idx_t> cent_nos (n * nprobe);
131  std::vector<float> cent_dis (n * nprobe);
132  index_ivf->quantizer->search(
133  n, x, nprobe, cent_dis.data(), cent_nos.data());
134 
135  if (query_centroid_ids) {
136  for (size_t i = 0; i < n; i++)
137  query_centroid_ids[i] = cent_nos[i * nprobe];
138  }
139 
140  index_ivf->search_preassigned (n, x, k,
141  cent_nos.data(), cent_dis.data(),
142  distances, labels, true);
143 
144  for (size_t i = 0; i < n * k; i++) {
145  idx_t label = labels[i];
146  if (label < 0) {
147  if (result_centroid_ids)
148  result_centroid_ids[i] = -1;
149  } else {
150  long list_no = label >> 32;
151  long list_index = label & 0xffffffff;
152  if (result_centroid_ids)
153  result_centroid_ids[i] = list_no;
154  labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
155  }
156  }
157 }
158 
159 
161  n_slice = 0;
162  IndexIVF* index_ivf = const_cast<IndexIVF*>(extract_index_ivf (index));
163  ils = dynamic_cast<ArrayInvertedLists *> (index_ivf->invlists);
164  nlist = ils->nlist;
165  FAISS_THROW_IF_NOT_MSG (ils,
166  "only supports indexes with ArrayInvertedLists");
167  sizes.resize(nlist);
168 }
169 
170 template<class T>
171 static void shift_and_add (std::vector<T> & dst,
172  size_t remove,
173  const std::vector<T> & src)
174 {
175  if (remove > 0)
176  memmove (dst.data(), dst.data() + remove,
177  (dst.size() - remove) * sizeof (T));
178  size_t insert_point = dst.size() - remove;
179  dst.resize (insert_point + src.size());
180  memcpy (dst.data() + insert_point, src.data (), src.size() * sizeof(T));
181 }
182 
183 template<class T>
184 static void remove_from_begin (std::vector<T> & v,
185  size_t remove)
186 {
187  if (remove > 0)
188  v.erase (v.begin(), v.begin() + remove);
189 }
190 
191 void SlidingIndexWindow::step(const Index *sub_index, bool remove_oldest) {
192 
193  FAISS_THROW_IF_NOT_MSG (!remove_oldest || n_slice > 0,
194  "cannot remove slice: there is none");
195 
196  const ArrayInvertedLists *ils2 = nullptr;
197  if(sub_index) {
198  check_compatible_for_merge (index, sub_index);
199  ils2 = dynamic_cast<const ArrayInvertedLists*>(
200  extract_index_ivf (sub_index)->invlists);
201  FAISS_THROW_IF_NOT_MSG (ils2, "supports only ArrayInvertedLists");
202  }
203  IndexIVF *index_ivf = extract_index_ivf (index);
204 
205  if (remove_oldest && ils2) {
206  for (int i = 0; i < nlist; i++) {
207  std::vector<size_t> & sizesi = sizes[i];
208  size_t amount_to_remove = sizesi[0];
209  index_ivf->ntotal += ils2->ids[i].size() - amount_to_remove;
210 
211  shift_and_add (ils->ids[i], amount_to_remove, ils2->ids[i]);
212  shift_and_add (ils->codes[i], amount_to_remove * ils->code_size,
213  ils2->codes[i]);
214  for (int j = 0; j + 1 < n_slice; j++) {
215  sizesi[j] = sizesi[j + 1] - amount_to_remove;
216  }
217  sizesi[n_slice - 1] = ils->ids[i].size();
218  }
219  } else if (ils2) {
220  for (int i = 0; i < nlist; i++) {
221  index_ivf->ntotal += ils2->ids[i].size();
222  shift_and_add (ils->ids[i], 0, ils2->ids[i]);
223  shift_and_add (ils->codes[i], 0, ils2->codes[i]);
224  sizes[i].push_back(ils->ids[i].size());
225  }
226  n_slice++;
227  } else if (remove_oldest) {
228  for (int i = 0; i < nlist; i++) {
229  size_t amount_to_remove = sizes[i][0];
230  index_ivf->ntotal -= amount_to_remove;
231  remove_from_begin (ils->ids[i], amount_to_remove);
232  remove_from_begin (ils->codes[i],
233  amount_to_remove * ils->code_size);
234  for (int j = 0; j + 1 < n_slice; j++) {
235  sizes[i][j] = sizes[i][j + 1] - amount_to_remove;
236  }
237  sizes[i].resize(sizes[i].size() - 1);
238  }
239  n_slice--;
240  } else {
241  FAISS_THROW_MSG ("nothing to do???");
242  }
243  index->ntotal = index_ivf->ntotal;
244 }
245 
246 
247 
248 // Get a subset of inverted lists [i0, i1). Works on IndexIVF's and
249 // IndexIVF's embedded in a IndexPreTransform
250 
252 get_invlist_range (const Index *index, long i0, long i1)
253 {
254  const IndexIVF *ivf = extract_index_ivf (index);
255 
256  FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
257 
258  const InvertedLists *src = ivf->invlists;
259 
260  ArrayInvertedLists * il = new ArrayInvertedLists(i1 - i0, src->code_size);
261 
262  for (long i = i0; i < i1; i++) {
263  il->add_entries(i - i0, src->list_size(i),
264  InvertedLists::ScopedIds (src, i).get(),
265  InvertedLists::ScopedCodes (src, i).get());
266  }
267  return il;
268 }
269 
270 
271 
272 void set_invlist_range (Index *index, long i0, long i1,
273  ArrayInvertedLists * src)
274 {
275  IndexIVF *ivf = extract_index_ivf (index);
276 
277  FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
278 
279  ArrayInvertedLists *dst = dynamic_cast<ArrayInvertedLists *>(ivf->invlists);
280  FAISS_THROW_IF_NOT_MSG (dst, "only ArrayInvertedLists supported");
281  FAISS_THROW_IF_NOT (src->nlist == i1 - i0 &&
282  dst->code_size == src->code_size);
283 
284  size_t ntotal = index->ntotal;
285  for (long i = i0 ; i < i1; i++) {
286  ntotal -= dst->list_size (i);
287  ntotal += src->list_size (i - i0);
288  std::swap (src->codes[i - i0], dst->codes[i]);
289  std::swap (src->ids[i - i0], dst->ids[i]);
290  }
291  ivf->ntotal = index->ntotal = ntotal;
292 }
293 
294 
295 void search_with_parameters (const Index *index,
296  idx_t n, const float *x, idx_t k,
297  float *distances, idx_t *labels,
298  IVFSearchParameters *params)
299 {
300  FAISS_THROW_IF_NOT (params);
301  const float *prev_x = x;
302  ScopeDeleter<float> del;
303 
304  if (auto ip = dynamic_cast<const IndexPreTransform *> (index)) {
305  x = ip->apply_chain (n, x);
306  if (x != prev_x) {
307  del.set(x);
308  }
309  index = ip->index;
310  }
311 
312  std::vector<idx_t> Iq(params->nprobe * n);
313  std::vector<float> Dq(params->nprobe * n);
314 
315  const IndexIVF *index_ivf = dynamic_cast<const IndexIVF *>(index);
316  FAISS_THROW_IF_NOT (index_ivf);
317 
318  index_ivf->quantizer->search(n, x, params->nprobe,
319  Dq.data(), Iq.data());
320 
321  index_ivf->search_preassigned(n, x, k, Iq.data(), Dq.data(),
322  distances, labels,
323  false, params);
324 }
325 
326 
327 
328 } } // namespace faiss::ivflib
Index * index
! chain of tranforms
virtual void search_preassigned(idx_t n, const float *x, idx_t k, const idx_t *assign, const float *centroid_dis, float *distances, idx_t *labels, bool store_pairs, const IVFSearchParameters *params=nullptr) const
Definition: IndexIVF.cpp:189
simple (default) implementation as an array of inverted lists
void check_compatible_for_merge(const IndexIVF &other) const
Definition: IndexIVF.cpp:461
size_t nprobe
number of probes at query time
Definition: IndexIVF.h:98
void assign(idx_t n, const float *x, idx_t *labels, idx_t k=1)
Definition: Index.cpp:35
virtual size_t list_size(size_t list_no) const =0
get the size of a list
size_t nlist
same as index-&gt;nlist
Definition: IVFlib.h:95
virtual idx_t get_single_id(size_t list_no, size_t offset) const
size_t code_size
code size per vector in bytes
Definition: InvertedLists.h:36
ArrayInvertedLists * ils
InvertedLists of index.
Definition: IVFlib.h:89
int n_slice
number of slices currently in index
Definition: IVFlib.h:92
std::vector< std::vector< size_t > > sizes
cumulative list sizes at each slice
Definition: IVFlib.h:98
SlidingIndexWindow(Index *index)
index should be initially empty and trained
Definition: IVFlib.cpp:160
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void step(const Index *sub_index, bool remove_oldest)
Definition: IVFlib.cpp:191
size_t nlist
number of possible key values
Definition: InvertedLists.h:35
InvertedLists * invlists
Acess to the actual data.
Definition: IndexIVF.h:93
std::vector< std::vector< idx_t > > ids
Inverted lists for indexes.
Index * quantizer
quantizer that maps vectors to inverted lists
Definition: IndexIVF.h:33