Skip to content

Commit

Permalink
Attempt to fix async queues creation
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 5, 2024
1 parent e7688f2 commit 378d2a7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def train(
vec = pufferlib.vector.Ray

# TODO: Remove the +1 once the driver env doesn't permanently increase the env id
env_send_queues = [Queue() for _ in range(args.train.num_envs + 1)]
env_recv_queues = [Queue() for _ in range(args.train.num_envs + 1)]
env_send_queues = [Queue() for _ in range(args.train.num_envs + args.train.num_workers)]
env_recv_queues = [Queue() for _ in range(args.train.num_envs + args.train.num_workers)]

vecenv = pufferlib.vector.make(
env_creator,
Expand Down
1 change: 0 additions & 1 deletion pokemonred_puffer/wrappers/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class AsyncWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, send_queues: list[Queue], recv_queues: list[Queue]):
super().__init__(env)
# We need to -1 because the env id is one offset due to puffer's driver env
self.send_queue = send_queues[self.env.unwrapped.env_id]
self.recv_queue = recv_queues[self.env.unwrapped.env_id]
print(f"Initialized queues for {self.env.unwrapped.env_id}")
Expand Down

0 comments on commit 378d2a7

Please sign in to comment.