forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 11
/
PolysemousTraining.h
158 lines (105 loc) · 4.71 KB
/
PolysemousTraining.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
#define FAISS_POLYSEMOUS_TRAINING_INCLUDED
#include "ProductQuantizer.h"
namespace faiss {
/// parameters used for the simulated annealing method
struct SimulatedAnnealingParameters {
// optimization parameters
double init_temperature; // init probaility of accepting a bad swap
double temperature_decay; // at each iteration the temp is multiplied by this
int n_iter; // nb of iterations
int n_redo; // nb of runs of the simulation
int seed; // random seed
int verbose;
bool only_bit_flips; // restrict permutation changes to bit flips
bool init_random; // intialize with a random permutation (not identity)
// set reasonable defaults
SimulatedAnnealingParameters ();
};
/// abstract class for the loss function
struct PermutationObjective {
int n;
virtual double compute_cost (const int *perm) const = 0;
// what would the cost update be if iw and jw were swapped?
// default implementation just computes both and computes the difference
virtual double cost_update (const int *perm, int iw, int jw) const;
virtual ~PermutationObjective () {}
};
struct ReproduceDistancesObjective : PermutationObjective {
double dis_weight_factor;
static double sqr (double x) { return x * x; }
// weihgting of distances: it is more important to reproduce small
// distances well
double dis_weight (double x) const;
std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
const double * target_dis; ///< wanted distances (size n^2)
std::vector<double> weights; ///< weights for each distance (size n^2)
double get_source_dis (int i, int j) const;
// cost = quadratic difference between actual distance and Hamming distance
double compute_cost(const int* perm) const override;
// what would the cost update be if iw and jw were swapped?
// computed in O(n) instead of O(n^2) for the full re-computation
double cost_update(const int* perm, int iw, int jw) const override;
ReproduceDistancesObjective (
int n,
const double *source_dis_in,
const double *target_dis_in,
double dis_weight_factor);
static void compute_mean_stdev (const double *tab, size_t n2,
double *mean_out, double *stddev_out);
void set_affine_target_dis (const double *source_dis_in);
~ReproduceDistancesObjective() override {}
};
struct RandomGenerator;
/// Simulated annealing optimization algorithm for permutations.
struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters {
PermutationObjective *obj;
int n; ///< size of the permutation
FILE *logfile; /// logs values of the cost function
SimulatedAnnealingOptimizer (PermutationObjective *obj,
const SimulatedAnnealingParameters &p);
RandomGenerator *rnd;
/// remember intial cost of optimization
double init_cost;
// main entry point. Perform the optimization loop, starting from
// and modifying permutation in-place
double optimize (int *perm);
// run the optimization and return the best result in best_perm
double run_optimization (int * best_perm);
virtual ~SimulatedAnnealingOptimizer ();
};
/// optimizes the order of indices in a ProductQuantizer
struct PolysemousTraining: SimulatedAnnealingParameters {
enum Optimization_type_t {
OT_None,
OT_ReproduceDistances_affine, ///< default
OT_Ranking_weighted_diff /// same as _2, but use rank of y+ - rank of y-
};
Optimization_type_t optimization_type;
// use 1/4 of the training points for the optimization, with
// max. ntrain_permutation. If ntrain_permutation == 0: train on
// centroids
int ntrain_permutation;
double dis_weight_factor; // decay of exp that weights distance loss
// filename pattern for the logging of iterations
std::string log_pattern;
// sets default values
PolysemousTraining ();
/// reorder the centroids so that the Hamming distace becomes a
/// good approximation of the SDC distance (called by train)
void optimize_pq_for_hamming (ProductQuantizer & pq,
size_t n, const float *x) const;
/// called by optimize_pq_for_hamming
void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const;
/// called by optimize_pq_for_hamming
void optimize_reproduce_distances (ProductQuantizer &pq) const;
};
} // namespace faiss
#endif