Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/Clustering.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  kmeans clustering routines
11 */
12 
13 #include "Clustering.h"
14 
15 
16 
17 #include <cmath>
18 #include <cstdio>
19 #include <cstring>
20 
21 #include "utils.h"
22 #include "FaissAssert.h"
23 #include "IndexFlat.h"
24 
25 namespace faiss {
26 
28  niter(25),
29  nredo(1),
30  verbose(false), spherical(false),
31  update_index(false),
32  min_points_per_centroid(39),
33  max_points_per_centroid(256),
34  seed(1234)
35 {}
36 // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
37 
38 
39 Clustering::Clustering (int d, int k):
40  d(d), k(k) {}
41 
42 Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
43  ClusteringParameters (cp), d(d), k(k) {}
44 
45 
46 
47 static double imbalance_factor (int n, int k, long *assign) {
48  std::vector<int> hist(k, 0);
49  for (int i = 0; i < n; i++)
50  hist[assign[i]]++;
51 
52  double tot = 0, uf = 0;
53 
54  for (int i = 0 ; i < k ; i++) {
55  tot += hist[i];
56  uf += hist[i] * (double) hist[i];
57  }
58  uf = uf * k / (tot * tot);
59 
60  return uf;
61 }
62 
63 
64 
65 
66 void Clustering::train (idx_t nx, const float *x_in, Index & index) {
67  FAISS_THROW_IF_NOT_MSG (nx >= k,
68  "need at least as many training points as clusters");
69 
70  double t0 = getmillisecs();
71 
72  // yes it is the user's responsibility, but it may spare us some
73  // hard-to-debug reports.
74  for (size_t i = 0; i < nx * d; i++) {
75  FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
76  "input contains NaN's or Inf's");
77  }
78 
79  const float *x = x_in;
81 
82  if (nx > k * max_points_per_centroid) {
83  if (verbose)
84  printf("Sampling a subset of %ld / %ld for training\n",
85  k * max_points_per_centroid, nx);
86  std::vector<int> perm (nx);
87  rand_perm (perm.data (), nx, seed);
88  nx = k * max_points_per_centroid;
89  float * x_new = new float [nx * d];
90  for (idx_t i = 0; i < nx; i++)
91  memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
92  x = x_new;
93  del1.set (x);
94  } else if (nx < k * min_points_per_centroid) {
95  fprintf (stderr,
96  "WARNING clustering %ld points to %ld centroids: "
97  "please provide at least %ld training points\n",
98  nx, k, idx_t(k) * min_points_per_centroid);
99  }
100 
101 
102  if (verbose)
103  printf("Clustering %d points in %ldD to %ld clusters, "
104  "redo %d times, %d iterations\n",
105  int(nx), d, k, nredo, niter);
106 
107 
108  idx_t * assign = new idx_t[nx];
109  ScopeDeleter<idx_t> del (assign);
110  float * dis = new float[nx];
111  ScopeDeleter<float> del2(dis);
112 
113  float best_err = 1e50;
114  double t_search_tot = 0;
115  if (verbose) {
116  printf(" Preprocessing in %.2f s\n",
117  (getmillisecs() - t0)/1000.);
118  }
119  t0 = getmillisecs();
120 
121  for (int redo = 0; redo < nredo; redo++) {
122 
123  std::vector<float> buf_centroids;
124 
125  std::vector<float> &cur_centroids =
126  nredo == 1 ? centroids : buf_centroids;
127 
128  if (verbose && nredo > 1) {
129  printf("Outer iteration %d / %d\n", redo, nredo);
130  }
131 
132  if (cur_centroids.size() == 0) {
133  // initialize centroids with random points from the dataset
134  cur_centroids.resize (d * k);
135  std::vector<int> perm (nx);
136 
137  rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
138 #pragma omp parallel for
139  for (int i = 0; i < k ; i++)
140  memcpy (&cur_centroids[i * d], x + perm[i] * d,
141  d * sizeof (float));
142  } else { // assume user provides some meaningful initialization
143  FAISS_THROW_IF_NOT (cur_centroids.size() == d * k);
144  FAISS_THROW_IF_NOT_MSG (nredo == 1,
145  "will redo with same initialization");
146  }
147 
148  if (spherical)
149  fvec_renorm_L2 (d, k, cur_centroids.data());
150 
151  if (!index.is_trained)
152  index.train (k, cur_centroids.data());
153 
154  FAISS_THROW_IF_NOT (index.ntotal == 0);
155  index.add (k, cur_centroids.data());
156  float err = 0;
157  for (int i = 0; i < niter; i++) {
158  double t0s = getmillisecs();
159  index.search (nx, x, 1, dis, assign);
160  t_search_tot += getmillisecs() - t0s;
161 
162  err = 0;
163  for (int j = 0; j < nx; j++)
164  err += dis[j];
165  obj.push_back (err);
166 
167  int nsplit = km_update_centroids (x, cur_centroids.data(),
168  assign, d, k, nx);
169 
170  if (verbose) {
171  printf (" Iteration %d (%.2f s, search %.2f s): "
172  "objective=%g imbalance=%.3f nsplit=%d \r",
173  i, (getmillisecs() - t0) / 1000.0,
174  t_search_tot / 1000,
175  err, imbalance_factor (nx, k, assign),
176  nsplit);
177  fflush (stdout);
178  }
179 
180  if (spherical)
181  fvec_renorm_L2 (d, k, cur_centroids.data());
182 
183  index.reset ();
184  if (update_index)
185  index.train (k, cur_centroids.data());
186 
187  assert (index.ntotal == 0);
188  index.add (k, cur_centroids.data());
189  }
190  if (verbose) printf("\n");
191  if (nredo > 1) {
192  if (err < best_err) {
193  if (verbose)
194  printf ("Objective improved: keep new clusters\n");
195  centroids = buf_centroids;
196  best_err = err;
197  }
198  index.reset ();
199  }
200  }
201 
202 }
203 
204 float kmeans_clustering (size_t d, size_t n, size_t k,
205  const float *x,
206  float *centroids)
207 {
208  Clustering clus (d, k);
209  clus.verbose = d * n * k > (1L << 30);
210  // display logs if > 1Gflop per iteration
211  IndexFlatL2 index (d);
212  clus.train (n, x, index);
213  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
214  return clus.obj.back();
215 }
216 
217 } // namespace faiss
int niter
clustering iterations
Definition: Clustering.h:25
int km_update_centroids(const float *x, float *centroids, long *assign, size_t d, size_t k, size_t n)
Definition: utils.cpp:1369
int nredo
redo clustering this many times and keep best
Definition: Clustering.h:26
ClusteringParameters()
sets reasonable defaults
Definition: Clustering.cpp:27
virtual void reset()=0
removes all elements from the database.
Clustering(int d, int k)
the only mandatory parameters are k and d
Definition: Clustering.cpp:39
size_t k
nb of centroids
Definition: Clustering.h:59
int seed
seed for the random number generator
Definition: Clustering.h:35
int min_points_per_centroid
otherwise you get a warning
Definition: Clustering.h:32
virtual void add(idx_t n, const float *x)=0
std::vector< float > obj
Definition: Clustering.h:66
float kmeans_clustering(size_t d, size_t n, size_t k, const float *x, float *centroids)
Definition: Clustering.cpp:204
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
double getmillisecs()
ms elapsed since some arbitrary epoch
Definition: utils.cpp:70
std::vector< float > centroids
centroids (k * d)
Definition: Clustering.h:62
size_t d
dimension of the vectors
Definition: Clustering.h:58
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
bool update_index
update index after each iteration?
Definition: Clustering.h:30
virtual void train(idx_t n, const float *x, faiss::Index &index)
Index is used during the assignment stage.
Definition: Clustering.cpp:66
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
virtual void train(idx_t n, const float *x)
Definition: Index.h:89
bool spherical
do we want normalized centroids?
Definition: Clustering.h:29
int max_points_per_centroid
to limit size of dataset
Definition: Clustering.h:33