From 96ffe8704b268b7a1cafc78b29fec6e3b9a52e43 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 24 Mar 2024 14:01:51 -0400 Subject: [PATCH] pin all the things --- pokemonred_puffer/cleanrl_puffer.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index dfd98bc..c10505f 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -273,16 +273,18 @@ def __init__( if hasattr(self.agent, "lstm"): shape = (self.agent.lstm.num_layers, total_agents, self.agent.lstm.hidden_size) self.next_lstm_state = ( - torch.zeros(shape, device=self.device), - torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device, pin_memory=True), + torch.zeros(shape, device=self.device, pin_memory=True), ) - self.obs = torch.zeros(config.batch_size + 1, *obs_shape, dtype=torch.uint8) - self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) - self.logprobs = torch.zeros(config.batch_size + 1) - self.rewards = torch.zeros(config.batch_size + 1) - self.dones = torch.zeros(config.batch_size + 1) - self.truncateds = torch.zeros(config.batch_size + 1) - self.values = torch.zeros(config.batch_size + 1) + self.obs = torch.zeros( + config.batch_size + 1, *obs_shape, dtype=torch.uint8, pin_memory=True + ) + self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int, pin_memory=True) + self.logprobs = torch.zeros(config.batch_size + 1, pin_memory=True) + self.rewards = torch.zeros(config.batch_size + 1, pin_memory=True) + self.dones = torch.zeros(config.batch_size + 1, pin_memory=True) + self.truncateds = torch.zeros(config.batch_size + 1, pin_memory=True) + self.values = torch.zeros(config.batch_size + 1, pin_memory=True) self.obs_ary = np.asarray(self.obs, dtype=np.uint8) self.actions_ary = np.asarray(self.actions) @@ -534,7 +536,7 @@ def train(self): ) # Flatten the batch - self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8) + self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8, pin_memory=True) b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) @@ -547,15 +549,11 @@ def train(self): # Optimizing the policy and value network train_time = time.time() pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], [] - mb_obs_buffer = torch.zeros_like( - b_obs[0], pin_memory=(self.device == "cuda"), dtype=torch.uint8 - ) for epoch in range(config.update_epochs): lstm_state = None for mb in range(num_minibatches): - mb_obs_buffer.copy_(b_obs[mb], non_blocking=True) - mb_obs = mb_obs_buffer.to(self.device, non_blocking=True) + mb_obs = b_obs[mb].to(self.device, non_blocking=True) mb_actions = b_actions[mb].contiguous() mb_values = b_values[mb].reshape(-1) mb_advantages = b_advantages[mb].reshape(-1)