-
Notifications
You must be signed in to change notification settings - Fork 0
/
pingpong_agent.py
128 lines (89 loc) · 3.74 KB
/
pingpong_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#Reference: https://www.youtube.com/watch?v=L8ypSXwyBds&t=1358s
import torch
from collections import deque
import random
import numpy as np
from pingpong_game import Game, Paddle_dir, Point
from pingpong_model import DQN, DQN_trainer
MAX_MEMORY = 10_000
BATCH_SIZE = 100
LR = 0.001
class DQN_agent:
def __init__(self):
self.n_games = 0
self.gamma = 0.9
self.epsilon = 0
self.model = DQN(input_shape=12, hidden_units=256, output_shape=3)
self.trainer = DQN_trainer(lr=LR, gamma=self.gamma, model=self.model)
self.memory = deque(maxlen=MAX_MEMORY)
def get_state(self, game):
#direction of board
#location of ball (x less than x of board, x in range of board, x greater than end of board)
paddle_x = game.paddle.x
ball_x, ball_y = game.ball.x, game.ball.y
ball_dir_vertical = game.ball_move_val_y
ball_dir_horizontal = game.ball_move_val_x
#[ball_left, ball_straight, ball_right,
# dir_left, dir_stay, dir_right]
state = [
ball_x < paddle_x, #ball is left
ball_x in range(int(paddle_x), int(paddle_x + 100 + 1)), #ball is straight,
ball_x > (paddle_x + 100),
ball_y < game.h/2, #ball not in paddle's half
ball_y > game.h/2, #ball is in paddle's half
#ball_dir
ball_dir_vertical > 0 and ball_dir_horizontal > 0, #D & R
ball_dir_vertical < 0 and ball_dir_horizontal < 0, #U & L
ball_dir_vertical < 0 and ball_dir_horizontal > 0, #U & R
ball_dir_vertical > 0 and ball_dir_horizontal < 0, #D & L
game.direction == Paddle_dir.LEFT,
game.direction == Paddle_dir.STAY,
game.direction == Paddle_dir.RIGHT
]
return np.array(state, dtype=int)
def remember(self, state, action, reward, next_state, game_over):
self.memory.append((state, action, reward, next_state, game_over))
def train_short_memory(self, state, action, reward, next_state, game_over):
self.trainer.train_step(state, action, reward, next_state, game_over)
def train_long_memory(self):
if len(self.memory) > BATCH_SIZE:
mini_sample = random.sample(self.memory, k=BATCH_SIZE)
else:
mini_sample = self.memory
states, actions, rewards, next_states, game_overs = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, game_overs)
def get_action(self, state):
self.epsilon = 80 - self.n_games
if random.randint(0, 200) < self.epsilon:
final_move = [0, 0, 0]
random_idx = random.randint(0, 2)
final_move[random_idx] = 1
else:
final_move = [0, 0, 0]
state = torch.tensor(state, dtype=torch.float)
pred = self.model(state)
idx = pred.argmax().item()
final_move[idx] = 1
return final_move
def train():
record = 0
agent = DQN_agent()
game = Game()
while True:
state = agent.get_state(game=game)
action = agent.get_action(state=state)
reward, game_over, score = game.play_step(action)
new_state= agent.get_state(game=game)
#print(state)
agent.train_short_memory(state, action, reward, new_state, game_over)
agent.remember(state, action, reward, new_state, game_over)
if game_over:
game.reset()
agent.n_games += 1
agent.train_long_memory()
if score > record:
record = score
agent.model.save()
print(f"Game: {agent.n_games} Score: {score} Record: {record}")
if __name__ == "__main__":
train()