diff --git a/config.yaml b/config.yaml index 4242895..b1be210 100644 --- a/config.yaml +++ b/config.yaml @@ -228,6 +228,26 @@ rewards: rocket_hideout_found: 5.0 explore_hidden_objs: 0.02 seen_action_bag_menu: 0.1 + + baseline.CutWithObjectRewardsRequiredEventsEnv: + reward: + event: 1.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + hm_count: 10.0 + level: 1.0 + badges: 5.0 + exploration: 0.02 + cut_coords: 0.0 + cut_tiles: 0.0 + start_menu: 0.00 + pokemon_menu: 0.0 + stats_menu: 0.0 + bag_menu: 0.1 + explore_hidden_objs: 0.02 + seen_action_bag_menu: 0.1 + required_event: 5.0 diff --git a/pokemonred_puffer/data/events.py b/pokemonred_puffer/data/events.py index c21468c..fdf46ba 100644 --- a/pokemonred_puffer/data/events.py +++ b/pokemonred_puffer/data/events.py @@ -2583,3 +2583,75 @@ def __init__(self, emu: PyBoy): def get_event(self, event_name: str) -> bool: return bool(getattr(self.b, event_name)) + + +REQUIRED_EVENTS = { + "EVENT_FOLLOWED_OAK_INTO_LAB", + "EVENT_PALLET_AFTER_GETTING_POKEBALLS", + "EVENT_FOLLOWED_OAK_INTO_LAB_2", + "EVENT_OAK_ASKED_TO_CHOOSE_MON", + "EVENT_GOT_STARTER", + "EVENT_BATTLED_RIVAL_IN_OAKS_LAB", + "EVENT_GOT_POKEDEX", + "EVENT_OAK_GOT_PARCEL", + "EVENT_GOT_OAKS_PARCEL", + "EVENT_BEAT_VIRIDIAN_GYM_GIOVANNI", + "EVENT_BEAT_BROCK", + "EVENT_BEAT_CERULEAN_RIVAL", + "EVENT_BEAT_CERULEAN_ROCKET_THIEF", + "EVENT_BEAT_MISTY", + "EVENT_GOT_BICYCLE", + "EVENT_BEAT_POKEMON_TOWER_RIVAL", + "EVENT_BEAT_GHOST_MAROWAK", + "EVENT_RESCUED_MR_FUJI_2", + "EVENT_GOT_POKE_FLUTE", + "EVENT_GOT_BIKE_VOUCHER", + "EVENT_2ND_LOCK_OPENED", + "EVENT_1ST_LOCK_OPENED", + "EVENT_BEAT_LT_SURGE", + "EVENT_BEAT_ERIKA", + "EVENT_FOUND_ROCKET_HIDEOUT", + "EVENT_GOT_HM04", + "EVENT_GAVE_GOLD_TEETH", + "EVENT_BEAT_KOGA", + "EVENT_BEAT_BLAINE", + "EVENT_BEAT_SABRINA", + # "EVENT_GOT_HM05", + "EVENT_FIGHT_ROUTE12_SNORLAX", + "EVENT_BEAT_ROUTE12_SNORLAX", + "EVENT_FIGHT_ROUTE16_SNORLAX", + "EVENT_BEAT_ROUTE16_SNORLAX", + "EVENT_GOT_HM02", + "EVENT_RESCUED_MR_FUJI", + "EVENT_2ND_ROUTE22_RIVAL_BATTLE", + "EVENT_BEAT_ROUTE22_RIVAL_2ND_BATTLE", + "EVENT_PASSED_CASCADEBADGE_CHECK", + "EVENT_PASSED_THUNDERBADGE_CHECK", + "EVENT_PASSED_RAINBOWBADGE_CHECK", + "EVENT_PASSED_SOULBADGE_CHECK", + "EVENT_PASSED_MARSHBADGE_CHECK", + "EVENT_PASSED_VOLCANOBADGE_CHECK", + "EVENT_PASSED_EARTHBADGE_CHECK", + "EVENT_USED_CELL_SEPARATOR_ON_BILL", + "EVENT_GOT_SS_TICKET", + "EVENT_MET_BILL_2", + "EVENT_BILL_SAID_USE_CELL_SEPARATOR", + "EVENT_LEFT_BILLS_HOUSE_AFTER_HELPING", + "EVENT_BEAT_MT_MOON_EXIT_SUPER_NERD", + "EVENT_GOT_DOME_FOSSIL", + "EVENT_GOT_HELIX_FOSSIL", + "EVENT_GOT_HM01", + "EVENT_RUBBED_CAPTAINS_BACK", + "EVENT_ROCKET_HIDEOUT_4_DOOR_UNLOCKED", + "EVENT_ROCKET_DROPPED_LIFT_KEY", + "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + "EVENT_BEAT_SILPH_CO_RIVAL", + "EVENT_BEAT_SILPH_CO_GIOVANNI", + # "EVENT_GOT_HM03" + "EVENT_BEAT_LORELEIS_ROOM_TRAINER_0", + "EVENT_BEAT_BRUNOS_ROOM_TRAINER_0", + "EVENT_BEAT_AGATHAS_ROOM_TRAINER_0", + "EVENT_BEAT_LANCE", + "EVENT_BEAT_CHAMPION_RIVAL", + "ELITE4_CHAMPION_EVENTS_END", +} diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 870614e..b871477 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -20,6 +20,7 @@ EVENT_FLAGS_START, EVENTS_FLAGS_LENGTH, MUSEUM_TICKET, + REQUIRED_EVENTS, EventFlags, ) from pokemonred_puffer.data.field_moves import FieldMoves @@ -1095,37 +1096,25 @@ def agent_stats(self, action): "seen_pokemon": int(sum(self.seen_pokemon)), "moves_obtained": int(sum(self.moves_obtained)), "opponent_level": self.max_opponent_level, - "met_bill": int(self.events.get_event("EVENT_MET_BILL")), - "used_cell_separator_on_bill": int( - self.events.get_event("EVENT_USED_CELL_SEPARATOR_ON_BILL") - ), - "ss_ticket": int(self.events.get_event("EVENT_GOT_SS_TICKET")), - "met_bill_2": int(self.events.get_event("EVENT_MET_BILL_2")), - "bill_said_use_cell_separator": int( - self.events.get_event("EVENT_BILL_SAID_USE_CELL_SEPARATOR") - ), - "left_bills_house_after_helping": int( - self.events.get_event("EVENT_LEFT_BILLS_HOUSE_AFTER_HELPING") - ), - "got_hm01": int(self.events.get_event("EVENT_GOT_HM01")), - "rubbed_captains_back": int(self.events.get_event("EVENT_RUBBED_CAPTAINS_BACK")), "taught_cut": int(self.check_if_party_has_hm(0xF)), "cut_coords": sum(self.cut_coords.values()), "cut_tiles": len(self.cut_tiles), - "start_menu": self.seen_start_menu, - "pokemon_menu": self.seen_pokemon_menu, - "stats_menu": self.seen_stats_menu, - "bag_menu": self.seen_bag_menu, - "action_bag_menu": self.seen_action_bag_menu, + "menu": { + "start_menu": self.seen_start_menu, + "pokemon_menu": self.seen_pokemon_menu, + "stats_menu": self.seen_stats_menu, + "bag_menu": self.seen_bag_menu, + "action_bag_menu": self.seen_action_bag_menu, + }, "blackout_check": self.blackout_check, "item_count": self.read_m(0xD31D), "reset_count": self.reset_count, "blackout_count": self.blackout_count, "pokecenter": np.sum(self.pokecenters), - "rival3": int(self.read_m(0xD665) == 4), - "rocket_hideout_found": self.events.get_event("EVENT_FOUND_ROCKET_HIDEOUT"), } | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, + "events": {self.events.get_event(event) for event in REQUIRED_EVENTS} + | {"rival3": int(self.read_m(0xD665) == 4)}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), "pokemon_exploration_map": explore_map, diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index fe93469..a5d0dfb 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -1,4 +1,5 @@ import pufferlib +from pokemonred_puffer.data.events import REQUIRED_EVENTS from pokemonred_puffer.environment import ( EVENT_FLAGS_START, EVENTS_FLAGS_LENGTH, @@ -221,3 +222,40 @@ def get_levels_reward(self): return self.max_level_sum else: return 15 + (self.max_level_sum - 15) / 4 + + +class CutWithObjectRewardRequiredEventsEnv(BaselineRewardEnv): + def get_game_state_reward(self): + return { + "event": self.reward_config["event"] * self.update_max_event_rew(), + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "level": self.reward_config["level"] * self.get_levels_reward(), + "badges": self.reward_config["badges"] * self.get_badges(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()), + "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, + "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, + "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, + "explore_hidden_objs": sum(self.seen_hidden_objs.values()) + * self.reward_config["explore_hidden_objs"], + "seen_action_bag_menu": self.seen_action_bag_menu + * self.reward_config["seen_action_bag_menu"], + "rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4), + } | { + event: self.events.get_event(event) * self.reward_config["required_event"] + for event in REQUIRED_EVENTS + } + + def get_levels_reward(self): + party_size = self.read_m("wPartyCount") + party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)] + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 15: + return self.max_level_sum + else: + return 15 + (self.max_level_sum - 15) / 4