faiss/c_api/Clustering_c.cpp

171 lines
4.8 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include "Clustering_c.h"
#include <faiss/Clustering.h>
#include <faiss/Index.h>
#include <vector>
#include "macros_impl.h"
extern "C" {
using faiss::Clustering;
using faiss::ClusteringIterationStats;
using faiss::ClusteringParameters;
using faiss::Index;
DEFINE_GETTER(Clustering, int, niter)
DEFINE_GETTER(Clustering, int, nredo)
DEFINE_GETTER(Clustering, int, verbose)
DEFINE_GETTER(Clustering, int, spherical)
DEFINE_GETTER(Clustering, int, int_centroids)
DEFINE_GETTER(Clustering, int, update_index)
DEFINE_GETTER(Clustering, int, frozen_centroids)
DEFINE_GETTER(Clustering, int, min_points_per_centroid)
DEFINE_GETTER(Clustering, int, max_points_per_centroid)
DEFINE_GETTER(Clustering, int, seed)
DEFINE_GETTER(Clustering, size_t, decode_block_size)
/// getter for d
DEFINE_GETTER(Clustering, size_t, d)
/// getter for k
DEFINE_GETTER(Clustering, size_t, k)
DEFINE_GETTER(ClusteringIterationStats, float, obj)
DEFINE_GETTER(ClusteringIterationStats, double, time)
DEFINE_GETTER(ClusteringIterationStats, double, time_search)
DEFINE_GETTER(ClusteringIterationStats, double, imbalance_factor)
DEFINE_GETTER(ClusteringIterationStats, int, nsplit)
void faiss_ClusteringParameters_init(FaissClusteringParameters* params) {
ClusteringParameters d;
params->frozen_centroids = d.frozen_centroids;
params->max_points_per_centroid = d.max_points_per_centroid;
params->min_points_per_centroid = d.min_points_per_centroid;
params->niter = d.niter;
params->nredo = d.nredo;
params->seed = d.seed;
params->spherical = d.spherical;
params->int_centroids = d.int_centroids;
params->update_index = d.update_index;
params->verbose = d.verbose;
params->decode_block_size = d.decode_block_size;
}
// This conversion is required because the two types are not memory-compatible
inline ClusteringParameters from_faiss_c(
const FaissClusteringParameters* params) {
ClusteringParameters o;
o.frozen_centroids = params->frozen_centroids;
o.max_points_per_centroid = params->max_points_per_centroid;
o.min_points_per_centroid = params->min_points_per_centroid;
o.niter = params->niter;
o.nredo = params->nredo;
o.seed = params->seed;
o.spherical = params->spherical;
o.update_index = params->update_index;
o.int_centroids = params->int_centroids;
o.verbose = params->verbose;
o.decode_block_size = params->decode_block_size;
return o;
}
/// getter for centroids (size = k * d)
void faiss_Clustering_centroids(
FaissClustering* clustering,
float** centroids,
size_t* size) {
std::vector<float>& v =
reinterpret_cast<Clustering*>(clustering)->centroids;
if (centroids) {
*centroids = v.data();
}
if (size) {
*size = v.size();
}
}
/// getter for iteration stats
void faiss_Clustering_iteration_stats(
FaissClustering* clustering,
FaissClusteringIterationStats** iteration_stats,
size_t* size) {
std::vector<ClusteringIterationStats>& v =
reinterpret_cast<Clustering*>(clustering)->iteration_stats;
if (iteration_stats) {
*iteration_stats =
reinterpret_cast<FaissClusteringIterationStats*>(v.data());
}
if (size) {
*size = v.size();
}
}
/// the only mandatory parameters are k and d
int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k) {
try {
Clustering* c = new Clustering(d, k);
*p_clustering = reinterpret_cast<FaissClustering*>(c);
return 0;
}
CATCH_AND_HANDLE
}
int faiss_Clustering_new_with_params(
FaissClustering** p_clustering,
int d,
int k,
const FaissClusteringParameters* cp) {
try {
Clustering* c = new Clustering(d, k, from_faiss_c(cp));
*p_clustering = reinterpret_cast<FaissClustering*>(c);
return 0;
}
CATCH_AND_HANDLE
}
/// Index is used during the assignment stage
int faiss_Clustering_train(
FaissClustering* clustering,
idx_t n,
const float* x,
FaissIndex* index) {
try {
reinterpret_cast<Clustering*>(clustering)
->train(n, x, *reinterpret_cast<Index*>(index));
return 0;
}
CATCH_AND_HANDLE
}
void faiss_Clustering_free(FaissClustering* clustering) {
delete reinterpret_cast<Clustering*>(clustering);
}
int faiss_kmeans_clustering(
size_t d,
size_t n,
size_t k,
const float* x,
float* centroids,
float* q_error) {
try {
float out = faiss::kmeans_clustering(d, n, k, x, centroids);
if (q_error) {
*q_error = out;
}
return 0;
}
CATCH_AND_HANDLE
}
}