Skip to content

Commit

Permalink
fixed replay to work with new serialized puffer
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Sep 11, 2023
1 parent 6e9c8a4 commit f8c8cf6
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def save_replays(policy_store_dir, save_dir):
args.early_stop_agent_num = 0 # run the full episode
args.resilient_population = 0 # no resilient agents

# TODO: custom models will require different policy creation functions
# NOTE: This creates a dummy learner agent. Is it necessary?
from reinforcement_learning import policy # import your policy
def make_policy(envs):
learner_policy = policy.Baseline(
Expand Down Expand Up @@ -87,8 +87,10 @@ def make_policy(envs):

# Load the policies into the policy pool
evaluator.policy_pool.update_policies({
p.name: p.policy(make_policy, evaluator.buffers[0], evaluator.device)
for p in policy_store._all_policies().values()
p.name: p.policy(
policy_args=[evaluator.buffers[0]],
device=evaluator.device
) for p in list(policy_store._all_policies().values())
})

# Set up the replay helper
Expand Down Expand Up @@ -166,7 +168,7 @@ def rank_policies(policy_store_dir, eval_curriculum_file, device):
args.resilient_population = 0 # no resilient agents
args.tasks_path = eval_curriculum_file # task-conditioning

# TODO: custom models will require different policy creation functions
# NOTE: This creates a dummy learner agent. Is it necessary?
from reinforcement_learning import policy # import your policy
def make_policy(envs):
learner_policy = policy.Baseline(
Expand Down Expand Up @@ -255,7 +257,7 @@ def make_policy(envs):
-p, --policy-store-dir: Directory to load policy checkpoints from for evaluation/ranking
-s, --replay-save-dir: Directory to save replays (Default: replays/)
-e, --eval-mode: Evaluate mode (Default: False)
-r, --replay-mode: Replay save mode (Default: False)
-d, --device: Device to use for evaluation/ranking (Default: cuda if available, otherwise cpu)
To generate replay from your checkpoints, put them together in policy_store_dir, run the following command,
Expand Down Expand Up @@ -289,12 +291,11 @@ def make_policy(envs):
help="Directory to save replays (Default: replays/)",
)
parser.add_argument(
"-e",
"--eval-mode",
dest="eval_mode",
type=bool,
default=True,
help="Evaluate mode (Default: True). To generate replay, set to False",
"-r",
"--replay-mode",
dest="replay_mode",
action="store_true",
help="Replay mode (Default: False)",
)
parser.add_argument(
"-d",
Expand All @@ -317,10 +318,10 @@ def make_policy(envs):
eval_args = parser.parse_args()
assert eval_args.policy_store_dir is not None, "Policy store directory must be specified"

if eval_args.eval_mode:
if getattr(eval_args, "replay_mode", False):
logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir)
else:
logging.info("Ranking checkpoints from %s", eval_args.policy_store_dir)
logging.info("Replays will NOT be generated")
rank_policies(eval_args.policy_store_dir, eval_args.task_file, eval_args.device)
else:
logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir)
save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir)

0 comments on commit f8c8cf6

Please sign in to comment.