diff --git a/config.yaml b/config.yaml index fe8373c..6ad611c 100644 --- a/config.yaml +++ b/config.yaml @@ -61,6 +61,7 @@ env: auto_remove_all_nonuseful_items: True auto_pokeflute: True infinite_money: True + use_global_map: False train: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 60e6f12..3af4b2d 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -103,6 +103,7 @@ def __init__(self, env_config: pufferlib.namespace): self.auto_remove_all_nonuseful_items = env_config.auto_remove_all_nonuseful_items self.auto_pokeflute = env_config.auto_pokeflute self.infinite_money = env_config.infinite_money + self.use_global_map = env_config.use_global_map self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? @@ -139,38 +140,34 @@ def __init__(self, env_config: pufferlib.namespace): v: i for i, v in enumerate([40, 0, 12, 1, 13, 51, 2, 54, 14, 59, 60, 61, 15, 3, 65]) } - self.observation_space = spaces.Dict( - { - "screen": spaces.Box( - low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - ), - "visited_mask": spaces.Box( - low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - ), - # "global_map": spaces.Box( - # low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - # ), - "global_map": spaces.Box( - low=0, high=255, shape=self.global_map_shape, dtype=np.uint8 - ), - # Discrete is more apt, but pufferlib is slower at processing Discrete - "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - "blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - "cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), - # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), - "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - "bag_items": spaces.Box( - low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 - ), - "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), - } - ) + obs_dict = { + "screen": spaces.Box(low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8), + "visited_mask": spaces.Box( + low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 + ), + # Discrete is more apt, but pufferlib is slower at processing Discrete + "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), + "blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), + "cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), + # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), + "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "bag_items": spaces.Box( + low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 + ), + "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), + } + + if self.use_global_map: + obs_dict["global_map"] = spaces.Box( + low=0, high=255, shape=self.global_map_shape, dtype=np.uint8 + ) + self.observation_space = spaces.Dict(obs_dict) self.pyboy = PyBoy( env_config.gb_path, @@ -434,10 +431,11 @@ def render(self): axis=-1, ).astype(np.uint8) """ - global_map = np.expand_dims( - 255 * self.explore_map, - axis=-1, - ).astype(np.uint8) + if self.use_global_map: + global_map = np.expand_dims( + 255 * self.explore_map, + axis=-1, + ).astype(np.uint8) if self.two_bit: game_pixels_render = ( @@ -463,25 +461,24 @@ def render(self): .reshape(game_pixels_render.shape) .astype(np.uint8) ) - global_map = ( - ( - np.digitize( - global_map.reshape((-1, 4)), - np.array([0, 64, 128, 255], dtype=np.uint8), - right=True, - ).astype(np.uint8) - << np.array([6, 4, 2, 0], dtype=np.uint8) + if self.use_global_map: + global_map = ( + ( + np.digitize( + global_map.reshape((-1, 4)), + np.array([0, 64, 128, 255], dtype=np.uint8), + right=True, + ).astype(np.uint8) + << np.array([6, 4, 2, 0], dtype=np.uint8) + ) + .sum(axis=1, dtype=np.uint8) + .reshape(self.global_map_shape) ) - .sum(axis=1, dtype=np.uint8) - .reshape(self.global_map_shape) - .astype(np.uint8) - ) return { "screen": game_pixels_render, "visited_mask": visited_mask, - "global_map": global_map, - } + } | ({"global_map": global_map} if self.use_global_map else {}) def _get_obs(self): # player_x, player_y, map_n = self.get_game_coords() diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 0a170b1..f65ec29 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -64,6 +64,7 @@ def __init__( self.value_fn = nn.LazyLinear(1) self.two_bit = env.unwrapped.env.two_bit + self.use_global_map = env.unwrapped.env.use_global_map self.register_buffer( "screen_buckets", torch.tensor(PIXEL_VALUES, dtype=torch.uint8), persistent=False @@ -95,14 +96,15 @@ def encode_observations(self, observations): screen = observations["screen"] visited_mask = observations["visited_mask"] - global_map = observations["global_map"] restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3]) - restored_global_map_shape = ( - global_map.shape[0], - global_map.shape[1], - global_map.shape[2] * 4, - global_map.shape[3], - ) + if self.use_global_map: + global_map = observations["global_map"] + restored_global_map_shape = ( + global_map.shape[0], + global_map.shape[1], + global_map.shape[2] * 4, + global_map.shape[3], + ) if self.two_bit: screen = torch.index_select( @@ -117,13 +119,14 @@ def encode_observations(self, observations): .flatten() .int(), ).reshape(restored_shape) - global_map = torch.index_select( - self.linear_buckets, - 0, - ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) - .flatten() - .int(), - ).reshape(restored_global_map_shape) + if self.use_global_map: + global_map = torch.index_select( + self.linear_buckets, + 0, + ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) + .flatten() + .int(), + ).reshape(restored_global_map_shape) badges = self.badge_buffer <= observations["badges"] map_id = self.map_embeddings(observations["map_id"].long()) blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()) @@ -137,32 +140,39 @@ def encode_observations(self, observations): image_observation = torch.cat((screen, visited_mask), dim=-1) if self.channels_last: image_observation = image_observation.permute(0, 3, 1, 2) - global_map = global_map.permute(0, 3, 1, 2) + if self.use_global_map: + global_map = global_map.permute(0, 3, 1, 2) if self.downsample > 1: image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] - return self.encode_linear( - torch.cat( + cat_obs = torch.cat( + ( + self.screen_network(image_observation.float() / 255.0).squeeze(1), + one_hot(observations["direction"].long(), 4).float().squeeze(1), + # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), + one_hot(observations["battle_type"].long(), 4).float().squeeze(1), + observations["cut_event"].float(), + observations["cut_in_party"].float(), + # observations["x"].float(), + # observations["y"].float(), + # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), + badges.float().squeeze(1), + map_id.squeeze(1), + blackout_map_id.squeeze(1), + observations["wJoyIgnore"].float(), + items.flatten(start_dim=1), + ), + dim=-1, + ) + if self.use_global_map: + cat_obs = torch.cat( ( - self.screen_network(image_observation.float() / 255.0).squeeze(1), + cat_obs, self.global_map_network(global_map.float() / 255.0).squeeze(1), - one_hot(observations["direction"].long(), 4).float().squeeze(1), - # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), - one_hot(observations["battle_type"].long(), 4).float().squeeze(1), - observations["cut_event"].float(), - observations["cut_in_party"].float(), - # observations["x"].float(), - # observations["y"].float(), - # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), - badges.float().squeeze(1), - map_id.squeeze(1), - blackout_map_id.squeeze(1), - observations["wJoyIgnore"].float(), - items.flatten(start_dim=1), ), dim=-1, ) - ), None + return self.encode_linear(cat_obs), None def decode_actions(self, flat_hidden, lookup, concat=None): action = self.actor(flat_hidden)