Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 31, 2024
1 parent ecf3dad commit d23a097
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
8 changes: 7 additions & 1 deletion pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,17 @@ def evaluate(self):
# V2 implementation
# check if we have a new highest required_events_count with N save states available
# If we do, migrate 100% of states to one of the states
max_event_count, new_state_key = max(self.states.keys())
max_event_count = 0
new_state_key = ""
for key in self.states.keys():
if len(key) > max_event_count:
max_event_count = len(key)
new_state_key = key
max_state: deque = self.states[key]
to_migrate_keys = []
if max_event_count > self.max_event_count and len(max_state) == max_state.maxlen:
to_migrate_keys = self.event_tracker.keys()
self.max_event_count = max_event_count

# Need a way not to reset the env id counter for the driver env
# Until then env ids are 1-indexed
Expand Down
2 changes: 1 addition & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def run_action_on_emulator(self, action):
self.pyboy.tick(self.action_freq, render=False)
while self.read_m("wJoyIgnore"):
self.pyboy.button("a", delay=8)
self.pyboy.tick(24, render=False)
self.pyboy.tick(self.action_freq, render=False)

if self.events.get_event("EVENT_GOT_HM01"):
if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F):
Expand Down
5 changes: 3 additions & 2 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,13 @@ def encode_observations(self, observations):
image_observation = image_observation[:, :, :: self.downsample, :: self.downsample]

# party network
species = self.species_embeddings(observations["species"].squeeze(1).int()).float()
species = self.species_embeddings(observations["species"].int()).float().squeeze(1)
status = one_hot(observations["status"].int(), 7).float().squeeze(1)
type1 = self.type_embeddings(observations["type1"].int()).squeeze(1)
type2 = self.type_embeddings(observations["type2"].int()).squeeze(1)
moves = (
self.moves_embeddings(observations["moves"].squeeze(1).int())
self.moves_embeddings(observations["moves"].int())
.squeeze(1)
.float()
.reshape((-1, 6, 4 * self.moves_embeddings.embedding_dim))
)
Expand Down

0 comments on commit d23a097

Please sign in to comment.