/**
 * Copyright (c) 2015-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD+Patents license found in the
 * LICENSE file in the root directory of this source tree.
 */

// Copyright 2004-present Facebook. All Rights Reserved
// -*- c++ -*-

#include "VectorTransform.h"

#include <cstdio>
#include <cmath>
#include <cstring>

#include "utils.h"
#include "FaissAssert.h"
#include "IndexPQ.h"

using namespace faiss;


extern "C" {

// this is to keep the clang syntax checker happy
#ifndef FINTEGER
#define FINTEGER int
#endif


/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */

int sgemm_ (
        const char *transa, const char *transb, FINTEGER *m, FINTEGER *
        n, FINTEGER *k, const float *alpha, const float *a,
        FINTEGER *lda, const float *b,
        FINTEGER *ldb, float *beta,
        float *c, FINTEGER *ldc);

int ssyrk_ (
        const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
        float *alpha, float *a, FINTEGER *lda,
        float *beta, float *c, FINTEGER *ldc);

/* Lapack functions from http://www.netlib.org/clapack/old/single/ */

int ssyev_ (
        const char *jobz, const char *uplo, FINTEGER *n, float *a,
        FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
        FINTEGER *info);

int dsyev_ (
        const char *jobz, const char *uplo, FINTEGER *n, double *a,
        FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
        FINTEGER *info);

int sgesvd_(
        const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
        float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
        FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);

}

/*********************************************
 * VectorTransform
 *********************************************/



float * VectorTransform::apply (Index::idx_t n, const float * x) const
{
    float * xt = new float[n * d_out];
    apply_noalloc (n, x, xt);
    return xt;
}


void VectorTransform::train (idx_t, const float *) {
    // does nothing by default
}


void VectorTransform::reverse_transform (
             idx_t , const float *,
             float *) const
{
    FAISS_THROW_MSG ("reverse transform not implemented");
}




/*********************************************
 * LinearTransform
 *********************************************/

/// both d_in > d_out and d_out < d_in are supported
LinearTransform::LinearTransform (int d_in, int d_out,
                                  bool have_bias):
    VectorTransform (d_in, d_out), have_bias (have_bias),
    is_orthonormal (false), verbose (false)
{}

void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
                               float * xt) const
{
    FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");

    float c_factor;
    if (have_bias) {
        FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
        float * xi = xt;
        for (int i = 0; i < n; i++)
            for(int j = 0; j < d_out; j++)
                *xi++ = b[j];
        c_factor = 1.0;
    } else {
        c_factor = 0.0;
    }

    FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
                      "Transformation matrix not initialized");

    float one = 1;
    FINTEGER nbiti = d_out, ni = n, di = d_in;
    sgemm_ ("Transposed", "Not transposed",
            &nbiti, &ni, &di,
            &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);

}


void LinearTransform::transform_transpose (idx_t n, const float * y,
                                           float *x) const
{
    if (have_bias) { // allocate buffer to store bias-corrected data
        float *y_new = new float [n * d_out];
        const float *yr = y;
        float *yw = y_new;
        for (idx_t i = 0; i < n; i++) {
            for (int j = 0; j < d_out; j++) {
                *yw++ = *yr++ - b [j];
            }
        }
        y = y_new;
    }

    {
        FINTEGER dii = d_in, doi = d_out, ni = n;
        float one = 1.0, zero = 0.0;
        sgemm_ ("Not", "Not", &dii, &ni, &doi,
                &one, A.data (), &dii, y, &doi, &zero, x, &dii);
    }

    if (have_bias) delete [] y;
}

