Skip to content

Commit

Permalink
load linref and normalizer on cuda device when resuming from checkpoi…
Browse files Browse the repository at this point in the history
…nt (#813)

* load linref and normalizer on cuda device when resuming from checkpoint

* move to() outside one level
  • Loading branch information
misko authored Aug 16, 2024
1 parent ef2a4bc commit 7877671
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def __init__(
"%j", self.config["slurm"]["job_id"]
)
if distutils.is_master():
add_timestamp_id_to_submission_pickle(self.config["slurm"]["folder"], self.config["slurm"]["job_id"], self.timestamp_id)
add_timestamp_id_to_submission_pickle(
self.config["slurm"]["folder"],
self.config["slurm"]["job_id"],
self.timestamp_id,
)

# Define datasets
if isinstance(dataset, list):
Expand Down Expand Up @@ -426,19 +430,23 @@ def load_references_and_normalizers(self):
elementref_config,
dataset=self.train_dataset,
seed=self.config["cmd"]["seed"],
checkpoint_dir=self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None,
checkpoint_dir=(
self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None
),
)

if norms_config is not None:
normalizers = load_normalizers_from_config(
norms_config,
dataset=self.train_dataset,
seed=self.config["cmd"]["seed"],
checkpoint_dir=self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None,
checkpoint_dir=(
self.config["cmd"]["checkpoint_dir"]
if not self.is_debug
else None
),
element_references=elementrefs,
)

Expand Down Expand Up @@ -487,15 +495,15 @@ def load_task(self):
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
self.config["outputs"][target_name].get(
"train_on_free_atoms", True
)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down Expand Up @@ -551,13 +559,13 @@ def _unwrapped_model(self):
def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
) -> None:
map_location = torch.device("cpu") if self.cpu else self.device
if checkpoint is None:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
map_location = torch.device("cpu") if self.cpu else self.device
checkpoint = torch.load(checkpoint_path, map_location=map_location)

self.epoch = checkpoint.get("epoch", 0)
Expand Down Expand Up @@ -600,13 +608,14 @@ def load_checkpoint(
mkeys = self.normalizers[target_key].load_state_dict(
checkpoint["normalizers"][key]
)
self.normalizers[target_key].to(map_location)
assert len(mkeys.missing_keys) == 0
assert len(mkeys.unexpected_keys) == 0

for key, state_dict in checkpoint.get("elementrefs", {}).items():
elementrefs = LinearReferences(
max_num_elements=len(state_dict["element_references"]) - 1
)
).to(map_location)
mkeys = elementrefs.load_state_dict(state_dict)
self.elementrefs[key] = elementrefs
assert len(mkeys.missing_keys) == 0
Expand Down

0 comments on commit 7877671

Please sign in to comment.