faiss/c_api/VectorTransform_c.cpp

228 lines
5.8 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include "VectorTransform_c.h"
#include <faiss/VectorTransform.h>
#include "macros_impl.h"
extern "C" {
DEFINE_DESTRUCTOR(VectorTransform)
DEFINE_GETTER(VectorTransform, int, is_trained)
DEFINE_GETTER(VectorTransform, int, d_in)
DEFINE_GETTER(VectorTransform, int, d_out)
int faiss_VectorTransform_train(
FaissVectorTransform* vt,
idx_t n,
const float* x) {
try {
reinterpret_cast<faiss::VectorTransform*>(vt)->train(n, x);
}
CATCH_AND_HANDLE
}
float* faiss_VectorTransform_apply(
const FaissVectorTransform* vt,
idx_t n,
const float* x) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)->apply(n, x);
}
void faiss_VectorTransform_apply_noalloc(
const FaissVectorTransform* vt,
idx_t n,
const float* x,
float* xt) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)->apply_noalloc(
n, x, xt);
}
void faiss_VectorTransform_reverse_transform(
const FaissVectorTransform* vt,
idx_t n,
const float* xt,
float* x) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)
->reverse_transform(n, xt, x);
}
/*********************************************
* LinearTransform
*********************************************/
DEFINE_DESTRUCTOR(LinearTransform)
DEFINE_GETTER(LinearTransform, int, have_bias)
DEFINE_GETTER(LinearTransform, int, is_orthonormal)
void faiss_LinearTransform_transform_transpose(
const FaissLinearTransform* vt,
idx_t n,
const float* y,
float* x) {
return reinterpret_cast<const faiss::LinearTransform*>(vt)
->transform_transpose(n, y, x);
}
void faiss_LinearTransform_set_is_orthonormal(FaissLinearTransform* vt) {
return reinterpret_cast<faiss::LinearTransform*>(vt)->set_is_orthonormal();
}
/*********************************************
* RandomRotationMatrix
*********************************************/
DEFINE_DESTRUCTOR(RandomRotationMatrix)
int faiss_RandomRotationMatrix_new_with(
FaissRandomRotationMatrix** p_vt,
int d_in,
int d_out) {
try {
*p_vt = reinterpret_cast<FaissRandomRotationMatrix*>(
new faiss::RandomRotationMatrix(d_in, d_out));
}
CATCH_AND_HANDLE
}
/*********************************************
* PCAMatrix
*********************************************/
DEFINE_DESTRUCTOR(PCAMatrix)
int faiss_PCAMatrix_new_with(
FaissPCAMatrix** p_vt,
int d_in,
int d_out,
float eigen_power,
int random_rotation) {
try {
bool random_rotation_ = static_cast<bool>(random_rotation);
*p_vt = reinterpret_cast<FaissPCAMatrix*>(new faiss::PCAMatrix(
d_in, d_out, eigen_power, random_rotation_));
}
CATCH_AND_HANDLE
}
DEFINE_GETTER(PCAMatrix, float, eigen_power)
DEFINE_GETTER(PCAMatrix, int, random_rotation)
/*********************************************
* ITQMatrix
*********************************************/
DEFINE_DESTRUCTOR(ITQMatrix)
int faiss_ITQMatrix_new_with(FaissITQMatrix** p_vt, int d) {
try {
*p_vt = reinterpret_cast<FaissITQMatrix*>(new faiss::ITQMatrix(d));
}
CATCH_AND_HANDLE
}
DEFINE_DESTRUCTOR(ITQTransform)
int faiss_ITQTransform_new_with(
FaissITQTransform** p_vt,
int d_in,
int d_out,
int do_pca) {
try {
bool do_pca_ = static_cast<bool>(do_pca);
*p_vt = reinterpret_cast<FaissITQTransform*>(
new faiss::ITQTransform(d_in, d_out, do_pca_));
}
CATCH_AND_HANDLE
}
DEFINE_GETTER(ITQTransform, int, do_pca)
/*********************************************
* OPQMatrix
*********************************************/
DEFINE_DESTRUCTOR(OPQMatrix)
int faiss_OPQMatrix_new_with(FaissOPQMatrix** p_vt, int d, int M, int d2) {
try {
*p_vt = reinterpret_cast<FaissOPQMatrix*>(
new faiss::OPQMatrix(d, M, d2));
}
CATCH_AND_HANDLE
}
DEFINE_GETTER(OPQMatrix, int, verbose)
DEFINE_SETTER(OPQMatrix, int, verbose)
DEFINE_GETTER(OPQMatrix, int, niter)
DEFINE_SETTER(OPQMatrix, int, niter)
DEFINE_GETTER(OPQMatrix, int, niter_pq)
DEFINE_SETTER(OPQMatrix, int, niter_pq)
/*********************************************
* RemapDimensionsTransform
*********************************************/
DEFINE_DESTRUCTOR(RemapDimensionsTransform)
int faiss_RemapDimensionsTransform_new_with(
FaissRemapDimensionsTransform** p_vt,
int d_in,
int d_out,
int uniform) {
try {
bool uniform_ = static_cast<bool>(uniform);
*p_vt = reinterpret_cast<FaissRemapDimensionsTransform*>(
new faiss::RemapDimensionsTransform(d_in, d_out, uniform_));
}
CATCH_AND_HANDLE
}
/*********************************************
* NormalizationTransform
*********************************************/
DEFINE_DESTRUCTOR(NormalizationTransform)
int faiss_NormalizationTransform_new_with(
FaissNormalizationTransform** p_vt,
int d,
float norm) {
try {
*p_vt = reinterpret_cast<FaissNormalizationTransform*>(
new faiss::NormalizationTransform(d, norm));
}
CATCH_AND_HANDLE
}
DEFINE_GETTER(NormalizationTransform, float, norm)
/*********************************************
* CenteringTransform
*********************************************/
DEFINE_DESTRUCTOR(CenteringTransform)
int faiss_CenteringTransform_new_with(FaissCenteringTransform** p_vt, int d) {
try {
*p_vt = reinterpret_cast<FaissCenteringTransform*>(
new faiss::CenteringTransform(d));
}
CATCH_AND_HANDLE
}
}