-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment.py
91 lines (70 loc) · 2.61 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
'''
Script to run simple tabular RL experiments.
author: [email protected]
'''
import numpy as np
import pandas as pd
from shutil import copyfile
def run_finite_tabular_experiment(agent, env, f_ext, nEps, seed=1,
recFreq=100, fileFreq=1000, targetPath='tmp.csv'):
'''
A simple script to run a finite tabular MDP experiment
Args:
agent - finite tabular agent
env - finite TabularMDP
f_ext - trivial FeatureTrueState
nEps - number of episodes to run
seed - numpy random seed
recFreq - how many episodes between logging
fileFreq - how many episodes between writing file
targetPath - where to write the csv
Returns:
NULL - data is output to targetPath as csv file
'''
data = []
qVals, qMax = env.compute_qVals()
np.random.seed(seed)
cumRegret = 0
cumReward = 0
empRegret = 0
for ep in range(1, nEps + 2):
# Reset the environment
env.reset()
epMaxVal = qMax[env.timestep][env.state]
agent.update_policy(ep)
epReward = 0
epRegret = 0
pContinue = 1
while pContinue > 0:
# Step through the episode
h, oldState = f_ext.get_feat(env)
action = agent.pick_action(oldState, h)
epRegret += qVals[oldState, h].max() - qVals[oldState, h][action]
reward, newState, pContinue = env.advance(action)
epReward += reward
agent.update_obs(oldState, action, reward, newState, pContinue, h)
cumReward += epReward
cumRegret += epRegret
empRegret += (epMaxVal - epReward)
# Variable granularity
if ep < 1e4:
recFreq = 100
elif ep < 1e5:
recFreq = 1000
else:
recFreq = 10000
# Logging to dataframe
if ep % recFreq == 0:
data.append([ep, epReward, cumReward, cumRegret, empRegret])
print ('episode:', ep, 'epReward:', epReward, 'cumRegret:', cumRegret)
if ep % max(fileFreq, recFreq) == 0:
dt = pd.DataFrame(data,
columns=['episode', 'epReward', 'cumReward',
'cumRegret', 'empRegret'])
print ('Writing to file ' + targetPath)
dt.to_csv('tmp.csv', index=False, float_format='%.2f')
copyfile('tmp.csv', targetPath)
print ('****************************')
print ('**************************************************')
print ('Experiment complete')
print ('**************************************************')