-
Notifications
You must be signed in to change notification settings - Fork 48
/
storage.py
50 lines (42 loc) · 1.96 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
# Storage for n-step training.
class PCTRolloutStorage(object):
def __init__(self, num_steps, num_processes, obs_shape, gamma):
self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
self.actions = torch.zeros(num_steps, num_processes, 1).long()
self.masks = torch.ones(num_steps + 1, num_processes, 1)
self.num_steps = num_steps
self.gamma = gamma
self.step = 0
def to(self, device):
self.obs = self.obs.to(device)
self.rewards = self.rewards.to(device)
self.returns = self.returns.to(device)
self.action_log_probs = self.action_log_probs.to(device)
self.actions = self.actions.to(device)
self.masks = self.masks.to(device)
def cuda(self):
self.obs = self.obs.cuda()
self.rewards = self.rewards.cuda()
self.returns = self.returns.cuda()
self.action_log_probs = self.action_log_probs.cuda()
self.actions = self.actions.cuda()
self.masks = self.masks.cuda()
def insert(self, obs, actions, action_log_probs, rewards, masks):
self.obs[self.step + 1].copy_(obs)
self.actions[self.step].copy_(actions)
self.action_log_probs[self.step].copy_(action_log_probs)
self.rewards[self.step].copy_(rewards)
self.masks[self.step + 1].copy_(masks)
self.step = (self.step + 1) % self.num_steps
def after_update(self):
self.obs[0].copy_(self.obs[-1])
self.masks[0].copy_(self.masks[-1])
def compute_returns(self, next_value):
self.returns[-1] = next_value#
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = self.returns[step + 1] * \
self.gamma * self.masks[step + 1] + self.rewards[step]