Skip to content

Commit

Permalink
uint8 for obs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Mar 24, 2024
1 parent c444aad commit 405481c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 23 deletions.
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ wandb:

debug:
env:
headless: False
headless: True
stream_wrapper: False
init_state: cut
train:
Expand All @@ -23,7 +23,7 @@ debug:
checkpoint_interval: 4
save_overlay: True
overlay_interval: 4
verbose: True
verbose: True
env_pool: False
log_frequency: 5000

Expand Down
16 changes: 9 additions & 7 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ def __init__(
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
)
self.obs = torch.zeros(config.batch_size + 1, *obs_shape)
self.obs = torch.zeros(config.batch_size + 1, *obs_shape, dtype=torch.uint8)
self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int)
self.logprobs = torch.zeros(config.batch_size + 1)
self.rewards = torch.zeros(config.batch_size + 1)
self.dones = torch.zeros(config.batch_size + 1)
self.truncateds = torch.zeros(config.batch_size + 1)
self.values = torch.zeros(config.batch_size + 1)

self.obs_ary = np.asarray(self.obs)
self.obs_ary = np.asarray(self.obs, dtype=np.uint8)
self.actions_ary = np.asarray(self.actions)
self.logprobs_ary = np.asarray(self.logprobs)
self.rewards_ary = np.asarray(self.rewards)
Expand Down Expand Up @@ -534,11 +534,11 @@ def train(self):
)

# Flatten the batch
self.b_obs = b_obs = torch.Tensor(self.obs_ary[b_idxs])
b_actions = torch.Tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.Tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.Tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
b_advantages = advantages.reshape(
config.batch_rows, num_minibatches, config.bptt_horizon
).transpose(0, 1)
Expand All @@ -547,7 +547,9 @@ def train(self):
# Optimizing the policy and value network
train_time = time.time()
pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], []
mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(self.device == "cuda"))
mb_obs_buffer = torch.zeros_like(
b_obs[0], pin_memory=(self.device == "cuda"), dtype=torch.uint8
)

for epoch in range(config.update_epochs):
lstm_state = None
Expand Down
17 changes: 3 additions & 14 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,19 @@ def encode_observations(self, observations):
screen = torch.index_select(
self.screen_buckets,
0,
screen.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
.flatten()
.int(),
((screen.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift).flatten().int(),
).reshape(restored_shape)
visited_mask = torch.index_select(
self.linear_buckets,
0,
visited_mask.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
((visited_mask.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)
global_map = torch.index_select(
self.linear_buckets,
0,
global_map.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)
Expand Down

0 comments on commit 405481c

Please sign in to comment.