Skip to content

Commit

Permalink
Reenable checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 2, 2024
1 parent 405481c commit 1397b16
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,22 @@ def __init__(
self.num_agents = self.pool.agents_per_env
total_agents = self.num_agents * config.num_envs

self.agent = pufferlib.emulation.make_object(
agent, agent_creator, [self.pool.driver_env], agent_kwargs
)

# If data_dir is provided, load the resume state
resume_state = {}
path = os.path.join(config.data_dir, exp_name)
if os.path.exists(path):
trainer_path = os.path.join(path, "trainer_state.pt")
resume_state = torch.load(trainer_path)
model_path = os.path.join(path, resume_state["model_name"])
self.agent = torch.load(model_path, map_location=self.device)
self.agent.load_state_dict(torch.load(model_path, map_location=self.device))
print(
f'Resumed from update {resume_state["update"]} '
f'with policy {resume_state["model_name"]}'
)
else:
self.agent = pufferlib.emulation.make_object(
agent, agent_creator, [self.pool.driver_env], agent_kwargs
)

self.global_step = resume_state.get("global_step", 0)
self.agent_step = resume_state.get("agent_step", 0)
Expand All @@ -251,7 +251,7 @@ def __init__(
# TODO: Figure out how to compile the optimizer!
# self.calculate_loss = torch.compile(self.calculate_loss, mode=config.compile_mode)

if self.opt_state is not None:
if config.load_optimizer_state is True and self.opt_state is not None:
self.optimizer.load_state_dict(resume_state["optimizer_state_dict"])

# Create policy pool
Expand Down Expand Up @@ -662,15 +662,13 @@ def train(self):
def close(self):
self.pool.close()

"""
if self.wandb is not None:
artifact_name = f"{self.exp_name}_model"
artifact = self.wandb.Artifact(artifact_name, type="model")
model_path = self.save_checkpoint()
artifact.add_file(model_path)
self.wandb.run.log_artifact(artifact)
self.wandb.finish()
"""

def save_checkpoint(self):
if self.config.save_checkpoint is False:
Expand All @@ -680,14 +678,15 @@ def save_checkpoint(self):
if not os.path.exists(path):
os.makedirs(path)

model_name = f"model_{self.update:06d}.pt"
model_name = f"model_{self.update:06d}_state.pth"
model_path = os.path.join(path, model_name)

# Already saved
if os.path.exists(model_path):
return model_path

torch.save(self.agent, model_path)
# To handleboth uncompiled and compiled self.agent, when getting state_dict()
torch.save(getattr(self.agent, "_orig_mod", self.agent).state_dict(), model_path)

state = {
"optimizer_state_dict": self.optimizer.state_dict(),
Expand All @@ -704,6 +703,9 @@ def save_checkpoint(self):
torch.save(state, state_path + ".tmp")
os.rename(state_path + ".tmp", state_path)

# Also save a copy
torch.save(state, os.path.join(path, f"trainer_state_{self.update:06d}.pt"))

print(f"Model saved to {model_path}")

return model_path
Expand Down

0 comments on commit 1397b16

Please sign in to comment.