Skip to content

Commit

Permalink
pin all the things
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Mar 24, 2024
1 parent 405481c commit 96ffe87
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,18 @@ def __init__(
if hasattr(self.agent, "lstm"):
shape = (self.agent.lstm.num_layers, total_agents, self.agent.lstm.hidden_size)
self.next_lstm_state = (
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device, pin_memory=True),
torch.zeros(shape, device=self.device, pin_memory=True),
)
self.obs = torch.zeros(config.batch_size + 1, *obs_shape, dtype=torch.uint8)
self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int)
self.logprobs = torch.zeros(config.batch_size + 1)
self.rewards = torch.zeros(config.batch_size + 1)
self.dones = torch.zeros(config.batch_size + 1)
self.truncateds = torch.zeros(config.batch_size + 1)
self.values = torch.zeros(config.batch_size + 1)
self.obs = torch.zeros(
config.batch_size + 1, *obs_shape, dtype=torch.uint8, pin_memory=True
)
self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int, pin_memory=True)
self.logprobs = torch.zeros(config.batch_size + 1, pin_memory=True)
self.rewards = torch.zeros(config.batch_size + 1, pin_memory=True)
self.dones = torch.zeros(config.batch_size + 1, pin_memory=True)
self.truncateds = torch.zeros(config.batch_size + 1, pin_memory=True)
self.values = torch.zeros(config.batch_size + 1, pin_memory=True)

self.obs_ary = np.asarray(self.obs, dtype=np.uint8)
self.actions_ary = np.asarray(self.actions)
Expand Down Expand Up @@ -534,7 +536,7 @@ def train(self):
)

# Flatten the batch
self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8, pin_memory=True)
b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
Expand All @@ -547,15 +549,11 @@ def train(self):
# Optimizing the policy and value network
train_time = time.time()
pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], []
mb_obs_buffer = torch.zeros_like(
b_obs[0], pin_memory=(self.device == "cuda"), dtype=torch.uint8
)

for epoch in range(config.update_epochs):
lstm_state = None
for mb in range(num_minibatches):
mb_obs_buffer.copy_(b_obs[mb], non_blocking=True)
mb_obs = mb_obs_buffer.to(self.device, non_blocking=True)
mb_obs = b_obs[mb].to(self.device, non_blocking=True)
mb_actions = b_actions[mb].contiguous()
mb_values = b_values[mb].reshape(-1)
mb_advantages = b_advantages[mb].reshape(-1)
Expand Down

0 comments on commit 96ffe87

Please sign in to comment.