diff --git a/config.yaml b/config.yaml index 575af92..dfeee75 100644 --- a/config.yaml +++ b/config.yaml @@ -194,9 +194,10 @@ rewards: start_menu: 0.00 pokemon_menu: 0.0 stats_menu: 0.0 - bag_menu: 0.0 + bag_menu: 0.1 rocket_hideout_found: 5.0 explore_hidden_objs: 0.02 + seen_action_bag_menu: 0.1 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index f3d39e4..d250d41 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -146,6 +146,7 @@ WindowEvent.PRESS_BUTTON_A, WindowEvent.PRESS_BUTTON_B, WindowEvent.PRESS_BUTTON_START, + WindowEvent.PASS, ] VALID_RELEASE_ACTIONS = [ @@ -156,6 +157,7 @@ WindowEvent.RELEASE_BUTTON_A, WindowEvent.RELEASE_BUTTON_B, WindowEvent.RELEASE_BUTTON_START, + WindowEvent.PASS, ] VALID_ACTIONS_STR = ["down", "left", "right", "up", "a", "b", "start"] @@ -246,6 +248,7 @@ def __init__(self, env_config: pufferlib.namespace): "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), } ) @@ -555,6 +558,7 @@ def _get_obs(self): # "y": np.array(player_y, dtype=np.uint8), "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), "badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8), + "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), } def set_perfect_iv_dvs(self): @@ -626,8 +630,8 @@ def run_action_on_emulator(self, action): self.action_hist[action] += 1 # press button then release after some steps # TODO: Add video saving logic - self.pyboy.send_input(VALID_ACTIONS[action]) - self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) + # self.pyboy.send_input(VALID_ACTIONS[action]) + # self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) if self.read_bit(0xD803, 0): @@ -755,8 +759,8 @@ def start_menu_hook(self, *args, **kwargs): self.seen_start_menu = 1 def item_menu_hook(self, *args, **kwargs): - if self.read_m("wIsInBattle") == 0: - self.seen_bag_menu = 1 + # if self.read_m("wIsInBattle") == 0: + self.seen_bag_menu = 1 def pokemon_menu_hook(self, *args, **kwargs): if self.read_m("wIsInBattle") == 0: @@ -767,8 +771,8 @@ def chose_stats_hook(self, *args, **kwargs): self.seen_stats_menu = 1 def chose_item_hook(self, *args, **kwargs): - if self.read_m("wIsInBattle") == 0: - self.seen_action_bag_menu = 1 + # if self.read_m("wIsInBattle") == 0: + self.seen_action_bag_menu = 1 def blackout_hook(self, *args, **kwargs): self.blackout_count += 1 @@ -853,6 +857,8 @@ def agent_stats(self, action): "reset_count": self.reset_count, "blackout_count": self.blackout_count, "pokecenter": np.sum(self.pokecenters), + "rival3": int(self.read_m(0xD665) == 4), + "rocket_hideout_found": int(self.read_bit(0xD77E, 1)), }, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 678f664..87ad8a7 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -128,6 +128,7 @@ def encode_observations(self, observations): badges.float().squeeze(1), map_id.squeeze(1), blackout_map_id.squeeze(1), + observations["wJoyIgnore"].float(), ), dim=-1, ) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index d64c3b8..23f48db 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -194,6 +194,8 @@ def get_game_state_reward(self): * int(self.read_bit(0xD77E, 1)), "explore_hidden_objs": sum(self.seen_hidden_objs.values()) * self.reward_config["explore_hidden_objs"], + "seen_action_bag_menu": self.seen_action_bag_menu + * self.reward_config["seen_action_bag_menu"], } def get_levels_reward(self):