From 5673afaf37c2e92922a71faef85a70a3e43f23c6 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Mon, 24 Jun 2024 12:51:48 -0400 Subject: [PATCH] More event obs --- config.yaml | 2 +- pokemonred_puffer/environment.py | 39 ++++++++++--------- .../policies/multi_convolutional.py | 12 +++--- pokemonred_puffer/rewards/baseline.py | 8 ++-- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/config.yaml b/config.yaml index 0deb0e4..35583b5 100644 --- a/config.yaml +++ b/config.yaml @@ -227,7 +227,7 @@ rewards: bag_menu: 0.1 rocket_hideout_found: 5.0 explore_hidden_objs: 0.02 - seen_action_bag_menu: 0.1G + seen_action_bag_menu: 0.1 baseline.CutWithObjectRewardRequiredEventsEnv: reward: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index ce344c3..2dcd7d3 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -155,18 +155,19 @@ def __init__(self, env_config: pufferlib.namespace): "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), "blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - "cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), - "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "bag_items": spaces.Box( low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 ), "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), + } | { + event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8) + for event in REQUIRED_EVENTS } if self.use_global_map: @@ -498,22 +499,24 @@ def _get_obs(self): # item ids start at 1 so using 0 as the nothing value is okay bag[2 * numBagItems :] = 0 - return self.render() | { - "direction": np.array( - self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 - ), - "blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), - "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), - "cut_event": np.array(self.events.get_event("EVENT_GOT_HM01"), dtype=np.uint8), - "cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8), - # "x": np.array(player_x, dtype=np.uint8), - # "y": np.array(player_y, dtype=np.uint8), - "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), - "badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8), - "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), - "bag_items": bag[::2].copy(), - "bag_quantity": bag[1::2].copy(), - } + return ( + self.render() + | { + "direction": np.array( + self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 + ), + "blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), + "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), + "cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8), + # "x": np.array(player_x, dtype=np.uint8), + # "y": np.array(player_y, dtype=np.uint8), + "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), + "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), + "bag_items": bag[::2].copy(), + "bag_quantity": bag[1::2].copy(), + } + | {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS} + ) def set_perfect_iv_dvs(self): party_size = self.read_m("wPartyCount") diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 37cfdc2..706c9b9 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -4,6 +4,7 @@ import torch from torch import nn +from pokemonred_puffer.data.events import REQUIRED_EVENTS from pokemonred_puffer.data.items import Items from pokemonred_puffer.environment import PIXEL_VALUES @@ -91,7 +92,7 @@ def __init__( self.register_buffer( "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False ) - self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False) + # self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False) # pokemon has 0xF7 map ids # Lets start with 4 dims for now. Could try 8 @@ -144,7 +145,7 @@ def encode_observations(self, observations): .flatten() .int(), ).reshape(restored_global_map_shape) - badges = self.badge_buffer <= observations["badges"] + # badges = self.badge_buffer <= observations["badges"] map_id = self.map_embeddings(observations["map_id"].long()) blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()) # The bag quantity can be a value between 1 and 99 @@ -170,17 +171,18 @@ def encode_observations(self, observations): one_hot(observations["direction"].long(), 4).float().squeeze(1), # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), one_hot(observations["battle_type"].long(), 4).float().squeeze(1), - observations["cut_event"].float(), + # observations["cut_event"].float(), observations["cut_in_party"].float(), # observations["x"].float(), # observations["y"].float(), # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), - badges.float().squeeze(1), + # badges.float().squeeze(1), map_id.squeeze(1), blackout_map_id.squeeze(1), observations["wJoyIgnore"].float(), items.flatten(start_dim=1), - ), + ) + + tuple(observations[event].float() for event in REQUIRED_EVENTS), dim=-1, ) if self.use_global_map: diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index ecb79eb..bc979ed 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -260,14 +260,12 @@ def get_game_state_reward(self): "rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4), } | { - event: self.events.get_event(event) * self.reward_config["required_event"] + event: self.reward_config["required_event"] * float(self.events.get_event(event)) for event in REQUIRED_EVENTS } | { - "required_items": { - item.name: int(item.value in bag_item_ids) * self.reward_config["required_item"] - for item in REQUIRED_ITEMS - }, + item.name: self.reward_config["required_item"] * float(item.value in bag_item_ids) + for item in REQUIRED_ITEMS } )