diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 084ed1b..a4573d4 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -2,8 +2,6 @@ import ast from datetime import datetime from functools import partial -import heapq -import math import os import pathlib import random @@ -144,6 +142,7 @@ class CleanPuffeRL: infos: dict = field(default_factory=lambda: defaultdict(list)) states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=5))) event_tracker: dict = field(default_factory=lambda: {}) + max_event_count: int = 0 def __post_init__(self): seed_everything(self.config.seed, self.config.torch_deterministic) @@ -294,11 +293,14 @@ def evaluate(self): self.config.async_wrapper and hasattr(self.config, "swarm_frequency") and hasattr(self.config, "swarm_keep_pct") - and self.epoch % self.config.swarm_frequency == 0 + # and self.epoch % self.config.swarm_frequency == 0 and "required_events_count" in self.infos and self.states ): - # collect the top swarm_keep_pct % of the envs in the batch + """ + # V1 implementation - + # collect the top swarm_keep_pct % of the envs in the batch + # migrate the envs not in the largest keep pct to one of the top states largest = [ x[1][0] for x in heapq.nlargest( @@ -308,17 +310,29 @@ def evaluate(self): ) ] - # find the envs not in the largest to_migrate_keys = set(self.event_tracker.keys()) - set(largest) print(f"Migrating {len(to_migrate_keys)} states:") - # Need a way not to reset the env id counter for the driver env - # Until then env ids are 1-indexed for key in to_migrate_keys: # we store states in a weird format # pull a list of states corresponding to a required event completion state new_state_key = random.choice(list(self.states.keys())) # pull a state within that list new_state = random.choice(self.states[new_state_key]) + """ + + # V2 implementation + # check if we have a new highest required_events_count with N save states available + # If we do, migrate 100% of states to one of the states + max_event_count, new_state_key = max(self.states.keys()) + max_state: deque = self.states[key] + to_migrate_keys = [] + if max_event_count > self.max_event_count and len(max_state) == max_state.maxlen: + to_migrate_keys = self.event_tracker.keys() + + # Need a way not to reset the env id counter for the driver env + # Until then env ids are 1-indexed + for key in to_migrate_keys: + new_state = random.choice(self.states[new_state_key]) print(f"Environment ID: {key}") print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}")