-
Notifications
You must be signed in to change notification settings - Fork 11
/
model_trainer.py
129 lines (95 loc) · 4.15 KB
/
model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Train a Model on NHL 94
"""
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import retro
import time
import datetime
import argparse
import logging
import numpy as np
from common import get_model_file_name, com_print, init_logger, create_output_dir
from models import init_model, print_model_info, get_num_parameters
from envs import init_env, init_play_env
def parse_cmdline(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--alg', type=str, default='ppo2')
parser.add_argument('--nn', type=str, default='CnnPolicy')
parser.add_argument('--nnsize', type=int, default='256')
parser.add_argument('--env', type=str, default='NHL941on1-Genesis')
parser.add_argument('--state', type=str, default=None)
parser.add_argument('--num_players', type=int, default='1')
parser.add_argument('--num_env', type=int, default=24)
parser.add_argument('--num_timesteps', type=int, default=6000000)
parser.add_argument('--output_basedir', type=str, default='~/OUTPUT')
parser.add_argument('--load_p1_model', type=str, default='')
parser.add_argument('--display_width', type=int, default='1440')
parser.add_argument('--display_height', type=int, default='810')
parser.add_argument('--alg_verbose', default=True, action='store_true')
parser.add_argument('--info_verbose', default=True, action='store_true')
parser.add_argument('--play', default=False, action='store_true')
parser.add_argument('--rf', type=str, default='')
parser.add_argument('--deterministic', default=True, action='store_true')
print(argv)
args = parser.parse_args(argv)
#if args.info_verbose is False:
# logger.set_level(logger.DISABLED)
return args
class ModelTrainer:
def __init__(self, args, logger):
self.args = args
#if self.args.alg_verbose:
# logger.log('========= Init =============')
self.output_fullpath = create_output_dir(args)
model_savefile_name = get_model_file_name(args)
self.model_savepath = os.path.join(self.output_fullpath, model_savefile_name)
self.env = init_env(self.output_fullpath, args.num_env, args.state, args.num_players, args)
self.p1_model = init_model(self.output_fullpath, args.load_p1_model, args.alg, args, self.env, logger)
#if self.args.alg_verbose:
com_print('OUTPUT PATH: %s' % self.output_fullpath)
com_print('ENV: %s' % args.env)
com_print('STATE: %s' % args.state)
com_print('NN: %s' % args.nn)
com_print('ALGO: %s' % args.alg)
com_print('NUM TIMESTEPS: %s' % args.num_timesteps)
com_print('NUM ENV: %s' % args.num_env)
com_print('NUM PLAYERS: %s' % args.num_players)
print(self.env.observation_space)
def train(self):
#if self.args.alg_verbose:
com_print('========= Start Training ==========')
self.p1_model.learn(total_timesteps=self.args.num_timesteps)
#if self.args.alg_verbose:
com_print('========= End Training ==========')
self.p1_model.save(self.model_savepath )
#if self.args.alg_verbose:
com_print('Model saved to:%s' % self.model_savepath)
return self.model_savepath
def play(self, args, continuous=True):
#if self.args.alg_verbose:
com_print('========= Start Play Loop ==========')
state = self.env.reset()
while True:
self.env.render(mode='human')
p1_actions = self.p1_model.predict(state, deterministic=args.deterministic)
state, reward, done, info = self.env.step(p1_actions[0])
time.sleep(0.01)
#print(reward)
if done[0]:
state = self.env.reset()
if not continuous and done is True:
return info
def main(argv):
args = parse_cmdline(argv[1:])
logger = init_logger(args)
com_print("=========== Params ===========")
com_print(args)
trainer = ModelTrainer(args, logger)
trainer.train()
if args.play:
trainer.play(args)
if __name__ == '__main__':
main(sys.argv)