diff --git a/config.yaml b/config.yaml index 6e34d9b..218d5ef 100644 --- a/config.yaml +++ b/config.yaml @@ -253,6 +253,7 @@ rewards: required_event: 5.0 required_item: 5.0 useful_item: 1.0 + pokecenter_heal: 1.0 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 86bb3fa..380c747 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -218,6 +218,7 @@ def register_hooks(self): # self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False) if self.disable_wild_encounters: self.setup_disable_wild_encounters() + self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None) def setup_disable_wild_encounters(self): bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") @@ -331,6 +332,7 @@ def init_mem(self): self.seen_stats_menu = 0 self.seen_bag_menu = 0 self.seen_action_bag_menu = 0 + self.pokecenter_heal = 0 def reset_mem(self): self.seen_start_menu = 0 @@ -338,6 +340,7 @@ def reset_mem(self): self.seen_stats_menu = 0 self.seen_bag_menu = 0 self.seen_action_bag_menu = 0 + self.pokecenter_heal = 0 def render(self): # (144, 160, 3) @@ -1032,6 +1035,9 @@ def blackout_hook(self, *args, **kwargs): def blackout_update_hook(self, *args, **kwargs): self.blackout_check = self.read_m("wLastBlackoutMap") + def pokecenter_heal_hook(self, *args, **kwargs): + self.pokecenter_heal = 1 + def cut_hook(self, context): player_direction = self.pyboy.memory[ self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1] @@ -1113,6 +1119,7 @@ def agent_stats(self, action): "reset_count": self.reset_count, "blackout_count": self.blackout_count, "pokecenter": np.sum(self.pokecenters), + "pokecenter_heal": self.pokecenter_heal, } | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "events": {event: self.events.get_event(event) for event in REQUIRED_EVENTS} diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index bd77d78..223f5dc 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -257,6 +257,7 @@ def get_game_state_reward(self): * self.reward_config["explore_hidden_objs"], "seen_action_bag_menu": self.seen_action_bag_menu * self.reward_config["seen_action_bag_menu"], + "pokecenter_heal": self.pokecenter_heal * self.reward_config["pokecenter_heal"], "rival3": self.reward_config["required_event"] * int(self.read_m("wSSAnne2FCurScript") == 4), "game_corner_rocket": self.reward_config["required_event"]