From f8c8cf6c467b4669c8d3d575eee36385e60c0403 Mon Sep 17 00:00:00 2001 From: kywch Date: Sun, 10 Sep 2023 17:29:14 -0700 Subject: [PATCH] fixed replay to work with new serialized puffer --- evaluate.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/evaluate.py b/evaluate.py index 2b7bb8e..06bc241 100644 --- a/evaluate.py +++ b/evaluate.py @@ -58,7 +58,7 @@ def save_replays(policy_store_dir, save_dir): args.early_stop_agent_num = 0 # run the full episode args.resilient_population = 0 # no resilient agents - # TODO: custom models will require different policy creation functions + # NOTE: This creates a dummy learner agent. Is it necessary? from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( @@ -87,8 +87,10 @@ def make_policy(envs): # Load the policies into the policy pool evaluator.policy_pool.update_policies({ - p.name: p.policy(make_policy, evaluator.buffers[0], evaluator.device) - for p in policy_store._all_policies().values() + p.name: p.policy( + policy_args=[evaluator.buffers[0]], + device=evaluator.device + ) for p in list(policy_store._all_policies().values()) }) # Set up the replay helper @@ -166,7 +168,7 @@ def rank_policies(policy_store_dir, eval_curriculum_file, device): args.resilient_population = 0 # no resilient agents args.tasks_path = eval_curriculum_file # task-conditioning - # TODO: custom models will require different policy creation functions + # NOTE: This creates a dummy learner agent. Is it necessary? from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( @@ -255,7 +257,7 @@ def make_policy(envs): -p, --policy-store-dir: Directory to load policy checkpoints from for evaluation/ranking -s, --replay-save-dir: Directory to save replays (Default: replays/) - -e, --eval-mode: Evaluate mode (Default: False) + -r, --replay-mode: Replay save mode (Default: False) -d, --device: Device to use for evaluation/ranking (Default: cuda if available, otherwise cpu) To generate replay from your checkpoints, put them together in policy_store_dir, run the following command, @@ -289,12 +291,11 @@ def make_policy(envs): help="Directory to save replays (Default: replays/)", ) parser.add_argument( - "-e", - "--eval-mode", - dest="eval_mode", - type=bool, - default=True, - help="Evaluate mode (Default: True). To generate replay, set to False", + "-r", + "--replay-mode", + dest="replay_mode", + action="store_true", + help="Replay mode (Default: False)", ) parser.add_argument( "-d", @@ -317,10 +318,10 @@ def make_policy(envs): eval_args = parser.parse_args() assert eval_args.policy_store_dir is not None, "Policy store directory must be specified" - if eval_args.eval_mode: + if getattr(eval_args, "replay_mode", False): + logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir) + save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir) + else: logging.info("Ranking checkpoints from %s", eval_args.policy_store_dir) logging.info("Replays will NOT be generated") rank_policies(eval_args.policy_store_dir, eval_args.task_file, eval_args.device) - else: - logging.info("Generating replays from the checkpoints in %s", eval_args.policy_store_dir) - save_replays(eval_args.policy_store_dir, eval_args.replay_save_dir)