-
Notifications
You must be signed in to change notification settings - Fork 0
/
Vose.h
153 lines (124 loc) · 3.68 KB
/
Vose.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#ifndef _VOSE_H_
#define _VOSE_H_
#include <algorithm>
#include <functional>
#include <random>
#include <string>
#include <sstream>
#include <vector>
#include "VoseUtil.h"
struct marble {
long weight;
std::string label;
size_t index;
bool operator<(const marble &m) const {
return (label < m.label && weight < m.weight);
}
std::string toString() {
std::stringstream ss;
ss << "<" << label;
ss << ",";
ss << weight << ">";
return ss.str();
}
};
class Vose {
private:
friend class test;
/** A random generator in the range of [0,scores.size() ) */
std::function<double()> random_scores_die;
/** A random number generator in the range [0,1] */
std::function<double()> random_prob_die;
protected:
/** All the marbles with their label and weight */
std::vector<marble> scores;
/** The total weight of all the marbles */
long total_weight;
std::vector<size_t> alias;
std::vector<double> prob;
public:
Vose(): total_weight(0L),
scores(std::vector<marble> ()),
alias(std::vector<size_t> ()),
prob(std::vector<double> ()),
random_scores_die(std::function<double()> ()),
random_prob_die(std::function<double()> ()) { }
/**
* Adds a new marble(label, weight) to the Urn.
* It updates all necessary internal structures.
* Returns the new contribution of te added marble.
*/
double add(marble m);
double add(std::string label, long weight);
/**
* Runs vose algorithm over the whole Urn of marbles.
*/
void init();
/**
* Make a random from from vose.
* Vose must be initialized with init.
*/
size_t rand();
/**
* Call rand() but return the actual marble
*/
marble rand_marble();
/**
* Merge this Vose structure with another and return the result.
*/
Vose merge(const Vose& vose);
Vose merge(const Vose &v1, const Vose &v2);
/**
* Given a key, find this marble and update the value. O(n)
* If new_val is zero, remove it.
*/
//virtual Vose update(std::string key, long new_val);
// Utility Methods -----------------------------------------
/**
* Compares two distributions of values
*/
static double kl(const std::vector<double> &a, const std::vector<double> &b);
/**
* Return a vector of the relative probabilities of the scores.
*/
inline std::vector<double> probabilities() {
std::vector<double> v;
std::for_each (scores.begin(), scores.end(), [&] (marble m) {
v.push_back((double) m.weight/total_weight);
});
return v;
}
/**
* Use the random int generator to generate a random in in the
* range [0, scores.size()).
*/
inline int random_scores_index() {
return floor(random_scores_die());
}
inline double random_scores_double() {
return random_scores_die();
}
void reset_scores_die() {
// Update the random number generators
std::default_random_engine generator(VoseUtil::RANDOM_SEED);
std::uniform_real_distribution<double>
real_distribution(0, scores.size()>0?scores.size():1);
random_scores_die = std::bind(real_distribution, generator);
}
/**
* Return a double from a uniform distribution in the range [0.0, 1.0].
*/
inline double random_double() {
return random_prob_die();
}
void reset_random_double() {
// Update the random number generators
std::default_random_engine generator(VoseUtil::RANDOM_SEED);
std::uniform_real_distribution<double> uniform_distribution(0.0,1.0);
random_prob_die = std::bind(uniform_distribution, generator);
}
long weight() { return total_weight; }
std::vector<marble> get_scores() const { return scores;}
std::string toString();
};
#endif // _VOSE_H_