From a11bac959732a9bc97b31afcacdd134402d7bb27 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 4 Apr 2024 22:59:58 +0900 Subject: [PATCH] Fixes to checkpointing --- config.yaml | 1 + pokemonred_puffer/cleanrl_puffer.py | 27 ++++++++++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/config.yaml b/config.yaml index a9a2b69..a5ec122 100644 --- a/config.yaml +++ b/config.yaml @@ -43,6 +43,7 @@ env: reduce_res: True two_bit: True log_frequency: 2000 + load_optimizer_state: False train: seed: 1 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 176a22f..95fa5c1 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,4 +1,5 @@ import os +import pathlib import random import time from collections import deque @@ -227,16 +228,24 @@ def __init__( # If data_dir is provided, load the resume state resume_state = {} - path = os.path.join(config.data_dir, exp_name) - if os.path.exists(path): - trainer_path = os.path.join(path, "trainer_state.pt") + path = pathlib.Path(config.data_dir) / exp_name + trainer_path = path / "trainer_state.pt" + if trainer_path.exists(): resume_state = torch.load(trainer_path) - model_path = os.path.join(path, resume_state["model_name"]) - self.agent.load_state_dict(torch.load(model_path, map_location=self.device)) - print( - f'Resumed from update {resume_state["update"]} ' - f'with policy {resume_state["model_name"]}' - ) + + model_version = str(resume_state["update"]).zfill(6) + model_filename = f"model_{model_version}_state.pth" + model_path = path / model_filename + if model_path.exists(): + self.agent.load_state_dict(torch.load(model_path, map_location=self.device)) + print( + f'Resumed from update {resume_state["update"]} ' + f'with policy {resume_state["model_name"]}' + ) + else: + print("No checkpoint found. Starting fresh.") + else: + print("No checkpoint found. Starting fresh.") self.global_step = resume_state.get("global_step", 0) self.agent_step = resume_state.get("agent_step", 0)