Support LSQ on GPU (#1978)

Summary:
## Description

This PR added support for LSQ on GPU. Only the encoding part is running on GPU and the others are still running on CPU.

Multi-GPU is also supported.

## Usage

``` python
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
ngpus = faiss.get_num_gpus()
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)  # we use all gpus

lsq.train(xt)
codes = lsq.compute_codes(xb)
decoded = lsq.decode(codes)
```

## Performance on SIFT1M

On 1 GPU:
```
===== lsq-gpu:
        mean square error = 17337.878528
        training time: 40.9857234954834 s
        encoding time: 27.12640070915222 s
```

On 2 GPUs:
```
===== lsq-gpu:
        mean square error = 17364.658176
        training time: 25.832106113433838 s
        encoding time: 14.879548072814941 s
```

On CPU:
```
===== lsq:
        mean square error = 17305.880576
        training time: 152.57522344589233 s
        encoding time: 110.01779270172119 s
```

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1978

Test Plan: buck test mode/dev-nosan //faiss/gpu/test/:test_gpu_index_py -- TestLSQIcmEncoder

Reviewed By: wickedfoo

Differential Revision: D29609763

Pulled By: mdouze

fbshipit-source-id: b6ffa2a3c02bf696a4e52348132affa0dd838870
pull/2055/head
Chengqi Deng 2021-09-09 09:10:37 -07:00 committed by Facebook GitHub Bot
parent 8a2860c1dc
commit eba1cb1a90
12 changed files with 983 additions and 170 deletions

View File

@ -9,6 +9,8 @@ We try to indicate most contributions here with the contributor names who are no
the Facebook Faiss team. Feel free to add entries here if you submit a PR.
## [Unreleased]
### Added
- Support LSQ on GPU (by @KinglittleQ)
## [1.7.1] - 2021-05-27
### Added

View File

@ -82,6 +82,13 @@ nt, d = xt.shape
# fastest to slowest
if 'lsq-gpu' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
ngpus = faiss.get_num_gpus()
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
lsq.verbose = True
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
print("===== PQ")

View File

@ -9,6 +9,7 @@ set(FAISS_GPU_SRC
GpuCloner.cpp
GpuClonerOptions.cpp
GpuDistance.cu
GpuIcmEncoder.cu
GpuIndex.cu
GpuIndexBinaryFlat.cu
GpuIndexFlat.cu
@ -46,6 +47,7 @@ set(FAISS_GPU_SRC
impl/scan/IVFInterleaved512.cu
impl/scan/IVFInterleaved1024.cu
impl/scan/IVFInterleaved2048.cu
impl/IcmEncoder.cu
utils/BlockSelectFloat.cu
utils/DeviceUtils.cu
utils/StackDeviceMemory.cpp
@ -80,6 +82,7 @@ set(FAISS_GPU_HEADERS
GpuCloner.h
GpuClonerOptions.h
GpuDistance.h
GpuIcmEncoder.h
GpuFaissAssert.h
GpuIndex.h
GpuIndexBinaryFlat.h
@ -118,6 +121,7 @@ set(FAISS_GPU_HEADERS
impl/RemapIndices.h
impl/VectorResidual.cuh
impl/scan/IVFInterleavedImpl.cuh
impl/IcmEncoder.cuh
utils/BlockSelectKernel.cuh
utils/Comparators.cuh
utils/ConversionOperators.cuh

View File

@ -0,0 +1,120 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/GpuIcmEncoder.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <faiss/utils/WorkerThread.h>
#include <faiss/gpu/impl/IcmEncoder.cuh>
#include <algorithm>
namespace faiss {
namespace gpu {
///< A helper structure to support multi-GPU
struct IcmEncoderShards {
std::vector<std::pair<
std::unique_ptr<IcmEncoderImpl>,
std::unique_ptr<WorkerThread>>>
workers;
void add(IcmEncoderImpl* encoder) {
workers.emplace_back(std::make_pair(
std::unique_ptr<IcmEncoderImpl>(encoder),
std::unique_ptr<WorkerThread>(new WorkerThread)));
}
IcmEncoderImpl* at(int idx) {
return workers[idx].first.get();
}
///< call f(idx, encoder) for each encoder
void runOnShards(std::function<void(int, IcmEncoderImpl*)> f) {
std::vector<std::future<bool>> v;
for (int i = 0; i < this->workers.size(); ++i) {
auto& p = this->workers[i];
auto encoder = p.first.get();
v.emplace_back(p.second->add([f, i, encoder]() { f(i, encoder); }));
}
for (int i = 0; i < v.size(); ++i) {
auto& fut = v[i];
fut.get(); // no exception handle, crash if any thread down
}
}
size_t size() {
return workers.size();
}
};
GpuIcmEncoder::GpuIcmEncoder(
const LocalSearchQuantizer* lsq,
const std::vector<GpuResourcesProvider*>& provs,
const std::vector<int>& devices)
: lsq::IcmEncoder(lsq), shards(new IcmEncoderShards()) {
// create an IcmEncoderImpl instance for each device.
for (size_t i = 0; i < provs.size(); i++) {
shards->add(new IcmEncoderImpl(
lsq->M, lsq->K, lsq->d, provs[i], devices[i]));
}
}
GpuIcmEncoder::~GpuIcmEncoder() {}
void GpuIcmEncoder::set_binary_term() {
auto fn = [=](int idx, IcmEncoderImpl* encoder) {
encoder->setBinaryTerm(lsq->codebooks.data());
};
shards->runOnShards(fn);
}
void GpuIcmEncoder::encode(
int32_t* codes,
const float* x,
std::mt19937& gen,
size_t n,
size_t ils_iters) const {
size_t nshards = shards->size();
size_t shard_size = (n + nshards - 1) / nshards;
auto codebooks = lsq->codebooks.data();
auto M = lsq->M;
auto d = lsq->d;
auto nperts = lsq->nperts;
auto icm_iters = lsq->icm_iters;
auto seed = gen();
// split input data
auto fn = [=](int idx, IcmEncoderImpl* encoder) {
size_t i0 = idx * shard_size;
size_t ni = std::min(shard_size, n - i0);
auto xi = x + i0 * d;
auto ci = codes + i0 * M;
std::mt19937 geni(idx + seed); // different seed for each shard
encoder->encode(
ci, xi, codebooks, geni, ni, nperts, ils_iters, icm_iters);
};
shards->runOnShards(fn);
}
GpuIcmEncoderFactory::GpuIcmEncoderFactory(int ngpus) {
for (int i = 0; i < ngpus; i++) {
provs.push_back(new StandardGpuResources());
devices.push_back(i);
}
}
lsq::IcmEncoder* GpuIcmEncoderFactory::get(const LocalSearchQuantizer* lsq) {
return new GpuIcmEncoder(lsq, provs, devices);
}
} // namespace gpu
} // namespace faiss

View File

@ -0,0 +1,60 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/impl/LocalSearchQuantizer.h>
#include <memory>
namespace faiss {
namespace gpu {
class GpuResourcesProvider;
struct IcmEncoderShards;
/** Perform LSQ encoding on GPU.
*
* Split input vectors to different devices and call IcmEncoderImpl::encode
* to encode them
*/
class GpuIcmEncoder : public lsq::IcmEncoder {
public:
GpuIcmEncoder(
const LocalSearchQuantizer* lsq,
const std::vector<GpuResourcesProvider*>& provs,
const std::vector<int>& devices);
~GpuIcmEncoder();
GpuIcmEncoder(const GpuIcmEncoder&) = delete;
GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete;
void set_binary_term() override;
void encode(
int32_t* codes,
const float* x,
std::mt19937& gen,
size_t n,
size_t ils_iters) const override;
private:
std::unique_ptr<IcmEncoderShards> shards;
};
struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory {
explicit GpuIcmEncoderFactory(int ngpus = 1);
lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override;
std::vector<GpuResourcesProvider*> provs;
std::vector<int> devices;
};
} // namespace gpu
} // namespace faiss

View File

@ -0,0 +1,379 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/IcmEncoder.cuh>
#include <faiss/gpu/GpuResources.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/gpu/impl/L2Norm.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/MatrixMult.cuh>
#include <faiss/gpu/utils/Pair.cuh>
#include <faiss/gpu/utils/Reductions.cuh>
#include <curand_kernel.h>
namespace faiss {
namespace gpu {
extern __shared__ char smem[];
/** encode using iterative conditional mode
*
* For subcode cm of a vector, we fix the other subcodes cj (j != m)
* and then find the optimal value of cm (cm = 1,...,K) such that
* minimizing the objective function.
*
* @param uterm precomputed unary terms, size (M, n, K)
* @param bterm precomputed binary terms, size (M1, M2, K1, K2)
* @param codes output vector encodings, size (n, M)
* @param M number of codebooks
* @param K number of codewords in a codebook
* @param m identify which subcode to condition on
*/
__global__ void runIcmEncodeStep(
const float* uterm,
const float* bterm,
int32_t* codes,
int M,
int K,
int m) {
using KVPair = Pair<float, int>;
int id = blockIdx.x; // each block takes care of one vector
int code = threadIdx.x; // each thread takes care of one possible code
// compute the objective value by look-up tables
KVPair obj(0.0f, code);
obj.k = uterm[id * K + code];
#pragma unroll
for (int m2 = 0; m2 < M; m2++) {
if (m2 == m) {
continue;
}
int32_t code2 = codes[id * M + m2];
obj.k += bterm[m2 * K * K + code * K + code2];
}
// find the minimum objective value and the corresponding code
__syncthreads();
obj = blockReduceAll<KVPair, Min<KVPair>, false, false>(
obj, Min<KVPair>(), (KVPair*)smem);
if (code == 0) {
codes[id * M + m] = obj.v;
}
}
/** compute reconstruction error for each vector
*
* decoded_x[i] = \sum codebooks[m][codes[i][m]], m = 1,..,M
* obj[i] = ||x[i] - decoded_x[i]||^2
*
* @param x input vectors, size [n, dims]
* @param codebooks codebooks, size [M, K, dims]
* @param codes vector codes, size [n, M]
* @param obj output reconstruction errors, size [n]
* @param n number of input vectors
* @param K number of codewords in a codebook
* @param M number of codebooks
*/
__global__ void runEvaluation(
const float* x,
const float* codebooks,
const int32_t* codes,
float* obj, // output
int n,
int M,
int K,
int dims) {
int id = blockIdx.x; // each block takes care of one vector
int d = threadIdx.x; // each thread takes care of one dimension
float acc = 0.0f;
#pragma unroll
for (int m = 0; m < M; m++) {
int32_t code = codes[id * M + m];
acc += codebooks[m * K * dims + code * dims + d];
}
acc -= x[id * dims + d];
acc = acc * acc;
// sum values of all dimensions together
__syncthreads();
acc = blockReduceAllSum<float, false, false>(acc, (float*)smem);
if (d == 0) {
obj[id] = acc;
}
}
/** perturb vector codes
*
* repeat nperts times:
* codes[i][randint(0, M)] = randint(0, K)
*
* @param seed random seed
* @param codes vector codes, size [n, M]
* @param n number of input vectors
* @param M number of codebooks
* @param K number of codewords in a codebook
* @param nperts number of subcode to be perturbed in a vector
*/
__global__ void runCodesPerturbation(
int seed,
int32_t* codes,
int n,
int M,
int K,
int nperts) {
// each thread takes care of one vector
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id >= n) {
return;
}
// we have to initialize the state
curandState_t state;
curand_init(seed, id, 0, &state);
for (int i = 0; i < nperts; i++) {
int pos = int(curand_uniform(&state) * M);
int32_t val = int32_t(curand_uniform(&state) * K);
codes[id * M + pos] = val;
}
}
/** select the best codes by reconstruction errors
*
* if objs[i] < best_objs[i]:
* best_objs[i] = objs[i]
* best_codes[i] = codes[i]
*
* @param bestCodes the best codes we've encountered, size [n, M]
* @param bestObjs min reconstruction errors we've encountered, size [n]
* @param codes input vector codes, size [n, M]
* @param objs reconstruction errors of input vector codes, size [n]
* @param n number of input vectors
*/
__global__ void runCodesSelection(
int32_t* bestCodes,
float* bestObjs,
const int32_t* codes,
const float* objs,
int n,
int M) {
// each thread takes care of one vector
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id >= n || objs[id] >= bestObjs[id]) {
return;
}
bestObjs[id] = objs[id];
#pragma unroll
for (int m = 0; m < M; m++) {
bestCodes[id * M + m] = codes[id * M + m];
}
}
/** add L2 norm of codewords in a codebook to the unary terms
*
* uterm[i][k] = norm[k]
*
* @param uterm unary terms, size [n, K]
* @param norm L2 norm of each codeword in a codebook, size [K]
* @param K number of codewords in a codebook
*/
__global__ void runNormAddition(float* uterm, const float* norm, int K) {
int id = blockIdx.x;
int code = threadIdx.x;
uterm[id * K + code] += norm[code];
}
IcmEncoderImpl::IcmEncoderImpl(
int M,
int K,
int dims,
GpuResourcesProvider* prov,
int device)
: M(M), K(K), dims(dims), prov(prov), device(device) {
res = prov->getResources();
}
void IcmEncoderImpl::computeUnaryTerms(
float* uterm, // output, [M, n, K]
const float* x, // [n, d]
const float* codebooks, // [M, K, d]
int n) const {
auto stream = res->getDefaultStreamCurrentDevice();
auto handle = res->getBlasHandleCurrentDevice();
DeviceTensor<float, 2, true> vecs(const_cast<float*>(x), {n, dims});
for (int m = 0; m < M; m++) {
auto cPtr = const_cast<float*>(codebooks + m * K * dims);
auto bPtr = uterm + m * n * K;
DeviceTensor<float, 2, true> ci(cPtr, {K, dims});
DeviceTensor<float, 2, true> bi(bPtr, {n, K});
runMatrixMult(
bi, false, vecs, false, ci, true, -2.0f, 0.0f, handle, stream);
}
DeviceTensor<float, 2, true> c(
const_cast<float*>(codebooks), {M * K, dims});
DeviceTensor<float, 1, true> norm(
res.get(), makeTempAlloc(AllocType::Other, stream), {M * K});
runL2Norm(c, true, norm, true, stream);
for (int m = 0; m < M; m++) {
auto uPtr = uterm + m * n * K;
auto nPtr = norm.data() + m * K;
runNormAddition<<<n, K, 0, stream>>>(uPtr, nPtr, K);
}
}
void IcmEncoderImpl::computeBinaryTerms(float* bterm, const float* codebooks)
const {
auto stream = res->getDefaultStreamCurrentDevice();
auto handle = res->getBlasHandleCurrentDevice();
for (int m1 = 0; m1 < M; m1++) {
for (int m2 = 0; m2 < M; m2++) {
auto ptr1 = const_cast<float*>(codebooks + m1 * K * dims);
auto ptr2 = const_cast<float*>(codebooks + m2 * K * dims);
auto ptr3 = bterm + m1 * M * K * K + m2 * K * K;
DeviceTensor<float, 2, true> c1(ptr1, {K, dims});
DeviceTensor<float, 2, true> c2(ptr2, {K, dims});
DeviceTensor<float, 2, true> b(ptr3, {K, K});
runMatrixMult(
b, false, c1, false, c2, true, 2.0f, 0.0f, handle, stream);
}
}
}
void IcmEncoderImpl::setBinaryTerm(const float* codebooksHost) {
DeviceScope scope(device);
auto device = getCurrentDevice();
auto stream = res->getDefaultStreamCurrentDevice();
// copy from host to device memory
codebooks = toDeviceNonTemporary<float, 3>(
res.get(),
device,
const_cast<float*>(codebooksHost),
stream,
{M, K, dims});
bterm = DeviceTensor<float, 4, true>(
res.get(), makeDevAlloc(AllocType::Other, stream), {M, M, K, K});
computeBinaryTerms(bterm.data(), codebooks.data());
}
void IcmEncoderImpl::encode(
int32_t* codesHost,
const float* xHost,
const float* codebooksHost,
std::mt19937& gen,
int n,
int nperts,
int ilsIters,
int icmIters) const {
DeviceScope scope(device);
auto device = getCurrentDevice();
auto stream = res->getDefaultStreamCurrentDevice();
// copy from host to device memory
auto codes = toDeviceTemporary<int32_t, 2>(
res.get(), device, const_cast<int32_t*>(codesHost), stream, {n, M});
auto x = toDeviceTemporary<float, 2>(
res.get(), device, const_cast<float*>(xHost), stream, {n, dims});
// compute unary terms
DeviceTensor<float, 3, true> uterm(
res.get(), makeTempAlloc(AllocType::Other, stream), {M, n, K});
computeUnaryTerms(uterm.data(), x.data(), codebooks.data(), n);
DeviceTensor<int32_t, 2, true> bestCodes(
res.get(), makeTempAlloc(AllocType::Other, stream), {n, M});
fromDevice<int32_t, 2>(codes, bestCodes.data(), stream);
DeviceTensor<float, 1, true> bestObjs(
res.get(), makeTempAlloc(AllocType::Other, stream), {n});
DeviceTensor<float, 1, true> objs(
res.get(), makeTempAlloc(AllocType::Other, stream), {n});
// compute how much shared memory we need
const int evaluateSmem = sizeof(float) * (dims + kWarpSize - 1) / kWarpSize;
const int encodeSmem =
sizeof(Pair<float, int>) * (K + kWarpSize - 1) / kWarpSize;
// compute the reconstruction error for each vector
runEvaluation<<<n, dims, evaluateSmem, stream>>>(
x.data(),
codebooks.data(),
codes.data(),
bestObjs.data(),
n,
M,
K,
dims);
int blockSize = 256;
int numBlocks = (n + blockSize - 1) / blockSize;
for (int i = 0; i < ilsIters; i++) {
runCodesPerturbation<<<numBlocks, blockSize, 0, stream>>>(
gen(), codes.data(), n, M, K, nperts);
// perform icm encoding
for (int j = 0; j < icmIters; j++) {
for (int m = 0; m < M; m++) {
runIcmEncodeStep<<<n, K, encodeSmem, stream>>>(
uterm[m].data(),
bterm[m].data(),
codes.data(),
M,
K,
m);
}
}
// compute the reconstruction error for each vector given codes
runEvaluation<<<n, dims, evaluateSmem, stream>>>(
x.data(),
codebooks.data(),
codes.data(),
objs.data(),
n,
M,
K,
dims);
// if objs[i] < best_objs[i], replace best_codes[i] with codes[i]
runCodesSelection<<<numBlocks, blockSize, 0, stream>>>(
bestCodes.data(),
bestObjs.data(),
codes.data(),
objs.data(),
n,
M);
codes.copyFrom(bestCodes, stream);
}
// copy back to host memory
fromDevice<int32_t, 2>(bestCodes, codesHost, stream);
}
} // namespace gpu
} // namespace faiss

View File

@ -0,0 +1,79 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <random>
namespace faiss {
namespace gpu {
struct IcmEncoderImpl {
int M; ///< number of codebooks
int K; ///< number of codewords in a codebook
int dims; ///< dimensions of a codeword
GpuResourcesProvider* prov;
std::shared_ptr<GpuResources> res;
int device;
DeviceTensor<float, 4, true> bterm; ///< bianry terms, size [M, M, K, K]
DeviceTensor<float, 3, true> codebooks; ///< codebooks, size [M, K, dims]
IcmEncoderImpl(
int M,
int K,
int dims,
GpuResourcesProvider* prov,
int device);
~IcmEncoderImpl() {}
///< copy codebooks to device memory and compute unary terms
void setBinaryTerm(const float* codebooks);
/** Compute unary terms.
*
* uterm[i] = x * codebook[i]^T, i = 1,...,M
*
* @param uterm output unary terms, size [M, n, K]
* @param x input vectors, size [n, dims]
* @param codebooks codebooks, size [M, K, dims]
* @param n number of input vectors
*/
void computeUnaryTerms(
float* bterm,
const float* x,
const float* codebooks,
int n) const;
/** Compute binary terms.
*
* bterm[i][j] = codebooks[i] * codebooks[j]^T. i, j = 1,...,M
*
* @param bterm output binary terms, size [M, M, K, K]
* @param codebooks codebooks, size [M, K, dims]
*/
void computeBinaryTerms(float* bterm, const float* codebooks) const;
///< icm encode method
void encode(
int32_t* codes,
const float* x,
const float* codebooks,
std::mt19937& gen,
int n,
int nperts,
int ilsIters,
int icmIters) const;
};
} // namespace gpu
} // namespace faiss

View File

@ -10,6 +10,7 @@ import time
import unittest
import numpy as np
import faiss
from faiss.contrib import datasets
class EvalIVFPQAccuracy(unittest.TestCase):
@ -462,5 +463,44 @@ class TestInvalidParams(unittest.TestCase):
self.assertTrue(np.array_equal(xb_indices[10:20], I[:, 0]))
class TestLSQIcmEncoder(unittest.TestCase):
@staticmethod
def eval_codec(q, xb):
codes = q.compute_codes(xb)
decoded = q.decode(codes)
return ((xb - decoded) ** 2).sum()
def subtest_gpu_encoding(self, ngpus):
"""check that the error is in the same as cpu."""
ds = datasets.SyntheticDataset(32, 1000, 1000, 0)
xt = ds.get_train()
xb = ds.get_database()
M = 4
nbits = 8
lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits)
lsq.train(xt)
err_cpu = self.eval_codec(lsq, xb)
lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits)
lsq.train(xt)
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
err_gpu = self.eval_codec(lsq, xb)
# 13804.411 vs 13814.794, 1 gpu
print(err_gpu, err_cpu)
self.assertLess(err_gpu, err_cpu * 1.05)
def test_one_gpu(self):
self.subtest_gpu_encoding(1)
def test_multiple_gpu(self):
ngpu = faiss.get_num_gpus()
self.subtest_gpu_encoding(ngpu)
if __name__ == '__main__':
unittest.main()

View File

@ -141,7 +141,8 @@ void random_int32(
namespace faiss {
LSQTimer lsq_timer;
lsq::LSQTimer lsq_timer;
using lsq::LSQTimerScope;
LocalSearchQuantizer::LocalSearchQuantizer(
size_t d,
@ -168,6 +169,12 @@ LocalSearchQuantizer::LocalSearchQuantizer(
random_seed = 0x12345;
std::srand(random_seed);
icm_encoder_factory = nullptr;
}
LocalSearchQuantizer::~LocalSearchQuantizer() {
delete icm_encoder_factory;
}
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
@ -177,8 +184,8 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
FAISS_THROW_IF_NOT(nperts <= M);
lsq_timer.reset();
LSQTimerScope scope(&lsq_timer, "train");
if (verbose) {
lsq_timer.start("train");
printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
M,
n,
@ -241,7 +248,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
}
// refine codes
icm_encode(x, codes.data(), n, train_ils_iters, gen);
icm_encode(codes.data(), x, n, train_ils_iters, gen);
if (verbose) {
float obj = evaluate(codes.data(), x, n);
@ -269,13 +276,13 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
}
if (verbose) {
lsq_timer.end("train");
float obj = evaluate(codes.data(), x, n);
scope.finish();
printf("After training: obj = %lf\n", obj);
printf("Time statistic:\n");
for (const auto& it : lsq_timer.duration) {
printf("\t%s time: %lf s\n", it.first.data(), it.second);
for (const auto& it : lsq_timer.t) {
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
}
}
}
@ -284,7 +291,7 @@ void LocalSearchQuantizer::perturb_codebooks(
float T,
const std::vector<float>& stddev,
std::mt19937& gen) {
lsq_timer.start("perturb_codebooks");
LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
std::vector<std::normal_distribution<float>> distribs;
for (size_t i = 0; i < d; i++) {
@ -298,8 +305,6 @@ void LocalSearchQuantizer::perturb_codebooks(
}
}
}
lsq_timer.end("perturb_codebooks");
}
void LocalSearchQuantizer::compute_codes(
@ -307,23 +312,26 @@ void LocalSearchQuantizer::compute_codes(
uint8_t* codes_out,
size_t n) const {
FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
lsq_timer.reset();
LSQTimerScope scope(&lsq_timer, "encode");
if (verbose) {
lsq_timer.reset();
printf("Encoding %zd vectors...\n", n);
lsq_timer.start("encode");
}
std::vector<int32_t> codes(n * M);
std::mt19937 gen(random_seed);
random_int32(codes, 0, K - 1, gen);
icm_encode(x, codes.data(), n, encode_ils_iters, gen);
icm_encode(codes.data(), x, n, encode_ils_iters, gen);
pack_codes(n, codes.data(), codes_out);
if (verbose) {
lsq_timer.end("encode");
double t = lsq_timer.get("encode");
printf("Time to encode %zd vectors: %lf s\n", n, t);
scope.finish();
printf("Time statistic:\n");
for (const auto& it : lsq_timer.t) {
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
}
}
}
@ -347,7 +355,7 @@ void LocalSearchQuantizer::update_codebooks(
const float* x,
const int32_t* codes,
size_t n) {
lsq_timer.start("update_codebooks");
LSQTimerScope scope(&lsq_timer, "update_codebooks");
if (!update_codebooks_with_double) {
// allocate memory
@ -485,8 +493,6 @@ void LocalSearchQuantizer::update_codebooks(
codebooks[i] = (float)d_codebooks[i];
}
}
lsq_timer.end("update_codebooks");
}
/** encode using iterative conditional mode
@ -508,15 +514,23 @@ void LocalSearchQuantizer::update_codebooks(
* These two terms can be precomputed and store in a look up table.
*/
void LocalSearchQuantizer::icm_encode(
const float* x,
int32_t* codes,
const float* x,
size_t n,
size_t ils_iters,
std::mt19937& gen) const {
lsq_timer.start("icm_encode");
LSQTimerScope scope(&lsq_timer, "icm_encode");
std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
compute_binary_terms(binaries.data());
auto factory = icm_encoder_factory;
std::unique_ptr<lsq::IcmEncoder> icm_encoder;
if (factory == nullptr) {
icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
} else {
icm_encoder.reset(factory->get(this));
}
// precompute binary terms for all chunks
icm_encoder->set_binary_term();
const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
for (size_t i = 0; i < n_chunks; i++) {
@ -532,23 +546,20 @@ void LocalSearchQuantizer::icm_encode(
const float* xi = x + i * chunk_size * d;
int32_t* codesi = codes + i * chunk_size * M;
icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
InterruptCallback::check();
icm_encoder->verbose = (verbose && i == 0);
icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
}
lsq_timer.end("icm_encode");
}
void LocalSearchQuantizer::icm_encode_partial(
size_t index,
const float* x,
void LocalSearchQuantizer::icm_encode_impl(
int32_t* codes,
size_t n,
const float* x,
const float* binaries,
std::mt19937& gen,
size_t n,
size_t ils_iters,
std::mt19937& gen) const {
std::vector<float> unaries(n * M * K); // [n, M, K]
bool verbose) const {
std::vector<float> unaries(n * M * K); // [M, n, K]
compute_unary_terms(x, unaries.data(), n);
std::vector<int32_t> best_codes;
@ -562,9 +573,7 @@ void LocalSearchQuantizer::icm_encode_partial(
// add perturbation to codes
perturb_codes(codes, n, gen);
for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
icm_encode_step(unaries.data(), binaries, codes, n);
}
icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
std::vector<float> icm_objs(n, 0.0f);
evaluate(codes, x, n, icm_objs.data());
@ -587,7 +596,7 @@ void LocalSearchQuantizer::icm_encode_partial(
memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
if (verbose && index == 0) {
if (verbose) {
printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
iter1,
mean_obj,
@ -598,61 +607,67 @@ void LocalSearchQuantizer::icm_encode_partial(
}
void LocalSearchQuantizer::icm_encode_step(
int32_t* codes,
const float* unaries,
const float* binaries,
int32_t* codes,
size_t n) const {
// condition on the m-th subcode
for (size_t m = 0; m < M; m++) {
std::vector<float> objs(n * K);
#pragma omp parallel for
for (int64_t i = 0; i < n; i++) {
auto u = unaries + i * (M * K) + m * K;
memcpy(objs.data() + i * K, u, sizeof(float) * K);
}
// compute objective function by adding unary
// and binary terms together
for (size_t other_m = 0; other_m < M; other_m++) {
if (other_m == m) {
continue;
}
size_t n,
size_t n_iters) const {
FAISS_THROW_IF_NOT(M != 0 && K != 0);
FAISS_THROW_IF_NOT(binaries != nullptr);
for (size_t iter = 0; iter < n_iters; iter++) {
// condition on the m-th subcode
for (size_t m = 0; m < M; m++) {
std::vector<float> objs(n * K);
#pragma omp parallel for
for (int64_t i = 0; i < n; i++) {
for (int32_t code = 0; code < K; code++) {
int32_t code2 = codes[i * M + other_m];
size_t binary_idx =
m * M * K * K + other_m * K * K + code * K + code2;
// binaries[m, other_m, code, code2]
objs[i * K + code] += binaries[binary_idx];
}
auto u = unaries + m * n * K + i * K;
memcpy(objs.data() + i * K, u, sizeof(float) * K);
}
}
// find the optimal value of the m-th subcode
// compute objective function by adding unary
// and binary terms together
for (size_t other_m = 0; other_m < M; other_m++) {
if (other_m == m) {
continue;
}
#pragma omp parallel for
for (int64_t i = 0; i < n; i++) {
float best_obj = HUGE_VALF;
int32_t best_code = 0;
for (size_t code = 0; code < K; code++) {
float obj = objs[i * K + code];
if (obj < best_obj) {
best_obj = obj;
best_code = code;
for (int64_t i = 0; i < n; i++) {
for (int32_t code = 0; code < K; code++) {
int32_t code2 = codes[i * M + other_m];
size_t binary_idx = m * M * K * K + other_m * K * K +
code * K + code2;
// binaries[m, other_m, code, code2]
objs[i * K + code] += binaries[binary_idx];
}
}
}
codes[i * M + m] = best_code;
}
} // loop M
// find the optimal value of the m-th subcode
#pragma omp parallel for
for (int64_t i = 0; i < n; i++) {
float best_obj = HUGE_VALF;
int32_t best_code = 0;
for (size_t code = 0; code < K; code++) {
float obj = objs[i * K + code];
if (obj < best_obj) {
best_obj = obj;
best_code = code;
}
}
codes[i * M + m] = best_code;
}
} // loop M
}
}
void LocalSearchQuantizer::perturb_codes(
int32_t* codes,
size_t n,
std::mt19937& gen) const {
lsq_timer.start("perturb_codes");
LSQTimerScope scope(&lsq_timer, "perturb_codes");
std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
@ -663,12 +678,10 @@ void LocalSearchQuantizer::perturb_codes(
codes[i * M + m] = k_distrib(gen);
}
}
lsq_timer.end("perturb_codes");
}
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
lsq_timer.start("compute_binary_terms");
LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
#pragma omp parallel for
for (int64_t m12 = 0; m12 < M * M; m12++) {
@ -686,52 +699,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
}
}
}
lsq_timer.end("compute_binary_terms");
}
void LocalSearchQuantizer::compute_unary_terms(
const float* x,
float* unaries,
float* unaries, // [M, n, K]
size_t n) const {
lsq_timer.start("compute_unary_terms");
LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
// compute x * codebooks^T
// compute x * codebook^T for each codebook
//
// NOTE: LAPACK use column major order
// out = alpha * op(A) * op(B) + beta * C
FINTEGER nrows_A = M * K;
FINTEGER ncols_A = d;
FINTEGER nrows_B = d;
FINTEGER ncols_B = n;
for (size_t m = 0; m < M; m++) {
FINTEGER nrows_A = K;
FINTEGER ncols_A = d;
float alpha = -2.0f;
float beta = 0.0f;
sgemm_("Transposed",
"Not Transposed",
&nrows_A, // nrows of op(A)
&ncols_B, // ncols of op(B)
&ncols_A, // ncols of op(A)
&alpha,
codebooks.data(),
&ncols_A, // nrows of A
x,
&nrows_B, // nrows of B
&beta,
unaries,
&nrows_A); // nrows of output
FINTEGER nrows_B = d;
FINTEGER ncols_B = n;
float alpha = -2.0f;
float beta = 0.0f;
sgemm_("Transposed",
"Not Transposed",
&nrows_A, // nrows of op(A)
&ncols_B, // ncols of op(B)
&ncols_A, // ncols of op(A)
&alpha,
codebooks.data() + m * K * d,
&ncols_A, // nrows of A
x,
&nrows_B, // nrows of B
&beta,
unaries + m * n * K,
&nrows_A); // nrows of output
}
std::vector<float> norms(M * K);
fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
#pragma omp parallel for
for (int64_t i = 0; i < n; i++) {
float* u = unaries + i * (M * K);
fvec_add(M * K, u, norms.data(), u);
for (size_t m = 0; m < M; m++) {
float* u = unaries + m * n * K + i * K;
fvec_add(K, u, norms.data() + m * K, u);
}
}
lsq_timer.end("compute_unary_terms");
}
float LocalSearchQuantizer::evaluate(
@ -739,7 +753,7 @@ float LocalSearchQuantizer::evaluate(
const float* x,
size_t n,
float* objs) const {
lsq_timer.start("evaluate");
LSQTimerScope scope(&lsq_timer, "evaluate");
// decode
std::vector<float> decoded_x(n * d, 0.0f);
@ -755,7 +769,7 @@ float LocalSearchQuantizer::evaluate(
fvec_add(d, decoded_i, c, decoded_i);
}
float err = fvec_L2sqr(x + i * d, decoded_i, d);
float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
obj += err;
if (objs) {
@ -763,34 +777,68 @@ float LocalSearchQuantizer::evaluate(
}
}
lsq_timer.end("evaluate");
obj = obj / n;
return obj;
}
namespace lsq {
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
: verbose(false), lsq(lsq) {}
void IcmEncoder::set_binary_term() {
auto M = lsq->M;
auto K = lsq->K;
binaries.resize(M * M * K * K);
lsq->compute_binary_terms(binaries.data());
}
void IcmEncoder::encode(
int32_t* codes,
const float* x,
std::mt19937& gen,
size_t n,
size_t ils_iters) const {
lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
}
double LSQTimer::get(const std::string& name) {
return duration[name];
if (t.count(name) == 0) {
return 0.0;
} else {
return t[name];
}
}
void LSQTimer::start(const std::string& name) {
FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
started[name] = true;
t0[name] = getmillisecs();
}
void LSQTimer::end(const std::string& name) {
FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
double t1 = getmillisecs();
double sec = (t1 - t0[name]) / 1000;
duration[name] += sec;
started[name] = false;
void LSQTimer::add(const std::string& name, double delta) {
if (t.count(name) == 0) {
t[name] = delta;
} else {
t[name] += delta;
}
}
void LSQTimer::reset() {
duration.clear();
t0.clear();
started.clear();
t.clear();
}
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
: timer(timer), name(name), finished(false) {
t0 = getmillisecs();
}
void LSQTimerScope::finish() {
if (!finished) {
auto delta = getmillisecs() - t0;
timer->add(name, delta);
finished = true;
}
}
LSQTimerScope::~LSQTimerScope() {
finish();
}
} // namespace lsq
} // namespace faiss

View File

@ -20,6 +20,12 @@
namespace faiss {
namespace lsq {
struct IcmEncoderFactory;
} // namespace lsq
/** Implementation of LSQ/LSQ++ described in the following two papers:
*
* Revisiting additive quantization
@ -36,7 +42,6 @@ namespace faiss {
* The trained codes are stored in `codebooks` which is called
* `centroids` in PQ and RQ.
*/
struct LocalSearchQuantizer : AdditiveQuantizer {
size_t K; ///< number of codes per codebook
@ -54,6 +59,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
int random_seed; ///< seed for random generator
size_t nperts; ///< number of perturbation in each code
///< if non-NULL, use this encoder to encode
lsq::IcmEncoderFactory* icm_encoder_factory;
bool update_codebooks_with_double = true;
LocalSearchQuantizer(
@ -66,6 +74,8 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
LocalSearchQuantizer();
~LocalSearchQuantizer() override;
// Train the local search quantizer
void train(size_t n, const float* x) override;
@ -73,6 +83,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
*
* @param x vectors to encode, size n * d
* @param codes output codes, size n * code_size
* @param n number of vectors
*/
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
@ -80,36 +91,46 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
*
* @param x training vectors, size n * d
* @param codes encoded training vectors, size n * M
* @param n number of vectors
*/
void update_codebooks(const float* x, const int32_t* codes, size_t n);
/** Encode vectors given codebooks using iterative conditional mode (icm).
*
* @param x vectors to encode, size n * d
* @param codes output codes, size n * M
* @param codes output codes, size n * M
* @param x vectors to encode, size n * d
* @param n number of vectors
* @param ils_iters number of iterations of iterative local search
*/
void icm_encode(
const float* x,
int32_t* codes,
const float* x,
size_t n,
size_t ils_iters,
std::mt19937& gen) const;
void icm_encode_partial(
size_t index,
const float* x,
void icm_encode_impl(
int32_t* codes,
const float* x,
const float* unaries,
std::mt19937& gen,
size_t n,
const float* binaries,
size_t ils_iters,
std::mt19937& gen) const;
bool verbose) const;
void icm_encode_step(
int32_t* codes,
const float* unaries,
const float* binaries,
int32_t* codes,
size_t n) const;
size_t n,
size_t n_iters) const;
/** Add some perturbation to codes
*
* @param codes codes to be perturbed, size n * M
* @param n number of vectors
*/
void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
/** Add some perturbation to codebooks
*
@ -121,12 +142,6 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
const std::vector<float>& stddev,
std::mt19937& gen);
/** Add some perturbation to codes
*
* @param codes codes to be perturbed, size n * M
*/
void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
/** Compute binary terms
*
* @param binaries binary terms, size M * M * K * K
@ -135,6 +150,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
/** Compute unary terms
*
* @param n number of vectors
* @param x vectors to encode, size n * d
* @param unaries unary terms, size n * M * K
*/
@ -142,8 +158,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
/** Helper function to compute reconstruction error
*
* @param x vectors to encode, size n * d
* @param codes encoded codes, size n * M
* @param x vectors to encode, size n * d
* @param n number of vectors
* @param objs if it is not null, store reconstruction
error of each vector into it, size n
*/
@ -154,13 +171,50 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
float* objs = nullptr) const;
};
namespace lsq {
struct IcmEncoder {
std::vector<float> binaries;
bool verbose;
const LocalSearchQuantizer* lsq;
explicit IcmEncoder(const LocalSearchQuantizer* lsq);
virtual ~IcmEncoder() {}
///< compute binary terms
virtual void set_binary_term();
/** Encode vectors given codebooks
*
* @param codes output codes, size n * M
* @param x vectors to encode, size n * d
* @param gen random generator
* @param n number of vectors
* @param ils_iters number of iterations of iterative local search
*/
virtual void encode(
int32_t* codes,
const float* x,
std::mt19937& gen,
size_t n,
size_t ils_iters) const;
};
struct IcmEncoderFactory {
virtual IcmEncoder* get(const LocalSearchQuantizer* lsq) {
return new IcmEncoder(lsq);
}
virtual ~IcmEncoderFactory() {}
};
/** A helper struct to count consuming time during training.
* It is NOT thread-safe.
*/
struct LSQTimer {
std::unordered_map<std::string, double> duration;
std::unordered_map<std::string, double> t0;
std::unordered_map<std::string, bool> started;
std::unordered_map<std::string, double> t;
LSQTimer() {
reset();
@ -168,13 +222,24 @@ struct LSQTimer {
double get(const std::string& name);
void start(const std::string& name);
void end(const std::string& name);
void add(const std::string& name, double delta);
void reset();
};
FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
struct LSQTimerScope {
double t0;
LSQTimer* timer;
std::string name;
bool finished;
LSQTimerScope(LSQTimer* timer, std::string name);
void finish();
~LSQTimerScope();
};
} // namespace lsq
} // namespace faiss

View File

@ -287,6 +287,7 @@ void gpu_sync_all_devices();
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuCloner.h>
#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuIcmEncoder.h>
int get_num_gpus()
{
@ -490,6 +491,7 @@ void gpu_sync_all_devices()
%include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>
%include <faiss/gpu/GpuIndexBinaryFlat.h>
%include <faiss/gpu/GpuDistance.h>
%include <faiss/gpu/GpuIcmEncoder.h>
#endif

View File

@ -15,6 +15,9 @@ import unittest
from faiss.contrib import datasets
sp = faiss.swig_ptr
def construct_sparse_matrix(codes, K):
n, M = codes.shape
B = np.zeros((n, M * K), dtype=np.float32)
@ -54,18 +57,18 @@ def compute_binary_terms_ref(codebooks):
def compute_unary_terms_ref(codebooks, x):
codebooks_t = np.swapaxes(codebooks, 1, 2) # [M, d, K]
unaries = -2 * x.dot(codebooks_t) # [n, M, K]
code_norms = np.sum(codebooks * codebooks, axis=2) # [M, K]
unaries += code_norms
unaries = np.swapaxes(unaries, 0, 1) # [M, n, K]
return unaries
def icm_encode_step_ref(unaries, binaries, codes):
n, M, K = unaries.shape
M, n, K = unaries.shape
for m in range(M):
objs = unaries[:, m].copy() # [n, K]
objs = unaries[m].copy() # [n, K]
for m2 in range(M): # pair, m2 != m
if m2 == m:
@ -131,8 +134,8 @@ class TestComponents(unittest.TestCase):
# decode x
pack_codes = np.zeros((n, lsq.code_size)).astype(np.uint8)
decoded_x = np.zeros((n, d)).astype(np.float32)
lsq.pack_codes(n, faiss.swig_ptr(codes), faiss.swig_ptr(pack_codes))
lsq.decode_c(faiss.swig_ptr(pack_codes), faiss.swig_ptr(decoded_x), n)
lsq.pack_codes(n, sp(codes), sp(pack_codes))
lsq.decode_c(sp(pack_codes), sp(decoded_x), n)
# decode in Python
codebooks = faiss.vector_float_to_array(lsq.codebooks)
@ -163,7 +166,7 @@ class TestComponents(unittest.TestCase):
codebooks = faiss.vector_float_to_array(lsq.codebooks)
codebooks = codebooks.reshape(M, K, d).copy()
lsq.update_codebooks(faiss.swig_ptr(x), faiss.swig_ptr(codes), n)
lsq.update_codebooks(sp(x), sp(codes), n)
new_codebooks = faiss.vector_float_to_array(lsq.codebooks)
new_codebooks = new_codebooks.reshape(M, K, d).copy()
@ -209,7 +212,7 @@ class TestComponents(unittest.TestCase):
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.train(x) # just for allocating memory for codebooks
lsq.compute_binary_terms(faiss.swig_ptr(binaries))
lsq.compute_binary_terms(sp(binaries))
codebooks = faiss.vector_float_to_array(lsq.codebooks)
codebooks = codebooks.reshape(M, K, d).copy()
@ -226,12 +229,12 @@ class TestComponents(unittest.TestCase):
rs = np.random.RandomState(123)
x = rs.rand(n, d).astype(np.float32)
unaries = np.zeros((n, M, K)).astype(np.float32)
unaries = np.zeros((M, n, K)).astype(np.float32)
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.train(x) # just for allocating memory for codebooks
lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n)
lsq.compute_unary_terms(sp(x), sp(unaries), n)
codebooks = faiss.vector_float_to_array(lsq.codebooks)
codebooks = codebooks.reshape(M, K, d).copy()
@ -251,15 +254,17 @@ class TestComponents(unittest.TestCase):
# randomly generate codes, binary terms and unary terms
codes = rs.randint(0, K, (n, M)).astype(np.int32)
new_codes = codes.copy()
unaries = rs.rand(n, M, K).astype(np.float32)
unaries = rs.rand(M, n, K).astype(np.float32)
binaries = rs.rand(M, M, K, K).astype(np.float32)
# do icm encoding given binary and unary terms
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
lsq.icm_encode_step(
faiss.swig_ptr(unaries),
faiss.swig_ptr(binaries),
faiss.swig_ptr(new_codes), n)
sp(new_codes),
sp(unaries),
sp(binaries),
n,
1)
# do icm encoding given binary and unary terms in Python
ref_codes = icm_encode_step_ref(unaries, binaries, codes)
@ -280,11 +285,11 @@ class TestComponents(unittest.TestCase):
# compute binary terms
binaries = np.zeros((M, M, K, K)).astype(np.float32)
lsq.compute_binary_terms(faiss.swig_ptr(binaries))
lsq.compute_binary_terms(sp(binaries))
# compute unary terms
unaries = np.zeros((n, M, K)).astype(np.float32)
lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n)
unaries = np.zeros((M, n, K)).astype(np.float32)
lsq.compute_unary_terms(sp(x), sp(unaries), n)
# randomly generate codes
codes = rs.randint(0, K, (n, M)).astype(np.int32)
@ -292,9 +297,11 @@ class TestComponents(unittest.TestCase):
# do icm encoding given binary and unary terms
lsq.icm_encode_step(
faiss.swig_ptr(unaries),
faiss.swig_ptr(binaries),
faiss.swig_ptr(new_codes), n)
sp(new_codes),
sp(unaries),
sp(binaries),
n,
1)
# do icm encoding without pre-computed unary and bianry terms in Python
codebooks = faiss.vector_float_to_array(lsq.codebooks)