diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index c071fd0..b74b193 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -455,7 +455,7 @@ def evaluate(self): # Index alive mask with policy pool idxs... # TODO: Find a way to avoid having to do this - learner_mask = torch.Tensor(mask * self.policy_pool.mask) + learner_mask = torch.as_tensor(mask * self.policy_pool.mask) # Ensure indices do not exceed batch size indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy() @@ -592,11 +592,11 @@ def train(self): ) # Flatten the batch - self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8) - b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) - b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) - # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) - b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) + self.b_obs = b_obs = torch.as_tensor(self.obs_ary[b_idxs], dtype=torch.uint8) + b_actions = torch.as_tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) + b_logprobs = torch.as_tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) + # b_dones = torch.as_tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) + b_values = torch.as_tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) b_advantages = advantages.reshape( config.batch_rows, num_minibatches, config.bptt_horizon ).transpose(0, 1)