Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/Heap.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 /* Function for soft heap */
11 
12 #include "Heap.h"
13 
14 
15 namespace faiss {
16 
17 
18 template <typename C>
20 {
21 #pragma omp parallel for
22  for (size_t j = 0; j < nh; j++)
23  heap_heapify<C> (k, val + j * k, ids + j * k);
24 }
25 
26 template <typename C>
28 {
29 #pragma omp parallel for
30  for (size_t j = 0; j < nh; j++)
31  heap_reorder<C> (k, val + j * k, ids + j * k);
32 }
33 
34 template <typename C>
35 void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
36  size_t i0, long ni)
37 {
38  if (ni == -1) ni = nh;
39  assert (i0 >= 0 && i0 + ni <= nh);
40 #pragma omp parallel for
41  for (size_t i = i0; i < i0 + ni; i++) {
42  T * __restrict simi = get_val(i);
43  TI * __restrict idxi = get_ids (i);
44  const T *ip_line = vin + (i - i0) * nj;
45 
46  for (size_t j = 0; j < nj; j++) {
47  T ip = ip_line [j];
48  if (C::cmp(simi[0], ip)) {
49  heap_pop<C> (k, simi, idxi);
50  heap_push<C> (k, simi, idxi, ip, j + j0);
51  }
52  }
53  }
54 }
55 
56 template <typename C>
58  size_t nj, const T *vin, const TI *id_in,
59  long id_stride, size_t i0, long ni)
60 {
61  if (id_in == nullptr) {
62  addn (nj, vin, 0, i0, ni);
63  return;
64  }
65  if (ni == -1) ni = nh;
66  assert (i0 >= 0 && i0 + ni <= nh);
67 #pragma omp parallel for
68  for (size_t i = i0; i < i0 + ni; i++) {
69  T * __restrict simi = get_val(i);
70  TI * __restrict idxi = get_ids (i);
71  const T *ip_line = vin + (i - i0) * nj;
72  const TI *id_line = id_in + (i - i0) * id_stride;
73 
74  for (size_t j = 0; j < nj; j++) {
75  T ip = ip_line [j];
76  if (C::cmp(simi[0], ip)) {
77  heap_pop<C> (k, simi, idxi);
78  heap_push<C> (k, simi, idxi, ip, id_line [j]);
79  }
80  }
81  }
82 }
83 
84 template <typename C>
86  T * out_val,
87  TI * out_ids) const
88 {
89 #pragma omp parallel for
90  for (size_t j = 0; j < nh; j++) {
91  long imin = -1;
92  typename C::T xval = C::Crev::neutral ();
93  const typename C::T * x_ = val + j * k;
94  for (size_t i = 0; i < k; i++)
95  if (C::cmp (x_[i], xval)) {
96  xval = x_[i];
97  imin = i;
98  }
99  if (out_val)
100  out_val[j] = xval;
101 
102  if (out_ids) {
103  if (ids && imin != -1)
104  out_ids[j] = ids [j * k + imin];
105  else
106  out_ids[j] = imin;
107  }
108  }
109 }
110 
111 
112 
113 
114 // explicit instanciations
115 
116 template struct HeapArray<CMin <float, long> >;
117 template struct HeapArray<CMax <float, long> >;
118 template struct HeapArray<CMin <int, long> >;
119 template struct HeapArray<CMax <int, long> >;
120 
121 
122 } // END namespace fasis
void reorder()
reorder all the heaps
Definition: Heap.cpp:27
void per_line_extrema(T *vals_out, TI *idx_out) const
Definition: Heap.cpp:85
void addn_with_ids(size_t nj, const T *vin, const TI *id_in=nullptr, long id_stride=0, size_t i0=0, long ni=-1)
Definition: Heap.cpp:57
void heapify()
prepare all the heaps before adding
Definition: Heap.cpp:19
void addn(size_t nj, const T *vin, TI j0=0, size_t i0=0, long ni=-1)
Definition: Heap.cpp:35