diff --git a/config.yaml b/config.yaml index 03e9fe3..a9a2b69 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ wandb: debug: env: - headless: False + headless: True stream_wrapper: False init_state: cut train: @@ -23,7 +23,7 @@ debug: checkpoint_interval: 4 save_overlay: True overlay_interval: 4 - verbose: True + verbose: True env_pool: False log_frequency: 5000 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 17acbb7..dfd98bc 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -276,7 +276,7 @@ def __init__( torch.zeros(shape, device=self.device), torch.zeros(shape, device=self.device), ) - self.obs = torch.zeros(config.batch_size + 1, *obs_shape) + self.obs = torch.zeros(config.batch_size + 1, *obs_shape, dtype=torch.uint8) self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) self.logprobs = torch.zeros(config.batch_size + 1) self.rewards = torch.zeros(config.batch_size + 1) @@ -284,7 +284,7 @@ def __init__( self.truncateds = torch.zeros(config.batch_size + 1) self.values = torch.zeros(config.batch_size + 1) - self.obs_ary = np.asarray(self.obs) + self.obs_ary = np.asarray(self.obs, dtype=np.uint8) self.actions_ary = np.asarray(self.actions) self.logprobs_ary = np.asarray(self.logprobs) self.rewards_ary = np.asarray(self.rewards) @@ -534,11 +534,11 @@ def train(self): ) # Flatten the batch - self.b_obs = b_obs = torch.Tensor(self.obs_ary[b_idxs]) - b_actions = torch.Tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) - b_logprobs = torch.Tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) + self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8) + b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) + b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) - b_values = torch.Tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) + b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) b_advantages = advantages.reshape( config.batch_rows, num_minibatches, config.bptt_horizon ).transpose(0, 1) @@ -547,7 +547,9 @@ def train(self): # Optimizing the policy and value network train_time = time.time() pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], [] - mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(self.device == "cuda")) + mb_obs_buffer = torch.zeros_like( + b_obs[0], pin_memory=(self.device == "cuda"), dtype=torch.uint8 + ) for epoch in range(config.update_epochs): lstm_state = None diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index f3cc06b..ced205e 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -82,30 +82,19 @@ def encode_observations(self, observations): screen = torch.index_select( self.screen_buckets, 0, - screen.flatten() - .unsqueeze(-1) - .bitwise_and(self.unpack_mask) - .bitwise_right_shift(self.unpack_shift) - .flatten() - .int(), + ((screen.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift).flatten().int(), ).reshape(restored_shape) visited_mask = torch.index_select( self.linear_buckets, 0, - visited_mask.flatten() - .unsqueeze(-1) - .bitwise_and(self.unpack_mask) - .bitwise_right_shift(self.unpack_shift) + ((visited_mask.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) .flatten() .int(), ).reshape(restored_shape) global_map = torch.index_select( self.linear_buckets, 0, - global_map.flatten() - .unsqueeze(-1) - .bitwise_and(self.unpack_mask) - .bitwise_right_shift(self.unpack_shift) + ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) .flatten() .int(), ).reshape(restored_shape)