From 2f10f1193d694bba48bf5a5613cdf56e68807f92 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 29 Jun 2024 11:03:10 -0400 Subject: [PATCH] Use float32 instead of uint16 for now --- pokemonred_puffer/environment.py | 24 +++++++++---------- .../policies/multi_convolutional.py | 2 -- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 84cfbb3..055b4d6 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -172,16 +172,16 @@ def __init__(self, env_config: pufferlib.namespace): "saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # This could be a dict within a sequence, but we'll do it like this and concat later "species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8), - "hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), "status": spaces.Box(low=0, high=7, shape=(6,), dtype=np.uint8), "type1": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8), "type2": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8), "level": spaces.Box(low=0, high=100, shape=(6,), dtype=np.uint8), - "maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), - "attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), - "defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), - "speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), - "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), + "attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), + "defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), + "speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), + "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.float32), "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), } | { event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8) @@ -537,16 +537,16 @@ def _get_obs(self): self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8 ), "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), - "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16), + "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.float32), "status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8), "type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8), "type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8), "level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8), - "maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint16), - "attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint16), - "defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint16), - "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint16), - "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16), + "maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.float32), + "attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.float32), + "defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.float32), + "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.float32), + "special": np.array([self.party[i].Special for i in range(6)], dtype=np.float32), "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), } | {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS} diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 8ea5b97..c38ac63 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -8,8 +8,6 @@ from pokemonred_puffer.data.items import Items from pokemonred_puffer.environment import PIXEL_VALUES -pufferlib.pytorch.nativize_tensor = torch.compiler.disable(pufferlib.pytorch.nativize_tensor) - # Because torch.nn.functional.one_hot cannot be traced by torch as of 2.2.0 def one_hot(tensor, num_classes):