void LinearTransform::set_is_orthonormal ()
{
    if (d_out > d_in) {
        // not clear what we should do in this case
        is_orthonormal = false;
        return;
    }
    if (d_out == 0) { // borderline case, unnormalized matrix
        is_orthonormal = true;
        return;
    }

    double eps = 4e-5;
    FAISS_ASSERT(A.size() >= d_out * d_in);
    {
        std::vector<float> ATA(d_out * d_out);
        FINTEGER dii = d_in, doi = d_out;
        float one = 1.0, zero = 0.0;

        sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
                &one, A.data (), &dii,
                A.data(), &dii,
                &zero, ATA.data(), &doi);

        is_orthonormal = true;
        for (long i = 0; i < d_out; i++) {
            for (long j = 0; j < d_out; j++) {
                float v = ATA[i + j * d_out];
                if (i == j) v-= 1;
                if (fabs(v) > eps) {
                    is_orthonormal = false;
                }
            }
        }
    }

}


void LinearTransform::reverse_transform (idx_t n, const float * xt,
                                         float *x) const
{
    if (is_orthonormal) {
        transform_transpose (n, xt, x);
    } else {
        FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
    }
}



/*********************************************
 * RandomRotationMatrix
 *********************************************/

void RandomRotationMatrix::init (int seed)
{

    if(d_out <= d_in) {
        A.resize (d_out * d_in);
        float *q = A.data();
        float_randn(q, d_out * d_in, seed);
        matrix_qr(d_in, d_out, q);
    } else {
        A.resize (d_out * d_out);
        float *q = A.data();
        float_randn(q, d_out * d_out, seed);
        matrix_qr(d_out, d_out, q);
        // remove columns
        int i, j;
        for (i = 0; i < d_out; i++) {
            for(j = 0; j < d_in; j++) {
                q[i * d_in + j] = q[i * d_out + j];
            }
        }
        A.resize(d_in * d_out);
    }
    is_orthonormal = true;
}

/*********************************************
 * PCAMatrix
 *********************************************/

PCAMatrix::PCAMatrix (int d_in, int d_out,
                      float eigen_power, bool random_rotation):
    LinearTransform(d_in, d_out, true),
    eigen_power(eigen_power), random_rotation(random_rotation)
{
    is_trained = false;
    max_points_per_d = 1000;
    balanced_bins = 0;
}


namespace {

/// Compute the eigenvalue decomposition of symmetric matrix cov,
/// dimensions d_in-by-d_in. Output eigenvectors in cov.

void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
{
    { // compute eigenvalues and vectors
        FINTEGER info = 0, lwork = -1, di = d_in;
        double workq;

        dsyev_ ("Vectors as well", "Upper",
                &di, cov, &di, eigenvalues, &workq, &lwork, &info);
        lwork = FINTEGER(workq);
        double *work = new double[lwork];

        dsyev_ ("Vectors as well", "Upper",
                &di, cov, &di, eigenvalues, work, &lwork, &info);

        delete [] work;

        if (info != 0) {
            fprintf (stderr, "WARN ssyev info returns %d, "
                     "a very bad PCA matrix is learnt\n",
                     int(info));
            // do not throw exception, as the matrix could still be useful
        }


        if(verbose && d_in <= 10) {
            printf("info=%ld new eigvals=[", long(info));
            for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
            printf("]\n");

            double *ci = cov;
            printf("eigenvecs=\n");
            for(int i = 0; i < d_in; i++) {
                for(int j = 0; j < d_in; j++)
                    printf("%10.4g ", *ci++);
                printf("\n");
            }
        }

    }

    // revert order of eigenvectors & values

    for(int i = 0; i < d_in / 2; i++) {

        std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
        double *v1 = cov + i * d_in;
        double *v2 = cov + (d_in - 1 - i) * d_in;
        for(int j = 0; j < d_in; j++)
            std::swap(v1[j], v2[j]);
    }

}


}

