diff --git a/evaluate.py b/evaluate.py index 06bc2410..c9dc6171 100644 --- a/evaluate.py +++ b/evaluate.py @@ -62,7 +62,7 @@ def save_replays(policy_store_dir, save_dir): from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( - envs, + envs._driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size