Skip to content

Commit

Permalink
obs gets all events
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 29, 2024
1 parent 72b2639 commit b87fb26
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
6 changes: 6 additions & 0 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import c_uint8, LittleEndianStructure, Union
import re

from pyboy import PyBoy

Expand Down Expand Up @@ -2585,6 +2586,11 @@ def get_event(self, event_name: str) -> bool:
return bool(getattr(self.b, event_name))


EVENTS = {
event for event, _ in EventFlagsBits._fields_ if not re.match("EVENT_[0-9,A-F]{3}$", event)
}


REQUIRED_EVENTS = {
"EVENT_FOLLOWED_OAK_INTO_LAB",
"EVENT_PALLET_AFTER_GETTING_POKEBALLS",
Expand Down
9 changes: 4 additions & 5 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pufferlib
from pokemonred_puffer.data.events import (
EVENT_FLAGS_START,
EVENTS,
EVENTS_FLAGS_LENGTH,
MUSEUM_TICKET,
REQUIRED_EVENTS,
Expand Down Expand Up @@ -186,9 +187,7 @@ def __init__(self, env_config: pufferlib.namespace):
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint32),
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
# Add 3 for rival_3, game corner rocket and saffron guard
"required_events": spaces.Box(
low=0, high=1, shape=(len(REQUIRED_EVENTS) + 3,), dtype=np.uint8
),
"events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8),
}

if self.use_global_map:
Expand Down Expand Up @@ -570,8 +569,8 @@ def _get_obs(self):
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"required_events": np.array(
[self.events.get_event(event) for event in REQUIRED_EVENTS]
"events": np.array(
[self.events.get_event(event) for event in EVENTS]
+ [
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket
Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32)

# event embeddings
n_events = env.env.observation_space["required_events"].shape[0]
n_events = env.env.observation_space["events"].shape[0]
self.event_embeddings = nn.Embedding(n_events, int(n_events**0.25) + 1, dtype=torch.float32)

def forward(self, observations):
Expand Down Expand Up @@ -205,7 +205,7 @@ def encode_observations(self, observations):
party_latent = self.party_network(party_obs)

event_obs = (
observations["required_events"].float() @ self.event_embeddings.weight
observations["events"].float() @ self.event_embeddings.weight
) / self.event_embeddings.weight.shape[0]
cat_obs = torch.cat(
(
Expand Down

0 comments on commit b87fb26

Please sign in to comment.