-
Notifications
You must be signed in to change notification settings - Fork 21
/
train.py
67 lines (53 loc) · 1.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from os import path
import configparser
import numpy as np
import random
import gym
import gym_flock
import torch
import sys
from learner.gnn_cloning import train_cloning
from learner.gnn_dagger import train_dagger
from learner.gnn_baseline import train_baseline
def run_experiment(args):
# initialize gym env
env_name = args.get('env')
env = gym.make(env_name)
if isinstance(env.env, gym_flock.envs.FlockingRelativeEnv):
env.env.params_from_cfg(args)
# use seed
seed = args.getint('seed')
env.seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# initialize params tuple
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alg = args.get('alg').lower()
if alg == 'dagger':
stats = train_dagger(env, args, device)
elif alg == 'cloning':
stats = train_cloning(env, args, device)
elif alg == 'baseline':
stats = train_baseline(env, args)
else:
raise Exception('Invalid algorithm/mode name')
return stats
def main():
fname = sys.argv[1]
config_file = path.join(path.dirname(__file__), fname)
config = configparser.ConfigParser()
config.read(config_file)
printed_header = False
if config.sections():
for section_name in config.sections():
if not printed_header:
print(config[section_name].get('header'))
printed_header = True
stats = run_experiment(config[section_name])
print(section_name + ", " + str(stats['mean']) + ", " + str(stats['std']))
else:
val = run_experiment(config[config.default_section])
print(val)
if __name__ == "__main__":
main()