Skip to content

Commit

Permalink
Merge branch 'release' into rel-hpc
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Sep 13, 2023
2 parents 0877269 + 1222980 commit f7ba912
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def save_replays(policy_store_dir, save_dir):
from reinforcement_learning import policy # import your policy
def make_policy(envs):
learner_policy = policy.Baseline(
envs,
envs._driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
Expand Down
9 changes: 6 additions & 3 deletions reinforcement_learning/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def critic(self, hidden):


class Baseline(pufferlib.models.Policy):
def __init__(self, envs, input_size=256, hidden_size=256, task_size=4096):
def __init__(self, env, input_size=256, hidden_size=256, task_size=4096):
super().__init__()
self.envs = envs

self.flat_observation_space = env.flat_observation_space
self.flat_observation_structure = env.flat_observation_structure

self.tile_encoder = TileEncoder(input_size)
self.player_encoder = PlayerEncoder(input_size, hidden_size)
Expand All @@ -50,7 +52,8 @@ def __init__(self, envs, input_size=256, hidden_size=256, task_size=4096):
self.value_head = torch.nn.Linear(hidden_size, 1)

def encode_observations(self, flat_observations):
env_outputs = self.envs.unpack_batched_obs(flat_observations)
env_outputs = pufferlib.emulation.unpack_batched_obs(flat_observations,
self.flat_observation_space, self.flat_observation_structure)
tile = self.tile_encoder(env_outputs["Tile"])
player_embeddings, my_agent = self.player_encoder(
env_outputs["Entity"], env_outputs["AgentId"][:, 0]
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setup_env(args):

def make_policy(envs):
learner_policy = policy.Baseline(
envs,
envs._driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
Expand Down

0 comments on commit f7ba912

Please sign in to comment.