From 1397b161cdc820822d7492056cd899387fa40f74 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Wed, 3 Apr 2024 06:42:06 +0900 Subject: [PATCH] Reenable checkpointing --- pokemonred_puffer/cleanrl_puffer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index dfd98bc..176a22f 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -221,6 +221,10 @@ def __init__( self.num_agents = self.pool.agents_per_env total_agents = self.num_agents * config.num_envs + self.agent = pufferlib.emulation.make_object( + agent, agent_creator, [self.pool.driver_env], agent_kwargs + ) + # If data_dir is provided, load the resume state resume_state = {} path = os.path.join(config.data_dir, exp_name) @@ -228,15 +232,11 @@ def __init__( trainer_path = os.path.join(path, "trainer_state.pt") resume_state = torch.load(trainer_path) model_path = os.path.join(path, resume_state["model_name"]) - self.agent = torch.load(model_path, map_location=self.device) + self.agent.load_state_dict(torch.load(model_path, map_location=self.device)) print( f'Resumed from update {resume_state["update"]} ' f'with policy {resume_state["model_name"]}' ) - else: - self.agent = pufferlib.emulation.make_object( - agent, agent_creator, [self.pool.driver_env], agent_kwargs - ) self.global_step = resume_state.get("global_step", 0) self.agent_step = resume_state.get("agent_step", 0) @@ -251,7 +251,7 @@ def __init__( # TODO: Figure out how to compile the optimizer! # self.calculate_loss = torch.compile(self.calculate_loss, mode=config.compile_mode) - if self.opt_state is not None: + if config.load_optimizer_state is True and self.opt_state is not None: self.optimizer.load_state_dict(resume_state["optimizer_state_dict"]) # Create policy pool @@ -662,7 +662,6 @@ def train(self): def close(self): self.pool.close() - """ if self.wandb is not None: artifact_name = f"{self.exp_name}_model" artifact = self.wandb.Artifact(artifact_name, type="model") @@ -670,7 +669,6 @@ def close(self): artifact.add_file(model_path) self.wandb.run.log_artifact(artifact) self.wandb.finish() - """ def save_checkpoint(self): if self.config.save_checkpoint is False: @@ -680,14 +678,15 @@ def save_checkpoint(self): if not os.path.exists(path): os.makedirs(path) - model_name = f"model_{self.update:06d}.pt" + model_name = f"model_{self.update:06d}_state.pth" model_path = os.path.join(path, model_name) # Already saved if os.path.exists(model_path): return model_path - torch.save(self.agent, model_path) + # To handleboth uncompiled and compiled self.agent, when getting state_dict() + torch.save(getattr(self.agent, "_orig_mod", self.agent).state_dict(), model_path) state = { "optimizer_state_dict": self.optimizer.state_dict(), @@ -704,6 +703,9 @@ def save_checkpoint(self): torch.save(state, state_path + ".tmp") os.rename(state_path + ".tmp", state_path) + # Also save a copy + torch.save(state, os.path.join(path, f"trainer_state_{self.update:06d}.pt")) + print(f"Model saved to {model_path}") return model_path