diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 6ee9e39..a9418a9 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -188,6 +188,8 @@ def __init__(self, env_config: pufferlib.namespace): "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), # Add 3 for rival_3, game corner rocket and saffron guard "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8), + # can't use 16-bit types so might as well send the float32 + "safari_steps": spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.float32), } if self.use_global_map: @@ -593,6 +595,7 @@ def _get_obs(self): ], dtype=np.uint8, ), + "safari_steps": np.array(self.read_short("wSafariSteps") / 502.0, dtype=np.float32), } def set_perfect_iv_dvs(self): diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 476e4cf..d8c0ff1 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -225,6 +225,7 @@ def encode_observations(self, observations): items.flatten(start_dim=1), party_latent, observations["events"].float().squeeze(1), + observations["safari_steps"].float(), ), dim=-1, )