Skip to content

Commit

Permalink
maybe this will improve get events a little bit
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 19, 2024
1 parent 8761d11 commit d6c9809
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
8 changes: 8 additions & 0 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ctypes import c_uint8, LittleEndianStructure, Union
import re
from typing import Iterator

from pyboy import PyBoy

Expand Down Expand Up @@ -2589,6 +2590,13 @@ def get_event(self, event_name: str) -> int:
"""
return getattr(self.b, event_name)

def get_events(self, event_names: Iterator[str]) -> Iterator[int]:
"""
1 if true, 0 if false
"""
for event_name in event_names:
yield getattr(self.b, event_name)

def set_event(self, event_name: str, value: bool):
# This is O(N) but it's so rare that I'm not too worried about it
idx = [x[0] for x in self.b._fields_].index(event_name)
Expand Down
32 changes: 20 additions & 12 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,18 @@ def _get_obs(self):
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"events": np.array(
[self.events.get_event(event) for event in EVENTS]
+ [
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket
self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), # saffron guard
self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras
],
"events": np.concatenate(
(
np.fromiter(self.events.get_events(EVENTS), dtype=np.uint8),
[
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable(
"HS_GAME_CORNER_ROCKET"
), # game corner rocket
self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), # saffron guard
self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras
],
),
dtype=np.uint8,
),
}
Expand Down Expand Up @@ -1613,9 +1617,13 @@ def update_health(self):

def update_pokedex(self):
# TODO: Make a hook
size = 0xD30A - 0xD2F7
caught_mem = self.pyboy.memory[0xD2F7 : 0xD2F7 + size]
seen_mem = self.pyboy.memory[0xD30A : 0xD30A + size]
_, wPokedexOwned = self.pyboy.symbol_lookup("wPokedexOwned")
_, wPokedexOwnedEnd = self.pyboy.symbol_lookup("wPokedexOwnedEnd")
_, wPokedexSeen = self.pyboy.symbol_lookup("wPokedexSeen")
_, wPokedexSeenEnd = self.pyboy.symbol_lookup("wPokedexSeenEnd")

caught_mem = self.pyboy.memory[wPokedexOwned:wPokedexOwnedEnd]
seen_mem = self.pyboy.memory[wPokedexSeen:wPokedexSeenEnd]
self.caught_pokemon = np.unpackbits(np.array(caught_mem, dtype=np.uint8))
self.seen_pokemon = np.unpackbits(np.array(seen_mem, dtype=np.uint8))

Expand Down Expand Up @@ -1709,7 +1717,7 @@ def get_levels_reward(self):

def get_required_events(self) -> set[str]:
return (
{event for event in REQUIRED_EVENTS if self.events.get_event(event)}
set(self.events.get_events(REQUIRED_EVENTS))
| ({"rival3"} if (self.read_m("wSSAnne2FCurScript") == 4) else set())
| (
{"game_corner_rocket"}
Expand Down
16 changes: 9 additions & 7 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,20 +376,22 @@ def get_game_state_reward(self):
return (
{
"event": self.reward_config["event"] * self.update_max_event_rew(),
"seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"]
* np.sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"]
* np.sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"level": self.reward_config["level"] * self.get_levels_reward(),
"badges": self.reward_config["badges"] * self.get_badges(),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()),
"cut_coords": self.reward_config["cut_coords"] * np.sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * np.sum(self.cut_tiles.values()),
"start_menu": self.reward_config["start_menu"] * self.seen_start_menu,
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
"explore_hidden_objs": sum(self.seen_hidden_objs.values()),
"explore_signs": sum(self.seen_signs.values())
"explore_hidden_objs": np.sum(self.seen_hidden_objs.values()),
"explore_signs": np.sum(self.seen_signs.values())
* self.reward_config["explore_signs"],
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
Expand Down

0 comments on commit d6c9809

Please sign in to comment.