void PCAMatrix::train (Index::idx_t n, const float *x)
{
    const float * x_in = x;

    x = fvecs_maybe_subsample (d_in, (size_t*)&n,
                               max_points_per_d * d_in, x, verbose);

    ScopeDeleter<float> del_x (x != x_in ? x : nullptr);

    // compute mean
    mean.clear(); mean.resize(d_in, 0.0);
    if (have_bias) { // we may want to skip the bias
        const float *xi = x;
        for (int i = 0; i < n; i++) {
            for(int j = 0; j < d_in; j++)
                mean[j] += *xi++;
        }
        for(int j = 0; j < d_in; j++)
            mean[j] /= n;
    }
    if(verbose) {
        printf("mean=[");
        for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
        printf("]\n");
    }

    if(n >= d_in) {
        // compute covariance matrix, store it in PCA matrix
        PCAMat.resize(d_in * d_in);
        float * cov = PCAMat.data();
        { // initialize with  mean * mean^T term
            float *ci = cov;
            for(int i = 0; i < d_in; i++) {
                for(int j = 0; j < d_in; j++)
                    *ci++ = - n * mean[i] * mean[j];
            }
        }
        {
            FINTEGER di = d_in, ni = n;
            float one = 1.0;
            ssyrk_ ("Up", "Non transposed",
                    &di, &ni, &one, (float*)x, &di, &one, cov, &di);

        }
        if(verbose && d_in <= 10) {
            float *ci = cov;
            printf("cov=\n");
            for(int i = 0; i < d_in; i++) {
                for(int j = 0; j < d_in; j++)
                    printf("%10g ", *ci++);
                printf("\n");
            }
        }

        std::vector<double> covd (d_in * d_in);
        for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];

        std::vector<double> eigenvaluesd (d_in);

        eig (d_in, covd.data (), eigenvaluesd.data (), verbose);

        for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
        eigenvalues.resize (d_in);

        for (size_t i = 0; i < d_in; i++)
            eigenvalues [i] = eigenvaluesd [i];


    } else {

        std::vector<float> xc (n * d_in);

        for (size_t i = 0; i < n; i++)
            for(size_t j = 0; j < d_in; j++)
                xc [i * d_in + j] = x [i * d_in + j] - mean[j];

        // compute Gram matrix
        std::vector<float> gram (n * n);
        {
            FINTEGER di = d_in, ni = n;
            float one = 1.0, zero = 0.0;
            ssyrk_ ("Up", "Transposed",
                    &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
        }

        if(verbose && d_in <= 10) {
            float *ci = gram.data();
            printf("gram=\n");
            for(int i = 0; i < n; i++) {
                for(int j = 0; j < n; j++)
                    printf("%10g ", *ci++);
                printf("\n");
            }
        }

        std::vector<double> gramd (n * n);
        for (size_t i = 0; i < n * n; i++)
            gramd [i] = gram [i];

        std::vector<double> eigenvaluesd (n);

        // eig will fill in only the n first eigenvals

        eig (n, gramd.data (), eigenvaluesd.data (), verbose);

        PCAMat.resize(d_in * n);

        for (size_t i = 0; i < n * n; i++)
            gram [i] = gramd [i];

        eigenvalues.resize (d_in);
        // fill in only the n first ones
        for (size_t i = 0; i < n; i++)
            eigenvalues [i] = eigenvaluesd [i];

        { // compute PCAMat = x' * v
            FINTEGER di = d_in, ni = n;
            float one = 1.0;

            sgemm_ ("Non", "Non Trans",
                    &di, &ni, &ni,
                    &one, xc.data(), &di, gram.data(), &ni,
                    &one, PCAMat.data(), &di);
        }

        if(verbose && d_in <= 10) {
            float *ci = PCAMat.data();
            printf("PCAMat=\n");
            for(int i = 0; i < n; i++) {
                for(int j = 0; j < d_in; j++)
                    printf("%10g ", *ci++);
                printf("\n");
            }
        }
        fvec_renorm_L2 (d_in, n, PCAMat.data());

    }

    prepare_Ab();
    is_trained = true;
}

void PCAMatrix::copy_from (const PCAMatrix & other)
{
    FAISS_THROW_IF_NOT (other.is_trained);
    mean = other.mean;
    eigenvalues = other.eigenvalues;
    PCAMat = other.PCAMat;
    prepare_Ab ();
    is_trained = true;
}

