10 #include "IndexShards.h"
15 #include "FaissAssert.h"
17 #include "WorkerThread.h"
28 void translate_labels (
long n, idx_t *labels,
long translation)
30 if (translation == 0)
return;
31 for (
long i = 0; i < n; i++) {
32 if(labels[i] < 0)
continue;
33 labels[i] += translation;
44 template <
class IndexClass,
class C>
46 merge_tables(
long n,
long k,
long nshard,
47 typename IndexClass::distance_t *distances,
49 const std::vector<typename IndexClass::distance_t>& all_distances,
50 const std::vector<idx_t>& all_labels,
51 const std::vector<long>& translations) {
55 using distance_t =
typename IndexClass::distance_t;
60 std::vector<int> buf (2 * nshard);
61 int * pointer = buf.data();
62 int * shard_ids = pointer + nshard;
63 std::vector<distance_t> buf2 (nshard);
64 distance_t * heap_vals = buf2.data();
66 for (
long i = 0; i < n; i++) {
69 const distance_t *D_in = all_distances.data() + i * k;
70 const idx_t *I_in = all_labels.data() + i * k;
73 for (
long s = 0; s < nshard; s++) {
75 if (I_in[stride * s] >= 0) {
76 heap_push<C> (++heap_size, heap_vals, shard_ids,
81 distance_t *D = distances + i * k;
82 idx_t *I = labels + i * k;
84 for (
int j = 0; j < k; j++) {
93 I[j] = I_in[stride * s + p] + translations[s];
95 heap_pop<C> (heap_size--, heap_vals, shard_ids);
97 if (p < k && I_in[stride * s + p] >= 0) {
98 heap_push<C> (++heap_size, heap_vals, shard_ids,
99 D_in[stride * s + p], s);
109 template <
typename IndexT>
114 successive_ids(successive_ids) {
117 template <
typename IndexT>
122 successive_ids(successive_ids) {
125 template <
typename IndexT>
129 successive_ids(successive_ids) {
132 template <
typename IndexT>
135 sync_with_shard_indexes();
138 template <
typename IndexT>
141 sync_with_shard_indexes();
144 template <
typename IndexT>
147 if (!this->count()) {
148 this->is_trained =
false;
154 auto firstIndex = this->at(0);
155 this->metric_type = firstIndex->metric_type;
156 this->is_trained = firstIndex->is_trained;
157 this->ntotal = firstIndex->ntotal;
159 for (
int i = 1; i < this->count(); ++i) {
160 auto index = this->at(i);
161 FAISS_THROW_IF_NOT(this->metric_type == index->
metric_type);
162 FAISS_THROW_IF_NOT(this->d == index->
d);
164 this->ntotal += index->
ntotal;
168 template <
typename IndexT>
170 IndexShardsTemplate<IndexT>::train(idx_t n,
171 const component_t *x) {
173 [n, x](
int no, IndexT *index) {
175 printf(
"begin train shard %d on %ld points\n", no, n);
181 printf(
"end train shard %d\n", no);
185 this->runOnIndex(fn);
186 sync_with_shard_indexes();
189 template <
typename IndexT>
192 const component_t *x) {
193 add_with_ids(n, x,
nullptr);
196 template <
typename IndexT>
199 const component_t * x,
202 FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
203 "It makes no sense to pass in ids and "
204 "request them to be shifted");
206 if (successive_ids) {
207 FAISS_THROW_IF_NOT_MSG(!xids,
208 "It makes no sense to pass in ids and "
209 "request them to be shifted");
210 FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
211 "when adding to IndexShards with sucessive_ids, "
212 "only add() in a single pass is supported");
215 idx_t nshard = this->count();
216 const idx_t *ids = xids;
218 std::vector<idx_t> aids;
220 if (!ids && !successive_ids) {
223 for (idx_t i = 0; i < n; i++) {
224 aids[i] = this->ntotal + i;
230 size_t components_per_vec =
231 sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
234 [n, ids, x, nshard, components_per_vec](
int no, IndexT *index) {
235 idx_t i0 = (idx_t) no * n / nshard;
236 idx_t i1 = ((idx_t) no + 1) * n / nshard;
237 auto x0 = x + i0 * components_per_vec;
240 printf (
"begin add shard %d on %ld points\n", no, n);
246 index->
add (i1 - i0, x0);
250 printf (
"end add shard %d on %ld points\n", no, i1 - i0);
254 this->runOnIndex(fn);
261 template <
typename IndexT>
264 const component_t *x,
266 distance_t *distances,
267 idx_t *labels)
const {
268 long nshard = this->count();
270 std::vector<distance_t> all_distances(nshard * k * n);
271 std::vector<idx_t> all_labels(nshard * k * n);
274 [n, k, x, &all_distances, &all_labels](
int no,
const IndexT *index) {
276 printf (
"begin query shard %d on %ld points\n", no, n);
280 all_distances.data() + no * k * n,
281 all_labels.data() + no * k * n);
284 printf (
"end query shard %d\n", no);
288 this->runOnIndex(fn);
290 std::vector<long> translations(nshard, 0);
294 if (successive_ids) {
297 for (
int s = 0; s + 1 < nshard; s++) {
298 translations[s + 1] = translations[s] + this->at(s)->ntotal;
302 if (this->metric_type == METRIC_L2) {
303 merge_tables<IndexT, CMin<distance_t, int>>(
304 n, k, nshard, distances, labels,
305 all_distances, all_labels, translations);
307 merge_tables<IndexT, CMax<distance_t, int>>(
308 n, k, nshard, distances, labels,
309 all_distances, all_labels, translations);
314 template struct IndexShardsTemplate<Index>;
315 template struct IndexShardsTemplate<IndexBinary>;
long idx_t
all indices are this type
void train(idx_t n, const float *x) override
Trains the quantizer and calls train_residual to train sub-quantizers.
void add_with_ids(idx_t n, const component_t *x, const idx_t *xids) override
void add_with_ids(idx_t n, const float *x, const idx_t *xids) override
default implementation that calls encode_vectors
idx_t ntotal
total nb of indexed vectors
bool verbose
verbosity level
IndexShardsTemplate(bool threaded=false, bool successive_ids=true)
void add(idx_t n, const component_t *x) override
supported only for sub-indices that implement add_with_ids
void onAfterAddIndex(IndexT *index) override
Called just after an index is added.
MetricType metric_type
type of metric this index uses for search
void add(idx_t n, const float *x) override
Calls add_with_ids with NULL ids.
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
void onAfterRemoveIndex(IndexT *index) override
Called just after an index is removed.