diff --git a/config.yaml b/config.yaml index b40fa09..cec536b 100644 --- a/config.yaml +++ b/config.yaml @@ -129,16 +129,31 @@ rewards: pokemon_menu: 0.1 stats_menu: 0.1 bag_menu: 0.1 + baseline.TeachCutReplicationEnvFork: + reward: + event: 1.0 + bill_saved: 5.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + hm_count: 10.0 + level: 1.0 + badges: 10.0 + exploration: 0.02 + cut_coords: 1.0 + cut_tiles: 1.0 + start_menu: 0.01 + pokemon_menu: 0.1 + stats_menu: 0.1 + bag_menu: 0.1 + taught_cut: 10.0 policies: multi_convolutional.MultiConvolutionalPolicy: policy: - input_size: 512 hidden_size: 512 output_size: 512 - framestack: 3 - flat_size: 1928 recurrent: # Assumed to be in the same module as the policy diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index e2a7a44..eb49bf6 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -314,6 +314,9 @@ def __init__( self.reward_buffer = deque(maxlen=1_000) self.exploration_map_agg = np.zeros((config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32) + self.cut_exploration_map_agg = np.zeros( + (config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32 + ) self.taught_cut = False self.infos = {} @@ -479,13 +482,18 @@ def evaluate(self): self.stats = {} self.max_stats = {} for k, v in self.infos["learner"].items(): - if "pokemon_exploration_map" in k and config.save_overlay is True: + if "exploration_map" in k and config.save_overlay is True: if self.update % config.overlay_interval == 0: - # self.exploration_map_agg[env_id, :, :] = v - # overlay = make_pokemon_red_overlay(self.exploration_map_agg) overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) if self.wandb is not None: self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay) + if "cut_exploration_map" in k and config.save_overlay is True: + if self.update % config.overlay_interval == 0: + overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) + if self.wandb is not None: + self.stats["Media/aggregate_cut_exploration_map"] = self.wandb.Image( + overlay + ) try: # TODO: Better checks on log data types # self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) self.stats[k] = np.mean(v) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 20144ff..4d083c3 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -253,6 +253,7 @@ def reset(self, seed: Optional[int] = None): self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.init_mem() self.reset_count = 0 with open(self.init_state_path, "rb") as f: @@ -270,6 +271,7 @@ def reset(self, seed: Optional[int] = None): self.caught_pokemon.fill(0) self.moves_obtained.fill(0) self.explore_map *= 0 + self.cut_explore_map *= 0 self.reset_mem() self.taught_cut = self.check_if_party_has_cut() @@ -543,7 +545,7 @@ def pokemon_menu_hook(self, *args, **kwargs): self.seen_pokemon_menu = 1 def chose_stats_hook(self, *args, **kwargs): - self.chose_stats_hook = 1 + self.seen_stats_menu = 1 def chose_item_hook(self, *args, **kwargs): self.seen_action_bag_menu = 1 @@ -567,8 +569,12 @@ def cut_hook(self, context): coords = (x - 1, y, map_id) if player_direction == 0xC: coords = (x + 1, y, map_id) + + wTileInFrontOfPlayer = self.pyboy.memory[ + self.pyboy.symbol_lookup("wTileInFrontOfPlayer")[1] + ] if context: - if self.pyboy.memory[self.pyboy.symbol_lookup("wTileInFrontOfPlayer")[1]] in [ + if wTileInFrontOfPlayer in [ 0x3D, 0x50, ]: @@ -578,6 +584,9 @@ def cut_hook(self, context): else: self.cut_coords[coords] = 0.01 + self.cut_explore_map[local_to_global(y, x, map_id)] = 1 + self.cut_tiles[wTileInFrontOfPlayer] = 1 + def agent_stats(self, action): levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))] return { @@ -626,7 +635,8 @@ def agent_stats(self, action): }, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), - "pokemon_exploration_map": self.explore_map, + "exploration_map": self.explore_map, + "cut_exploration_map": self.cut_explore_map, } def start_video(self): @@ -717,12 +727,12 @@ def read_short(self, addr: str | int) -> int: if isinstance(addr, str): _, addr = self.pyboy.symbol_lookup(addr) data = self.pyboy.memory[addr : addr + 2] - return data[0] << 8 + data[1] + return int(data[0] << 8) + int(data[1]) def read_bit(self, addr: str | int, bit: int) -> bool: # add padding so zero will read '0b100000000' instead of '0b0' # return bin(256 + self.read_m(addr))[-bit - 1] == "1" - return bool(self.read_m(addr) & 1 << (7 - bit)) + return bool(int(self.read_m(addr)) & (1 << (7 - bit))) def read_event_bits(self): _, addr = self.pyboy.symbol_lookup("wEventFlags") diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 82f588c..423e86a 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -25,13 +25,6 @@ class MultiConvolutionalPolicy(pufferlib.models.Policy): def __init__( self, env, - screen_framestack: int = 3, - global_map_frame_stack: int = 1, - screen_flat_size: int = 1928, # 14341, - global_map_flat_size: int = 1600, - input_size: int = 512, - framestack: int = 1, - flat_size: int = 1, hidden_size=512, output_size=512, channels_last: bool = True, @@ -42,41 +35,22 @@ def __init__( self.channels_last = channels_last self.downsample = downsample self.screen_network = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Conv2d(screen_framestack, 32, 8, stride=4)), + nn.LazyConv2d(32, 8, stride=4), nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 4, stride=2)), + nn.LazyConv2d(64, 4, stride=2), nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 3, stride=1)), + nn.LazyConv2d(64, 3, stride=1), nn.ReLU(), nn.Flatten(), ) - """ - self.global_map_network = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Conv2d(global_map_frame_stack, 32, 16, stride=8)), - nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 8, stride=4)), - nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 4, stride=2)), - nn.ReLU(), - nn.Flatten(), - ) - """ - self.encode_linear = nn.Sequential( - pufferlib.pytorch.layer_init( - nn.Linear( - screen_flat_size, - hidden_size, - ), - ), + nn.LazyLinear(hidden_size), nn.ReLU(), ) - self.actor = pufferlib.pytorch.layer_init( - nn.Linear(output_size, self.num_actions), std=0.01 - ) - self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1) + self.actor = nn.LazyLinear(self.num_actions) + self.value_fn = nn.LazyLinear(output_size, 1) def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 1592633..ef118e0 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -129,3 +129,53 @@ def get_all_events_reward(self): - int(self.read_bit(*MUSEUM_TICKET)), 0, ) + + +class TeachCutReplicationEnvFork(RedGymEnv): + def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): + super().__init__(env_config) + self.reward_config = reward_config + + def get_game_state_reward(self): + return { + "event": self.reward_config["event"] * self.update_max_event_rew(), + "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 3)), + "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 6)), + "left_bills_house_after_helping": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 7)), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "badges": self.reward_config["badges"] * self.get_badges(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), + "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, + "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, + "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, + "taught_cut": self.reward_config["taught_cut"] * int(self.check_if_party_has_cut()), + } + + def update_max_event_rew(self): + cur_rew = self.get_all_events_reward() + self.max_event_rew = max(cur_rew, self.max_event_rew) + return self.max_event_rew + + def get_all_events_reward(self): + # adds up all event flags, exclude museum ticket + return max( + sum( + [ + self.read_m(i).bit_count() + for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) + ] + ) + - self.base_event_flags + - int(self.read_bit(*MUSEUM_TICKET)), + 0, + )