-
Notifications
You must be signed in to change notification settings - Fork 0
/
Strategy.java
88 lines (74 loc) · 2.52 KB
/
Strategy.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import java.lang.StringBuilder;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Collections;
import java.util.Formatter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.EnumSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
public class Strategy {
private final HashMap<InfoSet, Distribution<Double>> strategy;
public Strategy(GameTree gt) {
this.strategy = new HashMap<InfoSet, Distribution<Double>>();
init(gt, gt.getRoot());
}
private void init(GameTree gt, GameNode n) {
if (n.getChildren().size() == 0) {
return;
}
if (n instanceof ActionNode) {
ActionNode an = (ActionNode) n;
Set<Double> bset = Sets.newHashSet(an.getBets());
strategy.put(an.getInfoSet(), new Distribution<Double>(bset));
}
for (GameNode c : n.getChildren()) {
init(gt, c);
}
}
public Distribution<Double> getDist(InfoSet iset) {
return strategy.get(iset);
}
public void setDist(InfoSet iset, Distribution<Double> dist) {
strategy.put(iset, dist);
}
public double getProb(InfoSet iset, Double d) {
try {
return strategy.get(iset).get(d);
} catch (Exception e) {
System.out.println(this);
System.out.println(d + " " + iset);
System.out.println(strategy.get(iset));
throw new RuntimeException(e);
}
}
public void setAction(InfoSet iset, Double d) {
strategy.get(iset).unilaterally(d);
}
public void average(Strategy s, double w) {
assert strategy.keySet().size() == s.strategy.keySet().size();
for (InfoSet is : strategy.keySet()) {
strategy.get(is).average(s.strategy.get(is), w);
}
}
public String toString() {
StringBuilder sb = new StringBuilder();
Formatter formatter = new Formatter(sb);
List<InfoSet> isets = Lists.newArrayList(strategy.keySet());
Collections.sort(isets, new Comparator<InfoSet>() {
public int compare(InfoSet s1, InfoSet s2) {
return s1.toString().compareTo(s2.toString());
}
});
for (InfoSet iset : isets) {
formatter.format("%-60s %s\n", iset, strategy.get(iset));
}
sb.setLength(sb.length() - 1);
return sb.toString();
}
}