Skip to content

Commit

Permalink
swarm on required events and items
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 5, 2024
1 parent bffd0fa commit 245afa6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ def evaluate(self):
f.write(str(key))
with open(state_dir / f"{hash(v)}.state", "wb") as f:
f.write(v)
elif "required_events_count" == k:
elif "required_count" == k:
for count, eid in zip(
self.infos["required_events_count"], self.infos["env_id"]
self.infos["required_count"], self.infos["env_id"]
):
self.event_tracker[eid] = count
self.infos[k].append(v)
Expand All @@ -294,7 +294,7 @@ def evaluate(self):
and hasattr(self.config, "swarm_frequency")
and hasattr(self.config, "swarm_keep_pct")
# and self.epoch % self.config.swarm_frequency == 0
and "required_events_count" in self.infos
and "required_count" in self.infos
and self.states
):
"""
Expand All @@ -321,7 +321,7 @@ def evaluate(self):
"""

# V2 implementation
# check if we have a new highest required_events_count with N save states available
# check if we have a new highest required_count with N save states available
# If we do, migrate 100% of states to one of the states
max_event_count = 0
new_state_key = ""
Expand Down
23 changes: 19 additions & 4 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.wd728 = Wd728Flags(self.pyboy)
self.party = PartyMons(self.pyboy)
self.required_events = self.get_required_events()
self.required_items = self.get_required_items()
self.seen_pokemon = np.zeros(152, dtype=np.uint8)
self.caught_pokemon = np.zeros(152, dtype=np.uint8)
self.moves_obtained = np.zeros(0xA5, dtype=np.uint8)
Expand All @@ -309,8 +310,12 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.pyboy.save_state(state)
state.seek(0)
infos |= {
"state": {tuple(self.required_events): state.read()},
"required_events_count": len(self.required_events),
"state": {
tuple(
sorted(list(self.required_events) + list(self.required_items))
): state.read()
},
"required_count": len(self.required_events) + len(self.required_items),
"env_id": self.env_id,
}
# lazy random seed setting
Expand Down Expand Up @@ -343,6 +348,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.max_opponent_level = 0
self.max_event_rew = 0
self.required_events = self.get_required_events()
self.required_items = self.get_required_items()
self.max_level_rew = 0
self.max_level_sum = 0
self.last_health = 1
Expand Down Expand Up @@ -642,18 +648,21 @@ def step(self, action):
info = {}

required_events = self.get_required_events()
required_items = self.get_required_items()
new_required_events = required_events - self.required_events
if self.save_state and new_required_events:
new_required_items = required_items - self.required_items
if self.save_state and (new_required_events or new_required_items):
state = io.BytesIO()
self.pyboy.save_state(state)
state.seek(0)
info["state"] = {tuple(required_events): state.read()}
info["required_events_count"] = len(required_events)
info["required_count"] = len(required_events) + len(required_items)
info["env_id"] = self.env_id
info = info | self.agent_stats(action)
elif self.step_count % self.log_frequency == 0:
info = info | self.agent_stats(action)
self.required_events = required_events
self.required_items = required_items

obs = self._get_obs()

Expand Down Expand Up @@ -1517,6 +1526,12 @@ def get_required_events(self) -> set[str]:
| ({"saffron_guard"} if self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK") else set())
)

def get_required_items(self) -> set[str]:
_, wNumBagItems = self.pyboy.symbol_lookup("wNumBagItems")
_, wBagItems = self.pyboy.symbol_lookup("wBagItems")
bag_items = self.pyboy.memory[wBagItems : wBagItems + wNumBagItems * 2]
return {Items(item).name for item in bag_items[::2] if Items(item) in REQUIRED_ITEMS}

def get_events_sum(self):
# adds up all event flags, exclude museum ticket
return max(
Expand Down

0 comments on commit 245afa6

Please sign in to comment.