mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Fix polysemous OOM
Summary: Polysemous training can OOM because it uses tables of size n^2 with n is 2**nbit of the PQ. This throws and exception when the table threatens to become too large. It also reduces the number of threads when this would make it possible to fit the computation within max_memory bytes. Reviewed By: wickedfoo Differential Revision: D26856747 fbshipit-source-id: bd98e60293494e2f4b2b6d48eb1200efb1ce683c
This commit is contained in:
parent
f2464141a7
commit
189aecb224
@ -8,8 +8,11 @@
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/impl/PolysemousTraining.h>
|
||||
#include "faiss/impl/FaissAssert.h"
|
||||
|
||||
#include <omp.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
@ -760,6 +763,8 @@ PolysemousTraining::PolysemousTraining() {
|
||||
optimization_type = OT_ReproduceDistances_affine;
|
||||
ntrain_permutation = 0;
|
||||
dis_weight_factor = log(2);
|
||||
// max 20 G RAM
|
||||
max_memory = (size_t)(20) * 1024 * 1024 * 1024;
|
||||
}
|
||||
|
||||
void PolysemousTraining::optimize_reproduce_distances(
|
||||
@ -769,7 +774,22 @@ void PolysemousTraining::optimize_reproduce_distances(
|
||||
int n = pq.ksub;
|
||||
int nbits = pq.nbits;
|
||||
|
||||
#pragma omp parallel for
|
||||
size_t mem1 = memory_usage_per_thread(pq);
|
||||
int nt = std::min(omp_get_max_threads(), int(pq.M));
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
mem1 < max_memory,
|
||||
"Polysemous training will use %zd bytes per thread, while the max is set to %zd",
|
||||
mem1,
|
||||
max_memory);
|
||||
|
||||
if (mem1 * nt > max_memory) {
|
||||
nt = max_memory / mem1;
|
||||
fprintf(stderr,
|
||||
"Polysemous training: WARN, reducing number of threads to %d to save memory",
|
||||
nt);
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(nt)
|
||||
for (int m = 0; m < pq.M; m++) {
|
||||
std::vector<double> dis_table;
|
||||
|
||||
@ -824,7 +844,6 @@ void PolysemousTraining::optimize_ranking(
|
||||
size_t n,
|
||||
const float* x) const {
|
||||
int dsub = pq.dsub;
|
||||
|
||||
int nbits = pq.nbits;
|
||||
|
||||
std::vector<uint8_t> all_codes(pq.code_size * n);
|
||||
@ -833,8 +852,9 @@ void PolysemousTraining::optimize_ranking(
|
||||
|
||||
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
||||
|
||||
if (n == 0)
|
||||
if (n == 0) {
|
||||
pq.compute_sdc_table();
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int m = 0; m < pq.M; m++) {
|
||||
@ -943,4 +963,18 @@ void PolysemousTraining::optimize_pq_for_hamming(
|
||||
pq.compute_sdc_table();
|
||||
}
|
||||
|
||||
size_t PolysemousTraining::memory_usage_per_thread(
|
||||
const ProductQuantizer& pq) const {
|
||||
size_t n = pq.ksub;
|
||||
|
||||
switch (optimization_type) {
|
||||
case OT_None:
|
||||
return 0;
|
||||
case OT_ReproduceDistances_affine:
|
||||
return n * n * sizeof(double) * 3;
|
||||
case OT_Ranking_weighted_diff:
|
||||
return n * n * n * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
@ -127,6 +127,9 @@ struct PolysemousTraining : SimulatedAnnealingParameters {
|
||||
int ntrain_permutation;
|
||||
double dis_weight_factor; ///< decay of exp that weights distance loss
|
||||
|
||||
/// refuse to train if it would require more than that amount of RAM
|
||||
size_t max_memory;
|
||||
|
||||
// filename pattern for the logging of iterations
|
||||
std::string log_pattern;
|
||||
|
||||
@ -142,6 +145,9 @@ struct PolysemousTraining : SimulatedAnnealingParameters {
|
||||
void optimize_ranking(ProductQuantizer& pq, size_t n, const float* x) const;
|
||||
/// called by optimize_pq_for_hamming
|
||||
void optimize_reproduce_distances(ProductQuantizer& pq) const;
|
||||
|
||||
/// make sure we don't blow up the memory
|
||||
size_t memory_usage_per_thread(const ProductQuantizer& pq) const;
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
@ -157,6 +157,16 @@ class IndexAccuracy(unittest.TestCase):
|
||||
# should give 0.234 0.236 0.236
|
||||
assert e[10] > 0.235
|
||||
|
||||
def test_polysemous_OOM(self):
|
||||
""" this used to cause OOM when training polysemous with large
|
||||
nb bits"""
|
||||
d = 32
|
||||
xt, xb, xq = get_dataset_2(d, 10000, 0, 0)
|
||||
index = faiss.IndexPQ(d, M, 13)
|
||||
index.do_polysemous_training = True
|
||||
index.pq.cp.niter = 0
|
||||
index.polysemous_training.max_memory = 128 * 1024 * 1024
|
||||
self.assertRaises(RuntimeError, index.train, xt)
|
||||
|
||||
|
||||
class TestSQFlavors(unittest.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user