From 7e91cc7d40f809f9913d1f73a72eaa5f059e337f Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 00:54:17 -0400 Subject: [PATCH 1/5] repro? --- config.yaml | 40 +++++++++++++++----- pokemonred_puffer/environment.py | 54 +++++++++++++++++---------- pokemonred_puffer/rewards/baseline.py | 36 ++++++++++++++++-- 3 files changed, 99 insertions(+), 31 deletions(-) diff --git a/config.yaml b/config.yaml index df64a8a..344ac50 100644 --- a/config.yaml +++ b/config.yaml @@ -33,7 +33,7 @@ env: state_dir: pyboy_states init_state: Bulbasaur action_freq: 24 - max_steps: 100_000_000 + max_steps: 20480 save_video: False fast_video: False frame_stacks: 1 @@ -48,9 +48,9 @@ train: compile_mode: "reduce-overhead" float32_matmul_precision: "high" total_timesteps: 100_000_000_000 - batch_size: 32768 - learning_rate: 2.5e-4 - anneal_lr: True + batch_size: 65536 + learning_rate: 2.0e-4 + anneal_lr: False gamma: 0.998 gae_lambda: 0.95 num_minibatches: 4 @@ -63,12 +63,12 @@ train: max_grad_norm: 0.5 target_kl: ~ batch_rows: 128 - bptt_horizon: 16 #8 - vf_clip_coef: 0.12 # 0.1 + bptt_horizon: 16 + vf_clip_coef: 0.1 - num_envs: 150 - envs_per_worker: 2 - envs_per_batch: 60 + num_envs: 96 + envs_per_worker: 1 + envs_per_batch: 32 env_pool: True verbose: True @@ -103,9 +103,31 @@ wrappers: - exploration.MaxLengthWrapper: capacity: 1750 + stream_only: + - stream_wrapper.StreamWrapper: + user: thatguy + rewards: baseline.BaselineRewardEnv: reward: + baseline.TeachCutReplicationEnv: + reward: + event: 1.0 + bill_saved: 5.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + hm_count: 10.0 + level: 1.0 + badges: 10.0 + exploration: 0.02 + cut_coords: 1.0 + cut_tiles: 1.0 + start_menu: 0.01 + pokemon_menu: 0.1 + stats_menu: 0.1 + bag_menu: 0.1 + policies: multi_convolutional.MultiConvolutionalPolicy: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index ac60f6c..48a00d4 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -4,7 +4,7 @@ from collections import deque from multiprocessing import Lock, shared_memory from pathlib import Path -from typing import Optional +from typing import Iterable, Optional import uuid import mediapy as media @@ -83,6 +83,8 @@ ] ) +HM_ITEM_IDS = set([0xC4, 0xC5, 0xC6, 0xC7, 0xC8]) + RESET_MAP_IDS = set( [ 0x0, # Pallet Town @@ -242,23 +244,22 @@ def reset(self, seed: Optional[int] = None): self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.init_mem() self.reset_count = 0 + with open(self.init_state_path, "rb") as f: + self.pyboy.load_state(f) + # lazy random seed setting + if not seed: + seed = random.randint(0, 4096) + self.pyboy.tick(seed, render=False) else: - self.recent_screens.clear() - self.recent_actions.clear() - self.seen_pokemon.fill(0) - self.caught_pokemon.fill(0) - self.moves_obtained.fill(0) - self.explore_map *= 0 - self.reset_mem() self.reset_count += 1 - with open(self.init_state_path, "rb") as f: - self.pyboy.load_state(f) - - # lazy random seed setting - if not seed: - seed = random.randint(0, 4096) - self.pyboy.tick(seed, render=False) + self.recent_screens.clear() + self.recent_actions.clear() + self.seen_pokemon.fill(0) + self.caught_pokemon.fill(0) + self.moves_obtained.fill(0) + self.explore_map *= 0 + self.reset_mem() self.taught_cut = self.check_if_party_has_cut() self.base_event_flags = sum( @@ -284,9 +285,6 @@ def reset(self, seed: Optional[int] = None): self.action_hist = np.zeros(len(VALID_ACTIONS)) - # experiment! - # self.max_steps += 128 - self.max_map_progress = 0 self.progress_reward = self.get_game_state_reward() self.total_reward = sum([val for _, val in self.progress_reward.items()]) @@ -497,7 +495,7 @@ def step(self, action): self.step_count += 1 reset = ( - self.step_count > self.max_steps # or + self.step_count >= self.max_steps # or # self.caught_pokemon[6] == 1 # squirtle ) @@ -799,3 +797,21 @@ def get_map_progress(self, map_idx): return self.essential_map_locations[map_idx] else: return -1 + + def get_items_in_bag(self) -> Iterable[int]: + num_bag_items = self.read_m("wNumBagItems") + _, addr = self.pyboy.symbol_lookup("wBagItems") + return self.pyboy.memory[addr : addr + 2 * num_bag_items][::2] + + def get_hm_count(self) -> int: + return len(HM_ITEM_IDS.intersection(self.get_items_in_bag())) + + def get_levels_reward(self): + # Level reward + party_levels = self.read_party() + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 30: + level_reward = 1 * self.max_level_sum + else: + level_reward = 30 + (self.max_level_sum - 30) / 4 + return level_reward diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 9a5bab8..adcf3ec 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -12,9 +12,6 @@ class BaselineRewardEnv(RedGymEnv): def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): super().__init__(env_config) - def step(self, action): - return super().step(action) - # TODO: make the reward weights configurable def get_game_state_reward(self): # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map @@ -80,3 +77,36 @@ def get_levels_reward(self): return self.max_level_sum else: return 15 + (self.max_level_sum - 15) / 4 + + +class TeachCutReplicationEnv(RedGymEnv): + def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): + super().__init__(env_config) + self.reward_config = reward_config + + def get_game_state_reward(self): + return { + "event": self.reward_config.event * self.update_max_event_rew(), + "met_bill": self.reward_config.bill_saved * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": self.reward_config.bill_saved + * int(self.read_bit(0xD7F2, 3)), + "ss_ticket": self.reward_config.bill_saved * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config.bill_saved * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": self.reward_config.bill_saved + * int(self.read_bit(0xD7F2, 6)), + "left_bills_house_after_helping": self.reward_config.bill_saved + * int(self.read_bit(0xD7F2, 7)), + "seen_pokemon": self.reward_config.seen_pokemon * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config.caught_pokemon * sum(self.caught_pokemon), + "moves_obtained": self.reward_config.moves_obtained * sum(self.moves_obtained), + "hm_count": self.reward_config.hm_count * self.get_hm_count(), + "level": self.reward_config.level * self.get_levels_reward(), + "badges": self.reward_config.badges * self.get_badges(), + "exploration": self.reward_config.exploration * len(self.seen_coords), + "cut_coords": self.reward_config.cut_coords * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config.cut_tiles * sum(self.cut_tiles), + "start_menu": self.reward_config.start_menu * self.seen_start_menu, + "pokemon_menu": self.reward_config.pokemon_menu * self.seen_pokemon_menu, + "stats_menu": self.reward_config.stats_menu * self.seen_stats_menu, + "bag_menu": self.reward_config.bag_menu * self.seen_bag_menu, + } From 5dbf8fd31df9c756ce8448e724a08ad3324a3dc3 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 01:02:03 -0400 Subject: [PATCH 2/5] Remove obs --- pokemonred_puffer/environment.py | 12 ++++++------ pokemonred_puffer/policies/multi_convolutional.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 48a00d4..4b08b09 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -179,13 +179,13 @@ def __init__(self, env_config: pufferlib.namespace): ), # Discrete is more apt, but pufferlib is slower at processing Discrete "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8), + # "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + # "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + # "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8), } ) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 9daba69..55dd91b 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -97,13 +97,13 @@ def encode_observations(self, observations): ( *output, one_hot(observations["direction"].long(), 4).float().squeeze(1), - one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), + # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), one_hot(observations["battle_type"].long(), 4).float().squeeze(1), - observations["cut_in_party"].float(), - observations["x"].float(), - observations["y"].float(), - one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), - one_hot(observations["badges"].long(), 8).float().squeeze(1), + # observations["cut_in_party"].float(), + # observations["x"].float(), + # observations["y"].float(), + # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), + # one_hot(observations["badges"].long(), 8).float().squeeze(1), ), dim=-1, ) From 972f17707ee1dbec59eaadc54033cca59a4e7aa1 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 01:05:22 -0400 Subject: [PATCH 3/5] More remvoe obs --- pokemonred_puffer/environment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 4b08b09..189623c 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -444,13 +444,13 @@ def _get_obs(self): "direction": np.array( self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 ), - "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), + # "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), - "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), - "x": np.array(player_x, dtype=np.uint8), - "y": np.array(player_y, dtype=np.uint8), - "map_id": np.array(map_n, dtype=np.uint8), - "badges": np.array(self.get_badges(), dtype=np.uint8), + # "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), + # "x": np.array(player_x, dtype=np.uint8), + # "y": np.array(player_y, dtype=np.uint8), + # "map_id": np.array(map_n, dtype=np.uint8), + # "badges": np.array(self.get_badges(), dtype=np.uint8), } def set_perfect_iv_dvs(self): From 6d8f87dac4bd524171f00f37ea7ca920004294da Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:58:14 -0400 Subject: [PATCH 4/5] Fixes --- config.yaml | 1 + pokemonred_puffer/cleanrl_puffer.py | 2 +- pokemonred_puffer/environment.py | 16 +++++++- pokemonred_puffer/rewards/baseline.py | 59 ++++++++++++++++++--------- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/config.yaml b/config.yaml index 344ac50..a8ea1f5 100644 --- a/config.yaml +++ b/config.yaml @@ -79,6 +79,7 @@ train: overlay_interval: 200 cpu_offload: True pool_kernel: [0] + log_frequency: 2000 wrappers: baseline: diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 60e7b86..e2a7a44 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -487,7 +487,7 @@ def evaluate(self): if self.wandb is not None: self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay) try: # TODO: Better checks on log data types - self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) + # self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) self.stats[k] = np.mean(v) self.max_stats[k] = np.max(v) except: # noqa diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 189623c..4bfade5 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -114,6 +114,16 @@ WindowEvent.PRESS_BUTTON_START, ] +VALID_RELEASE_ACTIONS = [ + WindowEvent.RELEASE_ARROW_DOWN, + WindowEvent.RELEASE_ARROW_LEFT, + WindowEvent.RELEASE_ARROW_RIGHT, + WindowEvent.RELEASE_ARROW_UP, + WindowEvent.RELEASE_BUTTON_A, + WindowEvent.RELEASE_BUTTON_B, + WindowEvent.RELEASE_BUTTON_START, +] + VALID_ACTIONS_STR = ["down", "left", "right", "up", "a", "b", "start"] ACTION_SPACE = spaces.Discrete(len(VALID_ACTIONS)) @@ -144,6 +154,7 @@ def __init__(self, env_config: pufferlib.namespace): self.perfect_ivs = env_config.perfect_ivs self.reduce_res = env_config.reduce_res self.gb_path = env_config.gb_path + self.log_frequency = env_config.log_frequency self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? @@ -488,7 +499,7 @@ def step(self, action): info = {} # TODO: Make log frequency a configuration parameter - if self.step_count % 2000 == 0: + if self.step_count % self.log_frequency == 0: info = self.agent_stats(action) obs = self._get_obs() @@ -505,7 +516,8 @@ def run_action_on_emulator(self, action): self.action_hist[action] += 1 # press button then release after some steps # TODO: Add video saving logic - self.pyboy.button(VALID_ACTIONS_STR[action], delay=8) + self.pyboy.send_input(action) + self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) if self.save_video and self.fast_video: diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index adcf3ec..1592633 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -86,27 +86,46 @@ def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.nam def get_game_state_reward(self): return { - "event": self.reward_config.event * self.update_max_event_rew(), - "met_bill": self.reward_config.bill_saved * int(self.read_bit(0xD7F1, 0)), - "used_cell_separator_on_bill": self.reward_config.bill_saved + "event": self.reward_config["event"] * self.update_max_event_rew(), + "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 3)), - "ss_ticket": self.reward_config.bill_saved * int(self.read_bit(0xD7F2, 4)), - "met_bill_2": self.reward_config.bill_saved * int(self.read_bit(0xD7F2, 5)), - "bill_said_use_cell_separator": self.reward_config.bill_saved + "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 6)), - "left_bills_house_after_helping": self.reward_config.bill_saved + "left_bills_house_after_helping": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 7)), - "seen_pokemon": self.reward_config.seen_pokemon * sum(self.seen_pokemon), - "caught_pokemon": self.reward_config.caught_pokemon * sum(self.caught_pokemon), - "moves_obtained": self.reward_config.moves_obtained * sum(self.moves_obtained), - "hm_count": self.reward_config.hm_count * self.get_hm_count(), - "level": self.reward_config.level * self.get_levels_reward(), - "badges": self.reward_config.badges * self.get_badges(), - "exploration": self.reward_config.exploration * len(self.seen_coords), - "cut_coords": self.reward_config.cut_coords * sum(self.cut_coords.values()), - "cut_tiles": self.reward_config.cut_tiles * sum(self.cut_tiles), - "start_menu": self.reward_config.start_menu * self.seen_start_menu, - "pokemon_menu": self.reward_config.pokemon_menu * self.seen_pokemon_menu, - "stats_menu": self.reward_config.stats_menu * self.seen_stats_menu, - "bag_menu": self.reward_config.bag_menu * self.seen_bag_menu, + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "level": self.reward_config["level"] * self.get_levels_reward(), + "badges": self.reward_config["badges"] * self.get_badges(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), + "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, + "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, + "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, } + + def update_max_event_rew(self): + cur_rew = self.get_all_events_reward() + self.max_event_rew = max(cur_rew, self.max_event_rew) + return self.max_event_rew + + def get_all_events_reward(self): + # adds up all event flags, exclude museum ticket + return max( + sum( + [ + self.read_m(i).bit_count() + for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) + ] + ) + - self.base_event_flags + - int(self.read_bit(*MUSEUM_TICKET)), + 0, + ) From dae87d934bf282fba07374887be7c497f384c298 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:09:45 -0400 Subject: [PATCH 5/5] Repair flat size --- config.yaml | 2 +- pokemonred_puffer/environment.py | 2 +- pokemonred_puffer/policies/multi_convolutional.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index a8ea1f5..b6b17de 100644 --- a/config.yaml +++ b/config.yaml @@ -137,7 +137,7 @@ policies: hidden_size: 512 output_size: 512 framestack: 3 - flat_size: 2184 + flat_size: 1928 recurrent: # Assumed to be in the same module as the policy diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 4bfade5..fedf7e3 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -516,7 +516,7 @@ def run_action_on_emulator(self, action): self.action_hist[action] += 1 # press button then release after some steps # TODO: Add video saving logic - self.pyboy.send_input(action) + self.pyboy.send_input(VALID_ACTIONS[action]) self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 55dd91b..82f588c 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -27,7 +27,7 @@ def __init__( env, screen_framestack: int = 3, global_map_frame_stack: int = 1, - screen_flat_size: int = 2433, # 14341, + screen_flat_size: int = 1928, # 14341, global_map_flat_size: int = 1600, input_size: int = 512, framestack: int = 1,