void PCAMatrix::prepare_Ab ()
{
    FAISS_THROW_IF_NOT_FMT (
            d_out * d_in <= PCAMat.size(),
            "PCA matrix cannot output %d dimensions from %d ",
            d_out, d_in);

    if (!random_rotation) {
        A = PCAMat;
        A.resize(d_out * d_in); // strip off useless dimensions

        // first scale the components
        if (eigen_power != 0) {
            float *ai = A.data();
            for (int i = 0; i < d_out; i++) {
                float factor = pow(eigenvalues[i], eigen_power);
                for(int j = 0; j < d_in; j++)
                    *ai++ *= factor;
            }
        }

        if (balanced_bins != 0) {
            FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
            int dsub = d_out / balanced_bins;
            std::vector <float> Ain;
            std::swap(A, Ain);
            A.resize(d_out * d_in);

            std::vector <float> accu(balanced_bins);
            std::vector <int> counter(balanced_bins);

            // greedy assignment
            for (int i = 0; i < d_out; i++) {
                // find best bin
                int best_j = -1;
                float min_w = 1e30;
                for (int j = 0; j < balanced_bins; j++) {
                    if (counter[j] < dsub && accu[j] < min_w) {
                        min_w = accu[j];
                        best_j = j;
                    }
                }
                int row_dst = best_j * dsub + counter[best_j];
                accu[best_j] += eigenvalues[i];
                counter[best_j] ++;
                memcpy (&A[row_dst * d_in], &Ain[i * d_in],
                        d_in * sizeof (A[0]));
            }

            if (verbose) {
                printf("  bin accu=[");
                for (int i = 0; i < balanced_bins; i++)
                    printf("%g ", accu[i]);
                printf("]\n");
            }
        }


    } else {
        FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
             "both balancing bins and applying a random rotation "
             "does not make sense");
        RandomRotationMatrix rr(d_out, d_out);

        rr.init(5);

        // apply scaling on the rotation matrix (right multiplication)
        if (eigen_power != 0) {
            for (int i = 0; i < d_out; i++) {
                float factor = pow(eigenvalues[i], eigen_power);
                for(int j = 0; j < d_out; j++)
                   rr.A[j * d_out + i] *= factor;
            }
        }

        A.resize(d_in * d_out);
        {
            FINTEGER dii = d_in, doo = d_out;
            float one = 1.0, zero = 0.0;

            sgemm_ ("Not", "Not", &dii, &doo, &doo,
                    &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
                    A.data(), &dii);

        }

    }

    b.clear(); b.resize(d_out);

    for (int i = 0; i < d_out; i++) {
        float accu = 0;
        for (int j = 0; j < d_in; j++)
            accu -= mean[j] * A[j + i * d_in];
        b[i] = accu;
    }

    is_orthonormal = eigen_power == 0;

}

/*********************************************
 * OPQMatrix
 *********************************************/


OPQMatrix::OPQMatrix (int d, int M, int d2):
    LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
    niter (50),
    niter_pq (4), niter_pq_0 (40),
    verbose(false)
{
    is_trained = false;
    // OPQ is quite expensive to train, so set this right.
    max_train_points = 256 * 256;
}



