Skip to content

Commit

Permalink
actually filter down the number of envs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 8, 2024
1 parent 49042a4 commit b89c96c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,11 @@ def evaluate(self):
):
# collect the top swarm_keep_pct % of the envs in the batch
largest = [
x[0]
x[1][0]
for x in heapq.nlargest(
math.ceil(len(self.event_tracker) * self.config.swarm_keep_pct),
enumerate(self.event_tracker),
key=lambda x: x[1],
enumerate(self.event_tracker.items()),
key=lambda x: x[1][1],
)
]
print("Migrating states:")
Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.pyboy.save_state(state)
state.seek(0)
infos |= {
"state": {hash("".join(self.required_events)): state.read()},
"state": {tuple(self.required_events): state.read()},
"required_events_count": len(self.required_events),
"env_id": self.env_id,
}
Expand Down Expand Up @@ -647,7 +647,7 @@ def step(self, action):
state = io.BytesIO()
self.pyboy.save_state(state)
state.seek(0)
info["state"] = {hash("".join(required_events)): state.read()}
info["state"] = {tuple(required_events): state.read()}
info["required_events_count"] = len(required_events)
info["env_id"] = self.env_id
info = info | self.agent_stats(action)
Expand Down

0 comments on commit b89c96c

Please sign in to comment.