Skip to content

Commit

Permalink
torch.as_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 25, 2024
1 parent ff618f9 commit dc20780
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def evaluate(self):

# Index alive mask with policy pool idxs...
# TODO: Find a way to avoid having to do this
learner_mask = torch.Tensor(mask * self.policy_pool.mask)
learner_mask = torch.as_tensor(mask * self.policy_pool.mask)

# Ensure indices do not exceed batch size
indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy()
Expand Down Expand Up @@ -592,11 +592,11 @@ def train(self):
)

# Flatten the batch
self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
self.b_obs = b_obs = torch.as_tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
b_actions = torch.as_tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.as_tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.as_tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.as_tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
b_advantages = advantages.reshape(
config.batch_rows, num_minibatches, config.bptt_horizon
).transpose(0, 1)
Expand Down

0 comments on commit dc20780

Please sign in to comment.