/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the CC-by-NC license found in the * LICENSE file in the root directory of this source tree. */ /* Copyright 2004-present Facebook. All Rights Reserved. kmeans clustering routines */ #include "Clustering.h" #include #include #include #include "utils.h" #include "FaissAssert.h" #include "IndexFlat.h" namespace faiss { ClusteringParameters::ClusteringParameters (): niter(25), nredo(1), verbose(false), spherical(false), update_index(false), min_points_per_centroid(39), max_points_per_centroid(256), seed(1234) {} // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k Clustering::Clustering (int d, int k): d(d), k(k) {} Clustering::Clustering (int d, int k, const ClusteringParameters &cp): ClusteringParameters (cp), d(d), k(k) {} static double imbalance_factor (int n, int k, long *assign) { std::vector hist(k, 0); for (int i = 0; i < n; i++) hist[assign[i]]++; double tot = 0, uf = 0; for (int i = 0 ; i < k ; i++) { tot += hist[i]; uf += hist[i] * (double) hist[i]; } uf = uf * k / (tot * tot); return uf; } void Clustering::train (idx_t nx, const float *x_in, Index & index) { FAISS_ASSERT (nx >= k || !"need at least as many training points as clusters"); double t0 = getmillisecs(); // yes it is the user's responsibility, but it may spare us some // hard-to-debug reports. for (size_t i = 0; i < nx * d; i++) { FAISS_ASSERT (finite (x_in[i]) || !"input contains NaN's or Inf's"); } const float *x = x_in; if (nx > k * max_points_per_centroid) { if (verbose) printf("Sampling a subset of %ld / %ld for training\n", k * max_points_per_centroid, nx); int *perm = new int[nx]; rand_perm (perm, nx, seed); nx = k * max_points_per_centroid; float * x_new = new float [nx * d]; for (idx_t i = 0; i < nx; i++) memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d); delete [] perm; x = x_new; } else if (nx < k * min_points_per_centroid) { fprintf (stderr, "WARNING clustering %ld points to %ld centroids: " "please provide at least %ld training points\n", nx, k, idx_t(k) * min_points_per_centroid); } if (verbose) printf("Clustering %d points in %ldD to %ld clusters, " "redo %d times, %d iterations\n", int(nx), d, k, nredo, niter); idx_t * assign = new idx_t[nx]; float * dis = new float[nx]; float best_err = 1e50; double t_search_tot = 0; if (verbose) { printf(" Preprocessing in %5g s\n", (getmillisecs() - t0)/1000.); } t0 = getmillisecs(); for (int redo = 0; redo < nredo; redo++) { std::vector buf_centroids; std::vector &cur_centroids = nredo == 1 ? centroids : buf_centroids; if (verbose && nredo > 1) { printf("Outer iteration %d / %d\n", redo, nredo); } if (cur_centroids.size() == 0) { // initialize centroids with random points from the dataset cur_centroids.resize (d * k); int *perm = new int[nx]; rand_perm (perm, nx, seed + 1 + redo * 15486557L); #pragma omp parallel for for (int i = 0; i < k ; i++) memcpy (&cur_centroids[i * d], x + perm[i] * d, d * sizeof (float)); delete [] perm; } else { // assume user provides some meaningful initialization FAISS_ASSERT (cur_centroids.size() == d * k); FAISS_ASSERT (nredo == 1 || !"will redo with same initialization"); } if (spherical) fvec_renorm_L2 (d, k, cur_centroids.data()); if (!index.is_trained) index.train (k, cur_centroids.data()); FAISS_ASSERT (index.ntotal == 0 ); index.add (k, cur_centroids.data()); float err = 0; for (int i = 0; i < niter; i++) { double t0s = getmillisecs(); index.search (nx, x, 1, dis, assign); t_search_tot += getmillisecs() - t0s; err = 0; for (int j = 0; j < nx; j++) err += dis[j]; obj.push_back (err); int nsplit = km_update_centroids (x, cur_centroids.data(), assign, d, k, nx); if (verbose) { printf (" Iteration %d (%5g s, search %5g s): " "objective=%g imbalance=%g nsplit=%d \r", i, (getmillisecs() - t0) / 1000.0, t_search_tot / 1000, err, imbalance_factor (nx, k, assign), nsplit); fflush (stdout); } if (spherical) fvec_renorm_L2 (d, k, cur_centroids.data()); index.reset (); if (update_index) index.train (k, cur_centroids.data()); assert (index.ntotal == 0); index.add (k, centroids.data()); } if (verbose) printf("\n"); if (nredo > 1) { if (err < best_err) { if (verbose) printf ("Keep new clusters\n"); centroids = cur_centroids; best_err = err; } } } delete [] assign; delete [] dis; if (x_in != x) delete [] x; } float kmeans_clustering (size_t d, size_t n, size_t k, const float *x, float *centroids) { Clustering clus (d, k); clus.verbose = d * n * k > (1L << 30); // display logs if > 1Gflop per iteration IndexFlatL2 index (d); clus.train (n, x, index); memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k); return clus.obj.back(); } } // namespace faiss