diff --git a/config.yaml b/config.yaml index 605d275..f6092a9 100644 --- a/config.yaml +++ b/config.yaml @@ -288,6 +288,7 @@ rewards: exploration_facility: 0.05 exploration_plateau: 0.03 exploration_lobby: 0.03 # for game corner + a_press: 0.02 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 6a10e47..f0c7902 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -281,6 +281,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = # We only init seen hidden objs once cause they can only be found once! self.seen_hidden_objs = {} self.seen_signs = {} + self.a_press = set() if options.get("state", None) is not None: self.pyboy.load_state(io.BytesIO(options["state"])) self.reset_count += 1 @@ -321,6 +322,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.recent_screens.clear() self.recent_actions.clear() + self.a_press.clear() self.seen_pokemon.fill(0) self.caught_pokemon.fill(0) self.moves_obtained.fill(0) @@ -616,6 +618,9 @@ def step(self, action): if self.disable_wild_encounters: self.pyboy.memory[self.pyboy.symbol_lookup("wRepelRemainingSteps")[1]] = 0xFF + # update the a press before we use it so we dont trigger the font loaded early return + if VALID_ACTIONS[action] == WindowEvent.PRESS_BUTTON_A: + self.update_a_press() self.run_action_on_emulator(action) self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) @@ -1182,6 +1187,7 @@ def agent_stats(self, action): "ptypes": self.read_party(), "hp": self.read_hp_fraction(), "coord": sum(sum(tileset.values()) for tileset in self.seen_coords.values()), + "a_press": len(self.a_press), "map_id": np.sum(self.seen_map_ids), "npc": sum(self.seen_npcs.values()), "hidden_obj": sum(self.seen_hidden_objs.values()), @@ -1292,6 +1298,22 @@ def update_seen_coords(self): # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 self.seen_map_ids[map_n] = 1 + def update_a_press(self): + if self.read_m("wIsInBattle") != 0 or self.read_m("wFontLoaded"): + return + + direction = self.read_m("wSpritePlayerStateData1FacingDirection") + x_pos, y_pos, map_n = self.get_game_coords() + if direction == 0: + y_pos += 1 + if direction == 4: + y_pos -= 1 + if direction == 8: + x_pos -= 1 + if direction == 0xC: + x_pos += 1 + self.a_press.add((x_pos, y_pos, map_n)) + def get_explore_map(self): explore_map = np.zeros(GLOBAL_MAP_SHAPE) for (x, y, map_n), v in self.seen_coords.items(): diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 65f7fcf..0e89b71 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -329,6 +329,7 @@ def get_game_state_reward(self): * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), "saffron_guard": self.reward_config["required_event"] * float(self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK")), + "a_press": len(self.a_press) * self.reward_config["a_press"], } | { f"exploration_{tileset.name.lower()}": self.reward_config.get(