From 940f4c54c6f67b3079c770d15f4b26c864bf8e2f Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:11:25 -0400 Subject: [PATCH] safari zone obs --- pokemonred_puffer/environment.py | 3 +++ pokemonred_puffer/policies/multi_convolutional.py | 1 + 2 files changed, 4 insertions(+) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 6ee9e39..a9418a9 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -188,6 +188,8 @@ def __init__(self, env_config: pufferlib.namespace): "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), # Add 3 for rival_3, game corner rocket and saffron guard "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8), + # can't use 16-bit types so might as well send the float32 + "safari_steps": spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.float32), } if self.use_global_map: @@ -593,6 +595,7 @@ def _get_obs(self): ], dtype=np.uint8, ), + "safari_steps": np.array(self.read_short("wSafariSteps") / 502.0, dtype=np.float32), } def set_perfect_iv_dvs(self): diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 476e4cf..d8c0ff1 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -225,6 +225,7 @@ def encode_observations(self, observations): items.flatten(start_dim=1), party_latent, observations["events"].float().squeeze(1), + observations["safari_steps"].float(), ), dim=-1, )