diff --git a/config.yaml b/config.yaml index 25d3084..7265c8f 100644 --- a/config.yaml +++ b/config.yaml @@ -8,7 +8,7 @@ debug: headless: False stream_wrapper: False init_state: victory_road - max_steps: 16 + max_steps: 20480 log_frequency: 1 disable_wild_encounters: True disable_ai_actions: True @@ -289,6 +289,7 @@ rewards: exploration_plateau: 0.025 exploration_lobby: 0.035 # for game corner a_press: 0.00001 + explore_warps: 0.03 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index ff4083d..ba3347f 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -251,6 +251,7 @@ def register_hooks(self): self.setup_disable_wild_encounters() self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None) self.pyboy.hook_register(None, "OverworldLoopLessDelay", self.overworld_loop_hook, None) + self.pyboy.hook_register(None, "CheckWarpsNoCollisionLoop", self.update_warps_hook, None) def setup_disable_wild_encounters(self): bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") @@ -371,6 +372,7 @@ def init_mem(self): self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.seen_map_ids = np.zeros(256) self.seen_npcs = {} + self.seen_warps = {} self.cut_coords = {} self.cut_tiles = {} @@ -652,6 +654,7 @@ def step(self, action): elif self.step_count % self.log_frequency == 0: info = info | self.agent_stats(action) self.required_events = required_events + print(self.seen_warps) obs = self._get_obs() @@ -1136,6 +1139,16 @@ def pokecenter_heal_hook(self, *args, **kwargs): def overworld_loop_hook(self, *args, **kwargs): self.user_control = True + def update_warps_hook(self, *args, **kwargs): + # current map id, destiation map id, warp id + key = ( + self.read_m("wCurMap"), + self.read_m("hWarpDestinationMap"), + self.read_m("wDestinationWarpID"), + ) + if key[-1] != 0xFF: + self.seen_warps[key] = 1 + def cut_hook(self, context): player_direction = self.pyboy.memory[ self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1] @@ -1194,6 +1207,7 @@ def agent_stats(self, action): "ptypes": self.read_party(), "hp": self.read_hp_fraction(), "coord": sum(sum(tileset.values()) for tileset in self.seen_coords.values()), + "warps": len(self.seen_warps), "a_press": len(self.a_press), "map_id": np.sum(self.seen_map_ids), "npc": sum(self.seen_npcs.values()), diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index e16b864..5024828 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -330,6 +330,7 @@ def get_game_state_reward(self): "saffron_guard": self.reward_config["required_event"] * float(self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK")), "a_press": len(self.a_press) * self.reward_config["a_press"], + "warps": len(self.seen_warps) * self.reward_config["explore_warps"], } | { f"exploration_{tileset.name.lower()}": self.reward_config.get( diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index 747a8ad..05a1427 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -60,6 +60,10 @@ def step_forget_explore(self): (k, max(0.15, v * (self.step_forgetting_factor["npc"]))) for k, v in self.env.unwrapped.seen_npcs.items() ) + self.env.unwrapped.seen_warps.update( + (k, max(0.15, v * (self.step_forgetting_factor["coords"]))) + for k, v in self.env.unwrapped.seen_warps.items() + ) self.env.unwrapped.explore_map *= self.step_forgetting_factor["explore"] self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = np.clip( self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0], 0.15, 1 @@ -113,6 +117,7 @@ def reset(self, *args, **kwargs): self.env.unwrapped.seen_npcs.clear() self.env.unwrapped.cut_coords.clear() self.env.unwrapped.cut_tiles.clear() + self.env.unwrapped.seen_warps.clear() self.counter += 1 return self.env.reset(*args, **kwargs) @@ -146,4 +151,9 @@ def reset(self, *args, **kwargs): self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = ( self.fixed_value["cut"] ) + self.env.unwrapped.seen_warps.update( + (k, self.fixed_value["coords"]) + for k, v in self.env.unwrapped.seen_warps.items() + if v > 0 + ) return self.env.reset(*args, **kwargs)