From 32961944b016206398e06ba7849908aa6b37668b Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Fri, 12 Jul 2024 09:33:36 -0400 Subject: [PATCH] archive states if requested --- config.yaml | 1 + pokemonred_puffer/cleanrl_puffer.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 99842f8..78b3784 100644 --- a/config.yaml +++ b/config.yaml @@ -118,6 +118,7 @@ train: load_optimizer_state: False use_rnn: True async_wrapper: True + archive_states: True # swarm_frequency: 500 # swarm_keep_pct: .8 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 1ddbc71..084ed1b 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,9 +1,11 @@ import argparse import ast +from datetime import datetime from functools import partial import heapq import math import os +import pathlib import random import time from collections import defaultdict, deque @@ -201,6 +203,10 @@ def __post_init__(self): self.taught_cut = False self.log = False + if self.config.archive_states: + self.archive_path = pathlib.Path(datetime.now().strftime("%Y%m%d-%H%M%S")) + self.archive_path.mkdir(exist_ok=False) + @pufferlib.utils.profile def evaluate(self): # states are managed separately so dont worry about deleting them @@ -253,7 +259,16 @@ def evaluate(self): for k, v in pufferlib.utils.unroll_nested_dict(i): if "state/" in k: _, key = k.split("/") - self.states[ast.literal_eval(key)].append(v) + key: tuple[str] = ast.literal_eval(key) + self.states[key].append(v) + if self.config.archive_states: + state_dir = self.archive_path / str(hash(key)) + if not state_dir.exists(): + state_dir.mkdir(exist_ok=True) + with open(state_dir / "desc.txt", "w") as f: + f.write(str(key)) + with open(state_dir / f"{hash(v)}.state", "wb") as f: + f.write(v) elif "required_events_count" == k: for count, eid in zip( self.infos["required_events_count"], self.infos["env_id"] @@ -292,7 +307,6 @@ def evaluate(self): key=lambda x: x[1][0], ) ] - waiting_for = [] # find the envs not in the largest to_migrate_keys = set(self.event_tracker.keys()) - set(largest) @@ -310,13 +324,13 @@ def evaluate(self): print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}") print(f"\tNew events: {new_state_key}") self.env_recv_queues[key].put(new_state) - waiting_for.append(key) # Now copy the hidden state over # This may be a little slow, but so is this whole process # self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] # self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] - for i in waiting_for: - self.env_send_queues[i].get() + for key in to_migrate_keys: + print(f"\tWaiting for message from env-id {key}") + self.env_send_queues[key].get() print("State migration complete") self.stats = {}