Fix polysemous OOM

Summary: Polysemous training can OOM because it uses tables of size n^2 with n is 2**nbit of the PQ. This throws and exception when the table threatens to become too large. It also reduces the number of threads when this would make it possible to fit the computation within max_memory bytes.

Reviewed By: wickedfoo

Differential Revision: D26856747

fbshipit-source-id: bd98e60293494e2f4b2b6d48eb1200efb1ce683c
This commit is contained in:
Matthijs Douze 2021-03-06 00:38:38 -08:00 committed by Facebook GitHub Bot
parent f2464141a7
commit 189aecb224
3 changed files with 53 additions and 3 deletions

View File

@ -8,8 +8,11 @@
// -*- c++ -*-
#include <faiss/impl/PolysemousTraining.h>
#include "faiss/impl/FaissAssert.h"
#include <omp.h>
#include <stdint.h>
#include <cmath>
#include <cstdlib>
#include <cstring>
@ -760,6 +763,8 @@ PolysemousTraining::PolysemousTraining() {
optimization_type = OT_ReproduceDistances_affine;
ntrain_permutation = 0;
dis_weight_factor = log(2);
// max 20 G RAM
max_memory = (size_t)(20) * 1024 * 1024 * 1024;
}
void PolysemousTraining::optimize_reproduce_distances(
@ -769,7 +774,22 @@ void PolysemousTraining::optimize_reproduce_distances(
int n = pq.ksub;
int nbits = pq.nbits;
#pragma omp parallel for
size_t mem1 = memory_usage_per_thread(pq);
int nt = std::min(omp_get_max_threads(), int(pq.M));
FAISS_THROW_IF_NOT_FMT(
mem1 < max_memory,
"Polysemous training will use %zd bytes per thread, while the max is set to %zd",
mem1,
max_memory);
if (mem1 * nt > max_memory) {
nt = max_memory / mem1;
fprintf(stderr,
"Polysemous training: WARN, reducing number of threads to %d to save memory",
nt);
}
#pragma omp parallel for num_threads(nt)
for (int m = 0; m < pq.M; m++) {
std::vector<double> dis_table;
@ -824,7 +844,6 @@ void PolysemousTraining::optimize_ranking(
size_t n,
const float* x) const {
int dsub = pq.dsub;
int nbits = pq.nbits;
std::vector<uint8_t> all_codes(pq.code_size * n);
@ -833,8 +852,9 @@ void PolysemousTraining::optimize_ranking(
FAISS_THROW_IF_NOT(pq.nbits == 8);
if (n == 0)
if (n == 0) {
pq.compute_sdc_table();
}
#pragma omp parallel for
for (int m = 0; m < pq.M; m++) {
@ -943,4 +963,18 @@ void PolysemousTraining::optimize_pq_for_hamming(
pq.compute_sdc_table();
}
size_t PolysemousTraining::memory_usage_per_thread(
const ProductQuantizer& pq) const {
size_t n = pq.ksub;
switch (optimization_type) {
case OT_None:
return 0;
case OT_ReproduceDistances_affine:
return n * n * sizeof(double) * 3;
case OT_Ranking_weighted_diff:
return n * n * n * sizeof(float);
}
}
} // namespace faiss

View File

@ -127,6 +127,9 @@ struct PolysemousTraining : SimulatedAnnealingParameters {
int ntrain_permutation;
double dis_weight_factor; ///< decay of exp that weights distance loss
/// refuse to train if it would require more than that amount of RAM
size_t max_memory;
// filename pattern for the logging of iterations
std::string log_pattern;
@ -142,6 +145,9 @@ struct PolysemousTraining : SimulatedAnnealingParameters {
void optimize_ranking(ProductQuantizer& pq, size_t n, const float* x) const;
/// called by optimize_pq_for_hamming
void optimize_reproduce_distances(ProductQuantizer& pq) const;
/// make sure we don't blow up the memory
size_t memory_usage_per_thread(const ProductQuantizer& pq) const;
};
} // namespace faiss

View File

@ -157,6 +157,16 @@ class IndexAccuracy(unittest.TestCase):
# should give 0.234 0.236 0.236
assert e[10] > 0.235
def test_polysemous_OOM(self):
""" this used to cause OOM when training polysemous with large
nb bits"""
d = 32
xt, xb, xq = get_dataset_2(d, 10000, 0, 0)
index = faiss.IndexPQ(d, M, 13)
index.do_polysemous_training = True
index.pq.cp.niter = 0
index.polysemous_training.max_memory = 128 * 1024 * 1024
self.assertRaises(RuntimeError, index.train, xt)
class TestSQFlavors(unittest.TestCase):