-
Notifications
You must be signed in to change notification settings - Fork 0
/
agents.py
97 lines (78 loc) · 3.17 KB
/
agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import copy
import numpy as np
import skml_config
class Random:
def __init__(self, random_func):
self.random_func = random_func
def act(self, obs):
return self.random_func()
class DQN:
def __init__(
self,
q_function,
replay_buffer,
explorer,
gamma = 0.99,
train_interval=1,
sync_target_interval=2,
replay_size=1024,
batch_size=32,
epochs = 1
):
assert replay_size <= replay_buffer.capacity
self.q_function = q_function
self.replay_buffer = replay_buffer
self.explorer = explorer
self.gamma = gamma
self.train_interval = train_interval
self.sync_target_interval = sync_target_interval
self.replay_size = replay_size
self.batch_size = batch_size
self.epochs = epochs
self.t = 0
self.last_obs = None
self.last_action = None
self.target_q_function = q_function
self.sync_target_q_function()
def act(self, obs):
obs = np.reshape(obs, (1, *obs.shape))
action = np.argmax(self.q_function.predict(obs)[0])
return action
# sample_obs, *_ = self.replay_buffer.sample(1)
# obs = [obs] + [np.array(sample_obs).reshape(-1)]
# obs = np.array(obs)
# action = np.argmax(self.q_function.predict(obs, True)[0])
# return action
def act_and_add_experience(self, obs, reward):
if self.last_obs is not None:
self.replay_buffer.add(self.last_obs, self.last_action, reward, obs)
action = self.explorer.select_action(self.t, obs, self.act)
self.last_obs = obs
self.last_action = action
return action
def stop_episode_and_train(self, reward):
self.replay_buffer.add(self.last_obs, self.last_action, reward, None)
self.last_obs = None
self.last_action = None
self.t += 1
if self.t % self.train_interval == 0:
self.train()
if self.t % self.sync_target_interval == 0:
self.sync_target_q_function()
def train(self):
if len(self.replay_buffer) == 0:
return
targets = []
observations, actions, rewards, next_observations = self.replay_buffer.sample(self.replay_size)
for observation, action, reward, next_observation in zip(observations, actions, rewards, next_observations):
target = reward
if next_observation is not None:
next_observation = np.reshape(next_observation, (1, *next_observation.shape))
next_action = np.argmax(self.q_function.predict(next_observation)[0])
target += self.gamma * self.target_q_function.predict(next_observation)[0][next_action]
observation = np.reshape(observation, (1, *observation.shape))
targets.append(self.q_function.predict(observation)[0])
targets[len(targets)-1][action] = target
self.q_function.fit(observations, np.array(targets), self.batch_size, self.epochs, verbose=-1)
def sync_target_q_function(self):
self.target_q_function = copy.deepcopy(self.q_function)