11 #include "PolysemousTraining.h"
22 #include "FaissAssert.h"
35 SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
38 init_temperature = 0.7;
39 temperature_decay = pow (0.9, 1/500.);
45 only_bit_flips =
false;
51 double PermutationObjective::cost_update (
52 const int *perm,
int iw,
int jw)
const
54 double orig_cost = compute_cost (perm);
56 std::vector<int> perm2 (n);
57 for (
int i = 0; i < n; i++)
62 double new_cost = compute_cost (perm2.data());
63 return new_cost - orig_cost;
78 FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
81 SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
87 double SimulatedAnnealingOptimizer::run_optimization (
int * best_perm)
89 double min_cost = 1e30;
92 for (
int it = 0; it < n_redo; it++) {
93 std::vector<int> perm(n);
94 for (
int i = 0; i <
n; i++)
97 for (
int i = 0; i <
n; i++) {
99 std::swap (perm[i], perm[j]);
102 float cost = optimize (perm.data());
103 if (logfile) fprintf (logfile,
"\n");
105 printf (
" optimization run %d: cost=%g %s\n",
106 it, cost, cost < min_cost ?
"keep" :
"");
108 if (cost < min_cost) {
109 memcpy (best_perm, perm.data(),
sizeof(perm[0]) * n);
118 double SimulatedAnnealingOptimizer::optimize (
int *perm)
120 double cost =
init_cost = obj->compute_cost (perm);
122 while (!(n <= (1 << log2n))) log2n++;
123 double temperature = init_temperature;
124 int n_swap = 0, n_hot = 0;
125 for (
int it = 0; it < n_iter; it++) {
126 temperature = temperature * temperature_decay;
128 if (only_bit_flips) {
130 jw = iw ^ (1 << rnd->
rand_int (log2n));
136 double delta_cost = obj->cost_update (perm, iw, jw);
137 if (delta_cost < 0 || rnd->rand_float () < temperature) {
138 std::swap (perm[iw], perm[jw]);
141 if (delta_cost >= 0) n_hot++;
143 if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
144 printf (
" iteration %d cost %g temp %g n_swap %d "
146 it, cost, temperature, n_swap, n_hot);
150 fprintf (logfile,
"%d %g %g %d %d\n",
151 it, cost, temperature, n_swap, n_hot);
154 if (verbose > 1) printf(
"\n");
171 static inline int hamming_dis (uint64_t a, uint64_t b)
173 return __builtin_popcountl (a ^ b);
179 struct ReproduceWithHammingObjective : PermutationObjective {
181 double dis_weight_factor;
183 static double sqr (
double x) {
return x * x; }
188 double dis_weight (
double x)
const
190 return exp (-dis_weight_factor * x);
193 std::vector<double> target_dis;
194 std::vector<double> weights;
197 double compute_cost(
const int* perm)
const override {
199 for (
int i = 0; i < n; i++) {
200 for (
int j = 0; j < n; j++) {
201 double wanted = target_dis[i * n + j];
202 double w = weights[i * n + j];
203 double actual = hamming_dis(perm[i], perm[j]);
204 cost += w * sqr(wanted - actual);
213 double cost_update(
const int* perm,
int iw,
int jw)
const override {
214 double delta_cost = 0;
216 for (
int i = 0; i < n; i++) {
218 for (
int j = 0; j < n; j++) {
219 double wanted = target_dis[i * n + j], w = weights[i * n + j];
220 double actual = hamming_dis(perm[i], perm[j]);
221 delta_cost -= w * sqr(wanted - actual);
223 hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
224 delta_cost += w * sqr(wanted - new_actual);
226 }
else if (i == jw) {
227 for (
int j = 0; j < n; j++) {
228 double wanted = target_dis[i * n + j], w = weights[i * n + j];
229 double actual = hamming_dis(perm[i], perm[j]);
230 delta_cost -= w * sqr(wanted - actual);
232 hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
233 delta_cost += w * sqr(wanted - new_actual);
238 double wanted = target_dis[i * n + j], w = weights[i * n + j];
239 double actual = hamming_dis(perm[i], perm[j]);
240 delta_cost -= w * sqr(wanted - actual);
241 double new_actual = hamming_dis(perm[i], perm[jw]);
242 delta_cost += w * sqr(wanted - new_actual);
246 double wanted = target_dis[i * n + j], w = weights[i * n + j];
247 double actual = hamming_dis(perm[i], perm[j]);
248 delta_cost -= w * sqr(wanted - actual);
249 double new_actual = hamming_dis(perm[i], perm[iw]);
250 delta_cost += w * sqr(wanted - new_actual);
260 ReproduceWithHammingObjective (
262 const std::vector<double> & dis_table,
263 double dis_weight_factor):
264 nbits (nbits), dis_weight_factor (dis_weight_factor)
267 FAISS_THROW_IF_NOT (dis_table.size() == n * n);
268 set_affine_target_dis (dis_table);
271 void set_affine_target_dis (
const std::vector<double> & dis_table)
273 double sum = 0, sum2 = 0;
275 for (
int i = 0; i < n2; i++) {
276 sum += dis_table [i];
277 sum2 += dis_table [i] * dis_table [i];
279 double mean = sum / n2;
280 double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
282 target_dis.resize (n2);
284 for (
int i = 0; i < n2; i++) {
286 double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
290 weights.push_back (dis_weight (td));
295 ~ReproduceWithHammingObjective()
override {}
302 double ReproduceDistancesObjective::dis_weight (
double x)
const
304 return exp (-dis_weight_factor * x);
308 double ReproduceDistancesObjective::get_source_dis (
int i,
int j)
const
314 double ReproduceDistancesObjective::compute_cost (
const int *perm)
const
317 for (
int i = 0; i < n; i++) {
318 for (
int j = 0; j < n; j++) {
320 double w =
weights [i * n + j];
321 double actual = get_source_dis (perm[i], perm[j]);
322 cost += w * sqr (wanted - actual);
330 double ReproduceDistancesObjective::cost_update(
331 const int *perm,
int iw,
int jw)
const
333 double delta_cost = 0;
334 for (
int i = 0; i < n; i++) {
336 for (
int j = 0; j < n; j++) {
339 double actual = get_source_dis (perm[i], perm[j]);
340 delta_cost -= w * sqr (wanted - actual);
341 double new_actual = get_source_dis (
343 perm[j == iw ? jw : j == jw ? iw : j]);
344 delta_cost += w * sqr (wanted - new_actual);
346 }
else if (i == jw) {
347 for (
int j = 0; j < n; j++) {
350 double actual = get_source_dis (perm[i], perm[j]);
351 delta_cost -= w * sqr (wanted - actual);
352 double new_actual = get_source_dis (
354 perm[j == iw ? jw : j == jw ? iw : j]);
355 delta_cost += w * sqr (wanted - new_actual);
362 double actual = get_source_dis (perm[i], perm[j]);
363 delta_cost -= w * sqr (wanted - actual);
364 double new_actual = get_source_dis (perm[i], perm[jw]);
365 delta_cost += w * sqr (wanted - new_actual);
371 double actual = get_source_dis (perm[i], perm[j]);
372 delta_cost -= w * sqr (wanted - actual);
373 double new_actual = get_source_dis (perm[i], perm[iw]);
374 delta_cost += w * sqr (wanted - new_actual);
383 ReproduceDistancesObjective::ReproduceDistancesObjective (
385 const double *source_dis_in,
386 const double *target_dis_in,
387 double dis_weight_factor):
388 dis_weight_factor (dis_weight_factor),
389 target_dis (target_dis_in)
392 set_affine_target_dis (source_dis_in);
395 void ReproduceDistancesObjective::compute_mean_stdev (
396 const double *tab,
size_t n2,
397 double *mean_out,
double *stddev_out)
399 double sum = 0, sum2 = 0;
400 for (
int i = 0; i < n2; i++) {
402 sum2 += tab [i] * tab [i];
404 double mean = sum / n2;
405 double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
407 *stddev_out = stddev;
410 void ReproduceDistancesObjective::set_affine_target_dis (
411 const double *source_dis_in)
415 double mean_src, stddev_src;
416 compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
418 double mean_target, stddev_target;
419 compute_mean_stdev (
target_dis, n2, &mean_target, &stddev_target);
421 printf (
"map mean %g std %g -> mean %g std %g\n",
422 mean_src, stddev_src, mean_target, stddev_target);
427 for (
int i = 0; i < n2; i++) {
429 source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
430 * stddev_target + mean_target;
444 template <
typename Ttab,
typename Taccu>
452 std::vector<Ttab> n_gt;
460 const Ttab *p = n_gt.data();
461 for (
int i = 0; i < nc; i++) {
463 for (
int j = 0; j < nc; j++) {
465 for (
int k = 0; k < nc; k++) {
467 if (hamming_dis (ip, jp) <
468 hamming_dis (ip, kp)) {
491 if (iw > jw) std::swap (iw, jw);
494 const Ttab * n_gt_i = n_gt.data();
495 for (
int i = 0; i < nc; i++) {
497 int ip = perm [i == iw ? jw : i == jw ? iw : i];
505 accu += update_i_plane (perm, iw, jw,
515 Taccu update_i (
const int *perm,
int iw,
int jw,
516 int ip0,
int ip,
const Ttab * n_gt_i)
const
519 const Ttab *n_gt_ij = n_gt_i;
520 for (
int j = 0; j < nc; j++) {
522 int jp = perm [j == iw ? jw : j == jw ? iw : j];
523 for (
int k = 0; k < nc; k++) {
525 int kp = perm [k == iw ? jw : k == jw ? iw : k];
526 int ng = n_gt_ij [k];
527 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
530 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
540 Taccu update_i_plane (
const int *perm,
int iw,
int jw,
541 int ip0,
int ip,
const Ttab * n_gt_i)
const
544 const Ttab *n_gt_ij = n_gt_i;
546 for (
int j = 0; j < nc; j++) {
547 if (j != iw && j != jw) {
549 for (
int k = 0; k < nc; k++) {
550 if (k != iw && k != jw) {
552 Ttab ng = n_gt_ij [k];
553 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
556 if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
568 inline Taccu
update_k (
const int *perm,
int iw,
int jw,
569 int ip0,
int ip,
int jp0,
int jp,
571 const Ttab * n_gt_ij)
const
575 int kp = perm [k == iw ? jw : k == jw ? iw : k];
576 Ttab ng = n_gt_ij [k];
577 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
580 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
588 int ip0,
int ip,
int jp0,
int jp,
589 const Ttab * n_gt_ij)
const
592 for (
int k = 0; k < nc; k++) {
593 if (k == iw || k == jw)
continue;
595 Ttab ng = n_gt_ij [k];
596 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
599 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
609 int ip0,
int ip,
const Ttab * n_gt_i)
const
612 const Ttab *n_gt_ij = n_gt_i;
614 for (
int j = 0; j < nc; j++) {
616 int jp = perm [j == iw ? jw : j == jw ? iw : j];
618 accu +=
update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
619 accu +=
update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
622 accu +=
update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
637 double cost_update(
const int* perm,
int iw,
int jw)
const override {
642 ~Score3Computer()
override {}
651 bool operator () (
int a,
int b) {
return tab[a] < tab[b]; }
659 const uint32_t *qcodes, *bcodes;
660 const float *gt_distances;
663 const uint32_t *qcodes,
const uint32_t *bcodes,
664 const float *gt_distances):
665 nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
666 bcodes(bcodes), gt_distances(gt_distances)
669 n_gt.resize (nc * nc * nc);
674 double rank_weight (
int r)
676 return 1.0 / (r + 1);
685 const std::vector<int> & b)
687 int nb = b.size(), na = a.size();
691 for (
int i = 0; i < na; i++) {
693 while (j < nb && ai >= b[j]) j++;
696 for (
int k = j; k < b.size(); k++)
697 accu_i += rank_weight (b[k] - ai);
699 accu += rank_weight (ai) * accu_i;
707 for (
int q = 0; q < nq; q++) {
708 const float *gtd = gt_distances + q * nb;
709 const uint32_t *cb = bcodes;
710 float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
712 printf(
"init gt for q=%d/%d \r", q, nq); fflush(stdout);
714 std::vector<int> rankv (nb);
715 int * ranks = rankv.data();
718 std::vector<std::vector<int> > tab (nc);
722 for (
int j = 0; j < nb; j++) ranks[j] = j;
723 std::sort (ranks, ranks + nb, s);
726 for (
int rank = 0; rank < nb; rank++) {
727 int i = ranks [rank];
728 tab [cb[i]].push_back (rank);
734 for (
int i = 0; i < nc; i++) {
735 std::vector<int> & di = tab[i];
736 for (
int j = 0; j < nc; j++) {
737 std::vector<int> & dj = tab[j];
756 PolysemousTraining::PolysemousTraining ()
759 ntrain_permutation = 0;
760 dis_weight_factor = log(2);
772 int nbits = pq.
nbits;
774 #pragma omp parallel for
775 for (
int m = 0; m < pq.
M; m++) {
776 std::vector<double> dis_table;
782 for (
int i = 0; i < n; i++) {
783 for (
int j = 0; j < n; j++) {
784 dis_table.push_back (
fvec_L2sqr (centroids + i * dsub,
785 centroids + j * dsub,
790 std::vector<int> perm (n);
791 ReproduceWithHammingObjective obj (
798 if (log_pattern.size()) {
800 snprintf (fname, 256, log_pattern.c_str(), m);
801 printf (
"opening log file %s\n", fname);
802 optim.logfile = fopen (fname,
"w");
803 FAISS_THROW_IF_NOT_MSG (optim.logfile,
"could not open logfile");
805 double final_cost = optim.run_optimization (perm.data());
808 printf (
"SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
812 if (log_pattern.size()) fclose (optim.logfile);
814 std::vector<float> centroids_copy;
815 for (
int i = 0; i < dsub * n; i++)
816 centroids_copy.push_back (centroids[i]);
818 for (
int i = 0; i < n; i++)
819 memcpy (centroids + perm[i] * dsub,
820 centroids_copy.data() + i * dsub,
821 dsub *
sizeof(centroids[0]));
834 int nbits = pq.
nbits;
836 std::vector<uint8_t> all_codes (pq.
code_size * n);
843 pq.compute_sdc_table ();
845 #pragma omp parallel for
846 for (
int m = 0; m < pq.
M; m++) {
848 std::vector <uint32_t> codes;
849 std::vector <float> gt_distances;
852 std::vector<float> xtrain (n * dsub);
853 for (
int i = 0; i < n; i++)
854 memcpy (xtrain.data() + i * dsub,
855 x + i * pq.
d + m * dsub,
856 sizeof(float) * dsub);
859 for (
int i = 0; i < n; i++)
860 codes [i] = all_codes [i * pq.
code_size + m];
862 nq = n / 4; nb = n - nq;
863 const float *xq = xtrain.data();
864 const float *xb = xq + nq * dsub;
866 gt_distances.resize (nq * nb);
871 gt_distances.data());
874 codes.resize (2 * nq);
875 for (
int i = 0; i < nq; i++)
876 codes[i] = codes [i + nq] = i;
878 gt_distances.resize (nq * nb);
880 memcpy (gt_distances.data (),
882 sizeof (float) * nq * nb);
889 codes.data(), codes.data() + nq,
890 gt_distances.data ());
894 printf(
" m=%d, nq=%ld, nb=%ld, intialize RankingScore "
901 if (log_pattern.size()) {
903 snprintf (fname, 256, log_pattern.c_str(), m);
904 printf (
"opening log file %s\n", fname);
905 optim.logfile = fopen (fname,
"w");
906 FAISS_THROW_IF_NOT_FMT (optim.logfile,
907 "could not open logfile %s", fname);
910 std::vector<int> perm (pq.
ksub);
912 double final_cost = optim.run_optimization (perm.data());
913 printf (
"SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
916 if (log_pattern.size()) fclose (optim.logfile);
920 std::vector<float> centroids_copy;
921 for (
int i = 0; i < dsub * pq.
ksub; i++)
922 centroids_copy.push_back (centroids[i]);
924 for (
int i = 0; i < pq.
ksub; i++)
925 memcpy (centroids + perm[i] * dsub,
926 centroids_copy.data() + i * dsub,
927 dsub *
sizeof(centroids[0]));
936 size_t n,
const float *x)
const
938 if (optimization_type == OT_None) {
946 pq.compute_sdc_table ();
random generator that can be used in multithreaded contexts
size_t nbits
number of bits per quantization index
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
size_t byte_per_idx
nb bytes per code component (1 or 2)
Taccu compute_update(const int *perm, int iw, int jw) const
std::vector< float > sdc_table
Symmetric Distance Table.
SimulatedAnnealingOptimizer(PermutationObjective *obj, const SimulatedAnnealingParameters &p)
logs values of the cost function
int n
size of the permutation
Taccu compute(const int *perm) const
size_t dsub
dimensionality of each subvector
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
Taccu update_j_line(const int *perm, int iw, int jw, int ip0, int ip, int jp0, int jp, const Ttab *n_gt_ij) const
compute update on a line of k's, where i and j are swapped
const double * target_dis
wanted distances (size n^2)
size_t code_size
byte per indexed vector
double init_cost
remember intial cost of optimization
int rand_int()
random positive integer
size_t ksub
number of centroids for each subquantizer
void optimize_ranking(ProductQuantizer &pq, size_t n, const float *x) const
called by optimize_pq_for_hamming
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
double compute_cost(const int *perm) const override
double getmillisecs()
ms elapsed since some arbitrary epoch
std::vector< double > weights
weights for each distance (size n^2)
double accum_gt_weight_diff(const std::vector< int > &a, const std::vector< int > &b)
parameters used for the simulated annealing method
Taccu update_i_cross(const int *perm, int iw, int jw, int ip0, int ip, const Ttab *n_gt_i) const
considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
size_t M
number of subquantizers
abstract class for the loss function
Taccu update_k(const int *perm, int iw, int jw, int ip0, int ip, int jp0, int jp, int k, const Ttab *n_gt_ij) const
used for the 8 cells were the 3 indices are swapped
std::vector< double > source_dis
"real" corrected distances (size n^2)
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
void optimize_reproduce_distances(ProductQuantizer &pq) const
called by optimize_pq_for_hamming
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
size_t d
size of the input vectors
Simulated annealing optimization algorithm for permutations.