Skip to content

Commit

Permalink
Merge pull request #21 from CarperAI/release
Browse files Browse the repository at this point in the history
ratings
  • Loading branch information
jsuarez5341 authored Oct 11, 2023
2 parents 61ec922 + 26f64fe commit 40e7338
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 28 deletions.
41 changes: 13 additions & 28 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,15 @@ def make_policy(envs):
replay_helper.save(replay_file, compress=False)
evaluator.close()

def create_policy_ranker(policy_store_dir, ranker_file="openskill.pickle"):
def create_policy_ranker(policy_store_dir, ranker_file="ranker.pickle", db_file="ranking.sqlite"):
file = os.path.join(policy_store_dir, ranker_file)
if os.path.exists(file):
if os.path.exists(file + ".lock"):
raise ValueError("Policy ranker file is locked. Delete the lock file.")
logging.info("Using policy ranker from %s", file)
policy_ranker = pufferlib.utils.PersistentObject(
file,
pufferlib.policy_ranker.OpenSkillRanker,
)
logging.info("Using existing policy ranker from %s", file)
policy_ranker = pufferlib.policy_ranker.OpenSkillRanker.load_from_file(file)
else:
policy_ranker = pufferlib.utils.PersistentObject(
file,
pufferlib.policy_ranker.OpenSkillRanker,
"anchor",
)
logging.info("Creating a new policy ranker and db under %s", policy_store_dir)
db_file = os.path.join(policy_store_dir, db_file)
policy_ranker = pufferlib.policy_ranker.OpenSkillRanker(db_file, "anchor")
return policy_ranker

class AllPolicySelector(pufferlib.policy_ranker.PolicySelector):
Expand Down Expand Up @@ -198,8 +191,10 @@ def make_policy(envs):
policy_selector=policy_selector,
)

rank_file = os.path.join(policy_store_dir, "ranking.txt")
with open(rank_file, "w") as f:
ranker_file = os.path.join(policy_store_dir, "ranker.pickle")
# This is for quick viewing of the ranks, not for the actual ranking
rank_txt = os.path.join(policy_store_dir, "ranking.txt")
with open(rank_txt, "w") as f:
pass

results = defaultdict(list)
Expand All @@ -214,12 +209,13 @@ def make_policy(envs):
ratings = evaluator.policy_ranker.ratings()
dataframe = pd.DataFrame(
{
("Rating"): [ratings.get(n).mu for n in ratings],
("Rating"): [ratings.get(n).get("mu") for n in ratings],
("Policy"): ratings.keys(),
}
)

with open(rank_file, "a") as f:
ratings = evaluator.policy_ranker.save_to_file(ranker_file)
with open(rank_txt, "a") as f:
f.write(
"\n\n"
+ dataframe.round(2)
Expand All @@ -228,17 +224,6 @@ def make_policy(envs):
+ "\n\n"
)

# Reset the envs and start the new episodes
# NOTE: The below line will probably end the episode in the middle,
# so we won't be able to sample scores from the successful agents.
# Thus, the scores will be biased towards the agents that die early.
# Still, the numbers we get this way is better than frequently
# updating the scores because the openskill ranking only takes the mean.
#evaluator.buffers[0]._async_reset()

# CHECK ME: delete the policy_ranker lock file
Path(evaluator.policy_ranker.lock.lock_file).unlink(missing_ok=True)

evaluator.close()
for pol, res in results.items():
aggregated = {}
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ray==2.6.1
scikit-learn==1.3.0
tensorboard==2.11.2
tiktoken==0.4.0
torch==1.13.1
torchtyping==0.1.4
traitlets==5.9.0
transformers==4.31.0
Expand Down

0 comments on commit 40e7338

Please sign in to comment.