From ff618f95b219f634980855a7fbaa9bb3c39e3e6a Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:23:24 -0400 Subject: [PATCH] copy hidden state --- pokemonred_puffer/cleanrl_puffer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index f7cc984..c071fd0 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -386,6 +386,10 @@ def evaluate(self): ) self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state]) waiting_for.append(i + 1) + # Now copy the hidden state over + # This may be a little slow, but so is this whole process + self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] + self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] for i in waiting_for: self.env_send_queues[i].get() print("State migration complete")