-
Notifications
You must be signed in to change notification settings - Fork 8
/
gae.py
32 lines (29 loc) · 1.24 KB
/
gae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Code for computing Generalized Advantage Estimation.
"""
import numpy as np
def gae(rewards, values, episode_ends, gamma, lam):
"""Compute generalized advantage estimate.
rewards: a list of rewards at each step.
values: the value estimate of the state at each step.
episode_ends: an array of the same shape as rewards, with a 1 if the
episode ended at that step and a 0 otherwise.
gamma: the discount factor.
lam: the GAE lambda parameter.
"""
# Invert episode_ends to have 0 if the episode ended and 1 otherwise
episode_ends = (episode_ends * -1) + 1
N = rewards.shape[0]
T = rewards.shape[1]
gae_step = np.zeros((N, ))
advantages = np.zeros((N, T))
for t in reversed(range(T - 1)):
# First compute delta, which is the one-step TD error
delta = rewards[:, t] + gamma * values[:, t + 1] * episode_ends[:, t] - values[:, t]
# Then compute the current step's GAE by discounting the previous step
# of GAE, resetting it to zero if the episode ended, and adding this
# step's delta
gae_step = delta + gamma * lam * episode_ends[:, t] * gae_step
# And store it
advantages[:, t] = gae_step
return advantages