diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d540d8aa..b452af6ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ We try to indicate most contributions here with the contributor names who are no the Facebook Faiss team. Feel free to add entries here if you submit a PR. ## [Unreleased] +### Added +- Support LSQ on GPU (by @KinglittleQ) ## [1.7.1] - 2021-05-27 ### Added diff --git a/benchs/bench_quantizer.py b/benchs/bench_quantizer.py index 0b3aaf100..54c710ada 100644 --- a/benchs/bench_quantizer.py +++ b/benchs/bench_quantizer.py @@ -82,6 +82,13 @@ nt, d = xt.shape # fastest to slowest +if 'lsq-gpu' in todo: + lsq = faiss.LocalSearchQuantizer(d, M, nbits) + ngpus = faiss.get_num_gpus() + lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus) + lsq.verbose = True + eval_quantizer(lsq, xb, xt, 'lsq-gpu') + if 'pq' in todo: pq = faiss.ProductQuantizer(d, M, nbits) print("===== PQ") diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index d8e22fa75..232590b1e 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -9,6 +9,7 @@ set(FAISS_GPU_SRC GpuCloner.cpp GpuClonerOptions.cpp GpuDistance.cu + GpuIcmEncoder.cu GpuIndex.cu GpuIndexBinaryFlat.cu GpuIndexFlat.cu @@ -46,6 +47,7 @@ set(FAISS_GPU_SRC impl/scan/IVFInterleaved512.cu impl/scan/IVFInterleaved1024.cu impl/scan/IVFInterleaved2048.cu + impl/IcmEncoder.cu utils/BlockSelectFloat.cu utils/DeviceUtils.cu utils/StackDeviceMemory.cpp @@ -80,6 +82,7 @@ set(FAISS_GPU_HEADERS GpuCloner.h GpuClonerOptions.h GpuDistance.h + GpuIcmEncoder.h GpuFaissAssert.h GpuIndex.h GpuIndexBinaryFlat.h @@ -118,6 +121,7 @@ set(FAISS_GPU_HEADERS impl/RemapIndices.h impl/VectorResidual.cuh impl/scan/IVFInterleavedImpl.cuh + impl/IcmEncoder.cuh utils/BlockSelectKernel.cuh utils/Comparators.cuh utils/ConversionOperators.cuh diff --git a/faiss/gpu/GpuIcmEncoder.cu b/faiss/gpu/GpuIcmEncoder.cu new file mode 100644 index 000000000..434fae9e3 --- /dev/null +++ b/faiss/gpu/GpuIcmEncoder.cu @@ -0,0 +1,120 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include + +namespace faiss { +namespace gpu { + +///< A helper structure to support multi-GPU +struct IcmEncoderShards { + std::vector, + std::unique_ptr>> + workers; + + void add(IcmEncoderImpl* encoder) { + workers.emplace_back(std::make_pair( + std::unique_ptr(encoder), + std::unique_ptr(new WorkerThread))); + } + + IcmEncoderImpl* at(int idx) { + return workers[idx].first.get(); + } + + ///< call f(idx, encoder) for each encoder + void runOnShards(std::function f) { + std::vector> v; + + for (int i = 0; i < this->workers.size(); ++i) { + auto& p = this->workers[i]; + auto encoder = p.first.get(); + v.emplace_back(p.second->add([f, i, encoder]() { f(i, encoder); })); + } + + for (int i = 0; i < v.size(); ++i) { + auto& fut = v[i]; + fut.get(); // no exception handle, crash if any thread down + } + } + + size_t size() { + return workers.size(); + } +}; + +GpuIcmEncoder::GpuIcmEncoder( + const LocalSearchQuantizer* lsq, + const std::vector& provs, + const std::vector& devices) + : lsq::IcmEncoder(lsq), shards(new IcmEncoderShards()) { + // create an IcmEncoderImpl instance for each device. + for (size_t i = 0; i < provs.size(); i++) { + shards->add(new IcmEncoderImpl( + lsq->M, lsq->K, lsq->d, provs[i], devices[i])); + } +} + +GpuIcmEncoder::~GpuIcmEncoder() {} + +void GpuIcmEncoder::set_binary_term() { + auto fn = [=](int idx, IcmEncoderImpl* encoder) { + encoder->setBinaryTerm(lsq->codebooks.data()); + }; + shards->runOnShards(fn); +} + +void GpuIcmEncoder::encode( + int32_t* codes, + const float* x, + std::mt19937& gen, + size_t n, + size_t ils_iters) const { + size_t nshards = shards->size(); + size_t shard_size = (n + nshards - 1) / nshards; + + auto codebooks = lsq->codebooks.data(); + auto M = lsq->M; + auto d = lsq->d; + auto nperts = lsq->nperts; + auto icm_iters = lsq->icm_iters; + + auto seed = gen(); + + // split input data + auto fn = [=](int idx, IcmEncoderImpl* encoder) { + size_t i0 = idx * shard_size; + size_t ni = std::min(shard_size, n - i0); + auto xi = x + i0 * d; + auto ci = codes + i0 * M; + std::mt19937 geni(idx + seed); // different seed for each shard + encoder->encode( + ci, xi, codebooks, geni, ni, nperts, ils_iters, icm_iters); + }; + shards->runOnShards(fn); +} + +GpuIcmEncoderFactory::GpuIcmEncoderFactory(int ngpus) { + for (int i = 0; i < ngpus; i++) { + provs.push_back(new StandardGpuResources()); + devices.push_back(i); + } +} + +lsq::IcmEncoder* GpuIcmEncoderFactory::get(const LocalSearchQuantizer* lsq) { + return new GpuIcmEncoder(lsq, provs, devices); +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/GpuIcmEncoder.h b/faiss/gpu/GpuIcmEncoder.h new file mode 100644 index 000000000..463027bb6 --- /dev/null +++ b/faiss/gpu/GpuIcmEncoder.h @@ -0,0 +1,60 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace faiss { +namespace gpu { + +class GpuResourcesProvider; +struct IcmEncoderShards; + +/** Perform LSQ encoding on GPU. + * + * Split input vectors to different devices and call IcmEncoderImpl::encode + * to encode them + */ +class GpuIcmEncoder : public lsq::IcmEncoder { + public: + GpuIcmEncoder( + const LocalSearchQuantizer* lsq, + const std::vector& provs, + const std::vector& devices); + + ~GpuIcmEncoder(); + + GpuIcmEncoder(const GpuIcmEncoder&) = delete; + GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete; + + void set_binary_term() override; + + void encode( + int32_t* codes, + const float* x, + std::mt19937& gen, + size_t n, + size_t ils_iters) const override; + + private: + std::unique_ptr shards; +}; + +struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory { + explicit GpuIcmEncoderFactory(int ngpus = 1); + + lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override; + + std::vector provs; + std::vector devices; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/IcmEncoder.cu b/faiss/gpu/impl/IcmEncoder.cu new file mode 100644 index 000000000..81f234f2b --- /dev/null +++ b/faiss/gpu/impl/IcmEncoder.cu @@ -0,0 +1,379 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace faiss { +namespace gpu { + +extern __shared__ char smem[]; + +/** encode using iterative conditional mode + * + * For subcode cm of a vector, we fix the other subcodes cj (j != m) + * and then find the optimal value of cm (cm = 1,...,K) such that + * minimizing the objective function. + * + * @param uterm precomputed unary terms, size (M, n, K) + * @param bterm precomputed binary terms, size (M1, M2, K1, K2) + * @param codes output vector encodings, size (n, M) + * @param M number of codebooks + * @param K number of codewords in a codebook + * @param m identify which subcode to condition on + */ +__global__ void runIcmEncodeStep( + const float* uterm, + const float* bterm, + int32_t* codes, + int M, + int K, + int m) { + using KVPair = Pair; + + int id = blockIdx.x; // each block takes care of one vector + int code = threadIdx.x; // each thread takes care of one possible code + + // compute the objective value by look-up tables + KVPair obj(0.0f, code); + obj.k = uterm[id * K + code]; + +#pragma unroll + for (int m2 = 0; m2 < M; m2++) { + if (m2 == m) { + continue; + } + int32_t code2 = codes[id * M + m2]; + obj.k += bterm[m2 * K * K + code * K + code2]; + } + + // find the minimum objective value and the corresponding code + __syncthreads(); + obj = blockReduceAll, false, false>( + obj, Min(), (KVPair*)smem); + + if (code == 0) { + codes[id * M + m] = obj.v; + } +} + +/** compute reconstruction error for each vector + * + * decoded_x[i] = \sum codebooks[m][codes[i][m]], m = 1,..,M + * obj[i] = ||x[i] - decoded_x[i]||^2 + * + * @param x input vectors, size [n, dims] + * @param codebooks codebooks, size [M, K, dims] + * @param codes vector codes, size [n, M] + * @param obj output reconstruction errors, size [n] + * @param n number of input vectors + * @param K number of codewords in a codebook + * @param M number of codebooks + */ +__global__ void runEvaluation( + const float* x, + const float* codebooks, + const int32_t* codes, + float* obj, // output + int n, + int M, + int K, + int dims) { + int id = blockIdx.x; // each block takes care of one vector + int d = threadIdx.x; // each thread takes care of one dimension + float acc = 0.0f; + +#pragma unroll + for (int m = 0; m < M; m++) { + int32_t code = codes[id * M + m]; + acc += codebooks[m * K * dims + code * dims + d]; + } + + acc -= x[id * dims + d]; + acc = acc * acc; + + // sum values of all dimensions together + __syncthreads(); + acc = blockReduceAllSum(acc, (float*)smem); + + if (d == 0) { + obj[id] = acc; + } +} + +/** perturb vector codes + * + * repeat nperts times: + * codes[i][randint(0, M)] = randint(0, K) + * + * @param seed random seed + * @param codes vector codes, size [n, M] + * @param n number of input vectors + * @param M number of codebooks + * @param K number of codewords in a codebook + * @param nperts number of subcode to be perturbed in a vector + */ +__global__ void runCodesPerturbation( + int seed, + int32_t* codes, + int n, + int M, + int K, + int nperts) { + // each thread takes care of one vector + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id >= n) { + return; + } + + // we have to initialize the state + curandState_t state; + curand_init(seed, id, 0, &state); + + for (int i = 0; i < nperts; i++) { + int pos = int(curand_uniform(&state) * M); + int32_t val = int32_t(curand_uniform(&state) * K); + codes[id * M + pos] = val; + } +} + +/** select the best codes by reconstruction errors + * + * if objs[i] < best_objs[i]: + * best_objs[i] = objs[i] + * best_codes[i] = codes[i] + * + * @param bestCodes the best codes we've encountered, size [n, M] + * @param bestObjs min reconstruction errors we've encountered, size [n] + * @param codes input vector codes, size [n, M] + * @param objs reconstruction errors of input vector codes, size [n] + * @param n number of input vectors + */ +__global__ void runCodesSelection( + int32_t* bestCodes, + float* bestObjs, + const int32_t* codes, + const float* objs, + int n, + int M) { + // each thread takes care of one vector + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id >= n || objs[id] >= bestObjs[id]) { + return; + } + + bestObjs[id] = objs[id]; +#pragma unroll + for (int m = 0; m < M; m++) { + bestCodes[id * M + m] = codes[id * M + m]; + } +} + +/** add L2 norm of codewords in a codebook to the unary terms + * + * uterm[i][k] = norm[k] + * + * @param uterm unary terms, size [n, K] + * @param norm L2 norm of each codeword in a codebook, size [K] + * @param K number of codewords in a codebook + */ +__global__ void runNormAddition(float* uterm, const float* norm, int K) { + int id = blockIdx.x; + int code = threadIdx.x; + + uterm[id * K + code] += norm[code]; +} + +IcmEncoderImpl::IcmEncoderImpl( + int M, + int K, + int dims, + GpuResourcesProvider* prov, + int device) + : M(M), K(K), dims(dims), prov(prov), device(device) { + res = prov->getResources(); +} + +void IcmEncoderImpl::computeUnaryTerms( + float* uterm, // output, [M, n, K] + const float* x, // [n, d] + const float* codebooks, // [M, K, d] + int n) const { + auto stream = res->getDefaultStreamCurrentDevice(); + auto handle = res->getBlasHandleCurrentDevice(); + + DeviceTensor vecs(const_cast(x), {n, dims}); + for (int m = 0; m < M; m++) { + auto cPtr = const_cast(codebooks + m * K * dims); + auto bPtr = uterm + m * n * K; + DeviceTensor ci(cPtr, {K, dims}); + DeviceTensor bi(bPtr, {n, K}); + runMatrixMult( + bi, false, vecs, false, ci, true, -2.0f, 0.0f, handle, stream); + } + + DeviceTensor c( + const_cast(codebooks), {M * K, dims}); + DeviceTensor norm( + res.get(), makeTempAlloc(AllocType::Other, stream), {M * K}); + runL2Norm(c, true, norm, true, stream); + + for (int m = 0; m < M; m++) { + auto uPtr = uterm + m * n * K; + auto nPtr = norm.data() + m * K; + runNormAddition<<>>(uPtr, nPtr, K); + } +} + +void IcmEncoderImpl::computeBinaryTerms(float* bterm, const float* codebooks) + const { + auto stream = res->getDefaultStreamCurrentDevice(); + auto handle = res->getBlasHandleCurrentDevice(); + + for (int m1 = 0; m1 < M; m1++) { + for (int m2 = 0; m2 < M; m2++) { + auto ptr1 = const_cast(codebooks + m1 * K * dims); + auto ptr2 = const_cast(codebooks + m2 * K * dims); + auto ptr3 = bterm + m1 * M * K * K + m2 * K * K; + DeviceTensor c1(ptr1, {K, dims}); + DeviceTensor c2(ptr2, {K, dims}); + DeviceTensor b(ptr3, {K, K}); + runMatrixMult( + b, false, c1, false, c2, true, 2.0f, 0.0f, handle, stream); + } + } +} + +void IcmEncoderImpl::setBinaryTerm(const float* codebooksHost) { + DeviceScope scope(device); + auto device = getCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // copy from host to device memory + codebooks = toDeviceNonTemporary( + res.get(), + device, + const_cast(codebooksHost), + stream, + {M, K, dims}); + bterm = DeviceTensor( + res.get(), makeDevAlloc(AllocType::Other, stream), {M, M, K, K}); + computeBinaryTerms(bterm.data(), codebooks.data()); +} + +void IcmEncoderImpl::encode( + int32_t* codesHost, + const float* xHost, + const float* codebooksHost, + std::mt19937& gen, + int n, + int nperts, + int ilsIters, + int icmIters) const { + DeviceScope scope(device); + auto device = getCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // copy from host to device memory + auto codes = toDeviceTemporary( + res.get(), device, const_cast(codesHost), stream, {n, M}); + auto x = toDeviceTemporary( + res.get(), device, const_cast(xHost), stream, {n, dims}); + + // compute unary terms + DeviceTensor uterm( + res.get(), makeTempAlloc(AllocType::Other, stream), {M, n, K}); + computeUnaryTerms(uterm.data(), x.data(), codebooks.data(), n); + + DeviceTensor bestCodes( + res.get(), makeTempAlloc(AllocType::Other, stream), {n, M}); + fromDevice(codes, bestCodes.data(), stream); + + DeviceTensor bestObjs( + res.get(), makeTempAlloc(AllocType::Other, stream), {n}); + + DeviceTensor objs( + res.get(), makeTempAlloc(AllocType::Other, stream), {n}); + + // compute how much shared memory we need + const int evaluateSmem = sizeof(float) * (dims + kWarpSize - 1) / kWarpSize; + const int encodeSmem = + sizeof(Pair) * (K + kWarpSize - 1) / kWarpSize; + + // compute the reconstruction error for each vector + runEvaluation<<>>( + x.data(), + codebooks.data(), + codes.data(), + bestObjs.data(), + n, + M, + K, + dims); + + int blockSize = 256; + int numBlocks = (n + blockSize - 1) / blockSize; + + for (int i = 0; i < ilsIters; i++) { + runCodesPerturbation<<>>( + gen(), codes.data(), n, M, K, nperts); + + // perform icm encoding + for (int j = 0; j < icmIters; j++) { + for (int m = 0; m < M; m++) { + runIcmEncodeStep<<>>( + uterm[m].data(), + bterm[m].data(), + codes.data(), + M, + K, + m); + } + } + + // compute the reconstruction error for each vector given codes + runEvaluation<<>>( + x.data(), + codebooks.data(), + codes.data(), + objs.data(), + n, + M, + K, + dims); + + // if objs[i] < best_objs[i], replace best_codes[i] with codes[i] + runCodesSelection<<>>( + bestCodes.data(), + bestObjs.data(), + codes.data(), + objs.data(), + n, + M); + + codes.copyFrom(bestCodes, stream); + } + + // copy back to host memory + fromDevice(bestCodes, codesHost, stream); +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/IcmEncoder.cuh b/faiss/gpu/impl/IcmEncoder.cuh new file mode 100644 index 000000000..9577aae8b --- /dev/null +++ b/faiss/gpu/impl/IcmEncoder.cuh @@ -0,0 +1,79 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace faiss { +namespace gpu { + +struct IcmEncoderImpl { + int M; ///< number of codebooks + int K; ///< number of codewords in a codebook + int dims; ///< dimensions of a codeword + + GpuResourcesProvider* prov; + std::shared_ptr res; + int device; + + DeviceTensor bterm; ///< bianry terms, size [M, M, K, K] + DeviceTensor codebooks; ///< codebooks, size [M, K, dims] + + IcmEncoderImpl( + int M, + int K, + int dims, + GpuResourcesProvider* prov, + int device); + + ~IcmEncoderImpl() {} + + ///< copy codebooks to device memory and compute unary terms + void setBinaryTerm(const float* codebooks); + + /** Compute unary terms. + * + * uterm[i] = x * codebook[i]^T, i = 1,...,M + * + * @param uterm output unary terms, size [M, n, K] + * @param x input vectors, size [n, dims] + * @param codebooks codebooks, size [M, K, dims] + * @param n number of input vectors + */ + void computeUnaryTerms( + float* bterm, + const float* x, + const float* codebooks, + int n) const; + + /** Compute binary terms. + * + * bterm[i][j] = codebooks[i] * codebooks[j]^T. i, j = 1,...,M + * + * @param bterm output binary terms, size [M, M, K, K] + * @param codebooks codebooks, size [M, K, dims] + */ + void computeBinaryTerms(float* bterm, const float* codebooks) const; + + ///< icm encode method + void encode( + int32_t* codes, + const float* x, + const float* codebooks, + std::mt19937& gen, + int n, + int nperts, + int ilsIters, + int icmIters) const; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/test/test_gpu_index.py b/faiss/gpu/test/test_gpu_index.py index 8fe898b92..937ef0660 100755 --- a/faiss/gpu/test/test_gpu_index.py +++ b/faiss/gpu/test/test_gpu_index.py @@ -10,6 +10,7 @@ import time import unittest import numpy as np import faiss +from faiss.contrib import datasets class EvalIVFPQAccuracy(unittest.TestCase): @@ -462,5 +463,44 @@ class TestInvalidParams(unittest.TestCase): self.assertTrue(np.array_equal(xb_indices[10:20], I[:, 0])) +class TestLSQIcmEncoder(unittest.TestCase): + + @staticmethod + def eval_codec(q, xb): + codes = q.compute_codes(xb) + decoded = q.decode(codes) + return ((xb - decoded) ** 2).sum() + + def subtest_gpu_encoding(self, ngpus): + """check that the error is in the same as cpu.""" + ds = datasets.SyntheticDataset(32, 1000, 1000, 0) + + xt = ds.get_train() + xb = ds.get_database() + + M = 4 + nbits = 8 + + lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits) + lsq.train(xt) + err_cpu = self.eval_codec(lsq, xb) + + lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits) + lsq.train(xt) + lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus) + err_gpu = self.eval_codec(lsq, xb) + + # 13804.411 vs 13814.794, 1 gpu + print(err_gpu, err_cpu) + self.assertLess(err_gpu, err_cpu * 1.05) + + def test_one_gpu(self): + self.subtest_gpu_encoding(1) + + def test_multiple_gpu(self): + ngpu = faiss.get_num_gpus() + self.subtest_gpu_encoding(ngpu) + + if __name__ == '__main__': unittest.main() diff --git a/faiss/impl/LocalSearchQuantizer.cpp b/faiss/impl/LocalSearchQuantizer.cpp index b7e4c0f9d..b8163e836 100644 --- a/faiss/impl/LocalSearchQuantizer.cpp +++ b/faiss/impl/LocalSearchQuantizer.cpp @@ -141,7 +141,8 @@ void random_int32( namespace faiss { -LSQTimer lsq_timer; +lsq::LSQTimer lsq_timer; +using lsq::LSQTimerScope; LocalSearchQuantizer::LocalSearchQuantizer( size_t d, @@ -168,6 +169,12 @@ LocalSearchQuantizer::LocalSearchQuantizer( random_seed = 0x12345; std::srand(random_seed); + + icm_encoder_factory = nullptr; +} + +LocalSearchQuantizer::~LocalSearchQuantizer() { + delete icm_encoder_factory; } LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {} @@ -177,8 +184,8 @@ void LocalSearchQuantizer::train(size_t n, const float* x) { FAISS_THROW_IF_NOT(nperts <= M); lsq_timer.reset(); + LSQTimerScope scope(&lsq_timer, "train"); if (verbose) { - lsq_timer.start("train"); printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n", M, n, @@ -241,7 +248,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) { } // refine codes - icm_encode(x, codes.data(), n, train_ils_iters, gen); + icm_encode(codes.data(), x, n, train_ils_iters, gen); if (verbose) { float obj = evaluate(codes.data(), x, n); @@ -269,13 +276,13 @@ void LocalSearchQuantizer::train(size_t n, const float* x) { } if (verbose) { - lsq_timer.end("train"); float obj = evaluate(codes.data(), x, n); + scope.finish(); printf("After training: obj = %lf\n", obj); printf("Time statistic:\n"); - for (const auto& it : lsq_timer.duration) { - printf("\t%s time: %lf s\n", it.first.data(), it.second); + for (const auto& it : lsq_timer.t) { + printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000); } } } @@ -284,7 +291,7 @@ void LocalSearchQuantizer::perturb_codebooks( float T, const std::vector& stddev, std::mt19937& gen) { - lsq_timer.start("perturb_codebooks"); + LSQTimerScope scope(&lsq_timer, "perturb_codebooks"); std::vector> distribs; for (size_t i = 0; i < d; i++) { @@ -298,8 +305,6 @@ void LocalSearchQuantizer::perturb_codebooks( } } } - - lsq_timer.end("perturb_codebooks"); } void LocalSearchQuantizer::compute_codes( @@ -307,23 +312,26 @@ void LocalSearchQuantizer::compute_codes( uint8_t* codes_out, size_t n) const { FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet."); + + lsq_timer.reset(); + LSQTimerScope scope(&lsq_timer, "encode"); if (verbose) { - lsq_timer.reset(); printf("Encoding %zd vectors...\n", n); - lsq_timer.start("encode"); } std::vector codes(n * M); std::mt19937 gen(random_seed); random_int32(codes, 0, K - 1, gen); - icm_encode(x, codes.data(), n, encode_ils_iters, gen); + icm_encode(codes.data(), x, n, encode_ils_iters, gen); pack_codes(n, codes.data(), codes_out); if (verbose) { - lsq_timer.end("encode"); - double t = lsq_timer.get("encode"); - printf("Time to encode %zd vectors: %lf s\n", n, t); + scope.finish(); + printf("Time statistic:\n"); + for (const auto& it : lsq_timer.t) { + printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000); + } } } @@ -347,7 +355,7 @@ void LocalSearchQuantizer::update_codebooks( const float* x, const int32_t* codes, size_t n) { - lsq_timer.start("update_codebooks"); + LSQTimerScope scope(&lsq_timer, "update_codebooks"); if (!update_codebooks_with_double) { // allocate memory @@ -485,8 +493,6 @@ void LocalSearchQuantizer::update_codebooks( codebooks[i] = (float)d_codebooks[i]; } } - - lsq_timer.end("update_codebooks"); } /** encode using iterative conditional mode @@ -508,15 +514,23 @@ void LocalSearchQuantizer::update_codebooks( * These two terms can be precomputed and store in a look up table. */ void LocalSearchQuantizer::icm_encode( - const float* x, int32_t* codes, + const float* x, size_t n, size_t ils_iters, std::mt19937& gen) const { - lsq_timer.start("icm_encode"); + LSQTimerScope scope(&lsq_timer, "icm_encode"); - std::vector binaries(M * M * K * K); // [M, M, K, K] - compute_binary_terms(binaries.data()); + auto factory = icm_encoder_factory; + std::unique_ptr icm_encoder; + if (factory == nullptr) { + icm_encoder.reset(lsq::IcmEncoderFactory().get(this)); + } else { + icm_encoder.reset(factory->get(this)); + } + + // precompute binary terms for all chunks + icm_encoder->set_binary_term(); const size_t n_chunks = (n + chunk_size - 1) / chunk_size; for (size_t i = 0; i < n_chunks; i++) { @@ -532,23 +546,20 @@ void LocalSearchQuantizer::icm_encode( const float* xi = x + i * chunk_size * d; int32_t* codesi = codes + i * chunk_size * M; - icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen); - - InterruptCallback::check(); + icm_encoder->verbose = (verbose && i == 0); + icm_encoder->encode(codesi, xi, gen, ni, ils_iters); } - - lsq_timer.end("icm_encode"); } -void LocalSearchQuantizer::icm_encode_partial( - size_t index, - const float* x, +void LocalSearchQuantizer::icm_encode_impl( int32_t* codes, - size_t n, + const float* x, const float* binaries, + std::mt19937& gen, + size_t n, size_t ils_iters, - std::mt19937& gen) const { - std::vector unaries(n * M * K); // [n, M, K] + bool verbose) const { + std::vector unaries(n * M * K); // [M, n, K] compute_unary_terms(x, unaries.data(), n); std::vector best_codes; @@ -562,9 +573,7 @@ void LocalSearchQuantizer::icm_encode_partial( // add perturbation to codes perturb_codes(codes, n, gen); - for (size_t iter2 = 0; iter2 < icm_iters; iter2++) { - icm_encode_step(unaries.data(), binaries, codes, n); - } + icm_encode_step(codes, unaries.data(), binaries, n, icm_iters); std::vector icm_objs(n, 0.0f); evaluate(codes, x, n, icm_objs.data()); @@ -587,7 +596,7 @@ void LocalSearchQuantizer::icm_encode_partial( memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M); - if (verbose && index == 0) { + if (verbose) { printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n", iter1, mean_obj, @@ -598,61 +607,67 @@ void LocalSearchQuantizer::icm_encode_partial( } void LocalSearchQuantizer::icm_encode_step( + int32_t* codes, const float* unaries, const float* binaries, - int32_t* codes, - size_t n) const { - // condition on the m-th subcode - for (size_t m = 0; m < M; m++) { - std::vector objs(n * K); -#pragma omp parallel for - for (int64_t i = 0; i < n; i++) { - auto u = unaries + i * (M * K) + m * K; - memcpy(objs.data() + i * K, u, sizeof(float) * K); - } - - // compute objective function by adding unary - // and binary terms together - for (size_t other_m = 0; other_m < M; other_m++) { - if (other_m == m) { - continue; - } + size_t n, + size_t n_iters) const { + FAISS_THROW_IF_NOT(M != 0 && K != 0); + FAISS_THROW_IF_NOT(binaries != nullptr); + for (size_t iter = 0; iter < n_iters; iter++) { + // condition on the m-th subcode + for (size_t m = 0; m < M; m++) { + std::vector objs(n * K); #pragma omp parallel for for (int64_t i = 0; i < n; i++) { - for (int32_t code = 0; code < K; code++) { - int32_t code2 = codes[i * M + other_m]; - size_t binary_idx = - m * M * K * K + other_m * K * K + code * K + code2; - // binaries[m, other_m, code, code2] - objs[i * K + code] += binaries[binary_idx]; - } + auto u = unaries + m * n * K + i * K; + memcpy(objs.data() + i * K, u, sizeof(float) * K); } - } - // find the optimal value of the m-th subcode + // compute objective function by adding unary + // and binary terms together + for (size_t other_m = 0; other_m < M; other_m++) { + if (other_m == m) { + continue; + } + #pragma omp parallel for - for (int64_t i = 0; i < n; i++) { - float best_obj = HUGE_VALF; - int32_t best_code = 0; - for (size_t code = 0; code < K; code++) { - float obj = objs[i * K + code]; - if (obj < best_obj) { - best_obj = obj; - best_code = code; + for (int64_t i = 0; i < n; i++) { + for (int32_t code = 0; code < K; code++) { + int32_t code2 = codes[i * M + other_m]; + size_t binary_idx = m * M * K * K + other_m * K * K + + code * K + code2; + // binaries[m, other_m, code, code2] + objs[i * K + code] += binaries[binary_idx]; + } } } - codes[i * M + m] = best_code; - } - } // loop M + // find the optimal value of the m-th subcode +#pragma omp parallel for + for (int64_t i = 0; i < n; i++) { + float best_obj = HUGE_VALF; + int32_t best_code = 0; + for (size_t code = 0; code < K; code++) { + float obj = objs[i * K + code]; + if (obj < best_obj) { + best_obj = obj; + best_code = code; + } + } + codes[i * M + m] = best_code; + } + + } // loop M + } } void LocalSearchQuantizer::perturb_codes( int32_t* codes, size_t n, std::mt19937& gen) const { - lsq_timer.start("perturb_codes"); + LSQTimerScope scope(&lsq_timer, "perturb_codes"); std::uniform_int_distribution m_distrib(0, M - 1); std::uniform_int_distribution k_distrib(0, K - 1); @@ -663,12 +678,10 @@ void LocalSearchQuantizer::perturb_codes( codes[i * M + m] = k_distrib(gen); } } - - lsq_timer.end("perturb_codes"); } void LocalSearchQuantizer::compute_binary_terms(float* binaries) const { - lsq_timer.start("compute_binary_terms"); + LSQTimerScope scope(&lsq_timer, "compute_binary_terms"); #pragma omp parallel for for (int64_t m12 = 0; m12 < M * M; m12++) { @@ -686,52 +699,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const { } } } - - lsq_timer.end("compute_binary_terms"); } void LocalSearchQuantizer::compute_unary_terms( const float* x, - float* unaries, + float* unaries, // [M, n, K] size_t n) const { - lsq_timer.start("compute_unary_terms"); + LSQTimerScope scope(&lsq_timer, "compute_unary_terms"); - // compute x * codebooks^T + // compute x * codebook^T for each codebook // // NOTE: LAPACK use column major order // out = alpha * op(A) * op(B) + beta * C - FINTEGER nrows_A = M * K; - FINTEGER ncols_A = d; - FINTEGER nrows_B = d; - FINTEGER ncols_B = n; + for (size_t m = 0; m < M; m++) { + FINTEGER nrows_A = K; + FINTEGER ncols_A = d; - float alpha = -2.0f; - float beta = 0.0f; - sgemm_("Transposed", - "Not Transposed", - &nrows_A, // nrows of op(A) - &ncols_B, // ncols of op(B) - &ncols_A, // ncols of op(A) - &alpha, - codebooks.data(), - &ncols_A, // nrows of A - x, - &nrows_B, // nrows of B - &beta, - unaries, - &nrows_A); // nrows of output + FINTEGER nrows_B = d; + FINTEGER ncols_B = n; + + float alpha = -2.0f; + float beta = 0.0f; + sgemm_("Transposed", + "Not Transposed", + &nrows_A, // nrows of op(A) + &ncols_B, // ncols of op(B) + &ncols_A, // ncols of op(A) + &alpha, + codebooks.data() + m * K * d, + &ncols_A, // nrows of A + x, + &nrows_B, // nrows of B + &beta, + unaries + m * n * K, + &nrows_A); // nrows of output + } std::vector norms(M * K); fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K); #pragma omp parallel for for (int64_t i = 0; i < n; i++) { - float* u = unaries + i * (M * K); - fvec_add(M * K, u, norms.data(), u); + for (size_t m = 0; m < M; m++) { + float* u = unaries + m * n * K + i * K; + fvec_add(K, u, norms.data() + m * K, u); + } } - - lsq_timer.end("compute_unary_terms"); } float LocalSearchQuantizer::evaluate( @@ -739,7 +753,7 @@ float LocalSearchQuantizer::evaluate( const float* x, size_t n, float* objs) const { - lsq_timer.start("evaluate"); + LSQTimerScope scope(&lsq_timer, "evaluate"); // decode std::vector decoded_x(n * d, 0.0f); @@ -755,7 +769,7 @@ float LocalSearchQuantizer::evaluate( fvec_add(d, decoded_i, c, decoded_i); } - float err = fvec_L2sqr(x + i * d, decoded_i, d); + float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d); obj += err; if (objs) { @@ -763,34 +777,68 @@ float LocalSearchQuantizer::evaluate( } } - lsq_timer.end("evaluate"); - obj = obj / n; return obj; } +namespace lsq { + +IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq) + : verbose(false), lsq(lsq) {} + +void IcmEncoder::set_binary_term() { + auto M = lsq->M; + auto K = lsq->K; + binaries.resize(M * M * K * K); + lsq->compute_binary_terms(binaries.data()); +} + +void IcmEncoder::encode( + int32_t* codes, + const float* x, + std::mt19937& gen, + size_t n, + size_t ils_iters) const { + lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose); +} + double LSQTimer::get(const std::string& name) { - return duration[name]; + if (t.count(name) == 0) { + return 0.0; + } else { + return t[name]; + } } -void LSQTimer::start(const std::string& name) { - FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running"); - started[name] = true; - t0[name] = getmillisecs(); -} - -void LSQTimer::end(const std::string& name) { - FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running"); - double t1 = getmillisecs(); - double sec = (t1 - t0[name]) / 1000; - duration[name] += sec; - started[name] = false; +void LSQTimer::add(const std::string& name, double delta) { + if (t.count(name) == 0) { + t[name] = delta; + } else { + t[name] += delta; + } } void LSQTimer::reset() { - duration.clear(); - t0.clear(); - started.clear(); + t.clear(); } +LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name) + : timer(timer), name(name), finished(false) { + t0 = getmillisecs(); +} + +void LSQTimerScope::finish() { + if (!finished) { + auto delta = getmillisecs() - t0; + timer->add(name, delta); + finished = true; + } +} + +LSQTimerScope::~LSQTimerScope() { + finish(); +} + +} // namespace lsq + } // namespace faiss diff --git a/faiss/impl/LocalSearchQuantizer.h b/faiss/impl/LocalSearchQuantizer.h index 50f7396bf..c02bce33e 100644 --- a/faiss/impl/LocalSearchQuantizer.h +++ b/faiss/impl/LocalSearchQuantizer.h @@ -20,6 +20,12 @@ namespace faiss { +namespace lsq { + +struct IcmEncoderFactory; + +} // namespace lsq + /** Implementation of LSQ/LSQ++ described in the following two papers: * * Revisiting additive quantization @@ -36,7 +42,6 @@ namespace faiss { * The trained codes are stored in `codebooks` which is called * `centroids` in PQ and RQ. */ - struct LocalSearchQuantizer : AdditiveQuantizer { size_t K; ///< number of codes per codebook @@ -54,6 +59,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer { int random_seed; ///< seed for random generator size_t nperts; ///< number of perturbation in each code + ///< if non-NULL, use this encoder to encode + lsq::IcmEncoderFactory* icm_encoder_factory; + bool update_codebooks_with_double = true; LocalSearchQuantizer( @@ -66,6 +74,8 @@ struct LocalSearchQuantizer : AdditiveQuantizer { LocalSearchQuantizer(); + ~LocalSearchQuantizer() override; + // Train the local search quantizer void train(size_t n, const float* x) override; @@ -73,6 +83,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer { * * @param x vectors to encode, size n * d * @param codes output codes, size n * code_size + * @param n number of vectors */ void compute_codes(const float* x, uint8_t* codes, size_t n) const override; @@ -80,36 +91,46 @@ struct LocalSearchQuantizer : AdditiveQuantizer { * * @param x training vectors, size n * d * @param codes encoded training vectors, size n * M + * @param n number of vectors */ void update_codebooks(const float* x, const int32_t* codes, size_t n); /** Encode vectors given codebooks using iterative conditional mode (icm). * - * @param x vectors to encode, size n * d - * @param codes output codes, size n * M + * @param codes output codes, size n * M + * @param x vectors to encode, size n * d + * @param n number of vectors * @param ils_iters number of iterations of iterative local search */ void icm_encode( - const float* x, int32_t* codes, + const float* x, size_t n, size_t ils_iters, std::mt19937& gen) const; - void icm_encode_partial( - size_t index, - const float* x, + void icm_encode_impl( int32_t* codes, + const float* x, + const float* unaries, + std::mt19937& gen, size_t n, - const float* binaries, size_t ils_iters, - std::mt19937& gen) const; + bool verbose) const; void icm_encode_step( + int32_t* codes, const float* unaries, const float* binaries, - int32_t* codes, - size_t n) const; + size_t n, + size_t n_iters) const; + + /** Add some perturbation to codes + * + * @param codes codes to be perturbed, size n * M + * @param n number of vectors + */ + void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const; /** Add some perturbation to codebooks * @@ -121,12 +142,6 @@ struct LocalSearchQuantizer : AdditiveQuantizer { const std::vector& stddev, std::mt19937& gen); - /** Add some perturbation to codes - * - * @param codes codes to be perturbed, size n * M - */ - void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const; - /** Compute binary terms * * @param binaries binary terms, size M * M * K * K @@ -135,6 +150,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer { /** Compute unary terms * + * @param n number of vectors * @param x vectors to encode, size n * d * @param unaries unary terms, size n * M * K */ @@ -142,8 +158,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer { /** Helper function to compute reconstruction error * - * @param x vectors to encode, size n * d * @param codes encoded codes, size n * M + * @param x vectors to encode, size n * d + * @param n number of vectors * @param objs if it is not null, store reconstruction error of each vector into it, size n */ @@ -154,13 +171,50 @@ struct LocalSearchQuantizer : AdditiveQuantizer { float* objs = nullptr) const; }; +namespace lsq { + +struct IcmEncoder { + std::vector binaries; + + bool verbose; + + const LocalSearchQuantizer* lsq; + + explicit IcmEncoder(const LocalSearchQuantizer* lsq); + + virtual ~IcmEncoder() {} + + ///< compute binary terms + virtual void set_binary_term(); + + /** Encode vectors given codebooks + * + * @param codes output codes, size n * M + * @param x vectors to encode, size n * d + * @param gen random generator + * @param n number of vectors + * @param ils_iters number of iterations of iterative local search + */ + virtual void encode( + int32_t* codes, + const float* x, + std::mt19937& gen, + size_t n, + size_t ils_iters) const; +}; + +struct IcmEncoderFactory { + virtual IcmEncoder* get(const LocalSearchQuantizer* lsq) { + return new IcmEncoder(lsq); + } + virtual ~IcmEncoderFactory() {} +}; + /** A helper struct to count consuming time during training. * It is NOT thread-safe. */ struct LSQTimer { - std::unordered_map duration; - std::unordered_map t0; - std::unordered_map started; + std::unordered_map t; LSQTimer() { reset(); @@ -168,13 +222,24 @@ struct LSQTimer { double get(const std::string& name); - void start(const std::string& name); - - void end(const std::string& name); + void add(const std::string& name, double delta); void reset(); }; -FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time +struct LSQTimerScope { + double t0; + LSQTimer* timer; + std::string name; + bool finished; + + LSQTimerScope(LSQTimer* timer, std::string name); + + void finish(); + + ~LSQTimerScope(); +}; + +} // namespace lsq } // namespace faiss diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index fef6705ed..a4623dfa0 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -287,6 +287,7 @@ void gpu_sync_all_devices(); #include #include #include +#include int get_num_gpus() { @@ -490,6 +491,7 @@ void gpu_sync_all_devices() %include %include %include +%include #endif diff --git a/tests/test_lsq.py b/tests/test_lsq.py index a29b2d104..2c301a43b 100644 --- a/tests/test_lsq.py +++ b/tests/test_lsq.py @@ -15,6 +15,9 @@ import unittest from faiss.contrib import datasets +sp = faiss.swig_ptr + + def construct_sparse_matrix(codes, K): n, M = codes.shape B = np.zeros((n, M * K), dtype=np.float32) @@ -54,18 +57,18 @@ def compute_binary_terms_ref(codebooks): def compute_unary_terms_ref(codebooks, x): codebooks_t = np.swapaxes(codebooks, 1, 2) # [M, d, K] unaries = -2 * x.dot(codebooks_t) # [n, M, K] - code_norms = np.sum(codebooks * codebooks, axis=2) # [M, K] unaries += code_norms + unaries = np.swapaxes(unaries, 0, 1) # [M, n, K] return unaries def icm_encode_step_ref(unaries, binaries, codes): - n, M, K = unaries.shape + M, n, K = unaries.shape for m in range(M): - objs = unaries[:, m].copy() # [n, K] + objs = unaries[m].copy() # [n, K] for m2 in range(M): # pair, m2 != m if m2 == m: @@ -131,8 +134,8 @@ class TestComponents(unittest.TestCase): # decode x pack_codes = np.zeros((n, lsq.code_size)).astype(np.uint8) decoded_x = np.zeros((n, d)).astype(np.float32) - lsq.pack_codes(n, faiss.swig_ptr(codes), faiss.swig_ptr(pack_codes)) - lsq.decode_c(faiss.swig_ptr(pack_codes), faiss.swig_ptr(decoded_x), n) + lsq.pack_codes(n, sp(codes), sp(pack_codes)) + lsq.decode_c(sp(pack_codes), sp(decoded_x), n) # decode in Python codebooks = faiss.vector_float_to_array(lsq.codebooks) @@ -163,7 +166,7 @@ class TestComponents(unittest.TestCase): codebooks = faiss.vector_float_to_array(lsq.codebooks) codebooks = codebooks.reshape(M, K, d).copy() - lsq.update_codebooks(faiss.swig_ptr(x), faiss.swig_ptr(codes), n) + lsq.update_codebooks(sp(x), sp(codes), n) new_codebooks = faiss.vector_float_to_array(lsq.codebooks) new_codebooks = new_codebooks.reshape(M, K, d).copy() @@ -209,7 +212,7 @@ class TestComponents(unittest.TestCase): lsq = faiss.LocalSearchQuantizer(d, M, nbits) lsq.train(x) # just for allocating memory for codebooks - lsq.compute_binary_terms(faiss.swig_ptr(binaries)) + lsq.compute_binary_terms(sp(binaries)) codebooks = faiss.vector_float_to_array(lsq.codebooks) codebooks = codebooks.reshape(M, K, d).copy() @@ -226,12 +229,12 @@ class TestComponents(unittest.TestCase): rs = np.random.RandomState(123) x = rs.rand(n, d).astype(np.float32) - unaries = np.zeros((n, M, K)).astype(np.float32) + unaries = np.zeros((M, n, K)).astype(np.float32) lsq = faiss.LocalSearchQuantizer(d, M, nbits) lsq.train(x) # just for allocating memory for codebooks - lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n) + lsq.compute_unary_terms(sp(x), sp(unaries), n) codebooks = faiss.vector_float_to_array(lsq.codebooks) codebooks = codebooks.reshape(M, K, d).copy() @@ -251,15 +254,17 @@ class TestComponents(unittest.TestCase): # randomly generate codes, binary terms and unary terms codes = rs.randint(0, K, (n, M)).astype(np.int32) new_codes = codes.copy() - unaries = rs.rand(n, M, K).astype(np.float32) + unaries = rs.rand(M, n, K).astype(np.float32) binaries = rs.rand(M, M, K, K).astype(np.float32) # do icm encoding given binary and unary terms lsq = faiss.LocalSearchQuantizer(d, M, nbits) lsq.icm_encode_step( - faiss.swig_ptr(unaries), - faiss.swig_ptr(binaries), - faiss.swig_ptr(new_codes), n) + sp(new_codes), + sp(unaries), + sp(binaries), + n, + 1) # do icm encoding given binary and unary terms in Python ref_codes = icm_encode_step_ref(unaries, binaries, codes) @@ -280,11 +285,11 @@ class TestComponents(unittest.TestCase): # compute binary terms binaries = np.zeros((M, M, K, K)).astype(np.float32) - lsq.compute_binary_terms(faiss.swig_ptr(binaries)) + lsq.compute_binary_terms(sp(binaries)) # compute unary terms - unaries = np.zeros((n, M, K)).astype(np.float32) - lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n) + unaries = np.zeros((M, n, K)).astype(np.float32) + lsq.compute_unary_terms(sp(x), sp(unaries), n) # randomly generate codes codes = rs.randint(0, K, (n, M)).astype(np.int32) @@ -292,9 +297,11 @@ class TestComponents(unittest.TestCase): # do icm encoding given binary and unary terms lsq.icm_encode_step( - faiss.swig_ptr(unaries), - faiss.swig_ptr(binaries), - faiss.swig_ptr(new_codes), n) + sp(new_codes), + sp(unaries), + sp(binaries), + n, + 1) # do icm encoding without pre-computed unary and bianry terms in Python codebooks = faiss.vector_float_to_array(lsq.codebooks)