diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 9c8f8d1..c055a03 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,4 +1,5 @@ import argparse +import ast from functools import partial import heapq import math @@ -252,7 +253,7 @@ def evaluate(self): for k, v in pufferlib.utils.unroll_nested_dict(i): if "state/" in k: _, key = k.split("/") - self.states[key].append(v) + self.states[ast.literal_eval(key)].append(v) elif "required_events_count" == k: for count, eid in zip( self.infos["required_events_count"], self.infos["env_id"] @@ -284,11 +285,11 @@ def evaluate(self): ): # collect the top swarm_keep_pct % of the envs in the batch largest = [ - x[1][0] + x[1][1] for x in heapq.nlargest( math.ceil(len(self.event_tracker) * self.config.swarm_keep_pct), enumerate(self.event_tracker.items()), - key=lambda x: x[1][1], + key=lambda x: x[1][0], ) ] waiting_for = [] @@ -301,13 +302,12 @@ def evaluate(self): for key in to_migrate_keys: # we store states in a weird format # pull a list of states corresponding to a required event completion state - new_state_key = random.choice(list(self.states)) + new_state_key = random.choice(list(self.states.keys())) # pull a state within that list new_state = random.choice(self.states[new_state_key]) - # TODO: Fill in more information about the new state - print(f"\tOld events count: {len(key)}") - print(f"\tOld events: {key}") - print(f"\tNew events count: {len(new_state_key)}") + + print(f"Environment ID: {key}") + print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}") print(f"\tNew events: {new_state_key}") self.env_recv_queues[key].put(new_state) waiting_for.append(key)