-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
74 lines (67 loc) · 2.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
def collect_gameplay_experiences(env, agents, game_count, winner=True):
"""
Collects gameplay experiences by playing env with the instructions
produced by agent and stores the gameplay experiences in buffer.
:param env: the game environment
:param agent: the DQN agent
:param buffer: the replay buffer
:return: None
"""
# TODO fix this function
state_batch = np.zeros((0, 4), dtype=int)
action_batch = []
score_batch = []
done_batch = []
env.set_agents(agents)
game = 0
while game < game_count:
env.reset()
trajectories = env.run(is_training=False)
winner_id = 0
if winner:
winner_id = 0
# get the winning trajectory for training
for trajectory in trajectories:
if trajectory[-1]['win']:
break
winner_id += 1
# if not winner go to next game
if winner_id >= len(trajectories):
continue
game += 1
states = []
actions = []
scores = []
dones = []
for trajectory in trajectories[winner_id]:
states.append(trajectory['state']['hand'] +
[trajectory['state']['score']])
actions.append(trajectory['action'])
scores.append(trajectory['state']['score'])
dones.append(trajectory['done'])
state_batch = np.concatenate((state_batch, np.array(states)))
action_batch += actions
score_batch += scores
done_batch += dones
return (state_batch, action_batch,
score_batch, done_batch)
def evaluate_training_result(env, agents, game_num):
"""
Evaluates the performance of the current DQN agent by using it to play a
few episodes of the game and then calculates the average reward it gets.
The higher the average reward is the better the DQN agent performs.
:param env: the game environment
:param agent: the DQN agent
:return: average reward across episodes
"""
average_rewards = [0.0] * len(agents)
env.set_agents(agents)
for i in range(game_num):
env.reset()
trajectories = env.run(is_training=False)
# get the win on the last state of the game
for j in range(len(trajectories)):
if trajectories[j][-1]['win']:
average_rewards[j] += 1.0/game_num
return average_rewards