-
Notifications
You must be signed in to change notification settings - Fork 0
/
vose.cc
73 lines (56 loc) · 1.56 KB
/
vose.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <queue>
#include "vose.h"
vose::vose(const std::vector<uint64_t> dist): gen(rd()), dist(dist), total(0) {
for (auto &entry : dist)
total += entry;
rebuild_alias_table();
}
void vose::rebuild_alias_table() {
std::deque<std::vector<vose_entry>::iterator> small;
std::deque<std::vector<vose_entry>::iterator> large;
table.resize(dist.size());
stale_table = false;
if (total == 0)
return;
const double slot_size = total / dist.size();
for (int i = 0; i < dist.size(); i ++) {
table[i] = { (double) dist[i], -1 };
if (dist[i] > slot_size)
large.push_back(table.begin() + i);
if (dist[i] < slot_size)
small.push_back(table.begin() + i);
}
while (!small.empty() && !large.empty()) {
auto lg = large.front();
large.pop_front();
auto sm = small.front();
small.pop_front();
lg->main_p -= (slot_size - sm->main_p);
sm->alt_i = std::distance(table.begin(), lg);
if (lg->main_p > slot_size)
large.push_front(lg);
if (lg->main_p < slot_size)
small.push_front(lg);
}
}
int vose::sample() {
if (stale_table)
rebuild_alias_table();
std::uniform_real_distribution<> unif_dis(0.0, dist.size());
double sample = unif_dis(gen);
int a = sample;
double b = (sample - a) / dist.size() * total;
if (b <= table[a].main_p)
return a;
else
return table[a].alt_i;
}
void vose::update(int idx, double value) {
total -= dist[idx];
total += value;
dist[idx] = value;
stale_table = true;
}
void vose::delta_update(int idx, double delta) {
update(idx, dist[idx] + delta);
}