From 245afa621d951817b47de4783f2d5ca5450141b5 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:09:09 -0400 Subject: [PATCH] swarm on required events and items --- pokemonred_puffer/cleanrl_puffer.py | 8 ++++---- pokemonred_puffer/environment.py | 23 +++++++++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 284f1b9..3763d0e 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -268,9 +268,9 @@ def evaluate(self): f.write(str(key)) with open(state_dir / f"{hash(v)}.state", "wb") as f: f.write(v) - elif "required_events_count" == k: + elif "required_count" == k: for count, eid in zip( - self.infos["required_events_count"], self.infos["env_id"] + self.infos["required_count"], self.infos["env_id"] ): self.event_tracker[eid] = count self.infos[k].append(v) @@ -294,7 +294,7 @@ def evaluate(self): and hasattr(self.config, "swarm_frequency") and hasattr(self.config, "swarm_keep_pct") # and self.epoch % self.config.swarm_frequency == 0 - and "required_events_count" in self.infos + and "required_count" in self.infos and self.states ): """ @@ -321,7 +321,7 @@ def evaluate(self): """ # V2 implementation - # check if we have a new highest required_events_count with N save states available + # check if we have a new highest required_count with N save states available # If we do, migrate 100% of states to one of the states max_event_count = 0 new_state_key = "" diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 1fa531e..6c071fa 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -299,6 +299,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.wd728 = Wd728Flags(self.pyboy) self.party = PartyMons(self.pyboy) self.required_events = self.get_required_events() + self.required_items = self.get_required_items() self.seen_pokemon = np.zeros(152, dtype=np.uint8) self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) @@ -309,8 +310,12 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.pyboy.save_state(state) state.seek(0) infos |= { - "state": {tuple(self.required_events): state.read()}, - "required_events_count": len(self.required_events), + "state": { + tuple( + sorted(list(self.required_events) + list(self.required_items)) + ): state.read() + }, + "required_count": len(self.required_events) + len(self.required_items), "env_id": self.env_id, } # lazy random seed setting @@ -343,6 +348,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.max_opponent_level = 0 self.max_event_rew = 0 self.required_events = self.get_required_events() + self.required_items = self.get_required_items() self.max_level_rew = 0 self.max_level_sum = 0 self.last_health = 1 @@ -642,18 +648,21 @@ def step(self, action): info = {} required_events = self.get_required_events() + required_items = self.get_required_items() new_required_events = required_events - self.required_events - if self.save_state and new_required_events: + new_required_items = required_items - self.required_items + if self.save_state and (new_required_events or new_required_items): state = io.BytesIO() self.pyboy.save_state(state) state.seek(0) info["state"] = {tuple(required_events): state.read()} - info["required_events_count"] = len(required_events) + info["required_count"] = len(required_events) + len(required_items) info["env_id"] = self.env_id info = info | self.agent_stats(action) elif self.step_count % self.log_frequency == 0: info = info | self.agent_stats(action) self.required_events = required_events + self.required_items = required_items obs = self._get_obs() @@ -1517,6 +1526,12 @@ def get_required_events(self) -> set[str]: | ({"saffron_guard"} if self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK") else set()) ) + def get_required_items(self) -> set[str]: + _, wNumBagItems = self.pyboy.symbol_lookup("wNumBagItems") + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + bag_items = self.pyboy.memory[wBagItems : wBagItems + wNumBagItems * 2] + return {Items(item).name for item in bag_items[::2] if Items(item) in REQUIRED_ITEMS} + def get_events_sum(self): # adds up all event flags, exclude museum ticket return max(