Skip to content

Commit

Permalink
events obs is now unpacked on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 20, 2024
1 parent 8c201c8 commit 96a6825
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
5 changes: 5 additions & 0 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2611,6 +2611,11 @@ def set_event(self, event_name: str, value: bool):
EVENTS = {
event for event, _, _ in EventFlagsBits._fields_ if not re.search("EVENT_[0-9,A-F]{3}$", event)
}
EVENTS_IDXS = [
i
for i, (event, _, _) in enumerate(EventFlagsBits._fields_)
if not re.search("EVENT_[0-9,A-F]{3}$", event)
]


REQUIRED_EVENTS = {
Expand Down
37 changes: 16 additions & 21 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pokemonred_puffer.data.elevators import NEXT_ELEVATORS
from pokemonred_puffer.data.events import (
EVENT_FLAGS_START,
EVENTS,
EVENTS_FLAGS_LENGTH,
MUSEUM_TICKET,
REQUIRED_EVENTS,
Expand Down Expand Up @@ -212,7 +211,11 @@ def __init__(self, env_config: DictConfig):
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint32),
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
# Add 4 for rival_3, game corner rocket, saffron guard and lapras
"events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 4,), dtype=np.uint8),
"events": spaces.Box(low=0, high=1, shape=(320,), dtype=np.uint8),
"rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"lapras": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
}
if not self.skip_safari_zone:
obs_dict["safari_steps"] = spaces.Box(low=0, high=502.0, shape=(1,), dtype=np.uint32)
Expand Down Expand Up @@ -617,25 +620,17 @@ def _get_obs(self):
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"events": np.concatenate(
(
np.fromiter(self.events.get_events(EVENTS), dtype=np.uint8),
np.array(
[
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable(
"HS_GAME_CORNER_ROCKET"
), # game corner rocket
self.flags.get_bit(
"BIT_GAVE_SAFFRON_GUARDS_DRINK"
), # saffron guard
self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras
],
dtype=np.uint8,
),
),
dtype=np.uint8,
),
"events": np.array(self.events.asbytes, dtype=np.uint8),
"rival_3": np.array(
self.read_m("wSSAnne2FCurScript") == 4, dtype=np.uint8
), # rival 3
"game_corner_rocket": np.array(
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), np.uint8
), # game corner rocket
"saffron_guard": np.array(
self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), np.uint8
), # saffron guard
"lapras": np.array(self.flags.get_bit("BIT_GOT_LAPRAS"), np.uint8), # got lapras
}
| (
{}
Expand Down
30 changes: 29 additions & 1 deletion pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import nn

from pokemonred_puffer.data.events import EVENTS_IDXS
from pokemonred_puffer.data.items import Items
from pokemonred_puffer.environment import PIXEL_VALUES

Expand Down Expand Up @@ -92,6 +93,16 @@ def __init__(
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)
self.register_buffer(
"unpack_bytes_mask",
torch.tensor([0x80, 0x40, 0x20, 0x10, 0x8, 0x4, 0x2, 0x1], dtype=torch.uint8),
persistent=False,
)
self.register_buffer(
"unpack_bytes_shift",
torch.tensor([7, 6, 5, 4, 3, 2, 1, 0], dtype=torch.uint8),
persistent=False,
)
# self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False)

# pokemon has 0xF7 map ids
Expand Down Expand Up @@ -209,6 +220,19 @@ def encode_observations(self, observations):
# event_obs = (
# observations["events"].float() @ self.event_embeddings.weight
# ) / self.event_embeddings.weight.shape[0]
events_obs = (
(
(
(observations["events"].reshape((-1, 1)) & self.unpack_bytes_mask)
>> self.unpack_bytes_shift
)
.flatten()
.reshape((observations["events"].shape[0], -1))[:, EVENTS_IDXS]
)
.float()
.squeeze(1)
)

cat_obs = torch.cat(
(
self.screen_network(image_observation.float() / 255.0).squeeze(1),
Expand All @@ -224,7 +248,11 @@ def encode_observations(self, observations):
blackout_map_id.squeeze(1),
items.flatten(start_dim=1),
party_latent,
observations["events"].float().squeeze(1),
events_obs,
observations["rival_3"].float(),
observations["game_corner_rocket"].float(),
observations["saffron_guard"].float(),
observations["lapras"].float(),
),
dim=-1,
)
Expand Down

0 comments on commit 96a6825

Please sign in to comment.