-
Notifications
You must be signed in to change notification settings - Fork 0
/
chain_experiment.py
74 lines (57 loc) · 2.5 KB
/
chain_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
Script to run tabular experiments in batch mode.
author: [email protected]
'''
import numpy as np
import pandas as pd
import argparse
import sys
import environment
import finite_tabular_agents
from feature_extractor import FeatureTrueState
from experiment import run_finite_tabular_experiment
if __name__ == '__main__':
'''
Run a tabular experiment according to command line arguments
'''
# Take in command line flags
parser = argparse.ArgumentParser(description='Run tabular RL experiment')
parser.add_argument('chainLen', help='length of chain', type=int)
parser.add_argument('alg', help='Agent constructor', type=str)
parser.add_argument('scaling', help='scaling', type=float)
parser.add_argument('seed', help='random seed', type=int)
parser.add_argument('nEps', help='number of episodes', type=int)
args = parser.parse_args()
# Make a filename to identify flags
fileName = ('chainLen'
+ '_len=' + '%03.f' % args.chainLen
+ '_alg=' + str(args.alg)
+ '_scal=' + '%03.2f' % args.scaling
+ '_seed=' + str(args.seed)
+ '.csv')
folderName = './'
targetPath = folderName + fileName
print ('******************************************************************')
print (fileName)
print ('******************************************************************')
# Make the environment
env = environment.make_stochasticChain(args.chainLen)
# Make the feature extractor
f_ext = FeatureTrueState(env.epLen, env.nState, env.nAction, env.nState)
# Make the agent
alg_dict = {'PSRL': finite_tabular_agents.PSRL,
'PSRLunif': finite_tabular_agents.PSRLunif,
'OptimisticPSRL': finite_tabular_agents.OptimisticPSRL,
'GaussianPSRL': finite_tabular_agents.GaussianPSRL,
'UCBVI': finite_tabular_agents.UCBVI,
'BEB': finite_tabular_agents.BEB,
'BOLT': finite_tabular_agents.BOLT,
'UCRL2': finite_tabular_agents.UCRL2,
'UCFH': finite_tabular_agents.UCFH,
'EpsilonGreedy': finite_tabular_agents.EpsilonGreedy}
agent_constructor = alg_dict[args.alg]
agent = agent_constructor(env.nState, env.nAction, env.epLen,
scaling=args.scaling)
# Run the experiment
run_finite_tabular_experiment(agent, env, f_ext, args.nEps, args.seed,
recFreq=100, fileFreq=1000, targetPath=targetPath)