void OPQMatrix::train (Index::idx_t n, const float *x)
{

    const float * x_in = x;

    x = fvecs_maybe_subsample (d_in, (size_t*)&n,
                               max_train_points, x, verbose);

    ScopeDeleter<float> del_x (x != x_in ? x : nullptr);

    // To support d_out > d_in, we pad input vectors with 0s to d_out
    size_t d = d_out <= d_in ? d_in : d_out;
    size_t d2 = d_out;

#if 0
    // what this test shows: the only way of getting bit-exact
    // reproducible results with sgeqrf and sgesvd seems to be forcing
    // single-threading.
    { // test repro
        std::vector<float> r (d * d);
        float * rotation = r.data();
        float_randn (rotation, d * d, 1234);
        printf("CS0: %016lx\n",
               ivec_checksum (128*128, (int*)rotation));
        matrix_qr (d, d, rotation);
        printf("CS1: %016lx\n",
               ivec_checksum (128*128, (int*)rotation));
        return;
    }
#endif

    if (verbose) {
        printf ("OPQMatrix::train: training an OPQ rotation matrix "
                "for M=%d from %ld vectors in %dD -> %dD\n",
                M, n, d_in, d_out);
    }

    std::vector<float> xtrain (n * d);
    // center x
    {
        std::vector<float> sum (d);
        const float *xi = x;
        for (size_t i = 0; i < n; i++) {
            for (int j = 0; j < d_in; j++)
                sum [j] += *xi++;
        }
        for (int i = 0; i < d; i++) sum[i] /= n;
        float *yi = xtrain.data();
        xi = x;
        for (size_t i = 0; i < n; i++) {
            for (int j = 0; j < d_in; j++)
                *yi++ = *xi++ - sum[j];
            yi += d - d_in;
        }
    }
    float *rotation;

    if (A.size () == 0) {
        A.resize (d * d);
        rotation = A.data();
        if (verbose)
            printf("  OPQMatrix::train: making random %ld*%ld rotation\n",
                   d, d);
        float_randn (rotation, d * d, 1234);
        matrix_qr (d, d, rotation);
        // we use only the d * d2 upper part of the matrix
        A.resize (d * d2);
    } else {
        FAISS_THROW_IF_NOT (A.size() == d * d2);
        rotation = A.data();
    }


    std::vector<float>
        xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
        tmp(d * d * 4);

    std::vector<uint8_t> codes (M * n);
    ProductQuantizer pq_regular (d2, M, 8);
    double t0 = getmillisecs();
    for (int iter = 0; iter < niter; iter++) {

        { // torch.mm(xtrain, rotation:t())
            FINTEGER di = d, d2i = d2, ni = n;
            float zero = 0, one = 1;
            sgemm_ ("Transposed", "Not transposed",
                    &d2i, &ni, &di,
                    &one, rotation, &di,
                    xtrain.data(), &di,
                    &zero, xproj.data(), &d2i);
        }

        pq_regular.cp.max_points_per_centroid = 1000;
        pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
        pq_regular.cp.verbose = verbose;
        pq_regular.train (n, xproj.data());

        pq_regular.compute_codes (xproj.data(), codes.data(), n);
        pq_regular.decode (codes.data(), pq_recons.data(), n);

        float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;

        if (verbose)
            printf ("    Iteration %d (%d PQ iterations):"
                    "%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
                    (getmillisecs () - t0) / 1000.0, pq_err);

        {
            float *u = tmp.data(), *vt = &tmp [d * d];
            float *sing_val = &tmp [2 * d * d];
            FINTEGER di = d, d2i = d2, ni = n;
            float one = 1, zero = 0;

            // torch.mm(xtrain:t(), pq_recons)
            sgemm_ ("Not", "Transposed",
                    &d2i, &di, &ni,
                    &one, pq_recons.data(), &d2i,
                    xtrain.data(), &di,
                    &zero, xxr.data(), &d2i);


            FINTEGER lwork = -1, info = -1;
            float worksz;
            // workspace query
            sgesvd_ ("All", "All",
                     &d2i, &di, xxr.data(), &d2i,
                     sing_val,
                     vt, &d2i, u, &di,
                     &worksz, &lwork, &info);

            lwork = int(worksz);
            std::vector<float> work (lwork);
            // u and vt swapped
            sgesvd_ ("All", "All",
                     &d2i, &di, xxr.data(), &d2i,
                     sing_val,
                     vt, &d2i, u, &di,
                     work.data(), &lwork, &info);

            sgemm_ ("Transposed", "Transposed",
                    &di, &d2i, &d2i,
                    &one, u, &di, vt, &d2i,
                    &zero, rotation, &di);

        }
        pq_regular.train_type = ProductQuantizer::Train_hot_start;
    }

    // revert A matrix
    if (d > d_in) {
        for (long i = 0; i < d_out; i++)
            memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
        A.resize (d_in * d_out);
    }

    is_trained = true;
    is_orthonormal = true;
}


