Skip to content

Commit

Permalink
added task-conditioned replay, tweaked req, pufferl
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Nov 2, 2023
1 parent 15b1fb1 commit 083a1c7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
56 changes: 52 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from dataclasses import asdict
from itertools import cycle

import dill
import numpy as np
import torch
import pandas as pd

from nmmo.render.replay_helper import FileReplayHelper
from nmmo.task.task_spec import make_task_from_spec

import pufferlib
from pufferlib.vectorization import Serial, Multiprocessing
Expand All @@ -36,9 +38,10 @@ def setup_policy_store(policy_store_dir):
policy_store = DirectoryPolicyStore(policy_store_dir)
return policy_store

def save_replays(policy_store_dir, save_dir):
def save_replays(policy_store_dir, save_dir, curriculum_file, task_to_assign=None):
# load the checkpoints into the policy store
policy_store = setup_policy_store(policy_store_dir)
policy_ranker = create_policy_ranker(policy_store_dir)
num_policies = len(policy_store._all_policies())

# setup the replay path
Expand All @@ -56,6 +59,7 @@ def save_replays(policy_store_dir, save_dir):
args.selfplay_num_policies = num_policies + 1
args.early_stop_agent_num = 0 # run the full episode
args.resilient_population = 0 # no resilient agents
args.tasks_path = curriculum_file # task-conditioning

# NOTE: This creates a dummy learner agent. Is it necessary?
from reinforcement_learning import policy # import your policy
Expand All @@ -81,6 +85,7 @@ def make_policy(envs):
selfplay_learner_weight=args.learner_weight,
selfplay_num_policies=args.selfplay_num_policies,
policy_store=policy_store,
policy_ranker=policy_ranker, # so that a new ranker is created
data_dir=save_dir,
)

Expand All @@ -97,9 +102,23 @@ def make_policy(envs):
replay_helper = FileReplayHelper()
nmmo_env = evaluator.buffers[0].envs[0].envs[0].env
nmmo_env.realm.record_replay(replay_helper)
replay_helper.reset()

if task_to_assign is not None:
with open(curriculum_file, 'rb') as f:
task_with_embedding = dill.load(f) # a list of TaskSpec
assert 0 <= task_to_assign < len(task_with_embedding), "Task index out of range"
select_task = task_with_embedding[task_to_assign]

# Assign the task to the env
tasks = make_task_from_spec(nmmo_env.possible_agents,
[select_task] * len(nmmo_env.possible_agents))
#nmmo_env.reset(make_task_fn=lambda: tasks)
nmmo_env.tasks = tasks # this is a hack
print("seed:", args.seed,
", task:", nmmo_env.tasks[0].spec_name)

# Run an episode to generate the replay
replay_helper.reset()
while True:
with torch.no_grad():
actions, logprob, value, _ = evaluator.policy_pool.forwards(
Expand All @@ -112,10 +131,28 @@ def make_policy(envs):
o, r, d, i = evaluator.buffers[0].recv()

num_alive = len(nmmo_env.realm.players)
print('Tick:', nmmo_env.realm.tick, ", alive agents:", num_alive)
task_done = sum(1 for task in nmmo_env.tasks if task.completed)
alive_done = sum(1 for task in nmmo_env.tasks
if task.completed and task.assignee[0] in nmmo_env.realm.players)
print("Tick:", nmmo_env.realm.tick, ", alive agents:", num_alive, ", task done:", task_done)
if num_alive == alive_done:
print("All alive agents completed the task.")
break
if num_alive == 0 or nmmo_env.realm.tick == args.max_episode_length:
print("All agents died or reached the max episode length.")
break

# Count how many agents completed the task
print("--------------------------------------------------")
print("Task:", nmmo_env.tasks[0].spec_name)
num_completed = sum(1 for task in nmmo_env.tasks if task.completed)
print("Number of agents completed the task:", num_completed)
avg_progress = np.mean([task.progress_info["max_progress"] for task in nmmo_env.tasks])
print(f"Average maximum progress (max=1): {avg_progress:.3f}")
avg_completed_tick = np.mean([task.progress_info["completed_tick"]
for task in nmmo_env.tasks if task.completed])
print(f"Average completed tick: {avg_completed_tick:.1f}")

# Save the replay file
replay_file = os.path.join(save_dir, f"replay_{time.strftime('%Y%m%d_%H%M%S')}")
logging.info("Saving replay to %s", replay_file)
Expand Down Expand Up @@ -243,6 +280,8 @@ def make_policy(envs):
-s, --replay-save-dir: Directory to save replays (Default: replays/)
-r, --replay-mode: Replay save mode (Default: False)
-d, --device: Device to use for evaluation/ranking (Default: cuda if available, otherwise cpu)
-t, --task-file: Task file to use for evaluation (Default: reinforcement_learning/eval_task_with_embedding.pkl)
-i, --task-index: The index of the task to assign in the curriculum file (Default: None)
To generate replay from your checkpoints, put them together in policy_store_dir, run the following command,
and replays will be saved under the replays/. The script will only use 1 environment.
Expand Down Expand Up @@ -297,14 +336,23 @@ def make_policy(envs):
default="reinforcement_learning/eval_task_with_embedding.pkl",
help="Task file to use for evaluation",
)
parser.add_argument(
"-i",
"--task-index",
dest="task_index",
type=int,
default=None,
help="The index of the task to assign in the curriculum file",
)

# Parse and check the arguments
eval_args = parser.parse_args()
assert eval_args.policy_store_dir is not None, "Policy store directory must be specified"

if getattr(eval_args, "replay_mode", False):
logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir,
eval_args.task_file, eval_args.task_index)
else:
logging.info("Ranking checkpoints from %s", eval_args.policy_store_dir)
logging.info("Replays will NOT be generated")
Expand Down
6 changes: 2 additions & 4 deletions reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ def __post_init__(self, *args, **kwargs):
# Create policy ranker
if self.policy_ranker is None:
if self.data_dir is not None:
self.policy_ranker = pufferlib.policy_ranker.OpenSkillRanker(
os.path.join(self.data_dir, "openskill.pickle"),
"anchor",
)
db_file = os.path.join(self.data_dir, "ranking.sqlite")
self.policy_ranker = pufferlib.policy_ranker.OpenSkillRanker(db_file, "anchor")
if "learner" not in self.policy_ranker.ratings():
self.policy_ranker.add_policy("learner")

Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ openelm
pandas==2.0.3
plotly==5.15.0
psutil==5.9.3
ray==2.6.1
scikit-learn==1.3.0
tensorboard==2.11.2
tiktoken==0.4.0
torch==1.13.1
torchtyping==0.1.4
traitlets==5.9.0
transformers==4.31.0
wandb==0.13.7

0 comments on commit 083a1c7

Please sign in to comment.