Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/Heap.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 /* Function for soft heap */
12 
13 #include "Heap.h"
14 
15 
16 namespace faiss {
17 
18 
19 template <typename C>
21 {
22 #pragma omp parallel for
23  for (size_t j = 0; j < nh; j++)
24  heap_heapify<C> (k, val + j * k, ids + j * k);
25 }
26 
27 template <typename C>
29 {
30 #pragma omp parallel for
31  for (size_t j = 0; j < nh; j++)
32  heap_reorder<C> (k, val + j * k, ids + j * k);
33 }
34 
35 template <typename C>
36 void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
37  size_t i0, long ni)
38 {
39  if (ni == -1) ni = nh;
40  assert (i0 >= 0 && i0 + ni <= nh);
41 #pragma omp parallel for
42  for (size_t i = i0; i < i0 + ni; i++) {
43  T * __restrict simi = get_val(i);
44  TI * __restrict idxi = get_ids (i);
45  const T *ip_line = vin + (i - i0) * nj;
46 
47  for (size_t j = 0; j < nj; j++) {
48  T ip = ip_line [j];
49  if (C::cmp(simi[0], ip)) {
50  heap_pop<C> (k, simi, idxi);
51  heap_push<C> (k, simi, idxi, ip, j + j0);
52  }
53  }
54  }
55 }
56 
57 template <typename C>
59  size_t nj, const T *vin, const TI *id_in,
60  long id_stride, size_t i0, long ni)
61 {
62  if (id_in == nullptr) {
63  addn (nj, vin, 0, i0, ni);
64  return;
65  }
66  if (ni == -1) ni = nh;
67  assert (i0 >= 0 && i0 + ni <= nh);
68 #pragma omp parallel for
69  for (size_t i = i0; i < i0 + ni; i++) {
70  T * __restrict simi = get_val(i);
71  TI * __restrict idxi = get_ids (i);
72  const T *ip_line = vin + (i - i0) * nj;
73  const TI *id_line = id_in + (i - i0) * id_stride;
74 
75  for (size_t j = 0; j < nj; j++) {
76  T ip = ip_line [j];
77  if (C::cmp(simi[0], ip)) {
78  heap_pop<C> (k, simi, idxi);
79  heap_push<C> (k, simi, idxi, ip, id_line [j]);
80  }
81  }
82  }
83 }
84 
85 template <typename C>
87  T * out_val,
88  TI * out_ids) const
89 {
90 #pragma omp parallel for
91  for (size_t j = 0; j < nh; j++) {
92  long imin = -1;
93  typename C::T xval = C::Crev::neutral ();
94  const typename C::T * x_ = val + j * k;
95  for (size_t i = 0; i < k; i++)
96  if (C::cmp (x_[i], xval)) {
97  xval = x_[i];
98  imin = i;
99  }
100  if (out_val)
101  out_val[j] = xval;
102 
103  if (out_ids) {
104  if (ids && imin != -1)
105  out_ids[j] = ids [j * k + imin];
106  else
107  out_ids[j] = imin;
108  }
109  }
110 }
111 
112 
113 
114 
115 // explicit instanciations
116 
117 template struct HeapArray<CMin <float, long> >;
118 template struct HeapArray<CMax <float, long> >;
119 template struct HeapArray<CMin <int, long> >;
120 template struct HeapArray<CMax <int, long> >;
121 
122 
123 } // END namespace fasis
void reorder()
reorder all the heaps
Definition: Heap.cpp:28
void per_line_extrema(T *vals_out, TI *idx_out) const
Definition: Heap.cpp:86
void addn_with_ids(size_t nj, const T *vin, const TI *id_in=nullptr, long id_stride=0, size_t i0=0, long ni=-1)
Definition: Heap.cpp:58
void heapify()
prepare all the heaps before adding
Definition: Heap.cpp:20
void addn(size_t nj, const T *vin, TI j0=0, size_t i0=0, long ni=-1)
Definition: Heap.cpp:36