Skip to content

Commit

Permalink
clean up event obs
Browse files Browse the repository at this point in the history
thatguy11325 committed Jul 25, 2024
1 parent a8f8a1f commit 97f65f2
Showing 3 changed files with 26 additions and 33 deletions.
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ debug:
device: cpu
compile: False
compile_mode: default
num_envs: 16
num_envs: 1
envs_per_worker: 1
num_workers: 1
env_batch_size: 32
@@ -37,8 +37,8 @@ debug:
verbose: False
env_pool: False
load_optimizer_state: False
swarm_frequency: 1
swarm_keep_pct: .1
# swarm_frequency: 1
# swarm_keep_pct: .1

env:
headless: True
21 changes: 9 additions & 12 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
@@ -173,9 +173,6 @@ def __init__(self, env_config: pufferlib.namespace):
low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8
),
"bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8),
"rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# This could be a dict within a sequence, but we'll do it like this and concat later
"species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8),
"hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
@@ -189,8 +186,9 @@ def __init__(self, env_config: pufferlib.namespace):
"speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
# Add 3 for rival_3, game corner rocket and saffron guard
"required_events": spaces.Box(
low=0, high=1, shape=(len(REQUIRED_EVENTS),), dtype=np.uint8
low=0, high=1, shape=(len(REQUIRED_EVENTS) + 3,), dtype=np.uint8
),
}

@@ -559,13 +557,6 @@ def _get_obs(self):
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
"rival_3": np.array(self.read_m("wSSAnne2FCurScript") == 4, dtype=np.uint8),
"game_corner_rocket": np.array(
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), dtype=np.uint8
),
"saffron_guard": np.array(
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8
),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16),
"status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8),
@@ -579,7 +570,13 @@ def _get_obs(self):
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"required_events": np.array(
[self.events.get_event(event) for event in REQUIRED_EVENTS], dtype=np.uint8
[self.events.get_event(event) for event in REQUIRED_EVENTS]
+ [
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard
],
dtype=np.uint8,
),
}

32 changes: 14 additions & 18 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import torch
from torch import nn

from pokemonred_puffer.data.events import REQUIRED_EVENTS
from pokemonred_puffer.data.items import Items
from pokemonred_puffer.environment import PIXEL_VALUES

@@ -110,9 +109,8 @@ def __init__(
self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32)

# event embeddings
self.event_embeddings = nn.Embedding(
len(REQUIRED_EVENTS), int(len(REQUIRED_EVENTS) ** 0.25) + 1, dtype=torch.float32
)
n_events = env.env.observation_space["required_events"].shape[0]
self.event_embeddings = nn.Embedding(n_events, int(n_events**0.25) + 1, dtype=torch.float32)

def forward(self, observations):
hidden, lookup = self.encode_observations(observations)
@@ -157,13 +155,14 @@ def encode_observations(self, observations):
.int(),
).reshape(restored_global_map_shape)
# badges = self.badge_buffer <= observations["badges"]
map_id = self.map_embeddings(observations["map_id"].long())
blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long())
map_id = self.map_embeddings(observations["map_id"].long()).squeeze(1)
blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()).squeeze(1)
# The bag quantity can be a value between 1 and 99
# TODO: Should items be positionally encoded? I dont think it matters
items = self.item_embeddings(observations["bag_items"].squeeze(1).long()).float() * (
observations["bag_quantity"].squeeze(1).float().unsqueeze(-1) / 100.0
)
items = (
self.item_embeddings(observations["bag_items"].long())
* (observations["bag_quantity"].float().unsqueeze(-1) / 100.0)
).squeeze(1)

# image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
image_observation = torch.cat((screen, visited_mask), dim=-1)
@@ -178,9 +177,9 @@ def encode_observations(self, observations):

# party network
species = self.species_embeddings(observations["species"].squeeze(1).long()).float()
status = one_hot(observations["status"].long(), 7).squeeze(1).float()
type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float()
type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float()
status = one_hot(observations["status"].long(), 7).float().squeeze(1)
type1 = self.type_embeddings(observations["type1"].long()).squeeze(1)
type2 = self.type_embeddings(observations["type2"].long()).squeeze(1)
moves = (
self.moves_embeddings(observations["moves"].squeeze(1).long())
.float()
@@ -205,9 +204,9 @@ def encode_observations(self, observations):
)
party_latent = self.party_network(party_obs)

event_obs = (observations["required_events"].float() @ self.event_embeddings.weight) / len(
REQUIRED_EVENTS
)
event_obs = (
observations["required_events"].float() @ self.event_embeddings.weight
) / self.event_embeddings.weight.shape[0]
cat_obs = torch.cat(
(
self.screen_network(image_observation.float() / 255.0).squeeze(1),
@@ -224,9 +223,6 @@ def encode_observations(self, observations):
blackout_map_id.squeeze(1),
observations["wJoyIgnore"].float(),
items.flatten(start_dim=1),
observations["rival_3"].float(),
observations["game_corner_rocket"].float(),
observations["saffron_guard"].float(),
party_latent,
event_obs,
),

0 comments on commit 97f65f2

Please sign in to comment.