From 4f1f8cb4f630961b34822da1652c5240e5e8c1bf Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:21:33 -0400 Subject: [PATCH] Cut event obs --- pokemonred_puffer/environment.py | 2 ++ pokemonred_puffer/policies/multi_convolutional.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 83e39c8..2a618f5 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -238,6 +238,7 @@ def __init__(self, env_config: pufferlib.namespace): "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), + "cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), @@ -547,6 +548,7 @@ def _get_obs(self): ), # "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), + "cut_event": np.array(self.read_bit(0xD803, 0), dtype=np.uint8), "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), # "x": np.array(player_x, dtype=np.uint8), # "y": np.array(player_y, dtype=np.uint8), diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 81f8a6e..2da3d4a 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -115,6 +115,7 @@ def encode_observations(self, observations): one_hot(observations["direction"].long(), 4).float().squeeze(1), # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), one_hot(observations["battle_type"].long(), 4).float().squeeze(1), + observations["cut_event"].float(), observations["cut_in_party"].float(), # observations["x"].float(), # observations["y"].float(),