Skip to content

Commit

Permalink
sign rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 18, 2024
1 parent 9783924 commit 128db1c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ rewards:
stats_menu: 0.0
bag_menu: 0.0
explore_hidden_objs: 0.02
explore_signs: 0.02
seen_action_bag_menu: 0.0
required_event: 5.0
required_item: 5.0
Expand Down
26 changes: 14 additions & 12 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,6 @@ def register_hooks(self):
self.pyboy.hook_register(
None, "CheckForHiddenObject.foundMatchingObject", self.hidden_object_hook, None
)
"""
_, addr = self.pyboy.symbol_lookup("IsSpriteOrSignInFrontOfPlayer.retry")
self.pyboy.hook_register(
None, addr-1, self.sign_hook, None
)
"""
self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None)
self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None)
# self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True)
Expand All @@ -265,6 +259,13 @@ def register_hooks(self):
self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None)
# self.pyboy.hook_register(None, "OverworldLoopLessDelay", self.overworld_loop_hook, None)
self.pyboy.hook_register(None, "CheckWarpsNoCollisionLoop", self.update_warps_hook, None)
signBank, signAddr = self.pyboy.symbol_lookup("IsSpriteOrSignInFrontOfPlayer.retry")
self.pyboy.hook_register(
signBank,
signAddr - 1,
self.sign_hook,
None,
)

def setup_disable_wild_encounters(self):
bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType")
Expand Down Expand Up @@ -292,8 +293,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.recent_screens = deque()
self.recent_actions = deque()
# We only init seen hidden objs once cause they can only be found once!
self.seen_hidden_objs = {}
self.seen_signs = {}
self.a_press = set()
if options.get("state", None) is not None:
self.pyboy.load_state(io.BytesIO(options["state"]))
Expand Down Expand Up @@ -397,6 +396,9 @@ def init_mem(self):
self.cut_coords = {}
self.cut_tiles = {}

self.seen_hidden_objs = {}
self.seen_signs = {}

self.seen_start_menu = 0
self.seen_pokemon_menu = 0
self.seen_stats_menu = 0
Expand Down Expand Up @@ -1191,10 +1193,9 @@ def skip_safari_zone_atn(self):
self.pyboy.memory[wNumBagItems] = numBagItems

def sign_hook(self, *args, **kwargs):
sign_id = self.pyboy.memory[self.pyboy.symbol_lookup("hSpriteIndexOrTextID")[1]]
map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]]
# We will store this by map id, y, x,
self.seen_hidden_objs[(map_id, sign_id)] = 1
sign_id = self.read_m("hSpriteIndexOrTextID")
map_id = self.read_m("wCurMap")
self.seen_signs[(map_id, sign_id)] = 1

def hidden_object_hook(self, *args, **kwargs):
hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]]
Expand Down Expand Up @@ -1311,6 +1312,7 @@ def agent_stats(self, action):
"map_id": np.sum(self.seen_map_ids),
"npc": sum(self.seen_npcs.values()),
"hidden_obj": sum(self.seen_hidden_objs.values()),
"sign": sum(self.seen_signs.values()),
"deaths": self.died_count,
"badge": self.get_badges(),
"healr": self.total_heal_health,
Expand Down
3 changes: 2 additions & 1 deletion pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def get_game_state_reward(self):
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
"explore_hidden_objs": sum(self.seen_hidden_objs.values())
"explore_hidden_objs": sum(self.seen_hidden_objs.values()),
"explore_signs": sum(self.seen_signs.values())
* self.reward_config["explore_hidden_objs"],
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
Expand Down
18 changes: 18 additions & 0 deletions pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def step_forget_explore(self):
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = np.clip(
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0], 0.15, 1
)
self.env.unwrapped.seen_hidden_objs.update(
(k, max(0.15, v * (self.step_forgetting_factor["hidden_objs"])))
for k, v in self.env.unwrapped.seen_coords.items()
)
self.env.unwrapped.seen_signs.update(
(k, max(0.15, v * (self.step_forgetting_factor["signs"])))
for k, v in self.env.unwrapped.seen_coords.items()
)

if self.env.unwrapped.read_m("wIsInBattle") == 0:
self.env.unwrapped.seen_start_menu *= self.step_forgetting_factor["start_menu"]
Expand Down Expand Up @@ -118,6 +126,8 @@ def reset(self, *args, **kwargs):
self.env.unwrapped.cut_coords.clear()
self.env.unwrapped.cut_tiles.clear()
self.env.unwrapped.seen_warps.clear()
self.env.unwrapped.seen_hidden_objs.clear()
self.env.unwrapped.seen_signs.clear()
self.counter += 1
return self.env.reset(*args, **kwargs)

Expand Down Expand Up @@ -156,4 +166,12 @@ def reset(self, *args, **kwargs):
for k, v in self.env.unwrapped.seen_warps.items()
if v > 0
)
self.env.unwrapped.seen_hidden_objs.update(
(k, self.fixed_value["hidden_objs"])
for k, v in self.env.unwrapped.seen_npcs.items()
if v > 0
)
self.env.unwrapped.seen_signs.update(
(k, self.fixed_value["signs"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0
)
return self.env.reset(*args, **kwargs)

0 comments on commit 128db1c

Please sign in to comment.