From 4888c584b59e5f7ca9a44182d27f0d459feba5b7 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 12 Sep 2023 21:32:03 +0000 Subject: [PATCH 1/2] Update policy for new pufferlib save --- reinforcement_learning/policy.py | 9 ++++++--- train.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 49314ec..3bb4d58 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -35,9 +35,11 @@ def critic(self, hidden): class Baseline(pufferlib.models.Policy): - def __init__(self, envs, input_size=256, hidden_size=256, task_size=4096): + def __init__(self, env, input_size=256, hidden_size=256, task_size=4096): super().__init__() - self.envs = envs + + self.flat_observation_space = env.flat_observation_space + self.flat_observation_structure = env.flat_observation_structure self.tile_encoder = TileEncoder(input_size) self.player_encoder = PlayerEncoder(input_size, hidden_size) @@ -50,7 +52,8 @@ def __init__(self, envs, input_size=256, hidden_size=256, task_size=4096): self.value_head = torch.nn.Linear(hidden_size, 1) def encode_observations(self, flat_observations): - env_outputs = self.envs.unpack_batched_obs(flat_observations) + env_outputs = pufferlib.emulation.unpack_batched_obs(flat_observations, + self.flat_observation_space, self.flat_observation_structure) tile = self.tile_encoder(env_outputs["Tile"]) player_embeddings, my_agent = self.player_encoder( env_outputs["Entity"], env_outputs["AgentId"][:, 0] diff --git a/train.py b/train.py index cc99df7..08f08c0 100644 --- a/train.py +++ b/train.py @@ -31,7 +31,7 @@ def setup_env(args): def make_policy(envs): learner_policy = policy.Baseline( - envs, + envs._driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size From 1222980a8e6648465601b7c2ab61b1486e5a6328 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 12 Sep 2023 23:12:49 +0000 Subject: [PATCH 2/2] Update evaluate for new policy --- evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluate.py b/evaluate.py index 06bc241..c9dc617 100644 --- a/evaluate.py +++ b/evaluate.py @@ -62,7 +62,7 @@ def save_replays(policy_store_dir, save_dir): from reinforcement_learning import policy # import your policy def make_policy(envs): learner_policy = policy.Baseline( - envs, + envs._driver_env, input_size=args.input_size, hidden_size=args.hidden_size, task_size=args.task_size