527 lines
16 KiB
C++
527 lines
16 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include <faiss/Clustering.h>
|
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
|
|
#include <cmath>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
|
|
#include <omp.h>
|
|
|
|
#include <faiss/utils/utils.h>
|
|
#include <faiss/utils/random.h>
|
|
#include <faiss/utils/distances.h>
|
|
#include <faiss/impl/FaissAssert.h>
|
|
#include <faiss/IndexFlat.h>
|
|
|
|
namespace faiss {
|
|
|
|
ClusteringParameters::ClusteringParameters ():
|
|
niter(25),
|
|
nredo(1),
|
|
verbose(false),
|
|
spherical(false),
|
|
int_centroids(false),
|
|
update_index(false),
|
|
frozen_centroids(false),
|
|
min_points_per_centroid(39),
|
|
max_points_per_centroid(256),
|
|
seed(1234),
|
|
decode_block_size(32768)
|
|
{}
|
|
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
|
|
|
|
|
Clustering::Clustering (int d, int k):
|
|
d(d), k(k) {}
|
|
|
|
Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
|
|
ClusteringParameters (cp), d(d), k(k) {}
|
|
|
|
|
|
|
|
static double imbalance_factor (int n, int k, int64_t *assign) {
|
|
std::vector<int> hist(k, 0);
|
|
for (int i = 0; i < n; i++)
|
|
hist[assign[i]]++;
|
|
|
|
double tot = 0, uf = 0;
|
|
|
|
for (int i = 0 ; i < k ; i++) {
|
|
tot += hist[i];
|
|
uf += hist[i] * (double) hist[i];
|
|
}
|
|
uf = uf * k / (tot * tot);
|
|
|
|
return uf;
|
|
}
|
|
|
|
void Clustering::post_process_centroids ()
|
|
{
|
|
|
|
if (spherical) {
|
|
fvec_renorm_L2 (d, k, centroids.data());
|
|
}
|
|
|
|
if (int_centroids) {
|
|
for (size_t i = 0; i < centroids.size(); i++)
|
|
centroids[i] = roundf (centroids[i]);
|
|
}
|
|
}
|
|
|
|
|
|
void Clustering::train (idx_t nx, const float *x_in, Index & index,
|
|
const float *weights) {
|
|
train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
|
|
index, weights);
|
|
}
|
|
|
|
|
|
namespace {
|
|
|
|
using idx_t = Clustering::idx_t;
|
|
|
|
idx_t subsample_training_set(
|
|
const Clustering &clus, idx_t nx, const uint8_t *x,
|
|
size_t line_size, const float * weights,
|
|
uint8_t **x_out,
|
|
float **weights_out
|
|
)
|
|
{
|
|
if (clus.verbose) {
|
|
printf("Sampling a subset of %ld / %ld for training\n",
|
|
clus.k * clus.max_points_per_centroid, nx);
|
|
}
|
|
std::vector<int> perm (nx);
|
|
rand_perm (perm.data (), nx, clus.seed);
|
|
nx = clus.k * clus.max_points_per_centroid;
|
|
uint8_t * x_new = new uint8_t [nx * line_size];
|
|
*x_out = x_new;
|
|
for (idx_t i = 0; i < nx; i++) {
|
|
memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
|
|
}
|
|
if (weights) {
|
|
float *weights_new = new float[nx];
|
|
for (idx_t i = 0; i < nx; i++) {
|
|
weights_new[i] = weights[perm[i]];
|
|
}
|
|
*weights_out = weights_new;
|
|
} else {
|
|
*weights_out = nullptr;
|
|
}
|
|
return nx;
|
|
}
|
|
|
|
/** compute centroids as (weighted) sum of training points
|
|
*
|
|
* @param x training vectors, size n * code_size (from codec)
|
|
* @param codec how to decode the vectors (if NULL then cast to float*)
|
|
* @param weights per-training vector weight, size n (or NULL)
|
|
* @param assign nearest centroid for each training vector, size n
|
|
* @param k_frozen do not update the k_frozen first centroids
|
|
* @param centroids centroid vectors (output only), size k * d
|
|
* @param hassign histogram of assignments per centroid (size k),
|
|
* should be 0 on input
|
|
*
|
|
*/
|
|
|
|
void compute_centroids (size_t d, size_t k, size_t n,
|
|
size_t k_frozen,
|
|
const uint8_t * x, const Index *codec,
|
|
const int64_t * assign,
|
|
const float * weights,
|
|
float * hassign,
|
|
float * centroids)
|
|
{
|
|
k -= k_frozen;
|
|
centroids += k_frozen * d;
|
|
|
|
memset (centroids, 0, sizeof(*centroids) * d * k);
|
|
|
|
size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
|
|
|
|
#pragma omp parallel
|
|
{
|
|
int nt = omp_get_num_threads();
|
|
int rank = omp_get_thread_num();
|
|
|
|
// this thread is taking care of centroids c0:c1
|
|
size_t c0 = (k * rank) / nt;
|
|
size_t c1 = (k * (rank + 1)) / nt;
|
|
std::vector<float> decode_buffer (d);
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
int64_t ci = assign[i];
|
|
assert (ci >= 0 && ci < k + k_frozen);
|
|
ci -= k_frozen;
|
|
if (ci >= c0 && ci < c1) {
|
|
float * c = centroids + ci * d;
|
|
const float * xi;
|
|
if (!codec) {
|
|
xi = reinterpret_cast<const float*>(x + i * line_size);
|
|
} else {
|
|
float *xif = decode_buffer.data();
|
|
codec->sa_decode (1, x + i * line_size, xif);
|
|
xi = xif;
|
|
}
|
|
if (weights) {
|
|
float w = weights[i];
|
|
hassign[ci] += w;
|
|
for (size_t j = 0; j < d; j++) {
|
|
c[j] += xi[j] * w;
|
|
}
|
|
} else {
|
|
hassign[ci] += 1.0;
|
|
for (size_t j = 0; j < d; j++) {
|
|
c[j] += xi[j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
#pragma omp parallel for
|
|
for (size_t ci = 0; ci < k; ci++) {
|
|
if (hassign[ci] == 0) {
|
|
continue;
|
|
}
|
|
float norm = 1 / hassign[ci];
|
|
float * c = centroids + ci * d;
|
|
for (size_t j = 0; j < d; j++) {
|
|
c[j] *= norm;
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// a bit above machine epsilon for float16
|
|
#define EPS (1 / 1024.)
|
|
|
|
/** Handle empty clusters by splitting larger ones.
|
|
*
|
|
* It works by slightly changing the centroids to make 2 clusters from
|
|
* a single one. Takes the same arguements as compute_centroids.
|
|
*
|
|
* @return nb of spliting operations (larger is worse)
|
|
*/
|
|
int split_clusters (size_t d, size_t k, size_t n,
|
|
size_t k_frozen,
|
|
float * hassign,
|
|
float * centroids)
|
|
{
|
|
k -= k_frozen;
|
|
centroids += k_frozen * d;
|
|
|
|
/* Take care of void clusters */
|
|
size_t nsplit = 0;
|
|
RandomGenerator rng (1234);
|
|
for (size_t ci = 0; ci < k; ci++) {
|
|
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
|
size_t cj;
|
|
for (cj = 0; 1; cj = (cj + 1) % k) {
|
|
/* probability to pick this cluster for split */
|
|
float p = (hassign[cj] - 1.0) / (float) (n - k);
|
|
float r = rng.rand_float ();
|
|
if (r < p) {
|
|
break; /* found our cluster to be split */
|
|
}
|
|
}
|
|
memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
|
|
|
|
/* small symmetric pertubation */
|
|
for (size_t j = 0; j < d; j++) {
|
|
if (j % 2 == 0) {
|
|
centroids[ci * d + j] *= 1 + EPS;
|
|
centroids[cj * d + j] *= 1 - EPS;
|
|
} else {
|
|
centroids[ci * d + j] *= 1 - EPS;
|
|
centroids[cj * d + j] *= 1 + EPS;
|
|
}
|
|
}
|
|
|
|
/* assume even split of the cluster */
|
|
hassign[ci] = hassign[cj] / 2;
|
|
hassign[cj] -= hassign[ci];
|
|
nsplit++;
|
|
}
|
|
}
|
|
|
|
return nsplit;
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
const Index * codec, Index & index,
|
|
const float *weights) {
|
|
|
|
FAISS_THROW_IF_NOT_FMT (nx >= k,
|
|
"Number of training points (%ld) should be at least "
|
|
"as large as number of clusters (%ld)", nx, k);
|
|
|
|
FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
|
|
"Codec dimension %d not the same as data dimension %d",
|
|
int(codec->d), int(d));
|
|
|
|
FAISS_THROW_IF_NOT_FMT (index.d == d,
|
|
"Index dimension %d not the same as data dimension %d",
|
|
int(index.d), int(d));
|
|
|
|
double t0 = getmillisecs();
|
|
|
|
if (!codec) {
|
|
// Check for NaNs in input data. Normally it is the user's
|
|
// responsibility, but it may spare us some hard-to-debug
|
|
// reports.
|
|
const float *x = reinterpret_cast<const float *>(x_in);
|
|
for (size_t i = 0; i < nx * d; i++) {
|
|
FAISS_THROW_IF_NOT_MSG (finite (x[i]),
|
|
"input contains NaN's or Inf's");
|
|
}
|
|
}
|
|
|
|
const uint8_t *x = x_in;
|
|
std::unique_ptr<uint8_t []> del1;
|
|
std::unique_ptr<float []> del3;
|
|
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
|
|
|
if (nx > k * max_points_per_centroid) {
|
|
uint8_t *x_new;
|
|
float *weights_new;
|
|
nx = subsample_training_set (*this, nx, x, line_size, weights,
|
|
&x_new, &weights_new);
|
|
del1.reset (x_new); x = x_new;
|
|
del3.reset (weights_new); weights = weights_new;
|
|
} else if (nx < k * min_points_per_centroid) {
|
|
fprintf (stderr,
|
|
"WARNING clustering %ld points to %ld centroids: "
|
|
"please provide at least %ld training points\n",
|
|
nx, k, idx_t(k) * min_points_per_centroid);
|
|
}
|
|
|
|
if (nx == k) {
|
|
// this is a corner case, just copy training set to clusters
|
|
if (verbose) {
|
|
printf("Number of training points (%ld) same as number of "
|
|
"clusters, just copying\n", nx);
|
|
}
|
|
centroids.resize (d * k);
|
|
if (!codec) {
|
|
memcpy (centroids.data(), x_in, sizeof (float) * d * k);
|
|
} else {
|
|
codec->sa_decode (nx, x_in, centroids.data());
|
|
}
|
|
|
|
// one fake iteration...
|
|
ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
|
|
iteration_stats.push_back (stats);
|
|
|
|
index.reset();
|
|
index.add(k, centroids.data());
|
|
return;
|
|
}
|
|
|
|
|
|
if (verbose) {
|
|
printf("Clustering %d points in %ldD to %ld clusters, "
|
|
"redo %d times, %d iterations\n",
|
|
int(nx), d, k, nredo, niter);
|
|
if (codec) {
|
|
printf("Input data encoded in %ld bytes per vector\n",
|
|
codec->sa_code_size ());
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<idx_t []> assign(new idx_t[nx]);
|
|
std::unique_ptr<float []> dis(new float[nx]);
|
|
|
|
// remember best iteration for redo
|
|
float best_err = HUGE_VALF;
|
|
std::vector<ClusteringIterationStats> best_obj;
|
|
std::vector<float> best_centroids;
|
|
|
|
// support input centroids
|
|
|
|
FAISS_THROW_IF_NOT_MSG (
|
|
centroids.size() % d == 0,
|
|
"size of provided input centroids not a multiple of dimension"
|
|
);
|
|
|
|
size_t n_input_centroids = centroids.size() / d;
|
|
|
|
if (verbose && n_input_centroids > 0) {
|
|
printf (" Using %zd centroids provided as input (%sfrozen)\n",
|
|
n_input_centroids, frozen_centroids ? "" : "not ");
|
|
}
|
|
|
|
double t_search_tot = 0;
|
|
if (verbose) {
|
|
printf(" Preprocessing in %.2f s\n",
|
|
(getmillisecs() - t0) / 1000.);
|
|
}
|
|
t0 = getmillisecs();
|
|
|
|
// temporary buffer to decode vectors during the optimization
|
|
std::vector<float> decode_buffer
|
|
(codec ? d * decode_block_size : 0);
|
|
|
|
for (int redo = 0; redo < nredo; redo++) {
|
|
|
|
if (verbose && nredo > 1) {
|
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
|
}
|
|
|
|
// initialize (remaining) centroids with random points from the dataset
|
|
centroids.resize (d * k);
|
|
std::vector<int> perm (nx);
|
|
|
|
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
|
|
|
|
if (!codec) {
|
|
for (int i = n_input_centroids; i < k ; i++) {
|
|
memcpy (¢roids[i * d], x + perm[i] * line_size, line_size);
|
|
}
|
|
} else {
|
|
for (int i = n_input_centroids; i < k ; i++) {
|
|
codec->sa_decode (1, x + perm[i] * line_size, ¢roids[i * d]);
|
|
}
|
|
}
|
|
|
|
post_process_centroids ();
|
|
|
|
// prepare the index
|
|
|
|
if (index.ntotal != 0) {
|
|
index.reset();
|
|
}
|
|
|
|
if (!index.is_trained) {
|
|
index.train (k, centroids.data());
|
|
}
|
|
|
|
index.add (k, centroids.data());
|
|
|
|
// k-means iterations
|
|
|
|
float err = 0;
|
|
for (int i = 0; i < niter; i++) {
|
|
double t0s = getmillisecs();
|
|
|
|
if (!codec) {
|
|
index.search (nx, reinterpret_cast<const float *>(x), 1,
|
|
dis.get(), assign.get());
|
|
} else {
|
|
// search by blocks of decode_block_size vectors
|
|
size_t code_size = codec->sa_code_size ();
|
|
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
|
|
size_t i1 = i0 + decode_block_size;
|
|
if (i1 > nx) { i1 = nx; }
|
|
codec->sa_decode (i1 - i0, x + code_size * i0,
|
|
decode_buffer.data ());
|
|
index.search (i1 - i0, decode_buffer.data (), 1,
|
|
dis.get() + i0, assign.get() + i0);
|
|
}
|
|
}
|
|
|
|
InterruptCallback::check();
|
|
t_search_tot += getmillisecs() - t0s;
|
|
|
|
// accumulate error
|
|
err = 0;
|
|
for (int j = 0; j < nx; j++) {
|
|
err += dis[j];
|
|
}
|
|
|
|
// update the centroids
|
|
std::vector<float> hassign (k);
|
|
|
|
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
|
compute_centroids (
|
|
d, k, nx, k_frozen,
|
|
x, codec, assign.get(), weights,
|
|
hassign.data(), centroids.data()
|
|
);
|
|
|
|
int nsplit = split_clusters (
|
|
d, k, nx, k_frozen,
|
|
hassign.data(), centroids.data()
|
|
);
|
|
|
|
// collect statistics
|
|
ClusteringIterationStats stats =
|
|
{ err, (getmillisecs() - t0) / 1000.0,
|
|
t_search_tot / 1000, imbalance_factor (nx, k, assign.get()),
|
|
nsplit };
|
|
iteration_stats.push_back(stats);
|
|
|
|
if (verbose) {
|
|
printf (" Iteration %d (%.2f s, search %.2f s): "
|
|
"objective=%g imbalance=%.3f nsplit=%d \r",
|
|
i, stats.time, stats.time_search, stats.obj,
|
|
stats.imbalance_factor, nsplit);
|
|
fflush (stdout);
|
|
}
|
|
|
|
post_process_centroids ();
|
|
|
|
// add centroids to index for the next iteration (or for output)
|
|
|
|
index.reset ();
|
|
if (update_index) {
|
|
index.train (k, centroids.data());
|
|
}
|
|
|
|
index.add (k, centroids.data());
|
|
InterruptCallback::check ();
|
|
}
|
|
|
|
if (verbose) printf("\n");
|
|
if (nredo > 1) {
|
|
if (err < best_err) {
|
|
if (verbose) {
|
|
printf ("Objective improved: keep new clusters\n");
|
|
}
|
|
best_centroids = centroids;
|
|
best_obj = iteration_stats;
|
|
best_err = err;
|
|
}
|
|
index.reset ();
|
|
}
|
|
}
|
|
if (nredo > 1) {
|
|
centroids = best_centroids;
|
|
iteration_stats = best_obj;
|
|
index.reset();
|
|
index.add(k, best_centroids.data());
|
|
}
|
|
|
|
}
|
|
|
|
float kmeans_clustering (size_t d, size_t n, size_t k,
|
|
const float *x,
|
|
float *centroids)
|
|
{
|
|
Clustering clus (d, k);
|
|
clus.verbose = d * n * k > (1L << 30);
|
|
// display logs if > 1Gflop per iteration
|
|
IndexFlatL2 index (d);
|
|
clus.train (n, x, index);
|
|
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
|
return clus.iteration_stats.back().obj;
|
|
}
|
|
|
|
} // namespace faiss
|