Skip to content

Commit

Permalink
visited mask is now based on global coords
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 27, 2024
1 parent ff35074 commit 75dc986
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def render(self):
# If not in battle, set the visited mask. There's no reason to process it when in battle
scale = 2 if self.reduce_res else 1
if self.read_m(0xD057) == 0:
'''
for y in range(-72 // 16, 72 // 16):
for x in range(-80 // 16, 80 // 16):
# y-y1 = m (x-x1)
Expand Down Expand Up @@ -418,16 +419,19 @@ def render(self):
)
)
"""
"""
gr, gc = local_to_global(player_y, player_x, map_n)
visited_mask = (
255
* np.repeat(
np.repeat(self.seen_global_coords[gr - 4 : gr + 5, gc - 4 : gc + 6], 16, 0), 16, -1
)
).astype(np.uint8)
visited_mask = np.expand_dims(visited_mask, -1)
'''
gr, gc = local_to_global(player_y, player_x, map_n)
visited_mask = (
255
* np.repeat(
np.repeat(self.explore_map[gr - 4 : gr + 6, gc - 4 : gc + 6], 16 // scale, 0),
16 // scale,
-1,
)
).astype(np.uint8)[6 // scale : -10 // scale, :]
visited_mask = np.expand_dims(visited_mask, -1)

"""
global_map = np.expand_dims(
255 * resize(self.explore_map, game_pixels_render.shape, anti_aliasing=False),
axis=-1,
Expand Down Expand Up @@ -1063,8 +1067,6 @@ def disable_wild_encounter_hook(self, *args, **kwargs):
def agent_stats(self, action):
levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))]
badges = self.read_m("wObtainedBadges")
explore_map = self.explore_map
explore_map[explore_map > 0] = 1

_, wBagItems = self.pyboy.symbol_lookup("wBagItems")
bag = np.array(self.pyboy.memory[wBagItems : wBagItems + 40], dtype=np.uint8)
Expand Down Expand Up @@ -1123,8 +1125,8 @@ def agent_stats(self, action):
"useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
"pokemon_exploration_map": explore_map,
"cut_exploration_map": self.cut_explore_map,
"pokemon_exploration_map": self.explore_map,
# "cut_exploration_map": self.cut_explore_map,
}

def start_video(self):
Expand Down Expand Up @@ -1170,7 +1172,6 @@ def update_seen_coords(self):
x_pos, y_pos, map_n = self.get_game_coords()
self.seen_coords[(x_pos, y_pos, map_n)] = 1
# TODO: Turn into a wrapper?
self.explore_map[self.explore_map > 0] = 0.5
self.explore_map[local_to_global(y_pos, x_pos, map_n)] = 1
# self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1
self.seen_map_ids[map_n] = 1
Expand Down

0 comments on commit 75dc986

Please sign in to comment.