Skip to content

Commit

Permalink
Global map is its own obs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 18, 2024
1 parent c95ceb1 commit 9135a62
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
18 changes: 14 additions & 4 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium import Env, spaces
from pyboy import PyBoy
from pyboy.utils import WindowEvent
from skimage.transform import resize
# from skimage.transform import resize

import pufferlib
from pokemonred_puffer.data.events import EVENT_FLAGS_START, EVENTS_FLAGS_LENGTH, MUSEUM_TICKET
Expand Down Expand Up @@ -106,6 +106,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
self.global_map_shape = GLOBAL_MAP_SHAPE
if self.reduce_res:
self.screen_output_shape = (72, 80, 1)
else:
Expand All @@ -116,6 +117,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.screen_output_shape[1] // 4,
1,
)
self.global_map_shape = (self.global_map_shape[0], self.global_map_shape[1] // 4, 1)
self.coords_pad = 12
self.enc_freqs = 8

Expand Down Expand Up @@ -148,6 +150,9 @@ def __init__(self, env_config: pufferlib.namespace):
# "global_map": spaces.Box(
# low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
# ),
"global_map": spaces.Box(
low=0, high=255, shape=self.global_map_shape, dtype=np.uint8
),
# Discrete is more apt, but pufferlib is slower at processing Discrete
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
Expand Down Expand Up @@ -423,12 +428,16 @@ def render(self):
)
).astype(np.uint8)
visited_mask = np.expand_dims(visited_mask, -1)
"""
global_map = np.expand_dims(
255 * resize(self.explore_map, game_pixels_render.shape, anti_aliasing=False),
axis=-1,
).astype(np.uint8)
"""
global_map = np.expand_dims(
255 * self.explore_map,
axis=-1,
).astype(np.uint8)

if self.two_bit:
game_pixels_render = (
Expand Down Expand Up @@ -464,13 +473,14 @@ def render(self):
<< np.array([6, 4, 2, 0], dtype=np.uint8)
)
.sum(axis=1, dtype=np.uint8)
.reshape(game_pixels_render.shape)
.reshape(self.global_map_shape)
.astype(np.uint8)
)

return {
"screen": game_pixels_render,
"visited_mask": visited_mask,
# "global_map": global_map,
"global_map": global_map,
}

def _get_obs(self):
Expand Down
35 changes: 26 additions & 9 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def __init__(
nn.ReLU(),
nn.Flatten(),
)
self.global_map_network = nn.Sequential(
nn.LazyConv2d(32, 8, stride=4),
nn.ReLU(),
nn.LazyConv2d(64, 4, stride=2),
nn.ReLU(),
nn.LazyConv2d(64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
)

self.encode_linear = nn.Sequential(
nn.LazyLinear(hidden_size),
Expand Down Expand Up @@ -86,8 +95,14 @@ def encode_observations(self, observations):

screen = observations["screen"]
visited_mask = observations["visited_mask"]
# global_map = observations["global_map"]
global_map = observations["global_map"]
restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3])
restored_global_map_shape = (
global_map.shape[0],
global_map.shape[1],
global_map.shape[2] * 4,
global_map.shape[3],
)

if self.two_bit:
screen = torch.index_select(
Expand All @@ -102,13 +117,13 @@ def encode_observations(self, observations):
.flatten()
.int(),
).reshape(restored_shape)
# global_map = torch.index_select(
# self.linear_buckets,
# 0,
# ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
# .flatten()
# .int(),
# ).reshape(restored_shape)
global_map = torch.index_select(
self.linear_buckets,
0,
((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift)
.flatten()
.int(),
).reshape(restored_global_map_shape)
badges = self.badge_buffer <= observations["badges"]
map_id = self.map_embeddings(observations["map_id"].long())
blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long())
Expand All @@ -122,13 +137,15 @@ def encode_observations(self, observations):
image_observation = torch.cat((screen, visited_mask), dim=-1)
if self.channels_last:
image_observation = image_observation.permute(0, 3, 1, 2)
global_map = global_map.permute(0, 3, 1, 2)
if self.downsample > 1:
image_observation = image_observation[:, :, :: self.downsample, :: self.downsample]

return self.encode_linear(
torch.cat(
(
(self.screen_network(image_observation.float() / 255.0).squeeze(1)),
self.screen_network(image_observation.float() / 255.0).squeeze(1),
self.global_map_network(global_map.float() / 255.0).squeeze(1),
one_hot(observations["direction"].long(), 4).float().squeeze(1),
# one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["battle_type"].long(), 4).float().squeeze(1),
Expand Down

0 comments on commit 9135a62

Please sign in to comment.