-
Notifications
You must be signed in to change notification settings - Fork 1
/
ppo_model.py
310 lines (255 loc) · 11.6 KB
/
ppo_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from typing import List, Tuple
import pytorch_lightning as pl
from networks import create_mlp, ActorCriticAgent, ActorCategorical, ActorContinous
from data import ExperienceSourceDataset
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.optimizer import Optimizer
try:
import gym
except ModuleNotFoundError:
_GYM_AVAILABLE = False
else:
_GYM_AVAILABLE = True
class PPO(pl.LightningModule):
"""
PyTorch Lightning implementation of `PPO
<https://arxiv.org/abs/1707.06347>`_
Paper authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov
Example:
model = PPO("CartPole-v0")
Train:
trainer = Trainer()
trainer.fit(model)
Note:
This example is based on:
https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py
https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/rl/reinforce_model.py
"""
def __init__(
self,
env: str,
gamma: float = 0.99,
lam: float = 0.95,
lr_actor: float = 3e-4,
lr_critic: float = 1e-3,
max_episode_len: float = 1000,
batch_size: int = 512,
steps_per_epoch: int = 2048,
nb_optim_iters: int = 4,
clip_ratio: float = 0.2,
) -> None:
"""
Args:
env: gym environment tag
gamma: discount factor
lam: advantage discount factor (lambda in the paper)
lr_actor: learning rate of actor network
lr_critic: learning rate of critic network
max_episode_len: maximum number interactions (actions) in an episode
batch_size: batch_size when training network- can simulate number of policy updates performed per epoch
steps_per_epoch: how many action-state pairs to rollout for trajectory collection per epoch
nb_optim_iters: how many steps of gradient descent to perform on each batch
clip_ratio: hyperparameter for clipping in the policy objective
"""
super().__init__()
if not _GYM_AVAILABLE:
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')
# Hyperparameters
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.steps_per_epoch = steps_per_epoch
self.nb_optim_iters = nb_optim_iters
self.batch_size = batch_size
self.gamma = gamma
self.lam = lam
self.max_episode_len = max_episode_len
self.clip_ratio = clip_ratio
self.save_hyperparameters()
self.env = gym.make(env)
# value network
self.critic = create_mlp(self.env.observation_space.shape, 1)
# policy network (agent)
if isinstance(self.env.action_space, gym.spaces.box.Box):
act_dim = self.env.action_space.shape[0]
actor_mlp = create_mlp(self.env.observation_space.shape, act_dim)
self.actor = ActorContinous(actor_mlp, act_dim)
elif isinstance(self.env.action_space, gym.spaces.discrete.Discrete):
actor_mlp = create_mlp(self.env.observation_space.shape, self.env.action_space.n)
self.actor = ActorCategorical(actor_mlp)
else:
raise NotImplementedError('Env action space should be of type Box (continous) or Discrete (categorical)'
'Got type: ', type(self.env.action_space))
self.agent = ActorCriticAgent(self.actor, self.critic)
self.batch_states = []
self.batch_actions = []
self.batch_adv = []
self.batch_qvals = []
self.batch_logp = []
self.ep_rewards = []
self.ep_values = []
self.epoch_rewards = []
self.episode_step = 0
self.avg_ep_reward = 0
self.avg_ep_len = 0
self.avg_reward = 0
self.state = torch.FloatTensor(self.env.reset())
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Passes in a state x through the network and returns the policy and a sampled action
Args:
x: environment state
Returns:
Tuple of policy and action
"""
pi, action = self.actor(x)
value = self.critic(x)
return pi, action, value
def discount_rewards(self, rewards: List[float], discount: float) -> List[float]:
"""Calculate the discounted rewards of all rewards in list
Args:
rewards: list of rewards/advantages
Returns:
list of discounted rewards/advantages
"""
assert isinstance(rewards[0], float)
cumul_reward = []
sum_r = 0.0
for r in reversed(rewards):
sum_r = (sum_r * discount) + r
cumul_reward.append(sum_r)
return list(reversed(cumul_reward))
def calc_advantage(self, rewards: List[float], values: List[float], last_value: float) -> List[float]:
"""Calculate the advantage given rewards, state values, and the last value of episode
Args:
rewards: list of episode rewards
values: list of state values from critic
last_value: value of last state of episode
Returns:
list of advantages
"""
rews = rewards + [last_value]
vals = values + [last_value]
# GAE
delta = [rews[i] + self.gamma * vals[i + 1] - vals[i] for i in range(len(rews) - 1)]
adv = self.discount_rewards(delta, self.gamma * self.lam)
return adv
def train_batch(
self,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Contains the logic for generating trajectory data to train policy and value network
Yield:
Tuple of Lists containing tensors for states, actions, log probs, qvals and advantage
"""
for step in range(self.steps_per_epoch):
pi, action, log_prob, value = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action.cpu().numpy())
self.episode_step += 1
self.batch_states.append(self.state)
self.batch_actions.append(action)
self.batch_logp.append(log_prob)
self.ep_rewards.append(reward)
self.ep_values.append(value.item())
self.state = torch.FloatTensor(next_state)
epoch_end = step == (self.steps_per_epoch - 1)
terminal = len(self.ep_rewards) == self.max_episode_len
if epoch_end or done or terminal:
# if trajectory ends abtruptly, boostrap value of next state
if (terminal or epoch_end) and not done:
with torch.no_grad():
_, _, _, value = self.agent(self.state, self.device)
last_value = value.item()
steps_before_cutoff = self.episode_step
else:
last_value = 0
steps_before_cutoff = 0
# discounted cumulative reward
self.batch_qvals += self.discount_rewards(self.ep_rewards + [last_value], self.gamma)[:-1]
# advantage
self.batch_adv += self.calc_advantage(self.ep_rewards, self.ep_values, last_value)
# logs
self.epoch_rewards.append(sum(self.ep_rewards))
# reset params
self.ep_rewards = []
self.ep_values = []
self.episode_step = 0
self.state = torch.FloatTensor(self.env.reset())
if epoch_end:
train_data = zip(
self.batch_states, self.batch_actions, self.batch_logp,
self.batch_qvals, self.batch_adv)
for state, action, logp_old, qval, adv in train_data:
yield state, action, logp_old, qval, adv
self.batch_states.clear()
self.batch_actions.clear()
self.batch_adv.clear()
self.batch_logp.clear()
self.batch_qvals.clear()
# logging
self.avg_reward = sum(self.epoch_rewards) / self.steps_per_epoch
# if epoch ended abruptly, exlude last cut-short episode to prevent stats skewness
epoch_rewards = self.epoch_rewards
if not done:
epoch_rewards = epoch_rewards[:-1]
total_epoch_reward = sum(epoch_rewards)
nb_episodes = len(epoch_rewards)
self.avg_ep_reward = total_epoch_reward / nb_episodes
self.avg_ep_len = (self.steps_per_epoch - steps_before_cutoff) / nb_episodes
self.epoch_rewards.clear()
def actor_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor:
pi, _ = self.actor(state)
logp = self.actor.get_log_prob(pi, action)
ratio = torch.exp(logp - logp_old)
clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv
loss_actor = -(torch.min(ratio * adv, clip_adv)).mean()
return loss_actor
def critic_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor:
value = self.critic(state)
loss_critic = (qval - value).pow(2).mean()
return loss_critic
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx, optimizer_idx):
"""
Carries out a single update to actor and critic network from a batch of replay buffer.
Args:
batch: batch of replay buffer/trajectory data
batch_idx: not used
optimizer_idx: idx that controls optimizing actor or critic network
Returns:
loss
"""
state, action, old_logp, qval, adv = batch
# normalize advantages
adv = (adv - adv.mean())/adv.std()
self.log("avg_ep_len", self.avg_ep_len, prog_bar=True, on_step=False, on_epoch=True)
self.log("avg_ep_reward", self.avg_ep_reward, prog_bar=True, on_step=False, on_epoch=True)
self.log("avg_reward", self.avg_reward, prog_bar=True, on_step=False, on_epoch=True)
if optimizer_idx == 0:
loss_actor = self.actor_loss(state, action, old_logp, qval, adv)
self.log('loss_actor', loss_actor, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss_actor
elif optimizer_idx == 1:
loss_critic = self.critic_loss(state, action, old_logp, qval, adv)
self.log('loss_critic', loss_critic, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return loss_critic
def configure_optimizers(self) -> List[Optimizer]:
""" Initialize Adam optimizer"""
optimizer_actor = optim.Adam(self.actor.parameters(), lr=self.lr_actor)
optimizer_critic = optim.Adam(self.critic.parameters(), lr=self.lr_critic)
return optimizer_actor, optimizer_critic
def optimizer_step(self, *args, **kwargs):
"""
Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic
for each data sample.
"""
for i in range(self.nb_optim_iters):
super().optimizer_step(*args, **kwargs)
def _dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
dataset = ExperienceSourceDataset(self.train_batch)
dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
return dataloader
def train_dataloader(self) -> DataLoader:
"""Get train loader"""
return self._dataloader()