Skip to content

Commit

Permalink
temporarily comment out global map obs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 15, 2024
1 parent efdd5d1 commit 2873489
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
8 changes: 4 additions & 4 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def __init__(self, env_config: pufferlib.namespace):
"visited_mask": spaces.Box(
low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
),
"global_map": spaces.Box(
low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
),
# "global_map": spaces.Box(
# low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
# ),
# Discrete is more apt, but pufferlib is slower at processing Discrete
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
Expand Down Expand Up @@ -464,7 +464,7 @@ def render(self):
return {
"screen": game_pixels_render,
"visited_mask": visited_mask,
"global_map": global_map,
# "global_map": global_map,
}

def _get_obs(self):
Expand Down
19 changes: 10 additions & 9 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def encode_observations(self, observations):

screen = observations["screen"]
visited_mask = observations["visited_mask"]
global_map = observations["global_map"]
# global_map = observations["global_map"]
restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3])

if self.two_bit:
Expand All @@ -96,18 +96,19 @@ def encode_observations(self, observations):
.flatten()
.int(),
).reshape(restored_shape)
global_map = torch.index_select(
self.linear_buckets,
0,
((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)
# global_map = torch.index_select(
# self.linear_buckets,
# 0,
# ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
# .flatten()
# .int(),
# ).reshape(restored_shape)
badges = self.badge_buffer <= observations["badges"]
map_id = self.map_embeddings(observations["map_id"].long())
blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long())

image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
# image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
image_observation = torch.cat((screen, visited_mask), dim=-1)
if self.channels_last:
image_observation = image_observation.permute(0, 3, 1, 2)
if self.downsample > 1:
Expand Down

0 comments on commit 2873489

Please sign in to comment.