diff --git a/config.yaml b/config.yaml index f30be7b..261561b 100644 --- a/config.yaml +++ b/config.yaml @@ -7,7 +7,7 @@ debug: env: headless: False stream_wrapper: False - init_state: cut3 + init_state: cut max_steps: 1_000_000 train: device: cpu @@ -157,25 +157,24 @@ rewards: explore_npcs: 0.02 explore_hidden_objs: 0.02 - baseline.RockTunnelReplicationEnv: + baseline.CutWithObjectRewardsEnv: reward: - level: 1.0 - exploration: 0.02 - taught_cut: 10.0 - event: 3.0 - seen_pokemon: 4.0 - caught_pokemon: 4.0 + event: 1.0 + bill_saved: 5.0 moves_obtained: 4.0 + hm_count: 10.0 + badges: 10.0 + exploration: 0.02 cut_coords: 1.0 cut_tiles: 1.0 - start_menu: 0.005 - pokemon_menu: 0.05 - stats_menu: 0.05 - bag_menu: 0.05 - pokecenter: 5.0 - # Really an addition to event reward - badges: 2.0 - bill_saved: 2.0 + start_menu: 0.0 + pokemon_menu: 0.0 + stats_menu: 0.0 + bag_menu: 0.0 + taught_cut: 10.0 + explore_npcs: 0.02 + explore_hidden_objs: 0.02 + rocket_hideout_found: 5.0 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 2a618f5..e9531a2 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -287,6 +287,12 @@ def register_hooks(self): self.pyboy.hook_register( None, "CheckForHiddenObject.foundMatchingObject", self.hidden_object_hook, None ) + """ + _, addr = self.pyboy.symbol_lookup("IsSpriteOrSignInFrontOfPlayer.retry") + self.pyboy.hook_register( + None, addr-1, self.sign_hook, None + ) + """ self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None) self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None) # self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True) @@ -306,6 +312,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.init_mem() # We only init seen hidden objs once cause they can only be found once! self.seen_hidden_objs = {} + self.seen_signs = {} if options.get("state", None) is not None: self.pyboy.load_state(io.BytesIO(options["state"])) self.reset_count += 1 @@ -553,7 +560,7 @@ def _get_obs(self): # "x": np.array(player_x, dtype=np.uint8), # "y": np.array(player_y, dtype=np.uint8), # "map_id": np.array(map_n, dtype=np.uint8), - "badges": np.array(self.read_m("wObtainedBadges"), dtype=np.uint8), + "badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8), } def set_perfect_iv_dvs(self): @@ -711,6 +718,12 @@ def cut_if_next(self): self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) self.pyboy.tick(4 * self.action_freq, render=True) + def sign_hook(self, *args, **kwargs): + sign_id = self.pyboy.memory[self.pyboy.symbol_lookup("hSpriteIndexOrTextID")[1]] + map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]] + # We will store this by map id, y, x, + self.seen_hidden_objs[(map_id, sign_id)] = 1 + def hidden_object_hook(self, *args, **kwargs): hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]] map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]] diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 2da3d4a..07f2933 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -69,7 +69,6 @@ def __init__( self.register_buffer( "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False ) - self.register_buffer("binary_mask", torch.tensor([2**i for i in range(8)])) def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) @@ -99,8 +98,7 @@ def encode_observations(self, observations): .flatten() .int(), ).reshape(restored_shape) - # > 0 doesn't risk a type conversion - badges = (observations["badges"] & self.binary_mask) > 0 + badges = (torch.arange(8) + 1) <= observations["badges"] image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) if self.channels_last: diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index c6da5d2..b914752 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -5,8 +5,6 @@ RedGymEnv, ) -import numpy as np - MUSEUM_TICKET = (0xD754, 0) @@ -165,16 +163,30 @@ def get_levels_reward(self): return 15 + (self.max_level_sum - 15) / 4 -class RockTunnelReplicationEnv(BaselineRewardEnv): +class CutWithObjectRewardsEnv(BaselineRewardEnv): def get_game_state_reward(self): return { - "level": self.reward_config["level"] * self.get_levels_reward(), - "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), - "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), "event": self.reward_config["event"] * self.update_max_event_rew(), - "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), - "caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon), - "moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained), + "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": ( + self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 3)) + ), + "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": ( + self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 6)) + ), + "left_bills_house_after_helping": ( + self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 7)) + ), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "badges": self.reward_config["badges"] * self.get_badges(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "explore_npcs": self.reward_config["explore_npcs"] * sum(self.seen_npcs.values()), + "explore_hidden_objs": ( + self.reward_config["explore_hidden_objs"] * sum(self.seen_hidden_objs.values()) + ), "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), "start_menu": ( @@ -187,18 +199,12 @@ def get_game_state_reward(self): self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut) ), "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut), - # "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters), - "badges": self.reward_config["badges"] * self.get_badges(), - "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), - "used_cell_separator_on_bill": self.reward_config["bill_saved"] - * int(self.read_bit(0xD7F2, 3)), - "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), - "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), - "bill_said_use_cell_separator": self.reward_config["bill_saved"] - * int(self.read_bit(0xD7F2, 6)), - "left_bills_house_after_helping": self.reward_config["bill_saved"] - * int(self.read_bit(0xD7F2, 7)), - "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), + "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "level": self.reward_config["level"] * self.get_levels_reward(), + "rocket_hideout_found": self.reward_config["rocket_hideout_found"] + * int(self.read_bit(0xD77E, 1)), } def get_levels_reward(self):