-
Notifications
You must be signed in to change notification settings - Fork 11
/
player_vs_model.py
104 lines (79 loc) · 3.39 KB
/
player_vs_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Play a pre-trained model on NHL 94
"""
import os
import sys
import retro
import datetime
import argparse
import logging
import numpy as np
import pygame
from common import get_model_file_name, com_print, init_logger
from models import print_model_info, get_num_parameters, get_model_probabilities
from envs import init_env, init_play_env
import game_wrappers_mgr as games
def parse_cmdline(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--alg', type=str, default='ppo2')
parser.add_argument('--model1_desc', type=str, default='CNN')
parser.add_argument('--nn', type=str, default='MlpPolicy')
parser.add_argument('--nnsize', type=int, default='256')
parser.add_argument('--env', type=str, default='NHL941on1-Genesis')
parser.add_argument('--state', type=str, default=None)
parser.add_argument('--num_players', type=int, default='2')
parser.add_argument('--num_env', type=int, default=1)
parser.add_argument('--num_timesteps', type=int, default=0)
parser.add_argument('--output_basedir', type=str, default='~/OUTPUT')
parser.add_argument('--model_1', type=str, default='')
parser.add_argument('--model_2', type=str, default='')
parser.add_argument('--display_width', type=int, default='1920')
parser.add_argument('--display_height', type=int, default='1080')
parser.add_argument('--fullscreen', default=False, action='store_true')
parser.add_argument('--deterministic', default=True, action='store_true')
parser.add_argument('--rf', type=str, default='')
#parser.add_argument('--useframeskip', default=False, action='store_true')
args = parser.parse_args(argv)
return args
class PlayerVsModel:
def __init__(self, args, logger, need_display=True):
self.p1_env = init_env(None, 1, args.state, 1, args, True)
self.display_env = init_play_env(args, 2, False, need_display, False)
self.ai_sys = games.wrappers.ai_sys(args, self.p1_env, logger)
if args.model_1 != '' or args.model_2 != '':
models = [args.model_1, args.model_2]
self.ai_sys.SetModels(models)
self.need_display = need_display
self.args = args
def play(self, continuous=True, need_reset=True):
state = self.display_env.reset()
total_rewards = 0
skip_frames = 0
p1_actions = []
p2_actions = []
info = None
while True:
p1_actions = self.ai_sys.predict(state, info=info, deterministic=self.args.deterministic)
p2_actions = self.display_env.player_actions
self.display_env.action_probabilities = []
actions = np.append(p1_actions, p2_actions)
for i in range(4):
self.display_env.set_ai_sys_info(self.ai_sys)
state, reward, done, info = self.display_env.step([actions])
total_rewards += reward
if done:
if continuous:
if need_reset:
state = self.display_env.reset()
else:
return info, total_rewards
def main(argv):
args = parse_cmdline(argv[1:])
logger = init_logger(args)
games.wrappers.init(args)
player = PlayerVsModel(args, logger)
com_print('========= Start of Game Loop ==========')
com_print('Press ESC or Q to quit')
player.play(need_reset=False)
if __name__ == '__main__':
main(sys.argv)