Skip to content

Commit

Permalink
warp rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 4, 2024
1 parent 0272247 commit 289c74e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
3 changes: 2 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ debug:
headless: False
stream_wrapper: False
init_state: victory_road
max_steps: 16
max_steps: 20480
log_frequency: 1
disable_wild_encounters: True
disable_ai_actions: True
Expand Down Expand Up @@ -289,6 +289,7 @@ rewards:
exploration_plateau: 0.025
exploration_lobby: 0.035 # for game corner
a_press: 0.00001
explore_warps: 0.03



Expand Down
14 changes: 14 additions & 0 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def register_hooks(self):
self.setup_disable_wild_encounters()
self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None)
self.pyboy.hook_register(None, "OverworldLoopLessDelay", self.overworld_loop_hook, None)
self.pyboy.hook_register(None, "CheckWarpsNoCollisionLoop", self.update_warps_hook, None)

def setup_disable_wild_encounters(self):
bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType")
Expand Down Expand Up @@ -371,6 +372,7 @@ def init_mem(self):
self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.seen_map_ids = np.zeros(256)
self.seen_npcs = {}
self.seen_warps = {}

self.cut_coords = {}
self.cut_tiles = {}
Expand Down Expand Up @@ -652,6 +654,7 @@ def step(self, action):
elif self.step_count % self.log_frequency == 0:
info = info | self.agent_stats(action)
self.required_events = required_events
print(self.seen_warps)

obs = self._get_obs()

Expand Down Expand Up @@ -1136,6 +1139,16 @@ def pokecenter_heal_hook(self, *args, **kwargs):
def overworld_loop_hook(self, *args, **kwargs):
self.user_control = True

def update_warps_hook(self, *args, **kwargs):
# current map id, destiation map id, warp id
key = (
self.read_m("wCurMap"),
self.read_m("hWarpDestinationMap"),
self.read_m("wDestinationWarpID"),
)
if key[-1] != 0xFF:
self.seen_warps[key] = 1

def cut_hook(self, context):
player_direction = self.pyboy.memory[
self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1]
Expand Down Expand Up @@ -1194,6 +1207,7 @@ def agent_stats(self, action):
"ptypes": self.read_party(),
"hp": self.read_hp_fraction(),
"coord": sum(sum(tileset.values()) for tileset in self.seen_coords.values()),
"warps": len(self.seen_warps),
"a_press": len(self.a_press),
"map_id": np.sum(self.seen_map_ids),
"npc": sum(self.seen_npcs.values()),
Expand Down
1 change: 1 addition & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def get_game_state_reward(self):
"saffron_guard": self.reward_config["required_event"]
* float(self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK")),
"a_press": len(self.a_press) * self.reward_config["a_press"],
"warps": len(self.seen_warps) * self.reward_config["explore_warps"],
}
| {
f"exploration_{tileset.name.lower()}": self.reward_config.get(
Expand Down
10 changes: 10 additions & 0 deletions pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def step_forget_explore(self):
(k, max(0.15, v * (self.step_forgetting_factor["npc"])))
for k, v in self.env.unwrapped.seen_npcs.items()
)
self.env.unwrapped.seen_warps.update(
(k, max(0.15, v * (self.step_forgetting_factor["coords"])))
for k, v in self.env.unwrapped.seen_warps.items()
)
self.env.unwrapped.explore_map *= self.step_forgetting_factor["explore"]
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = np.clip(
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0], 0.15, 1
Expand Down Expand Up @@ -113,6 +117,7 @@ def reset(self, *args, **kwargs):
self.env.unwrapped.seen_npcs.clear()
self.env.unwrapped.cut_coords.clear()
self.env.unwrapped.cut_tiles.clear()
self.env.unwrapped.seen_warps.clear()
self.counter += 1
return self.env.reset(*args, **kwargs)

Expand Down Expand Up @@ -146,4 +151,9 @@ def reset(self, *args, **kwargs):
self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = (
self.fixed_value["cut"]
)
self.env.unwrapped.seen_warps.update(
(k, self.fixed_value["coords"])
for k, v in self.env.unwrapped.seen_warps.items()
if v > 0
)
return self.env.reset(*args, **kwargs)

0 comments on commit 289c74e

Please sign in to comment.