Skip to content

Commit

Permalink
2-bit encoding of screen
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Mar 24, 2024
1 parent b79446c commit c444aad
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 61 deletions.
5 changes: 1 addition & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ env:
frame_stacks: 1
perfect_ivs: True
reduce_res: True
two_bit: True
log_frequency: 2000

train:
Expand Down Expand Up @@ -147,16 +148,12 @@ rewards:
taught_cut: 10.0
explore_npcs: 0.02
explore_hidden_objs: 0.02
seen_pokemon: 4.0
caught_pokemon: 4.0
level: 1.0


policies:
multi_convolutional.MultiConvolutionalPolicy:
policy:
hidden_size: 512
output_size: 512

recurrent:
# Assumed to be in the same module as the policy
Expand Down
114 changes: 70 additions & 44 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import pufferlib
from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global

PIXEL_VALUES = np.array([0, 85, 153, 255], dtype=np.uint8)

EVENT_FLAGS_START = 0xD747
EVENTS_FLAGS_LENGTH = 320
MUSEUM_TICKET = (0xD754, 0)
Expand Down Expand Up @@ -150,18 +152,24 @@ def __init__(self, env_config: pufferlib.namespace):
self.max_steps = env_config.max_steps
self.save_video = env_config.save_video
self.fast_video = env_config.fast_video
self.frame_stacks = env_config.frame_stacks
self.perfect_ivs = env_config.perfect_ivs
self.reduce_res = env_config.reduce_res
self.gb_path = env_config.gb_path
self.log_frequency = env_config.log_frequency
self.two_bit = env_config.two_bit
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
if self.reduce_res:
self.screen_output_shape = (72, 80, 3 * self.frame_stacks)
self.screen_output_shape = (72, 80, 1)
else:
self.screen_output_shape = (144, 160, 3 * self.frame_stacks)
self.screen_output_shape = (144, 160, 1)
if self.two_bit:
self.screen_output_shape = (
self.screen_output_shape[0],
self.screen_output_shape[1] // 4,
1,
)
self.coords_pad = 12
self.enc_freqs = 8

Expand All @@ -188,6 +196,12 @@ def __init__(self, env_config: pufferlib.namespace):
"screen": spaces.Box(
low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
),
"visited_mask": spaces.Box(
low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
),
"global_map": spaces.Box(
low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8
),
# Discrete is more apt, but pufferlib is slower at processing Discrete
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
# "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
Expand Down Expand Up @@ -247,8 +261,8 @@ def reset(self, seed: Optional[int] = None):
# restart game, skipping credits
self.explore_map_dim = 384
if self.first:
self.recent_screens = deque() # np.zeros(self.output_shape, dtype=np.uint8)
self.recent_actions = deque() # np.zeros((self.frame_stacks,), dtype=np.uint8)
self.recent_screens = deque()
self.recent_actions = deque()
self.seen_pokemon = np.zeros(152, dtype=np.uint8)
self.caught_pokemon = np.zeros(152, dtype=np.uint8)
self.moves_obtained = np.zeros(0xA5, dtype=np.uint8)
Expand Down Expand Up @@ -338,9 +352,10 @@ def reset_mem(self):

def render(self):
# (144, 160, 3)
game_pixels_render = self.screen.ndarray[:, :, 0:1]
game_pixels_render = np.expand_dims(self.screen.ndarray[:, :, 1], axis=-1)
if self.reduce_res:
game_pixels_render = game_pixels_render[::2, ::2, :]

