From c444aad40468dc91e54441186ba6db7b4a3f7811 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 23 Mar 2024 23:54:36 -0400 Subject: [PATCH] 2-bit encoding of screen --- config.yaml | 5 +- pokemonred_puffer/environment.py | 114 +++++++++++------- .../policies/multi_convolutional.py | 76 ++++++++++-- 3 files changed, 134 insertions(+), 61 deletions(-) diff --git a/config.yaml b/config.yaml index 47c559f..03e9fe3 100644 --- a/config.yaml +++ b/config.yaml @@ -41,6 +41,7 @@ env: frame_stacks: 1 perfect_ivs: True reduce_res: True + two_bit: True log_frequency: 2000 train: @@ -147,16 +148,12 @@ rewards: taught_cut: 10.0 explore_npcs: 0.02 explore_hidden_objs: 0.02 - seen_pokemon: 4.0 - caught_pokemon: 4.0 - level: 1.0 policies: multi_convolutional.MultiConvolutionalPolicy: policy: hidden_size: 512 - output_size: 512 recurrent: # Assumed to be in the same module as the policy diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index ae56df3..fa609b8 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -17,6 +17,8 @@ import pufferlib from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global +PIXEL_VALUES = np.array([0, 85, 153, 255], dtype=np.uint8) + EVENT_FLAGS_START = 0xD747 EVENTS_FLAGS_LENGTH = 320 MUSEUM_TICKET = (0xD754, 0) @@ -150,18 +152,24 @@ def __init__(self, env_config: pufferlib.namespace): self.max_steps = env_config.max_steps self.save_video = env_config.save_video self.fast_video = env_config.fast_video - self.frame_stacks = env_config.frame_stacks self.perfect_ivs = env_config.perfect_ivs self.reduce_res = env_config.reduce_res self.gb_path = env_config.gb_path self.log_frequency = env_config.log_frequency + self.two_bit = env_config.two_bit self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? if self.reduce_res: - self.screen_output_shape = (72, 80, 3 * self.frame_stacks) + self.screen_output_shape = (72, 80, 1) else: - self.screen_output_shape = (144, 160, 3 * self.frame_stacks) + self.screen_output_shape = (144, 160, 1) + if self.two_bit: + self.screen_output_shape = ( + self.screen_output_shape[0], + self.screen_output_shape[1] // 4, + 1, + ) self.coords_pad = 12 self.enc_freqs = 8 @@ -188,6 +196,12 @@ def __init__(self, env_config: pufferlib.namespace): "screen": spaces.Box( low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 ), + "visited_mask": spaces.Box( + low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 + ), + "global_map": spaces.Box( + low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 + ), # Discrete is more apt, but pufferlib is slower at processing Discrete "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), @@ -247,8 +261,8 @@ def reset(self, seed: Optional[int] = None): # restart game, skipping credits self.explore_map_dim = 384 if self.first: - self.recent_screens = deque() # np.zeros(self.output_shape, dtype=np.uint8) - self.recent_actions = deque() # np.zeros((self.frame_stacks,), dtype=np.uint8) + self.recent_screens = deque() + self.recent_actions = deque() self.seen_pokemon = np.zeros(152, dtype=np.uint8) self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) @@ -338,9 +352,10 @@ def reset_mem(self): def render(self): # (144, 160, 3) - game_pixels_render = self.screen.ndarray[:, :, 0:1] + game_pixels_render = np.expand_dims(self.screen.ndarray[:, :, 1], axis=-1) if self.reduce_res: game_pixels_render = game_pixels_render[::2, ::2, :] + # place an overlay on top of the screen greying out places we haven't visited # first get our location player_x, player_y, map_n = self.get_game_coords() @@ -388,7 +403,7 @@ def render(self): player_y + y + 1, map_n, ), - 0.15, + 0, ) * 255 ) @@ -422,31 +437,59 @@ def render(self): ).astype(np.uint8) visited_mask = np.expand_dims(visited_mask, -1) """ - # game_pixels_render = np.concatenate([game_pixels_render, visited_mask, cut_mask], axis=-1) - game_pixels_render = np.concatenate([game_pixels_render, visited_mask], axis=-1) - - return game_pixels_render - - def _get_screen_obs(self): - screen = self.render() - screen = np.concatenate( - [ - screen, - np.expand_dims( - 255 * resize(self.explore_map, screen.shape[:-1], anti_aliasing=False), - axis=-1, - ).astype(np.uint8), - ], + + global_map = np.expand_dims( + 255 * resize(self.explore_map, game_pixels_render.shape, anti_aliasing=False), axis=-1, - ) + ).astype(np.uint8) - self.update_recent_screens(screen) - return screen + if self.two_bit: + game_pixels_render = ( + ( + np.digitize( + game_pixels_render.reshape((-1, 4)), PIXEL_VALUES, right=True + ).astype(np.uint8) + << np.array([6, 4, 2, 0], dtype=np.uint8) + ) + .sum(axis=1, dtype=np.uint8) + .reshape((-1, game_pixels_render.shape[1] // 4, 1)) + ) + visited_mask = ( + ( + np.digitize( + visited_mask.reshape((-1, 4)), + np.array([0, 64, 128, 255], dtype=np.uint8), + right=True, + ).astype(np.uint8) + << np.array([6, 4, 2, 0], dtype=np.uint8) + ) + .sum(axis=1, dtype=np.uint8) + .reshape(game_pixels_render.shape) + .astype(np.uint8) + ) + global_map = ( + ( + np.digitize( + global_map.reshape((-1, 4)), + np.array([0, 64, 128, 255], dtype=np.uint8), + right=True, + ).astype(np.uint8) + << np.array([6, 4, 2, 0], dtype=np.uint8) + ) + .sum(axis=1, dtype=np.uint8) + .reshape(game_pixels_render.shape) + ) + + return { + "screen": game_pixels_render, + "visited_mask": visited_mask, + "global_map": global_map, + } def _get_obs(self): - player_x, player_y, map_n = self.get_game_coords() + # player_x, player_y, map_n = self.get_game_coords() return { - "screen": self._get_screen_obs(), + **self.render(), "direction": np.array( self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 ), @@ -515,9 +558,6 @@ def run_action_on_emulator(self, action): self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) - if self.save_video and self.fast_video: - self.add_video_frame() - def hidden_object_hook(self, *args, **kwargs): hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]] map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]] @@ -693,20 +733,6 @@ def get_explore_map(self): return explore_map - def update_recent_screens(self, cur_screen): - # self.recent_screens = np.roll(self.recent_screens, 1, axis=2) - # self.recent_screens[:, :, 0] = cur_screen[:, :, 0] - self.recent_screens.append(cur_screen) - if len(self.recent_screens) > self.frame_stacks: - self.recent_screens.popleft() - - def update_recent_actions(self, action): - # self.recent_actions = np.roll(self.recent_actions, 1) - # self.recent_actions[0] = action - self.recent_actions.append(action) - if len(self.recent_actions) > self.frame_stacks: - self.recent_actions.popleft() - def update_reward(self): # compute reward self.progress_reward = self.get_game_state_reward() diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index fdc2e75..f3cc06b 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -4,6 +4,8 @@ import pufferlib.models from pufferlib.emulation import unpack_batched_obs +from pokemonred_puffer.environment import PIXEL_VALUES + unpack_batched_obs = torch.compiler.disable(unpack_batched_obs) # Because torch.nn.functional.one_hot cannot be traced by torch as of 2.2.0 @@ -26,7 +28,6 @@ def __init__( self, env, hidden_size=512, - output_size=512, channels_last: bool = True, downsample: int = 1, ): @@ -52,24 +53,73 @@ def __init__( self.actor = nn.LazyLinear(self.num_actions) self.value_fn = nn.LazyLinear(1) + self.two_bit = env.unwrapped.env.two_bit + + self.register_buffer( + "screen_buckets", torch.tensor(PIXEL_VALUES, dtype=torch.uint8), persistent=False + ) + self.register_buffer( + "linear_buckets", torch.tensor([0, 64, 128, 255], dtype=torch.uint8), persistent=False + ) + self.register_buffer( + "unpack_mask", + torch.tensor([0xC0, 0x30, 0x0C, 0x03], dtype=torch.uint8), + persistent=False, + ) + self.register_buffer( + "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False + ) + def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) - output = [] - for okey, network in zip( - ("screen",), - (self.screen_network,), - ): - observation = observations[okey] - if self.channels_last: - observation = observation.permute(0, 3, 1, 2) - if self.downsample > 1: - observation = observation[:, :, :: self.downsample, :: self.downsample] - output.append(network(observation.float() / 255.0).squeeze(1)) + screen = observations["screen"] + visited_mask = observations["visited_mask"] + global_map = observations["global_map"] + restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3]) + + if self.two_bit: + screen = torch.index_select( + self.screen_buckets, + 0, + screen.flatten() + .unsqueeze(-1) + .bitwise_and(self.unpack_mask) + .bitwise_right_shift(self.unpack_shift) + .flatten() + .int(), + ).reshape(restored_shape) + visited_mask = torch.index_select( + self.linear_buckets, + 0, + visited_mask.flatten() + .unsqueeze(-1) + .bitwise_and(self.unpack_mask) + .bitwise_right_shift(self.unpack_shift) + .flatten() + .int(), + ).reshape(restored_shape) + global_map = torch.index_select( + self.linear_buckets, + 0, + global_map.flatten() + .unsqueeze(-1) + .bitwise_and(self.unpack_mask) + .bitwise_right_shift(self.unpack_shift) + .flatten() + .int(), + ).reshape(restored_shape) + + image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) + if self.channels_last: + image_observation = image_observation.permute(0, 3, 1, 2) + if self.downsample > 1: + image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] + return self.encode_linear( torch.cat( ( - *output, + (self.screen_network(image_observation.float() / 255.0).squeeze(1)), one_hot(observations["direction"].long(), 4).float().squeeze(1), # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), one_hot(observations["battle_type"].long(), 4).float().squeeze(1),