diff --git a/config.yaml b/config.yaml index 24ff19a..c02e688 100644 --- a/config.yaml +++ b/config.yaml @@ -157,6 +157,27 @@ rewards: explore_npcs: 0.02 explore_hidden_objs: 0.02 + baseline.RockTunnelReplicationEnv: + reward: + level: 1.0 + exploration: 0.02 + taught_cut: 10.0 + event: 3.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + cut_coords: 1.0 + cut_tiles: 1.0 + start_menu: 0.005 + pokemon_menu: 0.05 + stats_menu: 0.05 + bag_menu: 0.05 + pokecenter: 5.0 + # Really an addition to event reward + badges: 2.0 + bill_saved: 2.0 + + policies: multi_convolutional.MultiConvolutionalPolicy: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 24894a6..2e3d72e 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -290,6 +290,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.seen_pokemon = np.zeros(152, dtype=np.uint8) self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) + self.pokecenters = np.zeros(255, dtype=np.uint8) # lazy random seed setting if not seed: seed = random.randint(0, 4096) @@ -554,6 +555,7 @@ def step(self, action): if self.perfect_ivs: self.set_perfect_iv_dvs() self.taught_cut = self.check_if_party_has_cut() + self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} # TODO: Make log frequency a configuration parameter @@ -638,9 +640,9 @@ def cut_hook(self, context): ]: self.cut_coords[coords] = 10 else: - self.cut_coords[coords] = 0.01 + self.cut_coords[coords] = 0.001 else: - self.cut_coords[coords] = 0.01 + self.cut_coords[coords] = 0.001 self.cut_explore_map[local_to_global(y, x, map_id)] = 1 self.cut_tiles[wTileInFrontOfPlayer] = 1 @@ -875,8 +877,8 @@ def get_levels_reward(self): # Level reward party_levels = self.read_party() self.max_level_sum = max(self.max_level_sum, sum(party_levels)) - if self.max_level_sum < 30: + if self.max_level_sum < 15: level_reward = 1 * self.max_level_sum else: - level_reward = 30 + (self.max_level_sum - 30) / 4 + level_reward = 15 + (self.max_level_sum - 15) / 4 return level_reward diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 33e03ee..a46603a 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -5,6 +5,8 @@ RedGymEnv, ) +import numpy as np + MUSEUM_TICKET = (0xD754, 0) @@ -206,3 +208,40 @@ def get_levels_reward(self): return self.max_level_sum else: return 15 + (self.max_level_sum - 15) / 4 + + +class RockTunnelReplicationEnv(TeachCutReplicationEnv): + def get_game_state_reward(self): + return { + "level": self.reward_config["level"] * self.get_levels_reward(), + "exploration": self.reward_config["exploration"] * np.sum(self.seen_coords.values()), + "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), + "event": self.reward_config["event"] * self.update_max_event_rew(), + "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained), + "cut_coords": self.reward_config["cut_coords"] * np.sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * np.sum(self.cut_tiles), + "start_menu": ( + self.reward_config["start_menu"] * self.seen_start_menu * int(self.taught_cut) + ), + "pokemon_menu": ( + self.reward_config["pokemon_menu"] * self.seen_pokemon_menu * int(self.taught_cut) + ), + "stats_menu": ( + self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut) + ), + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut), + "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters), + "badges": self.reward_config["badges"] * self.get_badges(), + "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 3)), + "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 6)), + "left_bills_house_after_helping": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 7)), + "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), + }