diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index f7cc984..c071fd0 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -386,6 +386,10 @@ def evaluate(self): ) self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state]) waiting_for.append(i + 1) + # Now copy the hidden state over + # This may be a little slow, but so is this whole process + self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] + self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] for i in waiting_for: self.env_send_queues[i].get() print("State migration complete")