# place an overlay on top of the screen greying out places we haven't visited
# first get our location
player_x, player_y, map_n = self.get_game_coords()
Expand Down Expand Up @@ -388,7 +403,7 @@ def render(self):
player_y + y + 1,
map_n,
),
0.15,
0,
)
* 255
)
Expand Down Expand Up @@ -422,31 +437,59 @@ def render(self):
).astype(np.uint8)
visited_mask = np.expand_dims(visited_mask, -1)
"""
# game_pixels_render = np.concatenate([game_pixels_render, visited_mask, cut_mask], axis=-1)
game_pixels_render = np.concatenate([game_pixels_render, visited_mask], axis=-1)

return game_pixels_render

def _get_screen_obs(self):
screen = self.render()
screen = np.concatenate(
[
screen,
np.expand_dims(
255 * resize(self.explore_map, screen.shape[:-1], anti_aliasing=False),
axis=-1,
).astype(np.uint8),
],

global_map = np.expand_dims(
255 * resize(self.explore_map, game_pixels_render.shape, anti_aliasing=False),
axis=-1,
)
).astype(np.uint8)

self.update_recent_screens(screen)
return screen
if self.two_bit:
game_pixels_render = (
(
np.digitize(
game_pixels_render.reshape((-1, 4)), PIXEL_VALUES, right=True
).astype(np.uint8)
<< np.array([6, 4, 2, 0], dtype=np.uint8)
)
.sum(axis=1, dtype=np.uint8)
.reshape((-1, game_pixels_render.shape[1] // 4, 1))
)
visited_mask = (
(
np.digitize(
visited_mask.reshape((-1, 4)),
np.array([0, 64, 128, 255], dtype=np.uint8),
right=True,
).astype(np.uint8)
<< np.array([6, 4, 2, 0], dtype=np.uint8)
)
.sum(axis=1, dtype=np.uint8)
.reshape(game_pixels_render.shape)
.astype(np.uint8)
)
global_map = (
(
np.digitize(
global_map.reshape((-1, 4)),
np.array([0, 64, 128, 255], dtype=np.uint8),
right=True,
).astype(np.uint8)
<< np.array([6, 4, 2, 0], dtype=np.uint8)
)
.sum(axis=1, dtype=np.uint8)
.reshape(game_pixels_render.shape)
)

return {
"screen": game_pixels_render,
"visited_mask": visited_mask,
"global_map": global_map,
}

def _get_obs(self):
player_x, player_y, map_n = self.get_game_coords()
# player_x, player_y, map_n = self.get_game_coords()
return {
"screen": self._get_screen_obs(),
**self.render(),
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
Expand Down Expand Up @@ -515,9 +558,6 @@ def run_action_on_emulator(self, action):
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
self.pyboy.tick(self.action_freq, render=True)

if self.save_video and self.fast_video:
self.add_video_frame()

def hidden_object_hook(self, *args, **kwargs):
hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]]
map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]]
Expand Down Expand Up @@ -693,20 +733,6 @@ def get_explore_map(self):

return explore_map

def update_recent_screens(self, cur_screen):
# self.recent_screens = np.roll(self.recent_screens, 1, axis=2)
# self.recent_screens[:, :, 0] = cur_screen[:, :, 0]
self.recent_screens.append(cur_screen)
if len(self.recent_screens) > self.frame_stacks:
self.recent_screens.popleft()

def update_recent_actions(self, action):
# self.recent_actions = np.roll(self.recent_actions, 1)
# self.recent_actions[0] = action
self.recent_actions.append(action)
if len(self.recent_actions) > self.frame_stacks:
self.recent_actions.popleft()

def update_reward(self):
# compute reward
self.progress_reward = self.get_game_state_reward()
Expand Down
76 changes: 63 additions & 13 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pufferlib.models
from pufferlib.emulation import unpack_batched_obs

from pokemonred_puffer.environment import PIXEL_VALUES

unpack_batched_obs = torch.compiler.disable(unpack_batched_obs)

# Because torch.nn.functional.one_hot cannot be traced by torch as of 2.2.0
Expand All @@ -26,7 +28,6 @@ def __init__(
self,
env,
hidden_size=512,
output_size=512,
channels_last: bool = True,
downsample: int = 1,
):
Expand All @@ -52,24 +53,73 @@ def __init__(
self.actor = nn.LazyLinear(self.num_actions)
self.value_fn = nn.LazyLinear(1)

self.two_bit = env.unwrapped.env.two_bit

self.register_buffer(
"screen_buckets", torch.tensor(PIXEL_VALUES, dtype=torch.uint8), persistent=False
)
self.register_buffer(
"linear_buckets", torch.tensor([0, 64, 128, 255], dtype=torch.uint8), persistent=False
)
self.register_buffer(
"unpack_mask",
torch.tensor([0xC0, 0x30, 0x0C, 0x03], dtype=torch.uint8),
persistent=False,
)
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)

output = []
for okey, network in zip(
("screen",),
(self.screen_network,),
):
observation = observations[okey]
if self.channels_last:
observation = observation.permute(0, 3, 1, 2)
if self.downsample > 1:
observation = observation[:, :, :: self.downsample, :: self.downsample]
output.append(network(observation.float() / 255.0).squeeze(1))
screen = observations["screen"]
visited_mask = observations["visited_mask"]
global_map = observations["global_map"]
restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3])

if self.two_bit:
screen = torch.index_select(
self.screen_buckets,
0,
screen.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)
visited_mask = torch.index_select(
self.linear_buckets,
0,
visited_mask.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)
global_map = torch.index_select(
self.linear_buckets,
0,
global_map.flatten()
.unsqueeze(-1)
.bitwise_and(self.unpack_mask)
.bitwise_right_shift(self.unpack_shift)
.flatten()
.int(),
).reshape(restored_shape)

image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
if self.channels_last:
image_observation = image_observation.permute(0, 3, 1, 2)
if self.downsample > 1:
image_observation = image_observation[:, :, :: self.downsample, :: self.downsample]

return self.encode_linear(
torch.cat(
(
*output,
(self.screen_network(image_observation.float() / 255.0).squeeze(1)),
one_hot(observations["direction"].long(), 4).float().squeeze(1),
# one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["battle_type"].long(), 4).float().squeeze(1),
Expand Down

0 comments on commit c444aad

Please sign in to comment.