Skip to content

Commit

Permalink
Merge branch 'early-stop'
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Nov 16, 2024
2 parents 02dbe6c + 85d928d commit ae983ed
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 115 deletions.
8 changes: 8 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ train:
sqlite_wrapper: True
archive_states: True
swarm: True
early_stop:
# event name: minutes. If we dont satisfy each condition
# we early stop
# The defaults have a margin of error
EVENT_BEAT_BROCK: 30
EVENT_BEAT_MISTY: 90
EVENT_GOT_HM01: 180
one_epoch: True

wrappers:
empty:
Expand Down
98 changes: 71 additions & 27 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pokemonred_puffer.eval import make_pokemon_red_overlay
from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE
from pokemonred_puffer.profile import Profile, Utilization
from pokemonred_puffer.wrappers.sqlite import SqliteStateResetWrapper

pyximport.install(setup_args={"include_dirs": np.get_include()})
from pokemonred_puffer.c_gae import compute_gae # type: ignore # noqa: E402
Expand Down Expand Up @@ -142,6 +143,7 @@ class CleanPuffeRL:
states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=1)))
event_tracker: dict = field(default_factory=lambda: {})
max_event_count: int = 0
early_stop: bool = False

def __post_init__(self):
seed_everything(self.config.seed, self.config.torch_deterministic)
Expand Down Expand Up @@ -281,6 +283,37 @@ def evaluate(self):
self.vecenv.send(actions)

with self.profile.eval_misc:
# TODO: use the event infos instead of the states.
# I'm always running with state saving on right now so it's alright
if self.states and self.config.early_stop:
to_delete = []
for event, minutes in self.config.early_stop.items():
if any(event in key for key in self.states.keys()):
to_delete.append(event)
elif (self.profile.uptime > (minutes * 60)) and all(
event not in key for key in self.states.keys()
):
print(
f"Early stopping. In {self.profile.uptime // 60} minutes, "
f"Event {event} was not found in any states within its"
f"{minutes} minutes time limit"
)
self.early_stop = True
break
else:
print(
f"Early stopping check. In {self.profile.uptime // 60} minutes, "
f"Event {event} was not found in any states within its"
f"{minutes} minutes time limit"
)
for event in to_delete:
print(
f"Satisified early stopping constraint for {event} within "
f"{self.config.early_stop[event]} minutes. "
f"Event found n {self.profile.uptime // 60} minutes"
)
del self.config.early_stop[event]

# now for a tricky bit:
# if we have swarm_frequency, we will migrate the bottom
# % of envs in the batch (by required events count)
Expand Down Expand Up @@ -347,36 +380,39 @@ def evaluate(self):
)
]
if self.sqlite_db:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
cur.executemany(
"""
UPDATE states
SET pyboy_state=:state,
reset=:reset
WHERE env_id=:env_id
""",
tuple(
[
{"state": state, "reset": 1, "env_id": env_id}
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
]
),
)
self.vecenv.async_reset()
key_set = self.event_tracker.keys()
while True:
# We connect each time just in case we block the wrappers
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
resets = cur.execute(
cur.executemany(
"""
SELECT reset, env_id
FROM states
UPDATE states
SET pyboy_state=:state,
reset=:reset
WHERE env_id=:env_id
""",
).fetchall()
tuple(
[
{"state": state, "reset": 1, "env_id": env_id}
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
]
),
)
self.vecenv.async_reset()
# drain any INFO
key_set = self.event_tracker.keys()
while True:
# We connect each time just in case we block the wrappers
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
resets = cur.execute(
"""
SELECT reset, env_id
FROM states
""",
).fetchall()
if all(not reset for reset, env_id in resets if env_id in key_set):
break
time.sleep(0.5)
Expand Down Expand Up @@ -649,7 +685,15 @@ def calculate_loss(self, pg_loss, entropy_loss, v_loss):
self.optimizer.step()

def done_training(self):
return self.global_step >= self.config.total_timesteps
return (
self.early_stop
or self.global_step >= self.config.total_timesteps
or (
self.config.one_epoch
and self.states
and any("EVENT_BEAT_CHAMPION_RIVAL" in key for key in self.states.keys())
)
)

def __enter__(self):
return self
Expand Down
8 changes: 8 additions & 0 deletions pokemonred_puffer/data/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ class MapIds(Enum):
LORELEIS_ROOM = 0xF5
BRUNOS_ROOM = 0xF6
AGATHAS_ROOM = 0xF7
UNUSED_MAP_F8 = 0xF8
UNUSED_MAP_F9 = 0xF9
UNUSED_MAP_FA = 0xFA
UNUSED_MAP_FB = 0xFB
UNUSED_MAP_FC = 0xFC
UNUSED_MAP_FD = 0xFD
UNUSED_MAP_FE = 0xFE
UNUSED_MAP_FF = 0xFF


RESET_MAP_IDS = {
Expand Down
Loading

0 comments on commit ae983ed

Please sign in to comment.