10 #include "PolysemousTraining.h"
21 #include "FaissAssert.h"
34 SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
37 init_temperature = 0.7;
38 temperature_decay = pow (0.9, 1/500.);
44 only_bit_flips =
false;
50 double PermutationObjective::cost_update (
51 const int *perm,
int iw,
int jw)
const
53 double orig_cost = compute_cost (perm);
55 std::vector<int> perm2 (n);
56 for (
int i = 0; i < n; i++)
61 double new_cost = compute_cost (perm2.data());
62 return new_cost - orig_cost;
77 FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
80 SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
86 double SimulatedAnnealingOptimizer::run_optimization (
int * best_perm)
88 double min_cost = 1e30;
91 for (
int it = 0; it < n_redo; it++) {
92 std::vector<int> perm(n);
93 for (
int i = 0; i <
n; i++)
96 for (
int i = 0; i <
n; i++) {
98 std::swap (perm[i], perm[j]);
101 float cost = optimize (perm.data());
102 if (logfile) fprintf (logfile,
"\n");
104 printf (
" optimization run %d: cost=%g %s\n",
105 it, cost, cost < min_cost ?
"keep" :
"");
107 if (cost < min_cost) {
108 memcpy (best_perm, perm.data(),
sizeof(perm[0]) * n);
117 double SimulatedAnnealingOptimizer::optimize (
int *perm)
119 double cost =
init_cost = obj->compute_cost (perm);
121 while (!(n <= (1 << log2n))) log2n++;
122 double temperature = init_temperature;
123 int n_swap = 0, n_hot = 0;
124 for (
int it = 0; it < n_iter; it++) {
125 temperature = temperature * temperature_decay;
127 if (only_bit_flips) {
129 jw = iw ^ (1 << rnd->
rand_int (log2n));
135 double delta_cost = obj->cost_update (perm, iw, jw);
136 if (delta_cost < 0 || rnd->rand_float () < temperature) {
137 std::swap (perm[iw], perm[jw]);
140 if (delta_cost >= 0) n_hot++;
142 if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
143 printf (
" iteration %d cost %g temp %g n_swap %d "
145 it, cost, temperature, n_swap, n_hot);
149 fprintf (logfile,
"%d %g %g %d %d\n",
150 it, cost, temperature, n_swap, n_hot);
153 if (verbose > 1) printf(
"\n");
170 static inline int hamming_dis (uint64_t a, uint64_t b)
172 return __builtin_popcountl (a ^ b);
178 struct ReproduceWithHammingObjective : PermutationObjective {
180 double dis_weight_factor;
182 static double sqr (
double x) {
return x * x; }
187 double dis_weight (
double x)
const
189 return exp (-dis_weight_factor * x);
192 std::vector<double> target_dis;
193 std::vector<double> weights;
196 double compute_cost(
const int* perm)
const override {
198 for (
int i = 0; i < n; i++) {
199 for (
int j = 0; j < n; j++) {
200 double wanted = target_dis[i * n + j];
201 double w = weights[i * n + j];
202 double actual = hamming_dis(perm[i], perm[j]);
203 cost += w * sqr(wanted - actual);
212 double cost_update(
const int* perm,
int iw,
int jw)
const override {
213 double delta_cost = 0;
215 for (
int i = 0; i < n; i++) {
217 for (
int j = 0; j < n; j++) {
218 double wanted = target_dis[i * n + j], w = weights[i * n + j];
219 double actual = hamming_dis(perm[i], perm[j]);
220 delta_cost -= w * sqr(wanted - actual);
222 hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
223 delta_cost += w * sqr(wanted - new_actual);
225 }
else if (i == jw) {
226 for (
int j = 0; j < n; j++) {
227 double wanted = target_dis[i * n + j], w = weights[i * n + j];
228 double actual = hamming_dis(perm[i], perm[j]);
229 delta_cost -= w * sqr(wanted - actual);
231 hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
232 delta_cost += w * sqr(wanted - new_actual);
237 double wanted = target_dis[i * n + j], w = weights[i * n + j];
238 double actual = hamming_dis(perm[i], perm[j]);
239 delta_cost -= w * sqr(wanted - actual);
240 double new_actual = hamming_dis(perm[i], perm[jw]);
241 delta_cost += w * sqr(wanted - new_actual);
245 double wanted = target_dis[i * n + j], w = weights[i * n + j];
246 double actual = hamming_dis(perm[i], perm[j]);
247 delta_cost -= w * sqr(wanted - actual);
248 double new_actual = hamming_dis(perm[i], perm[iw]);
249 delta_cost += w * sqr(wanted - new_actual);
259 ReproduceWithHammingObjective (
261 const std::vector<double> & dis_table,
262 double dis_weight_factor):
263 nbits (nbits), dis_weight_factor (dis_weight_factor)
266 FAISS_THROW_IF_NOT (dis_table.size() == n * n);
267 set_affine_target_dis (dis_table);
270 void set_affine_target_dis (
const std::vector<double> & dis_table)
272 double sum = 0, sum2 = 0;
274 for (
int i = 0; i < n2; i++) {
275 sum += dis_table [i];
276 sum2 += dis_table [i] * dis_table [i];
278 double mean = sum / n2;
279 double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
281 target_dis.resize (n2);
283 for (
int i = 0; i < n2; i++) {
285 double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
289 weights.push_back (dis_weight (td));
294 ~ReproduceWithHammingObjective()
override {}
301 double ReproduceDistancesObjective::dis_weight (
double x)
const
303 return exp (-dis_weight_factor * x);
307 double ReproduceDistancesObjective::get_source_dis (
int i,
int j)
const
313 double ReproduceDistancesObjective::compute_cost (
const int *perm)
const
316 for (
int i = 0; i < n; i++) {
317 for (
int j = 0; j < n; j++) {
319 double w =
weights [i * n + j];
320 double actual = get_source_dis (perm[i], perm[j]);
321 cost += w * sqr (wanted - actual);
329 double ReproduceDistancesObjective::cost_update(
330 const int *perm,
int iw,
int jw)
const
332 double delta_cost = 0;
333 for (
int i = 0; i < n; i++) {
335 for (
int j = 0; j < n; j++) {
338 double actual = get_source_dis (perm[i], perm[j]);
339 delta_cost -= w * sqr (wanted - actual);
340 double new_actual = get_source_dis (
342 perm[j == iw ? jw : j == jw ? iw : j]);
343 delta_cost += w * sqr (wanted - new_actual);
345 }
else if (i == jw) {
346 for (
int j = 0; j < n; j++) {
349 double actual = get_source_dis (perm[i], perm[j]);
350 delta_cost -= w * sqr (wanted - actual);
351 double new_actual = get_source_dis (
353 perm[j == iw ? jw : j == jw ? iw : j]);
354 delta_cost += w * sqr (wanted - new_actual);
361 double actual = get_source_dis (perm[i], perm[j]);
362 delta_cost -= w * sqr (wanted - actual);
363 double new_actual = get_source_dis (perm[i], perm[jw]);
364 delta_cost += w * sqr (wanted - new_actual);
370 double actual = get_source_dis (perm[i], perm[j]);
371 delta_cost -= w * sqr (wanted - actual);
372 double new_actual = get_source_dis (perm[i], perm[iw]);
373 delta_cost += w * sqr (wanted - new_actual);
382 ReproduceDistancesObjective::ReproduceDistancesObjective (
384 const double *source_dis_in,
385 const double *target_dis_in,
386 double dis_weight_factor):
387 dis_weight_factor (dis_weight_factor),
388 target_dis (target_dis_in)
391 set_affine_target_dis (source_dis_in);
394 void ReproduceDistancesObjective::compute_mean_stdev (
395 const double *tab,
size_t n2,
396 double *mean_out,
double *stddev_out)
398 double sum = 0, sum2 = 0;
399 for (
int i = 0; i < n2; i++) {
401 sum2 += tab [i] * tab [i];
403 double mean = sum / n2;
404 double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
406 *stddev_out = stddev;
409 void ReproduceDistancesObjective::set_affine_target_dis (
410 const double *source_dis_in)
414 double mean_src, stddev_src;
415 compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
417 double mean_target, stddev_target;
418 compute_mean_stdev (
target_dis, n2, &mean_target, &stddev_target);
420 printf (
"map mean %g std %g -> mean %g std %g\n",
421 mean_src, stddev_src, mean_target, stddev_target);
426 for (
int i = 0; i < n2; i++) {
428 source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
429 * stddev_target + mean_target;
443 template <
typename Ttab,
typename Taccu>
451 std::vector<Ttab> n_gt;
459 const Ttab *p = n_gt.data();
460 for (
int i = 0; i < nc; i++) {
462 for (
int j = 0; j < nc; j++) {
464 for (
int k = 0; k < nc; k++) {
466 if (hamming_dis (ip, jp) <
467 hamming_dis (ip, kp)) {
490 if (iw > jw) std::swap (iw, jw);
493 const Ttab * n_gt_i = n_gt.data();
494 for (
int i = 0; i < nc; i++) {
496 int ip = perm [i == iw ? jw : i == jw ? iw : i];
504 accu += update_i_plane (perm, iw, jw,
514 Taccu update_i (
const int *perm,
int iw,
int jw,
515 int ip0,
int ip,
const Ttab * n_gt_i)
const
518 const Ttab *n_gt_ij = n_gt_i;
519 for (
int j = 0; j < nc; j++) {
521 int jp = perm [j == iw ? jw : j == jw ? iw : j];
522 for (
int k = 0; k < nc; k++) {
524 int kp = perm [k == iw ? jw : k == jw ? iw : k];
525 int ng = n_gt_ij [k];
526 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
529 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
539 Taccu update_i_plane (
const int *perm,
int iw,
int jw,
540 int ip0,
int ip,
const Ttab * n_gt_i)
const
543 const Ttab *n_gt_ij = n_gt_i;
545 for (
int j = 0; j < nc; j++) {
546 if (j != iw && j != jw) {
548 for (
int k = 0; k < nc; k++) {
549 if (k != iw && k != jw) {
551 Ttab ng = n_gt_ij [k];
552 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
555 if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
567 inline Taccu
update_k (
const int *perm,
int iw,
int jw,
568 int ip0,
int ip,
int jp0,
int jp,
570 const Ttab * n_gt_ij)
const
574 int kp = perm [k == iw ? jw : k == jw ? iw : k];
575 Ttab ng = n_gt_ij [k];
576 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
579 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
587 int ip0,
int ip,
int jp0,
int jp,
588 const Ttab * n_gt_ij)
const
591 for (
int k = 0; k < nc; k++) {
592 if (k == iw || k == jw)
continue;
594 Ttab ng = n_gt_ij [k];
595 if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
598 if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
608 int ip0,
int ip,
const Ttab * n_gt_i)
const
611 const Ttab *n_gt_ij = n_gt_i;
613 for (
int j = 0; j < nc; j++) {
615 int jp = perm [j == iw ? jw : j == jw ? iw : j];
617 accu +=
update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
618 accu +=
update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
621 accu +=
update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
636 double cost_update(
const int* perm,
int iw,
int jw)
const override {
641 ~Score3Computer()
override {}
650 bool operator () (
int a,
int b) {
return tab[a] < tab[b]; }
658 const uint32_t *qcodes, *bcodes;
659 const float *gt_distances;
662 const uint32_t *qcodes,
const uint32_t *bcodes,
663 const float *gt_distances):
664 nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
665 bcodes(bcodes), gt_distances(gt_distances)
668 n_gt.resize (nc * nc * nc);
673 double rank_weight (
int r)
675 return 1.0 / (r + 1);
684 const std::vector<int> & b)
686 int nb = b.size(), na = a.size();
690 for (
int i = 0; i < na; i++) {
692 while (j < nb && ai >= b[j]) j++;
695 for (
int k = j; k < b.size(); k++)
696 accu_i += rank_weight (b[k] - ai);
698 accu += rank_weight (ai) * accu_i;
706 for (
int q = 0; q < nq; q++) {
707 const float *gtd = gt_distances + q * nb;
708 const uint32_t *cb = bcodes;
709 float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
711 printf(
"init gt for q=%d/%d \r", q, nq); fflush(stdout);
713 std::vector<int> rankv (nb);
714 int * ranks = rankv.data();
717 std::vector<std::vector<int> > tab (nc);
721 for (
int j = 0; j < nb; j++) ranks[j] = j;
722 std::sort (ranks, ranks + nb, s);
725 for (
int rank = 0; rank < nb; rank++) {
726 int i = ranks [rank];
727 tab [cb[i]].push_back (rank);
733 for (
int i = 0; i < nc; i++) {
734 std::vector<int> & di = tab[i];
735 for (
int j = 0; j < nc; j++) {
736 std::vector<int> & dj = tab[j];
755 PolysemousTraining::PolysemousTraining ()
758 ntrain_permutation = 0;
759 dis_weight_factor = log(2);
771 int nbits = pq.
nbits;
773 #pragma omp parallel for
774 for (
int m = 0; m < pq.
M; m++) {
775 std::vector<double> dis_table;
781 for (
int i = 0; i < n; i++) {
782 for (
int j = 0; j < n; j++) {
783 dis_table.push_back (
fvec_L2sqr (centroids + i * dsub,
784 centroids + j * dsub,
789 std::vector<int> perm (n);
790 ReproduceWithHammingObjective obj (
797 if (log_pattern.size()) {
799 snprintf (fname, 256, log_pattern.c_str(), m);
800 printf (
"opening log file %s\n", fname);
801 optim.logfile = fopen (fname,
"w");
802 FAISS_THROW_IF_NOT_MSG (optim.logfile,
"could not open logfile");
804 double final_cost = optim.run_optimization (perm.data());
807 printf (
"SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
811 if (log_pattern.size()) fclose (optim.logfile);
813 std::vector<float> centroids_copy;
814 for (
int i = 0; i < dsub * n; i++)
815 centroids_copy.push_back (centroids[i]);
817 for (
int i = 0; i < n; i++)
818 memcpy (centroids + perm[i] * dsub,
819 centroids_copy.data() + i * dsub,
820 dsub *
sizeof(centroids[0]));
833 int nbits = pq.
nbits;
835 std::vector<uint8_t> all_codes (pq.
code_size * n);
839 FAISS_THROW_IF_NOT (pq.
nbits == 8);
842 pq.compute_sdc_table ();
844 #pragma omp parallel for
845 for (
int m = 0; m < pq.
M; m++) {
847 std::vector <uint32_t> codes;
848 std::vector <float> gt_distances;
851 std::vector<float> xtrain (n * dsub);
852 for (
int i = 0; i < n; i++)
853 memcpy (xtrain.data() + i * dsub,
854 x + i * pq.
d + m * dsub,
855 sizeof(float) * dsub);
858 for (
int i = 0; i < n; i++)
859 codes [i] = all_codes [i * pq.
code_size + m];
861 nq = n / 4; nb = n - nq;
862 const float *xq = xtrain.data();
863 const float *xb = xq + nq * dsub;
865 gt_distances.resize (nq * nb);
870 gt_distances.data());
873 codes.resize (2 * nq);
874 for (
int i = 0; i < nq; i++)
875 codes[i] = codes [i + nq] = i;
877 gt_distances.resize (nq * nb);
879 memcpy (gt_distances.data (),
881 sizeof (float) * nq * nb);
888 codes.data(), codes.data() + nq,
889 gt_distances.data ());
893 printf(
" m=%d, nq=%ld, nb=%ld, intialize RankingScore "
900 if (log_pattern.size()) {
902 snprintf (fname, 256, log_pattern.c_str(), m);
903 printf (
"opening log file %s\n", fname);
904 optim.logfile = fopen (fname,
"w");
905 FAISS_THROW_IF_NOT_FMT (optim.logfile,
906 "could not open logfile %s", fname);
909 std::vector<int> perm (pq.
ksub);
911 double final_cost = optim.run_optimization (perm.data());
912 printf (
"SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
915 if (log_pattern.size()) fclose (optim.logfile);
919 std::vector<float> centroids_copy;
920 for (
int i = 0; i < dsub * pq.
ksub; i++)
921 centroids_copy.push_back (centroids[i]);
923 for (
int i = 0; i < pq.
ksub; i++)
924 memcpy (centroids + perm[i] * dsub,
925 centroids_copy.data() + i * dsub,
926 dsub *
sizeof(centroids[0]));
935 size_t n,
const float *x)
const
937 if (optimization_type == OT_None) {
945 pq.compute_sdc_table ();
random generator that can be used in multithreaded contexts
size_t nbits
number of bits per quantization index
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
Taccu compute_update(const int *perm, int iw, int jw) const
std::vector< float > sdc_table
Symmetric Distance Table.
SimulatedAnnealingOptimizer(PermutationObjective *obj, const SimulatedAnnealingParameters &p)
logs values of the cost function
int n
size of the permutation
Taccu compute(const int *perm) const
size_t dsub
dimensionality of each subvector
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
Taccu update_j_line(const int *perm, int iw, int jw, int ip0, int ip, int jp0, int jp, const Ttab *n_gt_ij) const
compute update on a line of k's, where i and j are swapped
const double * target_dis
wanted distances (size n^2)
size_t code_size
byte per indexed vector
double init_cost
remember intial cost of optimization
int rand_int()
random positive integer
size_t ksub
number of centroids for each subquantizer
void optimize_ranking(ProductQuantizer &pq, size_t n, const float *x) const
called by optimize_pq_for_hamming
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
double compute_cost(const int *perm) const override
double getmillisecs()
ms elapsed since some arbitrary epoch
std::vector< double > weights
weights for each distance (size n^2)
double accum_gt_weight_diff(const std::vector< int > &a, const std::vector< int > &b)
parameters used for the simulated annealing method
Taccu update_i_cross(const int *perm, int iw, int jw, int ip0, int ip, const Ttab *n_gt_i) const
considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
size_t M
number of subquantizers
abstract class for the loss function
Taccu update_k(const int *perm, int iw, int jw, int ip0, int ip, int jp0, int jp, int k, const Ttab *n_gt_ij) const
used for the 8 cells were the 3 indices are swapped
std::vector< double > source_dis
"real" corrected distances (size n^2)
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
void optimize_reproduce_distances(ProductQuantizer &pq) const
called by optimize_pq_for_hamming
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
size_t d
size of the input vectors
Simulated annealing optimization algorithm for permutations.