/*********************************************
 * NormalizationTransform
 *********************************************/

NormalizationTransform::NormalizationTransform (int d, float norm):
    VectorTransform (d, d), norm (norm)
{
}

NormalizationTransform::NormalizationTransform ():
    VectorTransform (-1, -1), norm (-1)
{
}

void NormalizationTransform::apply_noalloc
      (idx_t n, const float* x, float* xt) const
{
    if (norm == 2.0) {
        memcpy (xt, x, sizeof (x[0]) * n * d_in);
        fvec_renorm_L2 (d_in, n, xt);
    } else {
        FAISS_THROW_MSG ("not implemented");
    }
}

void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
                                                float* x) const
{
    memcpy (x, xt, sizeof (xt[0]) * n * d_in);
}

/*********************************************
 * IndexPreTransform
 *********************************************/

IndexPreTransform::IndexPreTransform ():
    index(nullptr), own_fields (false)
{
}


IndexPreTransform::IndexPreTransform (
        Index * index):
    Index (index->d, index->metric_type),
    index (index), own_fields (false)
{
    is_trained = index->is_trained;
}


IndexPreTransform::IndexPreTransform (
        VectorTransform * ltrans,
        Index * index):
    Index (index->d, index->metric_type),
    index (index), own_fields (false)
{
    is_trained = index->is_trained;
    prepend_transform (ltrans);
}

void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
{
    FAISS_THROW_IF_NOT (ltrans->d_out == d);
    is_trained = is_trained && ltrans->is_trained;
    chain.insert (chain.begin(), ltrans);
    d = ltrans->d_in;
}


IndexPreTransform::~IndexPreTransform ()
{
    if (own_fields) {
        for (int i = 0; i < chain.size(); i++)
            delete chain[i];
        delete index;
    }
}




void IndexPreTransform::train (idx_t n, const float *x)
{
    int last_untrained = 0;
    if (!index->is_trained) {
        last_untrained = chain.size();
    } else {
        for (int i = chain.size() - 1; i >= 0; i--) {
            if (!chain[i]->is_trained) {
                last_untrained = i;
                break;
            }
        }
    }
    const float *prev_x = x;
    ScopeDeleter<float> del;

    if (verbose) {
        printf("IndexPreTransform::train: training chain 0 to %d\n",
               last_untrained);
    }

    for (int i = 0; i <= last_untrained; i++) {

        if (i < chain.size()) {
            VectorTransform *ltrans = chain [i];
            if (!ltrans->is_trained) {
                if (verbose) {
                    printf("   Training chain component %d/%zd\n",
                           i, chain.size());
                    if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
                        opqm->verbose = true;
                    }
                }
                ltrans->train (n, prev_x);
            }
        } else {
            if (verbose) {
                printf("   Training sub-index\n");
            }
            index->train (n, prev_x);
        }
        if (i == last_untrained) break;
        if (verbose) {
            printf("   Applying transform %d/%zd\n",
                   i, chain.size());
        }

        float * xt = chain[i]->apply (n, prev_x);

        if (prev_x != x) delete [] prev_x;
        prev_x = xt;
        del.set(xt);
    }

    is_trained = true;
}


const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
{
    const float *prev_x = x;
    ScopeDeleter<float> del;

    for (int i = 0; i < chain.size(); i++) {
        float * xt = chain[i]->apply (n, prev_x);
        ScopeDeleter<float> del2 (xt);
        del2.swap (del);
        prev_x = xt;
    }
    del.release ();
    return prev_x;
}

void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
{
    const float* next_x = xt;
    ScopeDeleter<float> del;

    for (int i = chain.size() - 1; i >= 0; i--) {
        float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
        ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
        chain [i]->reverse_transform (n, next_x, prev_x);
        del2.swap (del);
        next_x = prev_x;
    }
}

