diff --git a/config.yaml b/config.yaml index 020142c..d372b4d 100644 --- a/config.yaml +++ b/config.yaml @@ -78,6 +78,8 @@ env: exploration_inc: 1.0 exploration_max: 1.0 max_steps_scaling: 0 # 0.2 # every 10 events or items gained, multiply max_steps by 2 + map_id_scalefactor: 5.0 # multiply map ids whose events have not been completed by 5 + @@ -295,6 +297,33 @@ rewards: a_press: 0.0 # 0.00001 explore_warps: 0.05 use_surf: 0.05 + + baseline.ObjectRewardRequiredEventsMapIds: + reward: + event: 1.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + hm_count: 10.0 + level: 1.0 + badges: 5.0 + cut_coords: 0.0 + cut_tiles: 0.0 + start_menu: 0.0 + pokemon_menu: 0.0 + stats_menu: 0.0 + bag_menu: 0.0 + explore_hidden_objs: 0.01 + explore_signs: 0.015 + seen_action_bag_menu: 0.0 + required_event: 5.0 + required_item: 5.0 + useful_item: 1.0 + pokecenter_heal: 0.2 + exploration: 0.02 + a_press: 0.0 # 0.00001 + explore_warps: 0.01 + use_surf: 0.5 diff --git a/pokemonred_puffer/data/map.py b/pokemonred_puffer/data/map.py index 5a7a347..3ce0f83 100644 --- a/pokemonred_puffer/data/map.py +++ b/pokemonred_puffer/data/map.py @@ -271,3 +271,36 @@ class MapIds(Enum): 0x10, # Route 10 (Rock Tunnel) 0xE9, # Silph Co 9F (Heal station) } + +MAP_ID_COMPLETION_EVENTS = { + MapIds.PEWTER_GYM: "EVENT_BEAT_BROCK", + MapIds.CERULEAN_GYM: "EVENT_BEAT_MISTY", + MapIds.VERMILION_GYM: "EVENT_BEAT_LT_SURGE", + MapIds.CELADON_GYM: "EVENT_BEAT_ERIKA", + MapIds.SAFFRON_GYM: "EVENT_BEAT_SABRINA", + MapIds.FUCHSIA_GYM: "EVENT_BEAT_KOGA", + MapIds.CINNABAR_GYM: "EVENT_BEAT_BLAINE", + MapIds.VIRIDIAN_GYM: "EVENT_BEAT_VIRIDIAN_GYM_GIOVANNI", + MapIds.GAME_CORNER: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.ROCKET_HIDEOUT_B1F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.ROCKET_HIDEOUT_B2F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.ROCKET_HIDEOUT_B3F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.ROCKET_HIDEOUT_B4F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.ROCKET_HIDEOUT_ELEVATOR: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI", + MapIds.SILPH_CO_1F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_2F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_3F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_4F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_5F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_6F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_7F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_8F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_9F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_10F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_11F: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.SILPH_CO_ELEVATOR: "EVENT_BEAT_SILPH_CO_GIOVANNI", + MapIds.POKEMON_MANSION_1F: "HS_POKEMON_MANSION_B1F_ITEM_5", + MapIds.POKEMON_MANSION_2F: "HS_POKEMON_MANSION_B1F_ITEM_5", + MapIds.POKEMON_MANSION_3F: "HS_POKEMON_MANSION_B1F_ITEM_5", + MapIds.POKEMON_MANSION_B1F: "HS_POKEMON_MANSION_B1F_ITEM_5", +} diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 2baddb1..a07febe 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -33,7 +33,7 @@ USEFUL_ITEMS, Items, ) -from pokemonred_puffer.data.map import MapIds +from pokemonred_puffer.data.map import MAP_ID_COMPLETION_EVENTS, MapIds from pokemonred_puffer.data.missable_objects import MissableFlags from pokemonred_puffer.data.party import PartyMons from pokemonred_puffer.data.strength_puzzles import STRENGTH_SOLUTIONS @@ -130,6 +130,7 @@ def __init__(self, env_config: pufferlib.namespace): self.exploration_inc = env_config.exploration_inc self.exploration_max = env_config.exploration_max self.max_steps_scaling = env_config.max_steps_scaling + self.map_id_scalefactor = env_config.map_id_scalefactor self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? @@ -343,6 +344,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.caught_pokemon.fill(0) self.moves_obtained.fill(0) self.explore_map *= 0 + self.reward_explore_map *= 0 self.cut_explore_map *= 0 self.reset_mem() @@ -388,6 +390,7 @@ def init_mem(self): # All map ids have the same size, right? self.seen_coords: dict[int, dict[tuple[int, int, int], int]] = {} self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.reward_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.seen_map_ids = np.zeros(256) self.seen_npcs = {} @@ -1466,6 +1469,10 @@ def update_seen_coords(self): self.explore_map[local_to_global(y_pos, x_pos, map_n)] + inc, self.exploration_max, ) + self.reward_explore_map[local_to_global(y_pos, x_pos, map_n)] = min( + self.explore_map[local_to_global(y_pos, x_pos, map_n)] + inc, + self.exploration_max, + ) * self.map_id_scaling(map_n) # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 self.seen_map_ids[map_n] = 1 @@ -1687,3 +1694,20 @@ def get_events_sum(self): - int(self.read_bit(*MUSEUM_TICKET)), 0, ) + + def map_id_scaling(self, map_n: int) -> float: + map_id = MapIds(map_n) + if map_id not in MAP_ID_COMPLETION_EVENTS: + return 1.0 + + event_or_missable = MAP_ID_COMPLETION_EVENTS[map_id] + if ( + event_or_missable.startswith("EVENT_") + and not self.events.get_event(event_or_missable) + or ( + event_or_missable.startswith("HS_") + and not self.missables.get_missable(event_or_missable) + ) + ): + return self.map_id_scalefactor + return 1.0 diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index c67e491..23581d3 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -361,3 +361,65 @@ def get_levels_reward(self): return self.max_level_sum else: return 15 + (self.max_level_sum - 15) / 4 + + +class ObjectRewardRequiredEventsMapIds(BaselineRewardEnv): + def get_game_state_reward(self): + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + numBagItems = self.read_m("wNumBagItems") + bag_item_ids = set(self.pyboy.memory[wBagItems : wBagItems + 2 * numBagItems : 2]) + + return ( + { + "event": self.reward_config["event"] * self.update_max_event_rew(), + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "level": self.reward_config["level"] * self.get_levels_reward(), + "badges": self.reward_config["badges"] * self.get_badges(), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()), + "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, + "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, + "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, + "explore_hidden_objs": sum(self.seen_hidden_objs.values()), + "explore_signs": sum(self.seen_signs.values()) + * self.reward_config["explore_signs"], + "seen_action_bag_menu": self.seen_action_bag_menu + * self.reward_config["seen_action_bag_menu"], + "pokecenter_heal": self.pokecenter_heal * self.reward_config["pokecenter_heal"], + "rival3": self.reward_config["required_event"] + * int(self.read_m("wSSAnne2FCurScript") == 4), + "game_corner_rocket": self.reward_config["required_event"] + * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), + "saffron_guard": self.reward_config["required_event"] + * float(self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK")), + "a_press": len(self.a_press) * self.reward_config["a_press"], + "warps": len(self.seen_warps) * self.reward_config["explore_warps"], + "use_surf": self.reward_config["use_surf"] * self.use_surf, + "exploration": self.reward_config["exploration"] * np.sum(self.reward_explore_map), + } + | { + event: self.reward_config["required_event"] * float(self.events.get_event(event)) + for event in REQUIRED_EVENTS + } + | { + item.name: self.reward_config["required_item"] * float(item.value in bag_item_ids) + for item in REQUIRED_ITEMS + } + | { + item.name: self.reward_config["useful_item"] * float(item.value in bag_item_ids) + for item in USEFUL_ITEMS + } + ) + + def get_levels_reward(self): + party_size = self.read_m("wPartyCount") + party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)] + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 15: + return self.max_level_sum + else: + return 15 + (self.max_level_sum - 15) / 4