Support LSQ on GPU (#1978)
Summary: ## Description This PR added support for LSQ on GPU. Only the encoding part is running on GPU and the others are still running on CPU. Multi-GPU is also supported. ## Usage ``` python lsq = faiss.LocalSearchQuantizer(d, M, nbits) ngpus = faiss.get_num_gpus() lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus) # we use all gpus lsq.train(xt) codes = lsq.compute_codes(xb) decoded = lsq.decode(codes) ``` ## Performance on SIFT1M On 1 GPU: ``` ===== lsq-gpu: mean square error = 17337.878528 training time: 40.9857234954834 s encoding time: 27.12640070915222 s ``` On 2 GPUs: ``` ===== lsq-gpu: mean square error = 17364.658176 training time: 25.832106113433838 s encoding time: 14.879548072814941 s ``` On CPU: ``` ===== lsq: mean square error = 17305.880576 training time: 152.57522344589233 s encoding time: 110.01779270172119 s ``` Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1978 Test Plan: buck test mode/dev-nosan //faiss/gpu/test/:test_gpu_index_py -- TestLSQIcmEncoder Reviewed By: wickedfoo Differential Revision: D29609763 Pulled By: mdouze fbshipit-source-id: b6ffa2a3c02bf696a4e52348132affa0dd838870pull/2055/head
parent
8a2860c1dc
commit
eba1cb1a90
|
@ -9,6 +9,8 @@ We try to indicate most contributions here with the contributor names who are no
|
|||
the Facebook Faiss team. Feel free to add entries here if you submit a PR.
|
||||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
- Support LSQ on GPU (by @KinglittleQ)
|
||||
|
||||
## [1.7.1] - 2021-05-27
|
||||
### Added
|
||||
|
|
|
@ -82,6 +82,13 @@ nt, d = xt.shape
|
|||
|
||||
# fastest to slowest
|
||||
|
||||
if 'lsq-gpu' in todo:
|
||||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
ngpus = faiss.get_num_gpus()
|
||||
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
|
||||
lsq.verbose = True
|
||||
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
|
||||
|
||||
if 'pq' in todo:
|
||||
pq = faiss.ProductQuantizer(d, M, nbits)
|
||||
print("===== PQ")
|
||||
|
|
|
@ -9,6 +9,7 @@ set(FAISS_GPU_SRC
|
|||
GpuCloner.cpp
|
||||
GpuClonerOptions.cpp
|
||||
GpuDistance.cu
|
||||
GpuIcmEncoder.cu
|
||||
GpuIndex.cu
|
||||
GpuIndexBinaryFlat.cu
|
||||
GpuIndexFlat.cu
|
||||
|
@ -46,6 +47,7 @@ set(FAISS_GPU_SRC
|
|||
impl/scan/IVFInterleaved512.cu
|
||||
impl/scan/IVFInterleaved1024.cu
|
||||
impl/scan/IVFInterleaved2048.cu
|
||||
impl/IcmEncoder.cu
|
||||
utils/BlockSelectFloat.cu
|
||||
utils/DeviceUtils.cu
|
||||
utils/StackDeviceMemory.cpp
|
||||
|
@ -80,6 +82,7 @@ set(FAISS_GPU_HEADERS
|
|||
GpuCloner.h
|
||||
GpuClonerOptions.h
|
||||
GpuDistance.h
|
||||
GpuIcmEncoder.h
|
||||
GpuFaissAssert.h
|
||||
GpuIndex.h
|
||||
GpuIndexBinaryFlat.h
|
||||
|
@ -118,6 +121,7 @@ set(FAISS_GPU_HEADERS
|
|||
impl/RemapIndices.h
|
||||
impl/VectorResidual.cuh
|
||||
impl/scan/IVFInterleavedImpl.cuh
|
||||
impl/IcmEncoder.cuh
|
||||
utils/BlockSelectKernel.cuh
|
||||
utils/Comparators.cuh
|
||||
utils/ConversionOperators.cuh
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <faiss/gpu/GpuIcmEncoder.h>
|
||||
|
||||
#include <faiss/gpu/StandardGpuResources.h>
|
||||
#include <faiss/utils/WorkerThread.h>
|
||||
#include <faiss/gpu/impl/IcmEncoder.cuh>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
///< A helper structure to support multi-GPU
|
||||
struct IcmEncoderShards {
|
||||
std::vector<std::pair<
|
||||
std::unique_ptr<IcmEncoderImpl>,
|
||||
std::unique_ptr<WorkerThread>>>
|
||||
workers;
|
||||
|
||||
void add(IcmEncoderImpl* encoder) {
|
||||
workers.emplace_back(std::make_pair(
|
||||
std::unique_ptr<IcmEncoderImpl>(encoder),
|
||||
std::unique_ptr<WorkerThread>(new WorkerThread)));
|
||||
}
|
||||
|
||||
IcmEncoderImpl* at(int idx) {
|
||||
return workers[idx].first.get();
|
||||
}
|
||||
|
||||
///< call f(idx, encoder) for each encoder
|
||||
void runOnShards(std::function<void(int, IcmEncoderImpl*)> f) {
|
||||
std::vector<std::future<bool>> v;
|
||||
|
||||
for (int i = 0; i < this->workers.size(); ++i) {
|
||||
auto& p = this->workers[i];
|
||||
auto encoder = p.first.get();
|
||||
v.emplace_back(p.second->add([f, i, encoder]() { f(i, encoder); }));
|
||||
}
|
||||
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
auto& fut = v[i];
|
||||
fut.get(); // no exception handle, crash if any thread down
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() {
|
||||
return workers.size();
|
||||
}
|
||||
};
|
||||
|
||||
GpuIcmEncoder::GpuIcmEncoder(
|
||||
const LocalSearchQuantizer* lsq,
|
||||
const std::vector<GpuResourcesProvider*>& provs,
|
||||
const std::vector<int>& devices)
|
||||
: lsq::IcmEncoder(lsq), shards(new IcmEncoderShards()) {
|
||||
// create an IcmEncoderImpl instance for each device.
|
||||
for (size_t i = 0; i < provs.size(); i++) {
|
||||
shards->add(new IcmEncoderImpl(
|
||||
lsq->M, lsq->K, lsq->d, provs[i], devices[i]));
|
||||
}
|
||||
}
|
||||
|
||||
GpuIcmEncoder::~GpuIcmEncoder() {}
|
||||
|
||||
void GpuIcmEncoder::set_binary_term() {
|
||||
auto fn = [=](int idx, IcmEncoderImpl* encoder) {
|
||||
encoder->setBinaryTerm(lsq->codebooks.data());
|
||||
};
|
||||
shards->runOnShards(fn);
|
||||
}
|
||||
|
||||
void GpuIcmEncoder::encode(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
size_t ils_iters) const {
|
||||
size_t nshards = shards->size();
|
||||
size_t shard_size = (n + nshards - 1) / nshards;
|
||||
|
||||
auto codebooks = lsq->codebooks.data();
|
||||
auto M = lsq->M;
|
||||
auto d = lsq->d;
|
||||
auto nperts = lsq->nperts;
|
||||
auto icm_iters = lsq->icm_iters;
|
||||
|
||||
auto seed = gen();
|
||||
|
||||
// split input data
|
||||
auto fn = [=](int idx, IcmEncoderImpl* encoder) {
|
||||
size_t i0 = idx * shard_size;
|
||||
size_t ni = std::min(shard_size, n - i0);
|
||||
auto xi = x + i0 * d;
|
||||
auto ci = codes + i0 * M;
|
||||
std::mt19937 geni(idx + seed); // different seed for each shard
|
||||
encoder->encode(
|
||||
ci, xi, codebooks, geni, ni, nperts, ils_iters, icm_iters);
|
||||
};
|
||||
shards->runOnShards(fn);
|
||||
}
|
||||
|
||||
GpuIcmEncoderFactory::GpuIcmEncoderFactory(int ngpus) {
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
provs.push_back(new StandardGpuResources());
|
||||
devices.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
lsq::IcmEncoder* GpuIcmEncoderFactory::get(const LocalSearchQuantizer* lsq) {
|
||||
return new GpuIcmEncoder(lsq, provs, devices);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/impl/LocalSearchQuantizer.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
class GpuResourcesProvider;
|
||||
struct IcmEncoderShards;
|
||||
|
||||
/** Perform LSQ encoding on GPU.
|
||||
*
|
||||
* Split input vectors to different devices and call IcmEncoderImpl::encode
|
||||
* to encode them
|
||||
*/
|
||||
class GpuIcmEncoder : public lsq::IcmEncoder {
|
||||
public:
|
||||
GpuIcmEncoder(
|
||||
const LocalSearchQuantizer* lsq,
|
||||
const std::vector<GpuResourcesProvider*>& provs,
|
||||
const std::vector<int>& devices);
|
||||
|
||||
~GpuIcmEncoder();
|
||||
|
||||
GpuIcmEncoder(const GpuIcmEncoder&) = delete;
|
||||
GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete;
|
||||
|
||||
void set_binary_term() override;
|
||||
|
||||
void encode(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
size_t ils_iters) const override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<IcmEncoderShards> shards;
|
||||
};
|
||||
|
||||
struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory {
|
||||
explicit GpuIcmEncoderFactory(int ngpus = 1);
|
||||
|
||||
lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override;
|
||||
|
||||
std::vector<GpuResourcesProvider*> provs;
|
||||
std::vector<int> devices;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
|
@ -0,0 +1,379 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <faiss/gpu/impl/IcmEncoder.cuh>
|
||||
|
||||
#include <faiss/gpu/GpuResources.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/gpu/impl/L2Norm.cuh>
|
||||
#include <faiss/gpu/utils/CopyUtils.cuh>
|
||||
#include <faiss/gpu/utils/DeviceDefs.cuh>
|
||||
#include <faiss/gpu/utils/DeviceTensor.cuh>
|
||||
#include <faiss/gpu/utils/MatrixMult.cuh>
|
||||
#include <faiss/gpu/utils/Pair.cuh>
|
||||
#include <faiss/gpu/utils/Reductions.cuh>
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
/** encode using iterative conditional mode
|
||||
*
|
||||
* For subcode cm of a vector, we fix the other subcodes cj (j != m)
|
||||
* and then find the optimal value of cm (cm = 1,...,K) such that
|
||||
* minimizing the objective function.
|
||||
*
|
||||
* @param uterm precomputed unary terms, size (M, n, K)
|
||||
* @param bterm precomputed binary terms, size (M1, M2, K1, K2)
|
||||
* @param codes output vector encodings, size (n, M)
|
||||
* @param M number of codebooks
|
||||
* @param K number of codewords in a codebook
|
||||
* @param m identify which subcode to condition on
|
||||
*/
|
||||
__global__ void runIcmEncodeStep(
|
||||
const float* uterm,
|
||||
const float* bterm,
|
||||
int32_t* codes,
|
||||
int M,
|
||||
int K,
|
||||
int m) {
|
||||
using KVPair = Pair<float, int>;
|
||||
|
||||
int id = blockIdx.x; // each block takes care of one vector
|
||||
int code = threadIdx.x; // each thread takes care of one possible code
|
||||
|
||||
// compute the objective value by look-up tables
|
||||
KVPair obj(0.0f, code);
|
||||
obj.k = uterm[id * K + code];
|
||||
|
||||
#pragma unroll
|
||||
for (int m2 = 0; m2 < M; m2++) {
|
||||
if (m2 == m) {
|
||||
continue;
|
||||
}
|
||||
int32_t code2 = codes[id * M + m2];
|
||||
obj.k += bterm[m2 * K * K + code * K + code2];
|
||||
}
|
||||
|
||||
// find the minimum objective value and the corresponding code
|
||||
__syncthreads();
|
||||
obj = blockReduceAll<KVPair, Min<KVPair>, false, false>(
|
||||
obj, Min<KVPair>(), (KVPair*)smem);
|
||||
|
||||
if (code == 0) {
|
||||
codes[id * M + m] = obj.v;
|
||||
}
|
||||
}
|
||||
|
||||
/** compute reconstruction error for each vector
|
||||
*
|
||||
* decoded_x[i] = \sum codebooks[m][codes[i][m]], m = 1,..,M
|
||||
* obj[i] = ||x[i] - decoded_x[i]||^2
|
||||
*
|
||||
* @param x input vectors, size [n, dims]
|
||||
* @param codebooks codebooks, size [M, K, dims]
|
||||
* @param codes vector codes, size [n, M]
|
||||
* @param obj output reconstruction errors, size [n]
|
||||
* @param n number of input vectors
|
||||
* @param K number of codewords in a codebook
|
||||
* @param M number of codebooks
|
||||
*/
|
||||
__global__ void runEvaluation(
|
||||
const float* x,
|
||||
const float* codebooks,
|
||||
const int32_t* codes,
|
||||
float* obj, // output
|
||||
int n,
|
||||
int M,
|
||||
int K,
|
||||
int dims) {
|
||||
int id = blockIdx.x; // each block takes care of one vector
|
||||
int d = threadIdx.x; // each thread takes care of one dimension
|
||||
float acc = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M; m++) {
|
||||
int32_t code = codes[id * M + m];
|
||||
acc += codebooks[m * K * dims + code * dims + d];
|
||||
}
|
||||
|
||||
acc -= x[id * dims + d];
|
||||
acc = acc * acc;
|
||||
|
||||
// sum values of all dimensions together
|
||||
__syncthreads();
|
||||
acc = blockReduceAllSum<float, false, false>(acc, (float*)smem);
|
||||
|
||||
if (d == 0) {
|
||||
obj[id] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
/** perturb vector codes
|
||||
*
|
||||
* repeat nperts times:
|
||||
* codes[i][randint(0, M)] = randint(0, K)
|
||||
*
|
||||
* @param seed random seed
|
||||
* @param codes vector codes, size [n, M]
|
||||
* @param n number of input vectors
|
||||
* @param M number of codebooks
|
||||
* @param K number of codewords in a codebook
|
||||
* @param nperts number of subcode to be perturbed in a vector
|
||||
*/
|
||||
__global__ void runCodesPerturbation(
|
||||
int seed,
|
||||
int32_t* codes,
|
||||
int n,
|
||||
int M,
|
||||
int K,
|
||||
int nperts) {
|
||||
// each thread takes care of one vector
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (id >= n) {
|
||||
return;
|
||||
}
|
||||
|
||||
// we have to initialize the state
|
||||
curandState_t state;
|
||||
curand_init(seed, id, 0, &state);
|
||||
|
||||
for (int i = 0; i < nperts; i++) {
|
||||
int pos = int(curand_uniform(&state) * M);
|
||||
int32_t val = int32_t(curand_uniform(&state) * K);
|
||||
codes[id * M + pos] = val;
|
||||
}
|
||||
}
|
||||
|
||||
/** select the best codes by reconstruction errors
|
||||
*
|
||||
* if objs[i] < best_objs[i]:
|
||||
* best_objs[i] = objs[i]
|
||||
* best_codes[i] = codes[i]
|
||||
*
|
||||
* @param bestCodes the best codes we've encountered, size [n, M]
|
||||
* @param bestObjs min reconstruction errors we've encountered, size [n]
|
||||
* @param codes input vector codes, size [n, M]
|
||||
* @param objs reconstruction errors of input vector codes, size [n]
|
||||
* @param n number of input vectors
|
||||
*/
|
||||
__global__ void runCodesSelection(
|
||||
int32_t* bestCodes,
|
||||
float* bestObjs,
|
||||
const int32_t* codes,
|
||||
const float* objs,
|
||||
int n,
|
||||
int M) {
|
||||
// each thread takes care of one vector
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (id >= n || objs[id] >= bestObjs[id]) {
|
||||
return;
|
||||
}
|
||||
|
||||
bestObjs[id] = objs[id];
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M; m++) {
|
||||
bestCodes[id * M + m] = codes[id * M + m];
|
||||
}
|
||||
}
|
||||
|
||||
/** add L2 norm of codewords in a codebook to the unary terms
|
||||
*
|
||||
* uterm[i][k] = norm[k]
|
||||
*
|
||||
* @param uterm unary terms, size [n, K]
|
||||
* @param norm L2 norm of each codeword in a codebook, size [K]
|
||||
* @param K number of codewords in a codebook
|
||||
*/
|
||||
__global__ void runNormAddition(float* uterm, const float* norm, int K) {
|
||||
int id = blockIdx.x;
|
||||
int code = threadIdx.x;
|
||||
|
||||
uterm[id * K + code] += norm[code];
|
||||
}
|
||||
|
||||
IcmEncoderImpl::IcmEncoderImpl(
|
||||
int M,
|
||||
int K,
|
||||
int dims,
|
||||
GpuResourcesProvider* prov,
|
||||
int device)
|
||||
: M(M), K(K), dims(dims), prov(prov), device(device) {
|
||||
res = prov->getResources();
|
||||
}
|
||||
|
||||
void IcmEncoderImpl::computeUnaryTerms(
|
||||
float* uterm, // output, [M, n, K]
|
||||
const float* x, // [n, d]
|
||||
const float* codebooks, // [M, K, d]
|
||||
int n) const {
|
||||
auto stream = res->getDefaultStreamCurrentDevice();
|
||||
auto handle = res->getBlasHandleCurrentDevice();
|
||||
|
||||
DeviceTensor<float, 2, true> vecs(const_cast<float*>(x), {n, dims});
|
||||
for (int m = 0; m < M; m++) {
|
||||
auto cPtr = const_cast<float*>(codebooks + m * K * dims);
|
||||
auto bPtr = uterm + m * n * K;
|
||||
DeviceTensor<float, 2, true> ci(cPtr, {K, dims});
|
||||
DeviceTensor<float, 2, true> bi(bPtr, {n, K});
|
||||
runMatrixMult(
|
||||
bi, false, vecs, false, ci, true, -2.0f, 0.0f, handle, stream);
|
||||
}
|
||||
|
||||
DeviceTensor<float, 2, true> c(
|
||||
const_cast<float*>(codebooks), {M * K, dims});
|
||||
DeviceTensor<float, 1, true> norm(
|
||||
res.get(), makeTempAlloc(AllocType::Other, stream), {M * K});
|
||||
runL2Norm(c, true, norm, true, stream);
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
auto uPtr = uterm + m * n * K;
|
||||
auto nPtr = norm.data() + m * K;
|
||||
runNormAddition<<<n, K, 0, stream>>>(uPtr, nPtr, K);
|
||||
}
|
||||
}
|
||||
|
||||
void IcmEncoderImpl::computeBinaryTerms(float* bterm, const float* codebooks)
|
||||
const {
|
||||
auto stream = res->getDefaultStreamCurrentDevice();
|
||||
auto handle = res->getBlasHandleCurrentDevice();
|
||||
|
||||
for (int m1 = 0; m1 < M; m1++) {
|
||||
for (int m2 = 0; m2 < M; m2++) {
|
||||
auto ptr1 = const_cast<float*>(codebooks + m1 * K * dims);
|
||||
auto ptr2 = const_cast<float*>(codebooks + m2 * K * dims);
|
||||
auto ptr3 = bterm + m1 * M * K * K + m2 * K * K;
|
||||
DeviceTensor<float, 2, true> c1(ptr1, {K, dims});
|
||||
DeviceTensor<float, 2, true> c2(ptr2, {K, dims});
|
||||
DeviceTensor<float, 2, true> b(ptr3, {K, K});
|
||||
runMatrixMult(
|
||||
b, false, c1, false, c2, true, 2.0f, 0.0f, handle, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void IcmEncoderImpl::setBinaryTerm(const float* codebooksHost) {
|
||||
DeviceScope scope(device);
|
||||
auto device = getCurrentDevice();
|
||||
auto stream = res->getDefaultStreamCurrentDevice();
|
||||
|
||||
// copy from host to device memory
|
||||
codebooks = toDeviceNonTemporary<float, 3>(
|
||||
res.get(),
|
||||
device,
|
||||
const_cast<float*>(codebooksHost),
|
||||
stream,
|
||||
{M, K, dims});
|
||||
bterm = DeviceTensor<float, 4, true>(
|
||||
res.get(), makeDevAlloc(AllocType::Other, stream), {M, M, K, K});
|
||||
computeBinaryTerms(bterm.data(), codebooks.data());
|
||||
}
|
||||
|
||||
void IcmEncoderImpl::encode(
|
||||
int32_t* codesHost,
|
||||
const float* xHost,
|
||||
const float* codebooksHost,
|
||||
std::mt19937& gen,
|
||||
int n,
|
||||
int nperts,
|
||||
int ilsIters,
|
||||
int icmIters) const {
|
||||
DeviceScope scope(device);
|
||||
auto device = getCurrentDevice();
|
||||
auto stream = res->getDefaultStreamCurrentDevice();
|
||||
|
||||
// copy from host to device memory
|
||||
auto codes = toDeviceTemporary<int32_t, 2>(
|
||||
res.get(), device, const_cast<int32_t*>(codesHost), stream, {n, M});
|
||||
auto x = toDeviceTemporary<float, 2>(
|
||||
res.get(), device, const_cast<float*>(xHost), stream, {n, dims});
|
||||
|
||||
// compute unary terms
|
||||
DeviceTensor<float, 3, true> uterm(
|
||||
res.get(), makeTempAlloc(AllocType::Other, stream), {M, n, K});
|
||||
computeUnaryTerms(uterm.data(), x.data(), codebooks.data(), n);
|
||||
|
||||
DeviceTensor<int32_t, 2, true> bestCodes(
|
||||
res.get(), makeTempAlloc(AllocType::Other, stream), {n, M});
|
||||
fromDevice<int32_t, 2>(codes, bestCodes.data(), stream);
|
||||
|
||||
DeviceTensor<float, 1, true> bestObjs(
|
||||
res.get(), makeTempAlloc(AllocType::Other, stream), {n});
|
||||
|
||||
DeviceTensor<float, 1, true> objs(
|
||||
res.get(), makeTempAlloc(AllocType::Other, stream), {n});
|
||||
|
||||
// compute how much shared memory we need
|
||||
const int evaluateSmem = sizeof(float) * (dims + kWarpSize - 1) / kWarpSize;
|
||||
const int encodeSmem =
|
||||
sizeof(Pair<float, int>) * (K + kWarpSize - 1) / kWarpSize;
|
||||
|
||||
// compute the reconstruction error for each vector
|
||||
runEvaluation<<<n, dims, evaluateSmem, stream>>>(
|
||||
x.data(),
|
||||
codebooks.data(),
|
||||
codes.data(),
|
||||
bestObjs.data(),
|
||||
n,
|
||||
M,
|
||||
K,
|
||||
dims);
|
||||
|
||||
int blockSize = 256;
|
||||
int numBlocks = (n + blockSize - 1) / blockSize;
|
||||
|
||||
for (int i = 0; i < ilsIters; i++) {
|
||||
runCodesPerturbation<<<numBlocks, blockSize, 0, stream>>>(
|
||||
gen(), codes.data(), n, M, K, nperts);
|
||||
|
||||
// perform icm encoding
|
||||
for (int j = 0; j < icmIters; j++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
runIcmEncodeStep<<<n, K, encodeSmem, stream>>>(
|
||||
uterm[m].data(),
|
||||
bterm[m].data(),
|
||||
codes.data(),
|
||||
M,
|
||||
K,
|
||||
m);
|
||||
}
|
||||
}
|
||||
|
||||
// compute the reconstruction error for each vector given codes
|
||||
runEvaluation<<<n, dims, evaluateSmem, stream>>>(
|
||||
x.data(),
|
||||
codebooks.data(),
|
||||
codes.data(),
|
||||
objs.data(),
|
||||
n,
|
||||
M,
|
||||
K,
|
||||
dims);
|
||||
|
||||
// if objs[i] < best_objs[i], replace best_codes[i] with codes[i]
|
||||
runCodesSelection<<<numBlocks, blockSize, 0, stream>>>(
|
||||
bestCodes.data(),
|
||||
bestObjs.data(),
|
||||
codes.data(),
|
||||
objs.data(),
|
||||
n,
|
||||
M);
|
||||
|
||||
codes.copyFrom(bestCodes, stream);
|
||||
}
|
||||
|
||||
// copy back to host memory
|
||||
fromDevice<int32_t, 2>(bestCodes, codesHost, stream);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/gpu/GpuResources.h>
|
||||
#include <faiss/gpu/utils/DeviceTensor.cuh>
|
||||
|
||||
#include <random>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
struct IcmEncoderImpl {
|
||||
int M; ///< number of codebooks
|
||||
int K; ///< number of codewords in a codebook
|
||||
int dims; ///< dimensions of a codeword
|
||||
|
||||
GpuResourcesProvider* prov;
|
||||
std::shared_ptr<GpuResources> res;
|
||||
int device;
|
||||
|
||||
DeviceTensor<float, 4, true> bterm; ///< bianry terms, size [M, M, K, K]
|
||||
DeviceTensor<float, 3, true> codebooks; ///< codebooks, size [M, K, dims]
|
||||
|
||||
IcmEncoderImpl(
|
||||
int M,
|
||||
int K,
|
||||
int dims,
|
||||
GpuResourcesProvider* prov,
|
||||
int device);
|
||||
|
||||
~IcmEncoderImpl() {}
|
||||
|
||||
///< copy codebooks to device memory and compute unary terms
|
||||
void setBinaryTerm(const float* codebooks);
|
||||
|
||||
/** Compute unary terms.
|
||||
*
|
||||
* uterm[i] = x * codebook[i]^T, i = 1,...,M
|
||||
*
|
||||
* @param uterm output unary terms, size [M, n, K]
|
||||
* @param x input vectors, size [n, dims]
|
||||
* @param codebooks codebooks, size [M, K, dims]
|
||||
* @param n number of input vectors
|
||||
*/
|
||||
void computeUnaryTerms(
|
||||
float* bterm,
|
||||
const float* x,
|
||||
const float* codebooks,
|
||||
int n) const;
|
||||
|
||||
/** Compute binary terms.
|
||||
*
|
||||
* bterm[i][j] = codebooks[i] * codebooks[j]^T. i, j = 1,...,M
|
||||
*
|
||||
* @param bterm output binary terms, size [M, M, K, K]
|
||||
* @param codebooks codebooks, size [M, K, dims]
|
||||
*/
|
||||
void computeBinaryTerms(float* bterm, const float* codebooks) const;
|
||||
|
||||
///< icm encode method
|
||||
void encode(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
const float* codebooks,
|
||||
std::mt19937& gen,
|
||||
int n,
|
||||
int nperts,
|
||||
int ilsIters,
|
||||
int icmIters) const;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
|
@ -10,6 +10,7 @@ import time
|
|||
import unittest
|
||||
import numpy as np
|
||||
import faiss
|
||||
from faiss.contrib import datasets
|
||||
|
||||
class EvalIVFPQAccuracy(unittest.TestCase):
|
||||
|
||||
|
@ -462,5 +463,44 @@ class TestInvalidParams(unittest.TestCase):
|
|||
self.assertTrue(np.array_equal(xb_indices[10:20], I[:, 0]))
|
||||
|
||||
|
||||
class TestLSQIcmEncoder(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def eval_codec(q, xb):
|
||||
codes = q.compute_codes(xb)
|
||||
decoded = q.decode(codes)
|
||||
return ((xb - decoded) ** 2).sum()
|
||||
|
||||
def subtest_gpu_encoding(self, ngpus):
|
||||
"""check that the error is in the same as cpu."""
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 0)
|
||||
|
||||
xt = ds.get_train()
|
||||
xb = ds.get_database()
|
||||
|
||||
M = 4
|
||||
nbits = 8
|
||||
|
||||
lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits)
|
||||
lsq.train(xt)
|
||||
err_cpu = self.eval_codec(lsq, xb)
|
||||
|
||||
lsq = faiss.LocalSearchQuantizer(ds.d, M, nbits)
|
||||
lsq.train(xt)
|
||||
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
|
||||
err_gpu = self.eval_codec(lsq, xb)
|
||||
|
||||
# 13804.411 vs 13814.794, 1 gpu
|
||||
print(err_gpu, err_cpu)
|
||||
self.assertLess(err_gpu, err_cpu * 1.05)
|
||||
|
||||
def test_one_gpu(self):
|
||||
self.subtest_gpu_encoding(1)
|
||||
|
||||
def test_multiple_gpu(self):
|
||||
ngpu = faiss.get_num_gpus()
|
||||
self.subtest_gpu_encoding(ngpu)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -141,7 +141,8 @@ void random_int32(
|
|||
|
||||
namespace faiss {
|
||||
|
||||
LSQTimer lsq_timer;
|
||||
lsq::LSQTimer lsq_timer;
|
||||
using lsq::LSQTimerScope;
|
||||
|
||||
LocalSearchQuantizer::LocalSearchQuantizer(
|
||||
size_t d,
|
||||
|
@ -168,6 +169,12 @@ LocalSearchQuantizer::LocalSearchQuantizer(
|
|||
|
||||
random_seed = 0x12345;
|
||||
std::srand(random_seed);
|
||||
|
||||
icm_encoder_factory = nullptr;
|
||||
}
|
||||
|
||||
LocalSearchQuantizer::~LocalSearchQuantizer() {
|
||||
delete icm_encoder_factory;
|
||||
}
|
||||
|
||||
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
|
||||
|
@ -177,8 +184,8 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|||
FAISS_THROW_IF_NOT(nperts <= M);
|
||||
|
||||
lsq_timer.reset();
|
||||
LSQTimerScope scope(&lsq_timer, "train");
|
||||
if (verbose) {
|
||||
lsq_timer.start("train");
|
||||
printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
|
||||
M,
|
||||
n,
|
||||
|
@ -241,7 +248,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|||
}
|
||||
|
||||
// refine codes
|
||||
icm_encode(x, codes.data(), n, train_ils_iters, gen);
|
||||
icm_encode(codes.data(), x, n, train_ils_iters, gen);
|
||||
|
||||
if (verbose) {
|
||||
float obj = evaluate(codes.data(), x, n);
|
||||
|
@ -269,13 +276,13 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|||
}
|
||||
|
||||
if (verbose) {
|
||||
lsq_timer.end("train");
|
||||
float obj = evaluate(codes.data(), x, n);
|
||||
scope.finish();
|
||||
printf("After training: obj = %lf\n", obj);
|
||||
|
||||
printf("Time statistic:\n");
|
||||
for (const auto& it : lsq_timer.duration) {
|
||||
printf("\t%s time: %lf s\n", it.first.data(), it.second);
|
||||
for (const auto& it : lsq_timer.t) {
|
||||
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +291,7 @@ void LocalSearchQuantizer::perturb_codebooks(
|
|||
float T,
|
||||
const std::vector<float>& stddev,
|
||||
std::mt19937& gen) {
|
||||
lsq_timer.start("perturb_codebooks");
|
||||
LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
|
||||
|
||||
std::vector<std::normal_distribution<float>> distribs;
|
||||
for (size_t i = 0; i < d; i++) {
|
||||
|
@ -298,8 +305,6 @@ void LocalSearchQuantizer::perturb_codebooks(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("perturb_codebooks");
|
||||
}
|
||||
|
||||
void LocalSearchQuantizer::compute_codes(
|
||||
|
@ -307,23 +312,26 @@ void LocalSearchQuantizer::compute_codes(
|
|||
uint8_t* codes_out,
|
||||
size_t n) const {
|
||||
FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
|
||||
|
||||
lsq_timer.reset();
|
||||
LSQTimerScope scope(&lsq_timer, "encode");
|
||||
if (verbose) {
|
||||
lsq_timer.reset();
|
||||
printf("Encoding %zd vectors...\n", n);
|
||||
lsq_timer.start("encode");
|
||||
}
|
||||
|
||||
std::vector<int32_t> codes(n * M);
|
||||
std::mt19937 gen(random_seed);
|
||||
random_int32(codes, 0, K - 1, gen);
|
||||
|
||||
icm_encode(x, codes.data(), n, encode_ils_iters, gen);
|
||||
icm_encode(codes.data(), x, n, encode_ils_iters, gen);
|
||||
pack_codes(n, codes.data(), codes_out);
|
||||
|
||||
if (verbose) {
|
||||
lsq_timer.end("encode");
|
||||
double t = lsq_timer.get("encode");
|
||||
printf("Time to encode %zd vectors: %lf s\n", n, t);
|
||||
scope.finish();
|
||||
printf("Time statistic:\n");
|
||||
for (const auto& it : lsq_timer.t) {
|
||||
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -347,7 +355,7 @@ void LocalSearchQuantizer::update_codebooks(
|
|||
const float* x,
|
||||
const int32_t* codes,
|
||||
size_t n) {
|
||||
lsq_timer.start("update_codebooks");
|
||||
LSQTimerScope scope(&lsq_timer, "update_codebooks");
|
||||
|
||||
if (!update_codebooks_with_double) {
|
||||
// allocate memory
|
||||
|
@ -485,8 +493,6 @@ void LocalSearchQuantizer::update_codebooks(
|
|||
codebooks[i] = (float)d_codebooks[i];
|
||||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("update_codebooks");
|
||||
}
|
||||
|
||||
/** encode using iterative conditional mode
|
||||
|
@ -508,15 +514,23 @@ void LocalSearchQuantizer::update_codebooks(
|
|||
* These two terms can be precomputed and store in a look up table.
|
||||
*/
|
||||
void LocalSearchQuantizer::icm_encode(
|
||||
const float* x,
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
size_t n,
|
||||
size_t ils_iters,
|
||||
std::mt19937& gen) const {
|
||||
lsq_timer.start("icm_encode");
|
||||
LSQTimerScope scope(&lsq_timer, "icm_encode");
|
||||
|
||||
std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
|
||||
compute_binary_terms(binaries.data());
|
||||
auto factory = icm_encoder_factory;
|
||||
std::unique_ptr<lsq::IcmEncoder> icm_encoder;
|
||||
if (factory == nullptr) {
|
||||
icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
|
||||
} else {
|
||||
icm_encoder.reset(factory->get(this));
|
||||
}
|
||||
|
||||
// precompute binary terms for all chunks
|
||||
icm_encoder->set_binary_term();
|
||||
|
||||
const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
|
||||
for (size_t i = 0; i < n_chunks; i++) {
|
||||
|
@ -532,23 +546,20 @@ void LocalSearchQuantizer::icm_encode(
|
|||
|
||||
const float* xi = x + i * chunk_size * d;
|
||||
int32_t* codesi = codes + i * chunk_size * M;
|
||||
icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
|
||||
|
||||
InterruptCallback::check();
|
||||
icm_encoder->verbose = (verbose && i == 0);
|
||||
icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
|
||||
}
|
||||
|
||||
lsq_timer.end("icm_encode");
|
||||
}
|
||||
|
||||
void LocalSearchQuantizer::icm_encode_partial(
|
||||
size_t index,
|
||||
const float* x,
|
||||
void LocalSearchQuantizer::icm_encode_impl(
|
||||
int32_t* codes,
|
||||
size_t n,
|
||||
const float* x,
|
||||
const float* binaries,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
size_t ils_iters,
|
||||
std::mt19937& gen) const {
|
||||
std::vector<float> unaries(n * M * K); // [n, M, K]
|
||||
bool verbose) const {
|
||||
std::vector<float> unaries(n * M * K); // [M, n, K]
|
||||
compute_unary_terms(x, unaries.data(), n);
|
||||
|
||||
std::vector<int32_t> best_codes;
|
||||
|
@ -562,9 +573,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|||
// add perturbation to codes
|
||||
perturb_codes(codes, n, gen);
|
||||
|
||||
for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
|
||||
icm_encode_step(unaries.data(), binaries, codes, n);
|
||||
}
|
||||
icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
|
||||
|
||||
std::vector<float> icm_objs(n, 0.0f);
|
||||
evaluate(codes, x, n, icm_objs.data());
|
||||
|
@ -587,7 +596,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|||
|
||||
memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
|
||||
|
||||
if (verbose && index == 0) {
|
||||
if (verbose) {
|
||||
printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
|
||||
iter1,
|
||||
mean_obj,
|
||||
|
@ -598,61 +607,67 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|||
}
|
||||
|
||||
void LocalSearchQuantizer::icm_encode_step(
|
||||
int32_t* codes,
|
||||
const float* unaries,
|
||||
const float* binaries,
|
||||
int32_t* codes,
|
||||
size_t n) const {
|
||||
// condition on the m-th subcode
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
std::vector<float> objs(n * K);
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
auto u = unaries + i * (M * K) + m * K;
|
||||
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
||||
}
|
||||
|
||||
// compute objective function by adding unary
|
||||
// and binary terms together
|
||||
for (size_t other_m = 0; other_m < M; other_m++) {
|
||||
if (other_m == m) {
|
||||
continue;
|
||||
}
|
||||
size_t n,
|
||||
size_t n_iters) const {
|
||||
FAISS_THROW_IF_NOT(M != 0 && K != 0);
|
||||
FAISS_THROW_IF_NOT(binaries != nullptr);
|
||||
|
||||
for (size_t iter = 0; iter < n_iters; iter++) {
|
||||
// condition on the m-th subcode
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
std::vector<float> objs(n * K);
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int32_t code = 0; code < K; code++) {
|
||||
int32_t code2 = codes[i * M + other_m];
|
||||
size_t binary_idx =
|
||||
m * M * K * K + other_m * K * K + code * K + code2;
|
||||
// binaries[m, other_m, code, code2]
|
||||
objs[i * K + code] += binaries[binary_idx];
|
||||
}
|
||||
auto u = unaries + m * n * K + i * K;
|
||||
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
||||
}
|
||||
}
|
||||
|
||||
// find the optimal value of the m-th subcode
|
||||
// compute objective function by adding unary
|
||||
// and binary terms together
|
||||
for (size_t other_m = 0; other_m < M; other_m++) {
|
||||
if (other_m == m) {
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
float best_obj = HUGE_VALF;
|
||||
int32_t best_code = 0;
|
||||
for (size_t code = 0; code < K; code++) {
|
||||
float obj = objs[i * K + code];
|
||||
if (obj < best_obj) {
|
||||
best_obj = obj;
|
||||
best_code = code;
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int32_t code = 0; code < K; code++) {
|
||||
int32_t code2 = codes[i * M + other_m];
|
||||
size_t binary_idx = m * M * K * K + other_m * K * K +
|
||||
code * K + code2;
|
||||
// binaries[m, other_m, code, code2]
|
||||
objs[i * K + code] += binaries[binary_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
codes[i * M + m] = best_code;
|
||||
}
|
||||
|
||||
} // loop M
|
||||
// find the optimal value of the m-th subcode
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
float best_obj = HUGE_VALF;
|
||||
int32_t best_code = 0;
|
||||
for (size_t code = 0; code < K; code++) {
|
||||
float obj = objs[i * K + code];
|
||||
if (obj < best_obj) {
|
||||
best_obj = obj;
|
||||
best_code = code;
|
||||
}
|
||||
}
|
||||
codes[i * M + m] = best_code;
|
||||
}
|
||||
|
||||
} // loop M
|
||||
}
|
||||
}
|
||||
|
||||
void LocalSearchQuantizer::perturb_codes(
|
||||
int32_t* codes,
|
||||
size_t n,
|
||||
std::mt19937& gen) const {
|
||||
lsq_timer.start("perturb_codes");
|
||||
LSQTimerScope scope(&lsq_timer, "perturb_codes");
|
||||
|
||||
std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
|
||||
std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
|
||||
|
@ -663,12 +678,10 @@ void LocalSearchQuantizer::perturb_codes(
|
|||
codes[i * M + m] = k_distrib(gen);
|
||||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("perturb_codes");
|
||||
}
|
||||
|
||||
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
||||
lsq_timer.start("compute_binary_terms");
|
||||
LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t m12 = 0; m12 < M * M; m12++) {
|
||||
|
@ -686,52 +699,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("compute_binary_terms");
|
||||
}
|
||||
|
||||
void LocalSearchQuantizer::compute_unary_terms(
|
||||
const float* x,
|
||||
float* unaries,
|
||||
float* unaries, // [M, n, K]
|
||||
size_t n) const {
|
||||
lsq_timer.start("compute_unary_terms");
|
||||
LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
|
||||
|
||||
// compute x * codebooks^T
|
||||
// compute x * codebook^T for each codebook
|
||||
//
|
||||
// NOTE: LAPACK use column major order
|
||||
// out = alpha * op(A) * op(B) + beta * C
|
||||
FINTEGER nrows_A = M * K;
|
||||
FINTEGER ncols_A = d;
|
||||
|
||||
FINTEGER nrows_B = d;
|
||||
FINTEGER ncols_B = n;
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
FINTEGER nrows_A = K;
|
||||
FINTEGER ncols_A = d;
|
||||
|
||||
float alpha = -2.0f;
|
||||
float beta = 0.0f;
|
||||
sgemm_("Transposed",
|
||||
"Not Transposed",
|
||||
&nrows_A, // nrows of op(A)
|
||||
&ncols_B, // ncols of op(B)
|
||||
&ncols_A, // ncols of op(A)
|
||||
&alpha,
|
||||
codebooks.data(),
|
||||
&ncols_A, // nrows of A
|
||||
x,
|
||||
&nrows_B, // nrows of B
|
||||
&beta,
|
||||
unaries,
|
||||
&nrows_A); // nrows of output
|
||||
FINTEGER nrows_B = d;
|
||||
FINTEGER ncols_B = n;
|
||||
|
||||
float alpha = -2.0f;
|
||||
float beta = 0.0f;
|
||||
sgemm_("Transposed",
|
||||
"Not Transposed",
|
||||
&nrows_A, // nrows of op(A)
|
||||
&ncols_B, // ncols of op(B)
|
||||
&ncols_A, // ncols of op(A)
|
||||
&alpha,
|
||||
codebooks.data() + m * K * d,
|
||||
&ncols_A, // nrows of A
|
||||
x,
|
||||
&nrows_B, // nrows of B
|
||||
&beta,
|
||||
unaries + m * n * K,
|
||||
&nrows_A); // nrows of output
|
||||
}
|
||||
|
||||
std::vector<float> norms(M * K);
|
||||
fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
float* u = unaries + i * (M * K);
|
||||
fvec_add(M * K, u, norms.data(), u);
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
float* u = unaries + m * n * K + i * K;
|
||||
fvec_add(K, u, norms.data() + m * K, u);
|
||||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("compute_unary_terms");
|
||||
}
|
||||
|
||||
float LocalSearchQuantizer::evaluate(
|
||||
|
@ -739,7 +753,7 @@ float LocalSearchQuantizer::evaluate(
|
|||
const float* x,
|
||||
size_t n,
|
||||
float* objs) const {
|
||||
lsq_timer.start("evaluate");
|
||||
LSQTimerScope scope(&lsq_timer, "evaluate");
|
||||
|
||||
// decode
|
||||
std::vector<float> decoded_x(n * d, 0.0f);
|
||||
|
@ -755,7 +769,7 @@ float LocalSearchQuantizer::evaluate(
|
|||
fvec_add(d, decoded_i, c, decoded_i);
|
||||
}
|
||||
|
||||
float err = fvec_L2sqr(x + i * d, decoded_i, d);
|
||||
float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
|
||||
obj += err;
|
||||
|
||||
if (objs) {
|
||||
|
@ -763,34 +777,68 @@ float LocalSearchQuantizer::evaluate(
|
|||
}
|
||||
}
|
||||
|
||||
lsq_timer.end("evaluate");
|
||||
|
||||
obj = obj / n;
|
||||
return obj;
|
||||
}
|
||||
|
||||
namespace lsq {
|
||||
|
||||
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
|
||||
: verbose(false), lsq(lsq) {}
|
||||
|
||||
void IcmEncoder::set_binary_term() {
|
||||
auto M = lsq->M;
|
||||
auto K = lsq->K;
|
||||
binaries.resize(M * M * K * K);
|
||||
lsq->compute_binary_terms(binaries.data());
|
||||
}
|
||||
|
||||
void IcmEncoder::encode(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
size_t ils_iters) const {
|
||||
lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
|
||||
}
|
||||
|
||||
double LSQTimer::get(const std::string& name) {
|
||||
return duration[name];
|
||||
if (t.count(name) == 0) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return t[name];
|
||||
}
|
||||
}
|
||||
|
||||
void LSQTimer::start(const std::string& name) {
|
||||
FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
|
||||
started[name] = true;
|
||||
t0[name] = getmillisecs();
|
||||
}
|
||||
|
||||
void LSQTimer::end(const std::string& name) {
|
||||
FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
|
||||
double t1 = getmillisecs();
|
||||
double sec = (t1 - t0[name]) / 1000;
|
||||
duration[name] += sec;
|
||||
started[name] = false;
|
||||
void LSQTimer::add(const std::string& name, double delta) {
|
||||
if (t.count(name) == 0) {
|
||||
t[name] = delta;
|
||||
} else {
|
||||
t[name] += delta;
|
||||
}
|
||||
}
|
||||
|
||||
void LSQTimer::reset() {
|
||||
duration.clear();
|
||||
t0.clear();
|
||||
started.clear();
|
||||
t.clear();
|
||||
}
|
||||
|
||||
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
|
||||
: timer(timer), name(name), finished(false) {
|
||||
t0 = getmillisecs();
|
||||
}
|
||||
|
||||
void LSQTimerScope::finish() {
|
||||
if (!finished) {
|
||||
auto delta = getmillisecs() - t0;
|
||||
timer->add(name, delta);
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
|
||||
LSQTimerScope::~LSQTimerScope() {
|
||||
finish();
|
||||
}
|
||||
|
||||
} // namespace lsq
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -20,6 +20,12 @@
|
|||
|
||||
namespace faiss {
|
||||
|
||||
namespace lsq {
|
||||
|
||||
struct IcmEncoderFactory;
|
||||
|
||||
} // namespace lsq
|
||||
|
||||
/** Implementation of LSQ/LSQ++ described in the following two papers:
|
||||
*
|
||||
* Revisiting additive quantization
|
||||
|
@ -36,7 +42,6 @@ namespace faiss {
|
|||
* The trained codes are stored in `codebooks` which is called
|
||||
* `centroids` in PQ and RQ.
|
||||
*/
|
||||
|
||||
struct LocalSearchQuantizer : AdditiveQuantizer {
|
||||
size_t K; ///< number of codes per codebook
|
||||
|
||||
|
@ -54,6 +59,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
int random_seed; ///< seed for random generator
|
||||
size_t nperts; ///< number of perturbation in each code
|
||||
|
||||
///< if non-NULL, use this encoder to encode
|
||||
lsq::IcmEncoderFactory* icm_encoder_factory;
|
||||
|
||||
bool update_codebooks_with_double = true;
|
||||
|
||||
LocalSearchQuantizer(
|
||||
|
@ -66,6 +74,8 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
|
||||
LocalSearchQuantizer();
|
||||
|
||||
~LocalSearchQuantizer() override;
|
||||
|
||||
// Train the local search quantizer
|
||||
void train(size_t n, const float* x) override;
|
||||
|
||||
|
@ -73,6 +83,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
*
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param codes output codes, size n * code_size
|
||||
* @param n number of vectors
|
||||
*/
|
||||
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
|
||||
|
||||
|
@ -80,36 +91,46 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
*
|
||||
* @param x training vectors, size n * d
|
||||
* @param codes encoded training vectors, size n * M
|
||||
* @param n number of vectors
|
||||
*/
|
||||
void update_codebooks(const float* x, const int32_t* codes, size_t n);
|
||||
|
||||
/** Encode vectors given codebooks using iterative conditional mode (icm).
|
||||
*
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param codes output codes, size n * M
|
||||
* @param codes output codes, size n * M
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param n number of vectors
|
||||
* @param ils_iters number of iterations of iterative local search
|
||||
*/
|
||||
void icm_encode(
|
||||
const float* x,
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
size_t n,
|
||||
size_t ils_iters,
|
||||
std::mt19937& gen) const;
|
||||
|
||||
void icm_encode_partial(
|
||||
size_t index,
|
||||
const float* x,
|
||||
void icm_encode_impl(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
const float* unaries,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
const float* binaries,
|
||||
size_t ils_iters,
|
||||
std::mt19937& gen) const;
|
||||
bool verbose) const;
|
||||
|
||||
void icm_encode_step(
|
||||
int32_t* codes,
|
||||
const float* unaries,
|
||||
const float* binaries,
|
||||
int32_t* codes,
|
||||
size_t n) const;
|
||||
size_t n,
|
||||
size_t n_iters) const;
|
||||
|
||||
/** Add some perturbation to codes
|
||||
*
|
||||
* @param codes codes to be perturbed, size n * M
|
||||
* @param n number of vectors
|
||||
*/
|
||||
void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
|
||||
|
||||
/** Add some perturbation to codebooks
|
||||
*
|
||||
|
@ -121,12 +142,6 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
const std::vector<float>& stddev,
|
||||
std::mt19937& gen);
|
||||
|
||||
/** Add some perturbation to codes
|
||||
*
|
||||
* @param codes codes to be perturbed, size n * M
|
||||
*/
|
||||
void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
|
||||
|
||||
/** Compute binary terms
|
||||
*
|
||||
* @param binaries binary terms, size M * M * K * K
|
||||
|
@ -135,6 +150,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
|
||||
/** Compute unary terms
|
||||
*
|
||||
* @param n number of vectors
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param unaries unary terms, size n * M * K
|
||||
*/
|
||||
|
@ -142,8 +158,9 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
|
||||
/** Helper function to compute reconstruction error
|
||||
*
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param codes encoded codes, size n * M
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param n number of vectors
|
||||
* @param objs if it is not null, store reconstruction
|
||||
error of each vector into it, size n
|
||||
*/
|
||||
|
@ -154,13 +171,50 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|||
float* objs = nullptr) const;
|
||||
};
|
||||
|
||||
namespace lsq {
|
||||
|
||||
struct IcmEncoder {
|
||||
std::vector<float> binaries;
|
||||
|
||||
bool verbose;
|
||||
|
||||
const LocalSearchQuantizer* lsq;
|
||||
|
||||
explicit IcmEncoder(const LocalSearchQuantizer* lsq);
|
||||
|
||||
virtual ~IcmEncoder() {}
|
||||
|
||||
///< compute binary terms
|
||||
virtual void set_binary_term();
|
||||
|
||||
/** Encode vectors given codebooks
|
||||
*
|
||||
* @param codes output codes, size n * M
|
||||
* @param x vectors to encode, size n * d
|
||||
* @param gen random generator
|
||||
* @param n number of vectors
|
||||
* @param ils_iters number of iterations of iterative local search
|
||||
*/
|
||||
virtual void encode(
|
||||
int32_t* codes,
|
||||
const float* x,
|
||||
std::mt19937& gen,
|
||||
size_t n,
|
||||
size_t ils_iters) const;
|
||||
};
|
||||
|
||||
struct IcmEncoderFactory {
|
||||
virtual IcmEncoder* get(const LocalSearchQuantizer* lsq) {
|
||||
return new IcmEncoder(lsq);
|
||||
}
|
||||
virtual ~IcmEncoderFactory() {}
|
||||
};
|
||||
|
||||
/** A helper struct to count consuming time during training.
|
||||
* It is NOT thread-safe.
|
||||
*/
|
||||
struct LSQTimer {
|
||||
std::unordered_map<std::string, double> duration;
|
||||
std::unordered_map<std::string, double> t0;
|
||||
std::unordered_map<std::string, bool> started;
|
||||
std::unordered_map<std::string, double> t;
|
||||
|
||||
LSQTimer() {
|
||||
reset();
|
||||
|
@ -168,13 +222,24 @@ struct LSQTimer {
|
|||
|
||||
double get(const std::string& name);
|
||||
|
||||
void start(const std::string& name);
|
||||
|
||||
void end(const std::string& name);
|
||||
void add(const std::string& name, double delta);
|
||||
|
||||
void reset();
|
||||
};
|
||||
|
||||
FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
|
||||
struct LSQTimerScope {
|
||||
double t0;
|
||||
LSQTimer* timer;
|
||||
std::string name;
|
||||
bool finished;
|
||||
|
||||
LSQTimerScope(LSQTimer* timer, std::string name);
|
||||
|
||||
void finish();
|
||||
|
||||
~LSQTimerScope();
|
||||
};
|
||||
|
||||
} // namespace lsq
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -287,6 +287,7 @@ void gpu_sync_all_devices();
|
|||
#include <faiss/gpu/GpuAutoTune.h>
|
||||
#include <faiss/gpu/GpuCloner.h>
|
||||
#include <faiss/gpu/GpuDistance.h>
|
||||
#include <faiss/gpu/GpuIcmEncoder.h>
|
||||
|
||||
int get_num_gpus()
|
||||
{
|
||||
|
@ -490,6 +491,7 @@ void gpu_sync_all_devices()
|
|||
%include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>
|
||||
%include <faiss/gpu/GpuIndexBinaryFlat.h>
|
||||
%include <faiss/gpu/GpuDistance.h>
|
||||
%include <faiss/gpu/GpuIcmEncoder.h>
|
||||
|
||||
|
||||
#endif
|
||||
|
|
|
@ -15,6 +15,9 @@ import unittest
|
|||
from faiss.contrib import datasets
|
||||
|
||||
|
||||
sp = faiss.swig_ptr
|
||||
|
||||
|
||||
def construct_sparse_matrix(codes, K):
|
||||
n, M = codes.shape
|
||||
B = np.zeros((n, M * K), dtype=np.float32)
|
||||
|
@ -54,18 +57,18 @@ def compute_binary_terms_ref(codebooks):
|
|||
def compute_unary_terms_ref(codebooks, x):
|
||||
codebooks_t = np.swapaxes(codebooks, 1, 2) # [M, d, K]
|
||||
unaries = -2 * x.dot(codebooks_t) # [n, M, K]
|
||||
|
||||
code_norms = np.sum(codebooks * codebooks, axis=2) # [M, K]
|
||||
unaries += code_norms
|
||||
unaries = np.swapaxes(unaries, 0, 1) # [M, n, K]
|
||||
|
||||
return unaries
|
||||
|
||||
|
||||
def icm_encode_step_ref(unaries, binaries, codes):
|
||||
n, M, K = unaries.shape
|
||||
M, n, K = unaries.shape
|
||||
|
||||
for m in range(M):
|
||||
objs = unaries[:, m].copy() # [n, K]
|
||||
objs = unaries[m].copy() # [n, K]
|
||||
|
||||
for m2 in range(M): # pair, m2 != m
|
||||
if m2 == m:
|
||||
|
@ -131,8 +134,8 @@ class TestComponents(unittest.TestCase):
|
|||
# decode x
|
||||
pack_codes = np.zeros((n, lsq.code_size)).astype(np.uint8)
|
||||
decoded_x = np.zeros((n, d)).astype(np.float32)
|
||||
lsq.pack_codes(n, faiss.swig_ptr(codes), faiss.swig_ptr(pack_codes))
|
||||
lsq.decode_c(faiss.swig_ptr(pack_codes), faiss.swig_ptr(decoded_x), n)
|
||||
lsq.pack_codes(n, sp(codes), sp(pack_codes))
|
||||
lsq.decode_c(sp(pack_codes), sp(decoded_x), n)
|
||||
|
||||
# decode in Python
|
||||
codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
|
@ -163,7 +166,7 @@ class TestComponents(unittest.TestCase):
|
|||
codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
codebooks = codebooks.reshape(M, K, d).copy()
|
||||
|
||||
lsq.update_codebooks(faiss.swig_ptr(x), faiss.swig_ptr(codes), n)
|
||||
lsq.update_codebooks(sp(x), sp(codes), n)
|
||||
new_codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
new_codebooks = new_codebooks.reshape(M, K, d).copy()
|
||||
|
||||
|
@ -209,7 +212,7 @@ class TestComponents(unittest.TestCase):
|
|||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
lsq.train(x) # just for allocating memory for codebooks
|
||||
|
||||
lsq.compute_binary_terms(faiss.swig_ptr(binaries))
|
||||
lsq.compute_binary_terms(sp(binaries))
|
||||
|
||||
codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
codebooks = codebooks.reshape(M, K, d).copy()
|
||||
|
@ -226,12 +229,12 @@ class TestComponents(unittest.TestCase):
|
|||
|
||||
rs = np.random.RandomState(123)
|
||||
x = rs.rand(n, d).astype(np.float32)
|
||||
unaries = np.zeros((n, M, K)).astype(np.float32)
|
||||
unaries = np.zeros((M, n, K)).astype(np.float32)
|
||||
|
||||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
lsq.train(x) # just for allocating memory for codebooks
|
||||
|
||||
lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n)
|
||||
lsq.compute_unary_terms(sp(x), sp(unaries), n)
|
||||
|
||||
codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
codebooks = codebooks.reshape(M, K, d).copy()
|
||||
|
@ -251,15 +254,17 @@ class TestComponents(unittest.TestCase):
|
|||
# randomly generate codes, binary terms and unary terms
|
||||
codes = rs.randint(0, K, (n, M)).astype(np.int32)
|
||||
new_codes = codes.copy()
|
||||
unaries = rs.rand(n, M, K).astype(np.float32)
|
||||
unaries = rs.rand(M, n, K).astype(np.float32)
|
||||
binaries = rs.rand(M, M, K, K).astype(np.float32)
|
||||
|
||||
# do icm encoding given binary and unary terms
|
||||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
lsq.icm_encode_step(
|
||||
faiss.swig_ptr(unaries),
|
||||
faiss.swig_ptr(binaries),
|
||||
faiss.swig_ptr(new_codes), n)
|
||||
sp(new_codes),
|
||||
sp(unaries),
|
||||
sp(binaries),
|
||||
n,
|
||||
1)
|
||||
|
||||
# do icm encoding given binary and unary terms in Python
|
||||
ref_codes = icm_encode_step_ref(unaries, binaries, codes)
|
||||
|
@ -280,11 +285,11 @@ class TestComponents(unittest.TestCase):
|
|||
|
||||
# compute binary terms
|
||||
binaries = np.zeros((M, M, K, K)).astype(np.float32)
|
||||
lsq.compute_binary_terms(faiss.swig_ptr(binaries))
|
||||
lsq.compute_binary_terms(sp(binaries))
|
||||
|
||||
# compute unary terms
|
||||
unaries = np.zeros((n, M, K)).astype(np.float32)
|
||||
lsq.compute_unary_terms(faiss.swig_ptr(x), faiss.swig_ptr(unaries), n)
|
||||
unaries = np.zeros((M, n, K)).astype(np.float32)
|
||||
lsq.compute_unary_terms(sp(x), sp(unaries), n)
|
||||
|
||||
# randomly generate codes
|
||||
codes = rs.randint(0, K, (n, M)).astype(np.int32)
|
||||
|
@ -292,9 +297,11 @@ class TestComponents(unittest.TestCase):
|
|||
|
||||
# do icm encoding given binary and unary terms
|
||||
lsq.icm_encode_step(
|
||||
faiss.swig_ptr(unaries),
|
||||
faiss.swig_ptr(binaries),
|
||||
faiss.swig_ptr(new_codes), n)
|
||||
sp(new_codes),
|
||||
sp(unaries),
|
||||
sp(binaries),
|
||||
n,
|
||||
1)
|
||||
|
||||
# do icm encoding without pre-computed unary and bianry terms in Python
|
||||
codebooks = faiss.vector_float_to_array(lsq.codebooks)
|
||||
|
|
Loading…
Reference in New Issue