faiss/demos/demo_weighted_kmeans.cpp

182 lines
4.8 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <faiss/Clustering.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexHNSW.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/random.h>
namespace {
enum WeightedKMeansType {
WKMT_FlatL2,
WKMT_FlatIP,
WKMT_FlatIP_spherical,
WKMT_HNSW,
};
float weighted_kmeans_clustering(
size_t d,
size_t n,
size_t k,
const float* input,
const float* weights,
float* centroids,
WeightedKMeansType index_num) {
using namespace faiss;
Clustering clus(d, k);
clus.verbose = true;
std::unique_ptr<Index> index;
switch (index_num) {
case WKMT_FlatL2:
index = std::make_unique<IndexFlatL2>(d);
break;
case WKMT_FlatIP:
index = std::make_unique<IndexFlatIP>(d);
break;
case WKMT_FlatIP_spherical:
index = std::make_unique<IndexFlatIP>(d);
clus.spherical = true;
break;
case WKMT_HNSW:
IndexHNSWFlat* ihnsw = new IndexHNSWFlat(d, 32);
ihnsw->hnsw.efSearch = 128;
index.reset(ihnsw);
break;
}
clus.train(n, input, *index.get(), weights);
// on output the index contains the centroids.
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
return clus.iteration_stats.back().obj;
}
int d = 32;
float sigma = 0.1;
#define BIGTEST
#ifdef BIGTEST
// the production setup = setting of https://fb.quip.com/CWgnAAYbwtgs
int nc = 200000;
int n_big = 4;
int n_small = 2;
#else
int nc = 5;
int n_big = 100;
int n_small = 10;
#endif
int n; // number of training points
void generate_trainset(
std::vector<float>& ccent,
std::vector<float>& x,
std::vector<float>& weights) {
// same sampling as test_build_blocks.py test_weighted
ccent.resize(d * 2 * nc);
faiss::float_randn(ccent.data(), d * 2 * nc, 123);
faiss::fvec_renorm_L2(d, 2 * nc, ccent.data());
n = nc * n_big + nc * n_small;
x.resize(d * n);
weights.resize(n);
faiss::float_randn(x.data(), x.size(), 1234);
float* xi = x.data();
float* w = weights.data();
for (int ci = 0; ci < nc * 2; ci++) { // loop over centroids
int np = ci < nc ? n_big : n_small; // nb of points around this centroid
for (int i = 0; i < np; i++) {
for (int j = 0; j < d; j++) {
xi[j] = xi[j] * sigma + ccent[ci * d + j];
}
*w++ = ci < nc ? 0.1 : 10;
xi += d;
}
}
}
} // namespace
int main(int argc, char** argv) {
std::vector<float> ccent;
std::vector<float> x;
std::vector<float> weights;
printf("generate training set\n");
generate_trainset(ccent, x, weights);
std::vector<float> centroids;
centroids.resize(nc * d);
int the_index_num = -1;
int the_with_weights = -1;
if (argc == 3) {
the_index_num = atoi(argv[1]);
the_with_weights = atoi(argv[2]);
}
for (int index_num = WKMT_FlatL2; index_num <= WKMT_HNSW; index_num++) {
if (the_index_num >= 0 && index_num != the_index_num) {
continue;
}
for (int with_weights = 0; with_weights <= 1; with_weights++) {
if (the_with_weights >= 0 && with_weights != the_with_weights) {
continue;
}
printf("=================== index_num=%d Run %s weights\n",
index_num,
with_weights ? "with" : "without");
weighted_kmeans_clustering(
d,
n,
nc,
x.data(),
with_weights ? weights.data() : nullptr,
centroids.data(),
(WeightedKMeansType)index_num);
{ // compute distance of points to centroids
faiss::IndexFlatL2 cent_index(d);
cent_index.add(nc, centroids.data());
std::vector<float> dis(n);
std::vector<faiss::idx_t> idx(n);
cent_index.search(
nc * 2, ccent.data(), 1, dis.data(), idx.data());
float dis1 = 0, dis2 = 0;
for (int i = 0; i < nc; i++) {
dis1 += dis[i];
}
printf("average distance of points from big clusters: %g\n",
dis1 / nc);
for (int i = 0; i < nc; i++) {
dis2 += dis[i + nc];
}
printf("average distance of points from small clusters: %g\n",
dis2 / nc);
}
}
}
return 0;
}