diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 345dee9..02884d6 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -303,7 +303,7 @@ def evaluate(self): # pull a list of states corresponding to a required event completion state new_state = random.choice(list(self.states)) # pull a state within that list - new_state = random.choice(new_state) + new_state = random.choice(self.states[new_state]) # TODO: Fill in more information about the new state print(f"\t {key}") self.env_recv_queues[key].put(new_state) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 9491555..f510cb0 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -296,6 +296,16 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) self.pokecenters = np.zeros(252, dtype=np.uint8) + + if self.save_state: + state = io.BytesIO() + self.pyboy.save_state(state) + state.seek(0) + infos |= { + "state": {hash("".join(self.required_events)): state.read()}, + "required_events_count": len(self.required_events), + "env_id": self.env_id, + } # lazy random seed setting # if not seed: # seed = random.randint(0, 4096) @@ -344,14 +354,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.first = False - if self.save_state: - state = io.BytesIO() - self.pyboy.save_state(state) - state.seek(0) - infos |= { - "state": {hash("".join(self.required_events)): state.read()}, - "required_events_count": len(self.required_events), - } return self._get_obs(), infos def init_mem(self):