diff --git a/evaluate.py b/evaluate.py index c9dc617..cf777f2 100644 --- a/evaluate.py +++ b/evaluate.py @@ -21,11 +21,10 @@ from pufferlib.frameworks import cleanrl import pufferlib.policy_ranker import pufferlib.utils -import clean_pufferl import environment -from reinforcement_learning import config +from reinforcement_learning import config, clean_pufferl def setup_policy_store(policy_store_dir): # CHECK ME: can be custom models with different architectures loaded here? @@ -62,7 +61,7 @@ def save_replays(policy_store_dir, save_dir): from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( - envs._driver_env, + envs.driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size @@ -172,7 +171,7 @@ def rank_policies(policy_store_dir, eval_curriculum_file, device): from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( - envs, + envs.driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 3bb4d58..39afde5 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -36,7 +36,7 @@ def critic(self, hidden): class Baseline(pufferlib.models.Policy): def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): - super().__init__() + super().__init__(env) self.flat_observation_space = env.flat_observation_space self.flat_observation_structure = env.flat_observation_structure diff --git a/train.py b/train.py index 08f08c0..fa6da7e 100644 --- a/train.py +++ b/train.py @@ -5,12 +5,10 @@ from pufferlib.vectorization import Serial, Multiprocessing from pufferlib.policy_store import DirectoryPolicyStore from pufferlib.frameworks import cleanrl -import clean_pufferl import environment -from reinforcement_learning import policy -from reinforcement_learning import config +from reinforcement_learning import clean_pufferl, policy, config # NOTE: this file changes when running curriculum generation track # Run test_task_encoder.py to regenerate this file (or get it from the repo) @@ -31,7 +29,7 @@ def setup_env(args): def make_policy(envs): learner_policy = policy.Baseline( - envs._driver_env, + envs.driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size