void IndexPreTransform::add (idx_t n, const float *x)
{
    FAISS_THROW_IF_NOT (is_trained);
    const float *xt = apply_chain (n, x);
    ScopeDeleter<float> del(xt == x ? nullptr : xt);
    index->add (n, xt);
    ntotal = index->ntotal;
}

void IndexPreTransform::add_with_ids (idx_t n, const float * x,
                                      const long *xids)
{
    FAISS_THROW_IF_NOT (is_trained);
    const float *xt = apply_chain (n, x);
    ScopeDeleter<float> del(xt == x ? nullptr : xt);
    index->add_with_ids (n, xt, xids);
    ntotal = index->ntotal;
}




void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
                               float *distances, idx_t *labels) const
{
    FAISS_THROW_IF_NOT (is_trained);
    const float *xt = apply_chain (n, x);
    ScopeDeleter<float> del(xt == x ? nullptr : xt);
    index->search (n, xt, k, distances, labels);
}


void IndexPreTransform::reset () {
    index->reset();
    ntotal = 0;
}

long IndexPreTransform::remove_ids (const IDSelector & sel) {
    long nremove = index->remove_ids (sel);
    ntotal = index->ntotal;
    return nremove;
}


void IndexPreTransform::reconstruct (idx_t key, float * recons) const
{
    float *x = chain.empty() ? recons : new float [index->d];
    ScopeDeleter<float> del (recons == x ? nullptr : x);
    // Initial reconstruction
    index->reconstruct (key, x);

    // Revert transformations from last to first
    reverse_chain (1, x, recons);
}


void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
{
    float *x = chain.empty() ? recons : new float [ni * index->d];
    ScopeDeleter<float> del (recons == x ? nullptr : x);
    // Initial reconstruction
    index->reconstruct_n (i0, ni, x);

    // Revert transformations from last to first
    reverse_chain (ni, x, recons);
}


void IndexPreTransform::search_and_reconstruct (
      idx_t n, const float *x, idx_t k,
      float *distances, idx_t *labels, float* recons) const
{
    FAISS_THROW_IF_NOT (is_trained);

    const float* xt = apply_chain (n, x);
    ScopeDeleter<float> del ((xt == x) ? nullptr : xt);

    float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
    ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
    index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);

    // Revert transformations from last to first
    reverse_chain (n * k, recons_temp, recons);
}


/*********************************************
 * RemapDimensionsTransform
 *********************************************/


RemapDimensionsTransform::RemapDimensionsTransform (
        int d_in, int d_out, const int *map_in):
    VectorTransform (d_in, d_out)
{
    map.resize (d_out);
    for (int i = 0; i < d_out; i++) {
        map[i] = map_in[i];
        FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
    }
}

RemapDimensionsTransform::RemapDimensionsTransform (
      int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
{
    map.resize (d_out, -1);

    if (uniform) {
        if (d_in < d_out) {
            for (int i = 0; i < d_in; i++) {
                map [i * d_out / d_in] = i;
        }
        } else {
            for (int i = 0; i < d_out; i++) {
                map [i] = i * d_in / d_out;
            }
        }
    } else {
        for (int i = 0; i < d_in && i < d_out; i++)
            map [i] = i;
    }
}


void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
                                              float *xt) const
{
    for (idx_t i = 0; i < n; i++) {
        for (int j = 0; j < d_out; j++) {
            xt[j] = map[j] < 0 ? 0 : x[map[j]];
        }
        x += d_in;
        xt += d_out;
    }
}

void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
                                                  float *x) const
{
    memset (x, 0, sizeof (*x) * n * d_in);
    for (idx_t i = 0; i < n; i++) {
        for (int j = 0; j < d_out; j++) {
            if (map[j] >= 0) x[map[j]] = xt[j];
        }
        x += d_in;
        xt += d_out;
    }
}