Skip to content

Commit

Permalink
Add in pokeflute and surf coords
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Dec 18, 2024
1 parent ef5e4f3 commit b99f4a1
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 6 deletions.
37 changes: 34 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ env:
auto_teach_cut: False
auto_use_cut: False
auto_use_strength: True
auto_use_surf: True
auto_teach_surf: True
auto_use_surf: False
auto_teach_surf: False
auto_teach_strength: True
auto_solve_strength_puzzles: True
auto_remove_all_nonuseful_items: True
auto_pokeflute: True
auto_pokeflute: False
auto_next_elevator_floor: True
skip_safari_zone: False
infinite_safari_steps: False
Expand Down Expand Up @@ -341,6 +341,37 @@ rewards:
use_surf: 0.4
useful_item: 0.825
safari_zone: 3.4493650422686217

baseline.ObjectRewardRequiredEventsMapIdsFieldMoves:
reward:
a_press: 0.0 # 0.00001
badges: 3.0
bag_menu: 0.0
caught_pokemon: 2.5
valid_cut_coords: 0.75
valid_surf_coords: .75
event: .75
exploration: 0.018999755680454297
explore_hidden_objs: 0.00009999136567868017
explore_signs: 0.015025767686371013
explore_warps: 0.010135211705238394
hm_count: 7.5
invalid_cut_coords: 0.0001
invalid_surf_coords: 0.0001
level: 1.05
moves_obtained: 4.0
pokecenter_heal: 0.47
pokeflute_coords: 0.0001
pokemon_menu: 0.0
required_event: 7.0
required_item: 3.0
seen_action_bag_menu: 0.0
seen_pokemon: 2.5
start_menu: 0.0
stats_menu: 0.0
use_surf: 0.0
useful_item: 0.825
safari_zone: 3.4493650422686217

policies:
multi_convolutional.MultiConvolutionalPolicy:
Expand Down
55 changes: 53 additions & 2 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,18 @@ def register_hooks(self):
)
self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None)
self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None)
self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True)
self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False)
if not self.auto_use_cut:
self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True)
self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False)
# there is already an event for waking up the snorlax. No need to make a hookd for it
if not self.auto_pokeflute:
self.pyboy.hook_register(
None, "ItemUsePokeFlute.noSnorlaxToWakeUp", self.pokeflute_hook, None
)
if not self.auto_use_surf:
self.pyboy.hook_register(None, "SurfingAttemptFailed", self.surf_hook, context=False)
self.pyboy.hook_register(None, "ItemUseSurfboard.surf", self.surf_hook, context=True)

if self.disable_wild_encounters:
self.setup_disable_wild_encounters()
self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None)
Expand Down Expand Up @@ -423,6 +433,11 @@ def init_mem(self):
self.valid_cut_coords = {}
self.invalid_cut_coords = {}

self.pokeflute_coords = {}

self.valid_surf_coords = {}
self.invalid_surf_coords = {}

self.seen_hidden_objs = {}
self.seen_signs = {}

Expand Down Expand Up @@ -1375,6 +1390,39 @@ def cut_hook(self, context: bool):

self.cut_explore_map[local_to_global(y, x, map_id)] = 1

def pokeflute_hook(self, *args, **kwargs):
player_direction = self.pyboy.memory[
self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1]
]
x, y, map_id = self.get_game_coords() # x, y, map_id
if player_direction == 0: # down
coords = (x, y + 1, map_id)
if player_direction == 4:
coords = (x, y - 1, map_id)
if player_direction == 8:
coords = (x - 1, y, map_id)
if player_direction == 0xC:
coords = (x + 1, y, map_id)
self.pokeflute_coords[coords] = 1

def surf_hook(self, context: bool, *args, **kwargs):
player_direction = self.pyboy.memory[
self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1]
]
x, y, map_id = self.get_game_coords() # x, y, map_id
if player_direction == 0: # down
coords = (x, y + 1, map_id)
if player_direction == 4:
coords = (x, y - 1, map_id)
if player_direction == 8:
coords = (x - 1, y, map_id)
if player_direction == 0xC:
coords = (x + 1, y, map_id)
if context:
self.valid_surf_coords[coords] = 1
else:
self.invalid_surf_coords[coords] = 1

def disable_wild_encounter_hook(self, *args, **kwargs):
if (
self.disable_wild_encounters
Expand Down Expand Up @@ -1429,6 +1477,9 @@ def agent_stats(self, action):
"taught_strength": int(self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)),
"valid_cut_coords": len(self.valid_cut_coords),
"invalid_cut_coords": len(self.invalid_cut_coords),
"pokeflute_coords": len(self.pokeflute_coords),
"valid_surf_coords": len(self.valid_surf_coords),
"invalid_surf_coords": len(self.invalid_surf_coords),
"menu": {
"start_menu": self.seen_start_menu,
"pokemon_menu": self.seen_pokemon_menu,
Expand Down
14 changes: 13 additions & 1 deletion pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_levels_reward(self):


class ObjectRewardRequiredEventsMapIds(BaselineRewardEnv):
def get_game_state_reward(self):
def get_game_state_reward(self) -> dict[str, float]:
_, wBagItems = self.pyboy.symbol_lookup("wBagItems")
numBagItems = self.read_m("wNumBagItems")
bag_item_ids = set(self.pyboy.memory[wBagItems : wBagItems + 2 * numBagItems : 2])
Expand Down Expand Up @@ -430,3 +430,15 @@ def get_levels_reward(self):
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4


class ObjectRewardRequiredEventsMapIdsFieldMoves(ObjectRewardRequiredEventsMapIds):
def get_game_state_reward(self) -> dict[str, float]:
return super().get_game_state_reward() | {
"pokeflute_coords": self.reward_config["pokeflute_coords"]
* len(self.pokeflute_coords.values()),
"valid_surf_coords": self.reward_config["valid_surf_coords"]
* len(self.valid_surf_coords.values()),
"invalid_surf_coords": self.reward_config["invalid_surf_coords"]
* len(self.invalid_cut_coords.values()),
}
18 changes: 18 additions & 0 deletions pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def step(self, action):
self.env.unwrapped.seen_npcs.clear()
self.env.unwrapped.valid_cut_coords.clear()
self.env.unwrapped.invalid_cut_coords.clear()
self.env.unwrapped.pokeflute_coords.clear()
self.env.unwrapped.valid_surf_coords.clear()
self.env.unwrapped.invalid_surf_coords.clear()
self.env.unwrapped.seen_warps.clear()
self.env.unwrapped.seen_hidden_objs.clear()
self.env.unwrapped.seen_signs.clear()
Expand Down Expand Up @@ -166,6 +169,21 @@ def step(self, action):
for k, v in self.env.unwrapped.seen_npcs.items()
if v > 0
)
self.env.unwrapped.pokeflute_coords.update(
(k, self.fixed_value["pokeflute"])
for k, v in self.env.unwrapped.seen_npcs.items()
if v > 0
)
self.env.unwrapped.valid_surf_coords.update(
(k, self.fixed_value["valid_surf"])
for k, v in self.env.unwrapped.seen_npcs.items()
if v > 0
)
self.env.unwrapped.invalid_surf_coords.update(
(k, self.fixed_value["invalid_surf"])
for k, v in self.env.unwrapped.seen_npcs.items()
if v > 0
)
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[
"explore"
]
Expand Down

0 comments on commit b99f4a1

Please sign in to comment.