-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
127 lines (106 loc) · 3.89 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gym
import os
import torch.optim as optim
from agents import A2CAgent, RandomAgent, HumanAgent, MCTSAgent
# learning parameters
N_EPISODES = 150 # a few hours on a local pc
PLAYING_MODE = 'self'
# 'self' -> self play,
# 'adversarial' -> several networks learn against each others,
# 'random' -> against a random baseline
# 'mcts' against a MCTS baseline
# 'manual' -> against a human
TRAINING = True
GAMMA = 0.98 # weight of future rewards
UPDATE_EVERY = 15 # how many moves to play before updating (15 = update about twice a game)
SAVE_EVERY = 1000 # save a model every 1000 episodes
# learning rates
ACTOR_LR = 1e-3
CRITIC_LR = 1e-4
# network parameters
HIDDEN_DIM = 256
# game parameters
N_PLAYERS = 2
# paths
save_to = 'checkpoints_against_random' # folder to save the models into
load_from = 'checkpoints/144999.pt' # to use pretrained agents
# ================== SET UP ENVIRONMENT ====================================
env = gym.make("gym_azul:azul-v0", n_players=N_PLAYERS)
if not os.path.exists(save_to):
os.mkdir(save_to)
# =================== SET UP AGENTS ========================================
agents = []
if PLAYING_MODE == 'adversarial':
# n learning agents
for i in range(N_PLAYERS):
actor_optim = optim.Adam
critic_optim = optim.Adam
agent = A2CAgent(env, HIDDEN_DIM, actor_optim, critic_optim, ACTOR_LR, CRITIC_LR, GAMMA)
if load_from:
agent.load(load_from)
agents.append(agent)
if PLAYING_MODE == 'self':
# n times the same learning agent
actor_optim = optim.Adam
critic_optim = optim.Adam
agent = A2CAgent(env, HIDDEN_DIM, actor_optim, critic_optim, ACTOR_LR, CRITIC_LR, GAMMA,
nb_channels=N_PLAYERS)
if load_from:
agent.load(load_from)
for i in range(N_PLAYERS):
agents.append(agent)
if PLAYING_MODE == 'random':
# 1 learning agent and n-1 random agents
actor_optim = optim.Adam
critic_optim = optim.Adam
agent = A2CAgent(env, HIDDEN_DIM, actor_optim, critic_optim, ACTOR_LR, CRITIC_LR, GAMMA)
if load_from:
agent.load(load_from)
agents.append(agent)
for i in range(N_PLAYERS - 1):
agents.append(RandomAgent())
if PLAYING_MODE == 'mcts':
assert N_PLAYERS == 2 # only number supported yet
# 1 learning agent and 1 MCTS agent
actor_optim = optim.Adam
critic_optim = optim.Adam
agent = A2CAgent(env, HIDDEN_DIM, actor_optim, critic_optim, ACTOR_LR, CRITIC_LR, GAMMA)
if load_from:
agent.load(load_from)
agents.append(agent)
agents.append(MCTSAgent())
if PLAYING_MODE == 'manual':
# 1 agent and 1 human
assert N_PLAYERS == 2
actor_optim = optim.Adam
critic_optim = optim.Adam
agent = A2CAgent(env, HIDDEN_DIM, actor_optim, critic_optim, ACTOR_LR, CRITIC_LR, GAMMA)
if load_from:
agent.load(load_from)
agents.append(agent)
agents.append(HumanAgent())
# ==================== ACTUAL TRAINING ==========================================
for ep in range(N_EPISODES):
state = env.reset()
done = False
counter = 0
print('Game {}/{}'.format(ep + 1, N_EPISODES))
while not done:
update = not ((counter + 1) % UPDATE_EVERY)
counter += 1
for id, agent in enumerate(agents):
if done: break
state, done = agent.play(state, env, id)
if update and TRAINING:
agent.update()
winner, score = env.get_winner()
for id, agent in enumerate(agents):
# has no meaning if agent is self-playing
agent.next_game(winner == id)
print('Game completed in {} moves. Agent {} won with score {}'.format(counter + 1, winner, score))
if not ((ep + 1) % SAVE_EVERY):
for agent in agents:
agent.save('{}/{}.pt'.format(save_to, ep))
for id, agent in enumerate(agents):
# has no meaning if agent is self-playing
print('Agent {} stats: {}'.format(id, agent.stats))