Skip to content

Commit

Permalink
uint16 over float32 again
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 1, 2024
1 parent 1790c08 commit 7d52d79
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,16 @@ def __init__(self, env_config: pufferlib.namespace):
"saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# This could be a dict within a sequence, but we'll do it like this and concat later
"species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8),
"hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"status": spaces.Box(low=0, high=7, shape=(6,), dtype=np.uint8),
"type1": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8),
"type2": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8),
"level": spaces.Box(low=0, high=100, shape=(6,), dtype=np.uint8),
"maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32),
"maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
} | {
event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8)
Expand Down Expand Up @@ -537,16 +537,16 @@ def _get_obs(self):
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8
),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.float32),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16),
"status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8),
"type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8),
"type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8),
"level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8),
"maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.float32),
"attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.float32),
"defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.float32),
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.float32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.float32),
"maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint16),
"attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint16),
"defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint16),
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint16),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
}
| {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS}
Expand Down

0 comments on commit 7d52d79

Please sign in to comment.