Skip to content

Commit

Permalink
Merge pull request #73 from CarperAI/rel-puf04
Browse files Browse the repository at this point in the history
Incorporate puffer 0.4 changes
  • Loading branch information
jsuarez5341 authored Sep 20, 2023
2 parents bb93bf9 + 59fe20d commit eb1048d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
7 changes: 3 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
from pufferlib.frameworks import cleanrl
import pufferlib.policy_ranker
import pufferlib.utils
import clean_pufferl

import environment

from reinforcement_learning import config
from reinforcement_learning import config, clean_pufferl

def setup_policy_store(policy_store_dir):
# CHECK ME: can be custom models with different architectures loaded here?
Expand Down Expand Up @@ -62,7 +61,7 @@ def save_replays(policy_store_dir, save_dir):
from reinforcement_learning import policy # import your policy
def make_policy(envs):
learner_policy = policy.Baseline(
envs._driver_env,
envs.driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
Expand Down Expand Up @@ -172,7 +171,7 @@ def rank_policies(policy_store_dir, eval_curriculum_file, device):
from reinforcement_learning import policy # import your policy
def make_policy(envs):
learner_policy = policy.Baseline(
envs,
envs.driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
Expand Down
2 changes: 1 addition & 1 deletion reinforcement_learning/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def critic(self, hidden):

class Baseline(pufferlib.models.Policy):
def __init__(self, env, input_size=256, hidden_size=256, task_size=4096):
super().__init__()
super().__init__(env)

self.flat_observation_space = env.flat_observation_space
self.flat_observation_structure = env.flat_observation_structure
Expand Down
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from pufferlib.vectorization import Serial, Multiprocessing
from pufferlib.policy_store import DirectoryPolicyStore
from pufferlib.frameworks import cleanrl
import clean_pufferl

import environment

from reinforcement_learning import policy
from reinforcement_learning import config
from reinforcement_learning import clean_pufferl, policy, config

# NOTE: this file changes when running curriculum generation track
# Run test_task_encoder.py to regenerate this file (or get it from the repo)
Expand All @@ -31,7 +29,7 @@ def setup_env(args):

def make_policy(envs):
learner_policy = policy.Baseline(
envs._driver_env,
envs.driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
Expand Down

0 comments on commit eb1048d

Please sign in to comment.