diff --git a/config.yaml b/config.yaml index 24ff19a..79a56a9 100644 --- a/config.yaml +++ b/config.yaml @@ -5,18 +5,18 @@ wandb: debug: env: - headless: True + headless: False stream_wrapper: False init_state: cut - max_steps: 4 + max_steps: 1_000_000 train: device: cpu compile: False compile_mode: default - num_envs: 10 + num_envs: 4 envs_per_worker: 1 - envs_per_batch: 1 - batch_size: 8 + envs_per_batch: 4 + batch_size: 16 batch_rows: 4 bptt_horizon: 2 total_timesteps: 100_000_000 @@ -28,8 +28,8 @@ debug: env_pool: False log_frequency: 5000 load_optimizer_state: False - swarm_frequency: 5 - swarm_keep_pct: .8 + swarm_frequency: 10 + swarm_keep_pct: .1 env: headless: True @@ -157,6 +157,27 @@ rewards: explore_npcs: 0.02 explore_hidden_objs: 0.02 + baseline.RockTunnelReplicationEnv: + reward: + level: 1.0 + exploration: 0.02 + taught_cut: 10.0 + event: 3.0 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + moves_obtained: 4.0 + cut_coords: 1.0 + cut_tiles: 1.0 + start_menu: 0.005 + pokemon_menu: 0.05 + stats_menu: 0.05 + bag_menu: 0.05 + pokecenter: 5.0 + # Really an addition to event reward + badges: 2.0 + bill_saved: 2.0 + + policies: multi_convolutional.MultiConvolutionalPolicy: diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index f7cc984..b74b193 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -386,6 +386,10 @@ def evaluate(self): ) self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state]) waiting_for.append(i + 1) + # Now copy the hidden state over + # This may be a little slow, but so is this whole process + self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] + self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] for i in waiting_for: self.env_send_queues[i].get() print("State migration complete") @@ -451,7 +455,7 @@ def evaluate(self): # Index alive mask with policy pool idxs... # TODO: Find a way to avoid having to do this - learner_mask = torch.Tensor(mask * self.policy_pool.mask) + learner_mask = torch.as_tensor(mask * self.policy_pool.mask) # Ensure indices do not exceed batch size indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy() @@ -588,11 +592,11 @@ def train(self): ) # Flatten the batch - self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8) - b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) - b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) - # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) - b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) + self.b_obs = b_obs = torch.as_tensor(self.obs_ary[b_idxs], dtype=torch.uint8) + b_actions = torch.as_tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) + b_logprobs = torch.as_tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) + # b_dones = torch.as_tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) + b_values = torch.as_tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) b_advantages = advantages.reshape( config.batch_rows, num_minibatches, config.bptt_horizon ).transpose(0, 1) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index aa15707..4ed86db 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -289,10 +289,11 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.seen_pokemon = np.zeros(152, dtype=np.uint8) self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) + self.pokecenters = np.zeros(252, dtype=np.uint8) # lazy random seed setting - if not seed: - seed = random.randint(0, 4096) - self.pyboy.tick(seed, render=False) + # if not seed: + # seed = random.randint(0, 4096) + # self.pyboy.tick(seed, render=False) else: self.reset_count += 1 @@ -370,8 +371,10 @@ def reset_mem(self): def render(self): # (144, 160, 3) game_pixels_render = np.expand_dims(self.screen.ndarray[:, :, 1], axis=-1) + if self.reduce_res: game_pixels_render = game_pixels_render[::2, ::2, :] + # game_pixels_render = skimage.measure.block_reduce(game_pixels_render, (2, 2, 1), np.min) # place an overlay on top of the screen greying out places we haven't visited # first get our location @@ -551,11 +554,18 @@ def step(self, action): if self.perfect_ivs: self.set_perfect_iv_dvs() self.taught_cut = self.check_if_party_has_cut() - + self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} + + if self.get_events_sum() > self.max_event_rew: + state = io.BytesIO() + self.pyboy.save_state(state) + state.seek(0) + info["state"] = state.read() + # TODO: Make log frequency a configuration parameter if self.step_count % self.log_frequency == 0: - info = self.agent_stats(action) + info = info | self.agent_stats(action) obs = self._get_obs() @@ -635,9 +645,9 @@ def cut_hook(self, context): ]: self.cut_coords[coords] = 10 else: - self.cut_coords[coords] = 0.01 + self.cut_coords[coords] = 0.001 else: - self.cut_coords[coords] = 0.01 + self.cut_coords[coords] = 0.001 self.cut_explore_map[local_to_global(y, x, map_id)] = 1 self.cut_tiles[wTileInFrontOfPlayer] = 1 @@ -687,6 +697,7 @@ def agent_stats(self, action): "item_count": self.read_m(0xD31D), "reset_count": self.reset_count, "blackout_count": self.blackout_count, + "pokecenter": np.sum(self.pokecenters), }, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), @@ -877,3 +888,17 @@ def get_levels_reward(self): else: level_reward = 30 + (self.max_level_sum - 30) / 4 return level_reward + + def get_events_sum(self): + # adds up all event flags, exclude museum ticket + return max( + sum( + [ + self.read_m(i).bit_count() + for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) + ] + ) + - self.base_event_flags + - int(self.read_bit(*MUSEUM_TICKET)), + 0, + ) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 33e03ee..c6da5d2 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -5,12 +5,15 @@ RedGymEnv, ) +import numpy as np + MUSEUM_TICKET = (0xD754, 0) class BaselineRewardEnv(RedGymEnv): def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): super().__init__(env_config) + self.reward_config = reward_config # TODO: make the reward weights configurable def get_game_state_reward(self): @@ -80,11 +83,7 @@ def get_levels_reward(self): return 15 + (self.max_level_sum - 15) / 4 -class TeachCutReplicationEnv(RedGymEnv): - def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): - super().__init__(env_config) - self.reward_config = reward_config - +class TeachCutReplicationEnv(BaselineRewardEnv): def get_game_state_reward(self): return { "event": self.reward_config["event"] * self.update_max_event_rew(), @@ -113,31 +112,8 @@ def get_game_state_reward(self): "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), } - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) - return self.max_event_rew - - def get_all_events_reward(self): - # adds up all event flags, exclude museum ticket - return max( - sum( - [ - self.read_m(i).bit_count() - for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) - ] - ) - - self.base_event_flags - - int(self.read_bit(*MUSEUM_TICKET)), - 0, - ) - - -class TeachCutReplicationEnvFork(RedGymEnv): - def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace): - super().__init__(env_config) - self.reward_config = reward_config +class TeachCutReplicationEnvFork(BaselineRewardEnv): def get_game_state_reward(self): return { "event": self.reward_config["event"] * self.update_max_event_rew(), @@ -179,24 +155,51 @@ def get_game_state_reward(self): "level": self.reward_config["level"] * self.get_levels_reward(), } - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) - return self.max_event_rew + def get_levels_reward(self): + party_size = self.read_m("wPartyCount") + party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)] + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 15: + return self.max_level_sum + else: + return 15 + (self.max_level_sum - 15) / 4 - def get_all_events_reward(self): - # adds up all event flags, exclude museum ticket - return max( - sum( - [ - self.read_m(i).bit_count() - for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) - ] - ) - - self.base_event_flags - - int(self.read_bit(*MUSEUM_TICKET)), - 0, - ) + +class RockTunnelReplicationEnv(BaselineRewardEnv): + def get_game_state_reward(self): + return { + "level": self.reward_config["level"] * self.get_levels_reward(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), + "event": self.reward_config["event"] * self.update_max_event_rew(), + "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), + "start_menu": ( + self.reward_config["start_menu"] * self.seen_start_menu * int(self.taught_cut) + ), + "pokemon_menu": ( + self.reward_config["pokemon_menu"] * self.seen_pokemon_menu * int(self.taught_cut) + ), + "stats_menu": ( + self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut) + ), + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut), + # "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters), + "badges": self.reward_config["badges"] * self.get_badges(), + "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), + "used_cell_separator_on_bill": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 3)), + "ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)), + "met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)), + "bill_said_use_cell_separator": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 6)), + "left_bills_house_after_helping": self.reward_config["bill_saved"] + * int(self.read_bit(0xD7F2, 7)), + "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), + } def get_levels_reward(self): party_size = self.read_m("wPartyCount")