-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
91 lines (77 loc) · 2.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pickle
import numpy as np
from baselines.clinical_bandit import ClinicalBandit
from baselines.constant_bandit import ConstantBandit
from baselines.random_bandit import RandomBandit
from knn_ucb import KNNKLBandit, KNNUCBBandit
from lasso import LassoBandit
from linucb import WarfarinLinUCB
from runner import BaseRunner, HyperRunner, RandomRunner
from thompson import WarfarinThompsonSeparate
from utils import get_argument_parser
np.random.seed(seed=10)
# usage: main.py dataset/clean.csv linear --alpha 0.5
BANDIT_MAP = {
"constant": ConstantBandit,
"clinical": ClinicalBandit,
"linear": WarfarinLinUCB,
"thompson": WarfarinThompsonSeparate,
"lasso": LassoBandit,
"knnucb": KNNUCBBandit,
"knnkl": KNNKLBandit,
"random": RandomBandit
}
NON_CONSTANT_BANDITS = [
WarfarinThompsonSeparate, WarfarinLinUCB, ClinicalBandit, RandomBandit,
LassoBandit, KNNUCBBandit
]
def run(args):
# run once
if args.bandit == "hyper":
policies = [policy() for policy in NON_CONSTANT_BANDITS]
runner = HyperRunner(args.datafile, args.alpha, args.process, policies,
[1, 3])
result, policy_counts = runner.run()
elif args.bandit == "randhyper":
policies = [policy() for policy in NON_CONSTANT_BANDITS]
runner = RandomRunner(args.datafile, args.alpha, args.process,
policies, [1, 3])
result, policy_counts = runner.run()
else:
assert args.bandit in BANDIT_MAP
runner = BaseRunner(args.datafile, args.alpha, args.process)
bandit = BANDIT_MAP[args.bandit]()
result = runner.run_bandit(bandit)
policy_counts = None
return result, policy_counts
def runten(args):
regrets = []
corr_fracs = []
counts = []
policy_counts = []
for i in range(10):
print("running %d" % i)
result, policy_count = run(args)
regret, corr_frac, count = result
regrets.append(regret[-1])
corr_fracs.append(corr_frac[-1])
counts.append(count)
policy_counts.append(policy_count)
f = open(args.result_file, "wb")
pickle.dump((regrets, corr_fracs, counts, policy_counts), f)
print("Regret: ", np.mean(regrets))
print("Correct Fraction: ", np.mean(corr_fracs))
def run_single(args):
result, policy_count = run(args)
regret, corr_frac, count = result
f = open(args.result_file, "wb")
pickle.dump((regret, corr_frac, count, policy_count), f)
print("Regret: %s, Correct fraction %s, counter: %s" %
(regret[-1], corr_frac[-1], str(count.items())))
if __name__ == "__main__":
parser = get_argument_parser()
args = parser.parse_args()
if args.runten:
runten(args)
else:
run_single(args)