diff --git a/config.yaml b/config.yaml index 39f4659..a2b7427 100644 --- a/config.yaml +++ b/config.yaml @@ -12,6 +12,7 @@ debug: log_frequency: 1 disable_wild_encounters: True disable_ai_actions: True + use_global_map: True train: device: cpu compile: False diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 3b130ec..c7c8ad8 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1063,6 +1063,7 @@ def disable_wild_encounter_hook(self, *args, **kwargs): def agent_stats(self, action): levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))] badges = self.read_m("wObtainedBadges") + explore_map = self.explore_map[self.explore_map > 0] = 1 return { "stats": { "step": self.step_count + self.reset_count * self.max_steps, @@ -1113,7 +1114,7 @@ def agent_stats(self, action): | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), - "pokemon_exploration_map": self.explore_map, + "pokemon_exploration_map": explore_map, "cut_exploration_map": self.cut_explore_map, } @@ -1159,6 +1160,8 @@ def update_seen_coords(self): if not (self.read_m("wd736") & 0b1000_0000): x_pos, y_pos, map_n = self.get_game_coords() self.seen_coords[(x_pos, y_pos, map_n)] = 1 + # TODO: Turn into a wrapper? + self.explore_map[self.explore_map > 0] = 0.5 self.explore_map[local_to_global(y_pos, x_pos, map_n)] = 1 # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 self.seen_map_ids[map_n] = 1 diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index e9fde41..37cfdc2 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -69,6 +69,8 @@ def __init__( nn.LazyConv2d(64, 3, stride=1), nn.ReLU(), nn.Flatten(), + nn.LazyLinear(480), + nn.ReLU(), ) # if channels_last: # self.global_map_network = self.global_map_network.to(