Skip to content

Commit

Permalink
Required events oh my
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 24, 2024
1 parent f22e52c commit 5e2784d
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 21 deletions.
20 changes: 20 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,26 @@ rewards:
rocket_hideout_found: 5.0
explore_hidden_objs: 0.02
seen_action_bag_menu: 0.1

baseline.CutWithObjectRewardsRequiredEventsEnv:
reward:
event: 1.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
hm_count: 10.0
level: 1.0
badges: 5.0
exploration: 0.02
cut_coords: 0.0
cut_tiles: 0.0
start_menu: 0.00
pokemon_menu: 0.0
stats_menu: 0.0
bag_menu: 0.1
explore_hidden_objs: 0.02
seen_action_bag_menu: 0.1
required_event: 5.0



Expand Down
72 changes: 72 additions & 0 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,3 +2583,75 @@ def __init__(self, emu: PyBoy):

def get_event(self, event_name: str) -> bool:
return bool(getattr(self.b, event_name))


REQUIRED_EVENTS = {
"EVENT_FOLLOWED_OAK_INTO_LAB",
"EVENT_PALLET_AFTER_GETTING_POKEBALLS",
"EVENT_FOLLOWED_OAK_INTO_LAB_2",
"EVENT_OAK_ASKED_TO_CHOOSE_MON",
"EVENT_GOT_STARTER",
"EVENT_BATTLED_RIVAL_IN_OAKS_LAB",
"EVENT_GOT_POKEDEX",
"EVENT_OAK_GOT_PARCEL",
"EVENT_GOT_OAKS_PARCEL",
"EVENT_BEAT_VIRIDIAN_GYM_GIOVANNI",
"EVENT_BEAT_BROCK",
"EVENT_BEAT_CERULEAN_RIVAL",
"EVENT_BEAT_CERULEAN_ROCKET_THIEF",
"EVENT_BEAT_MISTY",
"EVENT_GOT_BICYCLE",
"EVENT_BEAT_POKEMON_TOWER_RIVAL",
"EVENT_BEAT_GHOST_MAROWAK",
"EVENT_RESCUED_MR_FUJI_2",
"EVENT_GOT_POKE_FLUTE",
"EVENT_GOT_BIKE_VOUCHER",
"EVENT_2ND_LOCK_OPENED",
"EVENT_1ST_LOCK_OPENED",
"EVENT_BEAT_LT_SURGE",
"EVENT_BEAT_ERIKA",
"EVENT_FOUND_ROCKET_HIDEOUT",
"EVENT_GOT_HM04",
"EVENT_GAVE_GOLD_TEETH",
"EVENT_BEAT_KOGA",
"EVENT_BEAT_BLAINE",
"EVENT_BEAT_SABRINA",
# "EVENT_GOT_HM05",
"EVENT_FIGHT_ROUTE12_SNORLAX",
"EVENT_BEAT_ROUTE12_SNORLAX",
"EVENT_FIGHT_ROUTE16_SNORLAX",
"EVENT_BEAT_ROUTE16_SNORLAX",
"EVENT_GOT_HM02",
"EVENT_RESCUED_MR_FUJI",
"EVENT_2ND_ROUTE22_RIVAL_BATTLE",
"EVENT_BEAT_ROUTE22_RIVAL_2ND_BATTLE",
"EVENT_PASSED_CASCADEBADGE_CHECK",
"EVENT_PASSED_THUNDERBADGE_CHECK",
"EVENT_PASSED_RAINBOWBADGE_CHECK",
"EVENT_PASSED_SOULBADGE_CHECK",
"EVENT_PASSED_MARSHBADGE_CHECK",
"EVENT_PASSED_VOLCANOBADGE_CHECK",
"EVENT_PASSED_EARTHBADGE_CHECK",
"EVENT_USED_CELL_SEPARATOR_ON_BILL",
"EVENT_GOT_SS_TICKET",
"EVENT_MET_BILL_2",
"EVENT_BILL_SAID_USE_CELL_SEPARATOR",
"EVENT_LEFT_BILLS_HOUSE_AFTER_HELPING",
"EVENT_BEAT_MT_MOON_EXIT_SUPER_NERD",
"EVENT_GOT_DOME_FOSSIL",
"EVENT_GOT_HELIX_FOSSIL",
"EVENT_GOT_HM01",
"EVENT_RUBBED_CAPTAINS_BACK",
"EVENT_ROCKET_HIDEOUT_4_DOOR_UNLOCKED",
"EVENT_ROCKET_DROPPED_LIFT_KEY",
"EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
"EVENT_BEAT_SILPH_CO_RIVAL",
"EVENT_BEAT_SILPH_CO_GIOVANNI",
# "EVENT_GOT_HM03"
"EVENT_BEAT_LORELEIS_ROOM_TRAINER_0",
"EVENT_BEAT_BRUNOS_ROOM_TRAINER_0",
"EVENT_BEAT_AGATHAS_ROOM_TRAINER_0",
"EVENT_BEAT_LANCE",
"EVENT_BEAT_CHAMPION_RIVAL",
"ELITE4_CHAMPION_EVENTS_END",
}
31 changes: 10 additions & 21 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EVENT_FLAGS_START,
EVENTS_FLAGS_LENGTH,
MUSEUM_TICKET,
REQUIRED_EVENTS,
EventFlags,
)
from pokemonred_puffer.data.field_moves import FieldMoves
Expand Down Expand Up @@ -1095,37 +1096,25 @@ def agent_stats(self, action):
"seen_pokemon": int(sum(self.seen_pokemon)),
"moves_obtained": int(sum(self.moves_obtained)),
"opponent_level": self.max_opponent_level,
"met_bill": int(self.events.get_event("EVENT_MET_BILL")),
"used_cell_separator_on_bill": int(
self.events.get_event("EVENT_USED_CELL_SEPARATOR_ON_BILL")
),
"ss_ticket": int(self.events.get_event("EVENT_GOT_SS_TICKET")),
"met_bill_2": int(self.events.get_event("EVENT_MET_BILL_2")),
"bill_said_use_cell_separator": int(
self.events.get_event("EVENT_BILL_SAID_USE_CELL_SEPARATOR")
),
"left_bills_house_after_helping": int(
self.events.get_event("EVENT_LEFT_BILLS_HOUSE_AFTER_HELPING")
),
"got_hm01": int(self.events.get_event("EVENT_GOT_HM01")),
"rubbed_captains_back": int(self.events.get_event("EVENT_RUBBED_CAPTAINS_BACK")),
"taught_cut": int(self.check_if_party_has_hm(0xF)),
"cut_coords": sum(self.cut_coords.values()),
"cut_tiles": len(self.cut_tiles),
"start_menu": self.seen_start_menu,
"pokemon_menu": self.seen_pokemon_menu,
"stats_menu": self.seen_stats_menu,
"bag_menu": self.seen_bag_menu,
"action_bag_menu": self.seen_action_bag_menu,
"menu": {
"start_menu": self.seen_start_menu,
"pokemon_menu": self.seen_pokemon_menu,
"stats_menu": self.seen_stats_menu,
"bag_menu": self.seen_bag_menu,
"action_bag_menu": self.seen_action_bag_menu,
},
"blackout_check": self.blackout_check,
"item_count": self.read_m(0xD31D),
"reset_count": self.reset_count,
"blackout_count": self.blackout_count,
"pokecenter": np.sum(self.pokecenters),
"rival3": int(self.read_m(0xD665) == 4),
"rocket_hideout_found": self.events.get_event("EVENT_FOUND_ROCKET_HIDEOUT"),
}
| {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)},
"events": {self.events.get_event(event) for event in REQUIRED_EVENTS}
| {"rival3": int(self.read_m(0xD665) == 4)},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
"pokemon_exploration_map": explore_map,
Expand Down
38 changes: 38 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pufferlib
from pokemonred_puffer.data.events import REQUIRED_EVENTS
from pokemonred_puffer.environment import (
EVENT_FLAGS_START,
EVENTS_FLAGS_LENGTH,
Expand Down Expand Up @@ -221,3 +222,40 @@ def get_levels_reward(self):
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4


class CutWithObjectRewardRequiredEventsEnv(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
"seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"level": self.reward_config["level"] * self.get_levels_reward(),
"badges": self.reward_config["badges"] * self.get_badges(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()),
"start_menu": self.reward_config["start_menu"] * self.seen_start_menu,
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
"explore_hidden_objs": sum(self.seen_hidden_objs.values())
* self.reward_config["explore_hidden_objs"],
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
"rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4),
} | {
event: self.events.get_event(event) * self.reward_config["required_event"]
for event in REQUIRED_EVENTS
}

def get_levels_reward(self):
party_size = self.read_m("wPartyCount")
party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)]
self.max_level_sum = max(self.max_level_sum, sum(party_levels))
if self.max_level_sum < 15:
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4

0 comments on commit 5e2784d

Please sign in to comment.