Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/Clustering.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  kmeans clustering routines
11 */
12 
13 #include "Clustering.h"
14 
15 
16 
17 #include <cmath>
18 #include <cstdio>
19 #include <cstring>
20 
21 #include "utils.h"
22 #include "FaissAssert.h"
23 #include "IndexFlat.h"
24 
25 namespace faiss {
26 
28  niter(25),
29  nredo(1),
30  verbose(false), spherical(false),
31  update_index(false),
32  frozen_centroids(false),
33  min_points_per_centroid(39),
34  max_points_per_centroid(256),
35  seed(1234)
36 {}
37 // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
38 
39 
40 Clustering::Clustering (int d, int k):
41  d(d), k(k) {}
42 
43 Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
44  ClusteringParameters (cp), d(d), k(k) {}
45 
46 
47 
48 static double imbalance_factor (int n, int k, long *assign) {
49  std::vector<int> hist(k, 0);
50  for (int i = 0; i < n; i++)
51  hist[assign[i]]++;
52 
53  double tot = 0, uf = 0;
54 
55  for (int i = 0 ; i < k ; i++) {
56  tot += hist[i];
57  uf += hist[i] * (double) hist[i];
58  }
59  uf = uf * k / (tot * tot);
60 
61  return uf;
62 }
63 
64 
65 
66 
67 void Clustering::train (idx_t nx, const float *x_in, Index & index) {
68  FAISS_THROW_IF_NOT_FMT (nx >= k,
69  "Number of training points (%ld) should be at least "
70  "as large as number of clusters (%ld)", nx, k);
71 
72  double t0 = getmillisecs();
73 
74  // yes it is the user's responsibility, but it may spare us some
75  // hard-to-debug reports.
76  for (size_t i = 0; i < nx * d; i++) {
77  FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
78  "input contains NaN's or Inf's");
79  }
80 
81  const float *x = x_in;
83 
84  if (nx > k * max_points_per_centroid) {
85  if (verbose)
86  printf("Sampling a subset of %ld / %ld for training\n",
87  k * max_points_per_centroid, nx);
88  std::vector<int> perm (nx);
89  rand_perm (perm.data (), nx, seed);
90  nx = k * max_points_per_centroid;
91  float * x_new = new float [nx * d];
92  for (idx_t i = 0; i < nx; i++)
93  memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
94  x = x_new;
95  del1.set (x);
96  } else if (nx < k * min_points_per_centroid) {
97  fprintf (stderr,
98  "WARNING clustering %ld points to %ld centroids: "
99  "please provide at least %ld training points\n",
100  nx, k, idx_t(k) * min_points_per_centroid);
101  }
102 
103 
104  if (nx == k) {
105  if (verbose) {
106  printf("Number of training points (%ld) same as number of "
107  "clusters, just copying\n", nx);
108  }
109  // this is a corner case, just copy training set to clusters
110  centroids.resize (d * k);
111  memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
112  index.reset();
113  index.add(k, x_in);
114  return;
115  }
116 
117 
118  if (verbose)
119  printf("Clustering %d points in %ldD to %ld clusters, "
120  "redo %d times, %d iterations\n",
121  int(nx), d, k, nredo, niter);
122 
123 
124 
125 
126  idx_t * assign = new idx_t[nx];
127  ScopeDeleter<idx_t> del (assign);
128  float * dis = new float[nx];
129  ScopeDeleter<float> del2(dis);
130 
131  // for redo
132  float best_err = HUGE_VALF;
133  std::vector<float> best_obj;
134  std::vector<float> best_centroids;
135 
136  // support input centroids
137 
138  FAISS_THROW_IF_NOT_MSG (
139  centroids.size() % d == 0,
140  "size of provided input centroids not a multiple of dimension");
141 
142  size_t n_input_centroids = centroids.size() / d;
143 
144  if (verbose && n_input_centroids > 0) {
145  printf (" Using %zd centroids provided as input (%sfrozen)\n",
146  n_input_centroids, frozen_centroids ? "" : "not ");
147  }
148 
149  double t_search_tot = 0;
150  if (verbose) {
151  printf(" Preprocessing in %.2f s\n",
152  (getmillisecs() - t0)/1000.);
153  }
154  t0 = getmillisecs();
155 
156  for (int redo = 0; redo < nredo; redo++) {
157 
158  if (verbose && nredo > 1) {
159  printf("Outer iteration %d / %d\n", redo, nredo);
160  }
161 
162 
163  // initialize remaining centroids with random points from the dataset
164  centroids.resize (d * k);
165  std::vector<int> perm (nx);
166 
167  rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
168  for (int i = n_input_centroids; i < k ; i++)
169  memcpy (&centroids[i * d], x + perm[i] * d,
170  d * sizeof (float));
171 
172  if (spherical) {
173  fvec_renorm_L2 (d, k, centroids.data());
174  }
175 
176  if (index.ntotal != 0) {
177  index.reset();
178  }
179 
180  if (!index.is_trained) {
181  index.train (k, centroids.data());
182  }
183 
184  index.add (k, centroids.data());
185  float err = 0;
186  for (int i = 0; i < niter; i++) {
187  double t0s = getmillisecs();
188  index.search (nx, x, 1, dis, assign);
189  t_search_tot += getmillisecs() - t0s;
190 
191  err = 0;
192  for (int j = 0; j < nx; j++)
193  err += dis[j];
194  obj.push_back (err);
195 
196  int nsplit = km_update_centroids (
197  x, centroids.data(),
198  assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
199 
200  if (verbose) {
201  printf (" Iteration %d (%.2f s, search %.2f s): "
202  "objective=%g imbalance=%.3f nsplit=%d \r",
203  i, (getmillisecs() - t0) / 1000.0,
204  t_search_tot / 1000,
205  err, imbalance_factor (nx, k, assign),
206  nsplit);
207  fflush (stdout);
208  }
209 
210  if (spherical)
211  fvec_renorm_L2 (d, k, centroids.data());
212 
213  index.reset ();
214  if (update_index)
215  index.train (k, centroids.data());
216 
217  assert (index.ntotal == 0);
218  index.add (k, centroids.data());
219  }
220  if (verbose) printf("\n");
221  if (nredo > 1) {
222  if (err < best_err) {
223  if (verbose)
224  printf ("Objective improved: keep new clusters\n");
225  best_centroids = centroids;
226  best_obj = obj;
227  best_err = err;
228  }
229  index.reset ();
230  }
231  }
232  if (nredo > 1) {
233  centroids = best_centroids;
234  obj = best_obj;
235  index.reset();
236  index.add(k, best_centroids.data());
237  }
238 
239 }
240 
241 float kmeans_clustering (size_t d, size_t n, size_t k,
242  const float *x,
243  float *centroids)
244 {
245  Clustering clus (d, k);
246  clus.verbose = d * n * k > (1L << 30);
247  // display logs if > 1Gflop per iteration
248  IndexFlatL2 index (d);
249  clus.train (n, x, index);
250  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
251  return clus.obj.back();
252 }
253 
254 } // namespace faiss
int km_update_centroids(const float *x, float *centroids, long *assign, size_t d, size_t k, size_t n, size_t k_frozen)
Definition: utils.cpp:1401
int niter
clustering iterations
Definition: Clustering.h:25
int nredo
redo clustering this many times and keep best
Definition: Clustering.h:26
ClusteringParameters()
sets reasonable defaults
Definition: Clustering.cpp:27
virtual void reset()=0
removes all elements from the database.
Clustering(int d, int k)
the only mandatory parameters are k and d
Definition: Clustering.cpp:40
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:23
size_t k
nb of centroids
Definition: Clustering.h:60
int seed
seed for the random number generator
Definition: Clustering.h:36
bool frozen_centroids
use the centroids provided as input and do not change them during iterations
Definition: Clustering.h:31
int min_points_per_centroid
otherwise you get a warning
Definition: Clustering.h:33
virtual void add(idx_t n, const float *x)=0
std::vector< float > obj
Definition: Clustering.h:67
float kmeans_clustering(size_t d, size_t n, size_t k, const float *x, float *centroids)
Definition: Clustering.cpp:241
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
double getmillisecs()
ms elapsed since some arbitrary epoch
Definition: utils.cpp:74
std::vector< float > centroids
centroids (k * d)
Definition: Clustering.h:63
size_t d
dimension of the vectors
Definition: Clustering.h:59
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
bool update_index
update index after each iteration?
Definition: Clustering.h:30
virtual void train(idx_t n, const float *x, faiss::Index &index)
Index is used during the assignment stage.
Definition: Clustering.cpp:67
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
bool spherical
do we want normalized centroids?
Definition: Clustering.h:29
int max_points_per_centroid
to limit size of dataset
Definition: Clustering.h:34