Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/Clustering.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "Clustering.h"
11 #include "AuxIndexStructures.h"
12 
13 
14 #include <cmath>
15 #include <cstdio>
16 #include <cstring>
17 
18 #include "utils.h"
19 #include "FaissAssert.h"
20 #include "IndexFlat.h"
21 
22 namespace faiss {
23 
25  niter(25),
26  nredo(1),
27  verbose(false),
28  spherical(false),
29  int_centroids(false),
30  update_index(false),
31  frozen_centroids(false),
32  min_points_per_centroid(39),
33  max_points_per_centroid(256),
34  seed(1234)
35 {}
36 // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
37 
38 
39 Clustering::Clustering (int d, int k):
40  d(d), k(k) {}
41 
42 Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
43  ClusteringParameters (cp), d(d), k(k) {}
44 
45 
46 
47 static double imbalance_factor (int n, int k, long *assign) {
48  std::vector<int> hist(k, 0);
49  for (int i = 0; i < n; i++)
50  hist[assign[i]]++;
51 
52  double tot = 0, uf = 0;
53 
54  for (int i = 0 ; i < k ; i++) {
55  tot += hist[i];
56  uf += hist[i] * (double) hist[i];
57  }
58  uf = uf * k / (tot * tot);
59 
60  return uf;
61 }
62 
64 {
65 
66  if (spherical) {
67  fvec_renorm_L2 (d, k, centroids.data());
68  }
69 
70  if (int_centroids) {
71  for (size_t i = 0; i < centroids.size(); i++)
72  centroids[i] = roundf (centroids[i]);
73  }
74 }
75 
76 
77 void Clustering::train (idx_t nx, const float *x_in, Index & index) {
78  FAISS_THROW_IF_NOT_FMT (nx >= k,
79  "Number of training points (%ld) should be at least "
80  "as large as number of clusters (%ld)", nx, k);
81 
82  double t0 = getmillisecs();
83 
84  // yes it is the user's responsibility, but it may spare us some
85  // hard-to-debug reports.
86  for (size_t i = 0; i < nx * d; i++) {
87  FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
88  "input contains NaN's or Inf's");
89  }
90 
91  const float *x = x_in;
93 
94  if (nx > k * max_points_per_centroid) {
95  if (verbose)
96  printf("Sampling a subset of %ld / %ld for training\n",
97  k * max_points_per_centroid, nx);
98  std::vector<int> perm (nx);
99  rand_perm (perm.data (), nx, seed);
100  nx = k * max_points_per_centroid;
101  float * x_new = new float [nx * d];
102  for (idx_t i = 0; i < nx; i++)
103  memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
104  x = x_new;
105  del1.set (x);
106  } else if (nx < k * min_points_per_centroid) {
107  fprintf (stderr,
108  "WARNING clustering %ld points to %ld centroids: "
109  "please provide at least %ld training points\n",
110  nx, k, idx_t(k) * min_points_per_centroid);
111  }
112 
113 
114  if (nx == k) {
115  if (verbose) {
116  printf("Number of training points (%ld) same as number of "
117  "clusters, just copying\n", nx);
118  }
119  // this is a corner case, just copy training set to clusters
120  centroids.resize (d * k);
121  memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
122  index.reset();
123  index.add(k, x_in);
124  return;
125  }
126 
127 
128  if (verbose)
129  printf("Clustering %d points in %ldD to %ld clusters, "
130  "redo %d times, %d iterations\n",
131  int(nx), d, k, nredo, niter);
132 
133  idx_t * assign = new idx_t[nx];
134  ScopeDeleter<idx_t> del (assign);
135  float * dis = new float[nx];
136  ScopeDeleter<float> del2(dis);
137 
138  // for redo
139  float best_err = HUGE_VALF;
140  std::vector<float> best_obj;
141  std::vector<float> best_centroids;
142 
143  // support input centroids
144 
145  FAISS_THROW_IF_NOT_MSG (
146  centroids.size() % d == 0,
147  "size of provided input centroids not a multiple of dimension");
148 
149  size_t n_input_centroids = centroids.size() / d;
150 
151  if (verbose && n_input_centroids > 0) {
152  printf (" Using %zd centroids provided as input (%sfrozen)\n",
153  n_input_centroids, frozen_centroids ? "" : "not ");
154  }
155 
156  double t_search_tot = 0;
157  if (verbose) {
158  printf(" Preprocessing in %.2f s\n",
159  (getmillisecs() - t0) / 1000.);
160  }
161  t0 = getmillisecs();
162 
163  for (int redo = 0; redo < nredo; redo++) {
164 
165  if (verbose && nredo > 1) {
166  printf("Outer iteration %d / %d\n", redo, nredo);
167  }
168 
169  // initialize remaining centroids with random points from the dataset
170  centroids.resize (d * k);
171  std::vector<int> perm (nx);
172 
173  rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
174  for (int i = n_input_centroids; i < k ; i++)
175  memcpy (&centroids[i * d], x + perm[i] * d,
176  d * sizeof (float));
177 
179 
180  if (index.ntotal != 0) {
181  index.reset();
182  }
183 
184  if (!index.is_trained) {
185  index.train (k, centroids.data());
186  }
187 
188  index.add (k, centroids.data());
189  float err = 0;
190  for (int i = 0; i < niter; i++) {
191  double t0s = getmillisecs();
192  index.search (nx, x, 1, dis, assign);
194  t_search_tot += getmillisecs() - t0s;
195 
196  err = 0;
197  for (int j = 0; j < nx; j++)
198  err += dis[j];
199  obj.push_back (err);
200 
201  int nsplit = km_update_centroids (
202  x, centroids.data(),
203  assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
204 
205  if (verbose) {
206  printf (" Iteration %d (%.2f s, search %.2f s): "
207  "objective=%g imbalance=%.3f nsplit=%d \r",
208  i, (getmillisecs() - t0) / 1000.0,
209  t_search_tot / 1000,
210  err, imbalance_factor (nx, k, assign),
211  nsplit);
212  fflush (stdout);
213  }
214 
216 
217  index.reset ();
218  if (update_index)
219  index.train (k, centroids.data());
220 
221  assert (index.ntotal == 0);
222  index.add (k, centroids.data());
224  }
225  if (verbose) printf("\n");
226  if (nredo > 1) {
227  if (err < best_err) {
228  if (verbose)
229  printf ("Objective improved: keep new clusters\n");
230  best_centroids = centroids;
231  best_obj = obj;
232  best_err = err;
233  }
234  index.reset ();
235  }
236  }
237  if (nredo > 1) {
238  centroids = best_centroids;
239  obj = best_obj;
240  index.reset();
241  index.add(k, best_centroids.data());
242  }
243 
244 }
245 
246 float kmeans_clustering (size_t d, size_t n, size_t k,
247  const float *x,
248  float *centroids)
249 {
250  Clustering clus (d, k);
251  clus.verbose = d * n * k > (1L << 30);
252  // display logs if > 1Gflop per iteration
253  IndexFlatL2 index (d);
254  clus.train (n, x, index);
255  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
256  return clus.obj.back();
257 }
258 
259 } // namespace faiss
int km_update_centroids(const float *x, float *centroids, long *assign, size_t d, size_t k, size_t n, size_t k_frozen)
Definition: utils.cpp:1078
int niter
clustering iterations
Definition: Clustering.h:23
int nredo
redo clustering this many times and keep best
Definition: Clustering.h:24
ClusteringParameters()
sets reasonable defaults
Definition: Clustering.cpp:24
virtual void reset()=0
removes all elements from the database.
Clustering(int d, int k)
the only mandatory parameters are k and d
Definition: Clustering.cpp:39
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:23
size_t k
nb of centroids
Definition: Clustering.h:59
int seed
seed for the random number generator
Definition: Clustering.h:35
bool frozen_centroids
use the centroids provided as input and do not change them during iterations
Definition: Clustering.h:30
int min_points_per_centroid
otherwise you get a warning
Definition: Clustering.h:32
virtual void add(idx_t n, const float *x)=0
void post_process_centroids()
Definition: Clustering.cpp:63
std::vector< float > obj
Definition: Clustering.h:66
float kmeans_clustering(size_t d, size_t n, size_t k, const float *x, float *centroids)
Definition: Clustering.cpp:246
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
double getmillisecs()
ms elapsed since some arbitrary epoch
Definition: utils.cpp:69
std::vector< float > centroids
centroids (k * d)
Definition: Clustering.h:62
size_t d
dimension of the vectors
Definition: Clustering.h:58
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
bool update_index
update index after each iteration?
Definition: Clustering.h:29
bool int_centroids
round centroids coordinates to integer
Definition: Clustering.h:28
virtual void train(idx_t n, const float *x, faiss::Index &index)
Index is used during the assignment stage.
Definition: Clustering.cpp:77
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
bool spherical
do we want normalized centroids?
Definition: Clustering.h:27
int max_points_per_centroid
to limit size of dataset
Definition: Clustering.h:33