Skip to content

Commit

Permalink
More event obs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 24, 2024
1 parent caa18b5 commit 5673afa
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ rewards:
bag_menu: 0.1
rocket_hideout_found: 5.0
explore_hidden_objs: 0.02
seen_action_bag_menu: 0.1G
seen_action_bag_menu: 0.1

baseline.CutWithObjectRewardRequiredEventsEnv:
reward:
Expand Down
39 changes: 21 additions & 18 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,19 @@ def __init__(self, env_config: pufferlib.namespace):
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8),
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16),
"badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"bag_items": spaces.Box(
low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8
),
"bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8),
} | {
event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8)
for event in REQUIRED_EVENTS
}

if self.use_global_map:
Expand Down Expand Up @@ -498,22 +499,24 @@ def _get_obs(self):
# item ids start at 1 so using 0 as the nothing value is okay
bag[2 * numBagItems :] = 0

return self.render() | {
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
"blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_event": np.array(self.events.get_event("EVENT_GOT_HM01"), dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8),
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
}
return (
self.render()
| {
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
"blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
}
| {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS}
)

def set_perfect_iv_dvs(self):
party_size = self.read_m("wPartyCount")
Expand Down
12 changes: 7 additions & 5 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import nn

from pokemonred_puffer.data.events import REQUIRED_EVENTS
from pokemonred_puffer.data.items import Items
from pokemonred_puffer.environment import PIXEL_VALUES

Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)
self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False)
# self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False)

# pokemon has 0xF7 map ids
# Lets start with 4 dims for now. Could try 8
Expand Down Expand Up @@ -144,7 +145,7 @@ def encode_observations(self, observations):
.flatten()
.int(),
).reshape(restored_global_map_shape)
badges = self.badge_buffer <= observations["badges"]
# badges = self.badge_buffer <= observations["badges"]
map_id = self.map_embeddings(observations["map_id"].long())
blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long())
# The bag quantity can be a value between 1 and 99
Expand All @@ -170,17 +171,18 @@ def encode_observations(self, observations):
one_hot(observations["direction"].long(), 4).float().squeeze(1),
# one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["battle_type"].long(), 4).float().squeeze(1),
observations["cut_event"].float(),
# observations["cut_event"].float(),
observations["cut_in_party"].float(),
# observations["x"].float(),
# observations["y"].float(),
# one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1),
badges.float().squeeze(1),
# badges.float().squeeze(1),
map_id.squeeze(1),
blackout_map_id.squeeze(1),
observations["wJoyIgnore"].float(),
items.flatten(start_dim=1),
),
)
+ tuple(observations[event].float() for event in REQUIRED_EVENTS),
dim=-1,
)
if self.use_global_map:
Expand Down
8 changes: 3 additions & 5 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,12 @@ def get_game_state_reward(self):
"rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4),
}
| {
event: self.events.get_event(event) * self.reward_config["required_event"]
event: self.reward_config["required_event"] * float(self.events.get_event(event))
for event in REQUIRED_EVENTS
}
| {
"required_items": {
item.name: int(item.value in bag_item_ids) * self.reward_config["required_item"]
for item in REQUIRED_ITEMS
},
item.name: self.reward_config["required_item"] * float(item.value in bag_item_ids)
for item in REQUIRED_ITEMS
}
)

Expand Down

0 comments on commit 5673afa

Please sign in to comment.